diff --git a/websocket.go b/websocket.go index d45660b..e3af8e1 100644 --- a/websocket.go +++ b/websocket.go @@ -1,18 +1,27 @@ package lemmy import ( + "context" + "encoding/json" + "net/http" "net/url" + "reflect" "time" - "github.com/mitchellh/mapstructure" "github.com/recws-org/recws" "go.arsenm.dev/go-lemmy/types" ) +type authData struct { + Auth string `json:"auth"` +} + type WSClient struct { - conn *recws.RecConn - respCh chan types.LemmyWebSocketMsg - errCh chan error + conn *recws.RecConn + baseURL *url.URL + respCh chan types.LemmyWebSocketMsg + errCh chan error + token string } func NewWebSocket(baseURL string) (*WSClient, error) { @@ -26,12 +35,13 @@ func NewWebSocket(baseURL string) (*WSClient, error) { } u = u.JoinPath("/api/v3") - ws.Dial(u.String(), nil) + ws.Dial(u.JoinPath("ws").String(), nil) out := &WSClient{ - conn: ws, - respCh: make(chan types.LemmyWebSocketMsg, 10), - errCh: make(chan error, 10), + conn: ws, + baseURL: u, + respCh: make(chan types.LemmyWebSocketMsg, 10), + errCh: make(chan error, 10), } go func() { @@ -49,10 +59,40 @@ func NewWebSocket(baseURL string) (*WSClient, error) { return out, nil } +func (c *WSClient) Login(ctx context.Context, l types.Login) error { + u := &url.URL{} + *u = *c.baseURL + + if u.Scheme == "ws" { + u.Scheme = "http" + } else if u.Scheme == "wss" { + u.Scheme = "https" + } + + hc := &Client{baseURL: u, client: http.DefaultClient} + err := hc.Login(ctx, l) + if err != nil { + return err + } + c.token = hc.token + return nil +} + func (c *WSClient) Request(op types.UserOperation, data any) error { + if data == nil { + data = authData{} + } + + data = c.setAuth(data) + + d, err := json.Marshal(data) + if err != nil { + return err + } + return c.conn.WriteJSON(types.LemmyWebSocketMsg{ Op: op, - Data: data, + Data: d, }) } @@ -64,13 +104,29 @@ func (c *WSClient) Errors() <-chan error { return c.errCh } -func DecodeResponse(data, out any) error { - dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", - Result: out, - }) - if err != nil { - return err +func (c *WSClient) setAuth(data any) any { + val := reflect.New(reflect.TypeOf(data)) + val.Elem().Set(reflect.ValueOf(data)) + + authField := val.Elem().FieldByName("Auth") + if !authField.IsValid() { + return data } - return dec.Decode(data) + + switch authField.Type().String() { + case "string": + authField.SetString(c.token) + case "types.Optional[string]": + setMtd := authField.MethodByName("Set") + out := setMtd.Call([]reflect.Value{reflect.ValueOf(c.token)}) + authField.Set(out[0]) + default: + return data + } + + return val.Elem().Interface() +} + +func DecodeResponse(data json.RawMessage, out any) error { + return json.Unmarshal(data, out) }