Support bidirectional requests over gateway

This commit is contained in:
2022-04-24 00:54:04 -07:00
parent 9034ef7c6b
commit 4b6f7d408e

137
socket.go
View File

@@ -19,10 +19,13 @@
package main package main
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"io" "io"
"net" "net"
"net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -31,6 +34,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/smallnest/rpcx/server" "github.com/smallnest/rpcx/server"
"github.com/smallnest/rpcx/share"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
"go.arsenm.dev/infinitime" "go.arsenm.dev/infinitime"
"go.arsenm.dev/infinitime/blefs" "go.arsenm.dev/infinitime/blefs"
@@ -46,7 +50,7 @@ var (
ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type") ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type")
ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type") ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
ErrDFUInvalidUpgType = errors.New("invalid upgrade type") ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway") ErrRPCXNoReturnURL = errors.New("bidirectional requests over gateway require a returnURL field in the metadata")
) )
type DoneMap map[string]chan struct{} type DoneMap map[string]chan struct{}
@@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error {
} }
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error { func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if !ok { if !ok {
return ErrRPCXUsingGateway return ErrRPCXNoReturnURL
} }
heartRateCh, cancel, err := i.dev.WatchHeartRate() heartRateCh, cancel, err := i.dev.WatchHeartRate()
@@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
} }
// Send response to connection if no done signal received // Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data) msgSender.SendMessage(id, "HeartRateSample", nil, data)
} }
} }
}() }()
@@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error {
} }
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error { func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if !ok { if !ok {
return ErrRPCXUsingGateway return ErrRPCXNoReturnURL
} }
battLevelCh, cancel, err := i.dev.WatchBatteryLevel() battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
@@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error
} }
// Send response to connection if no done signal received // Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data) msgSender.SendMessage(id, "BatteryLevelSample", nil, data)
} }
} }
}() }()
@@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er
} }
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error { func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if !ok { if !ok {
return ErrRPCXUsingGateway return ErrRPCXNoReturnURL
} }
motionValsCh, cancel, err := i.dev.WatchMotion() motionValsCh, cancel, err := i.dev.WatchMotion()
@@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
} }
// Send response to connection if no done signal received // Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "MotionSample", nil, data) msgSender.SendMessage(id, "MotionSample", nil, data)
} }
} }
}() }()
@@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error {
} }
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error { func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if !ok { if !ok {
return ErrRPCXUsingGateway return ErrRPCXNoReturnURL
} }
stepCountCh, cancel, err := i.dev.WatchStepCount() stepCountCh, cancel, err := i.dev.WatchStepCount()
@@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
} }
// Send response to connection if no done signal received // Send response to connection if no done signal received
i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data) msgSender.SendMessage(id, "StepCountSample", nil, data)
} }
} }
}() }()
@@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
id := uuid.New().String() id := uuid.New().String()
*out = id *out = id
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, i.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if ok { if ok {
go func() { go func() {
@@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
continue continue
} }
i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data) msgSender.SendMessage(id, "DFUProgress", nil, data)
} }
firmwareUpdating = false firmwareUpdating = false
i.srv.SendMessage(clientConn, id, "Done", nil, nil) msgSender.SendMessage(id, "Done", nil, nil)
}() }()
} }
@@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
id := uuid.New().String() id := uuid.New().String()
*out = id *out = id
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, fs.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if ok { if ok {
go func() { go func() {
@@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
continue continue
} }
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) msgSender.SendMessage(id, "FSProgress", nil, data)
} }
fs.srv.SendMessage(clientConn, id, "Done", nil, nil) msgSender.SendMessage(id, "Done", nil, nil)
}() }()
} }
@@ -540,22 +544,22 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error { func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error {
fs.updateFS() fs.updateFS()
localFile, err := os.Create(paths[0]) localFile, err := os.Create(paths[0])
if err != nil { if err != nil {
return err return err
} }
remoteFile, err := fs.fs.Open(paths[1]) remoteFile, err := fs.fs.Open(paths[1])
if err != nil { if err != nil {
return err return err
} }
id := uuid.New().String() id := uuid.New().String()
*out = id *out = id
// Get client's connection // Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn) msgSender, ok := getMsgSender(ctx, fs.srv)
// If user is using gateway, the client connection will not be available // If user is using gateway, the client connection will not be available
if ok { if ok {
go func() { go func() {
@@ -569,11 +573,11 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
log.Error().Err(err).Msg("Error encoding filesystem transfer progress event") log.Error().Err(err).Msg("Error encoding filesystem transfer progress event")
continue continue
} }
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data) msgSender.SendMessage(id, "FSProgress", nil, data)
} }
fs.srv.SendMessage(clientConn, id, "Done", nil, nil) msgSender.SendMessage(id, "Done", nil, nil)
localFile.Close() localFile.Close()
remoteFile.Close() remoteFile.Close()
}() }()
@@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string {
} }
return paths return paths
} }
func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) {
// Get client message sender
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
// If the connection exists, use rpcMsgSender
if ok {
return &rpcMsgSender{srv, clientConn}, true
} else {
// Get metadata if it exists
metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string)
if !ok {
return nil, false
}
// Get returnURL field from metadata if it exists
returnURL, ok := metadata["returnURL"]
if !ok {
return nil, false
}
// Use httpMsgSender
return &httpMsgSender{returnURL}, true
}
}
type MessageSender interface {
SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error
}
type rpcMsgSender struct {
srv *server.Server
conn net.Conn
}
func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data)
}
type httpMsgSender struct {
url string
}
func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data))
if err != nil {
return err
}
req.Header.Set("X-RPCX-ServicePath", servicePath)
req.Header.Set("X-RPCX-ServiceMethod", serviceMethod)
query := url.Values{}
for k, v := range metadata {
query.Set(k, v)
}
req.Header.Set("X-RPCX-Meta", query.Encode())
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
return res.Body.Close()
}