Compare commits

...

2 Commits

Author SHA1 Message Date
a4a015a4cc Add login support for WebSocket API 2022-12-12 18:12:22 -08:00
a90e43691d Add WebSocket types 2022-12-12 18:11:57 -08:00
3 changed files with 113 additions and 19 deletions

View File

@ -1,6 +1,7 @@
package types package types
import ( import (
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
) )
@ -27,6 +28,6 @@ func (le LemmyError) Error() string {
} }
type LemmyWebSocketMsg struct { type LemmyWebSocketMsg struct {
Op UserOperation `json:"op"` Op UserOperation `json:"op"`
Data any `json:"data"` Data json.RawMessage `json:"data"`
} }

37
types/websocket.go Normal file
View File

@ -0,0 +1,37 @@
package types
type UserJoin struct {
Auth string `json:"auth"`
}
type UserJoinResponse struct {
Joined bool `json:"joined"`
LemmyResponse
}
type CommunityJoin struct {
CommunityID int `json:"community_id"`
}
type CommunityJoinResponse struct {
Joined bool `json:"joined"`
LemmyResponse
}
type ModJoin struct {
CommunityID int `json:"community_id"`
}
type ModJoinResponse struct {
Joined bool `json:"joined"`
LemmyResponse
}
type PostJoin struct {
PostID int `json:"post_id"`
}
type PostJoinResponse struct {
Joined bool `json:"joined"`
LemmyResponse
}

View File

@ -1,18 +1,27 @@
package lemmy package lemmy
import ( import (
"context"
"encoding/json"
"net/http"
"net/url" "net/url"
"reflect"
"time" "time"
"github.com/mitchellh/mapstructure"
"github.com/recws-org/recws" "github.com/recws-org/recws"
"go.arsenm.dev/go-lemmy/types" "go.arsenm.dev/go-lemmy/types"
) )
type authData struct {
Auth string `json:"auth"`
}
type WSClient struct { type WSClient struct {
conn *recws.RecConn conn *recws.RecConn
respCh chan types.LemmyWebSocketMsg baseURL *url.URL
errCh chan error respCh chan types.LemmyWebSocketMsg
errCh chan error
token string
} }
func NewWebSocket(baseURL string) (*WSClient, error) { func NewWebSocket(baseURL string) (*WSClient, error) {
@ -26,12 +35,13 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
} }
u = u.JoinPath("/api/v3") u = u.JoinPath("/api/v3")
ws.Dial(u.String(), nil) ws.Dial(u.JoinPath("ws").String(), nil)
out := &WSClient{ out := &WSClient{
conn: ws, conn: ws,
respCh: make(chan types.LemmyWebSocketMsg, 10), baseURL: u,
errCh: make(chan error, 10), respCh: make(chan types.LemmyWebSocketMsg, 10),
errCh: make(chan error, 10),
} }
go func() { go func() {
@ -49,10 +59,40 @@ func NewWebSocket(baseURL string) (*WSClient, error) {
return out, nil return out, nil
} }
func (c *WSClient) Login(ctx context.Context, l types.Login) error {
u := &url.URL{}
*u = *c.baseURL
if u.Scheme == "ws" {
u.Scheme = "http"
} else if u.Scheme == "wss" {
u.Scheme = "https"
}
hc := &Client{baseURL: u, client: http.DefaultClient}
err := hc.Login(ctx, l)
if err != nil {
return err
}
c.token = hc.token
return nil
}
func (c *WSClient) Request(op types.UserOperation, data any) error { func (c *WSClient) Request(op types.UserOperation, data any) error {
if data == nil {
data = authData{}
}
data = c.setAuth(data)
d, err := json.Marshal(data)
if err != nil {
return err
}
return c.conn.WriteJSON(types.LemmyWebSocketMsg{ return c.conn.WriteJSON(types.LemmyWebSocketMsg{
Op: op, Op: op,
Data: data, Data: d,
}) })
} }
@ -64,13 +104,29 @@ func (c *WSClient) Errors() <-chan error {
return c.errCh return c.errCh
} }
func DecodeResponse(data, out any) error { func (c *WSClient) setAuth(data any) any {
dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ val := reflect.New(reflect.TypeOf(data))
TagName: "json", val.Elem().Set(reflect.ValueOf(data))
Result: out,
}) authField := val.Elem().FieldByName("Auth")
if err != nil { if !authField.IsValid() {
return err return data
} }
return dec.Decode(data)
switch authField.Type().String() {
case "string":
authField.SetString(c.token)
case "types.Optional[string]":
setMtd := authField.MethodByName("Set")
out := setMtd.Call([]reflect.Value{reflect.ValueOf(c.token)})
authField.Set(out[0])
default:
return data
}
return val.Elem().Interface()
}
func DecodeResponse(data json.RawMessage, out any) error {
return json.Unmarshal(data, out)
} }