go-lemmy/websocket.go

215 lines
4.5 KiB
Go
Raw Normal View History

2022-12-13 01:35:41 +00:00
package lemmy
import (
2022-12-13 02:12:22 +00:00
"context"
"encoding/json"
2023-01-09 20:54:06 +00:00
"fmt"
"net"
2022-12-13 02:12:22 +00:00
"net/http"
2022-12-13 01:35:41 +00:00
"net/url"
2022-12-13 02:12:22 +00:00
"reflect"
2023-01-09 20:54:06 +00:00
"time"
2022-12-13 01:35:41 +00:00
2023-01-09 20:54:06 +00:00
"github.com/cenkalti/backoff/v4"
"github.com/gorilla/websocket"
2022-12-13 01:35:41 +00:00
"go.arsenm.dev/go-lemmy/types"
)
2022-12-13 02:12:22 +00:00
type authData struct {
Auth string `json:"auth"`
}
2022-12-13 18:49:58 +00:00
// WSClient is a client for Lemmy's WebSocket API
2022-12-13 01:35:41 +00:00
type WSClient struct {
2023-01-09 20:54:06 +00:00
conn *websocket.Conn
baseURL *url.URL
respCh chan types.LemmyWebSocketMsg
errCh chan error
recHandler func(c *WSClient)
Token string
2022-12-13 01:35:41 +00:00
}
2022-12-13 18:49:58 +00:00
// NewWebSocket creates and returns a new WSClient, and
// starts a goroutine to read server responses and errors
2022-12-13 01:35:41 +00:00
func NewWebSocket(baseURL string) (*WSClient, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, err
}
u = u.JoinPath("/api/v3")
2023-01-09 20:54:06 +00:00
conn, _, err := keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
if err != nil {
return nil, err
}
2022-12-13 01:35:41 +00:00
out := &WSClient{
conn: conn,
2022-12-13 02:12:22 +00:00
baseURL: u,
respCh: make(chan types.LemmyWebSocketMsg, 10),
errCh: make(chan error, 10),
2022-12-13 01:35:41 +00:00
}
go func() {
for {
var msg types.LemmyWebSocketMsg
err = conn.ReadJSON(&msg)
2022-12-13 01:35:41 +00:00
if err != nil {
out.errCh <- err
2023-01-09 20:54:06 +00:00
2023-01-24 21:19:22 +00:00
conn.Close()
err = backoff.RetryNotify(
func() error {
conn, _, err = keepaliveDialer().Dial(u.JoinPath("ws").String(), nil)
if err != nil {
out.errCh <- err
return err
}
out.conn = conn
out.recHandler(out)
return nil
},
backoff.NewExponentialBackOff(),
func(err error, d time.Duration) {
out.errCh <- fmt.Errorf("reconnect backoff (%s): %w", d, err)
},
)
if err != nil {
out.errCh <- err
2023-01-09 20:54:06 +00:00
}
2022-12-13 01:35:41 +00:00
continue
}
out.respCh <- msg
}
}()
return out, nil
}
// ClientLogin logs in to Lemmy by sending an HTTP request to the
2022-12-13 18:49:58 +00:00
// login endpoint. It stores the returned token in the client
// for future use.
func (c *WSClient) ClientLogin(ctx context.Context, l types.Login) error {
2022-12-13 02:12:22 +00:00
u := &url.URL{}
*u = *c.baseURL
if u.Scheme == "ws" {
u.Scheme = "http"
} else if u.Scheme == "wss" {
u.Scheme = "https"
}
hc := &Client{baseURL: u, client: http.DefaultClient}
err := hc.ClientLogin(ctx, l)
2022-12-13 02:12:22 +00:00
if err != nil {
return err
}
2022-12-13 18:49:58 +00:00
c.Token = hc.Token
2022-12-13 02:12:22 +00:00
return nil
}
2022-12-13 18:49:58 +00:00
// Request sends a request to the server. If data is nil,
// the authentication token will be sent instead. If data
// has an Auth field, it will be set to the authentication
// token automatically.
2023-01-05 21:17:10 +00:00
func (c *WSClient) Request(op types.Operation, data any) error {
2022-12-13 02:12:22 +00:00
if data == nil {
data = authData{}
}
data = c.setAuth(data)
d, err := json.Marshal(data)
if err != nil {
return err
}
2022-12-13 01:35:41 +00:00
return c.conn.WriteJSON(types.LemmyWebSocketMsg{
2023-01-05 21:17:10 +00:00
Op: op.Operation(),
2022-12-13 02:12:22 +00:00
Data: d,
2022-12-13 01:35:41 +00:00
})
}
2022-12-13 18:49:58 +00:00
// Responses returns a channel that receives messages from
// Lemmy.
2022-12-13 01:35:41 +00:00
func (c *WSClient) Responses() <-chan types.LemmyWebSocketMsg {
return c.respCh
}
2022-12-13 18:49:58 +00:00
// Errors returns a channel that receives errors
// received while attempting to read responses
2022-12-13 01:35:41 +00:00
func (c *WSClient) Errors() <-chan error {
return c.errCh
}
2023-01-09 20:54:06 +00:00
func (c *WSClient) OnReconnect(rh func(c *WSClient)) {
c.recHandler = rh
}
2022-12-13 18:49:58 +00:00
// setAuth uses reflection to automatically
// set struct fields called Auth of type
// string or types.Optional[string] to the
// authentication token, then returns the
// updated struct
2022-12-13 02:12:22 +00:00
func (c *WSClient) setAuth(data any) any {
val := reflect.New(reflect.TypeOf(data))
val.Elem().Set(reflect.ValueOf(data))
authField := val.Elem().FieldByName("Auth")
if !authField.IsValid() {
return data
}
switch authField.Type().String() {
case "string":
2022-12-13 18:49:58 +00:00
authField.SetString(c.Token)
2022-12-13 02:12:22 +00:00
case "types.Optional[string]":
setMtd := authField.MethodByName("Set")
2022-12-13 18:49:58 +00:00
out := setMtd.Call([]reflect.Value{reflect.ValueOf(c.Token)})
2022-12-13 02:12:22 +00:00
authField.Set(out[0])
default:
return data
2022-12-13 01:35:41 +00:00
}
2022-12-13 02:12:22 +00:00
return val.Elem().Interface()
}
func DecodeResponse(data json.RawMessage, out any) error {
return json.Unmarshal(data, out)
2022-12-13 01:35:41 +00:00
}
2023-01-09 20:54:06 +00:00
func keepaliveDialer() *websocket.Dialer {
d := &websocket.Dialer{
NetDial: func(network, addr string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
conn, err := net.DialTCP(network, nil, tcpAddr)
if err != nil {
return nil, err
}
err = conn.SetKeepAlive(true)
if err != nil {
return nil, err
}
2023-01-24 21:19:22 +00:00
err = conn.SetKeepAlivePeriod(10 * time.Second)
2023-01-09 20:54:06 +00:00
if err != nil {
return nil, err
}
return conn, nil
},
}
d.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.NetDial(network, addr)
}
return d
}