Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec
This commit is contained in:
@@ -28,7 +28,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
@@ -37,12 +36,11 @@ import (
|
||||
type any = interface{}
|
||||
|
||||
var (
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||
ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
|
||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||
)
|
||||
|
||||
// Server is an lrpc server
|
||||
@@ -101,7 +99,7 @@ func (s *Server) Register(v any) error {
|
||||
}
|
||||
|
||||
// execute runs a method of a registered value
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
// Try to get value from receivers map
|
||||
val, ok := s.rcvrs[typ]
|
||||
if !ok {
|
||||
@@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
|
||||
// Get method type
|
||||
mtdType := mtd.Type()
|
||||
|
||||
// Return error if argument provided but isn't expected
|
||||
if mtdType.NumIn() == 1 && arg != nil {
|
||||
return nil, nil, ErrUnexpectedArgument
|
||||
}
|
||||
//TODO: if arg not nil but fn has no arg, err
|
||||
|
||||
// IF argument is []any
|
||||
anySlice, ok := arg.([]any)
|
||||
if ok {
|
||||
// Convert slice to the method's arg type and
|
||||
// set arg to the newly-converted slice
|
||||
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
|
||||
}
|
||||
argType := mtdType.In(1)
|
||||
argVal := reflect.New(argType)
|
||||
arg := argVal.Interface()
|
||||
|
||||
// Get argument value
|
||||
argVal := reflect.ValueOf(arg)
|
||||
// If argument's type does not match method's argument type
|
||||
if arg != nil && argVal.Type() != mtdType.In(1) {
|
||||
val, err = reflectutil.Convert(argVal, mtdType.In(1))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
arg = val.Interface()
|
||||
err = c.Unmarshal(data, arg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
arg = argVal.Elem().Interface()
|
||||
|
||||
|
||||
ctx = newContext(pCtx, c)
|
||||
// Get reflect value of context
|
||||
@@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
} else {
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create response
|
||||
res := types.Response{
|
||||
ID: call.ID,
|
||||
Return: val,
|
||||
Return: valData,
|
||||
}
|
||||
|
||||
// If function has created a channel
|
||||
if ctx.isChannel {
|
||||
idData, err := c.Marshal(ctx.channelID)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set IsChannel to true
|
||||
res.Type = types.ResponseTypeChannel
|
||||
// Overwrite return value with channel ID
|
||||
res.Return = ctx.channelID
|
||||
res.Return = idData
|
||||
|
||||
// Store context in map for future use
|
||||
s.contextsMtx.Lock()
|
||||
@@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
// For every value received from channel
|
||||
for val := range ctx.channel {
|
||||
codecMtx.Lock()
|
||||
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
Return: val,
|
||||
Return: valData,
|
||||
})
|
||||
|
||||
codecMtx.Unlock()
|
||||
}
|
||||
|
||||
@@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
|
||||
// sendErr sends an error response
|
||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
||||
valData, _ := c.Marshal(val)
|
||||
|
||||
// Encode error response using codec
|
||||
c.Encode(types.Response{
|
||||
Type: types.ResponseTypeError,
|
||||
ID: req.ID,
|
||||
Error: err.Error(),
|
||||
Return: val,
|
||||
Return: valData,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user