Initial Commit
This commit is contained in:
53
server/context.go
Normal file
53
server/context.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
|
||||
"github.com/gofrs/uuid"
|
||||
)
|
||||
|
||||
// Context is a connection context for RPC calls
|
||||
type Context struct {
|
||||
isChannel bool
|
||||
channelID string
|
||||
channel chan any
|
||||
|
||||
codec codec.Codec
|
||||
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// MakeChannel changes the function it's called in into a
|
||||
// channel function, and returns a channel which can be used
|
||||
// to send information to the client.
|
||||
//
|
||||
// This will ovewrite any return value of the function with
|
||||
// a channel ID.
|
||||
func (ctx *Context) MakeChannel() (chan<- any, error) {
|
||||
ctx.isChannel = true
|
||||
chID, err := uuid.NewV4()
|
||||
ctx.channelID = chID.String()
|
||||
ctx.channel = make(chan any, 5)
|
||||
return ctx.channel, err
|
||||
}
|
||||
|
||||
// GetCodec returns a codec bound to the connection
|
||||
// that called this function
|
||||
func (ctx *Context) GetCodec() codec.Codec {
|
||||
return ctx.codec
|
||||
}
|
||||
|
||||
// Done returns a channel that will be closed when
|
||||
// the context is canceled, such as when ChannelDone
|
||||
// is called by the client
|
||||
func (ctx *Context) Done() <-chan struct{} {
|
||||
return ctx.doneCh
|
||||
}
|
||||
|
||||
// Cancel cancels the context
|
||||
func (ctx *Context) Cancel() {
|
||||
close(ctx.doneCh)
|
||||
if ctx.channel != nil {
|
||||
close(ctx.channel)
|
||||
}
|
||||
}
|
||||
326
server/server.go
Normal file
326
server/server.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"go.arsenm.dev/lrpc/codec"
|
||||
"go.arsenm.dev/lrpc/internal/reflectutil"
|
||||
"go.arsenm.dev/lrpc/internal/types"
|
||||
)
|
||||
|
||||
// <= go1.17 compatibility
|
||||
type any = interface{}
|
||||
|
||||
var (
|
||||
ErrInvalidType = errors.New("type must be struct or pointer to struct")
|
||||
ErrTooManyInputs = errors.New("method may not have more than two inputs")
|
||||
ErrTooManyOutputs = errors.New("method may not have more than two return values")
|
||||
ErrNoSuchReceiver = errors.New("no such receiver registered")
|
||||
ErrNoSuchMethod = errors.New("no such method was found")
|
||||
ErrInvalidSecondReturn = errors.New("second return value must be error")
|
||||
ErrInvalidFirstInput = errors.New("first input must be *Context")
|
||||
)
|
||||
|
||||
// Server is an lrpc server
|
||||
type Server struct {
|
||||
rcvrs map[string]reflect.Value
|
||||
|
||||
contextsMtx sync.Mutex
|
||||
contexts map[string]*Context
|
||||
}
|
||||
|
||||
// New creates and returns a new server
|
||||
func New() *Server {
|
||||
// Create new server
|
||||
out := &Server{
|
||||
rcvrs: map[string]reflect.Value{},
|
||||
contexts: map[string]*Context{},
|
||||
}
|
||||
|
||||
// Register lrpc functions
|
||||
out.Register(lrpc{out})
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Close closes the server
|
||||
func (s *Server) Close() {
|
||||
for _, ctx := range s.contexts {
|
||||
ctx.Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a value to be called by a client
|
||||
func (s *Server) Register(v any) error {
|
||||
// Get reflect values for v
|
||||
val := reflect.ValueOf(v)
|
||||
kind := val.Kind()
|
||||
|
||||
// create variable to store name of v
|
||||
var name string
|
||||
switch kind {
|
||||
case reflect.Pointer:
|
||||
// If v is a pointer, get the name of the underlying type
|
||||
name = val.Elem().Type().Name()
|
||||
case reflect.Struct:
|
||||
// If v is a struct, get its name
|
||||
name = val.Type().Name()
|
||||
default:
|
||||
// If v is not pointer or struct, return error
|
||||
return ErrInvalidType
|
||||
}
|
||||
|
||||
// Add v to receivers map
|
||||
s.rcvrs[name] = val
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// execute runs a method of a registered value
|
||||
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
// Try to get value from receivers map
|
||||
val, ok := s.rcvrs[typ]
|
||||
if !ok {
|
||||
return nil, nil, ErrNoSuchReceiver
|
||||
}
|
||||
|
||||
// Try to retrieve given method
|
||||
mtd := val.MethodByName(name)
|
||||
if !mtd.IsValid() {
|
||||
return nil, nil, ErrNoSuchMethod
|
||||
}
|
||||
|
||||
// Get method's type
|
||||
mtdType := mtd.Type()
|
||||
if mtdType.NumIn() > 2 {
|
||||
return nil, nil, ErrTooManyInputs
|
||||
} else if mtdType.NumIn() < 1 {
|
||||
return nil, nil, ErrInvalidFirstInput
|
||||
}
|
||||
|
||||
if mtdType.NumOut() > 2 {
|
||||
return nil, nil, ErrTooManyOutputs
|
||||
}
|
||||
|
||||
// Check to ensure first parameter is context
|
||||
if mtdType.In(0) != reflect.TypeOf(&Context{}) {
|
||||
return nil, nil, ErrInvalidFirstInput
|
||||
}
|
||||
|
||||
//TODO: if arg not nil but fn has no arg, err
|
||||
|
||||
// IF argument is []any
|
||||
anySlice, ok := arg.([]any)
|
||||
if ok {
|
||||
// Convert slice to the method's arg type and
|
||||
// set arg to the newly-converted slice
|
||||
arg = reflectutil.ConvertSlice(anySlice, mtdType.In(1))
|
||||
}
|
||||
|
||||
// Get argument value
|
||||
argVal := reflect.ValueOf(arg)
|
||||
// If argument's type does not match method's argument type
|
||||
if arg != nil && argVal.Type() != mtdType.In(1) {
|
||||
// If it is possible to convert the arg to desired type
|
||||
if argVal.CanConvert(mtdType.In(1)) {
|
||||
// Convert and set arg to result
|
||||
arg = argVal.Convert(mtdType.In(1)).Interface()
|
||||
}
|
||||
//TODO: Invalid value err
|
||||
}
|
||||
|
||||
// Create new context
|
||||
ctx = &Context{
|
||||
doneCh: make(chan struct{}, 1),
|
||||
codec: c,
|
||||
}
|
||||
// Get reflect value of context
|
||||
ctxVal := reflect.ValueOf(ctx)
|
||||
|
||||
switch mtdType.NumOut() {
|
||||
case 0: // If method has no return values
|
||||
if mtdType.NumIn() == 2 {
|
||||
// Call method with arg, ignore returned value
|
||||
mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
} else {
|
||||
// Call method without arg, ignore returned value
|
||||
mtd.Call([]reflect.Value{ctxVal})
|
||||
}
|
||||
case 1: // If method has one return value
|
||||
if mtdType.NumIn() == 2 {
|
||||
// Call method with arg, get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
|
||||
// If the first return value's type is error
|
||||
if mtdType.Out(0).Name() == "error" {
|
||||
// Get first return value as interface
|
||||
out0 := out[0].Interface()
|
||||
if out0 == nil {
|
||||
a, err = nil, nil
|
||||
} else {
|
||||
a, err = nil, out0.(error)
|
||||
}
|
||||
} else {
|
||||
a, err = out[0].Interface(), nil
|
||||
}
|
||||
} else {
|
||||
// Call method without arg, get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal})
|
||||
|
||||
// If the first return value's type is error
|
||||
if mtdType.Out(0).Name() == "error" {
|
||||
// Get first return value as interface
|
||||
out0 := out[0].Interface()
|
||||
if out0 == nil {
|
||||
a, err = nil, nil
|
||||
} else {
|
||||
a, err = nil, out0.(error)
|
||||
}
|
||||
} else {
|
||||
a, err = out[0].Interface(), nil
|
||||
}
|
||||
}
|
||||
case 2: // If method has two return values
|
||||
if mtdType.NumIn() == 2 {
|
||||
// Call method with arg and get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal, reflect.ValueOf(arg)})
|
||||
|
||||
// Get second return value as interface
|
||||
out1 := out[1].Interface()
|
||||
if out1 != nil {
|
||||
err, ok = out1.(error)
|
||||
|
||||
// If second return value is not an error, the function is invalid
|
||||
if !ok {
|
||||
a, err = nil, ErrInvalidSecondReturn
|
||||
}
|
||||
}
|
||||
|
||||
a = out[0].Interface()
|
||||
} else {
|
||||
// Call method without arg and get returned values
|
||||
out := mtd.Call([]reflect.Value{ctxVal})
|
||||
|
||||
// Get second return value as interface
|
||||
out1 := out[1].Interface()
|
||||
if out1 != nil {
|
||||
|
||||
// If second return value is not an error, the function is invalid
|
||||
err, ok = out1.(error)
|
||||
if !ok {
|
||||
a, err = nil, ErrInvalidSecondReturn
|
||||
}
|
||||
}
|
||||
|
||||
a = out[0].Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return a, ctx, err
|
||||
}
|
||||
|
||||
// Serve starts the server using the provided listener
|
||||
// and codec function
|
||||
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new instance of codec bound to conn
|
||||
c := cf(conn)
|
||||
// Handle connection
|
||||
go s.handleConn(c)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn handles a listener connection
|
||||
func (s *Server) handleConn(c codec.Codec) {
|
||||
for {
|
||||
var call types.Request
|
||||
// Read request using codec
|
||||
err := c.Decode(&call)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, nil, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Execute decoded call
|
||||
val, ctx, err := s.execute(
|
||||
call.Receiver,
|
||||
call.Method,
|
||||
call.Arg,
|
||||
c,
|
||||
)
|
||||
if err != nil {
|
||||
s.sendErr(c, call, val, err)
|
||||
} else {
|
||||
// Create response
|
||||
res := types.Response{
|
||||
ID: call.ID,
|
||||
Return: val,
|
||||
}
|
||||
|
||||
// If function has created a channel
|
||||
if ctx.isChannel {
|
||||
// Set IsChannel to true
|
||||
res.IsChannel = true
|
||||
// Overwrite return value with channel ID
|
||||
res.Return = ctx.channelID
|
||||
|
||||
// Store context in map for future use
|
||||
s.contextsMtx.Lock()
|
||||
s.contexts[ctx.channelID] = ctx
|
||||
s.contextsMtx.Unlock()
|
||||
|
||||
go func() {
|
||||
// For every value received from channel
|
||||
for val := range ctx.channel {
|
||||
// Encode response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: ctx.channelID,
|
||||
Return: val,
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Encode response using codec
|
||||
c.Encode(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendErr sends an error response
|
||||
func (s *Server) sendErr(c codec.Codec, req types.Request, val any, err error) {
|
||||
// Encode error response using codec
|
||||
c.Encode(types.Response{
|
||||
ID: req.ID,
|
||||
IsError: true,
|
||||
Error: err.Error(),
|
||||
Return: val,
|
||||
})
|
||||
}
|
||||
|
||||
// lrpc contains functions registered on every server
|
||||
type lrpc struct {
|
||||
srv *Server
|
||||
}
|
||||
|
||||
// ChannelDone cancels a context and closes the associated channel
|
||||
func (l lrpc) ChannelDone(_ *Context, id string) {
|
||||
// Try to get context
|
||||
ctx, ok := l.srv.contexts[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.Cancel()
|
||||
// Delete context from map
|
||||
delete(l.srv.contexts, id)
|
||||
}
|
||||
Reference in New Issue
Block a user