/* Copyright © 2021 Arsen Musayelyan Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package main import ( "bufio" "bytes" "ekyu.moe/base91" "errors" "fmt" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "io" "io/ioutil" "net" "os" "path/filepath" "regexp" "strconv" "strings" ) // Encode byte slice to base91 encoded, escaped string for transmission func EncodeToSafeString(data []byte) string { // Encode data to base91 string base91Data := base91.EncodeToString(data) // Create regex to match |, ;, and % escapeRegex := regexp.MustCompile(`[|;%]`) // Map each character to its escape code escapeMap := map[string]string{"|": "%7C", ";": "%3B", "%": "%25"} // Replace all matched characters in accordance with the function escapedBase91Data := escapeRegex.ReplaceAllStringFunc(base91Data, func(in string) string { // Return character mapping return escapeMap[in] }) // Return escaped base91 string return escapedBase91Data } // Decode base91 encoded, escaped string to byte slice func DecodeSafeString(base91Str string) ([]byte, error) { // Create regex to match %7C, %3B, %25 unescapeRegex := regexp.MustCompile(`%7C|%3B|%25`) // Map each escape code to its character escapeMap := map[string]string{"%7C": "|", "%3B": ";", "%25": "%"} // Replace all matched characters in accordance with the function base91Data := unescapeRegex.ReplaceAllStringFunc(base91Str, func(in string) string { // Return escape code mapping return escapeMap[in] }) // Decode unescaped base91 string data := base91.DecodeString(base91Data) // Return byte slice return data, nil } // Save encrypted key to file func SaveEncryptedKey(encryptedKey []byte, filePath string) { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create file at given file path keyFile, err := os.Create(filePath) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } // Close file at the end of this function defer keyFile.Close() // Write encrypted key to file bytesWritten, err := keyFile.Write(encryptedKey) if err != nil { log.Fatal().Err(err).Msg("Error writing key to file") } // Log bytes written log.Info().Str("file", filepath.Base(filePath)).Msg("Wrote " + strconv.Itoa(bytesWritten) + " bytes") } // Create HTTP server to transmit files func SendFiles(dir string) { // Use ConsoleWriter logger with normal FatalHook log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Create TCP listener on port 9898 listener, err := net.Listen("tcp", ":9898") if err != nil { log.Fatal().Err(err).Msg("Error starting listener") } // Accept connection on listener connection, err := listener.Accept() if err != nil { log.Fatal().Err(err).Msg("Error accepting connection") } // Close connection at the end of this function defer connection.Close() // Create for loop to listen for messages on connection connectionLoop: for { // Use ConsoleWriter logger with TCPFatalHook log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(TCPFatalHook{conn: connection}) // Attempt to read new message on connection data, err := bufio.NewReader(connection).ReadString('\n') // If no message detected, try again if err != nil && err.Error() == "EOF" { continue } // If non-EOF error, fatally log if err != nil { log.Fatal().Err(err).Msg("Error reading data") } // Process received data processedData := strings.Split(strings.TrimSpace(data), ";") // If processedData is empty, alert the user of invalid data if len(processedData) < 1 { log.Fatal().Str("data", data).Msg("Received data invalid") } switch processedData[0] { case "key": // Inform user client has requested key log.Info().Msg("Key requested") // Read saved key key, err := ioutil.ReadFile(dir + "/key.aes") if err != nil { log.Fatal().Err(err).Msg("Error reading key") } // Write saved key to ResponseWriter _, err = fmt.Fprintln(connection, "OK;"+EncodeToSafeString(key)+";") if err != nil { log.Fatal().Err(err).Msg("Error writing response") } case "index": // Inform user a client has requested the file index log.Info().Msg("Index requested") // Get directory listing dirListing, err := ioutil.ReadDir(dir) if err != nil { log.Fatal().Err(err).Msg("Error reading directory") } // Create new slice to house filenames for index var indexSlice []string // For each file in listing for _, file := range dirListing { // If the file is not the key if !strings.Contains(file.Name(), "key.aes") { // Append the file path to indexSlice indexSlice = append(indexSlice, file.Name()) } } // Join index slice into string indexStr := strings.Join(indexSlice, "|") // Write index to ResponseWriter _, err = fmt.Fprintln(connection, "OK;"+indexStr+";") if err != nil { log.Fatal().Err(err).Msg("Error writing response") } case "file": // If processedData only has one entry if len(processedData) == 1 { // Warn user of unexpected end of line log.Warn().Err(errors.New("unexpected eol")).Msg("Invalid file request") // Send error to connection _, _ = fmt.Fprintln(connection, "ERR;") // Break out of switch break } // Set file to first path components of URL, excluding first / file := processedData[1] // Read file at specified location fileData, err := ioutil.ReadFile(dir + "/" + file) // If there was an error reading if err != nil { // Warn user of error log.Warn().Err(err).Msg("Error reading file") // Otherwise } else { // Inform user client has requested a file log.Info().Str("file", file).Msg("File requested") } // Write file as base91 to connection _, err = fmt.Fprintln(connection, "OK;"+EncodeToSafeString(fileData)+";") if err != nil { log.Fatal().Err(err).Msg("Error writing response") } case "stop": // Alert user that stop signal has been received log.Info().Msg("Received stop signal") // Print ok message to connection _, _ = fmt.Fprintln(connection, "OK;") // Break out of connectionLoop break connectionLoop } } } func ConnectToSender(senderAddr string) net.Conn { // Get server address by getting the IP without the port, and appending :9898 serverAddr := strings.Split(senderAddr, ":")[0] + ":9898" // Create error variable var err error // Create connection variable var connection net.Conn // Until break for { // Try connecting to sender connection, err = net.Dial("tcp", serverAddr) // If connection refused if err != nil && strings.Contains(err.Error(), "connection refused") { // Continue loop (retry) continue // If error other than connection refused } else if err != nil { // Fatally log log.Fatal().Err(err).Msg("Error connecting to sender") // If no error } else { // Break out of loop break } } // Returned created connection return connection } // Get files from sender func RecvFiles(connection net.Conn) { // Use ConsoleWriter logger log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).Hook(FatalHook{}) // Request index from sender _, err := fmt.Fprintln(connection, "index;") if err != nil { log.Fatal().Err(err).Msg("Error sending index request") } // Read received message message, err := bufio.NewReader(connection).ReadString('\n') if err != nil { log.Fatal().Err(err).Msg("Error getting index") } // Process received message procMessage := strings.Split(strings.TrimSpace(message), ";") // If non-ok code returned, fatally log if procMessage[0] != "OK" { log.Fatal().Err(err).Msg("Sender reported error") } // Get index from message index := strings.Split(strings.TrimSpace(procMessage[1]), "|") for _, file := range index { // Get current file in index _, err = fmt.Fprintln(connection, "file;"+file+";") if err != nil { log.Fatal().Err(err).Msg("Error sending file request") } // Read received message message, err := bufio.NewReader(connection).ReadString('\n') if err != nil { log.Fatal().Err(err).Msg("Error getting file") } // Process received message procMessage := strings.Split(message, ";") // If non-ok code returned if procMessage[0] != "OK" { // fatally log log.Fatal().Err(err).Msg("Sender reported error") // Otherwise } else { // Create new file at index filepath newFile, err := os.Create(*workDir + "/" + file) if err != nil { log.Fatal().Err(err).Msg("Error creating file") } // Decode file data from base91 string fileData, err := DecodeSafeString(strings.TrimSpace(procMessage[1])) if err != nil { log.Fatal().Err(err).Msg("Error decoding hex") } // Copy response body to new file bytesWritten, err := io.Copy(newFile, bytes.NewBuffer(fileData)) if err != nil { log.Fatal().Err(err).Msg("Error writing to file") } // Log bytes written log.Info().Str("file", filepath.Base(file)).Msg("Wrote " + strconv.Itoa(int(bytesWritten)) + " bytes") // Close new file newFile.Close() } } } // Send stop signal to sender func SendSrvStopSignal(connection net.Conn) { // Send stop signal to connection _, _ = fmt.Fprintln(connection, "stop;") // Close connection _ = connection.Close() }