Use context to stop sending values rather than trying to detect channel close
This commit is contained in:
		| @@ -1,6 +1,7 @@ | |||||||
| package client | package client | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"net" | 	"net" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| @@ -46,7 +47,7 @@ func New(conn net.Conn, cf codec.CodecFunc) *Client { | |||||||
| } | } | ||||||
|  |  | ||||||
| // Call calls a method on the server | // Call calls a method on the server | ||||||
| func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) error { | func (c *Client) Call(ctx context.Context, rcvr, method string, arg interface{}, ret interface{}) error { | ||||||
| 	// Create new v4 UUOD | 	// Create new v4 UUOD | ||||||
| 	id, err := uuid.NewV4() | 	id, err := uuid.NewV4() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -54,6 +55,8 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err | |||||||
| 	} | 	} | ||||||
| 	idStr := id.String() | 	idStr := id.String() | ||||||
|  |  | ||||||
|  | 	ctxDoneVal := reflect.ValueOf(ctx.Done()) | ||||||
|  |  | ||||||
| 	// Create new channel using the generated ID | 	// Create new channel using the generated ID | ||||||
| 	c.chMtx.Lock() | 	c.chMtx.Lock() | ||||||
| 	c.chs[idStr] = make(chan *types.Response, 1) | 	c.chs[idStr] = make(chan *types.Response, 1) | ||||||
| @@ -106,12 +109,12 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err | |||||||
| 		c.chs[chID] = make(chan *types.Response, 5) | 		c.chs[chID] = make(chan *types.Response, 5) | ||||||
| 		c.chMtx.Unlock() | 		c.chMtx.Unlock() | ||||||
|  |  | ||||||
| 		channelClosed := false |  | ||||||
| 		go func() { | 		go func() { | ||||||
| 			// Get type of channel elements | 			// Get type of channel elements | ||||||
| 			chElemType := retVal.Type().Elem() | 			chElemType := retVal.Type().Elem() | ||||||
| 			// For every value received from channel | 			// For every value received from channel | ||||||
| 			for val := range c.chs[chID] { | 			for val := range c.chs[chID] { | ||||||
|  | 				//s := time.Now() | ||||||
| 				if val.ChannelDone { | 				if val.ChannelDone { | ||||||
| 					// Close and delete channel | 					// Close and delete channel | ||||||
| 					c.chMtx.Lock() | 					c.chMtx.Lock() | ||||||
| @@ -121,9 +124,6 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err | |||||||
|  |  | ||||||
| 					// Close return channel | 					// Close return channel | ||||||
| 					retVal.Close() | 					retVal.Close() | ||||||
|  |  | ||||||
| 					channelClosed = true |  | ||||||
|  |  | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// Get reflect value from channel response | 				// Get reflect value from channel response | ||||||
| @@ -139,26 +139,21 @@ func (c *Client) Call(rcvr, method string, arg interface{}, ret interface{}) err | |||||||
| 					rVal = newVal | 					rVal = newVal | ||||||
| 				} | 				} | ||||||
|  |  | ||||||
| 				// Send value to channel | 				chosen, _, _ := reflect.Select([]reflect.SelectCase{ | ||||||
| 				retVal.Send(rVal) | 					{Dir: reflect.SelectSend, Chan: retVal, Send: rVal}, | ||||||
| 			} | 					{Dir: reflect.SelectRecv, Chan: ctxDoneVal, Send: reflect.Value{}}, | ||||||
| 		}() | 				}) | ||||||
|  | 				if chosen == 1 { | ||||||
|  | 					c.Call(context.Background(), "lrpc", "ChannelDone", id, nil) | ||||||
|  | 					// Close and delete channel | ||||||
|  | 					c.chMtx.Lock() | ||||||
|  | 					close(c.chs[chID]) | ||||||
|  | 					delete(c.chs, chID) | ||||||
|  | 					c.chMtx.Unlock() | ||||||
|  |  | ||||||
| 		go func() { | 					retVal.Close() | ||||||
| 			for { |  | ||||||
| 				val, ok := retVal.Recv() |  | ||||||
| 				if !ok && val.IsValid() { |  | ||||||
| 					break |  | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			if !channelClosed { |  | ||||||
| 				c.Call("lrpc", "ChannelDone", id, nil) |  | ||||||
| 				// Close and delete channel |  | ||||||
| 				c.chMtx.Lock() |  | ||||||
| 				close(c.chs[chID]) |  | ||||||
| 				delete(c.chs, chID) |  | ||||||
| 				c.chMtx.Unlock() |  | ||||||
| 			} |  | ||||||
| 		}() | 		}() | ||||||
| 	} else { | 	} else { | ||||||
| 		// IF return value is not a pointer, return error | 		// IF return value is not a pointer, return error | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package main | package main | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"encoding/gob" | 	"encoding/gob" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| @@ -12,21 +13,23 @@ import ( | |||||||
| func main() { | func main() { | ||||||
| 	gob.Register([2]int{}) | 	gob.Register([2]int{}) | ||||||
|  |  | ||||||
|  | 	ctx := context.Background() | ||||||
|  |  | ||||||
| 	conn, _ := net.Dial("tcp", "localhost:9090") | 	conn, _ := net.Dial("tcp", "localhost:9090") | ||||||
| 	c := client.New(conn, codec.Gob) | 	c := client.New(conn, codec.Gob) | ||||||
| 	defer c.Close() | 	defer c.Close() | ||||||
|  |  | ||||||
| 	var add int | 	var add int | ||||||
| 	c.Call("Arith", "Add", [2]int{5, 5}, &add) | 	c.Call(ctx, "Arith", "Add", [2]int{5, 5}, &add) | ||||||
|  |  | ||||||
| 	var sub int | 	var sub int | ||||||
| 	c.Call("Arith", "Sub", [2]int{5, 5}, &sub) | 	c.Call(ctx, "Arith", "Sub", [2]int{5, 5}, &sub) | ||||||
|  |  | ||||||
| 	var mul int | 	var mul int | ||||||
| 	c.Call("Arith", "Mul", [2]int{5, 5}, &mul) | 	c.Call(ctx, "Arith", "Mul", [2]int{5, 5}, &mul) | ||||||
|  |  | ||||||
| 	var div int | 	var div int | ||||||
| 	c.Call("Arith", "Div", [2]int{5, 5}, &div) | 	c.Call(ctx, "Arith", "Div", [2]int{5, 5}, &div) | ||||||
|  |  | ||||||
| 	fmt.Printf( | 	fmt.Printf( | ||||||
| 		"add: %d, sub: %d, mul: %d, div: %d\n", | 		"add: %d, sub: %d, mul: %d, div: %d\n", | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user