Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec
This commit is contained in:
parent
5e61e89ac1
commit
e02c8bc5ff
@ -26,7 +26,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"go.arsenm.dev/lrpc/codec"
|
"go.arsenm.dev/lrpc/codec"
|
||||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
|
||||||
"go.arsenm.dev/lrpc/internal/types"
|
"go.arsenm.dev/lrpc/internal/types"
|
||||||
|
|
||||||
"github.com/gofrs/uuid"
|
"github.com/gofrs/uuid"
|
||||||
@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
|||||||
c.chs[idStr] = make(chan *types.Response, 1)
|
c.chs[idStr] = make(chan *types.Response, 1)
|
||||||
c.chMtx.Unlock()
|
c.chMtx.Unlock()
|
||||||
|
|
||||||
|
argData, err := c.codec.Marshal(arg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Encode request using codec
|
// Encode request using codec
|
||||||
err = c.codec.Encode(types.Request{
|
err = c.codec.Encode(types.Request{
|
||||||
ID: idStr,
|
ID: idStr,
|
||||||
Receiver: rcvr,
|
Receiver: rcvr,
|
||||||
Method: method,
|
Method: method,
|
||||||
Arg: arg,
|
Arg: argData,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
|||||||
return ErrReturnNotChannel
|
return ErrReturnNotChannel
|
||||||
}
|
}
|
||||||
// Get channel ID returned in response
|
// Get channel ID returned in response
|
||||||
chID := resp.Return.(string)
|
var chID string
|
||||||
|
err = c.codec.Unmarshal(resp.Return, &chID)
|
||||||
|
if resp.Return == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Create new channel using channel ID
|
// Create new channel using channel ID
|
||||||
c.chMtx.Lock()
|
c.chMtx.Lock()
|
||||||
@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
|||||||
retVal.Close()
|
retVal.Close()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// Get reflect value from channel response
|
|
||||||
rVal := reflect.ValueOf(val.Return)
|
|
||||||
|
|
||||||
// If return value is not the same as the channel
|
outVal := reflect.New(chElemType)
|
||||||
if rVal.Type() != chElemType {
|
err = c.codec.Unmarshal(val.Return, outVal.Interface())
|
||||||
// Attempt to convert value, skip if impossible
|
if err != nil {
|
||||||
newVal, err := reflectutil.Convert(rVal, chElemType)
|
continue
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rVal = newVal
|
|
||||||
}
|
}
|
||||||
|
outVal = outVal.Elem()
|
||||||
|
|
||||||
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
||||||
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
|
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
|
||||||
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
||||||
})
|
})
|
||||||
if chosen == 1 {
|
if chosen == 1 {
|
||||||
@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
} else if resp.Type == types.ResponseTypeNormal {
|
} else if resp.Type == types.ResponseTypeNormal {
|
||||||
// IF return value is not a pointer, return error
|
err = c.codec.Unmarshal(resp.Return, ret)
|
||||||
if retVal.Kind() != reflect.Ptr {
|
if err != nil {
|
||||||
return ErrReturnNotPointer
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return type
|
|
||||||
retType := retVal.Type().Elem()
|
|
||||||
// Get refkect value from response
|
|
||||||
rVal := reflect.ValueOf(resp.Return)
|
|
||||||
|
|
||||||
// If types do not match
|
|
||||||
if rVal.Type() != retType {
|
|
||||||
// Attempt to convert types, return error if not possible
|
|
||||||
newVal, err := reflectutil.Convert(rVal, retType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rVal = newVal
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set return value to received value
|
|
||||||
retVal.Elem().Set(rVal)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
package codec
|
package codec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
@ -38,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec
|
|||||||
type Codec interface {
|
type Codec interface {
|
||||||
Encode(val any) error
|
Encode(val any) error
|
||||||
Decode(val any) error
|
Decode(val any) error
|
||||||
|
Unmarshal(data []byte, v any) error
|
||||||
|
Marshal(v any) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default is the default CodecFunc
|
// Default is the default CodecFunc
|
||||||
var Default = Msgpack
|
var Default = Msgpack
|
||||||
|
|
||||||
|
type JsonCodec struct {
|
||||||
|
*json.Encoder
|
||||||
|
*json.Decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (JsonCodec) Unmarshal(data []byte, v any) error {
|
||||||
|
return json.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (JsonCodec) Marshal(v any) ([]byte, error) {
|
||||||
|
return json.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
// JSON is a CodecFunc that creates a JSON Codec
|
// JSON is a CodecFunc that creates a JSON Codec
|
||||||
func JSON(rw io.ReadWriter) Codec {
|
func JSON(rw io.ReadWriter) Codec {
|
||||||
type jsonCodec struct {
|
return JsonCodec{
|
||||||
*json.Encoder
|
|
||||||
*json.Decoder
|
|
||||||
}
|
|
||||||
return jsonCodec{
|
|
||||||
Encoder: json.NewEncoder(rw),
|
Encoder: json.NewEncoder(rw),
|
||||||
Decoder: json.NewDecoder(rw),
|
Decoder: json.NewDecoder(rw),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MsgpackCodec struct {
|
||||||
|
*msgpack.Encoder
|
||||||
|
*msgpack.Decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
|
||||||
|
return msgpack.Unmarshal(data, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
|
||||||
|
return msgpack.Marshal(v)
|
||||||
|
}
|
||||||
|
|
||||||
// Msgpack is a CodecFunc that creates a Msgpack Codec
|
// Msgpack is a CodecFunc that creates a Msgpack Codec
|
||||||
func Msgpack(rw io.ReadWriter) Codec {
|
func Msgpack(rw io.ReadWriter) Codec {
|
||||||
type msgpackCodec struct {
|
return MsgpackCodec{
|
||||||
*msgpack.Encoder
|
|
||||||
*msgpack.Decoder
|
|
||||||
}
|
|
||||||
return msgpackCodec{
|
|
||||||
Encoder: msgpack.NewEncoder(rw),
|
Encoder: msgpack.NewEncoder(rw),
|
||||||
Decoder: msgpack.NewDecoder(rw),
|
Decoder: msgpack.NewDecoder(rw),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GobCodec struct {
|
||||||
|
*gob.Encoder
|
||||||
|
*gob.Decoder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (GobCodec) Unmarshal(data []byte, v any) error {
|
||||||
|
return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (GobCodec) Marshal(v any) ([]byte, error) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
err := gob.NewEncoder(buf).Encode(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Gob is a CodecFunc that creates a Gob Codec
|
// Gob is a CodecFunc that creates a Gob Codec
|
||||||
func Gob(rw io.ReadWriter) Codec {
|
func Gob(rw io.ReadWriter) Codec {
|
||||||
type gobCodec struct {
|
return GobCodec{
|
||||||
*gob.Encoder
|
|
||||||
*gob.Decoder
|
|
||||||
}
|
|
||||||
return gobCodec{
|
|
||||||
Encoder: gob.NewEncoder(rw),
|
Encoder: gob.NewEncoder(rw),
|
||||||
Decoder: gob.NewDecoder(rw),
|
Decoder: gob.NewDecoder(rw),
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,7 @@ type Request struct {
|
|||||||
ID string
|
ID string
|
||||||
Receiver string
|
Receiver string
|
||||||
Method string
|
Method string
|
||||||
Arg any
|
Arg []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseType uint8
|
type ResponseType uint8
|
||||||
@ -43,5 +43,5 @@ type Response struct {
|
|||||||
Type ResponseType
|
Type ResponseType
|
||||||
ID string
|
ID string
|
||||||
Error string
|
Error string
|
||||||
Return any
|
Return []byte
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"go.arsenm.dev/lrpc/codec"
|
"go.arsenm.dev/lrpc/codec"
|
||||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
|
||||||
"go.arsenm.dev/lrpc/internal/types"
|
"go.arsenm.dev/lrpc/internal/types"
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
@ -37,12 +36,11 @@ import (
|
|||||||
type any = interface{}
|
type any = interface{}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||||
ErrNoSuchMethod = errors.New("no such method was found")
|
ErrNoSuchMethod = errors.New("no such method was found")
|
||||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||||
ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
|
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server is an lrpc server
|
// Server is an lrpc server
|
||||||
@ -101,7 +99,7 @@ func (s *Server) Register(v any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute runs a method of a registered value
|
// execute runs a method of a registered value
|
||||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
|
||||||
// Try to get value from receivers map
|
// Try to get value from receivers map
|
||||||
val, ok := s.rcvrs[typ]
|
val, ok := s.rcvrs[typ]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
|
|||||||
// Get method type
|
// Get method type
|
||||||
mtdType := mtd.Type()
|
mtdType := mtd.Type()
|
||||||
|
|
||||||
// Return error if argument provided but isn't expected
|
//TODO: if arg not nil but fn has no arg, err
|
||||||
if mtdType.NumIn() == 1 && arg != nil {
|
|
||||||
return nil, nil, ErrUnexpectedArgument
|
argType := mtdType.In(1)
|
||||||
|
argVal := reflect.New(argType)
|
||||||
|
arg := argVal.Interface()
|
||||||
|
|
||||||
|
err = c.Unmarshal(data, arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// IF argument is []any
|
arg = argVal.Elem().Interface()
|
||||||
anySlice, ok := arg.([]any)
|
|
||||||
if ok {
|
|
||||||
// Convert slice to the method's arg type and
|
|
||||||
// set arg to the newly-converted slice
|
|
||||||
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get argument value
|
|
||||||
argVal := reflect.ValueOf(arg)
|
|
||||||
// If argument's type does not match method's argument type
|
|
||||||
if arg != nil && argVal.Type() != mtdType.In(1) {
|
|
||||||
val, err = reflectutil.Convert(argVal, mtdType.In(1))
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
arg = val.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = newContext(pCtx, c)
|
ctx = newContext(pCtx, c)
|
||||||
// Get reflect value of context
|
// Get reflect value of context
|
||||||
@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
s.sendErr(c, call, val, err)
|
s.sendErr(c, call, val, err)
|
||||||
} else {
|
} else {
|
||||||
|
valData, err := c.Marshal(val)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErr(c, call, val, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Create response
|
// Create response
|
||||||
res := types.Response{
|
res := types.Response{
|
||||||
ID: call.ID,
|
ID: call.ID,
|
||||||
Return: val,
|
Return: valData,
|
||||||
}
|
}
|
||||||
|
|
||||||
// If function has created a channel
|
// If function has created a channel
|
||||||
if ctx.isChannel {
|
if ctx.isChannel {
|
||||||
|
idData, err := c.Marshal(ctx.channelID)
|
||||||
|
if err != nil {
|
||||||
|
s.sendErr(c, call, val, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Set IsChannel to true
|
// Set IsChannel to true
|
||||||
res.Type = types.ResponseTypeChannel
|
res.Type = types.ResponseTypeChannel
|
||||||
// Overwrite return value with channel ID
|
// Overwrite return value with channel ID
|
||||||
res.Return = ctx.channelID
|
res.Return = idData
|
||||||
|
|
||||||
// Store context in map for future use
|
// Store context in map for future use
|
||||||
s.contextsMtx.Lock()
|
s.contextsMtx.Lock()
|
||||||
@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
|||||||
// For every value received from channel
|
// For every value received from channel
|
||||||
for val := range ctx.channel {
|
for val := range ctx.channel {
|
||||||
codecMtx.Lock()
|
codecMtx.Lock()
|
||||||
|
|
||||||
|
valData, err := c.Marshal(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Encode response using codec
|
// Encode response using codec
|
||||||
c.Encode(types.Response{
|
c.Encode(types.Response{
|
||||||
ID: ctx.channelID,
|
ID: ctx.channelID,
|
||||||
Return: val,
|
Return: valData,
|
||||||
})
|
})
|
||||||
|
|
||||||
codecMtx.Unlock()
|
codecMtx.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
|||||||
|
|
||||||
// sendErr sends an error response
|
// sendErr sends an error response
|
||||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
||||||
|
valData, _ := c.Marshal(val)
|
||||||
|
|
||||||
// Encode error response using codec
|
// Encode error response using codec
|
||||||
c.Encode(types.Response{
|
c.Encode(types.Response{
|
||||||
Type: types.ResponseTypeError,
|
Type: types.ResponseTypeError,
|
||||||
ID: req.ID,
|
ID: req.ID,
|
||||||
Error: err.Error(),
|
Error: err.Error(),
|
||||||
Return: val,
|
Return: valData,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user