Use context to stop sending values rather than trying to detect channel close

This commit is contained in:
2022-05-01 15:13:07 -07:00
parent 6df8cf53c6
commit b53388122c
2 changed files with 24 additions and 26 deletions

View File

@@ -1,6 +1,7 @@
package client
import (
"context"
"errors"
"net"
"reflect"
@@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client {
}
// Call calls a method on the server
func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error {
func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error {
// Create new v4 UUOD
id, err := uuid.NewV4()
if err != nil {
@@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
}
idStr := id.String()
ctxDoneVal := reflect.ValueOf(ctx.Done())
// Create new channel using the generated ID
c.chMtx.Lock()
c.chs[idStr] = make(chan *types.Response, 1)
@@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
c.chs[chID] = make(chan *types.Response, 5)
c.chMtx.Unlock()
channelClosed := false
go func() {
// Get type of channel elements
chElemType := retVal.Type().Elem()
// For every value received from channel
for val := range c.chs[chID] {
//s := time.Now()
if val.ChannelDone {
// Close and delete channel
c.chMtx.Lock()
@@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
// Close return channel
retVal.Close()
channelClosed = true
break
}
// Get reflect value from channel response
@@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err
rVal = newVal
}
// Send value to channel
retVal.Send(rVal)
}
}()
chosen, _, _ := reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: retVal, Send: rVal},
{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}},
})
if chosen == 1 {
c.Call(context.Background(), "lrpc", "ChannelDone", id, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
go func() {
for {
val, ok := retVal.Recv()
if !ok && val.IsValid() {
break
retVal.Close()
}
}
if !channelClosed {
c.Call("lrpc", "ChannelDone", id, nil)
// Close and delete channel
c.chMtx.Lock()
close(c.chs[chID])
delete(c.chs, chID)
c.chMtx.Unlock()
}
}()
} else {
// IF return value is not a pointer, return error