223 lines
6.2 KiB
Go
223 lines
6.2 KiB
Go
/*
|
|
Copyright © 2021 Arsen Musayelyan
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
package cmd
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mitchellh/mapstructure"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/cobra"
|
|
"github.com/spf13/viper"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
"go.arsenm.dev/lasso/internal/types"
|
|
)
|
|
|
|
// clientCmd represents the client command
|
|
var clientCmd = &cobra.Command{
|
|
Use: "client",
|
|
Short: "Start the lasso client",
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
// Disable certificate verification as server uses self-signed key
|
|
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
|
|
// Perform server status check
|
|
res, err := http.Get(url("status"))
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Server status check failed")
|
|
}
|
|
res.Body.Close()
|
|
|
|
// Get node list from server
|
|
res, err = http.Get(url("node", "list"))
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error getting node list from server")
|
|
}
|
|
|
|
var resp types.Response
|
|
// Decode server response
|
|
err = msgpack.NewDecoder(res.Body).Decode(&resp)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error decoding response body")
|
|
}
|
|
res.Body.Close()
|
|
|
|
var nodes map[string]types.Node
|
|
// Decode response data as node list
|
|
err = mapstructure.Decode(resp.Data, &nodes)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error decoding response")
|
|
}
|
|
|
|
// Get local IP
|
|
ip, err := localIP()
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error getting local IP")
|
|
}
|
|
|
|
// Attempt to get node from list
|
|
node, ok := nodes[viper.GetString("node.name")]
|
|
// If node does not exist
|
|
if !ok {
|
|
// Encode node as msgpack
|
|
data, err := msgpack.Marshal(types.Node{IP: ip})
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error encoding new node")
|
|
}
|
|
// Send node to server
|
|
res, err = http.Post(url("node", viper.GetString("node.name")), "text/plain", bytes.NewReader(data))
|
|
if err != nil {
|
|
log.Fatal().Msg("Error adding new node to server")
|
|
}
|
|
var resp types.Response
|
|
// Decode server response
|
|
err = msgpack.NewDecoder(res.Body).Decode(&resp)
|
|
if err != nil {
|
|
log.Fatal().Msg("Error decoding server response")
|
|
}
|
|
// If server returned error
|
|
if resp.Error {
|
|
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
|
|
}
|
|
// Set new node IP
|
|
node.IP = ip
|
|
// Set node ID in viper
|
|
viper.Set("node.id", resp.Data)
|
|
// Attempt to write new ID to config
|
|
if err := viper.WriteConfig(); err != nil {
|
|
log.Fatal().Err(err).Msg("Error writing new ID to config")
|
|
}
|
|
}
|
|
|
|
// If IP does not match current
|
|
if node.IP != ip {
|
|
// Encode new node as msgpack
|
|
data, err := msgpack.Marshal(types.Node{
|
|
IP: ip,
|
|
ID: viper.GetString("node.id"),
|
|
})
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error encoding updated node")
|
|
}
|
|
// Create new PATCH request with new node as data
|
|
req, err := http.NewRequest(
|
|
http.MethodPatch,
|
|
url("node", viper.GetString("node.name")),
|
|
bytes.NewReader(data),
|
|
)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error creating PATCH request")
|
|
}
|
|
// Perform request
|
|
res, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error sending update request to server")
|
|
}
|
|
var resp types.Response
|
|
// Decode server response
|
|
err = msgpack.NewDecoder(res.Body).Decode(&resp)
|
|
if err != nil {
|
|
log.Fatal().Msg("Error decoding server response")
|
|
}
|
|
// If server returned error
|
|
if resp.Error {
|
|
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
|
|
}
|
|
}
|
|
|
|
// Every minute
|
|
for range time.Tick(time.Minute) {
|
|
// Get local IP
|
|
newIP, err := localIP()
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Error getting new local IP")
|
|
continue
|
|
}
|
|
// If IP has changed since last check
|
|
if newIP != ip {
|
|
// Set IP to new IP
|
|
ip = newIP
|
|
// Encode new node as msgpack
|
|
data, err := msgpack.Marshal(types.Node{
|
|
IP: newIP,
|
|
ID: viper.GetString("node.id"),
|
|
})
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Error encoding updated node")
|
|
continue
|
|
}
|
|
// Create new PATCH request with new node as data
|
|
req, err := http.NewRequest(
|
|
http.MethodPatch,
|
|
url("node", viper.GetString("node.name")),
|
|
bytes.NewReader(data),
|
|
)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error creating PATCH request")
|
|
}
|
|
// Perform request
|
|
res, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error sending update request to server")
|
|
}
|
|
var resp types.Response
|
|
// Decode server response
|
|
err = msgpack.NewDecoder(res.Body).Decode(&resp)
|
|
if err != nil {
|
|
log.Fatal().Msg("Error decoding server response")
|
|
}
|
|
// If server returned error
|
|
if resp.Error {
|
|
log.Fatal().Str("error", resp.Message).Msg("Error returned by server")
|
|
}
|
|
}
|
|
}
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(clientCmd)
|
|
}
|
|
|
|
// localIP returns the IP address of the default network interface
|
|
func localIP() (string, error) {
|
|
// Make UDP "connection" to nonexistant address
|
|
conn, err := net.Dial("udp", "255.255.255.255:65535")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer conn.Close()
|
|
// Get local address
|
|
addr := conn.LocalAddr().String()
|
|
// Get host from address
|
|
host, _, _ := net.SplitHostPort(addr)
|
|
return host, nil
|
|
}
|
|
|
|
// url generates a URL for the given path on the server
|
|
func url(path ...string) string {
|
|
// Get server address with port
|
|
serverAddr := net.JoinHostPort(viper.GetString("server.addr"), viper.GetString("server.port"))
|
|
// Return HTTPS address
|
|
return "https://" + serverAddr + "/" + strings.Join(path, "/")
|
|
}
|