Compare commits

..

36 Commits

Author SHA1 Message Date
7592eae318 Handle requests concurrently 2022-08-06 22:52:58 -07:00
d35a16ec64 Remove now useless internal/reflectutil 2022-08-06 22:49:12 -07:00
e02c8bc5ff Marshal/Unmarshal arguments and return values separately to allow struct tags to take effect for each codec 2022-08-06 22:48:42 -07:00
5e61e89ac1 Add LRPCClient.getObject() 2022-06-03 17:59:58 -07:00
205e0b71e4 Actually skip LRPCChannel.close() this time 2022-06-03 14:36:34 -07:00
1e627b833e Skip LRPCChannel send(), done(), and close() if already closed 2022-06-03 12:33:03 -07:00
368c7333c5 Use callMethod() because call() is reserved 2022-06-02 18:54:00 -07:00
1e8e304f01 Remove unneeded array 2022-06-02 18:35:22 -07:00
acf262b4f0 Exclude delete transformation from functions filter 2022-06-02 18:27:41 -07:00
8843e7faa9 Mention web client in README 2022-06-02 14:14:18 -07:00
328be35ae2 Add web client 2022-06-02 14:09:27 -07:00
6ee3602128 Fix reflectutils.Convert() for maps and []any 2022-06-02 02:09:47 -07:00
e518b68d8c Return error if expected argument not provided 2022-06-02 02:09:07 -07:00
771c8c136e fix time.Ticker leak 2022-05-28 14:58:39 -07:00
c0a1c3bf43 Add channel test to lrpc 2022-05-28 14:52:00 -07:00
eadee97e5e Add tests 2022-05-26 13:32:19 -07:00
a12224c997 Add error for unexpected arguments 2022-05-26 13:01:29 -07:00
fbae725040 Add (*Server).HandleConn() 2022-05-16 15:42:15 -07:00
bc7aa0fe5b Propagate parent value to request context 2022-05-12 17:15:43 -07:00
3bcc01fdb6 Propagate context to requests 2022-05-12 17:13:44 -07:00
af77b121f8 Fix bug where the connection handler tries to access a channel before it has been created 2022-05-10 02:07:35 -07:00
349123fe25 Set module go version to 1.17 2022-05-07 21:41:51 -07:00
a7a2dc3270 Use Convert() for arrays in reflectutil.ConvertSlice() 2022-05-07 15:02:06 -07:00
f609d5a97f Add introspection functions 2022-05-07 14:59:04 -07:00
ff5f211a83 Use type uint8 to replace boolean fields in response 2022-05-07 14:01:10 -07:00
b1e7ded874 Resolve data races using mutex 2022-05-04 16:15:35 -07:00
7ef9e56505 Fix conversion to pointer 2022-05-03 18:54:39 -07:00
1269203c08 Handle pointer types in reflectutil.Convert() 2022-05-03 18:47:49 -07:00
5e99a66007 Use io.ReadWriteCloser instead of net.Conn in client 2022-05-02 16:24:50 -07:00
4d0c9da4d9 Add comments for ServeWS 2022-05-02 14:48:45 -07:00
a0d167ff75 Add WebSocket server support 2022-05-02 14:47:00 -07:00
bbc9774a96 Make *Context compliant with context.Context 2022-05-01 21:39:16 -07:00
f0f9422fef Add GPLv3 copyright headers 2022-05-01 21:30:00 -07:00
1ae94dc4c4 Clarify README 2022-05-01 21:27:14 -07:00
f1aa0f5c4f Fix call to lrpc.ChannelDone 2022-05-01 15:17:46 -07:00
b53388122c Use context to stop sending values rather than trying to detect channel close 2022-05-01 15:13:07 -07:00
18 changed files with 984 additions and 352 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/client/web/lrpc.js
/s
/c

View File

@@ -15,6 +15,12 @@ This RPC framework supports creating channels to transfer data from server to cl
### Codec
When creating a server or client, a `CodecFunc` can be provided. This `CodecFunc` is provided with an `io.ReadWriter` and returns a `Codec`, which is an interface that contains encode and decode functions with the same signature ad `json.Decoder.Decode()` and `json.Encoder.Encode()`.
When creating a server or client, a `CodecFunc` can be provided. An `io.ReadWriter` is passed into the `CodecFunc` and it returns a `Codec`, which is an interface that contains encode and decode functions with the same signature as `json.Decoder.Decode()` and `json.Encoder.Encode()`.
This allows any codec to be used for the transfer of the data, making it easy to create clients in different languages.
This allows any codec to be used for the transfer of the data, making it easy to create clients in different languages.
---
### Web Client
Inside `client/web`, there is a web client for lrpc using WebSockets. It is written in ruby (I don't like JS) and translated to human-readable JS using Ruby2JS. With the `bundler` gem installed, cd into `client/web` and run `make`. This will create a new file called `lrpc.js`, which can be used within a browser. It uses `crypto.randomUUID()`, so it must be used on an https site, not http.

View File

@@ -1,13 +1,31 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package client
import (
"context"
"errors"
"net"
"io"
"reflect"
"sync"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types"
"github.com/gofrs/uuid"
@@ -25,19 +43,20 @@ var (
// Client is an lrpc client
type Client struct {
conn net.Conn
conn io.ReadWriteCloser
codec codec.Codec
chMtx sync.Mutex
chMtx *sync.Mutex
chs map[string]chan *types.Response
}
// New creates and returns a new client
func New(conn net.Conn, cf codec.CodecFunc) *Client {
func New(conn io.ReadWriteCloser, cf codec.CodecFunc) *Client {
out := &Client{
conn: conn,
codec: cf(conn),
chs: map[string]chan *types.Response{},
chMtx: &sync.Mutex{},
}
go out.handleConn()
@@ -46,7 +65,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
}
// Call calls a method on the server
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
// Create new v4 UUOD
id, err := uuid.NewV4()
if err != nil {
@@ -54,24 +73,34 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
}
idStr := id.String()
ctxDoneVal := reflect.ValueOf(ctx.Done())
// Create new channel using the generated ID
c.chMtx.Lock()
c.chs[idStr] = make(chan *types.Response, 1)
c.chMtx.Unlock()
argData, err := c.codec.Marshal(arg)
if err != nil {
return err
}
// Encode request using codec
err = c.codec.Encode(types.Request{
ID: idStr,
Receiver: rcvr,
Method: method,
Arg: arg,
Arg: argData,
})
if err != nil {
return err
}
// Get response from channel
resp := <-c.chs[idStr]
c.chMtx.Lock()
respCh := c.chs[idStr]
c.chMtx.Unlock()
resp := <-respCh
// Close and delete channel
c.chMtx.Lock()
@@ -80,7 +109,7 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
c.chMtx.Unlock()
// If response is an error, return error
if resp.IsError {
if resp.Type == types.ResponseTypeError {
return errors.New(resp.Error)
}
@@ -93,26 +122,31 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
retVal := reflect.ValueOf(ret)
// If response is a channel
if resp.IsChannel {
if resp.Type == types.ResponseTypeChannel {
// If return value is not a channel, return error
if retVal.Kind() != reflect.Chan {
return ErrReturnNotChannel
}
// Get channel ID returned in response
chID := resp.Return.(string)
var chID string
err = c.codec.Unmarshal(resp.Return, &chID)
if resp.Return == nil {
return nil
}
// Create new channel using channel ID
c.chMtx.Lock()
c.chs[chID] = make(chan *types.Response, 5)
if _, ok := c.chs[chID]; !ok {
c.chs[chID] = make(chan *types.Response, 5)
}
c.chMtx.Unlock()
channelClosed := false
go func() {
// Get type of channel elements
chElemType := retVal.Type().Elem()
// For every value received from channel
for val := range c.chs[chID] {
if val.ChannelDone {
if val.Type == types.ResponseTypeChannelDone {
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
@@ -121,68 +155,37 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
// Close return channel
retVal.Close()
channelClosed = true
break
}
// Get reflect value from channel response
rVal := reflect.ValueOf(val.Return)
// If return value is not the same as the channel
if rVal.Type() != chElemType {
// Attempt to convert value, skip if impossible
newVal, err := reflectutil.Convert(rVal, chElemType)
if err != nil {
continue
}
rVal = newVal
outVal := reflect.New(chElemType)
err = c.codec.Unmarshal(val.Return, outVal.Interface())
if err != nil {
continue
}
outVal = outVal.Elem()
// Send value to channel
retVal.Send(rVal)
}
}()
chosen, _, _ := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
})
if chosen == 1 {
c.Call(context.Background(), "lrpc", "ChannelDone", chID, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
go func() {
for {
val, ok := retVal.Recv()
if !ok && val.IsValid() {
break
retVal.Close()
}
}
if !channelClosed {
c.Call("lrpc", "ChannelDone", id, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
}
}()
} else {
// IF return value is not a pointer, return error
if retVal.Kind() != reflect.Ptr {
return ErrReturnNotPointer
} else if resp.Type == types.ResponseTypeNormal {
err = c.codec.Unmarshal(resp.Return, ret)
if err != nil {
return err
}
// Get return type
retType := retVal.Type().Elem()
// Get refkect value from response
rVal := reflect.ValueOf(resp.Return)
// If types do not match
if rVal.Type() != retType {
// Attempt to convert types, return error if not possible
newVal, err := reflectutil.Convert(rVal, retType)
if err != nil {
return err
}
rVal = newVal
}
// Set return value to received value
retVal.Elem().Set(rVal)
}
return nil
@@ -197,11 +200,15 @@ func (c *Client) handleConn() {
continue
}
// Get channel from map, skip if it doesn't exist
c.chMtx.Lock()
// Attempt to get channel from map
ch, ok := c.chs[resp.ID]
// If channel does not exist, make it
if !ok {
continue
ch = make(chan *types.Response, 5)
c.chs[resp.ID] = ch
}
c.chMtx.Unlock()
// Send response to channel
ch <- resp

3
client/web/Gemfile Normal file
View File

@@ -0,0 +1,3 @@
source 'https://rubygems.org'
gem 'ruby2js'

19
client/web/Gemfile.lock Normal file
View File

@@ -0,0 +1,19 @@
GEM
remote: https://rubygems.org/
specs:
ast (2.4.2)
parser (3.1.2.0)
ast (~> 2.4.1)
regexp_parser (2.1.1)
ruby2js (5.0.1)
parser
regexp_parser (~> 2.1.1)
PLATFORMS
x86_64-linux
DEPENDENCIES
ruby2js
BUNDLED WITH
2.3.15

3
client/web/Makefile Normal file
View File

@@ -0,0 +1,3 @@
lrpc.js: convert.rb lrpc.rb
bundle install
ruby convert.rb > lrpc.js

10
client/web/convert.rb Executable file
View File

@@ -0,0 +1,10 @@
#!/usr/bin/ruby
require 'ruby2js'
require 'ruby2js/filter/functions'
puts Ruby2JS.convert(
File.read('lrpc.rb'),
eslevel: 2016,
exclude: [:delete],
)

174
client/web/lrpc.rb Normal file
View File

@@ -0,0 +1,174 @@
# LRPCResponseType represents the various types an LRPC
# response can have.
LRPCResponseType = {
Normal: 0,
Error: 1,
Channel: 2,
ChannelDone: 3,
}
# LRPCClient represents a client for the LRPC protocol
# using WebSockets and the JSON codec
class LRPCClient
def initialize(addr)
# Set self variables
@callMap = Map.new()
@enc = TextEncoder.new()
@dec = TextDecoder.new()
# Create connection to lrpc server
@conn = WebSocket.new(addr)
@conn.binaryType = "arraybuffer"
@conn.onmessage = proc do |msg|
# if msg.data is string
if msg.data.instance_of? String
# Set json to msg.data
json = msg.data
else
# Set json to decoded msg.data
json = @dec.decode(msg.data)
end
# Parse JSON string
val = JSON.parse(json)
# Get id from callMap
fns = @callMap.get(val.ID)
# If fns is undefined (key does not exist), and this is
# a normal response, return
return if !fns && val.Type == LRPCResponseType.Normal
case val.Type
when LRPCResponseType.Normal
# If fns is a channel, send the value. Otherwise,
# resolve the promise with the value.
if fns.isChannel
fns.send(val.Return)
else
fns.resolve(val.Return)
end
when LRPCResponseType.Channel
# Get channel ID from response
chID = val.Return
# Create new LRPCChannel
ch = LRPCChannel.new(self, chID)
# Set channel in map
@callMap.set(chID, ch)
# Resolve promise with channel
fns.resolve(ch)
when LRPCResponseType.ChannelDone
# Close and delete channel
fns.close()
@callMap.delete(val.ID)
when LRPCResponseType.Error
# Reject promise with error
fns.reject(val.Error)
end
# Delete item from map unless it is a channel
@callMap.delete(val.ID) unless fns.isChannel
end
end
# call calls a method on the server with the given
# argument and returns a promise.
def callMethod(rcvr, method, arg)
return Promise.new do |resolve, reject|
# Get random UUID (this only works with TLS)
id = crypto.randomUUID()
# Add resolve/reject functions to callMap
@callMap.set(id, {
resolve: resolve,
reject: reject,
})
# Encode data as JSON
data = @enc.encode({
Receiver: rcvr,
Method: method,
Arg: arg,
ID: id,
}.to_json())
# Send data to lrpc server
@conn.send(data.buffer)
end
end
# getClient returns an object containing functions
# corresponding to registered functions on the given
# receiver. It uses the lrpc.Introspect() endpoint
# to achieve this.
def getObject(rcvr)
return Promise.new do |resolve|
# Introspect methods on given receiver
self.callMethod("lrpc", "Introspect", rcvr).then do |methodDesc|
# Create output object
out = {}
# For each method in description array
methodDesc.each do |method|
# Create and assign new function to call current method
out[method.Name] = proc { |arg| return self.callMethod(rcvr, method.Name, arg) }
end
# Resolve promise with output promise
resolve(out)
end
end
end
end
# LRPCChannel represents a channel used for lrpc.
class LRPCChannel
def initialize(client, id)
# Set self variables
@client = client
@id = id
@closed = false
# Set function variables to no-ops
@onMessage = proc {|fn|}
@onClose = proc {}
end
# isChannel is defined to allow identifying whether
# an object is a channel.
def isChannel() end
# send sends a value on the channel. This should not
# be called by the consumer of the channel.
def send(val)
return if @closed
fn = @onMessage
fn(val)
end
# done cancels the context corresponding to the channel
# on the server side and closes the channel.
def done()
return if @closed
@client.callMethod("lrpc", "ChannelDone", @id)
self.close()
@client._callMap.delete(@id)
end
# onMessage sets the callback to be called whenever a
# message is received. The function should have one parameter
# that will be set to the value received. Subsequent calls
# will overwrite the callback
def onMessage(fn)
@onMessage = fn
end
# onClose sets the callback to be called whenever the client
# is closed. The function should have no parameters.
# Subsequent calls will overwrite the callback
def onClose(fn)
@onClose = fn
end
# close closes the channel. This should not be called by the
# consumer of the channel. Use done() instead.
def close()
return if @closed
fn = @onClose
fn()
@closed = true
end
end

View File

@@ -1,6 +1,25 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package codec
import (
"bytes"
"encoding/gob"
"encoding/json"
"io"
@@ -8,6 +27,9 @@ import (
"github.com/vmihailenco/msgpack/v5"
)
// <= go1.17 compatibility
type any = interface{}
// CodecFunc is a function that returns a new Codec
// bound to the given io.ReadWriter
type CodecFunc func(io.ReadWriter) Codec
@@ -17,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec
type Codec interface {
Encode(val any) error
Decode(val any) error
Unmarshal(data []byte, v any) error
Marshal(v any) ([]byte, error)
}
// Default is the default CodecFunc
var Default = Msgpack
type JsonCodec struct {
*json.Encoder
*json.Decoder
}
func (JsonCodec) Unmarshal(data []byte, v any) error {
return json.Unmarshal(data, v)
}
func (JsonCodec) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}
// JSON is a CodecFunc that creates a JSON Codec
func JSON(rw io.ReadWriter) Codec {
type jsonCodec struct {
*json.Encoder
*json.Decoder
}
return jsonCodec{
return JsonCodec{
Encoder: json.NewEncoder(rw),
Decoder: json.NewDecoder(rw),
}
}
type MsgpackCodec struct {
*msgpack.Encoder
*msgpack.Decoder
}
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
return msgpack.Unmarshal(data, v)
}
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
return msgpack.Marshal(v)
}
// Msgpack is a CodecFunc that creates a Msgpack Codec
func Msgpack(rw io.ReadWriter) Codec {
type msgpackCodec struct {
*msgpack.Encoder
*msgpack.Decoder
}
return msgpackCodec{
return MsgpackCodec{
Encoder: msgpack.NewEncoder(rw),
Decoder: msgpack.NewDecoder(rw),
}
}
type GobCodec struct {
*gob.Encoder
*gob.Decoder
}
func (GobCodec) Unmarshal(data []byte, v any) error {
return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
}
func (GobCodec) Marshal(v any) ([]byte, error) {
buf := &bytes.Buffer{}
err := gob.NewEncoder(buf).Encode(v)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Gob is a CodecFunc that creates a Gob Codec
func Gob(rw io.ReadWriter) Codec {
type gobCodec struct {
*gob.Encoder
*gob.Decoder
}
return gobCodec{
return GobCodec{
Encoder: gob.NewEncoder(rw),
Decoder: gob.NewDecoder(rw),
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/gob"
"fmt"
"net"
@@ -12,21 +13,23 @@ import (
func main() {
gob.Register([2]int{})
ctx := context.Background()
conn, _ := net.Dial("tcp", "localhost:9090")
c := client.New(conn, codec.Gob)
defer c.Close()
var add int
c.Call("Arith", "Add", [2]int{5, 5}, &add)
c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
var sub int
c.Call("Arith", "Sub", [2]int{5, 5}, &sub)
c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
var mul int
c.Call("Arith", "Mul", [2]int{5, 5}, &mul)
c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
var div int
c.Call("Arith", "Div", [2]int{5, 5}, &div)
c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
fmt.Printf(
"add: %d, sub: %d, mul: %d, div: %d\n",

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/gob"
"net"
@@ -33,5 +34,5 @@ func main() {
s.Register(Arith{})
ln, _ := net.Listen("tcp", ":9090")
s.Serve(ln, codec.Gob)
s.Serve(context.Background(), ln, codec.Gob)
}

3
go.mod
View File

@@ -1,11 +1,12 @@
module go.arsenm.dev/lrpc
go 1.18
go 1.17
require (
github.com/gofrs/uuid v4.2.0+incompatible
github.com/mitchellh/mapstructure v1.5.0
github.com/vmihailenco/msgpack/v5 v5.3.5
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4
)
require github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect

2
go.sum
View File

@@ -13,6 +13,8 @@ github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA=
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,148 +0,0 @@
package reflectutil
import (
"encoding"
"fmt"
"reflect"
"github.com/mitchellh/mapstructure"
)
// Convert attempts to convert the given value to the given type
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
// Get input type
inType := in.Type()
// If input is already the desired type, return
if inType == toType {
return in, nil
}
// If input can be converted to desired type, convert and return
if in.CanConvert(toType) {
return in.Convert(toType), nil
}
// Create new value of desired type
to := reflect.New(toType).Elem()
// If type is a pointer
if to.Kind() == reflect.Ptr {
// Initialize value
to.Set(reflect.New(to.Type().Elem()))
}
switch val := in.Interface().(type) {
case string:
// If desired type satisfies text unmarshaler
if u, ok := to.Interface().(encoding.TextUnmarshaler); ok {
// Use text unmarshaler to get value
err := u.UnmarshalText([]byte(val))
if err != nil {
return reflect.Value{}, err
}
// Return unmarshaled value
return reflect.ValueOf(any(u)), nil
}
case []byte:
// If desired type satisfies binary unmarshaler
if u, ok := to.Interface().(encoding.BinaryUnmarshaler); ok {
// Use binary unmarshaler to get value
err := u.UnmarshalBinary(val)
if err != nil {
return reflect.Value{}, err
}
// Return unmarshaled value
return reflect.ValueOf(any(u)), nil
}
}
// If input is a map
if in.Kind() == reflect.Map {
// Use mapstructure to decode value
err := mapstructure.Decode(in.Interface(), to.Addr().Interface())
if err == nil {
return to, nil
}
}
// If input is a slice of any, and output is an array or slice
if in.Type() == reflect.TypeOf([]any{}) &&
to.Kind() == reflect.Slice || to.Kind() == reflect.Array {
// Use ConvertSlice to convert value
return reflect.ValueOf(ConvertSlice(
in.Interface().([]any),
toType,
)), nil
}
return to, fmt.Errorf("cannot convert %s to %s", inType, toType)
}
// ConvertSlice converts []any to an array or slice, as provided
// in the "to" argument.
func ConvertSlice(in []any, to reflect.Type) any {
// Create new value for output
out := reflect.New(to).Elem()
// If output value is a slice
if out.Kind() == reflect.Slice {
// Get type of slice elements
outType := out.Type().Elem()
// For every value provided
for i := 0; i < len(in); i++ {
// Get value of input type
inVal := reflect.ValueOf(in[i])
// Create new output type
outVal := reflect.New(outType).Elem()
// If types match
if inVal.Type() == outType {
// Set output value to input value
outVal.Set(inVal)
} else {
newVal, err := Convert(inVal, outType)
if err != nil {
// Set output value to its zero value
outVal.Set(reflect.Zero(outVal.Type()))
} else {
outVal.Set(newVal)
}
}
// Append output value to slice
out = reflect.Append(out, outVal)
}
} else if out.Kind() == reflect.Array && out.Len() == len(in) {
//If output type is array and lengths match
// For every input value
for i := 0; i < len(in); i++ {
// Get matching output index
outVal := out.Index(i)
// Get input value
inVal := reflect.ValueOf(in[i])
// If types match
if inVal.Type() == outVal.Type() {
// Set output value to input value
outVal.Set(inVal)
} else {
// If input value can be converted to output type
if inVal.CanConvert(outVal.Type()) {
// Convert and set output value to input value
outVal.Set(inVal.Convert(outVal.Type()))
} else {
// Set output value to its zero value
outVal.Set(reflect.Zero(outVal.Type()))
}
}
}
}
// Return created value
return out.Interface()
}

View File

@@ -1,19 +1,47 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package types
// <= go1.17 compatibility
type any = interface{}
// Request represents a request sent to the server
type Request struct {
ID string
Receiver string
Method string
Arg any
Arg []byte
}
type ResponseType uint8
const (
ResponseTypeNormal ResponseType = iota
ResponseTypeError
ResponseTypeChannel
ResponseTypeChannelDone
)
// Response represents a response returned by the server
type Response struct {
ID string
ChannelDone bool
IsChannel bool
IsError bool
Error string
Return any
Type ResponseType
ID string
Error string
Return []byte
}

212
lrpc_test.go Normal file
View File

@@ -0,0 +1,212 @@
package lrpc_test
import (
"context"
"encoding/gob"
"net"
"testing"
"time"
"go.arsenm.dev/lrpc/client"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/server"
)
type Arith struct{}
func (Arith) Add(ctx *server.Context, in [2]int) int {
return in[0] + in[1]
}
func (Arith) Mul(ctx *server.Context, in [2]int) int {
return in[0] * in[1]
}
func (Arith) Div(ctx *server.Context, in [2]int) int {
return in[0] / in[1]
}
func (Arith) Sub(ctx *server.Context, in [2]int) int {
return in[0] - in[1]
}
func TestCalls(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create new network pipe
sConn, cConn := net.Pipe()
s := server.New()
defer s.Close()
// Register Arith for RPC
s.Register(Arith{})
// Serve the pipe connection using default codec
go s.ServeConn(ctx, sConn, codec.Default)
// Create new client using default codec
c := client.New(cConn, codec.Default)
defer c.Close()
// Call Arith.Add()
var add int
err := c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
if err != nil {
t.Error(err)
}
// Call Arith.Sub()
var sub int
err = c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
if err != nil {
t.Error(err)
}
// Call Arith.Mul()
var mul int
err = c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
if err != nil {
t.Error(err)
}
// Call Arith.Div()
var div int
err = c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
if err != nil {
t.Error(err)
}
if add != 10 {
t.Errorf("add: expected 10, got %d", add)
}
if sub != 0 {
t.Errorf("sub: expected 0, got %d", sub)
}
if mul != 25 {
t.Errorf("mul: expected 25, got %d", mul)
}
if div != 1 {
t.Errorf("div: expected 1, got %d", div)
}
}
func TestCodecs(t *testing.T) {
// Register the 2-integer array for gob
gob.Register([2]int{})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create function to test each codec
testCodec := func(cf codec.CodecFunc, name string) {
// Create network pipe
sConn, cConn := net.Pipe()
s := server.New()
defer s.Close()
// Register Arith for RPC
s.Register(Arith{})
// Serve the pipe connection using provided codec
go s.ServeConn(ctx, sConn, cf)
// Create new client using provided codec
c := client.New(cConn, cf)
defer c.Close()
// Call Arith.Add()
var add int
err := c.Call(ctx, "Arith", "Add", [2]int{2, 2}, &add)
if err != nil {
t.Errorf("codec/%s: %v", name, err)
}
if add != 4 {
t.Errorf("codec/%s: add: expected 4, got %d", name, add)
}
}
// Test all codecs
testCodec(codec.Msgpack, "msgpack")
testCodec(codec.JSON, "json")
testCodec(codec.Gob, "gob")
}
type Channel struct{}
func (Channel) Time(ctx *server.Context, interval time.Duration) error {
ch, err := ctx.MakeChannel()
if err != nil {
return err
}
tick := time.NewTicker(interval)
go func() {
for {
select {
case t := <-tick.C:
ch <- t
case <-ctx.Done():
tick.Stop()
close(ch)
return
}
}
}()
return nil
}
func TestChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create new network pipe
sConn, cConn := net.Pipe()
s := server.New()
defer s.Close()
// Register Arith for RPC
s.Register(Channel{})
// Serve the pipe connection using default codec
go s.ServeConn(ctx, sConn, codec.Default)
// Create new client using default codec
c := client.New(cConn, codec.Default)
defer c.Close()
timeCtx, timeCancel := context.WithCancel(ctx)
defer timeCancel()
timeCh := make(chan *time.Time, 2)
err := c.Call(timeCtx, "Channel", "Time", time.Millisecond, timeCh)
if err != nil {
t.Error(err)
}
var loops int
var lastTime *time.Time
for curTime := range timeCh {
if loops > 3 {
timeCancel()
break
}
if lastTime == nil {
lastTime = curTime
continue
}
diff := curTime.Sub(*lastTime)
diff = diff.Round(time.Millisecond)
if diff != time.Millisecond {
t.Fatalf("expected 1s diff, got %s", diff)
}
lastTime = curTime
loops++
}
}

View File

@@ -1,6 +1,27 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"time"
"go.arsenm.dev/lrpc/codec"
"github.com/gofrs/uuid"
@@ -14,7 +35,26 @@ type Context struct {
codec codec.Codec
doneCh chan struct{}
doneCh chan struct{}
canceled bool
ctx context.Context
}
func newContext(ctx context.Context, codec codec.Codec) *Context {
out := &Context{
doneCh: make(chan struct{}),
codec: codec,
ctx: ctx,
}
if ctx == nil {
out.ctx = context.Background()
}
go func() {
<-out.ctx.Done()
out.cancel()
}()
return out
}
// MakeChannel changes the function it's called in into a
@@ -37,6 +77,26 @@ func (ctx *Context) GetCodec() codec.Codec {
return ctx.codec
}
// Deadline always returns the current time and false
// as this context does not support deadlines
func (ctx *Context) Deadline() (time.Time, bool) {
return time.Now(), false
}
// Value always returns nil as this context stores no values
func (ctx *Context) Value(key any) any {
return ctx.ctx.Value(key)
}
// Err returns context.Canceled if the context was canceled,
// otherwise nil
func (ctx *Context) Err() error {
if ctx.canceled {
return context.Canceled
}
return nil
}
// Done returns a channel that will be closed when
// the context is canceled, such as when ChannelDone
// is called by the client
@@ -45,6 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
}
// Cancel cancels the context
func (ctx *Context) Cancel() {
func (ctx *Context) cancel() {
if ctx.canceled {
return
}
ctx.canceled = true
close(ctx.doneCh)
}

View File

@@ -1,28 +1,46 @@
/*
* lrpc allows for clients to call functions on a server remotely.
* Copyright (C) 2022 Arsen Musayelyan
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package server
import (
"context"
"errors"
"io"
"net"
"net/http"
"reflect"
"sync"
"go.arsenm.dev/lrpc/codec"
"go.arsenm.dev/lrpc/internal/reflectutil"
"go.arsenm.dev/lrpc/internal/types"
"golang.org/x/net/websocket"
)
// <= go1.17 compatibility
type any = interface{}
var (
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrTooManyInputs = errors.New("method may not have more than two inputs")
ErrTooManyOutputs = errors.New("method may not have more than two return values")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidSecondReturn = errors.New("second return value must be error")
ErrInvalidFirstInput = errors.New("first input must be *Context")
ErrInvalidType = errors.New("type must be struct or pointer to struct")
ErrNoSuchReceiver = errors.New("no such receiver registered")
ErrNoSuchMethod = errors.New("no such method was found")
ErrInvalidMethod = errors.New("method invalid for lrpc call")
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
)
// Server is an lrpc server
@@ -50,7 +68,7 @@ func New() *Server {
// Close closes the server
func (s *Server) Close() {
for _, ctx := range s.contexts {
ctx.Cancel()
ctx.cancel()
}
}
@@ -81,7 +99,7 @@ func (s *Server) Register(v any) error {
}
// execute runs a method of a registered value
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
// Try to get value from receivers map
val, ok := s.rcvrs[typ]
if !ok {
@@ -94,55 +112,37 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
return nil, nil, ErrNoSuchMethod
}
// Get method's type
// If method invalid, return error
if !mtdValid(mtd) {
return nil, nil, ErrInvalidMethod
}
// Get method type
mtdType := mtd.Type()
if mtdType.NumIn() > 2 {
return nil, nil, ErrTooManyInputs
} else if mtdType.NumIn() < 1 {
return nil, nil, ErrInvalidFirstInput
}
if mtdType.NumOut() > 2 {
return nil, nil, ErrTooManyOutputs
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf(&Context{}) {
return nil, nil, ErrInvalidFirstInput
}
//TODO: if arg not nil but fn has no arg, err
// IF argument is []any
anySlice, ok := arg.([]any)
if ok {
// Convert slice to the method's arg type and
// set arg to the newly-converted slice
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
argType := mtdType.In(1)
argVal := reflect.New(argType)
arg := argVal.Interface()
err = c.Unmarshal(data, arg)
if err != nil {
return nil, nil, err
}
// Get argument value
argVal := reflect.ValueOf(arg)
// If argument's type does not match method's argument type
if arg != nil && argVal.Type() != mtdType.In(1) {
val, err = reflectutil.Convert(argVal, mtdType.In(1))
if err != nil {
return nil, nil, err
}
arg = val.Interface()
}
arg = argVal.Elem().Interface()
// Create new context
ctx = &Context{
doneCh: make(chan struct{}, 1),
codec: c,
}
ctx = newContext(pCtx, c)
// Get reflect value of context
ctxVal := reflect.ValueOf(ctx)
switch mtdType.NumOut() {
case 0: // If method has no return values
if mtdType.NumIn() == 2 {
if arg == nil {
return nil, nil, ErrArgNotProvided
}
// Call method with arg, ignore returned value
mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
} else {
@@ -151,6 +151,10 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
}
case 1: // If method has one return value
if mtdType.NumIn() == 2 {
if arg == nil {
return nil, nil, ErrArgNotProvided
}
// Call method with arg, get returned values
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
@@ -185,6 +189,10 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
}
case 2: // If method has two return values
if mtdType.NumIn() == 2 {
if arg == nil {
return nil, nil, ErrArgNotProvided
}
// Call method with arg and get returned values
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
@@ -195,7 +203,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid
if !ok {
a, err = nil, ErrInvalidSecondReturn
a, err = nil, ErrInvalidMethod
}
}
@@ -211,7 +219,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// If second return value is not an error, the function is invalid
err, ok = out1.(error)
if !ok {
a, err = nil, ErrInvalidSecondReturn
a, err = nil, ErrInvalidMethod
}
}
@@ -224,22 +232,66 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
// Serve starts the server using the provided listener
// and codec function
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
go func() {
<-ctx.Done()
ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
} else if err != nil {
continue
}
// Create new instance of codec bound to conn
c := cf(conn)
// Handle connection
go s.handleConn(c)
go s.handleConn(ctx, c)
}
}
// handleConn handles a listener connection
func (s *Server) handleConn(c codec.Codec) {
// ServeWS starts a server using WebSocket. This may be useful for
// clients written in other languages, such as JS for a browser.
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
// Create new WebSocket server
ws := websocket.Server{}
// Create new WebSocket config
ws.Config = websocket.Config{
Version: websocket.ProtocolVersionHybi13,
}
// Set server handler
ws.Handler = func(c *websocket.Conn) {
s.handleConn(c.Request().Context(), cf(c))
}
server := &http.Server{
Addr: addr,
BaseContext: func(net.Listener) context.Context {
return ctx
},
Handler: http.HandlerFunc(ws.ServeHTTP),
}
// Listen and serve on given address
return server.ListenAndServe()
}
// ServeConn uses the provided connection to serve the client.
// This may be useful if something other than a net.Listener
// needs to be used
func (s *Server) ServeConn(ctx context.Context, conn io.ReadWriter, cf codec.CodecFunc) {
s.handleConn(ctx, cf(conn))
}
// handleConn handles a connection
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
codecMtx := &sync.Mutex{}
for {
var call types.Request
// Read request using codec
@@ -251,72 +303,103 @@ func (s *Server) handleConn(c codec.Codec) {
continue
}
// Execute decoded call
val, ctx, err := s.execute(
call.Receiver,
call.Method,
call.Arg,
c,
)
if err != nil {
s.sendErr(c, call, val, err)
} else {
// Create response
res := types.Response{
ID: call.ID,
Return: val,
}
go func() {
// Execute decoded call
val, ctx, err := s.execute(
pCtx,
call.Receiver,
call.Method,
call.Arg,
c,
)
if err != nil {
s.sendErr(c, call, val, err)
} else {
valData, err := c.Marshal(val)
if err != nil {
s.sendErr(c, call, val, err)
return
}
// If function has created a channel
if ctx.isChannel {
// Set IsChannel to true
res.IsChannel = true
// Overwrite return value with channel ID
res.Return = ctx.channelID
// Create response
res := types.Response{
ID: call.ID,
Return: valData,
}
// Store context in map for future use
s.contextsMtx.Lock()
s.contexts[ctx.channelID] = ctx
s.contextsMtx.Unlock()
go func() {
// For every value received from channel
for val := range ctx.channel {
// Encode response using codec
c.Encode(types.Response{
ID: ctx.channelID,
Return: val,
})
// If function has created a channel
if ctx.isChannel {
idData, err := c.Marshal(ctx.channelID)
if err != nil {
s.sendErr(c, call, val, err)
return
}
// Cancel context
ctx.Cancel()
// Delete context from map
// Set IsChannel to true
res.Type = types.ResponseTypeChannel
// Overwrite return value with channel ID
res.Return = idData
// Store context in map for future use
s.contextsMtx.Lock()
delete(s.contexts, ctx.channelID)
s.contexts[ctx.channelID] = ctx
s.contextsMtx.Unlock()
c.Encode(types.Response{
ID: ctx.channelID,
ChannelDone: true,
})
}()
go func() {
// For every value received from channel
for val := range ctx.channel {
codecMtx.Lock()
valData, err := c.Marshal(val)
if err != nil {
continue
}
// Encode response using codec
c.Encode(types.Response{
ID: ctx.channelID,
Return: valData,
})
codecMtx.Unlock()
}
// Cancel context
ctx.cancel()
// Delete context from map
s.contextsMtx.Lock()
delete(s.contexts, ctx.channelID)
s.contextsMtx.Unlock()
codecMtx.Lock()
c.Encode(types.Response{
Type: types.ResponseTypeChannelDone,
ID: ctx.channelID,
})
codecMtx.Unlock()
}()
}
// Encode response using codec
codecMtx.Lock()
c.Encode(res)
codecMtx.Unlock()
}
// Encode response using codec
c.Encode(res)
}
}()
}
}
// sendErr sends an error response
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
valData, _ := c.Marshal(val)
// Encode error response using codec
c.Encode(types.Response{
ID: req.ID,
IsError: true,
Error: err.Error(),
Return: val,
Type: types.ResponseTypeError,
ID: req.ID,
Error: err.Error(),
Return: valData,
})
}
@@ -334,9 +417,114 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
}
// Cancel context
ctx.Cancel()
ctx.cancel()
// Delete context from map
l.srv.contextsMtx.Lock()
delete(l.srv.contexts, id)
l.srv.contextsMtx.Unlock()
}
// MethodDesc describes methods on a receiver
type MethodDesc struct {
Name string
Args []string
Returns []string
}
// Introspect returns method descriptions for the given receiver
func (l lrpc) Introspect(_ *Context, name string) ([]MethodDesc, error) {
// Attempt to get receiver
rcvr, ok := l.srv.rcvrs[name]
if !ok {
return nil, ErrNoSuchReceiver
}
// Get receiver type value
rcvrType := rcvr.Type()
// Create slice for output
var out []MethodDesc
// For every method on receiver
for i := 0; i < rcvr.NumMethod(); i++ {
// Get receiver method
mtd := rcvr.Method(i)
// If invalid, skip
if !mtdValid(mtd) {
continue
}
// Get method type
mtdType := mtd.Type()
// Get amount of arguments
numIn := mtdType.NumIn()
args := make([]string, numIn-1)
// For every argument, store type in slice
// Skip first argument, as it is *Context
for i := 1; i < numIn; i++ {
args[i-1] = mtdType.In(i).String()
}
// Get amount of returns
numOut := mtdType.NumOut()
returns := make([]string, numOut)
// For every return, store type in slice
for i := 0; i < numOut; i++ {
returns[i] = mtdType.Out(i).String()
}
out = append(out, MethodDesc{
Name: rcvrType.Method(i).Name,
Args: args,
Returns: returns,
})
}
return out, nil
}
// IntrospectAll runs Introspect on all registered receivers and returns all results
func (l lrpc) IntrospectAll(_ *Context) (map[string][]MethodDesc, error) {
// Create map for output
out := make(map[string][]MethodDesc, len(l.srv.rcvrs))
// For every registered receiver
for name := range l.srv.rcvrs {
// Introspect receiver
descs, err := l.Introspect(nil, name)
if err != nil {
return nil, err
}
// Set results in map
out[name] = descs
}
return out, nil
}
func mtdValid(mtd reflect.Value) bool {
// Get method's type
mtdType := mtd.Type()
// If method has more than 2 or less than 1 input, it is invalid
if mtdType.NumIn() > 2 || mtdType.NumIn() < 1 {
return false
}
// If method has more than 2 outputs, it is invalid
if mtdType.NumOut() > 2 {
return false
}
// Check to ensure first parameter is context
if mtdType.In(0) != reflect.TypeOf((*Context)(nil)) {
return false
}
// If method has 2 outputs
if mtdType.NumOut() == 2 {
// Check to ensure the second one is an error
if mtdType.Out(1).Name() != "error" {
return false
}
}
return true
}