forked from Elara6331/itd
		
	Support bidirectional requests over gateway
This commit is contained in:
		
							
								
								
									
										125
									
								
								socket.go
									
									
									
									
									
								
							
							
						
						
									
										125
									
								
								socket.go
									
									
									
									
									
								
							| @@ -19,10 +19,13 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"io" | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -31,6 +34,7 @@ import ( | |||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| 	"github.com/smallnest/rpcx/server" | 	"github.com/smallnest/rpcx/server" | ||||||
|  | 	"github.com/smallnest/rpcx/share" | ||||||
| 	"github.com/vmihailenco/msgpack/v5" | 	"github.com/vmihailenco/msgpack/v5" | ||||||
| 	"go.arsenm.dev/infinitime" | 	"go.arsenm.dev/infinitime" | ||||||
| 	"go.arsenm.dev/infinitime/blefs" | 	"go.arsenm.dev/infinitime/blefs" | ||||||
| @@ -46,7 +50,7 @@ var ( | |||||||
| 	ErrDFUInvalidFile    = errors.New("provided file is invalid for given upgrade type") | 	ErrDFUInvalidFile    = errors.New("provided file is invalid for given upgrade type") | ||||||
| 	ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") | 	ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") | ||||||
| 	ErrDFUInvalidUpgType = errors.New("invalid upgrade type") | 	ErrDFUInvalidUpgType = errors.New("invalid upgrade type") | ||||||
| 	ErrRPCXUsingGateway  = errors.New("bidirectional requests are unsupported over gateway") | 	ErrRPCXNoReturnURL   = errors.New("bidirectional requests over gateway require a returnURL field in the metadata") | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type DoneMap map[string]chan struct{} | type DoneMap map[string]chan struct{} | ||||||
| @@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { | func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrRPCXUsingGateway | 		return ErrRPCXNoReturnURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	heartRateCh, cancel, err := i.dev.WatchHeartRate() | 	heartRateCh, cancel, err := i.dev.WatchHeartRate() | ||||||
| @@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { | |||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// Send response to connection if no done signal received | 				// Send response to connection if no done signal received | ||||||
| 				i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data) | 				msgSender.SendMessage(id, "HeartRateSample", nil, data) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { | func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrRPCXUsingGateway | 		return ErrRPCXNoReturnURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	battLevelCh, cancel, err := i.dev.WatchBatteryLevel() | 	battLevelCh, cancel, err := i.dev.WatchBatteryLevel() | ||||||
| @@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error | |||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// Send response to connection if no done signal received | 				// Send response to connection if no done signal received | ||||||
| 				i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data) | 				msgSender.SendMessage(id, "BatteryLevelSample", nil, data) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er | |||||||
| } | } | ||||||
|  |  | ||||||
| func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { | func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrRPCXUsingGateway | 		return ErrRPCXNoReturnURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	motionValsCh, cancel, err := i.dev.WatchMotion() | 	motionValsCh, cancel, err := i.dev.WatchMotion() | ||||||
| @@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { | |||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// Send response to connection if no done signal received | 				// Send response to connection if no done signal received | ||||||
| 				i.srv.SendMessage(clientConn, id, "MotionSample", nil, data) | 				msgSender.SendMessage(id, "MotionSample", nil, data) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { | func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrRPCXUsingGateway | 		return ErrRPCXNoReturnURL | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	stepCountCh, cancel, err := i.dev.WatchStepCount() | 	stepCountCh, cancel, err := i.dev.WatchStepCount() | ||||||
| @@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { | |||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// Send response to connection if no done signal received | 				// Send response to connection if no done signal received | ||||||
| 				i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data) | 				msgSender.SendMessage(id, "StepCountSample", nil, data) | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| @@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou | |||||||
| 	id := uuid.New().String() | 	id := uuid.New().String() | ||||||
| 	*out = id | 	*out = id | ||||||
|  |  | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, i.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if ok { | 	if ok { | ||||||
| 		go func() { | 		go func() { | ||||||
| @@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) | 				msgSender.SendMessage(id, "DFUProgress", nil, data) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			firmwareUpdating = false | 			firmwareUpdating = false | ||||||
| 			i.srv.SendMessage(clientConn, id, "Done", nil, nil) | 			msgSender.SendMessage(id, "Done", nil, nil) | ||||||
| 		}() | 		}() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { | |||||||
| 	id := uuid.New().String() | 	id := uuid.New().String() | ||||||
| 	*out = id | 	*out = id | ||||||
|  |  | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, fs.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if ok { | 	if ok { | ||||||
| 		go func() { | 		go func() { | ||||||
| @@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error { | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) | 				msgSender.SendMessage(id, "FSProgress", nil, data) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			fs.srv.SendMessage(clientConn, id, "Done", nil, nil) | 			msgSender.SendMessage(id, "Done", nil, nil) | ||||||
| 		}() | 		}() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -554,8 +558,8 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error | |||||||
| 	id := uuid.New().String() | 	id := uuid.New().String() | ||||||
| 	*out = id | 	*out = id | ||||||
|  |  | ||||||
| 	// Get client's connection | 	// Get client message sender | ||||||
| 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | 	msgSender, ok := getMsgSender(ctx, fs.srv) | ||||||
| 	// If user is using gateway, the client connection will not be available | 	// If user is using gateway, the client connection will not be available | ||||||
| 	if ok { | 	if ok { | ||||||
| 		go func() { | 		go func() { | ||||||
| @@ -570,10 +574,10 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error | |||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) | 				msgSender.SendMessage(id, "FSProgress", nil, data) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			fs.srv.SendMessage(clientConn, id, "Done", nil, nil) | 			msgSender.SendMessage(id, "Done", nil, nil) | ||||||
| 			localFile.Close() | 			localFile.Close() | ||||||
| 			remoteFile.Close() | 			remoteFile.Close() | ||||||
| 		}() | 		}() | ||||||
| @@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string { | |||||||
| 	} | 	} | ||||||
| 	return paths | 	return paths | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) { | ||||||
|  | 	// Get client message sender | ||||||
|  | 	clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) | ||||||
|  | 	// If the connection exists, use rpcMsgSender | ||||||
|  | 	if ok { | ||||||
|  | 		return &rpcMsgSender{srv, clientConn}, true | ||||||
|  | 	} else { | ||||||
|  | 		// Get metadata if it exists | ||||||
|  | 		metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string) | ||||||
|  | 		if !ok { | ||||||
|  | 			return nil, false | ||||||
|  | 		} | ||||||
|  | 		// Get returnURL field from metadata if it exists | ||||||
|  | 		returnURL, ok := metadata["returnURL"] | ||||||
|  | 		if !ok { | ||||||
|  | 			return nil, false | ||||||
|  | 		} | ||||||
|  | 		// Use httpMsgSender | ||||||
|  | 		return &httpMsgSender{returnURL}, true | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type MessageSender interface { | ||||||
|  | 	SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type rpcMsgSender struct { | ||||||
|  | 	srv  *server.Server | ||||||
|  | 	conn net.Conn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { | ||||||
|  |  | ||||||
|  | 	return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type httpMsgSender struct { | ||||||
|  | 	url string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error { | ||||||
|  | 	req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	req.Header.Set("X-RPCX-ServicePath", servicePath) | ||||||
|  | 	req.Header.Set("X-RPCX-ServiceMethod", serviceMethod) | ||||||
|  |  | ||||||
|  | 	query := url.Values{} | ||||||
|  | 	for k, v := range metadata { | ||||||
|  | 		query.Set(k, v) | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("X-RPCX-Meta", query.Encode()) | ||||||
|  |  | ||||||
|  | 	res, err := http.DefaultClient.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return res.Body.Close() | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user