Support bidirectional requests over gateway
This commit is contained in:
parent
44c89408d2
commit
0ae40d69bc
137
socket.go
137
socket.go
@ -19,10 +19,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -31,6 +34,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/smallnest/rpcx/server"
|
||||
"github.com/smallnest/rpcx/share"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
"go.arsenm.dev/infinitime"
|
||||
"go.arsenm.dev/infinitime/blefs"
|
||||
@ -46,7 +50,7 @@ var (
|
||||
ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type")
|
||||
ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
|
||||
ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
|
||||
ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway")
|
||||
ErrRPCXNoReturnURL = errors.New("bidirectional requests over gateway require a returnURL field in the metadata")
|
||||
)
|
||||
|
||||
type DoneMap map[string]chan struct{}
|
||||
@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error {
|
||||
}
|
||||
|
||||
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if !ok {
|
||||
return ErrRPCXUsingGateway
|
||||
return ErrRPCXNoReturnURL
|
||||
}
|
||||
|
||||
heartRateCh, cancel, err := i.dev.WatchHeartRate()
|
||||
@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
|
||||
}
|
||||
|
||||
// Send response to connection if no done signal received
|
||||
i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data)
|
||||
msgSender.SendMessage(id, "HeartRateSample", nil, data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error {
|
||||
}
|
||||
|
||||
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if !ok {
|
||||
return ErrRPCXUsingGateway
|
||||
return ErrRPCXNoReturnURL
|
||||
}
|
||||
|
||||
battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
|
||||
@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error
|
||||
}
|
||||
|
||||
// Send response to connection if no done signal received
|
||||
i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data)
|
||||
msgSender.SendMessage(id, "BatteryLevelSample", nil, data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er
|
||||
}
|
||||
|
||||
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if !ok {
|
||||
return ErrRPCXUsingGateway
|
||||
return ErrRPCXNoReturnURL
|
||||
}
|
||||
|
||||
motionValsCh, cancel, err := i.dev.WatchMotion()
|
||||
@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
|
||||
}
|
||||
|
||||
// Send response to connection if no done signal received
|
||||
i.srv.SendMessage(clientConn, id, "MotionSample", nil, data)
|
||||
msgSender.SendMessage(id, "MotionSample", nil, data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error {
|
||||
}
|
||||
|
||||
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if !ok {
|
||||
return ErrRPCXUsingGateway
|
||||
return ErrRPCXNoReturnURL
|
||||
}
|
||||
|
||||
stepCountCh, cancel, err := i.dev.WatchStepCount()
|
||||
@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
|
||||
}
|
||||
|
||||
// Send response to connection if no done signal received
|
||||
i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data)
|
||||
msgSender.SendMessage(id, "StepCountSample", nil, data)
|
||||
}
|
||||
}
|
||||
}()
|
||||
@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
|
||||
id := uuid.New().String()
|
||||
*out = id
|
||||
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if ok {
|
||||
go func() {
|
||||
@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
|
||||
continue
|
||||
}
|
||||
|
||||
i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data)
|
||||
msgSender.SendMessage(id, "DFUProgress", nil, data)
|
||||
}
|
||||
|
||||
firmwareUpdating = false
|
||||
i.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
||||
msgSender.SendMessage(id, "Done", nil, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
|
||||
id := uuid.New().String()
|
||||
*out = id
|
||||
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, fs.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if ok {
|
||||
go func() {
|
||||
@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
|
||||
msgSender.SendMessage(id, "FSProgress", nil, data)
|
||||
}
|
||||
|
||||
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
||||
msgSender.SendMessage(id, "Done", nil, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@ -540,22 +544,22 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
|
||||
|
||||
func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error {
|
||||
fs.updateFS()
|
||||
|
||||
|
||||
localFile, err := os.Create(paths[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
remoteFile, err := fs.fs.Open(paths[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
id := uuid.New().String()
|
||||
*out = id
|
||||
|
||||
// Get client's connection
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
|
||||
// Get client message sender
|
||||
msgSender, ok := getMsgSender(ctx, fs.srv)
|
||||
// If user is using gateway, the client connection will not be available
|
||||
if ok {
|
||||
go func() {
|
||||
@ -569,11 +573,11 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
|
||||
log.Error().Err(err).Msg("Error encoding filesystem transfer progress event")
|
||||
continue
|
||||
}
|
||||
|
||||
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
|
||||
|
||||
msgSender.SendMessage(id, "FSProgress", nil, data)
|
||||
}
|
||||
|
||||
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
||||
|
||||
msgSender.SendMessage(id, "Done", nil, nil)
|
||||
localFile.Close()
|
||||
remoteFile.Close()
|
||||
}()
|
||||
@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string {
|
||||
}
|
||||
return paths
|
||||
}
|
||||
|
||||
func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) {
|
||||
// Get client message sender
|
||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||
// If the connection exists, use rpcMsgSender
|
||||
if ok {
|
||||
return &rpcMsgSender{srv, clientConn}, true
|
||||
} else {
|
||||
// Get metadata if it exists
|
||||
metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Get returnURL field from metadata if it exists
|
||||
returnURL, ok := metadata["returnURL"]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
// Use httpMsgSender
|
||||
return &httpMsgSender{returnURL}, true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type MessageSender interface {
|
||||
SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error
|
||||
}
|
||||
|
||||
type rpcMsgSender struct {
|
||||
srv *server.Server
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
|
||||
|
||||
return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data)
|
||||
}
|
||||
|
||||
type httpMsgSender struct {
|
||||
url string
|
||||
}
|
||||
|
||||
func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
|
||||
req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("X-RPCX-ServicePath", servicePath)
|
||||
req.Header.Set("X-RPCX-ServiceMethod", serviceMethod)
|
||||
|
||||
query := url.Values{}
|
||||
for k, v := range metadata {
|
||||
query.Set(k, v)
|
||||
}
|
||||
req.Header.Set("X-RPCX-Meta", query.Encode())
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return res.Body.Close()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user