Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec
This commit is contained in:
@@ -26,7 +26,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
@@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
c.chs[idStr] = make(chan *types.Response, 1)
|
||||
c.chMtx.Unlock()
|
||||
|
||||
argData, err := c.codec.Marshal(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode request using codec
|
||||
err = c.codec.Encode(types.Request{
|
||||
ID: idStr,
|
||||
Receiver: rcvr,
|
||||
Method: method,
|
||||
Arg: arg,
|
||||
Arg: argData,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
return ErrReturnNotChannel
|
||||
}
|
||||
// Get channel ID returned in response
|
||||
chID := resp.Return.(string)
|
||||
var chID string
|
||||
err = c.codec.Unmarshal(resp.Return, &chID)
|
||||
if resp.Return == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new channel using channel ID
|
||||
c.chMtx.Lock()
|
||||
@@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
retVal.Close()
|
||||
break
|
||||
}
|
||||
// Get reflect value from channel response
|
||||
rVal := reflect.ValueOf(val.Return)
|
||||
|
||||
// If return value is not the same as the channel
|
||||
if rVal.Type() != chElemType {
|
||||
// Attempt to convert value, skip if impossible
|
||||
newVal, err := reflectutil.Convert(rVal, chElemType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
rVal = newVal
|
||||
outVal := reflect.New(chElemType)
|
||||
err = c.codec.Unmarshal(val.Return, outVal.Interface())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
outVal = outVal.Elem()
|
||||
|
||||
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
||||
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
|
||||
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
|
||||
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
||||
})
|
||||
if chosen == 1 {
|
||||
@@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
}
|
||||
}()
|
||||
} else if resp.Type == types.ResponseTypeNormal {
|
||||
// IF return value is not a pointer, return error
|
||||
if retVal.Kind() != reflect.Ptr {
|
||||
return ErrReturnNotPointer
|
||||
err = c.codec.Unmarshal(resp.Return, ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get return type
|
||||
retType := retVal.Type().Elem()
|
||||
// Get refkect value from response
|
||||
rVal := reflect.ValueOf(resp.Return)
|
||||
|
||||
// If types do not match
|
||||
if rVal.Type() != retType {
|
||||
// Attempt to convert types, return error if not possible
|
||||
newVal, err := reflectutil.Convert(rVal, retType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rVal = newVal
|
||||
}
|
||||
|
||||
// Set return value to received value
|
||||
retVal.Elem().Set(rVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user