lasso/cmd/lasso/cmd/client.go

223 lines
6.2 KiB
Go

/*
Copyright © 2021 Arsen Musayelyan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cmd
import (
"bytes"
"crypto/tls"
"net"
"net/http"
"strings"
"time"
"github.com/mitchellh/mapstructure"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/vmihailenco/msgpack/v5"
"go.arsenm.dev/lasso/internal/types"
)
// clientCmd represents the client command
var clientCmd = &cobra.Command{
Use: "client",
Short: "Start the lasso client",
Run: func(cmd *cobra.Command, args []string) {
// Disable certificate verification as server uses self-signed key
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
// Perform server status check
res, err := http.Get(url("status"))
if err != nil {
log.Fatal().Err(err).Msg("Server status check failed")
}
res.Body.Close()
// Get node list from server
res, err = http.Get(url("node", "list"))
if err != nil {
log.Fatal().Err(err).Msg("Error getting node list from server")
}
var resp types.Response
// Decode server response
err = msgpack.NewDecoder(res.Body).Decode(&resp)
if err != nil {
log.Fatal().Err(err).Msg("Error decoding response body")
}
res.Body.Close()
var nodes map[string]types.Node
// Decode response data as node list
err = mapstructure.Decode(resp.Data, &nodes)
if err != nil {
log.Fatal().Err(err).Msg("Error decoding response")
}
// Get local IP
ip, err := localIP()
if err != nil {
log.Fatal().Err(err).Msg("Error getting local IP")
}
// Attempt to get node from list
node, ok := nodes[viper.GetString("node.name")]
// If node does not exist
if !ok {
// Encode node as msgpack
data, err := msgpack.Marshal(types.Node{IP: ip})
if err != nil {
log.Fatal().Err(err).Msg("Error encoding new node")
}
// Send node to server
res, err = http.Post(url("node", viper.GetString("node.name")), "text/plain", bytes.NewReader(data))
if err != nil {
log.Fatal().Msg("Error adding new node to server")
}
var resp types.Response
// Decode server response
err = msgpack.NewDecoder(res.Body).Decode(&resp)
if err != nil {
log.Fatal().Msg("Error decoding server response")
}
// If server returned error
if resp.Error {
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
}
// Set new node IP
node.IP = ip
// Set node ID in viper
viper.Set("node.id", resp.Data)
// Attempt to write new ID to config
if err := viper.WriteConfig(); err != nil {
log.Fatal().Err(err).Msg("Error writing new ID to config")
}
}
// If IP does not match current
if node.IP != ip {
// Encode new node as msgpack
data, err := msgpack.Marshal(types.Node{
IP: ip,
ID: viper.GetString("node.id"),
})
if err != nil {
log.Fatal().Err(err).Msg("Error encoding updated node")
}
// Create new PATCH request with new node as data
req, err := http.NewRequest(
http.MethodPatch,
url("node", viper.GetString("node.name")),
bytes.NewReader(data),
)
if err != nil {
log.Fatal().Err(err).Msg("Error creating PATCH request")
}
// Perform request
res, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatal().Err(err).Msg("Error sending update request to server")
}
var resp types.Response
// Decode server response
err = msgpack.NewDecoder(res.Body).Decode(&resp)
if err != nil {
log.Fatal().Msg("Error decoding server response")
}
// If server returned error
if resp.Error {
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
}
}
// Every minute
for range time.Tick(time.Minute) {
// Get local IP
newIP, err := localIP()
if err != nil {
log.Error().Err(err).Msg("Error getting new local IP")
continue
}
// If IP has changed since last check
if newIP != ip {
// Set IP to new IP
ip = newIP
// Encode new node as msgpack
data, err := msgpack.Marshal(types.Node{
IP: newIP,
ID: viper.GetString("node.id"),
})
if err != nil {
log.Error().Err(err).Msg("Error encoding updated node")
continue
}
// Create new PATCH request with new node as data
req, err := http.NewRequest(
http.MethodPatch,
url("node", viper.GetString("node.name")),
bytes.NewReader(data),
)
if err != nil {
log.Fatal().Err(err).Msg("Error creating PATCH request")
}
// Perform request
res, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatal().Err(err).Msg("Error sending update request to server")
}
var resp types.Response
// Decode server response
err = msgpack.NewDecoder(res.Body).Decode(&resp)
if err != nil {
log.Fatal().Msg("Error decoding server response")
}
// If server returned error
if resp.Error {
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
}
}
}
},
}
func init() {
rootCmd.AddCommand(clientCmd)
}
// localIP returns the IP address of the default network interface
func localIP() (string, error) {
// Make UDP "connection" to nonexistant address
conn, err := net.Dial("udp", "255.255.255.255:65535")
if err != nil {
return "", err
}
defer conn.Close()
// Get local address
addr := conn.LocalAddr().String()
// Get host from address
host, _, _ := net.SplitHostPort(addr)
return host, nil
}
// url generates a URL for the given path on the server
func url(path ...string) string {
// Get server address with port
serverAddr := net.JoinHostPort(viper.GetString("server.addr"), viper.GetString("server.port"))
// Return HTTPS address
return "https://" + serverAddr + "/" + strings.Join(path, "/")
}