Add WebSocket reconnects
This commit is contained in:
		
							
								
								
									
										1
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								go.mod
									
									
									
									
									
								
							@@ -5,6 +5,7 @@ go 1.19
 | 
				
			|||||||
retract v0.0.0-20230105203020-27ef17a00e22
 | 
					retract v0.0.0-20230105203020-27ef17a00e22
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
 | 
						github.com/cenkalti/backoff/v4 v4.2.0
 | 
				
			||||||
	github.com/dave/jennifer v1.6.0
 | 
						github.com/dave/jennifer v1.6.0
 | 
				
			||||||
	github.com/google/go-querystring v1.1.0
 | 
						github.com/google/go-querystring v1.1.0
 | 
				
			||||||
	github.com/gorilla/websocket v1.4.2
 | 
						github.com/gorilla/websocket v1.4.2
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.sum
									
									
									
									
									
								
							@@ -1,3 +1,5 @@
 | 
				
			|||||||
 | 
					github.com/cenkalti/backoff/v4 v4.2.0 h1:HN5dHm3WBOgndBH6E8V0q2jIYIR3s9yglV8k/+MN3u4=
 | 
				
			||||||
 | 
					github.com/cenkalti/backoff/v4 v4.2.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
 | 
				
			||||||
github.com/dave/jennifer v1.6.0 h1:MQ/6emI2xM7wt0tJzJzyUik2Q3Tcn2eE0vtYgh4GPVI=
 | 
					github.com/dave/jennifer v1.6.0 h1:MQ/6emI2xM7wt0tJzJzyUik2Q3Tcn2eE0vtYgh4GPVI=
 | 
				
			||||||
github.com/dave/jennifer v1.6.0/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk=
 | 
					github.com/dave/jennifer v1.6.0/go.mod h1:AxTG893FiZKqxy3FP1kL80VMshSMuz2G+EgvszgGRnk=
 | 
				
			||||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
 | 
					github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										71
									
								
								websocket.go
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								websocket.go
									
									
									
									
									
								
							@@ -3,10 +3,15 @@ package lemmy
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/cenkalti/backoff/v4"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
	"go.arsenm.dev/go-lemmy/types"
 | 
						"go.arsenm.dev/go-lemmy/types"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -21,6 +26,7 @@ type WSClient struct {
 | 
				
			|||||||
	baseURL    *url.URL
 | 
						baseURL    *url.URL
 | 
				
			||||||
	respCh     chan types.LemmyWebSocketMsg
 | 
						respCh     chan types.LemmyWebSocketMsg
 | 
				
			||||||
	errCh      chan error
 | 
						errCh      chan error
 | 
				
			||||||
 | 
						recHandler func(c *WSClient)
 | 
				
			||||||
	Token      string
 | 
						Token      string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -33,7 +39,7 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	u = u.JoinPath("/api/v3")
 | 
						u = u.JoinPath("/api/v3")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	conn, _, err := websocket.DefaultDialer.Dial(u.JoinPath("ws").String(), nil)
 | 
						conn, _, err := keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -51,6 +57,31 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
 | 
				
			|||||||
			err = conn.ReadJSON(&msg)
 | 
								err = conn.ReadJSON(&msg)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				out.errCh <- err
 | 
									out.errCh <- err
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									// gorilla/websocket eats the error type, so I have to check
 | 
				
			||||||
 | 
									// the string itself
 | 
				
			||||||
 | 
									if strings.HasSuffix(err.Error(), "connection timed out") {
 | 
				
			||||||
 | 
										err = backoff.RetryNotify(
 | 
				
			||||||
 | 
											func() error {
 | 
				
			||||||
 | 
												conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
 | 
				
			||||||
 | 
												if err != nil {
 | 
				
			||||||
 | 
													out.errCh <- err
 | 
				
			||||||
 | 
													return err
 | 
				
			||||||
 | 
												}
 | 
				
			||||||
 | 
												out.conn = conn
 | 
				
			||||||
 | 
												out.recHandler(out)
 | 
				
			||||||
 | 
												return nil
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
											backoff.NewExponentialBackOff(),
 | 
				
			||||||
 | 
											func(err error, d time.Duration) {
 | 
				
			||||||
 | 
												out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err)
 | 
				
			||||||
 | 
											},
 | 
				
			||||||
 | 
										)
 | 
				
			||||||
 | 
										if err != nil {
 | 
				
			||||||
 | 
											out.errCh <- err
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			out.respCh <- msg
 | 
								out.respCh <- msg
 | 
				
			||||||
@@ -116,6 +147,10 @@ func (c *WSClient) Errors() <-chan error {
 | 
				
			|||||||
	return c.errCh
 | 
						return c.errCh
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (c *WSClient) OnReconnect(rh func(c *WSClient)) {
 | 
				
			||||||
 | 
						c.recHandler = rh
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// setAuth uses reflection to automatically
 | 
					// setAuth uses reflection to automatically
 | 
				
			||||||
// set struct fields called Auth of type
 | 
					// set struct fields called Auth of type
 | 
				
			||||||
// string or types.Optional[string] to the
 | 
					// string or types.Optional[string] to the
 | 
				
			||||||
@@ -147,3 +182,37 @@ func (c *WSClient) setAuth(data any) any {
 | 
				
			|||||||
func DecodeResponse(data json.RawMessage, out any) error {
 | 
					func DecodeResponse(data json.RawMessage, out any) error {
 | 
				
			||||||
	return json.Unmarshal(data, out)
 | 
						return json.Unmarshal(data, out)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func keepaliveDialer() *websocket.Dialer {
 | 
				
			||||||
 | 
						d := &websocket.Dialer{
 | 
				
			||||||
 | 
							NetDial: func(network, addr string) (net.Conn, error) {
 | 
				
			||||||
 | 
								tcpAddr, err := net.ResolveTCPAddr(network, addr)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								conn, err := net.DialTCP(network, nil, tcpAddr)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								err = conn.SetKeepAlive(true)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								err = conn.SetKeepAlivePeriod(time.Second)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return conn, nil
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						d.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
 | 
				
			||||||
 | 
							return d.NetDial(network, addr)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return d
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user