forked from Elara6331/itd
		
	Support bidirectional requests over gateway
This commit is contained in:
		
							
								
								
									
										125
									
								
								socket.go
									
									
									
									
									
								
							
							
						
						
									
										125
									
								
								socket.go
									
									
									
									
									
								
							@@ -19,10 +19,13 @@
 | 
				
			|||||||
package main
 | 
					package main
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@@ -31,6 +34,7 @@ import (
 | 
				
			|||||||
	"github.com/google/uuid"
 | 
						"github.com/google/uuid"
 | 
				
			||||||
	"github.com/rs/zerolog/log"
 | 
						"github.com/rs/zerolog/log"
 | 
				
			||||||
	"github.com/smallnest/rpcx/server"
 | 
						"github.com/smallnest/rpcx/server"
 | 
				
			||||||
 | 
						"github.com/smallnest/rpcx/share"
 | 
				
			||||||
	"github.com/vmihailenco/msgpack/v5"
 | 
						"github.com/vmihailenco/msgpack/v5"
 | 
				
			||||||
	"go.arsenm.dev/infinitime"
 | 
						"go.arsenm.dev/infinitime"
 | 
				
			||||||
	"go.arsenm.dev/infinitime/blefs"
 | 
						"go.arsenm.dev/infinitime/blefs"
 | 
				
			||||||
@@ -46,7 +50,7 @@ var (
 | 
				
			|||||||
	ErrDFUInvalidFile    = errors.New("provided file is invalid for given upgrade type")
 | 
						ErrDFUInvalidFile    = errors.New("provided file is invalid for given upgrade type")
 | 
				
			||||||
	ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
 | 
						ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
 | 
				
			||||||
	ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
 | 
						ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
 | 
				
			||||||
	ErrRPCXUsingGateway  = errors.New("bidirectional requests are unsupported over gateway")
 | 
						ErrRPCXNoReturnURL   = errors.New("bidirectional requests over gateway require a returnURL field in the metadata")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type DoneMap map[string]chan struct{}
 | 
					type DoneMap map[string]chan struct{}
 | 
				
			||||||
@@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
 | 
					func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, i.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return ErrRPCXUsingGateway
 | 
							return ErrRPCXNoReturnURL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	heartRateCh, cancel, err := i.dev.WatchHeartRate()
 | 
						heartRateCh, cancel, err := i.dev.WatchHeartRate()
 | 
				
			||||||
@@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// Send response to connection if no done signal received
 | 
									// Send response to connection if no done signal received
 | 
				
			||||||
				i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data)
 | 
									msgSender.SendMessage(id, "HeartRateSample", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
 | 
					func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, i.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return ErrRPCXUsingGateway
 | 
							return ErrRPCXNoReturnURL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
 | 
						battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
 | 
				
			||||||
@@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// Send response to connection if no done signal received
 | 
									// Send response to connection if no done signal received
 | 
				
			||||||
				i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data)
 | 
									msgSender.SendMessage(id, "BatteryLevelSample", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
 | 
					func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, i.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return ErrRPCXUsingGateway
 | 
							return ErrRPCXNoReturnURL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	motionValsCh, cancel, err := i.dev.WatchMotion()
 | 
						motionValsCh, cancel, err := i.dev.WatchMotion()
 | 
				
			||||||
@@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// Send response to connection if no done signal received
 | 
									// Send response to connection if no done signal received
 | 
				
			||||||
				i.srv.SendMessage(clientConn, id, "MotionSample", nil, data)
 | 
									msgSender.SendMessage(id, "MotionSample", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
 | 
					func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, i.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if !ok {
 | 
						if !ok {
 | 
				
			||||||
		return ErrRPCXUsingGateway
 | 
							return ErrRPCXNoReturnURL
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	stepCountCh, cancel, err := i.dev.WatchStepCount()
 | 
						stepCountCh, cancel, err := i.dev.WatchStepCount()
 | 
				
			||||||
@@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
 | 
				
			|||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				// Send response to connection if no done signal received
 | 
									// Send response to connection if no done signal received
 | 
				
			||||||
				i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data)
 | 
									msgSender.SendMessage(id, "StepCountSample", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
@@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
 | 
				
			|||||||
	id := uuid.New().String()
 | 
						id := uuid.New().String()
 | 
				
			||||||
	*out = id
 | 
						*out = id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, i.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
@@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
 | 
				
			|||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data)
 | 
									msgSender.SendMessage(id, "DFUProgress", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			firmwareUpdating = false
 | 
								firmwareUpdating = false
 | 
				
			||||||
			i.srv.SendMessage(clientConn, id, "Done", nil, nil)
 | 
								msgSender.SendMessage(id, "Done", nil, nil)
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
 | 
				
			|||||||
	id := uuid.New().String()
 | 
						id := uuid.New().String()
 | 
				
			||||||
	*out = id
 | 
						*out = id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, fs.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
@@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
 | 
				
			|||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
 | 
									msgSender.SendMessage(id, "FSProgress", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
 | 
								msgSender.SendMessage(id, "Done", nil, nil)
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -554,8 +558,8 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
 | 
				
			|||||||
	id := uuid.New().String()
 | 
						id := uuid.New().String()
 | 
				
			||||||
	*out = id
 | 
						*out = id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Get client's connection
 | 
						// Get client message sender
 | 
				
			||||||
	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
						msgSender, ok := getMsgSender(ctx, fs.srv)
 | 
				
			||||||
	// If user is using gateway, the client connection will not be available
 | 
						// If user is using gateway, the client connection will not be available
 | 
				
			||||||
	if ok {
 | 
						if ok {
 | 
				
			||||||
		go func() {
 | 
							go func() {
 | 
				
			||||||
@@ -570,10 +574,10 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
 | 
				
			|||||||
					continue
 | 
										continue
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
 | 
									msgSender.SendMessage(id, "FSProgress", nil, data)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
 | 
								msgSender.SendMessage(id, "Done", nil, nil)
 | 
				
			||||||
			localFile.Close()
 | 
								localFile.Close()
 | 
				
			||||||
			remoteFile.Close()
 | 
								remoteFile.Close()
 | 
				
			||||||
		}()
 | 
							}()
 | 
				
			||||||
@@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return paths
 | 
						return paths
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) {
 | 
				
			||||||
 | 
						// Get client message sender
 | 
				
			||||||
 | 
						clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
 | 
				
			||||||
 | 
						// If the connection exists, use rpcMsgSender
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							return &rpcMsgSender{srv, clientConn}, true
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							// Get metadata if it exists
 | 
				
			||||||
 | 
							metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string)
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return nil, false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// Get returnURL field from metadata if it exists
 | 
				
			||||||
 | 
							returnURL, ok := metadata["returnURL"]
 | 
				
			||||||
 | 
							if !ok {
 | 
				
			||||||
 | 
								return nil, false
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							// Use httpMsgSender
 | 
				
			||||||
 | 
							return &httpMsgSender{returnURL}, true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type MessageSender interface {
 | 
				
			||||||
 | 
						SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type rpcMsgSender struct {
 | 
				
			||||||
 | 
						srv  *server.Server
 | 
				
			||||||
 | 
						conn net.Conn
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type httpMsgSender struct {
 | 
				
			||||||
 | 
						url string
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
 | 
				
			||||||
 | 
						req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data))
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						req.Header.Set("X-RPCX-ServicePath", servicePath)
 | 
				
			||||||
 | 
						req.Header.Set("X-RPCX-ServiceMethod", serviceMethod)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						query := url.Values{}
 | 
				
			||||||
 | 
						for k, v := range metadata {
 | 
				
			||||||
 | 
							query.Set(k, v)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						req.Header.Set("X-RPCX-Meta", query.Encode())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						res, err := http.DefaultClient.Do(req)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return res.Body.Close()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user