Use context to stop sending values rather than trying to detect channel close
This commit is contained in:
parent
6df8cf53c6
commit
b53388122c
@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call calls a method on the server
|
// Call calls a method on the server
|
||||||
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
|
func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
|
||||||
// Create new v4 UUOD
|
// Create new v4 UUOD
|
||||||
id, err := uuid.NewV4()
|
id, err := uuid.NewV4()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
|||||||
}
|
}
|
||||||
idStr := id.String()
|
idStr := id.String()
|
||||||
|
|
||||||
|
ctxDoneVal := reflect.ValueOf(ctx.Done())
|
||||||
|
|
||||||
// Create new channel using the generated ID
|
// Create new channel using the generated ID
|
||||||
c.chMtx.Lock()
|
c.chMtx.Lock()
|
||||||
c.chs[idStr] = make(chan *types.Response, 1)
|
c.chs[idStr] = make(chan *types.Response, 1)
|
||||||
@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
|||||||
c.chs[chID] = make(chan *types.Response, 5)
|
c.chs[chID] = make(chan *types.Response, 5)
|
||||||
c.chMtx.Unlock()
|
c.chMtx.Unlock()
|
||||||
|
|
||||||
channelClosed := false
|
|
||||||
go func() {
|
go func() {
|
||||||
// Get type of channel elements
|
// Get type of channel elements
|
||||||
chElemType := retVal.Type().Elem()
|
chElemType := retVal.Type().Elem()
|
||||||
// For every value received from channel
|
// For every value received from channel
|
||||||
for val := range c.chs[chID] {
|
for val := range c.chs[chID] {
|
||||||
|
//s := time.Now()
|
||||||
if val.ChannelDone {
|
if val.ChannelDone {
|
||||||
// Close and delete channel
|
// Close and delete channel
|
||||||
c.chMtx.Lock()
|
c.chMtx.Lock()
|
||||||
@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
|||||||
|
|
||||||
// Close return channel
|
// Close return channel
|
||||||
retVal.Close()
|
retVal.Close()
|
||||||
|
|
||||||
channelClosed = true
|
|
||||||
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
// Get reflect value from channel response
|
// Get reflect value from channel response
|
||||||
@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
|
|||||||
rVal = newVal
|
rVal = newVal
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send value to channel
|
chosen, _, _ := reflect.Select([]reflect.SelectCase{
|
||||||
retVal.Send(rVal)
|
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
|
||||||
}
|
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
|
||||||
}()
|
})
|
||||||
|
if chosen == 1 {
|
||||||
|
c.Call(context.Background(), "lrpc", "ChannelDone", id, nil)
|
||||||
|
// Close and delete channel
|
||||||
|
c.chMtx.Lock()
|
||||||
|
close(c.chs[chID])
|
||||||
|
delete(c.chs, chID)
|
||||||
|
c.chMtx.Unlock()
|
||||||
|
|
||||||
go func() {
|
retVal.Close()
|
||||||
for {
|
|
||||||
val, ok := retVal.Recv()
|
|
||||||
if !ok && val.IsValid() {
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !channelClosed {
|
|
||||||
c.Call("lrpc", "ChannelDone", id, nil)
|
|
||||||
// Close and delete channel
|
|
||||||
c.chMtx.Lock()
|
|
||||||
close(c.chs[chID])
|
|
||||||
delete(c.chs, chID)
|
|
||||||
c.chMtx.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
} else {
|
} else {
|
||||||
// IF return value is not a pointer, return error
|
// IF return value is not a pointer, return error
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@ -12,21 +13,23 @@ import (
|
|||||||
func main() {
|
func main() {
|
||||||
gob.Register([2]int{})
|
gob.Register([2]int{})
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
conn, _ := net.Dial("tcp", "localhost:9090")
|
conn, _ := net.Dial("tcp", "localhost:9090")
|
||||||
c := client.New(conn, codec.Gob)
|
c := client.New(conn, codec.Gob)
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
var add int
|
var add int
|
||||||
c.Call("Arith", "Add", [2]int{5, 5}, &add)
|
c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add)
|
||||||
|
|
||||||
var sub int
|
var sub int
|
||||||
c.Call("Arith", "Sub", [2]int{5, 5}, &sub)
|
c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub)
|
||||||
|
|
||||||
var mul int
|
var mul int
|
||||||
c.Call("Arith", "Mul", [2]int{5, 5}, &mul)
|
c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul)
|
||||||
|
|
||||||
var div int
|
var div int
|
||||||
c.Call("Arith", "Div", [2]int{5, 5}, &div)
|
c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div)
|
||||||
|
|
||||||
fmt.Printf(
|
fmt.Printf(
|
||||||
"add: %d, sub: %d, mul: %d, div: %d\n",
|
"add: %d, sub: %d, mul: %d, div: %d\n",
|
||||||
|
Reference in New Issue
Block a user