Compare commits
	
		
			15 Commits
		
	
	
		
			eadee97e5e
			...
			master
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7592eae318 | |||
| d35a16ec64 | |||
| e02c8bc5ff | |||
| 5e61e89ac1 | |||
| 205e0b71e4 | |||
| 1e627b833e | |||
| 368c7333c5 | |||
| 1e8e304f01 | |||
| acf262b4f0 | |||
| 8843e7faa9 | |||
| 328be35ae2 | |||
| 6ee3602128 | |||
| e518b68d8c | |||
| 771c8c136e | |||
| c0a1c3bf43 | 
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
/client/web/lrpc.js
 | 
			
		||||
/s
 | 
			
		||||
/c
 | 
			
		||||
@@ -18,3 +18,9 @@ This RPC framework supports creating channels to transfer data from server to cl
 | 
			
		||||
When creating a server or client, a `CodecFunc` can be provided. An `io.ReadWriter` is passed into the `CodecFunc` and it returns a `Codec`, which is an interface that contains encode and decode functions with the same signature as `json.Decoder.Decode()` and `json.Encoder.Encode()`.
 | 
			
		||||
 | 
			
		||||
This allows any codec to be used for the transfer of the data, making it easy to create clients in different languages.
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
### Web Client
 | 
			
		||||
 | 
			
		||||
Inside `client/web`, there is a web client for lrpc using WebSockets. It is written in ruby (I don't like JS) and translated to human-readable JS using Ruby2JS. With the `bundler` gem installed, cd into `client/web` and run `make`. This will create a new file called `lrpc.js`, which can be used within a browser. It uses `crypto.randomUUID()`, so it must be used on an https site, not http.
 | 
			
		||||
@@ -26,7 +26,6 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"go.arsenm.dev/lrpc/codec"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/reflectutil"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/types"
 | 
			
		||||
 | 
			
		||||
	"github.com/gofrs/uuid"
 | 
			
		||||
@@ -81,12 +80,17 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
	c.chs[idStr] = make(chan *types.Response, 1)
 | 
			
		||||
	c.chMtx.Unlock()
 | 
			
		||||
 | 
			
		||||
	argData, err := c.codec.Marshal(arg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Encode request using codec
 | 
			
		||||
	err = c.codec.Encode(types.Request{
 | 
			
		||||
		ID:       idStr,
 | 
			
		||||
		Receiver: rcvr,
 | 
			
		||||
		Method:   method,
 | 
			
		||||
		Arg:      arg,
 | 
			
		||||
		Arg:      argData,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -124,7 +128,11 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
			return ErrReturnNotChannel
 | 
			
		||||
		}
 | 
			
		||||
		// Get channel ID returned in response
 | 
			
		||||
		chID := resp.Return.(string)
 | 
			
		||||
		var chID string
 | 
			
		||||
		err = c.codec.Unmarshal(resp.Return, &chID)
 | 
			
		||||
		if resp.Return == nil {
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create new channel using channel ID
 | 
			
		||||
		c.chMtx.Lock()
 | 
			
		||||
@@ -149,21 +157,16 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
					retVal.Close()
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				// Get reflect value from channel response
 | 
			
		||||
				rVal := reflect.ValueOf(val.Return)
 | 
			
		||||
 | 
			
		||||
				// If return value is not the same as the channel
 | 
			
		||||
				if rVal.Type() != chElemType {
 | 
			
		||||
					// Attempt to convert value, skip if impossible
 | 
			
		||||
					newVal, err := reflectutil.Convert(rVal, chElemType)
 | 
			
		||||
				outVal := reflect.New(chElemType)
 | 
			
		||||
				err = c.codec.Unmarshal(val.Return, outVal.Interface())
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
					rVal = newVal
 | 
			
		||||
				}
 | 
			
		||||
				outVal = outVal.Elem()
 | 
			
		||||
 | 
			
		||||
				chosen, _, _ := reflect.Select([]reflect.SelectCase{
 | 
			
		||||
					{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
 | 
			
		||||
					{Dir: reflect.SelectSend, Chan: retVal, Send: outVal},
 | 
			
		||||
					{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
 | 
			
		||||
				})
 | 
			
		||||
				if chosen == 1 {
 | 
			
		||||
@@ -179,28 +182,10 @@ func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{},
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	} else if resp.Type == types.ResponseTypeNormal {
 | 
			
		||||
		// IF return value is not a pointer, return error
 | 
			
		||||
		if retVal.Kind() != reflect.Ptr {
 | 
			
		||||
			return ErrReturnNotPointer
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Get return type
 | 
			
		||||
		retType := retVal.Type().Elem()
 | 
			
		||||
		// Get refkect value from response
 | 
			
		||||
		rVal := reflect.ValueOf(resp.Return)
 | 
			
		||||
 | 
			
		||||
		// If types do not match
 | 
			
		||||
		if rVal.Type() != retType {
 | 
			
		||||
			// Attempt to convert types, return error if not possible
 | 
			
		||||
			newVal, err := reflectutil.Convert(rVal, retType)
 | 
			
		||||
		err = c.codec.Unmarshal(resp.Return, ret)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
			rVal = newVal
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Set return value to received value
 | 
			
		||||
		retVal.Elem().Set(rVal)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										3
									
								
								client/web/Gemfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								client/web/Gemfile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
source 'https://rubygems.org'
 | 
			
		||||
 | 
			
		||||
gem 'ruby2js'
 | 
			
		||||
							
								
								
									
										19
									
								
								client/web/Gemfile.lock
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								client/web/Gemfile.lock
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
GEM
 | 
			
		||||
  remote: https://rubygems.org/
 | 
			
		||||
  specs:
 | 
			
		||||
    ast (2.4.2)
 | 
			
		||||
    parser (3.1.2.0)
 | 
			
		||||
      ast (~> 2.4.1)
 | 
			
		||||
    regexp_parser (2.1.1)
 | 
			
		||||
    ruby2js (5.0.1)
 | 
			
		||||
      parser
 | 
			
		||||
      regexp_parser (~> 2.1.1)
 | 
			
		||||
 | 
			
		||||
PLATFORMS
 | 
			
		||||
  x86_64-linux
 | 
			
		||||
 | 
			
		||||
DEPENDENCIES
 | 
			
		||||
  ruby2js
 | 
			
		||||
 | 
			
		||||
BUNDLED WITH
 | 
			
		||||
   2.3.15
 | 
			
		||||
							
								
								
									
										3
									
								
								client/web/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								client/web/Makefile
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
lrpc.js: convert.rb lrpc.rb
 | 
			
		||||
	bundle install
 | 
			
		||||
	ruby convert.rb > lrpc.js
 | 
			
		||||
							
								
								
									
										10
									
								
								client/web/convert.rb
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										10
									
								
								client/web/convert.rb
									
									
									
									
									
										Executable file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
#!/usr/bin/ruby
 | 
			
		||||
 | 
			
		||||
require 'ruby2js'
 | 
			
		||||
require 'ruby2js/filter/functions'
 | 
			
		||||
 | 
			
		||||
puts Ruby2JS.convert(
 | 
			
		||||
    File.read('lrpc.rb'),
 | 
			
		||||
    eslevel: 2016,
 | 
			
		||||
    exclude: [:delete],
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										174
									
								
								client/web/lrpc.rb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								client/web/lrpc.rb
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,174 @@
 | 
			
		||||
# LRPCResponseType represents the various types an LRPC
 | 
			
		||||
# response can have.
 | 
			
		||||
LRPCResponseType = {
 | 
			
		||||
  Normal: 0,
 | 
			
		||||
  Error: 1,
 | 
			
		||||
  Channel: 2,
 | 
			
		||||
  ChannelDone: 3,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# LRPCClient represents a client for the LRPC protocol
 | 
			
		||||
# using WebSockets and the JSON codec
 | 
			
		||||
class LRPCClient
 | 
			
		||||
  def initialize(addr)
 | 
			
		||||
    # Set self variables
 | 
			
		||||
    @callMap = Map.new()
 | 
			
		||||
    @enc = TextEncoder.new()
 | 
			
		||||
    @dec = TextDecoder.new()
 | 
			
		||||
    
 | 
			
		||||
    # Create connection to lrpc server
 | 
			
		||||
    @conn = WebSocket.new(addr)
 | 
			
		||||
    @conn.binaryType = "arraybuffer"
 | 
			
		||||
    @conn.onmessage = proc do |msg|
 | 
			
		||||
      # if msg.data is string
 | 
			
		||||
      if msg.data.instance_of? String
 | 
			
		||||
        # Set json to msg.data
 | 
			
		||||
        json = msg.data
 | 
			
		||||
      else
 | 
			
		||||
        # Set json to decoded msg.data
 | 
			
		||||
        json = @dec.decode(msg.data)
 | 
			
		||||
      end
 | 
			
		||||
      # Parse JSON string
 | 
			
		||||
      val = JSON.parse(json)
 | 
			
		||||
      # Get id from callMap
 | 
			
		||||
      fns = @callMap.get(val.ID)
 | 
			
		||||
      # If fns is undefined (key does not exist), and this is
 | 
			
		||||
      # a normal response, return
 | 
			
		||||
      return if !fns && val.Type == LRPCResponseType.Normal
 | 
			
		||||
 | 
			
		||||
      case val.Type
 | 
			
		||||
      when LRPCResponseType.Normal
 | 
			
		||||
        # If fns is a channel, send the value. Otherwise,
 | 
			
		||||
        # resolve the promise with the value.
 | 
			
		||||
        if fns.isChannel
 | 
			
		||||
          fns.send(val.Return)
 | 
			
		||||
        else
 | 
			
		||||
          fns.resolve(val.Return)
 | 
			
		||||
        end
 | 
			
		||||
      when LRPCResponseType.Channel
 | 
			
		||||
        # Get channel ID from response
 | 
			
		||||
        chID = val.Return
 | 
			
		||||
        # Create new LRPCChannel
 | 
			
		||||
        ch = LRPCChannel.new(self, chID)
 | 
			
		||||
        # Set channel in map
 | 
			
		||||
        @callMap.set(chID, ch)
 | 
			
		||||
        # Resolve promise with channel
 | 
			
		||||
        fns.resolve(ch)
 | 
			
		||||
      when LRPCResponseType.ChannelDone
 | 
			
		||||
        # Close and delete channel
 | 
			
		||||
        fns.close()
 | 
			
		||||
        @callMap.delete(val.ID)
 | 
			
		||||
      when LRPCResponseType.Error
 | 
			
		||||
        # Reject promise with error
 | 
			
		||||
        fns.reject(val.Error)
 | 
			
		||||
      end
 | 
			
		||||
 | 
			
		||||
      # Delete item from map unless it is a channel
 | 
			
		||||
      @callMap.delete(val.ID) unless fns.isChannel
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # call calls a method on the server with the given
 | 
			
		||||
  # argument and returns a promise.
 | 
			
		||||
  def callMethod(rcvr, method, arg)
 | 
			
		||||
    return Promise.new do |resolve, reject|
 | 
			
		||||
      # Get random UUID (this only works with TLS)
 | 
			
		||||
      id = crypto.randomUUID()
 | 
			
		||||
      # Add resolve/reject functions to callMap
 | 
			
		||||
      @callMap.set(id, {
 | 
			
		||||
        resolve: resolve,
 | 
			
		||||
        reject: reject,
 | 
			
		||||
      })
 | 
			
		||||
 | 
			
		||||
      # Encode data as JSON
 | 
			
		||||
      data = @enc.encode({
 | 
			
		||||
        Receiver: rcvr,
 | 
			
		||||
        Method: method,
 | 
			
		||||
        Arg: arg,
 | 
			
		||||
        ID: id,
 | 
			
		||||
      }.to_json())
 | 
			
		||||
 | 
			
		||||
      # Send data to lrpc server
 | 
			
		||||
      @conn.send(data.buffer)
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # getClient returns an object containing functions
 | 
			
		||||
  # corresponding to registered functions on the given
 | 
			
		||||
  # receiver. It uses the lrpc.Introspect() endpoint
 | 
			
		||||
  # to achieve this.
 | 
			
		||||
  def getObject(rcvr)
 | 
			
		||||
    return Promise.new do |resolve|
 | 
			
		||||
      # Introspect methods on given receiver
 | 
			
		||||
      self.callMethod("lrpc", "Introspect", rcvr).then do |methodDesc|
 | 
			
		||||
        # Create output object
 | 
			
		||||
        out = {}
 | 
			
		||||
        # For each method in description array
 | 
			
		||||
        methodDesc.each do |method|
 | 
			
		||||
          # Create and assign new function to call current method
 | 
			
		||||
          out[method.Name] = proc { |arg| return self.callMethod(rcvr, method.Name, arg) }
 | 
			
		||||
        end
 | 
			
		||||
        # Resolve promise with output promise
 | 
			
		||||
        resolve(out)
 | 
			
		||||
      end
 | 
			
		||||
    end
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
 | 
			
		||||
# LRPCChannel represents a channel used for lrpc.
 | 
			
		||||
class LRPCChannel
 | 
			
		||||
  def initialize(client, id)
 | 
			
		||||
    # Set self variables
 | 
			
		||||
    @client = client
 | 
			
		||||
    @id = id
 | 
			
		||||
    @closed = false
 | 
			
		||||
    # Set function variables to no-ops
 | 
			
		||||
    @onMessage = proc {|fn|}
 | 
			
		||||
    @onClose = proc {}
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # isChannel is defined to allow identifying whether
 | 
			
		||||
  # an object is a channel.
 | 
			
		||||
  def isChannel() end
 | 
			
		||||
 | 
			
		||||
  # send sends a value on the channel. This should not
 | 
			
		||||
  # be called by the consumer of the channel.
 | 
			
		||||
  def send(val)
 | 
			
		||||
    return if @closed
 | 
			
		||||
    fn = @onMessage
 | 
			
		||||
    fn(val)
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # done cancels the context corresponding to the channel
 | 
			
		||||
  # on the server side and closes the channel.
 | 
			
		||||
  def done()
 | 
			
		||||
    return if @closed
 | 
			
		||||
    @client.callMethod("lrpc", "ChannelDone", @id)
 | 
			
		||||
    self.close()
 | 
			
		||||
    @client._callMap.delete(@id)
 | 
			
		||||
  end
 | 
			
		||||
  
 | 
			
		||||
  # onMessage sets the callback to be called whenever a
 | 
			
		||||
  # message is received. The function should have one parameter
 | 
			
		||||
  # that will be set to the value received. Subsequent calls
 | 
			
		||||
  # will overwrite the callback
 | 
			
		||||
  def onMessage(fn)
 | 
			
		||||
    @onMessage = fn
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # onClose sets the callback to be called whenever the client
 | 
			
		||||
  # is closed. The function should have no parameters. 
 | 
			
		||||
  # Subsequent calls will overwrite the callback
 | 
			
		||||
  def onClose(fn)
 | 
			
		||||
    @onClose = fn
 | 
			
		||||
  end
 | 
			
		||||
 | 
			
		||||
  # close closes the channel. This should not be called by the
 | 
			
		||||
  # consumer of the channel. Use done() instead.
 | 
			
		||||
  def close()
 | 
			
		||||
    return if @closed
 | 
			
		||||
    fn = @onClose
 | 
			
		||||
    fn()
 | 
			
		||||
    @closed = true
 | 
			
		||||
  end
 | 
			
		||||
end
 | 
			
		||||
@@ -19,6 +19,7 @@
 | 
			
		||||
package codec
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/gob"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -38,42 +39,76 @@ type CodecFunc func(io.ReadWriter) Codec
 | 
			
		||||
type Codec interface {
 | 
			
		||||
	Encode(val any) error
 | 
			
		||||
	Decode(val any) error
 | 
			
		||||
	Unmarshal(data []byte, v any) error
 | 
			
		||||
	Marshal(v any) ([]byte, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Default is the default CodecFunc
 | 
			
		||||
var Default = Msgpack
 | 
			
		||||
 | 
			
		||||
// JSON is a CodecFunc that creates a JSON Codec
 | 
			
		||||
func JSON(rw io.ReadWriter) Codec {
 | 
			
		||||
	type jsonCodec struct {
 | 
			
		||||
type JsonCodec struct {
 | 
			
		||||
	*json.Encoder
 | 
			
		||||
	*json.Decoder
 | 
			
		||||
	}
 | 
			
		||||
	return jsonCodec{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (JsonCodec) Unmarshal(data []byte, v any) error {
 | 
			
		||||
	return json.Unmarshal(data, v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (JsonCodec) Marshal(v any) ([]byte, error) {
 | 
			
		||||
	return json.Marshal(v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JSON is a CodecFunc that creates a JSON Codec
 | 
			
		||||
func JSON(rw io.ReadWriter) Codec {
 | 
			
		||||
	return JsonCodec{
 | 
			
		||||
		Encoder: json.NewEncoder(rw),
 | 
			
		||||
		Decoder: json.NewDecoder(rw),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Msgpack is a CodecFunc that creates a Msgpack Codec
 | 
			
		||||
func Msgpack(rw io.ReadWriter) Codec {
 | 
			
		||||
	type msgpackCodec struct {
 | 
			
		||||
type MsgpackCodec struct {
 | 
			
		||||
	*msgpack.Encoder
 | 
			
		||||
	*msgpack.Decoder
 | 
			
		||||
	}
 | 
			
		||||
	return msgpackCodec{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (MsgpackCodec) Unmarshal(data []byte, v any) error {
 | 
			
		||||
	return msgpack.Unmarshal(data, v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (MsgpackCodec) Marshal(v any) ([]byte, error) {
 | 
			
		||||
	return msgpack.Marshal(v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Msgpack is a CodecFunc that creates a Msgpack Codec
 | 
			
		||||
func Msgpack(rw io.ReadWriter) Codec {
 | 
			
		||||
	return MsgpackCodec{
 | 
			
		||||
		Encoder: msgpack.NewEncoder(rw),
 | 
			
		||||
		Decoder: msgpack.NewDecoder(rw),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Gob is a CodecFunc that creates a Gob Codec
 | 
			
		||||
func Gob(rw io.ReadWriter) Codec {
 | 
			
		||||
	type gobCodec struct {
 | 
			
		||||
type GobCodec struct {
 | 
			
		||||
	*gob.Encoder
 | 
			
		||||
	*gob.Decoder
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (GobCodec) Unmarshal(data []byte, v any) error {
 | 
			
		||||
	return gob.NewDecoder(bytes.NewReader(data)).Decode(v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (GobCodec) Marshal(v any) ([]byte, error) {
 | 
			
		||||
	buf := &bytes.Buffer{}
 | 
			
		||||
	err := gob.NewEncoder(buf).Encode(v)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return gobCodec{
 | 
			
		||||
	return buf.Bytes(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Gob is a CodecFunc that creates a Gob Codec
 | 
			
		||||
func Gob(rw io.ReadWriter) Codec {
 | 
			
		||||
	return GobCodec{
 | 
			
		||||
		Encoder: gob.NewEncoder(rw),
 | 
			
		||||
		Decoder: gob.NewDecoder(rw),
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,186 +0,0 @@
 | 
			
		||||
/*
 | 
			
		||||
 *	lrpc allows for clients to call functions on a server remotely.
 | 
			
		||||
 *	Copyright (C) 2022 Arsen Musayelyan
 | 
			
		||||
 *
 | 
			
		||||
 *	This program is free software: you can redistribute it and/or modify
 | 
			
		||||
 *	it under the terms of the GNU General Public License as published by
 | 
			
		||||
 *	the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
 *	(at your option) any later version.
 | 
			
		||||
 *
 | 
			
		||||
 *	This program is distributed in the hope that it will be useful,
 | 
			
		||||
 *	but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
 *	MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
 *	GNU General Public License for more details.
 | 
			
		||||
 *
 | 
			
		||||
 *	You should have received a copy of the GNU General Public License
 | 
			
		||||
 *	along with this program.  If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
package reflectutil
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"github.com/mitchellh/mapstructure"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// <= go1.17 compatibility
 | 
			
		||||
type any = interface{}
 | 
			
		||||
 | 
			
		||||
// Convert attempts to convert the given value to the given type
 | 
			
		||||
func Convert(in reflect.Value, toType reflect.Type) (reflect.Value, error) {
 | 
			
		||||
	// Get input type
 | 
			
		||||
	inType := in.Type()
 | 
			
		||||
 | 
			
		||||
	// If input is already the desired type, return
 | 
			
		||||
	if inType == toType {
 | 
			
		||||
		return in, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the output type is a pointer to the input type
 | 
			
		||||
	if reflect.PtrTo(inType) == toType {
 | 
			
		||||
		if in.CanAddr() {
 | 
			
		||||
			// Return pointer to input
 | 
			
		||||
			return in.Addr(), nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		inPtrVal := reflect.New(inType)
 | 
			
		||||
		inPtrVal.Elem().Set(in)
 | 
			
		||||
		return inPtrVal, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If input is a pointer pointing to the output type
 | 
			
		||||
	if inType.Kind() == reflect.Ptr && inType.Elem() == toType {
 | 
			
		||||
		// Return value being pointed at by input
 | 
			
		||||
		return reflect.Indirect(in), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If input can be converted to desired type, convert and return
 | 
			
		||||
	if in.CanConvert(toType) {
 | 
			
		||||
		return in.Convert(toType), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create new value of desired type
 | 
			
		||||
	to := reflect.New(toType).Elem()
 | 
			
		||||
 | 
			
		||||
	// If type is a pointer
 | 
			
		||||
	if to.Kind() == reflect.Ptr {
 | 
			
		||||
		// Initialize value
 | 
			
		||||
		to.Set(reflect.New(to.Type().Elem()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch val := in.Interface().(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		// If desired type satisfies text unmarshaler
 | 
			
		||||
		if u, ok := to.Interface().(encoding.TextUnmarshaler); ok {
 | 
			
		||||
			// Use text unmarshaler to get value
 | 
			
		||||
			err := u.UnmarshalText([]byte(val))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return reflect.Value{}, err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Return unmarshaled value
 | 
			
		||||
			return reflect.ValueOf(any(u)), nil
 | 
			
		||||
		}
 | 
			
		||||
	case []byte:
 | 
			
		||||
		// If desired type satisfies binary unmarshaler
 | 
			
		||||
		if u, ok := to.Interface().(encoding.BinaryUnmarshaler); ok {
 | 
			
		||||
			// Use binary unmarshaler to get value
 | 
			
		||||
			err := u.UnmarshalBinary(val)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return reflect.Value{}, err
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Return unmarshaled value
 | 
			
		||||
			return reflect.ValueOf(any(u)), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If input is a map
 | 
			
		||||
	if in.Kind() == reflect.Map {
 | 
			
		||||
		// Use mapstructure to decode value
 | 
			
		||||
		err := mapstructure.Decode(in.Interface(), to.Addr().Interface())
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return to, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If input is a slice of any, and output is an array or slice
 | 
			
		||||
	if in.Type() == reflect.TypeOf([]any{}) &&
 | 
			
		||||
		to.Kind() == reflect.Slice || to.Kind() == reflect.Array {
 | 
			
		||||
		// Use ConvertSlice to convert value
 | 
			
		||||
		return reflect.ValueOf(ConvertSlice(
 | 
			
		||||
			in.Interface().([]any),
 | 
			
		||||
			toType,
 | 
			
		||||
		)), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return to, fmt.Errorf("cannot convert %s to %s", inType, toType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConvertSlice converts []any to an array or slice, as provided
 | 
			
		||||
// in the "to" argument.
 | 
			
		||||
func ConvertSlice(in []any, to reflect.Type) any {
 | 
			
		||||
	// Create new value for output
 | 
			
		||||
	out := reflect.New(to).Elem()
 | 
			
		||||
 | 
			
		||||
	// Get type of slice elements
 | 
			
		||||
	outType := out.Type().Elem()
 | 
			
		||||
 | 
			
		||||
	// If output value is a slice
 | 
			
		||||
	if out.Kind() == reflect.Slice {
 | 
			
		||||
		// For every value provided
 | 
			
		||||
		for i := 0; i < len(in); i++ {
 | 
			
		||||
			// Get value of input type
 | 
			
		||||
			inVal := reflect.ValueOf(in[i])
 | 
			
		||||
			// Create new output type
 | 
			
		||||
			outVal := reflect.New(outType).Elem()
 | 
			
		||||
 | 
			
		||||
			// If types match
 | 
			
		||||
			if inVal.Type() == outType {
 | 
			
		||||
				// Set output value to input value
 | 
			
		||||
				outVal.Set(inVal)
 | 
			
		||||
			} else {
 | 
			
		||||
				newVal, err := Convert(inVal, outType)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					// Set output value to its zero value
 | 
			
		||||
					outVal.Set(reflect.Zero(outVal.Type()))
 | 
			
		||||
				} else {
 | 
			
		||||
					outVal.Set(newVal)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Append output value to slice
 | 
			
		||||
			out = reflect.Append(out, outVal)
 | 
			
		||||
		}
 | 
			
		||||
	} else if out.Kind() == reflect.Array && out.Len() == len(in) {
 | 
			
		||||
		//If output type is array and lengths match
 | 
			
		||||
 | 
			
		||||
		// For every input value
 | 
			
		||||
		for i := 0; i < len(in); i++ {
 | 
			
		||||
			// Get matching output index
 | 
			
		||||
			outVal := out.Index(i)
 | 
			
		||||
			// Get input value
 | 
			
		||||
			inVal := reflect.ValueOf(in[i])
 | 
			
		||||
 | 
			
		||||
			// If types match
 | 
			
		||||
			if inVal.Type() == outVal.Type() {
 | 
			
		||||
				// Set output value to input value
 | 
			
		||||
				outVal.Set(inVal)
 | 
			
		||||
			} else {
 | 
			
		||||
				newVal, err := Convert(inVal, outType)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					// Set output value to its zero value
 | 
			
		||||
					outVal.Set(reflect.Zero(outVal.Type()))
 | 
			
		||||
				} else {
 | 
			
		||||
					outVal.Set(newVal)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Return created value
 | 
			
		||||
	return out.Interface()
 | 
			
		||||
}
 | 
			
		||||
@@ -26,7 +26,7 @@ type Request struct {
 | 
			
		||||
	ID       string
 | 
			
		||||
	Receiver string
 | 
			
		||||
	Method   string
 | 
			
		||||
	Arg      any
 | 
			
		||||
	Arg      []byte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ResponseType uint8
 | 
			
		||||
@@ -43,5 +43,5 @@ type Response struct {
 | 
			
		||||
	Type   ResponseType
 | 
			
		||||
	ID     string
 | 
			
		||||
	Error  string
 | 
			
		||||
	Return any
 | 
			
		||||
	Return []byte
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										78
									
								
								lrpc_test.go
									
									
									
									
									
								
							
							
						
						
									
										78
									
								
								lrpc_test.go
									
									
									
									
									
								
							@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"encoding/gob"
 | 
			
		||||
	"net"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"go.arsenm.dev/lrpc/client"
 | 
			
		||||
	"go.arsenm.dev/lrpc/codec"
 | 
			
		||||
@@ -132,3 +133,80 @@ func TestCodecs(t *testing.T) {
 | 
			
		||||
	testCodec(codec.JSON, "json")
 | 
			
		||||
	testCodec(codec.Gob, "gob")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Channel struct{}
 | 
			
		||||
 | 
			
		||||
func (Channel) Time(ctx *server.Context, interval time.Duration) error {
 | 
			
		||||
	ch, err := ctx.MakeChannel()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tick := time.NewTicker(interval)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			select {
 | 
			
		||||
			case t := <-tick.C:
 | 
			
		||||
				ch <- t
 | 
			
		||||
			case <-ctx.Done():
 | 
			
		||||
				tick.Stop()
 | 
			
		||||
				close(ch)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestChannel(t *testing.T) {
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	// Create new network pipe
 | 
			
		||||
	sConn, cConn := net.Pipe()
 | 
			
		||||
 | 
			
		||||
	s := server.New()
 | 
			
		||||
	defer s.Close()
 | 
			
		||||
	// Register Arith for RPC
 | 
			
		||||
	s.Register(Channel{})
 | 
			
		||||
	// Serve the pipe connection using default codec
 | 
			
		||||
	go s.ServeConn(ctx, sConn, codec.Default)
 | 
			
		||||
 | 
			
		||||
	// Create new client using default codec
 | 
			
		||||
	c := client.New(cConn, codec.Default)
 | 
			
		||||
	defer c.Close()
 | 
			
		||||
 | 
			
		||||
	timeCtx, timeCancel := context.WithCancel(ctx)
 | 
			
		||||
	defer timeCancel()
 | 
			
		||||
 | 
			
		||||
	timeCh := make(chan *time.Time, 2)
 | 
			
		||||
	err := c.Call(timeCtx, "Channel", "Time", time.Millisecond, timeCh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var loops int
 | 
			
		||||
	var lastTime *time.Time
 | 
			
		||||
	for curTime := range timeCh {
 | 
			
		||||
		if loops > 3 {
 | 
			
		||||
			timeCancel()
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if lastTime == nil {
 | 
			
		||||
			lastTime = curTime
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		diff := curTime.Sub(*lastTime)
 | 
			
		||||
		diff = diff.Round(time.Millisecond)
 | 
			
		||||
 | 
			
		||||
		if diff != time.Millisecond {
 | 
			
		||||
			t.Fatalf("expected 1s diff, got %s", diff)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		lastTime = curTime
 | 
			
		||||
		loops++
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,6 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"go.arsenm.dev/lrpc/codec"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/reflectutil"
 | 
			
		||||
	"go.arsenm.dev/lrpc/internal/types"
 | 
			
		||||
	"golang.org/x/net/websocket"
 | 
			
		||||
)
 | 
			
		||||
@@ -41,7 +40,7 @@ var (
 | 
			
		||||
	ErrNoSuchReceiver = errors.New("no such receiver registered")
 | 
			
		||||
	ErrNoSuchMethod   = errors.New("no such method was found")
 | 
			
		||||
	ErrInvalidMethod  = errors.New("method invalid for lrpc call")
 | 
			
		||||
	ErrUnexpectedArgument = errors.New("argument provided but the function does not accept any arguments")
 | 
			
		||||
	ErrArgNotProvided = errors.New("method expected an argument, but none was provided")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Server is an lrpc server
 | 
			
		||||
@@ -100,7 +99,7 @@ func (s *Server) Register(v any) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// execute runs a method of a registered value
 | 
			
		||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
func (s *Server) execute(pCtx context.Context, typ string, name string, data []byte, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
	// Try to get value from receivers map
 | 
			
		||||
	val, ok := s.rcvrs[typ]
 | 
			
		||||
	if !ok {
 | 
			
		||||
@@ -121,29 +120,18 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
 | 
			
		||||
	// Get method type
 | 
			
		||||
	mtdType := mtd.Type()
 | 
			
		||||
 | 
			
		||||
	// Return error if argument provided but isn't expected
 | 
			
		||||
	if mtdType.NumIn() == 1 && arg != nil {
 | 
			
		||||
		return nil, nil, ErrUnexpectedArgument
 | 
			
		||||
	}
 | 
			
		||||
	//TODO: if arg not nil but fn has no arg, err
 | 
			
		||||
 | 
			
		||||
	// IF argument is []any
 | 
			
		||||
	anySlice, ok := arg.([]any)
 | 
			
		||||
	if ok {
 | 
			
		||||
		// Convert slice to the method's arg type and
 | 
			
		||||
		// set arg to the newly-converted slice
 | 
			
		||||
		arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
 | 
			
		||||
	}
 | 
			
		||||
	argType := mtdType.In(1)
 | 
			
		||||
	argVal := reflect.New(argType)
 | 
			
		||||
	arg := argVal.Interface()
 | 
			
		||||
 | 
			
		||||
	// Get argument value
 | 
			
		||||
	argVal := reflect.ValueOf(arg)
 | 
			
		||||
	// If argument's type does not match method's argument type
 | 
			
		||||
	if arg != nil && argVal.Type() != mtdType.In(1) {
 | 
			
		||||
		val, err = reflectutil.Convert(argVal, mtdType.In(1))
 | 
			
		||||
	err = c.Unmarshal(data, arg)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
		arg = val.Interface()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	arg = argVal.Elem().Interface()
 | 
			
		||||
 | 
			
		||||
	ctx = newContext(pCtx, c)
 | 
			
		||||
	// Get reflect value of context
 | 
			
		||||
@@ -152,6 +140,9 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
 | 
			
		||||
	switch mtdType.NumOut() {
 | 
			
		||||
	case 0: // If method has no return values
 | 
			
		||||
		if mtdType.NumIn() == 2 {
 | 
			
		||||
			if arg == nil {
 | 
			
		||||
				return nil, nil, ErrArgNotProvided
 | 
			
		||||
			}
 | 
			
		||||
			// Call method with arg, ignore returned value
 | 
			
		||||
			mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -160,6 +151,10 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
 | 
			
		||||
		}
 | 
			
		||||
	case 1: // If method has one return value
 | 
			
		||||
		if mtdType.NumIn() == 2 {
 | 
			
		||||
			if arg == nil {
 | 
			
		||||
				return nil, nil, ErrArgNotProvided
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Call method with arg, get returned values
 | 
			
		||||
			out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
 | 
			
		||||
 | 
			
		||||
@@ -194,6 +189,10 @@ func (s *Server) execute(pCtx context.Context, typ string, name string, arg any,
 | 
			
		||||
		}
 | 
			
		||||
	case 2: // If method has two return values
 | 
			
		||||
		if mtdType.NumIn() == 2 {
 | 
			
		||||
			if arg == nil {
 | 
			
		||||
				return nil, nil, ErrArgNotProvided
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Call method with arg and get returned values
 | 
			
		||||
			out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
 | 
			
		||||
 | 
			
		||||
@@ -304,6 +303,7 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		go func() {
 | 
			
		||||
			// Execute decoded call
 | 
			
		||||
			val, ctx, err := s.execute(
 | 
			
		||||
				pCtx,
 | 
			
		||||
@@ -315,18 +315,30 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				s.sendErr(c, call, val, err)
 | 
			
		||||
			} else {
 | 
			
		||||
				valData, err := c.Marshal(val)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					s.sendErr(c, call, val, err)
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Create response
 | 
			
		||||
				res := types.Response{
 | 
			
		||||
					ID:     call.ID,
 | 
			
		||||
				Return: val,
 | 
			
		||||
					Return: valData,
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// If function has created a channel
 | 
			
		||||
				if ctx.isChannel {
 | 
			
		||||
					idData, err := c.Marshal(ctx.channelID)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						s.sendErr(c, call, val, err)
 | 
			
		||||
						return
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// Set IsChannel to true
 | 
			
		||||
					res.Type = types.ResponseTypeChannel
 | 
			
		||||
					// Overwrite return value with channel ID
 | 
			
		||||
				res.Return = ctx.channelID
 | 
			
		||||
					res.Return = idData
 | 
			
		||||
 | 
			
		||||
					// Store context in map for future use
 | 
			
		||||
					s.contextsMtx.Lock()
 | 
			
		||||
@@ -337,11 +349,18 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
						// For every value received from channel
 | 
			
		||||
						for val := range ctx.channel {
 | 
			
		||||
							codecMtx.Lock()
 | 
			
		||||
 | 
			
		||||
							valData, err := c.Marshal(val)
 | 
			
		||||
							if err != nil {
 | 
			
		||||
								continue
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							// Encode response using codec
 | 
			
		||||
							c.Encode(types.Response{
 | 
			
		||||
								ID:     ctx.channelID,
 | 
			
		||||
							Return: val,
 | 
			
		||||
								Return: valData,
 | 
			
		||||
							})
 | 
			
		||||
 | 
			
		||||
							codecMtx.Unlock()
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
@@ -366,17 +385,21 @@ func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
				c.Encode(res)
 | 
			
		||||
				codecMtx.Unlock()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// sendErr sends an error response
 | 
			
		||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
 | 
			
		||||
	valData, _ := c.Marshal(val)
 | 
			
		||||
 | 
			
		||||
	// Encode error response using codec
 | 
			
		||||
	c.Encode(types.Response{
 | 
			
		||||
		Type:   types.ResponseTypeError,
 | 
			
		||||
		ID:     req.ID,
 | 
			
		||||
		Error:  err.Error(),
 | 
			
		||||
		Return: val,
 | 
			
		||||
		Return: valData,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user