Propagate context to requests
This commit is contained in:
parent
af77b121f8
commit
3bcc01fdb6
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"net"
|
||||
|
||||
@ -33,5 +34,5 @@ func main() {
|
||||
s.Register(Arith{})
|
||||
|
||||
ln, _ := net.Listen("tcp", ":9090")
|
||||
s.Serve(ln, codec.Gob)
|
||||
s.Serve(context.Background(), ln, codec.Gob)
|
||||
}
|
||||
|
@ -37,6 +37,24 @@ type Context struct {
|
||||
|
||||
doneCh chan struct{}
|
||||
canceled bool
|
||||
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func newContext(ctx context.Context, codec codec.Codec) *Context {
|
||||
out := &Context{
|
||||
doneCh: make(chan struct{}),
|
||||
codec: codec,
|
||||
ctx: ctx,
|
||||
}
|
||||
if ctx == nil {
|
||||
out.ctx = context.Background()
|
||||
}
|
||||
go func() {
|
||||
<-out.ctx.Done()
|
||||
out.cancel()
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
// MakeChannel changes the function it's called in into a
|
||||
@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
|
||||
}
|
||||
|
||||
// Cancel cancels the context
|
||||
func (ctx *Context) Cancel() {
|
||||
func (ctx *Context) cancel() {
|
||||
if ctx.canceled {
|
||||
return
|
||||
}
|
||||
ctx.canceled = true
|
||||
close(ctx.doneCh)
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@ -67,7 +68,7 @@ func New() *Server {
|
||||
// Close closes the server
|
||||
func (s *Server) Close() {
|
||||
for _, ctx := range s.contexts {
|
||||
ctx.Cancel()
|
||||
ctx.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,7 +99,7 @@ func (s *Server) Register(v any) error {
|
||||
}
|
||||
|
||||
// execute runs a method of a registered value
|
||||
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||
// Try to get value from receivers map
|
||||
val, ok := s.rcvrs[typ]
|
||||
if !ok {
|
||||
@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
arg = val.Interface()
|
||||
}
|
||||
|
||||
// Create new context
|
||||
ctx = &Context{
|
||||
doneCh: make(chan struct{}, 1),
|
||||
codec: c,
|
||||
}
|
||||
ctx = newContext(pCtx, c)
|
||||
// Get reflect value of context
|
||||
ctxVal := reflect.ValueOf(ctx)
|
||||
|
||||
@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
||||
|
||||
// Serve starts the server using the provided listener
|
||||
// and codec function
|
||||
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
|
||||
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
break
|
||||
} else if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create new instance of codec bound to conn
|
||||
c := cf(conn)
|
||||
// Handle connection
|
||||
go s.handleConn(c)
|
||||
go s.handleConn(ctx, c)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeWS starts a server using WebSocket. This may be useful for
|
||||
// clients written in other languages, such as JS for a browser.
|
||||
func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
|
||||
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
|
||||
// Create new WebSocket server
|
||||
ws := websocket.Server{}
|
||||
|
||||
@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
|
||||
|
||||
// Set server handler
|
||||
ws.Handler = func(c *websocket.Conn) {
|
||||
s.handleConn(cf(c))
|
||||
s.handleConn(c.Request().Context(), cf(c))
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
Handler: http.HandlerFunc(ws.ServeHTTP),
|
||||
}
|
||||
|
||||
// Listen and serve on given address
|
||||
return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP))
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
// handleConn handles a listener connection
|
||||
func (s *Server) handleConn(c codec.Codec) {
|
||||
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||
codecMtx := &sync.Mutex{}
|
||||
|
||||
for {
|
||||
@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) {
|
||||
|
||||
// Execute decoded call
|
||||
val, ctx, err := s.execute(
|
||||
pCtx,
|
||||
call.Receiver,
|
||||
call.Method,
|
||||
call.Arg,
|
||||
@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) {
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.Cancel()
|
||||
ctx.cancel()
|
||||
// Delete context from map
|
||||
s.contextsMtx.Lock()
|
||||
delete(s.contexts, ctx.channelID)
|
||||
@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
ctx.Cancel()
|
||||
ctx.cancel()
|
||||
// Delete context from map
|
||||
l.srv.contextsMtx.Lock()
|
||||
delete(l.srv.contexts, id)
|
||||
|
Reference in New Issue
Block a user