Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec
This commit is contained in:
		@@ -28,7 +28,6 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"go.arsenm.dev/lrpc/codec"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/reflectutil"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/types"
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
)
 | 
			
		||||
@@ -37,12 +36,11 @@ import (
 | 
			
		||||
type any = interface{}
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	ErrInvalidType        = errors.New("type must be struct or pointer to struct")
 | 
			
		||||
	ErrNoSuchReceiver     = errors.New("no such receiver registered")
 | 
			
		||||
	ErrNoSuchMethod       = errors.New("no such method was found")
 | 
			
		||||
	ErrInvalidMethod      = errors.New("method invalid for lrpc call")
 | 
			
		||||
	ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
 | 
			
		||||
	ErrArgNotProvided     = errors.New("method expected an argument, but none was provided")
 | 
			
		||||
	ErrInvalidType    = errors.New("type must be struct or pointer to struct")
 | 
			
		||||
	ErrNoSuchReceiver = errors.New("no such receiver registered")
 | 
			
		||||
	ErrNoSuchMethod   = errors.New("no such method was found")
 | 
			
		||||
	ErrInvalidMethod  = errors.New("method invalid for lrpc call")
 | 
			
		||||
	ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Server is an lrpc server
 | 
			
		||||
@@ -101,7 +99,7 @@ func (s *Server) Register(v any) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// execute runs a method of a registered value
 | 
			
		||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
	// Try to get value from receivers map
 | 
			
		||||
	val, ok := s.rcvrs[typ]
 | 
			
		||||
	if !ok {
 | 
			
		||||
@@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
 | 
			
		||||
	// Get method type
 | 
			
		||||
	mtdType := mtd.Type()
 | 
			
		||||
 | 
			
		||||
	// Return error if argument provided but isn't expected
 | 
			
		||||
	if mtdType.NumIn() == 1 && arg != nil {
 | 
			
		||||
		return nil, nil, ErrUnexpectedArgument
 | 
			
		||||
	}
 | 
			
		||||
	//TODO: if arg not nil but fn has no arg, err
 | 
			
		||||
 | 
			
		||||
	// IF argument is []any
 | 
			
		||||
	anySlice, ok := arg.([]any)
 | 
			
		||||
	if ok {
 | 
			
		||||
		// Convert slice to the method's arg type and
 | 
			
		||||
		// set arg to the newly-converted slice
 | 
			
		||||
		arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
 | 
			
		||||
	}
 | 
			
		||||
	argType := mtdType.In(1)
 | 
			
		||||
	argVal := reflect.New(argType)
 | 
			
		||||
	arg := argVal.Interface()
 | 
			
		||||
 | 
			
		||||
	// Get argument value
 | 
			
		||||
	argVal := reflect.ValueOf(arg)
 | 
			
		||||
	// If argument's type does not match method's argument type
 | 
			
		||||
	if arg != nil && argVal.Type() != mtdType.In(1) {
 | 
			
		||||
		val, err = reflectutil.Convert(argVal, mtdType.In(1))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, nil, err
 | 
			
		||||
		}
 | 
			
		||||
		arg = val.Interface()
 | 
			
		||||
	err = c.Unmarshal(data, arg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
	arg = argVal.Elem().Interface()
 | 
			
		||||
	
 | 
			
		||||
 | 
			
		||||
	ctx = newContext(pCtx, c)
 | 
			
		||||
	// Get reflect value of context
 | 
			
		||||
@@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.sendErr(c, call, val, err)
 | 
			
		||||
		} else {
 | 
			
		||||
			valData, err := c.Marshal(val)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				s.sendErr(c, call, val, err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Create response
 | 
			
		||||
			res := types.Response{
 | 
			
		||||
				ID:     call.ID,
 | 
			
		||||
				Return: val,
 | 
			
		||||
				Return: valData,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// If function has created a channel
 | 
			
		||||
			if ctx.isChannel {
 | 
			
		||||
				idData, err := c.Marshal(ctx.channelID)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					s.sendErr(c, call, val, err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Set IsChannel to true
 | 
			
		||||
				res.Type = types.ResponseTypeChannel
 | 
			
		||||
				// Overwrite return value with channel ID
 | 
			
		||||
				res.Return = ctx.channelID
 | 
			
		||||
				res.Return = idData
 | 
			
		||||
 | 
			
		||||
				// Store context in map for future use
 | 
			
		||||
				s.contextsMtx.Lock()
 | 
			
		||||
@@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
					// For every value received from channel
 | 
			
		||||
					for val := range ctx.channel {
 | 
			
		||||
						codecMtx.Lock()
 | 
			
		||||
 | 
			
		||||
						valData, err := c.Marshal(val)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						// Encode response using codec
 | 
			
		||||
						c.Encode(types.Response{
 | 
			
		||||
							ID:     ctx.channelID,
 | 
			
		||||
							Return: val,
 | 
			
		||||
							Return: valData,
 | 
			
		||||
						})
 | 
			
		||||
 | 
			
		||||
						codecMtx.Unlock()
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
@@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
 | 
			
		||||
// sendErr sends an error response
 | 
			
		||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
 | 
			
		||||
	valData, _ := c.Marshal(val)
 | 
			
		||||
 | 
			
		||||
	// Encode error response using codec
 | 
			
		||||
	c.Encode(types.Response{
 | 
			
		||||
		Type:   types.ResponseTypeError,
 | 
			
		||||
		ID:     req.ID,
 | 
			
		||||
		Error:  err.Error(),
 | 
			
		||||
		Return: val,
 | 
			
		||||
		Return: valData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user