Add WebSocket reconnects
This commit is contained in:
parent
31aa7361ec
commit
4fe5c94d9c
1
go.mod
1
go.mod
@ -5,6 +5,7 @@ go 1.19
|
||||
retract v0.0.0-20230105203020-27ef17a00e22
|
||||
|
||||
require (
|
||||
github.com/cenkalti/backoff/v4 v4.2.0
|
||||
github.com/dave/jennifer v1.6.0
|
||||
github.com/google/go-querystring v1.1.0
|
||||
github.com/gorilla/websocket v1.4.2
|
||||
|
2
go.sum
2
go.sum
@ -1,3 +1,5 @@
|
||||
github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
|
||||
github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/dave/jennifer v1.6.0 h1:MQ/6emI2xM7wt0tJzJzyUik2Q3Tcn2eE0vtYgh4GPVI=
|
||||
github.com/dave/jennifer v1.6.0/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
|
81
websocket.go
81
websocket.go
@ -3,10 +3,15 @@ package lemmy
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.arsenm.dev/go-lemmy/types"
|
||||
)
|
||||
@ -17,11 +22,12 @@ type authData struct {
|
||||
|
||||
// WSClient is a client for Lemmy's WebSocket API
|
||||
type WSClient struct {
|
||||
conn *websocket.Conn
|
||||
baseURL *url.URL
|
||||
respCh chan types.LemmyWebSocketMsg
|
||||
errCh chan error
|
||||
Token string
|
||||
conn *websocket.Conn
|
||||
baseURL *url.URL
|
||||
respCh chan types.LemmyWebSocketMsg
|
||||
errCh chan error
|
||||
recHandler func(c *WSClient)
|
||||
Token string
|
||||
}
|
||||
|
||||
// NewWebSocket creates and returns a new WSClient, and
|
||||
@ -33,7 +39,7 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
|
||||
}
|
||||
u = u.JoinPath("/api/v3")
|
||||
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.JoinPath("ws").String(), nil)
|
||||
conn, _, err := keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -51,6 +57,31 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
|
||||
err = conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
out.errCh <- err
|
||||
|
||||
// gorilla/websocket eats the error type, so I have to check
|
||||
// the string itself
|
||||
if strings.HasSuffix(err.Error(), "connection timed out") {
|
||||
err = backoff.RetryNotify(
|
||||
func() error {
|
||||
conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
|
||||
if err != nil {
|
||||
out.errCh <- err
|
||||
return err
|
||||
}
|
||||
out.conn = conn
|
||||
out.recHandler(out)
|
||||
return nil
|
||||
},
|
||||
backoff.NewExponentialBackOff(),
|
||||
func(err error, d time.Duration) {
|
||||
out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err)
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
out.errCh <- err
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
out.respCh <- msg
|
||||
@ -116,6 +147,10 @@ func (c *WSClient) Errors() <-chan error {
|
||||
return c.errCh
|
||||
}
|
||||
|
||||
func (c *WSClient) OnReconnect(rh func(c *WSClient)) {
|
||||
c.recHandler = rh
|
||||
}
|
||||
|
||||
// setAuth uses reflection to automatically
|
||||
// set struct fields called Auth of type
|
||||
// string or types.Optional[string] to the
|
||||
@ -147,3 +182,37 @@ func (c *WSClient) setAuth(data any) any {
|
||||
func DecodeResponse(data json.RawMessage, out any) error {
|
||||
return json.Unmarshal(data, out)
|
||||
}
|
||||
|
||||
func keepaliveDialer() *websocket.Dialer {
|
||||
d := &websocket.Dialer{
|
||||
NetDial: func(network, addr string) (net.Conn, error) {
|
||||
tcpAddr, err := net.ResolveTCPAddr(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP(network, nil, tcpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = conn.SetKeepAlive(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = conn.SetKeepAlivePeriod(time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
}
|
||||
|
||||
d.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return d.NetDial(network, addr)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user