Add WebSocket reconnects

This commit is contained in:
Elara 2023-01-09 12:54:06 -08:00
parent 31aa7361ec
commit 4fe5c94d9c
3 changed files with 78 additions and 6 deletions

1
go.mod
View File

@ -5,6 +5,7 @@ go 1.19
retract v0.0.0-20230105203020-27ef17a00e22
require (
github.com/cenkalti/backoff/v4 v4.2.0
github.com/dave/jennifer v1.6.0
github.com/google/go-querystring v1.1.0
github.com/gorilla/websocket v1.4.2

2
go.sum
View File

@ -1,3 +1,5 @@
github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/dave/jennifer v1.6.0 h1:MQ/6emI2xM7wt0tJzJzyUik2Q3Tcn2eE0vtYgh4GPVI=
github.com/dave/jennifer v1.6.0/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=

View File

@ -3,10 +3,15 @@ package lemmy
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/gorilla/websocket"
"go.arsenm.dev/go-lemmy/types"
)
@ -21,6 +26,7 @@ type WSClient struct {
baseURL *url.URL
respCh chan types.LemmyWebSocketMsg
errCh chan error
recHandler func(c *WSClient)
Token string
}
@ -33,7 +39,7 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
}
u = u.JoinPath("/api/v3")
conn, _, err := websocket.DefaultDialer.Dial(u.JoinPath("ws").String(), nil)
conn, _, err := keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
if err != nil {
return nil, err
}
@ -51,6 +57,31 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
err = conn.ReadJSON(&msg)
if err != nil {
out.errCh <- err
// gorilla/websocket eats the error type, so I have to check
// the string itself
if strings.HasSuffix(err.Error(), "connection timed out") {
err = backoff.RetryNotify(
func() error {
conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
if err != nil {
out.errCh <- err
return err
}
out.conn = conn
out.recHandler(out)
return nil
},
backoff.NewExponentialBackOff(),
func(err error, d time.Duration) {
out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err)
},
)
if err != nil {
out.errCh <- err
}
}
continue
}
out.respCh <- msg
@ -116,6 +147,10 @@ func (c *WSClient) Errors() <-chan error {
return c.errCh
}
func (c *WSClient) OnReconnect(rh func(c *WSClient)) {
c.recHandler = rh
}
// setAuth uses reflection to automatically
// set struct fields called Auth of type
// string or types.Optional[string] to the
@ -147,3 +182,37 @@ func (c *WSClient) setAuth(data any) any {
func DecodeResponse(data json.RawMessage, out any) error {
return json.Unmarshal(data, out)
}
func keepaliveDialer() *websocket.Dialer {
d := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
conn, err := net.DialTCP(network, nil, tcpAddr)
if err != nil {
return nil, err
}
err = conn.SetKeepAlive(true)
if err != nil {
return nil, err
}
err = conn.SetKeepAlivePeriod(time.Second)
if err != nil {
return nil, err
}
return conn, nil
},
}
d.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.NetDial(network, addr)
}
return d
}