diff --git a/config.go b/config.go index 4886ea4..1eabbd3 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,7 @@ package main import ( + "archive/tar" "encoding/json" "github.com/pkg/browser" "github.com/rs/zerolog" @@ -10,6 +11,7 @@ import ( "os" "path/filepath" "strconv" + "strings" ) // Create config type to store action type and data @@ -63,6 +65,42 @@ func (config *Config) CollectFiles(dir string) { if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") } // Replace file path in config.ActionData with file name config.ActionData = filepath.Base(config.ActionData) + } else if config.ActionType == "dir" { + // Create tar archive + tarFile, err := os.Create(dir + "/" + filepath.Base(config.ActionData) + ".tar") + if err != nil { log.Fatal().Err(err).Msg("Error creating file") } + // Close tar file at the end of this function + defer tarFile.Close() + // Create writer for tar archive + tarArchiver := tar.NewWriter(tarFile) + // Close archiver at the end of this function + defer tarArchiver.Close() + // Walk given directory + err = filepath.Walk(config.ActionData, func(path string, info os.FileInfo, err error) error { + // Return if error walking + if err != nil { return err } + // Skip if file is not normal mode + if !info.Mode().IsRegular() { return nil } + // Create tar header for file + header, err := tar.FileInfoHeader(info, info.Name()) + if err != nil { return err } + // Change header name to reflect decompressed filepath + header.Name = strings.TrimPrefix(strings.ReplaceAll(path, config.ActionData, ""), string(filepath.Separator)) + // Write header to archive + if err := tarArchiver.WriteHeader(header); err != nil { return err } + // Open source file + src, err := os.Open(path) + if err != nil { return err } + // Close source file at the end of this function + defer src.Close() + // Copy source bytes to tar archive + if _, err := io.Copy(tarArchiver, src); err != nil { return err } + // Return at the end of the function + return nil + }) + if err != nil { log.Fatal().Err(err).Msg("Error creating tar archive") } + // Set config data to base path for receiver + config.ActionData = filepath.Base(config.ActionData) } } @@ -80,6 +118,9 @@ func (config *Config) ReadFile(filePath string) { // Execute action specified in config func (config *Config) ExecuteAction(srcDir string) { + // Get user's home directory + homeDir, err := os.UserHomeDir() + if err != nil { log.Fatal().Err(err).Msg("Error getting home directory") } // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // If action is file @@ -89,9 +130,6 @@ func (config *Config) ExecuteAction(srcDir string) { if err != nil { log.Fatal().Err(err).Msg("Error reading file from config") } // Close source file at the end of this function defer src.Close() - // Get user's home directory - homeDir, err := os.UserHomeDir() - if err != nil { log.Fatal().Err(err).Msg("Error getting home directory") } // Create file in user's Downloads directory dst, err := os.Create(homeDir + "/Downloads/" + config.ActionData) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } @@ -105,6 +143,51 @@ func (config *Config) ExecuteAction(srcDir string) { // Attempt to open URL in browser err := browser.OpenURL(config.ActionData) if err != nil { log.Fatal().Err(err).Msg("Error opening browser") } + // If action is dir + } else if config.ActionType == "dir" { + // Set destination directory to ~/Downloads/{dir name} + dstDir := homeDir + "/Downloads/" + config.ActionData + // Try to create destination directory + err := os.Mkdir(dstDir, 0755) + if err != nil { log.Fatal().Err(err).Msg("Error creating directory") } + // Try to open tar archive file + tarFile, err := os.Open(srcDir + "/" + config.ActionData + ".tar") + if err != nil { log.Fatal().Err(err).Msg("Error opening tar archive") } + // Close tar archive file at the end of this function + defer tarFile.Close() + // Create tar reader to unarchive tar archive + tarUnarchiver := tar.NewReader(tarFile) + // Loop to recursively unarchive tar file + unarchiveLoop: for { + // Jump to next header in tar archive + header, err := tarUnarchiver.Next() + switch { + // If EOF + case err == io.EOF: + // break loop + break unarchiveLoop + case err != nil: + log.Fatal().Err(err).Msg("Error unarchiving tar archive") + // If nil header + case header == nil: + // Skip + continue + } + // Set target path to header name in destination dir + targetPath := filepath.Join(dstDir, header.Name) + switch header.Typeflag { + // If regular file + case tar.TypeReg: + // Try to create containing folder ignoring errors + _ = os.MkdirAll(strings.TrimSuffix(targetPath, filepath.Base(targetPath)), 0755) + // Create file with mode contained in header at target path + dstFile, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) + if err != nil { log.Fatal().Err(err).Msg("Error creating file during unarchiving") } + // Copy data from tar archive into file + _, err = io.Copy(dstFile, tarUnarchiver) + if err != nil { log.Fatal().Err(err).Msg("Error copying data to file") } + } + } // Catchall } else { // Log unknown action type diff --git a/fileCrypto.go b/fileCrypto.go index cb6caa5..9f81224 100644 --- a/fileCrypto.go +++ b/fileCrypto.go @@ -1,11 +1,13 @@ package main import ( + "bytes" "crypto/aes" "crypto/cipher" "crypto/md5" "crypto/rand" "encoding/hex" + "github.com/klauspost/compress/zstd" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "io" @@ -17,12 +19,20 @@ import ( ) // Encrypt given file using the shared key -func EncryptFile(filePath string, newFilePath string, sharedKey string) { +func CompressAndEncryptFile(filePath string, newFilePath string, sharedKey string) { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Read data from file - data, err := ioutil.ReadFile(filePath) + file, err := os.Open(filePath) + if err != nil { log.Fatal().Err(err).Msg("Error opening file") } + compressedBuffer := new(bytes.Buffer) + zstdEncoder, err := zstd.NewWriter(compressedBuffer) + if err != nil { log.Fatal().Err(err).Msg("Error creating Zstd encoder") } + _, err = io.Copy(zstdEncoder, file) if err != nil { log.Fatal().Err(err).Msg("Error reading file") } + zstdEncoder.Close() + data, err := ioutil.ReadAll(compressedBuffer) + if err != nil { log.Fatal().Err(err).Msg("Error reading compressed buffer") } // Create md5 hash of password in order to make it the required size md5Hash := md5.New() md5Hash.Write([]byte(sharedKey)) @@ -54,7 +64,7 @@ func EncryptFile(filePath string, newFilePath string, sharedKey string) { } // Decrypt given file using the shared key -func DecryptFile(filePath string, newFilePath string, sharedKey string) { +func DecryptAndDecompressFile(filePath string, newFilePath string, sharedKey string) { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Read data from file @@ -76,16 +86,19 @@ func DecryptFile(filePath string, newFilePath string, sharedKey string) { // Decrypt data plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { log.Fatal().Err(err).Msg("Error decrypting data") } + zstdDecoder, err := zstd.NewReader(bytes.NewBuffer(plaintext)) + if err != nil { log.Fatal().Err(err).Msg("Error creating Zstd decoder") } // Create new file newFile, err := os.Create(newFilePath) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } // Defer file close defer newFile.Close() - // Write ciphertext to new file - bytesWritten, err := newFile.Write(plaintext) + // Write plaintext to new file + bytesWritten, err := io.Copy(newFile, zstdDecoder) if err != nil { log.Fatal().Err(err).Msg("Error writing to file") } + zstdDecoder.Close() // Log bytes written and to which file - log.Info().Str("file", filepath.Base(newFilePath)).Msg("Wrote " + strconv.Itoa(bytesWritten) + " bytes") + log.Info().Str("file", filepath.Base(newFilePath)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes") } // Encrypt files in given directory using shared key @@ -99,7 +112,7 @@ func EncryptFiles(dir string, sharedKey string) { // If file is not a directory and is not the key if !info.IsDir() && !strings.Contains(path, "key.aes"){ // Encrypt the file using shared key, appending .enc - EncryptFile(path, path + ".enc", sharedKey) + CompressAndEncryptFile(path, path + ".zst.enc", sharedKey) // Remove unencrypted file err := os.Remove(path) if err != nil { return err } @@ -121,7 +134,7 @@ func DecryptFiles(dir string, sharedKey string) { // If file is not a directory and is encrypted if !info.IsDir() && strings.Contains(path, ".enc") { // Decrypt the file using the shared key, removing .enc - DecryptFile(path, strings.TrimSuffix(path, ".enc"), sharedKey) + DecryptAndDecompressFile(path, strings.TrimSuffix(path, ".zst.enc"), sharedKey) } // Return nil if no errors occurred return nil diff --git a/go.mod b/go.mod index 26e9f9a..988eab4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( github.com/grandcat/zeroconf v1.0.0 + github.com/klauspost/compress v1.11.3 github.com/pkg/browser v0.0.0-20201112035734-206646e67786 github.com/rs/zerolog v1.20.0 ) diff --git a/go.sum b/go.sum index dfef2ce..cf6f34e 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE= github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs= +github.com/klauspost/compress v1.11.3 h1:dB4Bn0tN3wdCzQxnS8r06kV74qN/TAfaIS0bVE8h3jc= +github.com/klauspost/compress v1.11.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/pkg/browser v0.0.0-20201112035734-206646e67786 h1:4Gk0Dsp90g2YwfsxDOjvkEIgKGh+2R9FlvormRycveA=