package profilefed import ( "bytes" "crypto/ed25519" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "sync" "queerdevs.org/profilefed/webfinger" ) const responseSizeLimit = 32_000_000 var ( // ErrPubkeyNotFound signifies that the server public key is not found. ErrPubkeyNotFound = errors.New("server pubkey not found") // ErrNoSignature signifies that the response contains no signature. ErrNoSignature = errors.New("response contains no signature") // ErrSignatureMismatch signifies that the message does not match the server signature. ErrSignatureMismatch = errors.New("message does not match server signature") ) // DefaultClient returns a default client for ProfileFed. // // It uses an in-memory synchronized map to store public keys. // For production, it's highly recommended to implement a custom // client that persists the keys to a database or similar, so that // restarting your app doesn't provide opportunities for malicious servers. func DefaultClient() Client { defaultMap := sync.Map{} return Client{ SavePubkey: func(serverName string, previousNames []string, pubkey ed25519.PublicKey) error { defaultMap.Store(serverName, pubkey) for _, name := range previousNames { defaultMap.Delete(name) } return nil }, GetPubkey: func(serverName string) (ed25519.PublicKey, error) { pubkey, ok := defaultMap.Load(serverName) if !ok { return nil, ErrPubkeyNotFound } return pubkey.(ed25519.PublicKey), nil }, } } // Client represents a ProfileFed client type Client struct { // SavePubkey saves the public key for a given server. SavePubkey func(serverName string, previousNames []string, pubkey ed25519.PublicKey) error // GetPubkey retrieves the public key for a given server. // If the key isn't found, GetPubkey should return [ErrPubkeyNotFound] GetPubkey func(serverName string) (ed25519.PublicKey, error) } // Lookup looks up the profile descriptor for the given resource. func (c Client) Lookup(resource string) (*Descriptor, error) { out := &Descriptor{} return out, c.lookup(resource, "", out) } // LookupID looks up the profile descriptor that matches the given ID // for the given resource. func (c Client) LookupID(resource, id string) (*Descriptor, error) { out := &Descriptor{} return out, c.lookup(resource, id, out) } // Lookup looks up all the available profile descriptors for the given resource. func (c Client) LookupAll(resource string) (map[string]*Descriptor, error) { out := map[string]*Descriptor{} return out, c.lookup(resource, "", &out) } func (c Client) lookup(resource, id string, dest any) error { wfdesc, err := webfinger.LookupAcct(resource) if err != nil { return err } pfdLink, ok := wfdesc.LinkByType("application/x-pfd+json") if !ok { return errors.New("server does not support the profilefed protocol") } pfdURL, err := url.Parse(pfdLink.Href) if err != nil { return err } if id != "" { q := pfdURL.Query() q.Set("id", id) pfdURL.RawQuery = q.Encode() } pubkeySaved := false pubkey, err := c.GetPubkey(pfdURL.Host) if errors.Is(err, ErrPubkeyNotFound) { info, _, err := getServerInfo(pfdURL.Scheme, pfdURL.Host) if err != nil { return err } pubkey, err = base64.StdEncoding.DecodeString(info.PublicKey) if err != nil { return err } err = c.SavePubkey(pfdURL.Host, info.PreviousNames, pubkey) if err != nil { return err } pubkeySaved = true } else if err != nil { return err } res, err := http.Get(pfdURL.String()) if err != nil { return err } defer res.Body.Close() if err := checkResp(res, "getProfileDescriptor"); err != nil { return err } data, err := io.ReadAll(io.LimitReader(res.Body, responseSizeLimit)) if err != nil { return err } if err := res.Body.Close(); err != nil { return err } sig, err := getSignature(res) if err != nil { return err } if !ed25519.Verify(pubkey, data, sig) { // If the pubkey was just saved in the current request, we probably // already have the newest one, so just return a mismatch error. if pubkeySaved { return ErrSignatureMismatch } res, err := serverInfoReq(pfdURL.Scheme, pfdURL.Host) if err != nil { return err } serverData, err := io.ReadAll(io.LimitReader(res.Body, responseSizeLimit)) if err != nil { return err } var info serverInfoData err = json.Unmarshal(serverData, &info) if err != nil { return err } newPubkey, err := base64.StdEncoding.DecodeString(info.PublicKey) if err != nil { return err } if bytes.Equal(pubkey, newPubkey) { return ErrSignatureMismatch } verified := false sigs := getPrevSignatures(res) for _, sig := range sigs { if ed25519.Verify(pubkey, serverData, sig) { verified = true break } } if !verified { return ErrSignatureMismatch } infoSig, err := getSignature(res) if err != nil { return err } if !ed25519.Verify(newPubkey, infoSig, serverData) { return ErrSignatureMismatch } err = c.SavePubkey(pfdURL.Host, info.PreviousNames, newPubkey) if err != nil { return err } if !ed25519.Verify(newPubkey, data, sig) { return ErrSignatureMismatch } } return json.Unmarshal(data, dest) } // serverInfoReq performs an HTTP request to retrieve server information. func serverInfoReq(scheme, host string) (*http.Response, error) { serverInfoURL := url.URL{ Scheme: scheme, Host: host, Path: "/_profilefed/server", } return http.Get(serverInfoURL.String()) } // getServerInfo retrieves server information. func getServerInfo(scheme, host string) (serverInfoData, [][]byte, error) { res, err := serverInfoReq(scheme, host) if err != nil { return serverInfoData{}, nil, err } defer res.Body.Close() if err := checkResp(res, "getServerInfo"); err != nil { return serverInfoData{}, nil, err } var out serverInfoData err = json.NewDecoder(io.LimitReader(res.Body, responseSizeLimit)).Decode(&out) return out, getPrevSignatures(res), err } // getPrevSignatures extracts previous signatures from a response. func getPrevSignatures(res *http.Response) [][]byte { var sigs [][]byte sigStrs := res.Header[http.CanonicalHeaderKey("X-ProfileFed-Previous")] for _, sigStr := range sigStrs { sig, err := base64.StdEncoding.DecodeString(sigStr) if err != nil { continue // Skip invalid signatures } sigs = append(sigs, sig) } return sigs } // getSignature extracts the signature from a response. func getSignature(res *http.Response) ([]byte, error) { sigStr := res.Header.Get("X-ProfileFed-Sig") if sigStr == "" { return nil, ErrNoSignature } return base64.StdEncoding.DecodeString(sigStr) } // checkResp returns an error if the response is not 200 OK. func checkResp(res *http.Response, opName string) error { if res.StatusCode != 200 { return fmt.Errorf("%s: %s", opName, res.Status) } return nil }