From 73b717f2e3ca45642219cd64d67caf48aeb4f637 Mon Sep 17 00:00:00 2001 From: Arsen Musayelyan Date: Thu, 3 Dec 2020 09:32:03 -0800 Subject: [PATCH] Add FatalHook for graceful shutdown in case of Fatal log --- config.go | 8 ++++---- deviceDiscovery.go | 2 +- fileCrypto.go | 8 ++++---- files.go | 6 +++--- keyCrypto.go | 6 +++--- keyExchange.go | 4 ++-- logging.go | 16 ++++++++++++++++ main.go | 9 +++++++-- 8 files changed, 40 insertions(+), 19 deletions(-) create mode 100644 logging.go diff --git a/config.go b/config.go index 02a8890..4886ea4 100644 --- a/config.go +++ b/config.go @@ -26,7 +26,7 @@ func NewConfig(actionType string, actionData string) *Config { // Create config file func (config *Config) CreateFile(dir string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create config file at given directory configFile, err := os.Create(dir + "/config.json") if err != nil { log.Fatal().Err(err).Msg("Error creating config file") } @@ -45,7 +45,7 @@ func (config *Config) CreateFile(dir string) { // Collect all required files into given directory func (config *Config) CollectFiles(dir string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // If action type is file if config.ActionType == "file" { // Open file path in config.ActionData @@ -69,7 +69,7 @@ func (config *Config) CollectFiles(dir string) { // Read config file at given file path func (config *Config) ReadFile(filePath string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Read file at filePath fileData, err := ioutil.ReadFile(filePath) if err != nil { log.Fatal().Err(err).Msg("Error reading config file") } @@ -81,7 +81,7 @@ func (config *Config) ReadFile(filePath string) { // Execute action specified in config func (config *Config) ExecuteAction(srcDir string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // If action is file if config.ActionType == "file" { // Open file from config at given directory diff --git a/deviceDiscovery.go b/deviceDiscovery.go index fd45a00..283ecdb 100644 --- a/deviceDiscovery.go +++ b/deviceDiscovery.go @@ -12,7 +12,7 @@ import ( // Discover opensend receivers on the network func DiscoverReceivers() ([]string, []string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create zeroconf resolver resolver, err := zeroconf.NewResolver(nil) if err != nil { log.Fatal().Err(err).Msg("Error creating zeroconf resolver") } diff --git a/fileCrypto.go b/fileCrypto.go index 8d2078a..0cc02e2 100644 --- a/fileCrypto.go +++ b/fileCrypto.go @@ -19,7 +19,7 @@ import ( // Encrypt given file using the shared key func EncryptFile(filePath string, newFilePath string, sharedKey string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Read data from file data, err := ioutil.ReadFile(filePath) if err != nil { log.Fatal().Err(err).Msg("Error reading file") } @@ -54,7 +54,7 @@ func EncryptFile(filePath string, newFilePath string, sharedKey string) { // Decrypt given file using the shared key func DecryptFile(filePath string, newFilePath string, sharedKey string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Read data from file data, err := ioutil.ReadFile(filePath) if err != nil { log.Fatal().Err(err).Msg("Error reading file") } @@ -88,7 +88,7 @@ func DecryptFile(filePath string, newFilePath string, sharedKey string) { // Encrypt files in given directory using shared key func EncryptFiles(dir string, sharedKey string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Walk given directory err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { // If error reading, return err @@ -110,7 +110,7 @@ func EncryptFiles(dir string, sharedKey string) { // Decrypt files in given directory using shared key func DecryptFiles(dir string, sharedKey string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Walk given directory err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { // If error reading, return err diff --git a/files.go b/files.go index b0e0e78..f58b8cc 100644 --- a/files.go +++ b/files.go @@ -18,7 +18,7 @@ import ( // Save encrypted key to file func SaveEncryptedKey(encryptedKey []byte, filePath string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create file at given file path keyFile, err := os.Create(filePath) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } @@ -34,7 +34,7 @@ func SaveEncryptedKey(encryptedKey []byte, filePath string) { // Create HTTP server to transmit files func SendFiles(dir string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Instantiate http.Server struct srv := &http.Server{} // Listen on all ipv4 addresses on port 9898 @@ -114,7 +114,7 @@ func SendFiles(dir string) { // Get files from sender func RecvFiles(dir string, senderAddr string) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Get server address by getting the IP without the port, prepending http:// and appending :9898 serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898" // GET /index on sender's HTTP server diff --git a/keyCrypto.go b/keyCrypto.go index 89f62ac..51929c3 100644 --- a/keyCrypto.go +++ b/keyCrypto.go @@ -15,7 +15,7 @@ import ( // Generate RSA keypair func GenerateRSAKeypair() (*rsa.PrivateKey, *rsa.PublicKey) { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Generate private/public RSA keypair privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { log.Fatal().Err(err).Msg("Error generating RSA keypair") } @@ -28,7 +28,7 @@ func GenerateRSAKeypair() (*rsa.PrivateKey, *rsa.PublicKey) { // Get public key from sender func GetKey(senderAddr string) []byte { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Get server address by getting the IP without the port, prepending http:// and appending :9898 serverAddr := "http://" + strings.Split(senderAddr, ":")[0] + ":9898" // GET /key on the sender's HTTP server @@ -55,7 +55,7 @@ func GetKey(senderAddr string) []byte { // Encrypt shared key with received public key func EncryptKey(sharedKey string, recvPubKey *rsa.PublicKey) []byte { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Encrypt shared key using RSA encryptedSharedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, recvPubKey, []byte(sharedKey), nil) if err != nil { log.Fatal().Err(err).Msg("Error encrypting shared key") } diff --git a/keyExchange.go b/keyExchange.go index db55d1c..4bb59de 100644 --- a/keyExchange.go +++ b/keyExchange.go @@ -12,7 +12,7 @@ import ( // Exchange keys with sender func ReceiverKeyExchange(key *rsa.PublicKey) string { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create TCP listener on port 9797 listener, err := net.Listen("tcp", ":9797") if err != nil { log.Fatal().Err(err).Msg("Error starting listener") } @@ -49,7 +49,7 @@ func ReceiverKeyExchange(key *rsa.PublicKey) string { // Exchange keys with receiver func SenderKeyExchange(receiverIP string) *rsa.PublicKey { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Connect to TCP socket on receiver IP port 9797 connection, err := net.Dial("tcp", receiverIP + ":9797") if err != nil { log.Fatal().Err(err).Msg("Error connecting to sender") } diff --git a/logging.go b/logging.go new file mode 100644 index 0000000..24abb20 --- /dev/null +++ b/logging.go @@ -0,0 +1,16 @@ +package main + +import ( + "github.com/rs/zerolog" + "os" +) + +type FatalHook struct {} + +func (hook FatalHook) Run(_ *zerolog.Event, level zerolog.Level, _ string) { + // If log event is fatal + if level == zerolog.FatalLevel { + // Attempt removal of opensend directory + _ = os.RemoveAll(opensendDir) + } +} diff --git a/main.go b/main.go index 09dea2f..dcfb1a2 100644 --- a/main.go +++ b/main.go @@ -17,15 +17,20 @@ import ( "time" ) +var opensendDir string + func main() { // Use ConsoleWriter logger - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) + if os.Args[1] == "f" { + log.Fatal().Msg("Test") + } // Get user's home directory homeDir, err := os.UserHomeDir() if err != nil { log.Fatal().Err(err).Msg("Error getting home directory") } // Define opensend directory as ~/.opensend - opensendDir := homeDir + "/.opensend" + opensendDir = homeDir + "/.opensend" // Create channel for signals sig := make(chan os.Signal, 1)