198 lines
4.2 KiB
Go
198 lines
4.2 KiB
Go
package client
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"reflect"
|
|
"sync"
|
|
|
|
"go.arsenm.dev/lrpc/codec"
|
|
"go.arsenm.dev/lrpc/internal/reflectutil"
|
|
"go.arsenm.dev/lrpc/internal/types"
|
|
|
|
"github.com/gofrs/uuid"
|
|
)
|
|
|
|
// <= go1.17 compatibility
|
|
type any = interface{}
|
|
|
|
// Client error values
|
|
var (
|
|
ErrReturnNotChannel = errors.New("function call returns channel but return value is not a channel type")
|
|
ErrReturnNotPointer = errors.New("function call returns value but return value is not a pointer")
|
|
ErrMismatchedType = errors.New("type of channel does not match type returned by server")
|
|
)
|
|
|
|
// Client is an lrpc client
|
|
type Client struct {
|
|
conn net.Conn
|
|
codec codec.Codec
|
|
|
|
chMtx sync.Mutex
|
|
chs map[string]chan *types.Response
|
|
}
|
|
|
|
// New creates and returns a new client
|
|
func New(conn net.Conn, cf codec.CodecFunc) *Client {
|
|
out := &Client{
|
|
conn: conn,
|
|
codec: cf(conn),
|
|
chs: map[string]chan *types.Response{},
|
|
}
|
|
|
|
go out.handleConn()
|
|
|
|
return out
|
|
}
|
|
|
|
// Call calls a method on the server
|
|
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
|
|
// Create new v4 UUOD
|
|
id, err := uuid.NewV4()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idStr := id.String()
|
|
|
|
// Create new channel using the generated ID
|
|
c.chMtx.Lock()
|
|
c.chs[idStr] = make(chan *types.Response, 1)
|
|
c.chMtx.Unlock()
|
|
|
|
// Encode request using codec
|
|
err = c.codec.Encode(types.Request{
|
|
ID: idStr,
|
|
Receiver: rcvr,
|
|
Method: method,
|
|
Arg: arg,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get response from channel
|
|
resp := <-c.chs[idStr]
|
|
|
|
// Close and delete channel
|
|
c.chMtx.Lock()
|
|
close(c.chs[idStr])
|
|
delete(c.chs, idStr)
|
|
c.chMtx.Unlock()
|
|
|
|
// If response is an error, return error
|
|
if resp.IsError {
|
|
return errors.New(resp.Error)
|
|
}
|
|
|
|
// If there is no return value, stop now
|
|
if resp.Return == nil {
|
|
return nil
|
|
}
|
|
|
|
// Get reflect value of return value
|
|
retVal := reflect.ValueOf(ret)
|
|
|
|
// If response is a channel
|
|
if resp.IsChannel {
|
|
// If return value is not a channel, return error
|
|
if retVal.Kind() != reflect.Chan {
|
|
return ErrReturnNotChannel
|
|
}
|
|
// Get channel ID returned in response
|
|
chID := resp.Return.(string)
|
|
|
|
// Create new channel using channel ID
|
|
c.chMtx.Lock()
|
|
c.chs[chID] = make(chan *types.Response, 5)
|
|
c.chMtx.Unlock()
|
|
|
|
go func() {
|
|
// Get type of channel elements
|
|
chElemType := retVal.Type().Elem()
|
|
// For every value received from channel
|
|
for val := range c.chs[chID] {
|
|
// Get reflect value from channel response
|
|
rVal := reflect.ValueOf(val.Return)
|
|
|
|
// If return value is not the same as the channel
|
|
if rVal.Type() != chElemType {
|
|
// Attempt to convert value, skip if impossible
|
|
newVal, err := reflectutil.Convert(rVal, chElemType)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
rVal = newVal
|
|
}
|
|
|
|
// Try to read from the channel
|
|
recvVal, ok := retVal.TryRecv()
|
|
// IF the channel cannot be read but the value is valid,
|
|
// the channel must be closed
|
|
if !ok && recvVal.IsValid() {
|
|
// Send done signal
|
|
c.Call("lrpc", "ChannelDone", idStr, nil)
|
|
// Close and delete channel
|
|
c.chMtx.Lock()
|
|
close(c.chs[chID])
|
|
delete(c.chs, chID)
|
|
c.chMtx.Unlock()
|
|
break
|
|
}
|
|
|
|
// Send value to channel
|
|
retVal.Send(rVal)
|
|
}
|
|
}()
|
|
} else {
|
|
// IF return value is not a pointer, return error
|
|
if retVal.Kind() != reflect.Pointer {
|
|
return ErrReturnNotPointer
|
|
}
|
|
|
|
// Get return type
|
|
retType := retVal.Type().Elem()
|
|
// Get refkect value from response
|
|
rVal := reflect.ValueOf(resp.Return)
|
|
|
|
// If types do not match
|
|
if rVal.Type() != retType {
|
|
// Attempt to convert types, return error if not possible
|
|
newVal, err := reflectutil.Convert(rVal, retType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rVal = newVal
|
|
}
|
|
|
|
// Set return value to received value
|
|
retVal.Elem().Set(rVal)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) handleConn() {
|
|
for {
|
|
resp := &types.Response{}
|
|
// Attempt to decode response using codec
|
|
err := c.codec.Decode(resp)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Get channel from map, skip if it doesn't exist
|
|
ch, ok := c.chs[resp.ID]
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
// Send response to channel
|
|
ch <- resp
|
|
}
|
|
}
|
|
|
|
// Close closes the client
|
|
func (c *Client) Close() error {
|
|
return c.conn.Close()
|
|
}
|