diff --git a/muxserver/muxserver.go b/muxserver/muxserver.go index d9a8c2d..f63cde1 100644 --- a/muxserver/muxserver.go +++ b/muxserver/muxserver.go @@ -30,18 +30,24 @@ func NewWithOptions(handler drpc.Handler, opts drpcserver.Options) *Server { // Serve listens on the given listener and handles all multiplexed // streams. func (s *Server) Serve(ctx context.Context, ln net.Listener) error { + connCh := make(chan net.Conn, 2) + errCh := make(chan error, 2) + go connChannel(ln, connCh, errCh) + for { - conn, err := ln.Accept() - if err != nil { - return err - } + select { + case conn := <-connCh: + sess, err := yamux.Server(conn, nil) + if err != nil { + return err + } - sess, err := yamux.Server(conn, nil) - if err != nil { + go s.handleSession(ctx, sess) + case err := <-errCh: return err + case <-ctx.Done(): + return nil } - - go s.handleSession(ctx, sess) } } @@ -62,3 +68,14 @@ func (s *Server) handleSession(ctx context.Context, sess *yamux.Session) { func (s *Server) ServeOne(ctx context.Context, conn io.ReadWriteCloser) error { return s.srv.ServeOne(ctx, conn) } + +func connChannel(ln net.Listener, connCh chan net.Conn, errCh chan error) { + for { + conn, err := ln.Accept() + if err != nil { + errCh <- err + continue + } + connCh <- conn + } +}