From 4fe5c94d9caf9d218d2b893b99d5799d37ce99c4 Mon Sep 17 00:00:00 2001 From: Elara Musayelyan Date: Mon, 9 Jan 2023 12:54:06 -0800 Subject: [PATCH] Add WebSocket reconnects --- go.mod | 1 + go.sum | 2 ++ websocket.go | 81 ++++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 37807a5..7ba8bf1 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.19 retract v0.0.0-20230105203020-27ef17a00e22 require ( + github.com/cenkalti/backoff/v4 v4.2.0 github.com/dave/jennifer v1.6.0 github.com/google/go-querystring v1.1.0 github.com/gorilla/websocket v1.4.2 diff --git a/go.sum b/go.sum index bfce658..98632fd 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4= +github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/dave/jennifer v1.6.0 h1:MQ/6emI2xM7wt0tJzJzyUik2Q3Tcn2eE0vtYgh4GPVI= github.com/dave/jennifer v1.6.0/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/websocket.go b/websocket.go index fa2bffa..24347ab 100644 --- a/websocket.go +++ b/websocket.go @@ -3,10 +3,15 @@ package lemmy import ( "context" "encoding/json" + "fmt" + "net" "net/http" "net/url" "reflect" + "strings" + "time" + "github.com/cenkalti/backoff/v4" "github.com/gorilla/websocket" "go.arsenm.dev/go-lemmy/types" ) @@ -17,11 +22,12 @@ type authData struct { // WSClient is a client for Lemmy's WebSocket API type WSClient struct { - conn *websocket.Conn - baseURL *url.URL - respCh chan types.LemmyWebSocketMsg - errCh chan error - Token string + conn *websocket.Conn + baseURL *url.URL + respCh chan types.LemmyWebSocketMsg + errCh chan error + recHandler func(c *WSClient) + Token string } // NewWebSocket creates and returns a new WSClient, and @@ -33,7 +39,7 @@ func NewWebSocket(baseURL string) (*WSClient, error) { } u = u.JoinPath("/api/v3") - conn, _, err := websocket.DefaultDialer.Dial(u.JoinPath("ws").String(), nil) + conn, _, err := keepaliveDialer().Dial(u.JoinPath("ws").String(), nil) if err != nil { return nil, err } @@ -51,6 +57,31 @@ func NewWebSocket(baseURL string) (*WSClient, error) { err = conn.ReadJSON(&msg) if err != nil { out.errCh <- err + + // gorilla/websocket eats the error type, so I have to check + // the string itself + if strings.HasSuffix(err.Error(), "connection timed out") { + err = backoff.RetryNotify( + func() error { + conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil) + if err != nil { + out.errCh <- err + return err + } + out.conn = conn + out.recHandler(out) + return nil + }, + backoff.NewExponentialBackOff(), + func(err error, d time.Duration) { + out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err) + }, + ) + if err != nil { + out.errCh <- err + } + } + continue } out.respCh <- msg @@ -116,6 +147,10 @@ func (c *WSClient) Errors() <-chan error { return c.errCh } +func (c *WSClient) OnReconnect(rh func(c *WSClient)) { + c.recHandler = rh +} + // setAuth uses reflection to automatically // set struct fields called Auth of type // string or types.Optional[string] to the @@ -147,3 +182,37 @@ func (c *WSClient) setAuth(data any) any { func DecodeResponse(data json.RawMessage, out any) error { return json.Unmarshal(data, out) } + +func keepaliveDialer() *websocket.Dialer { + d := &websocket.Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + tcpAddr, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return nil, err + } + + conn, err := net.DialTCP(network, nil, tcpAddr) + if err != nil { + return nil, err + } + + err = conn.SetKeepAlive(true) + if err != nil { + return nil, err + } + + err = conn.SetKeepAlivePeriod(time.Second) + if err != nil { + return nil, err + } + + return conn, nil + }, + } + + d.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return d.NetDial(network, addr) + } + + return d +}