lasso/cmd/lassoctl/cmd/mssh.go

362 lines
9.2 KiB
Go

/*
Copyright © 2021 NAME HERE <EMAIL ADDRESS>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"bufio"
"io"
"os"
"path/filepath"
"strings"
"sync"
"github.com/abiosoft/ishell"
"github.com/melbahja/goph"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
// msshCmd represents the mssh command
var msshCmd = &cobra.Command{
Use: "mssh <user@node...>",
Short: "Make an SSH connection to multiple nodes simultaneously",
Run: func(cmd *cobra.Command, args []string) {
// Create new shell
shell := ishell.New()
var password string
// If password prompt requested
if viper.GetBool("password") {
// Print prompt
shell.Print("Password: ")
// Read password into variable
password = shell.ReadPassword()
}
// Get node list from server
nodes := getNodeList()
// Create new goph auth
var auth goph.Auth
// If password prompt requested
if viper.GetBool("password") {
// Add password method to auth
auth = append(auth, goph.Password(password)...)
}
// If identity given
if viper.GetString("identity") != "" {
// Get identity as ssh key
keyAuth, err := goph.Key(viper.GetString("identity"), "")
if err != nil {
log.Fatal().Err(err).Msg("Error getting SSH key:")
}
// Add key to auth
auth = append(auth, keyAuth...)
}
// Get all identities in default location
keys, err := getAllIdentities()
if err != nil {
log.Fatal().Msg("Error getting SSH identities")
}
// Add keys to auth
auth = append(auth, keys...)
// If ssh agent exists
if goph.HasAgent() {
// Get ssh agent
agent, err := goph.UseAgent()
if err != nil {
log.Fatal().Msg("Error getting SSH agent")
}
// Add ssh agent to auth
auth = append(auth, agent...)
}
// Create map to store goph clients
clients := map[string]*goph.Client{}
// For every device
for _, device := range args {
// Split device by "@"
splitArg := strings.Split(device, "@")
// If split device has less than two elements, it is invalid
if len(splitArg) != 2 {
log.Fatal().Msg("Invalid username/node argument")
}
// Get variables from split device
username, nodeName := splitArg[0], splitArg[1]
// Get node from list if it exists
node, ok := nodes[nodeName]
if !ok {
log.Fatal().Str("node", nodeName).Msg("Node does not exist on the server")
}
// Connect to node without verifying known hosts
client, err := goph.NewUnknown(username, node.IP, auth)
if err != nil {
log.Fatal().Err(err).Msg("Error connecting to node")
}
// Add device to client list
clients[device] = client
}
// Clode all clients at the end of the function
defer closeClients(clients)
// Add run command to shell
shell.AddCmd(&ishell.Cmd{
Name: "run",
Help: "Run a shell command on all nodes simultaneously.",
Func: func(c *ishell.Context) {
// Create new wait group
wg := &sync.WaitGroup{}
// For every client
for name, client := range clients {
// Add goroutine to wait group
wg.Add(1)
go func(client *goph.Client, name string) {
// Remove from wait group at the end of the function
defer wg.Done()
// Create command from given arguments
cmd, err := client.Command(c.Args[0], c.Args[1:]...)
if err != nil {
cmdError(shell, name, err)
return
}
// Get command Stdout pipe
stdout, err := cmd.StdoutPipe()
if err != nil {
cmdError(shell, name, err)
return
}
// Get command Stderr pipe
stderr, err := cmd.StderrPipe()
if err != nil {
cmdError(shell, name, err)
return
}
// Combine Stdout and Stderr
combined := io.MultiReader(stdout, stderr)
// Start command without waiting for it to finish
if err := cmd.Start(); err != nil {
cmdError(shell, name, err)
return
}
// Create new scanner for combined output
scanner := bufio.NewScanner(combined)
for scanner.Scan() {
// Print command output
cmdOut(shell, name, scanner.Text())
}
}(client, name)
}
// Wait for all goroutines to complete
wg.Wait()
},
})
// Create file command
file := &ishell.Cmd{
Name: "file",
Help: "Transfer files to/from nodes",
}
// Add send command to file command
file.AddCmd(&ishell.Cmd{
Name: "send",
Help: "Send a file to all nodes. '~' will expand to device's home directory.",
Func: func(c *ishell.Context) {
// Get local and remote paths
localPath := c.Args[0]
remotePath := c.Args[1]
// Create new wait group
wg := &sync.WaitGroup{}
// For every client
for name, client := range clients {
// Set new remote to remote path
newRemote := remotePath
// If new remote starts with "~"
if strings.HasPrefix(newRemote, "~") {
// Replace ~ with "/home/user"
newRemote = filepath.Join(
"/home/"+client.User(),
strings.TrimPrefix(newRemote, "~"),
)
}
// Add one goroutine to wait group
wg.Add(1)
go func(client *goph.Client, name string) {
// Remove from wait group at the end of the function
defer wg.Done()
// Upload file to remote
err := client.Upload(localPath, newRemote)
if err != nil {
shell.Printf("%s [error]: %v\n", name, err)
return
}
// Print success message
cmdSuccess(shell, name, "upload complete")
}(client, name)
// Wait for all added goroutines to complete
wg.Wait()
}
},
})
// Add retrieve command to shell
file.AddCmd(&ishell.Cmd{
Name: "retrieve",
Aliases: []string{"retr", "recv", "get"},
Help: "Retrieve file from all nodes. File will be saved as 'file-user@node.ext'",
Func: func(c *ishell.Context) {
// Get remote and local paths
remotePath := c.Args[0]
localPath := c.Args[1]
// Create new wait group
wg := &sync.WaitGroup{}
// For every client
for name, client := range clients {
// Get extension, base, and directory of local path
localExt := filepath.Ext(localPath)
localBase := filepath.Base(localPath)
localDir := filepath.Dir(localPath)
// Attempt to stat local path
info, err := os.Stat(localPath)
if err != nil {
log.Fatal().Err(err).Msg("Error getting file info")
}
// If local path is a directory
if info.IsDir() {
// Set local directory to local path
localDir = localPath
// Set local extention to remote extension
localExt = filepath.Ext(remotePath)
// Set local base to remote base
localBase = filepath.Base(remotePath)
}
// Remove extension from local base
localBaseNoExt := strings.TrimSuffix(localBase, localExt)
// Create new local, formatted as "filename-user@node.extension"
newLocal := filepath.Join(localDir, localBaseNoExt+"-"+name+localExt)
// Set new remote to remote path
newRemote := remotePath
// If new remote starts with "~"
if strings.HasPrefix(newRemote, "~") {
// Replace ~ with "/home/user"
newRemote = filepath.Join(
"/home/"+client.User(),
strings.TrimPrefix(newRemote, "~"),
)
}
// Add goroutine to wait group
wg.Add(1)
go func(client *goph.Client, name string) {
// Remove from wait group at end of function
defer wg.Done()
// Download file from remote
err := client.Download(newRemote, newLocal)
if err != nil {
cmdError(shell, name, err)
return
}
// Print success message
cmdSuccess(shell, name, "download complete")
}(client, name)
// Wait for all added goroutines to complete
wg.Wait()
}
},
})
// Add file command to shell
shell.AddCmd(file)
// Run shell
shell.Run()
},
}
func init() {
rootCmd.AddCommand(msshCmd)
msshCmd.Flags().BoolP("password", "p", false, "Prompt for SSH password on start")
msshCmd.Flags().StringP("identity", "i", "", "SSH identity file to use")
viper.BindPFlags(msshCmd.Flags())
}
func getAllIdentities() (goph.Auth, error) {
// Get user's home directory
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
// Get all ssh keys in "~/.ssh"
matches, err := filepath.Glob(filepath.Join(home, ".ssh", "id_*.pub"))
if err != nil {
return nil, err
}
// Create goph auth for keys
var out goph.Auth
// For every glob match
for _, match := range matches {
// Get path of private key
privKeyPath := strings.TrimSuffix(match, ".pub")
// Get SSH key as goph ket
auth, err := goph.Key(privKeyPath, "")
if err != nil {
return out, err
}
// Add goph key to auth
out = append(out, auth...)
}
return out, nil
}
func closeClients(clients map[string]*goph.Client) {
// For every client
for _, client := range clients {
// Close client
client.Close()
}
}