Propagate context to requests
This commit is contained in:
parent
af77b121f8
commit
3bcc01fdb6
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
@ -33,5 +34,5 @@ func main() {
|
|||||||
s.Register(Arith{})
|
s.Register(Arith{})
|
||||||
|
|
||||||
ln, _ := net.Listen("tcp", ":9090")
|
ln, _ := net.Listen("tcp", ":9090")
|
||||||
s.Serve(ln, codec.Gob)
|
s.Serve(context.Background(), ln, codec.Gob)
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,24 @@ type Context struct {
|
|||||||
|
|
||||||
doneCh chan struct{}
|
doneCh chan struct{}
|
||||||
canceled bool
|
canceled bool
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func newContext(ctx context.Context, codec codec.Codec) *Context {
|
||||||
|
out := &Context{
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
codec: codec,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
out.ctx = context.Background()
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
<-out.ctx.Done()
|
||||||
|
out.cancel()
|
||||||
|
}()
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeChannel changes the function it's called in into a
|
// MakeChannel changes the function it's called in into a
|
||||||
@ -87,7 +105,10 @@ func (ctx *Context) Done() <-chan struct{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel cancels the context
|
// Cancel cancels the context
|
||||||
func (ctx *Context) Cancel() {
|
func (ctx *Context) cancel() {
|
||||||
|
if ctx.canceled {
|
||||||
|
return
|
||||||
|
}
|
||||||
ctx.canceled = true
|
ctx.canceled = true
|
||||||
close(ctx.doneCh)
|
close(ctx.doneCh)
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -67,7 +68,7 @@ func New() *Server {
|
|||||||
// Close closes the server
|
// Close closes the server
|
||||||
func (s *Server) Close() {
|
func (s *Server) Close() {
|
||||||
for _, ctx := range s.contexts {
|
for _, ctx := range s.contexts {
|
||||||
ctx.Cancel()
|
ctx.cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,7 +99,7 @@ func (s *Server) Register(v any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// execute runs a method of a registered value
|
// execute runs a method of a registered value
|
||||||
func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
func (s *Server) execute(pCtx context.Context, typ string, name string, arg any, c codec.Codec) (a any, ctx *Context, err error) {
|
||||||
// Try to get value from receivers map
|
// Try to get value from receivers map
|
||||||
val, ok := s.rcvrs[typ]
|
val, ok := s.rcvrs[typ]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -140,11 +141,7 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
|||||||
arg = val.Interface()
|
arg = val.Interface()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new context
|
ctx = newContext(pCtx, c)
|
||||||
ctx = &Context{
|
|
||||||
doneCh: make(chan struct{}, 1),
|
|
||||||
codec: c,
|
|
||||||
}
|
|
||||||
// Get reflect value of context
|
// Get reflect value of context
|
||||||
ctxVal := reflect.ValueOf(ctx)
|
ctxVal := reflect.ValueOf(ctx)
|
||||||
|
|
||||||
@ -232,23 +229,30 @@ func (s *Server) execute(typ string, name string, arg any, c codec.Codec) (a any
|
|||||||
|
|
||||||
// Serve starts the server using the provided listener
|
// Serve starts the server using the provided listener
|
||||||
// and codec function
|
// and codec function
|
||||||
func (s *Server) Serve(ln net.Listener, cf codec.CodecFunc) {
|
func (s *Server) Serve(ctx context.Context, ln net.Listener, cf codec.CodecFunc) {
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
ln.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new instance of codec bound to conn
|
// Create new instance of codec bound to conn
|
||||||
c := cf(conn)
|
c := cf(conn)
|
||||||
// Handle connection
|
// Handle connection
|
||||||
go s.handleConn(c)
|
go s.handleConn(ctx, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeWS starts a server using WebSocket. This may be useful for
|
// ServeWS starts a server using WebSocket. This may be useful for
|
||||||
// clients written in other languages, such as JS for a browser.
|
// clients written in other languages, such as JS for a browser.
|
||||||
func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
|
func (s *Server) ServeWS(ctx context.Context, addr string, cf codec.CodecFunc) (err error) {
|
||||||
// Create new WebSocket server
|
// Create new WebSocket server
|
||||||
ws := websocket.Server{}
|
ws := websocket.Server{}
|
||||||
|
|
||||||
@ -259,15 +263,23 @@ func (s *Server) ServeWS(addr string, cf codec.CodecFunc) (err error) {
|
|||||||
|
|
||||||
// Set server handler
|
// Set server handler
|
||||||
ws.Handler = func(c *websocket.Conn) {
|
ws.Handler = func(c *websocket.Conn) {
|
||||||
s.handleConn(cf(c))
|
s.handleConn(c.Request().Context(), cf(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
BaseContext: func(net.Listener) context.Context {
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
Handler: http.HandlerFunc(ws.ServeHTTP),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen and serve on given address
|
// Listen and serve on given address
|
||||||
return http.ListenAndServe(addr, http.HandlerFunc(ws.ServeHTTP))
|
return server.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConn handles a listener connection
|
// handleConn handles a listener connection
|
||||||
func (s *Server) handleConn(c codec.Codec) {
|
func (s *Server) handleConn(pCtx context.Context, c codec.Codec) {
|
||||||
codecMtx := &sync.Mutex{}
|
codecMtx := &sync.Mutex{}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -283,6 +295,7 @@ func (s *Server) handleConn(c codec.Codec) {
|
|||||||
|
|
||||||
// Execute decoded call
|
// Execute decoded call
|
||||||
val, ctx, err := s.execute(
|
val, ctx, err := s.execute(
|
||||||
|
pCtx,
|
||||||
call.Receiver,
|
call.Receiver,
|
||||||
call.Method,
|
call.Method,
|
||||||
call.Arg,
|
call.Arg,
|
||||||
@ -322,7 +335,7 @@ func (s *Server) handleConn(c codec.Codec) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel context
|
// Cancel context
|
||||||
ctx.Cancel()
|
ctx.cancel()
|
||||||
// Delete context from map
|
// Delete context from map
|
||||||
s.contextsMtx.Lock()
|
s.contextsMtx.Lock()
|
||||||
delete(s.contexts, ctx.channelID)
|
delete(s.contexts, ctx.channelID)
|
||||||
@ -370,7 +383,7 @@ func (l lrpc) ChannelDone(_ *Context, id string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cancel context
|
// Cancel context
|
||||||
ctx.Cancel()
|
ctx.cancel()
|
||||||
// Delete context from map
|
// Delete context from map
|
||||||
l.srv.contextsMtx.Lock()
|
l.srv.contextsMtx.Lock()
|
||||||
delete(l.srv.contexts, id)
|
delete(l.srv.contexts, id)
|
||||||
|
Reference in New Issue
Block a user