Compare commits

..

No commits in common. "d35a16ec64f579a02b0f7f917b09fdb8bd740233" and "5e61e89ac154304acb6f2d7632c88cc010888d53" have entirely different histories.

6 changed files with 275 additions and 116 deletions

View File

@ -26,6 +26,7 @@ import (
"sync" "sync"
"go.arsenm.dev/lrpc/codec" "go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types" "go.arsenm.dev/lrpc/internal/types"
"github.com/gofrs/uuid" "github.com/gofrs/uuid"
@ -80,17 +81,12 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
c.chs[idStr] = make(chan *types.Response, 1) c.chs[idStr] = make(chan *types.Response, 1)
c.chMtx.Unlock() c.chMtx.Unlock()
argData, err := c.codec.Marshal(arg)
if err != nil {
return err
}
// Encode request using codec // Encode request using codec
err = c.codec.Encode(types.Request{ err = c.codec.Encode(types.Request{
ID: idStr, ID: idStr,
Receiver: rcvr, Receiver: rcvr,
Method: method, Method: method,
Arg: argData, Arg: arg,
}) })
if err != nil { if err != nil {
return err return err
@ -128,11 +124,7 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
return ErrReturnNotChannel return ErrReturnNotChannel
} }
// Get channel ID returned in response // Get channel ID returned in response
var chID string chID := resp.Return.(string)
err = c.codec.Unmarshal(resp.Return, &chID)
if resp.Return == nil {
return nil
}
// Create new channel using channel ID // Create new channel using channel ID
c.chMtx.Lock() c.chMtx.Lock()
@ -157,16 +149,21 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
retVal.Close() retVal.Close()
break break
} }
// Get reflect value from channel response
rVal := reflect.ValueOf(val.Return)
outVal := reflect.New(chElemType) // If return value is not the same as the channel
err = c.codec.Unmarshal(val.Return, outVal.Interface()) if rVal.Type() != chElemType {
// Attempt to convert value, skip if impossible
newVal, err := reflectutil.Convert(rVal, chElemType)
if err != nil { if err != nil {
continue continue
} }
outVal = outVal.Elem() rVal = newVal
}
chosen, _, _ := reflect.Select([]reflect.SelectCase{ chosen, _, _ := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal}, {Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}}, {Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
}) })
if chosen == 1 { if chosen == 1 {
@ -182,10 +179,28 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
} }
}() }()
} else if resp.Type == types.ResponseTypeNormal { } else if resp.Type == types.ResponseTypeNormal {
err = c.codec.Unmarshal(resp.Return, ret) // IF return value is not a pointer, return error
if retVal.Kind() != reflect.Ptr {
return ErrReturnNotPointer
}
// Get return type
retType := retVal.Type().Elem()
// Get refkect value from response
rVal := reflect.ValueOf(resp.Return)
// If types do not match
if rVal.Type() != retType {
// Attempt to convert types, return error if not possible
newVal, err := reflectutil.Convert(rVal, retType)
if err != nil { if err != nil {
return err return err
} }
rVal = newVal
}
// Set return value to received value
retVal.Elem().Set(rVal)
} }
return nil return nil

View File

@ -19,7 +19,6 @@
package codec package codec
import ( import (
"bytes"
"encoding/gob" "encoding/gob"
"encoding/json" "encoding/json"
"io" "io"
@ -39,76 +38,42 @@ type CodecFunc func(io.ReadWriter) Codec
type Codec interface { type Codec interface {
Encode(val any) error Encode(val any) error
Decode(val any) error Decode(val any) error
Unmarshal(data []byte, v any) error
Marshal(v any) ([]byte, error)
} }
// Default is the default CodecFunc // Default is the default CodecFunc
var Default = Msgpack var Default = Msgpack
type JsonCodec struct { // JSON is a CodecFunc that creates a JSON Codec
func JSON(rw io.ReadWriter) Codec {
type jsonCodec struct {
*json.Encoder *json.Encoder
*json.Decoder *json.Decoder
} }
return jsonCodec{
func (JsonCodec) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}
func (JsonCodec) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}
// JSON is a CodecFunc that creates a JSON Codec
func JSON(rw io.ReadWriter) Codec {
return JsonCodec{
Encoder: json.NewEncoder(rw), Encoder: json.NewEncoder(rw),
Decoder: json.NewDecoder(rw), Decoder: json.NewDecoder(rw),
} }
} }
type MsgpackCodec struct { // Msgpack is a CodecFunc that creates a Msgpack Codec
func Msgpack(rw io.ReadWriter) Codec {
type msgpackCodec struct {
*msgpack.Encoder *msgpack.Encoder
*msgpack.Decoder *msgpack.Decoder
} }
return msgpackCodec{
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
return msgpack.Unmarshal(data, v)
}
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
return msgpack.Marshal(v)
}
// Msgpack is a CodecFunc that creates a Msgpack Codec
func Msgpack(rw io.ReadWriter) Codec {
return MsgpackCodec{
Encoder: msgpack.NewEncoder(rw), Encoder: msgpack.NewEncoder(rw),
Decoder: msgpack.NewDecoder(rw), Decoder: msgpack.NewDecoder(rw),
} }
} }
type GobCodec struct { // Gob is a CodecFunc that creates a Gob Codec
func Gob(rw io.ReadWriter) Codec {
type gobCodec struct {
*gob.Encoder *gob.Encoder
*gob.Decoder *gob.Decoder
} }
return gobCodec{
func (GobCodec) Unmarshal(data []byte, v any) error {
return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
}
func (GobCodec) Marshal(v any) ([]byte, error) {
buf := &bytes.Buffer{}
err := gob.NewEncoder(buf).Encode(v)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Gob is a CodecFunc that creates a Gob Codec
func Gob(rw io.ReadWriter) Codec {
return GobCodec{
Encoder: gob.NewEncoder(rw), Encoder: gob.NewEncoder(rw),
Decoder: gob.NewDecoder(rw), Decoder: gob.NewDecoder(rw),
} }

View File

@ -0,0 +1,188 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package reflectutil
import (
"encoding"
"fmt"
"reflect"
"github.com/mitchellh/mapstructure"
)
// <= go1.17 compatibility
type any = interface{}
// Convert attempts to convert the given value to the given type
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
// Get input type
inType := in.Type()
// If input is already the desired type, return
if inType == toType {
return in, nil
}
// If the output type is a pointer to the input type
if reflect.PtrTo(inType) == toType {
if in.CanAddr() {
// Return pointer to input
return in.Addr(), nil
}
inPtrVal := reflect.New(inType)
inPtrVal.Elem().Set(in)
return inPtrVal, nil
}
// If input is a pointer pointing to the output type
if inType.Kind() == reflect.Ptr && inType.Elem() == toType {
// Return value being pointed at by input
return reflect.Indirect(in), nil
}
// If input can be converted to desired type, convert and return
if in.CanConvert(toType) {
return in.Convert(toType), nil
}
// Create new value of desired type
to := reflect.New(toType).Elem()
// If type is a pointer
if to.Kind() == reflect.Ptr {
// Initialize value
to.Set(reflect.New(to.Type().Elem()))
}
switch val := in.Interface().(type) {
case string:
// If desired type satisfies text unmarshaler
if u, ok := to.Interface().(encoding.TextUnmarshaler); ok {
// Use text unmarshaler to get value
err := u.UnmarshalText([]byte(val))
if err != nil {
return reflect.Value{}, err
}
// Return unmarshaled value
return reflect.ValueOf(any(u)), nil
}
case []byte:
// If desired type satisfies binary unmarshaler
if u, ok := to.Interface().(encoding.BinaryUnmarshaler); ok {
// Use binary unmarshaler to get value
err := u.UnmarshalBinary(val)
if err != nil {
return reflect.Value{}, err
}
// Return unmarshaled value
return reflect.ValueOf(any(u)), nil
}
}
// If input is a map
if in.Kind() == reflect.Map {
// Use mapstructure to decode value
err := mapstructure.Decode(in.Interface(), to.Addr().Interface())
if err == nil {
return to, nil
} else {
return reflect.Value{}, err
}
}
// If input is a slice of any, and output is an array or slice
if in.Type() == reflect.TypeOf([]any{}) &&
(to.Kind() == reflect.Slice || to.Kind() == reflect.Array) {
// Use ConvertSlice to convert value
return reflect.ValueOf(ConvertSlice(
in.Interface().([]any),
toType,
)), nil
}
return to, fmt.Errorf("cannot convert %s to %s", inType, toType)
}
// ConvertSlice converts []any to an array or slice, as provided
// in the "to" argument.
func ConvertSlice(in []any, to reflect.Type) any {
// Create new value for output
out := reflect.New(to).Elem()
// Get type of slice elements
outType := out.Type().Elem()
// If output value is a slice
if out.Kind() == reflect.Slice {
// For every value provided
for i := 0; i < len(in); i++ {
// Get value of input type
inVal := reflect.ValueOf(in[i])
// Create new output type
outVal := reflect.New(outType).Elem()
// If types match
if inVal.Type() == outType {
// Set output value to input value
outVal.Set(inVal)
} else {
newVal, err := Convert(inVal, outType)
if err != nil {
// Set output value to its zero value
outVal.Set(reflect.Zero(outVal.Type()))
} else {
outVal.Set(newVal)
}
}
// Append output value to slice
out = reflect.Append(out, outVal)
}
} else if out.Kind() == reflect.Array && out.Len() == len(in) {
//If output type is array and lengths match
// For every input value
for i := 0; i < len(in); i++ {
// Get matching output index
outVal := out.Index(i)
// Get input value
inVal := reflect.ValueOf(in[i])
// If types match
if inVal.Type() == outVal.Type() {
// Set output value to input value
outVal.Set(inVal)
} else {
newVal, err := Convert(inVal, outType)
if err != nil {
// Set output value to its zero value
outVal.Set(reflect.Zero(outVal.Type()))
} else {
outVal.Set(newVal)
}
}
}
}
// Return created value
return out.Interface()
}

View File

@ -26,7 +26,7 @@ type Request struct {
ID string ID string
Receiver string Receiver string
Method string Method string
Arg []byte Arg any
} }
type ResponseType uint8 type ResponseType uint8
@ -43,5 +43,5 @@ type Response struct {
Type ResponseType Type ResponseType
ID string ID string
Error string Error string
Return []byte Return any
} }

View File

@ -28,6 +28,7 @@ import (
"sync" "sync"
"go.arsenm.dev/lrpc/codec" "go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types" "go.arsenm.dev/lrpc/internal/types"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
@ -40,6 +41,7 @@ var (
ErrNoSuchReceiver = errors.New("no such receiver registered") ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found") ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidMethod = errors.New("method invalid for lrpc call") ErrInvalidMethod = errors.New("method invalid for lrpc call")
ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
ErrArgNotProvided = errors.New("method expected an argument, but none was provided") ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
) )
@ -99,7 +101,7 @@ func (s *Server) Register(v any) error {
} }
// execute runs a method of a registered value // execute runs a method of a registered value
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) { func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
// Try to get value from receivers map // Try to get value from receivers map
val, ok := s.rcvrs[typ] val, ok := s.rcvrs[typ]
if !ok { if !ok {
@ -120,19 +122,29 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, data []b
// Get method type // Get method type
mtdType := mtd.Type() mtdType := mtd.Type()
//TODO: if arg not nil but fn has no arg, err // Return error if argument provided but isn't expected
if mtdType.NumIn() == 1 && arg != nil {
return nil, nil, ErrUnexpectedArgument
}
argType := mtdType.In(1) // IF argument is []any
argVal := reflect.New(argType) anySlice, ok := arg.([]any)
arg := argVal.Interface() if ok {
// Convert slice to the method's arg type and
// set arg to the newly-converted slice
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
}
err = c.Unmarshal(data, arg) // Get argument value
argVal := reflect.ValueOf(arg)
// If argument's type does not match method's argument type
if arg != nil && argVal.Type() != mtdType.In(1) {
val, err = reflectutil.Convert(argVal, mtdType.In(1))
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
arg = val.Interface()
arg = argVal.Elem().Interface() }
ctx = newContext(pCtx, c) ctx = newContext(pCtx, c)
// Get reflect value of context // Get reflect value of context
@ -315,30 +327,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
if err != nil { if err != nil {
s.sendErr(c, call, val, err) s.sendErr(c, call, val, err)
} else { } else {
valData, err := c.Marshal(val)
if err != nil {
s.sendErr(c, call, val, err)
continue
}
// Create response // Create response
res := types.Response{ res := types.Response{
ID: call.ID, ID: call.ID,
Return: valData, Return: val,
} }
// If function has created a channel // If function has created a channel
if ctx.isChannel { if ctx.isChannel {
idData, err := c.Marshal(ctx.channelID)
if err != nil {
s.sendErr(c, call, val, err)
continue
}
// Set IsChannel to true // Set IsChannel to true
res.Type = types.ResponseTypeChannel res.Type = types.ResponseTypeChannel
// Overwrite return value with channel ID // Overwrite return value with channel ID
res.Return = idData res.Return = ctx.channelID
// Store context in map for future use // Store context in map for future use
s.contextsMtx.Lock() s.contextsMtx.Lock()
@ -349,18 +349,11 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
// For every value received from channel // For every value received from channel
for val := range ctx.channel { for val := range ctx.channel {
codecMtx.Lock() codecMtx.Lock()
valData, err := c.Marshal(val)
if err != nil {
continue
}
// Encode response using codec // Encode response using codec
c.Encode(types.Response{ c.Encode(types.Response{
ID: ctx.channelID, ID: ctx.channelID,
Return: valData, Return: val,
}) })
codecMtx.Unlock() codecMtx.Unlock()
} }
@ -390,14 +383,12 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
// sendErr sends an error response // sendErr sends an error response
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) { func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
valData, _ := c.Marshal(val)
// Encode error response using codec // Encode error response using codec
c.Encode(types.Response{ c.Encode(types.Response{
Type: types.ResponseTypeError, Type: types.ResponseTypeError,
ID: req.ID, ID: req.ID,
Error: err.Error(), Error: err.Error(),
Return: valData, Return: val,
}) })
} }