362 lines
9.2 KiB
Go
362 lines
9.2 KiB
Go
/*
|
|
Copyright © 2021 NAME HERE <EMAIL ADDRESS>
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
package cmd
|
|
|
|
import (
|
|
"bufio"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/abiosoft/ishell"
|
|
"github.com/melbahja/goph"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/cobra"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// msshCmd represents the mssh command
|
|
var msshCmd = &cobra.Command{
|
|
Use: "mssh <user@node...>",
|
|
Short: "Make an SSH connection to multiple nodes simultaneously",
|
|
Run: func(cmd *cobra.Command, args []string) {
|
|
// Create new shell
|
|
shell := ishell.New()
|
|
|
|
var password string
|
|
// If password prompt requested
|
|
if viper.GetBool("password") {
|
|
// Print prompt
|
|
shell.Print("Password: ")
|
|
// Read password into variable
|
|
password = shell.ReadPassword()
|
|
}
|
|
|
|
// Get node list from server
|
|
nodes := getNodeList()
|
|
|
|
// Create new goph auth
|
|
var auth goph.Auth
|
|
|
|
// If password prompt requested
|
|
if viper.GetBool("password") {
|
|
// Add password method to auth
|
|
auth = append(auth, goph.Password(password)...)
|
|
}
|
|
|
|
// If identity given
|
|
if viper.GetString("identity") != "" {
|
|
// Get identity as ssh key
|
|
keyAuth, err := goph.Key(viper.GetString("identity"), "")
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error getting SSH key:")
|
|
}
|
|
// Add key to auth
|
|
auth = append(auth, keyAuth...)
|
|
}
|
|
|
|
// Get all identities in default location
|
|
keys, err := getAllIdentities()
|
|
if err != nil {
|
|
log.Fatal().Msg("Error getting SSH identities")
|
|
}
|
|
// Add keys to auth
|
|
auth = append(auth, keys...)
|
|
|
|
// If ssh agent exists
|
|
if goph.HasAgent() {
|
|
// Get ssh agent
|
|
agent, err := goph.UseAgent()
|
|
if err != nil {
|
|
log.Fatal().Msg("Error getting SSH agent")
|
|
}
|
|
// Add ssh agent to auth
|
|
auth = append(auth, agent...)
|
|
}
|
|
|
|
// Create map to store goph clients
|
|
clients := map[string]*goph.Client{}
|
|
|
|
// For every device
|
|
for _, device := range args {
|
|
// Split device by "@"
|
|
splitArg := strings.Split(device, "@")
|
|
|
|
// If split device has less than two elements, it is invalid
|
|
if len(splitArg) != 2 {
|
|
log.Fatal().Msg("Invalid username/node argument")
|
|
}
|
|
|
|
// Get variables from split device
|
|
username, nodeName := splitArg[0], splitArg[1]
|
|
|
|
// Get node from list if it exists
|
|
node, ok := nodes[nodeName]
|
|
if !ok {
|
|
log.Fatal().Str("node", nodeName).Msg("Node does not exist on the server")
|
|
}
|
|
|
|
// Connect to node without verifying known hosts
|
|
client, err := goph.NewUnknown(username, node.IP, auth)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error connecting to node")
|
|
}
|
|
// Add device to client list
|
|
clients[device] = client
|
|
}
|
|
|
|
// Clode all clients at the end of the function
|
|
defer closeClients(clients)
|
|
|
|
// Add run command to shell
|
|
shell.AddCmd(&ishell.Cmd{
|
|
Name: "run",
|
|
Help: "Run a shell command on all nodes simultaneously.",
|
|
Func: func(c *ishell.Context) {
|
|
// Create new wait group
|
|
wg := &sync.WaitGroup{}
|
|
// For every client
|
|
for name, client := range clients {
|
|
// Add goroutine to wait group
|
|
wg.Add(1)
|
|
go func(client *goph.Client, name string) {
|
|
// Remove from wait group at the end of the function
|
|
defer wg.Done()
|
|
// Create command from given arguments
|
|
cmd, err := client.Command(c.Args[0], c.Args[1:]...)
|
|
if err != nil {
|
|
cmdError(shell, name, err)
|
|
return
|
|
}
|
|
|
|
// Get command Stdout pipe
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
cmdError(shell, name, err)
|
|
return
|
|
}
|
|
// Get command Stderr pipe
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
cmdError(shell, name, err)
|
|
return
|
|
}
|
|
// Combine Stdout and Stderr
|
|
combined := io.MultiReader(stdout, stderr)
|
|
|
|
// Start command without waiting for it to finish
|
|
if err := cmd.Start(); err != nil {
|
|
cmdError(shell, name, err)
|
|
return
|
|
}
|
|
|
|
// Create new scanner for combined output
|
|
scanner := bufio.NewScanner(combined)
|
|
for scanner.Scan() {
|
|
// Print command output
|
|
cmdOut(shell, name, scanner.Text())
|
|
}
|
|
}(client, name)
|
|
}
|
|
// Wait for all goroutines to complete
|
|
wg.Wait()
|
|
},
|
|
})
|
|
|
|
// Create file command
|
|
file := &ishell.Cmd{
|
|
Name: "file",
|
|
Help: "Transfer files to/from nodes",
|
|
}
|
|
|
|
// Add send command to file command
|
|
file.AddCmd(&ishell.Cmd{
|
|
Name: "send",
|
|
Help: "Send a file to all nodes. '~' will expand to device's home directory.",
|
|
Func: func(c *ishell.Context) {
|
|
// Get local and remote paths
|
|
localPath := c.Args[0]
|
|
remotePath := c.Args[1]
|
|
|
|
// Create new wait group
|
|
wg := &sync.WaitGroup{}
|
|
// For every client
|
|
for name, client := range clients {
|
|
// Set new remote to remote path
|
|
newRemote := remotePath
|
|
// If new remote starts with "~"
|
|
if strings.HasPrefix(newRemote, "~") {
|
|
// Replace ~ with "/home/user"
|
|
newRemote = filepath.Join(
|
|
"/home/"+client.User(),
|
|
strings.TrimPrefix(newRemote, "~"),
|
|
)
|
|
}
|
|
|
|
// Add one goroutine to wait group
|
|
wg.Add(1)
|
|
go func(client *goph.Client, name string) {
|
|
// Remove from wait group at the end of the function
|
|
defer wg.Done()
|
|
|
|
// Upload file to remote
|
|
err := client.Upload(localPath, newRemote)
|
|
if err != nil {
|
|
shell.Printf("%s [error]: %v\n", name, err)
|
|
return
|
|
}
|
|
|
|
// Print success message
|
|
cmdSuccess(shell, name, "upload complete")
|
|
}(client, name)
|
|
|
|
// Wait for all added goroutines to complete
|
|
wg.Wait()
|
|
}
|
|
},
|
|
})
|
|
|
|
// Add retrieve command to shell
|
|
file.AddCmd(&ishell.Cmd{
|
|
Name: "retrieve",
|
|
Aliases: []string{"retr", "recv", "get"},
|
|
Help: "Retrieve file from all nodes. File will be saved as 'file-user@node.ext'",
|
|
Func: func(c *ishell.Context) {
|
|
// Get remote and local paths
|
|
remotePath := c.Args[0]
|
|
localPath := c.Args[1]
|
|
|
|
// Create new wait group
|
|
wg := &sync.WaitGroup{}
|
|
// For every client
|
|
for name, client := range clients {
|
|
// Get extension, base, and directory of local path
|
|
localExt := filepath.Ext(localPath)
|
|
localBase := filepath.Base(localPath)
|
|
localDir := filepath.Dir(localPath)
|
|
|
|
// Attempt to stat local path
|
|
info, err := os.Stat(localPath)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Msg("Error getting file info")
|
|
}
|
|
|
|
// If local path is a directory
|
|
if info.IsDir() {
|
|
// Set local directory to local path
|
|
localDir = localPath
|
|
// Set local extention to remote extension
|
|
localExt = filepath.Ext(remotePath)
|
|
// Set local base to remote base
|
|
localBase = filepath.Base(remotePath)
|
|
}
|
|
|
|
// Remove extension from local base
|
|
localBaseNoExt := strings.TrimSuffix(localBase, localExt)
|
|
// Create new local, formatted as "filename-user@node.extension"
|
|
newLocal := filepath.Join(localDir, localBaseNoExt+"-"+name+localExt)
|
|
|
|
// Set new remote to remote path
|
|
newRemote := remotePath
|
|
// If new remote starts with "~"
|
|
if strings.HasPrefix(newRemote, "~") {
|
|
// Replace ~ with "/home/user"
|
|
newRemote = filepath.Join(
|
|
"/home/"+client.User(),
|
|
strings.TrimPrefix(newRemote, "~"),
|
|
)
|
|
}
|
|
|
|
// Add goroutine to wait group
|
|
wg.Add(1)
|
|
go func(client *goph.Client, name string) {
|
|
// Remove from wait group at end of function
|
|
defer wg.Done()
|
|
|
|
// Download file from remote
|
|
err := client.Download(newRemote, newLocal)
|
|
if err != nil {
|
|
cmdError(shell, name, err)
|
|
return
|
|
}
|
|
|
|
// Print success message
|
|
cmdSuccess(shell, name, "download complete")
|
|
}(client, name)
|
|
|
|
// Wait for all added goroutines to complete
|
|
wg.Wait()
|
|
}
|
|
},
|
|
})
|
|
|
|
// Add file command to shell
|
|
shell.AddCmd(file)
|
|
|
|
// Run shell
|
|
shell.Run()
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
rootCmd.AddCommand(msshCmd)
|
|
msshCmd.Flags().BoolP("password", "p", false, "Prompt for SSH password on start")
|
|
msshCmd.Flags().StringP("identity", "i", "", "SSH identity file to use")
|
|
|
|
viper.BindPFlags(msshCmd.Flags())
|
|
}
|
|
|
|
func getAllIdentities() (goph.Auth, error) {
|
|
// Get user's home directory
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Get all ssh keys in "~/.ssh"
|
|
matches, err := filepath.Glob(filepath.Join(home, ".ssh", "id_*.pub"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create goph auth for keys
|
|
var out goph.Auth
|
|
// For every glob match
|
|
for _, match := range matches {
|
|
// Get path of private key
|
|
privKeyPath := strings.TrimSuffix(match, ".pub")
|
|
|
|
// Get SSH key as goph ket
|
|
auth, err := goph.Key(privKeyPath, "")
|
|
if err != nil {
|
|
return out, err
|
|
}
|
|
// Add goph key to auth
|
|
out = append(out, auth...)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func closeClients(clients map[string]*goph.Client) {
|
|
// For every client
|
|
for _, client := range clients {
|
|
// Close client
|
|
client.Close()
|
|
}
|
|
}
|