Compare commits

...

2 Commits

2 changed files with 24 additions and 26 deletions

View File

@ -1,6 +1,7 @@
package client package client
import ( import (
"context"
"errors" "errors"
"net" "net"
"reflect" "reflect"
@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
} }
// Call calls a method on the server // Call calls a method on the server
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error { func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
// Create new v4 UUOD // Create new v4 UUOD
id, err := uuid.NewV4() id, err := uuid.NewV4()
if err != nil { if err != nil {
@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
} }
idStr := id.String() idStr := id.String()
ctxDoneVal := reflect.ValueOf(ctx.Done())
// Create new channel using the generated ID // Create new channel using the generated ID
c.chMtx.Lock() c.chMtx.Lock()
c.chs[idStr] = make(chan *types.Response, 1) c.chs[idStr] = make(chan *types.Response, 1)
@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
c.chs[chID] = make(chan *types.Response, 5) c.chs[chID] = make(chan *types.Response, 5)
c.chMtx.Unlock() c.chMtx.Unlock()
channelClosed := false
go func() { go func() {
// Get type of channel elements // Get type of channel elements
chElemType := retVal.Type().Elem() chElemType := retVal.Type().Elem()
// For every value received from channel // For every value received from channel
for val := range c.chs[chID] { for val := range c.chs[chID] {
//s := time.Now()
if val.ChannelDone { if val.ChannelDone {
// Close and delete channel // Close and delete channel
c.chMtx.Lock() c.chMtx.Lock()
@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
// Close return channel // Close return channel
retVal.Close() retVal.Close()
channelClosed = true
break break
} }
// Get reflect value from channel response // Get reflect value from channel response
@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
rVal = newVal rVal = newVal
} }
// Send value to channel chosen, _, _ := reflect.Select([]reflect.SelectCase{
retVal.Send(rVal) {Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
} {Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
}() })
if chosen == 1 {
c.Call(context.Background(), "lrpc", "ChannelDone", chID, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
go func() { retVal.Close()
for {
val, ok := retVal.Recv()
if !ok && val.IsValid() {
break
} }
} }
if !channelClosed {
c.Call("lrpc", "ChannelDone", id, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
}
}() }()
} else { } else {
// IF return value is not a pointer, return error // IF return value is not a pointer, return error

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/gob" "encoding/gob"
"fmt" "fmt"
"net" "net"
@ -12,21 +13,23 @@ import (
func main() { func main() {
gob.Register([2]int{}) gob.Register([2]int{})
ctx := context.Background()
conn, _ := net.Dial("tcp", "localhost:9090") conn, _ := net.Dial("tcp", "localhost:9090")
c := client.New(conn, codec.Gob) c := client.New(conn, codec.Gob)
defer c.Close() defer c.Close()
var add int var add int
c.Call("Arith", "Add", [2]int{5, 5}, &add) c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
var sub int var sub int
c.Call("Arith", "Sub", [2]int{5, 5}, &sub) c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
var mul int var mul int
c.Call("Arith", "Mul", [2]int{5, 5}, &mul) c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
var div int var div int
c.Call("Arith", "Div", [2]int{5, 5}, &div) c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
fmt.Printf( fmt.Printf(
"add: %d, sub: %d, mul: %d, div: %d\n", "add: %d, sub: %d, mul: %d, div: %d\n",