Compare commits
2 Commits
5e61e89ac1
...
d35a16ec64
Author | SHA1 | Date | |
---|---|---|---|
d35a16ec64 | |||
e02c8bc5ff |
@ -26,7 +26,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
c.chs[idStr] = make(chan *types.Response, 1)
|
||||
c.chMtx.Unlock()
|
||||
|
||||
argData, err := c.codec.Marshal(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode request using codec
|
||||
err = c.codec.Encode(types.Request{
|
||||
ID: idStr,
|
||||
Receiver: rcvr,
|
||||
Method: method,
|
||||
Arg: arg,
|
||||
Arg: argData,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
return ErrReturnNotChannel
|
||||
}
|
||||
// Get channel ID returned in response
|
||||
chID := resp.Return.(string)
|
||||
var chID string
|
||||
err = c.codec.Unmarshal(resp.Return, &chID)
|
||||
if resp.Return == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new channel using channel ID
|
||||
c.chMtx.Lock()
|
||||
@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
retVal.Close()
|
||||
break
|
||||
}
|
||||
// Get reflect value from channel response
|
||||
rVal := reflect.ValueOf(val.Return)
|
||||
|
||||
// If return value is not the same as the channel
|
||||
if rVal.Type() != chElemType {
|
||||
// Attempt to convert value, skip if impossible
|
||||
newVal, err := reflectutil.Convert(rVal, chElemType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
rVal = newVal
|
||||
outVal := reflect.New(chElemType)
|
||||
err = c.codec.Unmarshal(val.Return, outVal.Interface())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
outVal = outVal.Elem()
|
||||
|
||||
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
||||
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
|
||||
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
|
||||
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
||||
})
|
||||
if chosen == 1 {
|
||||
@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
|
||||
}
|
||||
}()
|
||||
} else if resp.Type == types.ResponseTypeNormal {
|
||||
// IF return value is not a pointer, return error
|
||||
if retVal.Kind() != reflect.Ptr {
|
||||
return ErrReturnNotPointer
|
||||
err = c.codec.Unmarshal(resp.Return, ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get return type
|
||||
retType := retVal.Type().Elem()
|
||||
// Get refkect value from response
|
||||
rVal := reflect.ValueOf(resp.Return)
|
||||
|
||||
// If types do not match
|
||||
if rVal.Type() != retType {
|
||||
// Attempt to convert types, return error if not possible
|
||||
newVal, err := reflectutil.Convert(rVal, retType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rVal = newVal
|
||||
}
|
||||
|
||||
// Set return value to received value
|
||||
retVal.Elem().Set(rVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -19,6 +19,7 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@ -38,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec
|
||||
type Codec interface {
|
||||
Encode(val any) error
|
||||
Decode(val any) error
|
||||
Unmarshal(data []byte, v any) error
|
||||
Marshal(v any) ([]byte, error)
|
||||
}
|
||||
|
||||
// Default is the default CodecFunc
|
||||
var Default = Msgpack
|
||||
|
||||
type JsonCodec struct {
|
||||
*json.Encoder
|
||||
*json.Decoder
|
||||
}
|
||||
|
||||
func (JsonCodec) Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (JsonCodec) Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// JSON is a CodecFunc that creates a JSON Codec
|
||||
func JSON(rw io.ReadWriter) Codec {
|
||||
type jsonCodec struct {
|
||||
*json.Encoder
|
||||
*json.Decoder
|
||||
}
|
||||
return jsonCodec{
|
||||
return JsonCodec{
|
||||
Encoder: json.NewEncoder(rw),
|
||||
Decoder: json.NewDecoder(rw),
|
||||
}
|
||||
}
|
||||
|
||||
type MsgpackCodec struct {
|
||||
*msgpack.Encoder
|
||||
*msgpack.Decoder
|
||||
}
|
||||
|
||||
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
|
||||
return msgpack.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
|
||||
return msgpack.Marshal(v)
|
||||
}
|
||||
|
||||
// Msgpack is a CodecFunc that creates a Msgpack Codec
|
||||
func Msgpack(rw io.ReadWriter) Codec {
|
||||
type msgpackCodec struct {
|
||||
*msgpack.Encoder
|
||||
*msgpack.Decoder
|
||||
}
|
||||
return msgpackCodec{
|
||||
return MsgpackCodec{
|
||||
Encoder: msgpack.NewEncoder(rw),
|
||||
Decoder: msgpack.NewDecoder(rw),
|
||||
}
|
||||
}
|
||||
|
||||
type GobCodec struct {
|
||||
*gob.Encoder
|
||||
*gob.Decoder
|
||||
}
|
||||
|
||||
func (GobCodec) Unmarshal(data []byte, v any) error {
|
||||
return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
}
|
||||
|
||||
func (GobCodec) Marshal(v any) ([]byte, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
err := gob.NewEncoder(buf).Encode(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Gob is a CodecFunc that creates a Gob Codec
|
||||
func Gob(rw io.ReadWriter) Codec {
|
||||
type gobCodec struct {
|
||||
*gob.Encoder
|
||||
*gob.Decoder
|
||||
}
|
||||
return gobCodec{
|
||||
return GobCodec{
|
||||
Encoder: gob.NewEncoder(rw),
|
||||
Decoder: gob.NewDecoder(rw),
|
||||
}
|
||||
|
@ -1,188 +0,0 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// <= go1.17 compatibility
|
||||
type any = interface{}
|
||||
|
||||
// Convert attempts to convert the given value to the given type
|
||||
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
|
||||
// Get input type
|
||||
inType := in.Type()
|
||||
|
||||
// If input is already the desired type, return
|
||||
if inType == toType {
|
||||
return in, nil
|
||||
}
|
||||
|
||||
// If the output type is a pointer to the input type
|
||||
if reflect.PtrTo(inType) == toType {
|
||||
if in.CanAddr() {
|
||||
// Return pointer to input
|
||||
return in.Addr(), nil
|
||||
}
|
||||
|
||||
inPtrVal := reflect.New(inType)
|
||||
inPtrVal.Elem().Set(in)
|
||||
return inPtrVal, nil
|
||||
}
|
||||
|
||||
// If input is a pointer pointing to the output type
|
||||
if inType.Kind() == reflect.Ptr && inType.Elem() == toType {
|
||||
// Return value being pointed at by input
|
||||
return reflect.Indirect(in), nil
|
||||
}
|
||||
|
||||
// If input can be converted to desired type, convert and return
|
||||
if in.CanConvert(toType) {
|
||||
return in.Convert(toType), nil
|
||||
}
|
||||
|
||||
// Create new value of desired type
|
||||
to := reflect.New(toType).Elem()
|
||||
|
||||
// If type is a pointer
|
||||
if to.Kind() == reflect.Ptr {
|
||||
// Initialize value
|
||||
to.Set(reflect.New(to.Type().Elem()))
|
||||
}
|
||||
|
||||
switch val := in.Interface().(type) {
|
||||
case string:
|
||||
// If desired type satisfies text unmarshaler
|
||||
if u, ok := to.Interface().(encoding.TextUnmarshaler); ok {
|
||||
// Use text unmarshaler to get value
|
||||
err := u.UnmarshalText([]byte(val))
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Return unmarshaled value
|
||||
return reflect.ValueOf(any(u)), nil
|
||||
}
|
||||
case []byte:
|
||||
// If desired type satisfies binary unmarshaler
|
||||
if u, ok := to.Interface().(encoding.BinaryUnmarshaler); ok {
|
||||
// Use binary unmarshaler to get value
|
||||
err := u.UnmarshalBinary(val)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Return unmarshaled value
|
||||
return reflect.ValueOf(any(u)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// If input is a map
|
||||
if in.Kind() == reflect.Map {
|
||||
// Use mapstructure to decode value
|
||||
err := mapstructure.Decode(in.Interface(), to.Addr().Interface())
|
||||
if err == nil {
|
||||
return to, nil
|
||||
} else {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// If input is a slice of any, and output is an array or slice
|
||||
if in.Type() == reflect.TypeOf([]any{}) &&
|
||||
(to.Kind() == reflect.Slice || to.Kind() == reflect.Array) {
|
||||
// Use ConvertSlice to convert value
|
||||
return reflect.ValueOf(ConvertSlice(
|
||||
in.Interface().([]any),
|
||||
toType,
|
||||
)), nil
|
||||
}
|
||||
|
||||
return to, fmt.Errorf("cannot convert %s to %s", inType, toType)
|
||||
}
|
||||
|
||||
// ConvertSlice converts []any to an array or slice, as provided
|
||||
// in the "to" argument.
|
||||
func ConvertSlice(in []any, to reflect.Type) any {
|
||||
// Create new value for output
|
||||
out := reflect.New(to).Elem()
|
||||
|
||||
// Get type of slice elements
|
||||
outType := out.Type().Elem()
|
||||
|
||||
// If output value is a slice
|
||||
if out.Kind() == reflect.Slice {
|
||||
// For every value provided
|
||||
for i := 0; i < len(in); i++ {
|
||||
// Get value of input type
|
||||
inVal := reflect.ValueOf(in[i])
|
||||
// Create new output type
|
||||
outVal := reflect.New(outType).Elem()
|
||||
|
||||
// If types match
|
||||
if inVal.Type() == outType {
|
||||
// Set output value to input value
|
||||
outVal.Set(inVal)
|
||||
} else {
|
||||
newVal, err := Convert(inVal, outType)
|
||||
if err != nil {
|
||||
// Set output value to its zero value
|
||||
outVal.Set(reflect.Zero(outVal.Type()))
|
||||
} else {
|
||||
outVal.Set(newVal)
|
||||
}
|
||||
}
|
||||
|
||||
// Append output value to slice
|
||||
out = reflect.Append(out, outVal)
|
||||
}
|
||||
} else if out.Kind() == reflect.Array && out.Len() == len(in) {
|
||||
//If output type is array and lengths match
|
||||
|
||||
// For every input value
|
||||
for i := 0; i < len(in); i++ {
|
||||
// Get matching output index
|
||||
outVal := out.Index(i)
|
||||
// Get input value
|
||||
inVal := reflect.ValueOf(in[i])
|
||||
|
||||
// If types match
|
||||
if inVal.Type() == outVal.Type() {
|
||||
// Set output value to input value
|
||||
outVal.Set(inVal)
|
||||
} else {
|
||||
newVal, err := Convert(inVal, outType)
|
||||
if err != nil {
|
||||
// Set output value to its zero value
|
||||
outVal.Set(reflect.Zero(outVal.Type()))
|
||||
} else {
|
||||
outVal.Set(newVal)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return created value
|
||||
return out.Interface()
|
||||
}
|
@ -26,7 +26,7 @@ type Request struct {
|
||||
ID string
|
||||
Receiver string
|
||||
Method string
|
||||
Arg any
|
||||
Arg []byte
|
||||
}
|
||||
|
||||
type ResponseType uint8
|
||||
@ -43,5 +43,5 @@ type Response struct {
|
||||
Type ResponseType
|
||||
ID string
|
||||
Error string
|
||||
Return any
|
||||
Return []byte
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ func TestCodecs(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Errorf("codec/%s: %v", name, err)
|
||||
}
|
||||
|
||||
|
||||
if add != 4 {
|
||||
t.Errorf("codec/%s: add: expected 4, got %d", name, add)
|
||||
}
|
||||
|
@ -28,7 +28,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
@ -37,12 +36,11 @@ import (
|
||||
type any = interface{}
|
||||
|
||||
var (
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||
ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
|
||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||
)
|
||||
|
||||
// Server is an lrpc server
|
||||
@ -101,7 +99,7 @@ func (s *Server) Register(v any) error {
|
||||
}
|
||||
|
||||
// execute runs a method of a registered value
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
// Try to get value from receivers map
|
||||
val, ok := s.rcvrs[typ]
|
||||
if !ok {
|
||||
@ -122,29 +120,19 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
|
||||
// Get method type
|
||||
mtdType := mtd.Type()
|
||||
|
||||
// Return error if argument provided but isn't expected
|
||||
if mtdType.NumIn() == 1 && arg != nil {
|
||||
return nil, nil, ErrUnexpectedArgument
|
||||
}
|
||||
//TODO: if arg not nil but fn has no arg, err
|
||||
|
||||
// IF argument is []any
|
||||
anySlice, ok := arg.([]any)
|
||||
if ok {
|
||||
// Convert slice to the method's arg type and
|
||||
// set arg to the newly-converted slice
|
||||
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
|
||||
}
|
||||
argType := mtdType.In(1)
|
||||
argVal := reflect.New(argType)
|
||||
arg := argVal.Interface()
|
||||
|
||||
// Get argument value
|
||||
argVal := reflect.ValueOf(arg)
|
||||
// If argument's type does not match method's argument type
|
||||
if arg != nil && argVal.Type() != mtdType.In(1) {
|
||||
val, err = reflectutil.Convert(argVal, mtdType.In(1))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
arg = val.Interface()
|
||||
err = c.Unmarshal(data, arg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
arg = argVal.Elem().Interface()
|
||||
|
||||
|
||||
ctx = newContext(pCtx, c)
|
||||
// Get reflect value of context
|
||||
@ -327,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
} else {
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Create response
|
||||
res := types.Response{
|
||||
ID: call.ID,
|
||||
Return: val,
|
||||
Return: valData,
|
||||
}
|
||||
|
||||
// If function has created a channel
|
||||
if ctx.isChannel {
|
||||
idData, err := c.Marshal(ctx.channelID)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set IsChannel to true
|
||||
res.Type = types.ResponseTypeChannel
|
||||
// Overwrite return value with channel ID
|
||||
res.Return = ctx.channelID
|
||||
res.Return = idData
|
||||
|
||||
// Store context in map for future use
|
||||
s.contextsMtx.Lock()
|
||||
@ -349,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
// For every value received from channel
|
||||
for val := range ctx.channel {
|
||||
codecMtx.Lock()
|
||||
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
Return: val,
|
||||
Return: valData,
|
||||
})
|
||||
|
||||
codecMtx.Unlock()
|
||||
}
|
||||
|
||||
@ -383,12 +390,14 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
|
||||
// sendErr sends an error response
|
||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
||||
valData, _ := c.Marshal(val)
|
||||
|
||||
// Encode error response using codec
|
||||
c.Encode(types.Response{
|
||||
Type: types.ResponseTypeError,
|
||||
ID: req.ID,
|
||||
Error: err.Error(),
|
||||
Return: val,
|
||||
Return: valData,
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user