lasso/cmd/lasso/cmd/server.go

228 lines
6.4 KiB
Go

/*
Copyright © 2021 Arsen Musayelyan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cmd
import (
"errors"
"net"
"net/http"
"github.com/dgraph-io/badger/v3"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/vmihailenco/msgpack/v5"
"go.arsenm.dev/lasso/internal/logging"
"go.arsenm.dev/lasso/internal/types"
)
var (
ErrAlreadyExists = errors.New("node already exists in database")
ErrNoExists = errors.New("node does not exist in database")
)
// serverCmd represents the server command
var serverCmd = &cobra.Command{
Use: "server",
Short: "Start the lasso master server",
Run: func(cmd *cobra.Command, args []string) {
// Open nodes database
nodes, err := badger.Open(
badger.DefaultOptions("/etc/lasso/nodes").
WithLogger(logging.BadgerLogger{}),
)
if err != nil {
log.Fatal().Err(err).Msg("Error opening badger database")
}
// Close database at end of function
defer nodes.Close()
// Create new router
router := chi.NewMux()
router.Use(logging.ChiLogger)
// GET /status (Status check)
router.Get("/status", func(res http.ResponseWriter, req *http.Request) {
// Encode empty response to connection
msgpack.NewEncoder(res).Encode(types.Response{})
})
router.Route("/node", func(node chi.Router) {
// POST /node/:name (Create new node and return id)
node.Post("/{name}", func(res http.ResponseWriter, req *http.Request) {
// Get name parameter from URL
name := chi.URLParam(req, "name")
var nodeReq types.Node
// Decode request body as a node
err = msgpack.NewDecoder(req.Body).Decode(&nodeReq)
if err != nil {
httpError(res, http.StatusBadRequest, err, "Unable to decode request body")
return
}
// Create new UUID
id := uuid.New().String()
// Set node ID
nodeReq.ID = id
// Update nodes database
err = nodes.Update(func(txn *badger.Txn) error {
// Attempt to get node from database
item, err := txn.Get([]byte(name))
// If error exists but item found
if err != nil && err != badger.ErrKeyNotFound {
return err
} else if item != nil {
return ErrAlreadyExists
}
// Encode request node to msgpack
data, _ := msgpack.Marshal(nodeReq)
// Set new node in database and return error
return txn.Set([]byte(name), data)
})
if err != nil {
httpError(res, http.StatusInternalServerError, err, "Error adding node to database")
return
}
// Encode response with id to connection
msgpack.NewEncoder(res).Encode(types.Response{Data: id})
})
// PATCH /node/:name (Update IP address of a node)
node.Patch("/{name}", func(res http.ResponseWriter, req *http.Request) {
// Get name parameter from URL
name := chi.URLParam(req, "name")
var nodeReq types.Node
// Decode request body as node
err = msgpack.NewDecoder(req.Body).Decode(&nodeReq)
if err != nil {
httpError(res, http.StatusBadRequest, err, "Unable to decode request body")
return
}
var dbNode types.Node
// View nodes database
err = nodes.View(func(txn *badger.Txn) error {
// Attempt to get node from database
item, err := txn.Get([]byte(name))
if err != nil {
return err
}
// Decode node and return error
return item.Value(func(val []byte) error {
return msgpack.Unmarshal(val, &dbNode)
})
})
if err != nil {
httpError(res, http.StatusBadRequest, err, "Unable to get node from database")
return
}
// If request and database IDs are not the same
if nodeReq.ID != dbNode.ID {
httpError(res, http.StatusForbidden, nil, "Incorrect UUID for specified node")
return
}
// Update nodes database
err = nodes.Update(func(txn *badger.Txn) error {
// Encode node sent in request
data, err := msgpack.Marshal(nodeReq)
if err != nil {
return err
}
// Set new node in database and return error
return txn.Set([]byte(name), data)
})
if err != nil {
httpError(res, http.StatusInternalServerError, err, "Error updating node in database")
return
}
// Encode response to connection
msgpack.NewEncoder(res).Encode(types.Response{})
})
// GET /node/list (Get list of all nodes)
node.Get("/list", func(res http.ResponseWriter, req *http.Request) {
// Create map of nodes for output
out := map[string]types.Node{}
// View nodes database
nodes.View(func(txn *badger.Txn) error {
// Create new database iterator
it := txn.NewIterator(badger.DefaultIteratorOptions)
// Close iterator at end of function
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
// Get item from iterator
item := it.Item()
// Get key from item
key := item.Key()
var node types.Node
// Get value from iterator
err = item.Value(func(val []byte) error {
// Decode value as node and return error
return msgpack.Unmarshal(val, &node)
})
if err != nil {
return err
}
// Remove node ID
node.ID = ""
// Set node in map
out[string(key)] = node
}
return nil
})
// Encode map in response on connection
msgpack.NewEncoder(res).Encode(types.Response{
Data: out,
})
})
})
// Get listen address for server from config
addr := net.JoinHostPort(viper.GetString("server.addr"), viper.GetString("server.port"))
// Log HTTPS server starting
log.Info().Str("addr", addr).Msg("Starting HTTPS server")
// Start HTTPS server using certificate paths from config
err = http.ListenAndServeTLS(
addr,
viper.GetString("server.tls.cert"),
viper.GetString("server.tls.key"),
router,
)
if err != nil {
log.Fatal().Err(err).Msg("Error while running server")
}
},
}
func init() {
rootCmd.AddCommand(serverCmd)
}