228 lines
6.4 KiB
Go
228 lines
6.4 KiB
Go
/*
|
|
Copyright © 2021 Arsen Musayelyan
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
*/
|
|
package cmd
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
|
|
"github.com/dgraph-io/badger/v3"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/cobra"
|
|
"github.com/spf13/viper"
|
|
"github.com/vmihailenco/msgpack/v5"
|
|
"go.arsenm.dev/lasso/internal/logging"
|
|
"go.arsenm.dev/lasso/internal/types"
|
|
)
|
|
|
|
var (
|
|
ErrAlreadyExists = errors.New("node already exists in database")
|
|
ErrNoExists = errors.New("node does not exist in database")
|
|
)
|
|
|
|
// serverCmd represents the server command
|
|
var serverCmd = &cobra.Command{
|
|
Use: "server",
|
|
Short: "Start the lasso master server",
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
|
|
// Open nodes database
|
|
nodes, err := badger.Open(
|
|
badger.DefaultOptions("/etc/lasso/nodes").
|
|
WithLogger(logging.BadgerLogger{}),
|
|
)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error opening badger database")
|
|
}
|
|
// Close database at end of function
|
|
defer nodes.Close()
|
|
|
|
// Create new router
|
|
router := chi.NewMux()
|
|
|
|
router.Use(logging.ChiLogger)
|
|
|
|
// GET /status (Status check)
|
|
router.Get("/status", func(res http.ResponseWriter, req *http.Request) {
|
|
// Encode empty response to connection
|
|
msgpack.NewEncoder(res).Encode(types.Response{})
|
|
})
|
|
|
|
router.Route("/node", func(node chi.Router) {
|
|
// POST /node/:name (Create new node and return id)
|
|
node.Post("/{name}", func(res http.ResponseWriter, req *http.Request) {
|
|
// Get name parameter from URL
|
|
name := chi.URLParam(req, "name")
|
|
|
|
var nodeReq types.Node
|
|
// Decode request body as a node
|
|
err = msgpack.NewDecoder(req.Body).Decode(&nodeReq)
|
|
if err != nil {
|
|
httpError(res, http.StatusBadRequest, err, "Unable to decode request body")
|
|
return
|
|
}
|
|
|
|
// Create new UUID
|
|
id := uuid.New().String()
|
|
// Set node ID
|
|
nodeReq.ID = id
|
|
|
|
// Update nodes database
|
|
err = nodes.Update(func(txn *badger.Txn) error {
|
|
// Attempt to get node from database
|
|
item, err := txn.Get([]byte(name))
|
|
// If error exists but item found
|
|
if err != nil && err != badger.ErrKeyNotFound {
|
|
return err
|
|
} else if item != nil {
|
|
return ErrAlreadyExists
|
|
}
|
|
|
|
// Encode request node to msgpack
|
|
data, _ := msgpack.Marshal(nodeReq)
|
|
// Set new node in database and return error
|
|
return txn.Set([]byte(name), data)
|
|
})
|
|
if err != nil {
|
|
httpError(res, http.StatusInternalServerError, err, "Error adding node to database")
|
|
return
|
|
}
|
|
|
|
// Encode response with id to connection
|
|
msgpack.NewEncoder(res).Encode(types.Response{Data: id})
|
|
})
|
|
|
|
// PATCH /node/:name (Update IP address of a node)
|
|
node.Patch("/{name}", func(res http.ResponseWriter, req *http.Request) {
|
|
// Get name parameter from URL
|
|
name := chi.URLParam(req, "name")
|
|
|
|
var nodeReq types.Node
|
|
// Decode request body as node
|
|
err = msgpack.NewDecoder(req.Body).Decode(&nodeReq)
|
|
if err != nil {
|
|
httpError(res, http.StatusBadRequest, err, "Unable to decode request body")
|
|
return
|
|
}
|
|
|
|
var dbNode types.Node
|
|
// View nodes database
|
|
err = nodes.View(func(txn *badger.Txn) error {
|
|
// Attempt to get node from database
|
|
item, err := txn.Get([]byte(name))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Decode node and return error
|
|
return item.Value(func(val []byte) error {
|
|
return msgpack.Unmarshal(val, &dbNode)
|
|
})
|
|
})
|
|
if err != nil {
|
|
httpError(res, http.StatusBadRequest, err, "Unable to get node from database")
|
|
return
|
|
}
|
|
|
|
// If request and database IDs are not the same
|
|
if nodeReq.ID != dbNode.ID {
|
|
httpError(res, http.StatusForbidden, nil, "Incorrect UUID for specified node")
|
|
return
|
|
}
|
|
|
|
// Update nodes database
|
|
err = nodes.Update(func(txn *badger.Txn) error {
|
|
// Encode node sent in request
|
|
data, err := msgpack.Marshal(nodeReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Set new node in database and return error
|
|
return txn.Set([]byte(name), data)
|
|
})
|
|
if err != nil {
|
|
httpError(res, http.StatusInternalServerError, err, "Error updating node in database")
|
|
return
|
|
}
|
|
|
|
// Encode response to connection
|
|
msgpack.NewEncoder(res).Encode(types.Response{})
|
|
})
|
|
|
|
// GET /node/list (Get list of all nodes)
|
|
node.Get("/list", func(res http.ResponseWriter, req *http.Request) {
|
|
// Create map of nodes for output
|
|
out := map[string]types.Node{}
|
|
|
|
// View nodes database
|
|
nodes.View(func(txn *badger.Txn) error {
|
|
// Create new database iterator
|
|
it := txn.NewIterator(badger.DefaultIteratorOptions)
|
|
// Close iterator at end of function
|
|
defer it.Close()
|
|
for it.Rewind(); it.Valid(); it.Next() {
|
|
// Get item from iterator
|
|
item := it.Item()
|
|
// Get key from item
|
|
key := item.Key()
|
|
var node types.Node
|
|
// Get value from iterator
|
|
err = item.Value(func(val []byte) error {
|
|
// Decode value as node and return error
|
|
return msgpack.Unmarshal(val, &node)
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Remove node ID
|
|
node.ID = ""
|
|
// Set node in map
|
|
out[string(key)] = node
|
|
}
|
|
return nil
|
|
})
|
|
|
|
// Encode map in response on connection
|
|
msgpack.NewEncoder(res).Encode(types.Response{
|
|
Data: out,
|
|
})
|
|
})
|
|
})
|
|
|
|
// Get listen address for server from config
|
|
addr := net.JoinHostPort(viper.GetString("server.addr"), viper.GetString("server.port"))
|
|
// Log HTTPS server starting
|
|
log.Info().Str("addr", addr).Msg("Starting HTTPS server")
|
|
// Start HTTPS server using certificate paths from config
|
|
err = http.ListenAndServeTLS(
|
|
addr,
|
|
viper.GetString("server.tls.cert"),
|
|
viper.GetString("server.tls.key"),
|
|
router,
|
|
)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error while running server")
|
|
}
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(serverCmd)
|
|
}
|