forked from Elara6331/itd
		
	Support bidirectional requests over gateway
This commit is contained in:
		
							
								
								
									
										125
									
								
								socket.go
									
									
									
									
									
								
							
							
						
						
									
										125
									
								
								socket.go
									
									
									
									
									
								
							| @@ -19,10 +19,13 @@ | ||||
| package main | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| @@ -31,6 +34,7 @@ import ( | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"github.com/smallnest/rpcx/server" | ||||
| 	"github.com/smallnest/rpcx/share" | ||||
| 	"github.com/vmihailenco/msgpack/v5" | ||||
| 	"go.arsenm.dev/infinitime" | ||||
| 	"go.arsenm.dev/infinitime/blefs" | ||||
| @@ -46,7 +50,7 @@ var ( | ||||
| 	ErrDFUInvalidFile    = errors.New("provided file is invalid for given upgrade type") | ||||
| 	ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") | ||||
| 	ErrDFUInvalidUpgType = errors.New("invalid upgrade type") | ||||
| 	ErrRPCXUsingGateway  = errors.New("bidirectional requests are unsupported over gateway") | ||||
| 	ErrRPCXNoReturnURL   = errors.New("bidirectional requests over gateway require a returnURL field in the metadata") | ||||
| ) | ||||
|  | ||||
| type DoneMap map[string]chan struct{} | ||||
| @@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error { | ||||
| } | ||||
|  | ||||
| func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if !ok { | ||||
| 		return ErrRPCXUsingGateway | ||||
| 		return ErrRPCXNoReturnURL | ||||
| 	} | ||||
|  | ||||
| 	heartRateCh, cancel, err := i.dev.WatchHeartRate() | ||||
| @@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { | ||||
| 				} | ||||
|  | ||||
| 				// Send response to connection if no done signal received | ||||
| 				i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data) | ||||
| 				msgSender.SendMessage(id, "HeartRateSample", nil, data) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| @@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error { | ||||
| } | ||||
|  | ||||
| func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if !ok { | ||||
| 		return ErrRPCXUsingGateway | ||||
| 		return ErrRPCXNoReturnURL | ||||
| 	} | ||||
|  | ||||
| 	battLevelCh, cancel, err := i.dev.WatchBatteryLevel() | ||||
| @@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error | ||||
| 				} | ||||
|  | ||||
| 				// Send response to connection if no done signal received | ||||
| 				i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data) | ||||
| 				msgSender.SendMessage(id, "BatteryLevelSample", nil, data) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| @@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er | ||||
| } | ||||
|  | ||||
| func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if !ok { | ||||
| 		return ErrRPCXUsingGateway | ||||
| 		return ErrRPCXNoReturnURL | ||||
| 	} | ||||
|  | ||||
| 	motionValsCh, cancel, err := i.dev.WatchMotion() | ||||
| @@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { | ||||
| 				} | ||||
|  | ||||
| 				// Send response to connection if no done signal received | ||||
| 				i.srv.SendMessage(clientConn, id, "MotionSample", nil, data) | ||||
| 				msgSender.SendMessage(id, "MotionSample", nil, data) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| @@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error { | ||||
| } | ||||
|  | ||||
| func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if !ok { | ||||
| 		return ErrRPCXUsingGateway | ||||
| 		return ErrRPCXNoReturnURL | ||||
| 	} | ||||
|  | ||||
| 	stepCountCh, cancel, err := i.dev.WatchStepCount() | ||||
| @@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { | ||||
| 				} | ||||
|  | ||||
| 				// Send response to connection if no done signal received | ||||
| 				i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data) | ||||
| 				msgSender.SendMessage(id, "StepCountSample", nil, data) | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| @@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou | ||||
| 	id := uuid.New().String() | ||||
| 	*out = id | ||||
|  | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if ok { | ||||
| 		go func() { | ||||
| @@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) | ||||
| 				msgSender.SendMessage(id, "DFUProgress", nil, data) | ||||
| 			} | ||||
|  | ||||
| 			firmwareUpdating = false | ||||
| 			i.srv.SendMessage(clientConn, id, "Done", nil, nil) | ||||
| 			msgSender.SendMessage(id, "Done", nil, nil) | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| @@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { | ||||
| 	id := uuid.New().String() | ||||
| 	*out = id | ||||
|  | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, fs.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if ok { | ||||
| 		go func() { | ||||
| @@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) | ||||
| 				msgSender.SendMessage(id, "FSProgress", nil, data) | ||||
| 			} | ||||
|  | ||||
| 			fs.srv.SendMessage(clientConn, id, "Done", nil, nil) | ||||
| 			msgSender.SendMessage(id, "Done", nil, nil) | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| @@ -554,8 +558,8 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error | ||||
| 	id := uuid.New().String() | ||||
| 	*out = id | ||||
|  | ||||
| 	// Get client's connection | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// Get client message sender | ||||
| 	msgSender, ok := getMsgSender(ctx, fs.srv) | ||||
| 	// If user is using gateway, the client connection will not be available | ||||
| 	if ok { | ||||
| 		go func() { | ||||
| @@ -570,10 +574,10 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error | ||||
| 					continue | ||||
| 				} | ||||
|  | ||||
| 				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) | ||||
| 				msgSender.SendMessage(id, "FSProgress", nil, data) | ||||
| 			} | ||||
|  | ||||
| 			fs.srv.SendMessage(clientConn, id, "Done", nil, nil) | ||||
| 			msgSender.SendMessage(id, "Done", nil, nil) | ||||
| 			localFile.Close() | ||||
| 			remoteFile.Close() | ||||
| 		}() | ||||
| @@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string { | ||||
| 	} | ||||
| 	return paths | ||||
| } | ||||
|  | ||||
| func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) { | ||||
| 	// Get client message sender | ||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||
| 	// If the connection exists, use rpcMsgSender | ||||
| 	if ok { | ||||
| 		return &rpcMsgSender{srv, clientConn}, true | ||||
| 	} else { | ||||
| 		// Get metadata if it exists | ||||
| 		metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string) | ||||
| 		if !ok { | ||||
| 			return nil, false | ||||
| 		} | ||||
| 		// Get returnURL field from metadata if it exists | ||||
| 		returnURL, ok := metadata["returnURL"] | ||||
| 		if !ok { | ||||
| 			return nil, false | ||||
| 		} | ||||
| 		// Use httpMsgSender | ||||
| 		return &httpMsgSender{returnURL}, true | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| type MessageSender interface { | ||||
| 	SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error | ||||
| } | ||||
|  | ||||
| type rpcMsgSender struct { | ||||
| 	srv  *server.Server | ||||
| 	conn net.Conn | ||||
| } | ||||
|  | ||||
| func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { | ||||
|  | ||||
| 	return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data) | ||||
| } | ||||
|  | ||||
| type httpMsgSender struct { | ||||
| 	url string | ||||
| } | ||||
|  | ||||
| func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { | ||||
| 	req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	req.Header.Set("X-RPCX-ServicePath", servicePath) | ||||
| 	req.Header.Set("X-RPCX-ServiceMethod", serviceMethod) | ||||
|  | ||||
| 	query := url.Values{} | ||||
| 	for k, v := range metadata { | ||||
| 		query.Set(k, v) | ||||
| 	} | ||||
| 	req.Header.Set("X-RPCX-Meta", query.Encode()) | ||||
|  | ||||
| 	res, err := http.DefaultClient.Do(req) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return res.Body.Close() | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user