Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec
This commit is contained in:
		@@ -26,7 +26,6 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"go.arsenm.dev/lrpc/codec"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/reflectutil"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/types"
 | 
			
		||||
 | 
			
		||||
	"github.com/gofrs/uuid"
 | 
			
		||||
@@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
	c.chs[idStr] = make(chan *types.Response, 1)
 | 
			
		||||
	c.chMtx.Unlock()
 | 
			
		||||
 | 
			
		||||
	argData, err := c.codec.Marshal(arg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode request using codec
 | 
			
		||||
	err = c.codec.Encode(types.Request{
 | 
			
		||||
		ID:       idStr,
 | 
			
		||||
		Receiver: rcvr,
 | 
			
		||||
		Method:   method,
 | 
			
		||||
		Arg:      arg,
 | 
			
		||||
		Arg:      argData,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
			return ErrReturnNotChannel
 | 
			
		||||
		}
 | 
			
		||||
		// Get channel ID returned in response
 | 
			
		||||
		chID := resp.Return.(string)
 | 
			
		||||
		var chID string
 | 
			
		||||
		err = c.codec.Unmarshal(resp.Return, &chID)
 | 
			
		||||
		if resp.Return == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create new channel using channel ID
 | 
			
		||||
		c.chMtx.Lock()
 | 
			
		||||
@@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
					retVal.Close()
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				// Get reflect value from channel response
 | 
			
		||||
				rVal := reflect.ValueOf(val.Return)
 | 
			
		||||
 | 
			
		||||
				// If return value is not the same as the channel
 | 
			
		||||
				if rVal.Type() != chElemType {
 | 
			
		||||
					// Attempt to convert value, skip if impossible
 | 
			
		||||
					newVal, err := reflectutil.Convert(rVal, chElemType)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					rVal = newVal
 | 
			
		||||
				outVal := reflect.New(chElemType)
 | 
			
		||||
				err = c.codec.Unmarshal(val.Return, outVal.Interface())
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				outVal = outVal.Elem()
 | 
			
		||||
 | 
			
		||||
				chosen, _, _ := reflect.Select([]reflect.SelectCase{
 | 
			
		||||
					{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
 | 
			
		||||
					{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
 | 
			
		||||
					{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
 | 
			
		||||
				})
 | 
			
		||||
				if chosen == 1 {
 | 
			
		||||
@@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	} else if resp.Type == types.ResponseTypeNormal {
 | 
			
		||||
		// IF return value is not a pointer, return error
 | 
			
		||||
		if retVal.Kind() != reflect.Ptr {
 | 
			
		||||
			return ErrReturnNotPointer
 | 
			
		||||
		err = c.codec.Unmarshal(resp.Return, ret)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Get return type
 | 
			
		||||
		retType := retVal.Type().Elem()
 | 
			
		||||
		// Get refkect value from response
 | 
			
		||||
		rVal := reflect.ValueOf(resp.Return)
 | 
			
		||||
 | 
			
		||||
		// If types do not match
 | 
			
		||||
		if rVal.Type() != retType {
 | 
			
		||||
			// Attempt to convert types, return error if not possible
 | 
			
		||||
			newVal, err := reflectutil.Convert(rVal, retType)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			rVal = newVal
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Set return value to received value
 | 
			
		||||
		retVal.Elem().Set(rVal)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user