forked from Elara6331/itd
Support bidirectional requests over gateway
This commit is contained in:
125
socket.go
125
socket.go
@@ -19,10 +19,13 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -31,6 +34,7 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/smallnest/rpcx/server"
|
"github.com/smallnest/rpcx/server"
|
||||||
|
"github.com/smallnest/rpcx/share"
|
||||||
"github.com/vmihailenco/msgpack/v5"
|
"github.com/vmihailenco/msgpack/v5"
|
||||||
"go.arsenm.dev/infinitime"
|
"go.arsenm.dev/infinitime"
|
||||||
"go.arsenm.dev/infinitime/blefs"
|
"go.arsenm.dev/infinitime/blefs"
|
||||||
@@ -46,7 +50,7 @@ var (
|
|||||||
ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type")
|
ErrDFUInvalidFile = errors.New("provided file is invalid for given upgrade type")
|
||||||
ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
|
ErrDFUNotEnoughFiles = errors.New("not enough files provided for given upgrade type")
|
||||||
ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
|
ErrDFUInvalidUpgType = errors.New("invalid upgrade type")
|
||||||
ErrRPCXUsingGateway = errors.New("bidirectional requests are unsupported over gateway")
|
ErrRPCXNoReturnURL = errors.New("bidirectional requests over gateway require a returnURL field in the metadata")
|
||||||
)
|
)
|
||||||
|
|
||||||
type DoneMap map[string]chan struct{}
|
type DoneMap map[string]chan struct{}
|
||||||
@@ -137,11 +141,11 @@ func (i *ITD) HeartRate(_ context.Context, _ none, out *uint8) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
|
func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrRPCXUsingGateway
|
return ErrRPCXNoReturnURL
|
||||||
}
|
}
|
||||||
|
|
||||||
heartRateCh, cancel, err := i.dev.WatchHeartRate()
|
heartRateCh, cancel, err := i.dev.WatchHeartRate()
|
||||||
@@ -168,7 +172,7 @@ func (i *ITD) WatchHeartRate(ctx context.Context, _ none, out *string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send response to connection if no done signal received
|
// Send response to connection if no done signal received
|
||||||
i.srv.SendMessage(clientConn, id, "HeartRateSample", nil, data)
|
msgSender.SendMessage(id, "HeartRateSample", nil, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -184,11 +188,11 @@ func (i *ITD) BatteryLevel(_ context.Context, _ none, out *uint8) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
|
func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error {
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrRPCXUsingGateway
|
return ErrRPCXNoReturnURL
|
||||||
}
|
}
|
||||||
|
|
||||||
battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
|
battLevelCh, cancel, err := i.dev.WatchBatteryLevel()
|
||||||
@@ -215,7 +219,7 @@ func (i *ITD) WatchBatteryLevel(ctx context.Context, _ none, out *string) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send response to connection if no done signal received
|
// Send response to connection if no done signal received
|
||||||
i.srv.SendMessage(clientConn, id, "BatteryLevelSample", nil, data)
|
msgSender.SendMessage(id, "BatteryLevelSample", nil, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -231,11 +235,11 @@ func (i *ITD) Motion(_ context.Context, _ none, out *infinitime.MotionValues) er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
|
func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrRPCXUsingGateway
|
return ErrRPCXNoReturnURL
|
||||||
}
|
}
|
||||||
|
|
||||||
motionValsCh, cancel, err := i.dev.WatchMotion()
|
motionValsCh, cancel, err := i.dev.WatchMotion()
|
||||||
@@ -262,7 +266,7 @@ func (i *ITD) WatchMotion(ctx context.Context, _ none, out *string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send response to connection if no done signal received
|
// Send response to connection if no done signal received
|
||||||
i.srv.SendMessage(clientConn, id, "MotionSample", nil, data)
|
msgSender.SendMessage(id, "MotionSample", nil, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -278,11 +282,11 @@ func (i *ITD) StepCount(_ context.Context, _ none, out *uint32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
|
func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrRPCXUsingGateway
|
return ErrRPCXNoReturnURL
|
||||||
}
|
}
|
||||||
|
|
||||||
stepCountCh, cancel, err := i.dev.WatchStepCount()
|
stepCountCh, cancel, err := i.dev.WatchStepCount()
|
||||||
@@ -309,7 +313,7 @@ func (i *ITD) WatchStepCount(ctx context.Context, _ none, out *string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send response to connection if no done signal received
|
// Send response to connection if no done signal received
|
||||||
i.srv.SendMessage(clientConn, id, "StepCountSample", nil, data)
|
msgSender.SendMessage(id, "StepCountSample", nil, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -386,8 +390,8 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
|
|||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
*out = id
|
*out = id
|
||||||
|
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, i.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if ok {
|
if ok {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -399,11 +403,11 @@ func (i *ITD) FirmwareUpgrade(ctx context.Context, reqData api.FwUpgradeData, ou
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
i.srv.SendMessage(clientConn, id, "DFUProgress", nil, data)
|
msgSender.SendMessage(id, "DFUProgress", nil, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
firmwareUpdating = false
|
firmwareUpdating = false
|
||||||
i.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
msgSender.SendMessage(id, "Done", nil, nil)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,8 +510,8 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
|
|||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
*out = id
|
*out = id
|
||||||
|
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, fs.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if ok {
|
if ok {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -522,10 +526,10 @@ func (fs *FS) Upload(ctx context.Context, paths [2]string, out *string) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
|
msgSender.SendMessage(id, "FSProgress", nil, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
msgSender.SendMessage(id, "Done", nil, nil)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -554,8 +558,8 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
|
|||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
*out = id
|
*out = id
|
||||||
|
|
||||||
// Get client's connection
|
// Get client message sender
|
||||||
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
msgSender, ok := getMsgSender(ctx, fs.srv)
|
||||||
// If user is using gateway, the client connection will not be available
|
// If user is using gateway, the client connection will not be available
|
||||||
if ok {
|
if ok {
|
||||||
go func() {
|
go func() {
|
||||||
@@ -570,10 +574,10 @@ func (fs *FS) Download(ctx context.Context, paths [2]string, out *string) error
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.srv.SendMessage(clientConn, id, "FSProgress", nil, data)
|
msgSender.SendMessage(id, "FSProgress", nil, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.srv.SendMessage(clientConn, id, "Done", nil, nil)
|
msgSender.SendMessage(id, "Done", nil, nil)
|
||||||
localFile.Close()
|
localFile.Close()
|
||||||
remoteFile.Close()
|
remoteFile.Close()
|
||||||
}()
|
}()
|
||||||
@@ -608,3 +612,66 @@ func cleanPaths(paths []string) []string {
|
|||||||
}
|
}
|
||||||
return paths
|
return paths
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getMsgSender(ctx context.Context, srv *server.Server) (MessageSender, bool) {
|
||||||
|
// Get client message sender
|
||||||
|
clientConn, ok := ctx.Value(server.RemoteConnContextKey).(net.Conn)
|
||||||
|
// If the connection exists, use rpcMsgSender
|
||||||
|
if ok {
|
||||||
|
return &rpcMsgSender{srv, clientConn}, true
|
||||||
|
} else {
|
||||||
|
// Get metadata if it exists
|
||||||
|
metadata, ok := ctx.Value(share.ReqMetaDataKey).(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Get returnURL field from metadata if it exists
|
||||||
|
returnURL, ok := metadata["returnURL"]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
// Use httpMsgSender
|
||||||
|
return &httpMsgSender{returnURL}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageSender interface {
|
||||||
|
SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type rpcMsgSender struct {
|
||||||
|
srv *server.Server
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rpcMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
|
||||||
|
|
||||||
|
return r.srv.SendMessage(r.conn, servicePath, serviceMethod, metadata, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpMsgSender struct {
|
||||||
|
url string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpMsgSender) SendMessage(servicePath, serviceMethod string, metadata map[string]string, data []byte) error {
|
||||||
|
req, err := http.NewRequest(http.MethodPost, h.url, bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("X-RPCX-ServicePath", servicePath)
|
||||||
|
req.Header.Set("X-RPCX-ServiceMethod", serviceMethod)
|
||||||
|
|
||||||
|
query := url.Values{}
|
||||||
|
for k, v := range metadata {
|
||||||
|
query.Set(k, v)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-RPCX-Meta", query.Encode())
|
||||||
|
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return res.Body.Close()
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user