Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec

This commit is contained in:
2022-08-06 22:48:42 -07:00
parent 5e61e89ac1
commit e02c8bc5ff
5 changed files with 114 additions and 85 deletions

View File

@@ -26,7 +26,6 @@ import (
"sync"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types"
"github.com/gofrs/uuid"
@@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
c.chs[idStr] = make(chan *types.Response, 1)
c.chMtx.Unlock()
argData, err := c.codec.Marshal(arg)
if err != nil {
return err
}
// Encode request using codec
err = c.codec.Encode(types.Request{
ID: idStr,
Receiver: rcvr,
Method: method,
Arg: arg,
Arg: argData,
})
if err != nil {
return err
@@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
return ErrReturnNotChannel
}
// Get channel ID returned in response
chID := resp.Return.(string)
var chID string
err = c.codec.Unmarshal(resp.Return, &chID)
if resp.Return == nil {
return nil
}
// Create new channel using channel ID
c.chMtx.Lock()
@@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
retVal.Close()
break
}
// Get reflect value from channel response
rVal := reflect.ValueOf(val.Return)
// If return value is not the same as the channel
if rVal.Type() != chElemType {
// Attempt to convert value, skip if impossible
newVal, err := reflectutil.Convert(rVal, chElemType)
if err != nil {
continue
}
rVal = newVal
outVal := reflect.New(chElemType)
err = c.codec.Unmarshal(val.Return, outVal.Interface())
if err != nil {
continue
}
outVal = outVal.Elem()
chosen, _, _ := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
})
if chosen == 1 {
@@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
}
}()
} else if resp.Type == types.ResponseTypeNormal {
// IF return value is not a pointer, return error
if retVal.Kind() != reflect.Ptr {
return ErrReturnNotPointer
err = c.codec.Unmarshal(resp.Return, ret)
if err != nil {
return err
}
// Get return type
retType := retVal.Type().Elem()
// Get refkect value from response
rVal := reflect.ValueOf(resp.Return)
// If types do not match
if rVal.Type() != retType {
// Attempt to convert types, return error if not possible
newVal, err := reflectutil.Convert(rVal, retType)
if err != nil {
return err
}
rVal = newVal
}
// Set return value to received value
retVal.Elem().Set(rVal)
}
return nil