Propagate context to requests
This commit is contained in:
		@@ -37,6 +37,24 @@ type Context struct {
 | 
			
		||||
 | 
			
		||||
	doneCh   chan struct{}
 | 
			
		||||
	canceled bool
 | 
			
		||||
 | 
			
		||||
	ctx context.Context
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newContext(ctx context.Context, codec codec.Codec) *Context {
 | 
			
		||||
	out := &Context{
 | 
			
		||||
		doneCh: make(chan struct{}),
 | 
			
		||||
		codec:  codec,
 | 
			
		||||
		ctx:    ctx,
 | 
			
		||||
	}
 | 
			
		||||
	if ctx == nil {
 | 
			
		||||
		out.ctx = context.Background()
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		<-out.ctx.Done()
 | 
			
		||||
		out.cancel()
 | 
			
		||||
	}()
 | 
			
		||||
	return out
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MakeChannel changes the function it's called in into a
 | 
			
		||||
@@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Cancel cancels the context
 | 
			
		||||
func (ctx *Context) Cancel() {
 | 
			
		||||
func (ctx *Context) cancel() {
 | 
			
		||||
	if ctx.canceled {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ctx.canceled = true
 | 
			
		||||
	close(ctx.doneCh)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,6 +19,7 @@
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
@@ -67,7 +68,7 @@ func New() *Server {
 | 
			
		||||
// Close closes the server
 | 
			
		||||
func (s *Server) Close() {
 | 
			
		||||
	for _, ctx := range s.contexts {
 | 
			
		||||
		ctx.Cancel()
 | 
			
		||||
		ctx.cancel()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -98,7 +99,7 @@ func (s *Server) Register(v any) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// execute runs a method of a registered value
 | 
			
		||||
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
 | 
			
		||||
	// Try to get value from receivers map
 | 
			
		||||
	val, ok := s.rcvrs[typ]
 | 
			
		||||
	if !ok {
 | 
			
		||||
@@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
 | 
			
		||||
		arg = val.Interface()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create new context
 | 
			
		||||
	ctx = &Context{
 | 
			
		||||
		doneCh: make(chan struct{}, 1),
 | 
			
		||||
		codec:  c,
 | 
			
		||||
	}
 | 
			
		||||
	ctx = newContext(pCtx, c)
 | 
			
		||||
	// Get reflect value of context
 | 
			
		||||
	ctxVal := reflect.ValueOf(ctx)
 | 
			
		||||
 | 
			
		||||
@@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
 | 
			
		||||
 | 
			
		||||
// Serve starts the server using the provided listener
 | 
			
		||||
// and codec function
 | 
			
		||||
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
 | 
			
		||||
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
 | 
			
		||||
	go func() {
 | 
			
		||||
		<-ctx.Done()
 | 
			
		||||
		ln.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := ln.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if errors.Is(err, net.ErrClosed) {
 | 
			
		||||
			break
 | 
			
		||||
		} else if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create new instance of codec bound to conn
 | 
			
		||||
		c := cf(conn)
 | 
			
		||||
		// Handle connection
 | 
			
		||||
		go s.handleConn(c)
 | 
			
		||||
		go s.handleConn(ctx, c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ServeWS starts a server using WebSocket. This may be useful for
 | 
			
		||||
// clients written in other languages, such as JS for a browser.
 | 
			
		||||
func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
 | 
			
		||||
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
 | 
			
		||||
	// Create new WebSocket server
 | 
			
		||||
	ws := websocket.Server{}
 | 
			
		||||
 | 
			
		||||
@@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
 | 
			
		||||
 | 
			
		||||
	// Set server handler
 | 
			
		||||
	ws.Handler = func(c *websocket.Conn) {
 | 
			
		||||
		s.handleConn(cf(c))
 | 
			
		||||
		s.handleConn(c.Request().Context(), cf(c))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	server := &http.Server{
 | 
			
		||||
		Addr: addr,
 | 
			
		||||
		BaseContext: func(net.Listener) context.Context {
 | 
			
		||||
			return ctx
 | 
			
		||||
		},
 | 
			
		||||
		Handler: http.HandlerFunc(ws.ServeHTTP),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Listen and serve on given address
 | 
			
		||||
	return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP))
 | 
			
		||||
	return server.ListenAndServe()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleConn handles a listener connection
 | 
			
		||||
func (s *Server) handleConn(c codec.Codec) {
 | 
			
		||||
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
 | 
			
		||||
	codecMtx := &sync.Mutex{}
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
@@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) {
 | 
			
		||||
 | 
			
		||||
		// Execute decoded call
 | 
			
		||||
		val, ctx, err := s.execute(
 | 
			
		||||
			pCtx,
 | 
			
		||||
			call.Receiver,
 | 
			
		||||
			call.Method,
 | 
			
		||||
			call.Arg,
 | 
			
		||||
@@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) {
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// Cancel context
 | 
			
		||||
					ctx.Cancel()
 | 
			
		||||
					ctx.cancel()
 | 
			
		||||
					// Delete context from map
 | 
			
		||||
					s.contextsMtx.Lock()
 | 
			
		||||
					delete(s.contexts, ctx.channelID)
 | 
			
		||||
@@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Cancel context
 | 
			
		||||
	ctx.Cancel()
 | 
			
		||||
	ctx.cancel()
 | 
			
		||||
	// Delete context from map
 | 
			
		||||
	l.srv.contextsMtx.Lock()
 | 
			
		||||
	delete(l.srv.contexts, id)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user