Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec

This commit is contained in:
2022-08-06 22:48:42 -07:00
parent 5e61e89ac1
commit e02c8bc5ff
5 changed files with 114 additions and 85 deletions

View File

@@ -28,7 +28,6 @@ import (
"sync"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types"
"golang.org/x/net/websocket"
)
@@ -37,12 +36,11 @@ import (
type any = interface{}
var (
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidMethod = errors.New("method invalid for lrpc call")
ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidMethod = errors.New("method invalid for lrpc call")
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
)
// Server is an lrpc server
@@ -101,7 +99,7 @@ func (s *Server) Register(v any) error {
}
// execute runs a method of a registered value
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
// Try to get value from receivers map
val, ok := s.rcvrs[typ]
if !ok {
@@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
// Get method type
mtdType := mtd.Type()
// Return error if argument provided but isn't expected
if mtdType.NumIn() == 1 && arg != nil {
return nil, nil, ErrUnexpectedArgument
}
//TODO: if arg not nil but fn has no arg, err
// IF argument is []any
anySlice, ok := arg.([]any)
if ok {
// Convert slice to the method's arg type and
// set arg to the newly-converted slice
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
}
argType := mtdType.In(1)
argVal := reflect.New(argType)
arg := argVal.Interface()
// Get argument value
argVal := reflect.ValueOf(arg)
// If argument's type does not match method's argument type
if arg != nil && argVal.Type() != mtdType.In(1) {
val, err = reflectutil.Convert(argVal, mtdType.In(1))
if err != nil {
return nil, nil, err
}
arg = val.Interface()
err = c.Unmarshal(data, arg)
if err != nil {
return nil, nil, err
}
arg = argVal.Elem().Interface()
ctx = newContext(pCtx, c)
// Get reflect value of context
@@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
if err != nil {
s.sendErr(c, call, val, err)
} else {
valData, err := c.Marshal(val)
if err != nil {
s.sendErr(c, call, val, err)
continue
}
// Create response
res := types.Response{
ID: call.ID,
Return: val,
Return: valData,
}
// If function has created a channel
if ctx.isChannel {
idData, err := c.Marshal(ctx.channelID)
if err != nil {
s.sendErr(c, call, val, err)
continue
}
// Set IsChannel to true
res.Type = types.ResponseTypeChannel
// Overwrite return value with channel ID
res.Return = ctx.channelID
res.Return = idData
// Store context in map for future use
s.contextsMtx.Lock()
@@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
// For every value received from channel
for val := range ctx.channel {
codecMtx.Lock()
valData, err := c.Marshal(val)
if err != nil {
continue
}
// Encode response using codec
c.Encode(types.Response{
ID: ctx.channelID,
Return: val,
Return: valData,
})
codecMtx.Unlock()
}
@@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
// sendErr sends an error response
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
valData, _ := c.Marshal(val)
// Encode error response using codec
c.Encode(types.Response{
Type: types.ResponseTypeError,
ID: req.ID,
Error: err.Error(),
Return: val,
Return: valData,
})
}