Compare commits
36 Commits
6df8cf53c6
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 7592eae318 | |||
| d35a16ec64 | |||
| e02c8bc5ff | |||
| 5e61e89ac1 | |||
| 205e0b71e4 | |||
| 1e627b833e | |||
| 368c7333c5 | |||
| 1e8e304f01 | |||
| acf262b4f0 | |||
| 8843e7faa9 | |||
| 328be35ae2 | |||
| 6ee3602128 | |||
| e518b68d8c | |||
| 771c8c136e | |||
| c0a1c3bf43 | |||
| eadee97e5e | |||
| a12224c997 | |||
| fbae725040 | |||
| bc7aa0fe5b | |||
| 3bcc01fdb6 | |||
| af77b121f8 | |||
| 349123fe25 | |||
| a7a2dc3270 | |||
| f609d5a97f | |||
| ff5f211a83 | |||
| b1e7ded874 | |||
| 7ef9e56505 | |||
| 1269203c08 | |||
| 5e99a66007 | |||
| 4d0c9da4d9 | |||
| a0d167ff75 | |||
| bbc9774a96 | |||
| f0f9422fef | |||
| 1ae94dc4c4 | |||
| f1aa0f5c4f | |||
| b53388122c |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
/client/web/lrpc.js
|
||||
/s
|
||||
/c
|
||||
10
README.md
10
README.md
@@ -15,6 +15,12 @@ This RPC framework supports creating channels to transfer data from server to cl
|
||||
|
||||
### Codec
|
||||
|
||||
When creating a server or client, a `CodecFunc` can be provided. This `CodecFunc` is provided with an `io.ReadWriter` and returns a `Codec`, which is an interface that contains encode and decode functions with the same signature ad `json.Decoder.Decode()` and `json.Encoder.Encode()`.
|
||||
When creating a server or client, a `CodecFunc` can be provided. An `io.ReadWriter` is passed into the `CodecFunc` and it returns a `Codec`, which is an interface that contains encode and decode functions with the same signature as `json.Decoder.Decode()` and `json.Encoder.Encode()`.
|
||||
|
||||
This allows any codec to be used for the transfer of the data, making it easy to create clients in different languages.
|
||||
This allows any codec to be used for the transfer of the data, making it easy to create clients in different languages.
|
||||
|
||||
---
|
||||
|
||||
### Web Client
|
||||
|
||||
Inside `client/web`, there is a web client for lrpc using WebSockets. It is written in ruby (I don't like JS) and translated to human-readable JS using Ruby2JS. With the `bundler` gem installed, cd into `client/web` and run `make`. This will create a new file called `lrpc.js`, which can be used within a browser. It uses `crypto.randomUUID()`, so it must be used on an https site, not http.
|
||||
@@ -1,13 +1,31 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"io"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
@@ -25,19 +43,20 @@ var (
|
||||
|
||||
// Client is an lrpc client
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
conn io.ReadWriteCloser
|
||||
codec codec.Codec
|
||||
|
||||
chMtx sync.Mutex
|
||||
chMtx *sync.Mutex
|
||||
chs map[string]chan *types.Response
|
||||
}
|
||||
|
||||
// New creates and returns a new client
|
||||
func New(conn net.Conn, cf codec.CodecFunc) *Client {
|
||||
func New(conn io.ReadWriteCloser, cf codec.CodecFunc) *Client {
|
||||
out := &Client{
|
||||
conn: conn,
|
||||
codec: cf(conn),
|
||||
chs: map[string]chan *types.Response{},
|
||||
chMtx: &sync.Mutex{},
|
||||
}
|
||||
|
||||
go out.handleConn()
|
||||
@@ -46,7 +65,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
|
||||
}
|
||||
|
||||
// Call calls a method on the server
|
||||
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
|
||||
func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
|
||||
// Create new v4 UUOD
|
||||
id, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
@@ -54,24 +73,34 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
||||
}
|
||||
idStr := id.String()
|
||||
|
||||
ctxDoneVal := reflect.ValueOf(ctx.Done())
|
||||
|
||||
// Create new channel using the generated ID
|
||||
c.chMtx.Lock()
|
||||
c.chs[idStr] = make(chan *types.Response, 1)
|
||||
c.chMtx.Unlock()
|
||||
|
||||
argData, err := c.codec.Marshal(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encode request using codec
|
||||
err = c.codec.Encode(types.Request{
|
||||
ID: idStr,
|
||||
Receiver: rcvr,
|
||||
Method: method,
|
||||
Arg: arg,
|
||||
Arg: argData,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get response from channel
|
||||
resp := <-c.chs[idStr]
|
||||
c.chMtx.Lock()
|
||||
respCh := c.chs[idStr]
|
||||
c.chMtx.Unlock()
|
||||
resp := <-respCh
|
||||
|
||||
// Close and delete channel
|
||||
c.chMtx.Lock()
|
||||
@@ -80,7 +109,7 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
||||
c.chMtx.Unlock()
|
||||
|
||||
// If response is an error, return error
|
||||
if resp.IsError {
|
||||
if resp.Type == types.ResponseTypeError {
|
||||
return errors.New(resp.Error)
|
||||
}
|
||||
|
||||
@@ -93,26 +122,31 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
||||
retVal := reflect.ValueOf(ret)
|
||||
|
||||
// If response is a channel
|
||||
if resp.IsChannel {
|
||||
if resp.Type == types.ResponseTypeChannel {
|
||||
// If return value is not a channel, return error
|
||||
if retVal.Kind() != reflect.Chan {
|
||||
return ErrReturnNotChannel
|
||||
}
|
||||
// Get channel ID returned in response
|
||||
chID := resp.Return.(string)
|
||||
var chID string
|
||||
err = c.codec.Unmarshal(resp.Return, &chID)
|
||||
if resp.Return == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create new channel using channel ID
|
||||
c.chMtx.Lock()
|
||||
c.chs[chID] = make(chan *types.Response, 5)
|
||||
if _, ok := c.chs[chID]; !ok {
|
||||
c.chs[chID] = make(chan *types.Response, 5)
|
||||
}
|
||||
c.chMtx.Unlock()
|
||||
|
||||
channelClosed := false
|
||||
go func() {
|
||||
// Get type of channel elements
|
||||
chElemType := retVal.Type().Elem()
|
||||
// For every value received from channel
|
||||
for val := range c.chs[chID] {
|
||||
if val.ChannelDone {
|
||||
if val.Type == types.ResponseTypeChannelDone {
|
||||
// Close and delete channel
|
||||
c.chMtx.Lock()
|
||||
close(c.chs[chID])
|
||||
@@ -121,68 +155,37 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
||||
|
||||
// Close return channel
|
||||
retVal.Close()
|
||||
|
||||
channelClosed = true
|
||||
|
||||
break
|
||||
}
|
||||
// Get reflect value from channel response
|
||||
rVal := reflect.ValueOf(val.Return)
|
||||
|
||||
// If return value is not the same as the channel
|
||||
if rVal.Type() != chElemType {
|
||||
// Attempt to convert value, skip if impossible
|
||||
newVal, err := reflectutil.Convert(rVal, chElemType)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
rVal = newVal
|
||||
outVal := reflect.New(chElemType)
|
||||
err = c.codec.Unmarshal(val.Return, outVal.Interface())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
outVal = outVal.Elem()
|
||||
|
||||
// Send value to channel
|
||||
retVal.Send(rVal)
|
||||
}
|
||||
}()
|
||||
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
||||
{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
|
||||
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
||||
})
|
||||
if chosen == 1 {
|
||||
c.Call(context.Background(), "lrpc", "ChannelDone", chID, nil)
|
||||
// Close and delete channel
|
||||
c.chMtx.Lock()
|
||||
close(c.chs[chID])
|
||||
delete(c.chs, chID)
|
||||
c.chMtx.Unlock()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
val, ok := retVal.Recv()
|
||||
if !ok && val.IsValid() {
|
||||
break
|
||||
retVal.Close()
|
||||
}
|
||||
}
|
||||
if !channelClosed {
|
||||
c.Call("lrpc", "ChannelDone", id, nil)
|
||||
// Close and delete channel
|
||||
c.chMtx.Lock()
|
||||
close(c.chs[chID])
|
||||
delete(c.chs, chID)
|
||||
c.chMtx.Unlock()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// IF return value is not a pointer, return error
|
||||
if retVal.Kind() != reflect.Ptr {
|
||||
return ErrReturnNotPointer
|
||||
} else if resp.Type == types.ResponseTypeNormal {
|
||||
err = c.codec.Unmarshal(resp.Return, ret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get return type
|
||||
retType := retVal.Type().Elem()
|
||||
// Get refkect value from response
|
||||
rVal := reflect.ValueOf(resp.Return)
|
||||
|
||||
// If types do not match
|
||||
if rVal.Type() != retType {
|
||||
// Attempt to convert types, return error if not possible
|
||||
newVal, err := reflectutil.Convert(rVal, retType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rVal = newVal
|
||||
}
|
||||
|
||||
// Set return value to received value
|
||||
retVal.Elem().Set(rVal)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -197,11 +200,15 @@ func (c *Client) handleConn() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get channel from map, skip if it doesn't exist
|
||||
c.chMtx.Lock()
|
||||
// Attempt to get channel from map
|
||||
ch, ok := c.chs[resp.ID]
|
||||
// If channel does not exist, make it
|
||||
if !ok {
|
||||
continue
|
||||
ch = make(chan *types.Response, 5)
|
||||
c.chs[resp.ID] = ch
|
||||
}
|
||||
c.chMtx.Unlock()
|
||||
|
||||
// Send response to channel
|
||||
ch <- resp
|
||||
3
client/web/Gemfile
Normal file
3
client/web/Gemfile
Normal file
@@ -0,0 +1,3 @@
|
||||
source 'https://rubygems.org'
|
||||
|
||||
gem 'ruby2js'
|
||||
19
client/web/Gemfile.lock
Normal file
19
client/web/Gemfile.lock
Normal file
@@ -0,0 +1,19 @@
|
||||
GEM
|
||||
remote: https://rubygems.org/
|
||||
specs:
|
||||
ast (2.4.2)
|
||||
parser (3.1.2.0)
|
||||
ast (~> 2.4.1)
|
||||
regexp_parser (2.1.1)
|
||||
ruby2js (5.0.1)
|
||||
parser
|
||||
regexp_parser (~> 2.1.1)
|
||||
|
||||
PLATFORMS
|
||||
x86_64-linux
|
||||
|
||||
DEPENDENCIES
|
||||
ruby2js
|
||||
|
||||
BUNDLED WITH
|
||||
2.3.15
|
||||
3
client/web/Makefile
Normal file
3
client/web/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
lrpc.js: convert.rb lrpc.rb
|
||||
bundle install
|
||||
ruby convert.rb > lrpc.js
|
||||
10
client/web/convert.rb
Executable file
10
client/web/convert.rb
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/usr/bin/ruby
|
||||
|
||||
require 'ruby2js'
|
||||
require 'ruby2js/filter/functions'
|
||||
|
||||
puts Ruby2JS.convert(
|
||||
File.read('lrpc.rb'),
|
||||
eslevel: 2016,
|
||||
exclude: [:delete],
|
||||
)
|
||||
174
client/web/lrpc.rb
Normal file
174
client/web/lrpc.rb
Normal file
@@ -0,0 +1,174 @@
|
||||
# LRPCResponseType represents the various types an LRPC
|
||||
# response can have.
|
||||
LRPCResponseType = {
|
||||
Normal: 0,
|
||||
Error: 1,
|
||||
Channel: 2,
|
||||
ChannelDone: 3,
|
||||
}
|
||||
|
||||
# LRPCClient represents a client for the LRPC protocol
|
||||
# using WebSockets and the JSON codec
|
||||
class LRPCClient
|
||||
def initialize(addr)
|
||||
# Set self variables
|
||||
@callMap = Map.new()
|
||||
@enc = TextEncoder.new()
|
||||
@dec = TextDecoder.new()
|
||||
|
||||
# Create connection to lrpc server
|
||||
@conn = WebSocket.new(addr)
|
||||
@conn.binaryType = "arraybuffer"
|
||||
@conn.onmessage = proc do |msg|
|
||||
# if msg.data is string
|
||||
if msg.data.instance_of? String
|
||||
# Set json to msg.data
|
||||
json = msg.data
|
||||
else
|
||||
# Set json to decoded msg.data
|
||||
json = @dec.decode(msg.data)
|
||||
end
|
||||
# Parse JSON string
|
||||
val = JSON.parse(json)
|
||||
# Get id from callMap
|
||||
fns = @callMap.get(val.ID)
|
||||
# If fns is undefined (key does not exist), and this is
|
||||
# a normal response, return
|
||||
return if !fns && val.Type == LRPCResponseType.Normal
|
||||
|
||||
case val.Type
|
||||
when LRPCResponseType.Normal
|
||||
# If fns is a channel, send the value. Otherwise,
|
||||
# resolve the promise with the value.
|
||||
if fns.isChannel
|
||||
fns.send(val.Return)
|
||||
else
|
||||
fns.resolve(val.Return)
|
||||
end
|
||||
when LRPCResponseType.Channel
|
||||
# Get channel ID from response
|
||||
chID = val.Return
|
||||
# Create new LRPCChannel
|
||||
ch = LRPCChannel.new(self, chID)
|
||||
# Set channel in map
|
||||
@callMap.set(chID, ch)
|
||||
# Resolve promise with channel
|
||||
fns.resolve(ch)
|
||||
when LRPCResponseType.ChannelDone
|
||||
# Close and delete channel
|
||||
fns.close()
|
||||
@callMap.delete(val.ID)
|
||||
when LRPCResponseType.Error
|
||||
# Reject promise with error
|
||||
fns.reject(val.Error)
|
||||
end
|
||||
|
||||
# Delete item from map unless it is a channel
|
||||
@callMap.delete(val.ID) unless fns.isChannel
|
||||
end
|
||||
end
|
||||
|
||||
# call calls a method on the server with the given
|
||||
# argument and returns a promise.
|
||||
def callMethod(rcvr, method, arg)
|
||||
return Promise.new do |resolve, reject|
|
||||
# Get random UUID (this only works with TLS)
|
||||
id = crypto.randomUUID()
|
||||
# Add resolve/reject functions to callMap
|
||||
@callMap.set(id, {
|
||||
resolve: resolve,
|
||||
reject: reject,
|
||||
})
|
||||
|
||||
# Encode data as JSON
|
||||
data = @enc.encode({
|
||||
Receiver: rcvr,
|
||||
Method: method,
|
||||
Arg: arg,
|
||||
ID: id,
|
||||
}.to_json())
|
||||
|
||||
# Send data to lrpc server
|
||||
@conn.send(data.buffer)
|
||||
end
|
||||
end
|
||||
|
||||
# getClient returns an object containing functions
|
||||
# corresponding to registered functions on the given
|
||||
# receiver. It uses the lrpc.Introspect() endpoint
|
||||
# to achieve this.
|
||||
def getObject(rcvr)
|
||||
return Promise.new do |resolve|
|
||||
# Introspect methods on given receiver
|
||||
self.callMethod("lrpc", "Introspect", rcvr).then do |methodDesc|
|
||||
# Create output object
|
||||
out = {}
|
||||
# For each method in description array
|
||||
methodDesc.each do |method|
|
||||
# Create and assign new function to call current method
|
||||
out[method.Name] = proc { |arg| return self.callMethod(rcvr, method.Name, arg) }
|
||||
end
|
||||
# Resolve promise with output promise
|
||||
resolve(out)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# LRPCChannel represents a channel used for lrpc.
|
||||
class LRPCChannel
|
||||
def initialize(client, id)
|
||||
# Set self variables
|
||||
@client = client
|
||||
@id = id
|
||||
@closed = false
|
||||
# Set function variables to no-ops
|
||||
@onMessage = proc {|fn|}
|
||||
@onClose = proc {}
|
||||
end
|
||||
|
||||
# isChannel is defined to allow identifying whether
|
||||
# an object is a channel.
|
||||
def isChannel() end
|
||||
|
||||
# send sends a value on the channel. This should not
|
||||
# be called by the consumer of the channel.
|
||||
def send(val)
|
||||
return if @closed
|
||||
fn = @onMessage
|
||||
fn(val)
|
||||
end
|
||||
|
||||
# done cancels the context corresponding to the channel
|
||||
# on the server side and closes the channel.
|
||||
def done()
|
||||
return if @closed
|
||||
@client.callMethod("lrpc", "ChannelDone", @id)
|
||||
self.close()
|
||||
@client._callMap.delete(@id)
|
||||
end
|
||||
|
||||
# onMessage sets the callback to be called whenever a
|
||||
# message is received. The function should have one parameter
|
||||
# that will be set to the value received. Subsequent calls
|
||||
# will overwrite the callback
|
||||
def onMessage(fn)
|
||||
@onMessage = fn
|
||||
end
|
||||
|
||||
# onClose sets the callback to be called whenever the client
|
||||
# is closed. The function should have no parameters.
|
||||
# Subsequent calls will overwrite the callback
|
||||
def onClose(fn)
|
||||
@onClose = fn
|
||||
end
|
||||
|
||||
# close closes the channel. This should not be called by the
|
||||
# consumer of the channel. Use done() instead.
|
||||
def close()
|
||||
return if @closed
|
||||
fn = @onClose
|
||||
fn()
|
||||
@closed = true
|
||||
end
|
||||
end
|
||||
@@ -1,6 +1,25 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"io"
|
||||
@@ -8,6 +27,9 @@ import (
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
// <= go1.17 compatibility
|
||||
type any = interface{}
|
||||
|
||||
// CodecFunc is a function that returns a new Codec
|
||||
// bound to the given io.ReadWriter
|
||||
type CodecFunc func(io.ReadWriter) Codec
|
||||
@@ -17,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec
|
||||
type Codec interface {
|
||||
Encode(val any) error
|
||||
Decode(val any) error
|
||||
Unmarshal(data []byte, v any) error
|
||||
Marshal(v any) ([]byte, error)
|
||||
}
|
||||
|
||||
// Default is the default CodecFunc
|
||||
var Default = Msgpack
|
||||
|
||||
type JsonCodec struct {
|
||||
*json.Encoder
|
||||
*json.Decoder
|
||||
}
|
||||
|
||||
func (JsonCodec) Unmarshal(data []byte, v any) error {
|
||||
return json.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (JsonCodec) Marshal(v any) ([]byte, error) {
|
||||
return json.Marshal(v)
|
||||
}
|
||||
|
||||
// JSON is a CodecFunc that creates a JSON Codec
|
||||
func JSON(rw io.ReadWriter) Codec {
|
||||
type jsonCodec struct {
|
||||
*json.Encoder
|
||||
*json.Decoder
|
||||
}
|
||||
return jsonCodec{
|
||||
return JsonCodec{
|
||||
Encoder: json.NewEncoder(rw),
|
||||
Decoder: json.NewDecoder(rw),
|
||||
}
|
||||
}
|
||||
|
||||
type MsgpackCodec struct {
|
||||
*msgpack.Encoder
|
||||
*msgpack.Decoder
|
||||
}
|
||||
|
||||
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
|
||||
return msgpack.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
|
||||
return msgpack.Marshal(v)
|
||||
}
|
||||
|
||||
// Msgpack is a CodecFunc that creates a Msgpack Codec
|
||||
func Msgpack(rw io.ReadWriter) Codec {
|
||||
type msgpackCodec struct {
|
||||
*msgpack.Encoder
|
||||
*msgpack.Decoder
|
||||
}
|
||||
return msgpackCodec{
|
||||
return MsgpackCodec{
|
||||
Encoder: msgpack.NewEncoder(rw),
|
||||
Decoder: msgpack.NewDecoder(rw),
|
||||
}
|
||||
}
|
||||
|
||||
type GobCodec struct {
|
||||
*gob.Encoder
|
||||
*gob.Decoder
|
||||
}
|
||||
|
||||
func (GobCodec) Unmarshal(data []byte, v any) error {
|
||||
return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
|
||||
}
|
||||
|
||||
func (GobCodec) Marshal(v any) ([]byte, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
err := gob.NewEncoder(buf).Encode(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Gob is a CodecFunc that creates a Gob Codec
|
||||
func Gob(rw io.ReadWriter) Codec {
|
||||
type gobCodec struct {
|
||||
*gob.Encoder
|
||||
*gob.Decoder
|
||||
}
|
||||
return gobCodec{
|
||||
return GobCodec{
|
||||
Encoder: gob.NewEncoder(rw),
|
||||
Decoder: gob.NewDecoder(rw),
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -12,21 +13,23 @@ import (
|
||||
func main() {
|
||||
gob.Register([2]int{})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
conn, _ := net.Dial("tcp", "localhost:9090")
|
||||
c := client.New(conn, codec.Gob)
|
||||
defer c.Close()
|
||||
|
||||
var add int
|
||||
c.Call("Arith", "Add", [2]int{5, 5}, &add)
|
||||
c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
|
||||
|
||||
var sub int
|
||||
c.Call("Arith", "Sub", [2]int{5, 5}, &sub)
|
||||
c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
|
||||
|
||||
var mul int
|
||||
c.Call("Arith", "Mul", [2]int{5, 5}, &mul)
|
||||
c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
|
||||
|
||||
var div int
|
||||
c.Call("Arith", "Div", [2]int{5, 5}, &div)
|
||||
c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
|
||||
|
||||
fmt.Printf(
|
||||
"add: %d, sub: %d, mul: %d, div: %d\n",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net"
|
||||
|
||||
@@ -33,5 +34,5 @@ func main() {
|
||||
s.Register(Arith{})
|
||||
|
||||
ln, _ := net.Listen("tcp", ":9090")
|
||||
s.Serve(ln, codec.Gob)
|
||||
s.Serve(context.Background(), ln, codec.Gob)
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@@ -1,11 +1,12 @@
|
||||
module go.arsenm.dev/lrpc
|
||||
|
||||
go 1.18
|
||||
go 1.17
|
||||
|
||||
require (
|
||||
github.com/gofrs/uuid v4.2.0+incompatible
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5
|
||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4
|
||||
)
|
||||
|
||||
require github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -13,6 +13,8 @@ github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9
|
||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA=
|
||||
golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
package reflectutil
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
// Convert attempts to convert the given value to the given type
|
||||
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
|
||||
// Get input type
|
||||
inType := in.Type()
|
||||
|
||||
// If input is already the desired type, return
|
||||
if inType == toType {
|
||||
return in, nil
|
||||
}
|
||||
|
||||
// If input can be converted to desired type, convert and return
|
||||
if in.CanConvert(toType) {
|
||||
return in.Convert(toType), nil
|
||||
}
|
||||
|
||||
// Create new value of desired type
|
||||
to := reflect.New(toType).Elem()
|
||||
|
||||
// If type is a pointer
|
||||
if to.Kind() == reflect.Ptr {
|
||||
// Initialize value
|
||||
to.Set(reflect.New(to.Type().Elem()))
|
||||
}
|
||||
|
||||
switch val := in.Interface().(type) {
|
||||
case string:
|
||||
// If desired type satisfies text unmarshaler
|
||||
if u, ok := to.Interface().(encoding.TextUnmarshaler); ok {
|
||||
// Use text unmarshaler to get value
|
||||
err := u.UnmarshalText([]byte(val))
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Return unmarshaled value
|
||||
return reflect.ValueOf(any(u)), nil
|
||||
}
|
||||
case []byte:
|
||||
// If desired type satisfies binary unmarshaler
|
||||
if u, ok := to.Interface().(encoding.BinaryUnmarshaler); ok {
|
||||
// Use binary unmarshaler to get value
|
||||
err := u.UnmarshalBinary(val)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
|
||||
// Return unmarshaled value
|
||||
return reflect.ValueOf(any(u)), nil
|
||||
}
|
||||
}
|
||||
|
||||
// If input is a map
|
||||
if in.Kind() == reflect.Map {
|
||||
// Use mapstructure to decode value
|
||||
err := mapstructure.Decode(in.Interface(), to.Addr().Interface())
|
||||
if err == nil {
|
||||
return to, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If input is a slice of any, and output is an array or slice
|
||||
if in.Type() == reflect.TypeOf([]any{}) &&
|
||||
to.Kind() == reflect.Slice || to.Kind() == reflect.Array {
|
||||
// Use ConvertSlice to convert value
|
||||
return reflect.ValueOf(ConvertSlice(
|
||||
in.Interface().([]any),
|
||||
toType,
|
||||
)), nil
|
||||
}
|
||||
|
||||
return to, fmt.Errorf("cannot convert %s to %s", inType, toType)
|
||||
}
|
||||
|
||||
// ConvertSlice converts []any to an array or slice, as provided
|
||||
// in the "to" argument.
|
||||
func ConvertSlice(in []any, to reflect.Type) any {
|
||||
// Create new value for output
|
||||
out := reflect.New(to).Elem()
|
||||
|
||||
// If output value is a slice
|
||||
if out.Kind() == reflect.Slice {
|
||||
// Get type of slice elements
|
||||
outType := out.Type().Elem()
|
||||
|
||||
// For every value provided
|
||||
for i := 0; i < len(in); i++ {
|
||||
// Get value of input type
|
||||
inVal := reflect.ValueOf(in[i])
|
||||
// Create new output type
|
||||
outVal := reflect.New(outType).Elem()
|
||||
|
||||
// If types match
|
||||
if inVal.Type() == outType {
|
||||
// Set output value to input value
|
||||
outVal.Set(inVal)
|
||||
} else {
|
||||
newVal, err := Convert(inVal, outType)
|
||||
if err != nil {
|
||||
// Set output value to its zero value
|
||||
outVal.Set(reflect.Zero(outVal.Type()))
|
||||
} else {
|
||||
outVal.Set(newVal)
|
||||
}
|
||||
}
|
||||
|
||||
// Append output value to slice
|
||||
out = reflect.Append(out, outVal)
|
||||
}
|
||||
} else if out.Kind() == reflect.Array && out.Len() == len(in) {
|
||||
//If output type is array and lengths match
|
||||
|
||||
// For every input value
|
||||
for i := 0; i < len(in); i++ {
|
||||
// Get matching output index
|
||||
outVal := out.Index(i)
|
||||
// Get input value
|
||||
inVal := reflect.ValueOf(in[i])
|
||||
|
||||
// If types match
|
||||
if inVal.Type() == outVal.Type() {
|
||||
// Set output value to input value
|
||||
outVal.Set(inVal)
|
||||
} else {
|
||||
// If input value can be converted to output type
|
||||
if inVal.CanConvert(outVal.Type()) {
|
||||
// Convert and set output value to input value
|
||||
outVal.Set(inVal.Convert(outVal.Type()))
|
||||
} else {
|
||||
// Set output value to its zero value
|
||||
outVal.Set(reflect.Zero(outVal.Type()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return created value
|
||||
return out.Interface()
|
||||
}
|
||||
@@ -1,19 +1,47 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package types
|
||||
|
||||
// <= go1.17 compatibility
|
||||
type any = interface{}
|
||||
|
||||
// Request represents a request sent to the server
|
||||
type Request struct {
|
||||
ID string
|
||||
Receiver string
|
||||
Method string
|
||||
Arg any
|
||||
Arg []byte
|
||||
}
|
||||
|
||||
type ResponseType uint8
|
||||
|
||||
const (
|
||||
ResponseTypeNormal ResponseType = iota
|
||||
ResponseTypeError
|
||||
ResponseTypeChannel
|
||||
ResponseTypeChannelDone
|
||||
)
|
||||
|
||||
// Response represents a response returned by the server
|
||||
type Response struct {
|
||||
ID string
|
||||
ChannelDone bool
|
||||
IsChannel bool
|
||||
IsError bool
|
||||
Error string
|
||||
Return any
|
||||
Type ResponseType
|
||||
ID string
|
||||
Error string
|
||||
Return []byte
|
||||
}
|
||||
|
||||
212
lrpc_test.go
Normal file
212
lrpc_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package lrpc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.arsenm.dev/lrpc/client"
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/server"
|
||||
)
|
||||
|
||||
type Arith struct{}
|
||||
|
||||
func (Arith) Add(ctx *server.Context, in [2]int) int {
|
||||
return in[0] + in[1]
|
||||
}
|
||||
|
||||
func (Arith) Mul(ctx *server.Context, in [2]int) int {
|
||||
return in[0] * in[1]
|
||||
}
|
||||
|
||||
func (Arith) Div(ctx *server.Context, in [2]int) int {
|
||||
return in[0] / in[1]
|
||||
}
|
||||
|
||||
func (Arith) Sub(ctx *server.Context, in [2]int) int {
|
||||
return in[0] - in[1]
|
||||
}
|
||||
|
||||
func TestCalls(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create new network pipe
|
||||
sConn, cConn := net.Pipe()
|
||||
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
// Register Arith for RPC
|
||||
s.Register(Arith{})
|
||||
// Serve the pipe connection using default codec
|
||||
go s.ServeConn(ctx, sConn, codec.Default)
|
||||
|
||||
// Create new client using default codec
|
||||
c := client.New(cConn, codec.Default)
|
||||
defer c.Close()
|
||||
|
||||
// Call Arith.Add()
|
||||
var add int
|
||||
err := c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Call Arith.Sub()
|
||||
var sub int
|
||||
err = c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Call Arith.Mul()
|
||||
var mul int
|
||||
err = c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Call Arith.Div()
|
||||
var div int
|
||||
err = c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if add != 10 {
|
||||
t.Errorf("add: expected 10, got %d", add)
|
||||
}
|
||||
|
||||
if sub != 0 {
|
||||
t.Errorf("sub: expected 0, got %d", sub)
|
||||
}
|
||||
|
||||
if mul != 25 {
|
||||
t.Errorf("mul: expected 25, got %d", mul)
|
||||
}
|
||||
|
||||
if div != 1 {
|
||||
t.Errorf("div: expected 1, got %d", div)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodecs(t *testing.T) {
|
||||
// Register the 2-integer array for gob
|
||||
gob.Register([2]int{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create function to test each codec
|
||||
testCodec := func(cf codec.CodecFunc, name string) {
|
||||
// Create network pipe
|
||||
sConn, cConn := net.Pipe()
|
||||
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
// Register Arith for RPC
|
||||
s.Register(Arith{})
|
||||
// Serve the pipe connection using provided codec
|
||||
go s.ServeConn(ctx, sConn, cf)
|
||||
|
||||
// Create new client using provided codec
|
||||
c := client.New(cConn, cf)
|
||||
defer c.Close()
|
||||
|
||||
// Call Arith.Add()
|
||||
var add int
|
||||
err := c.Call(ctx, "Arith", "Add", [2]int{2, 2}, &add)
|
||||
if err != nil {
|
||||
t.Errorf("codec/%s: %v", name, err)
|
||||
}
|
||||
|
||||
if add != 4 {
|
||||
t.Errorf("codec/%s: add: expected 4, got %d", name, add)
|
||||
}
|
||||
}
|
||||
|
||||
// Test all codecs
|
||||
testCodec(codec.Msgpack, "msgpack")
|
||||
testCodec(codec.JSON, "json")
|
||||
testCodec(codec.Gob, "gob")
|
||||
}
|
||||
|
||||
type Channel struct{}
|
||||
|
||||
func (Channel) Time(ctx *server.Context, interval time.Duration) error {
|
||||
ch, err := ctx.MakeChannel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tick := time.NewTicker(interval)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case t := <-tick.C:
|
||||
ch <- t
|
||||
case <-ctx.Done():
|
||||
tick.Stop()
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create new network pipe
|
||||
sConn, cConn := net.Pipe()
|
||||
|
||||
s := server.New()
|
||||
defer s.Close()
|
||||
// Register Arith for RPC
|
||||
s.Register(Channel{})
|
||||
// Serve the pipe connection using default codec
|
||||
go s.ServeConn(ctx, sConn, codec.Default)
|
||||
|
||||
// Create new client using default codec
|
||||
c := client.New(cConn, codec.Default)
|
||||
defer c.Close()
|
||||
|
||||
timeCtx, timeCancel := context.WithCancel(ctx)
|
||||
defer timeCancel()
|
||||
|
||||
timeCh := make(chan *time.Time, 2)
|
||||
err := c.Call(timeCtx, "Channel", "Time", time.Millisecond, timeCh)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var loops int
|
||||
var lastTime *time.Time
|
||||
for curTime := range timeCh {
|
||||
if loops > 3 {
|
||||
timeCancel()
|
||||
break
|
||||
}
|
||||
|
||||
if lastTime == nil {
|
||||
lastTime = curTime
|
||||
continue
|
||||
}
|
||||
|
||||
diff := curTime.Sub(*lastTime)
|
||||
diff = diff.Round(time.Millisecond)
|
||||
|
||||
if diff != time.Millisecond {
|
||||
t.Fatalf("expected 1s diff, got %s", diff)
|
||||
}
|
||||
|
||||
lastTime = curTime
|
||||
loops++
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,27 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
@@ -14,7 +35,26 @@ type Context struct {
|
||||
|
||||
codec codec.Codec
|
||||
|
||||
doneCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
canceled bool
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func newContext(ctx context.Context, codec codec.Codec) *Context {
|
||||
out := &Context{
|
||||
doneCh: make(chan struct{}),
|
||||
codec: codec,
|
||||
ctx: ctx,
|
||||
}
|
||||
if ctx == nil {
|
||||
out.ctx = context.Background()
|
||||
}
|
||||
go func() {
|
||||
<-out.ctx.Done()
|
||||
out.cancel()
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
// MakeChannel changes the function it's called in into a
|
||||
@@ -37,6 +77,26 @@ func (ctx *Context) GetCodec() codec.Codec {
|
||||
return ctx.codec
|
||||
}
|
||||
|
||||
// Deadline always returns the current time and false
|
||||
// as this context does not support deadlines
|
||||
func (ctx *Context) Deadline() (time.Time, bool) {
|
||||
return time.Now(), false
|
||||
}
|
||||
|
||||
// Value always returns nil as this context stores no values
|
||||
func (ctx *Context) Value(key any) any {
|
||||
return ctx.ctx.Value(key)
|
||||
}
|
||||
|
||||
// Err returns context.Canceled if the context was canceled,
|
||||
// otherwise nil
|
||||
func (ctx *Context) Err() error {
|
||||
if ctx.canceled {
|
||||
return context.Canceled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that will be closed when
|
||||
// the context is canceled, such as when ChannelDone
|
||||
// is called by the client
|
||||
@@ -45,6 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
|
||||
}
|
||||
|
||||
// Cancel cancels the context
|
||||
func (ctx *Context) Cancel() {
|
||||
func (ctx *Context) cancel() {
|
||||
if ctx.canceled {
|
||||
return
|
||||
}
|
||||
ctx.canceled = true
|
||||
close(ctx.doneCh)
|
||||
}
|
||||
|
||||
396
server/server.go
396
server/server.go
@@ -1,28 +1,46 @@
|
||||
/*
|
||||
* lrpc allows for clients to call functions on a server remotely.
|
||||
* Copyright (C) 2022 Arsen Musayelyan
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
// <= go1.17 compatibility
|
||||
type any = interface{}
|
||||
|
||||
var (
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrTooManyInputs = errors.New("method may not have more than two inputs")
|
||||
ErrTooManyOutputs = errors.New("method may not have more than two return values")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidSecondReturn = errors.New("second return value must be error")
|
||||
ErrInvalidFirstInput = errors.New("first input must be *Context")
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidMethod = errors.New("method invalid for lrpc call")
|
||||
ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
|
||||
)
|
||||
|
||||
// Server is an lrpc server
|
||||
@@ -50,7 +68,7 @@ func New() *Server {
|
||||
// Close closes the server
|
||||
func (s *Server) Close() {
|
||||
for _, ctx := range s.contexts {
|
||||
ctx.Cancel()
|
||||
ctx.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,7 +99,7 @@ func (s *Server) Register(v any) error {
|
||||
}
|
||||
|
||||
// execute runs a method of a registered value
|
||||
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
// Try to get value from receivers map
|
||||
val, ok := s.rcvrs[typ]
|
||||
if !ok {
|
||||
@@ -94,55 +112,37 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
return nil, nil, ErrNoSuchMethod
|
||||
}
|
||||
|
||||
// Get method's type
|
||||
// If method invalid, return error
|
||||
if !mtdValid(mtd) {
|
||||
return nil, nil, ErrInvalidMethod
|
||||
}
|
||||
|
||||
// Get method type
|
||||
mtdType := mtd.Type()
|
||||
if mtdType.NumIn() > 2 {
|
||||
return nil, nil, ErrTooManyInputs
|
||||
} else if mtdType.NumIn() < 1 {
|
||||
return nil, nil, ErrInvalidFirstInput
|
||||
}
|
||||
|
||||
if mtdType.NumOut() > 2 {
|
||||
return nil, nil, ErrTooManyOutputs
|
||||
}
|
||||
|
||||
// Check to ensure first parameter is context
|
||||
if mtdType.In(0) != reflect.TypeOf(&Context{}) {
|
||||
return nil, nil, ErrInvalidFirstInput
|
||||
}
|
||||
|
||||
//TODO: if arg not nil but fn has no arg, err
|
||||
|
||||
// IF argument is []any
|
||||
anySlice, ok := arg.([]any)
|
||||
if ok {
|
||||
// Convert slice to the method's arg type and
|
||||
// set arg to the newly-converted slice
|
||||
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
|
||||
argType := mtdType.In(1)
|
||||
argVal := reflect.New(argType)
|
||||
arg := argVal.Interface()
|
||||
|
||||
err = c.Unmarshal(data, arg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get argument value
|
||||
argVal := reflect.ValueOf(arg)
|
||||
// If argument's type does not match method's argument type
|
||||
if arg != nil && argVal.Type() != mtdType.In(1) {
|
||||
val, err = reflectutil.Convert(argVal, mtdType.In(1))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
arg = val.Interface()
|
||||
}
|
||||
arg = argVal.Elem().Interface()
|
||||
|
||||
// Create new context
|
||||
ctx = &Context{
|
||||
doneCh: make(chan struct{}, 1),
|
||||
codec: c,
|
||||
}
|
||||
ctx = newContext(pCtx, c)
|
||||
// Get reflect value of context
|
||||
ctxVal := reflect.ValueOf(ctx)
|
||||
|
||||
switch mtdType.NumOut() {
|
||||
case 0: // If method has no return values
|
||||
if mtdType.NumIn() == 2 {
|
||||
if arg == nil {
|
||||
return nil, nil, ErrArgNotProvided
|
||||
}
|
||||
// Call method with arg, ignore returned value
|
||||
mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
} else {
|
||||
@@ -151,6 +151,10 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
}
|
||||
case 1: // If method has one return value
|
||||
if mtdType.NumIn() == 2 {
|
||||
if arg == nil {
|
||||
return nil, nil, ErrArgNotProvided
|
||||
}
|
||||
|
||||
// Call method with arg, get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
|
||||
@@ -185,6 +189,10 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
}
|
||||
case 2: // If method has two return values
|
||||
if mtdType.NumIn() == 2 {
|
||||
if arg == nil {
|
||||
return nil, nil, ErrArgNotProvided
|
||||
}
|
||||
|
||||
// Call method with arg and get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
|
||||
@@ -195,7 +203,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
|
||||
// If second return value is not an error, the function is invalid
|
||||
if !ok {
|
||||
a, err = nil, ErrInvalidSecondReturn
|
||||
a, err = nil, ErrInvalidMethod
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,7 +219,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
// If second return value is not an error, the function is invalid
|
||||
err, ok = out1.(error)
|
||||
if !ok {
|
||||
a, err = nil, ErrInvalidSecondReturn
|
||||
a, err = nil, ErrInvalidMethod
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,22 +232,66 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
|
||||
// Serve starts the server using the provided listener
|
||||
// and codec function
|
||||
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
|
||||
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
break
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new instance of codec bound to conn
|
||||
c := cf(conn)
|
||||
// Handle connection
|
||||
go s.handleConn(c)
|
||||
go s.handleConn(ctx, c)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn handles a listener connection
|
||||
func (s *Server) handleConn(c codec.Codec) {
|
||||
// ServeWS starts a server using WebSocket. This may be useful for
|
||||
// clients written in other languages, such as JS for a browser.
|
||||
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
|
||||
// Create new WebSocket server
|
||||
ws := websocket.Server{}
|
||||
|
||||
// Create new WebSocket config
|
||||
ws.Config = websocket.Config{
|
||||
Version: websocket.ProtocolVersionHybi13,
|
||||
}
|
||||
|
||||
// Set server handler
|
||||
ws.Handler = func(c *websocket.Conn) {
|
||||
s.handleConn(c.Request().Context(), cf(c))
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
Handler: http.HandlerFunc(ws.ServeHTTP),
|
||||
}
|
||||
|
||||
// Listen and serve on given address
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// ServeConn uses the provided connection to serve the client.
|
||||
// This may be useful if something other than a net.Listener
|
||||
// needs to be used
|
||||
func (s *Server) ServeConn(ctx context.Context, conn io.ReadWriter, cf codec.CodecFunc) {
|
||||
s.handleConn(ctx, cf(conn))
|
||||
}
|
||||
|
||||
// handleConn handles a connection
|
||||
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
codecMtx := &sync.Mutex{}
|
||||
|
||||
for {
|
||||
var call types.Request
|
||||
// Read request using codec
|
||||
@@ -251,72 +303,103 @@ func (s *Server) handleConn(c codec.Codec) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute decoded call
|
||||
val, ctx, err := s.execute(
|
||||
call.Receiver,
|
||||
call.Method,
|
||||
call.Arg,
|
||||
c,
|
||||
)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
} else {
|
||||
// Create response
|
||||
res := types.Response{
|
||||
ID: call.ID,
|
||||
Return: val,
|
||||
}
|
||||
go func() {
|
||||
// Execute decoded call
|
||||
val, ctx, err := s.execute(
|
||||
pCtx,
|
||||
call.Receiver,
|
||||
call.Method,
|
||||
call.Arg,
|
||||
c,
|
||||
)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
} else {
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
return
|
||||
}
|
||||
|
||||
// If function has created a channel
|
||||
if ctx.isChannel {
|
||||
// Set IsChannel to true
|
||||
res.IsChannel = true
|
||||
// Overwrite return value with channel ID
|
||||
res.Return = ctx.channelID
|
||||
// Create response
|
||||
res := types.Response{
|
||||
ID: call.ID,
|
||||
Return: valData,
|
||||
}
|
||||
|
||||
// Store context in map for future use
|
||||
s.contextsMtx.Lock()
|
||||
s.contexts[ctx.channelID] = ctx
|
||||
s.contextsMtx.Unlock()
|
||||
|
||||
go func() {
|
||||
// For every value received from channel
|
||||
for val := range ctx.channel {
|
||||
// Encode response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
Return: val,
|
||||
})
|
||||
// If function has created a channel
|
||||
if ctx.isChannel {
|
||||
idData, err := c.Marshal(ctx.channelID)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.Cancel()
|
||||
// Delete context from map
|
||||
// Set IsChannel to true
|
||||
res.Type = types.ResponseTypeChannel
|
||||
// Overwrite return value with channel ID
|
||||
res.Return = idData
|
||||
|
||||
// Store context in map for future use
|
||||
s.contextsMtx.Lock()
|
||||
delete(s.contexts, ctx.channelID)
|
||||
s.contexts[ctx.channelID] = ctx
|
||||
s.contextsMtx.Unlock()
|
||||
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
ChannelDone: true,
|
||||
})
|
||||
}()
|
||||
go func() {
|
||||
// For every value received from channel
|
||||
for val := range ctx.channel {
|
||||
codecMtx.Lock()
|
||||
|
||||
valData, err := c.Marshal(val)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
Return: valData,
|
||||
})
|
||||
|
||||
codecMtx.Unlock()
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.cancel()
|
||||
// Delete context from map
|
||||
s.contextsMtx.Lock()
|
||||
delete(s.contexts, ctx.channelID)
|
||||
s.contextsMtx.Unlock()
|
||||
|
||||
codecMtx.Lock()
|
||||
c.Encode(types.Response{
|
||||
Type: types.ResponseTypeChannelDone,
|
||||
ID: ctx.channelID,
|
||||
})
|
||||
codecMtx.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
codecMtx.Lock()
|
||||
c.Encode(res)
|
||||
codecMtx.Unlock()
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
c.Encode(res)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// sendErr sends an error response
|
||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
||||
valData, _ := c.Marshal(val)
|
||||
|
||||
// Encode error response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: req.ID,
|
||||
IsError: true,
|
||||
Error: err.Error(),
|
||||
Return: val,
|
||||
Type: types.ResponseTypeError,
|
||||
ID: req.ID,
|
||||
Error: err.Error(),
|
||||
Return: valData,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -334,9 +417,114 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.Cancel()
|
||||
ctx.cancel()
|
||||
// Delete context from map
|
||||
l.srv.contextsMtx.Lock()
|
||||
delete(l.srv.contexts, id)
|
||||
l.srv.contextsMtx.Unlock()
|
||||
}
|
||||
|
||||
// MethodDesc describes methods on a receiver
|
||||
type MethodDesc struct {
|
||||
Name string
|
||||
Args []string
|
||||
Returns []string
|
||||
}
|
||||
|
||||
// Introspect returns method descriptions for the given receiver
|
||||
func (l lrpc) Introspect(_ *Context, name string) ([]MethodDesc, error) {
|
||||
// Attempt to get receiver
|
||||
rcvr, ok := l.srv.rcvrs[name]
|
||||
if !ok {
|
||||
return nil, ErrNoSuchReceiver
|
||||
}
|
||||
// Get receiver type value
|
||||
rcvrType := rcvr.Type()
|
||||
|
||||
// Create slice for output
|
||||
var out []MethodDesc
|
||||
// For every method on receiver
|
||||
for i := 0; i < rcvr.NumMethod(); i++ {
|
||||
// Get receiver method
|
||||
mtd := rcvr.Method(i)
|
||||
// If invalid, skip
|
||||
if !mtdValid(mtd) {
|
||||
continue
|
||||
}
|
||||
// Get method type
|
||||
mtdType := mtd.Type()
|
||||
|
||||
// Get amount of arguments
|
||||
numIn := mtdType.NumIn()
|
||||
args := make([]string, numIn-1)
|
||||
// For every argument, store type in slice
|
||||
// Skip first argument, as it is *Context
|
||||
for i := 1; i < numIn; i++ {
|
||||
args[i-1] = mtdType.In(i).String()
|
||||
}
|
||||
|
||||
// Get amount of returns
|
||||
numOut := mtdType.NumOut()
|
||||
returns := make([]string, numOut)
|
||||
// For every return, store type in slice
|
||||
for i := 0; i < numOut; i++ {
|
||||
returns[i] = mtdType.Out(i).String()
|
||||
}
|
||||
|
||||
out = append(out, MethodDesc{
|
||||
Name: rcvrType.Method(i).Name,
|
||||
Args: args,
|
||||
Returns: returns,
|
||||
})
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IntrospectAll runs Introspect on all registered receivers and returns all results
|
||||
func (l lrpc) IntrospectAll(_ *Context) (map[string][]MethodDesc, error) {
|
||||
// Create map for output
|
||||
out := make(map[string][]MethodDesc, len(l.srv.rcvrs))
|
||||
// For every registered receiver
|
||||
for name := range l.srv.rcvrs {
|
||||
// Introspect receiver
|
||||
descs, err := l.Introspect(nil, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set results in map
|
||||
out[name] = descs
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func mtdValid(mtd reflect.Value) bool {
|
||||
// Get method's type
|
||||
mtdType := mtd.Type()
|
||||
|
||||
// If method has more than 2 or less than 1 input, it is invalid
|
||||
if mtdType.NumIn() > 2 || mtdType.NumIn() < 1 {
|
||||
return false
|
||||
}
|
||||
|
||||
// If method has more than 2 outputs, it is invalid
|
||||
if mtdType.NumOut() > 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check to ensure first parameter is context
|
||||
if mtdType.In(0) != reflect.TypeOf((*Context)(nil)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If method has 2 outputs
|
||||
if mtdType.NumOut() == 2 {
|
||||
// Check to ensure the second one is an error
|
||||
if mtdType.Out(1).Name() != "error" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user