Use request type for error response type

This commit is contained in:
Elara 2021-10-24 01:09:27 -07:00
parent 28610d9ebb
commit 0d164aef3d

View File

@ -22,7 +22,6 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -100,11 +99,6 @@ func startSocket(dev *infinitime.Device) error {
func handleConnection(conn net.Conn, dev *infinitime.Device) { func handleConnection(conn net.Conn, dev *infinitime.Device) {
defer conn.Close() defer conn.Close()
// If firmware is updating, return error
if firmwareUpdating {
connErr(conn, nil, "Firmware update in progress")
return
}
// Create new scanner on connection // Create new scanner on connection
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(conn)
@ -113,16 +107,22 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Decode scanned message into types.Request // Decode scanned message into types.Request
err := json.Unmarshal(scanner.Bytes(), &req) err := json.Unmarshal(scanner.Bytes(), &req)
if err != nil { if err != nil {
connErr(conn, err, "Error decoding JSON input") connErr(conn, req.Type, err, "Error decoding JSON input")
continue continue
} }
// If firmware is updating, return error
if firmwareUpdating {
connErr(conn, req.Type, nil, "Firmware update in progress")
return
}
switch req.Type { switch req.Type {
case types.ReqTypeHeartRate: case types.ReqTypeHeartRate:
// Get heart rate from watch // Get heart rate from watch
heartRate, err := dev.HeartRate() heartRate, err := dev.HeartRate()
if err != nil { if err != nil {
connErr(conn, err, "Error getting heart rate") connErr(conn, req.Type, err, "Error getting heart rate")
break break
} }
// Encode heart rate to connection // Encode heart rate to connection
@ -133,7 +133,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeWatchHeartRate: case types.ReqTypeWatchHeartRate:
heartRateCh, cancel, err := dev.WatchHeartRate() heartRateCh, cancel, err := dev.WatchHeartRate()
if err != nil { if err != nil {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, req.Type, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String() reqID := uuid.New().String()
@ -161,7 +161,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Get battery level from watch // Get battery level from watch
battLevel, err := dev.BatteryLevel() battLevel, err := dev.BatteryLevel()
if err != nil { if err != nil {
connErr(conn, err, "Error getting battery level") connErr(conn, req.Type, err, "Error getting battery level")
break break
} }
// Encode battery level to connection // Encode battery level to connection
@ -172,7 +172,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeWatchBattLevel: case types.ReqTypeWatchBattLevel:
battLevelCh, cancel, err := dev.WatchBatteryLevel() battLevelCh, cancel, err := dev.WatchBatteryLevel()
if err != nil { if err != nil {
connErr(conn, err, "Error getting battery level channel") connErr(conn, req.Type, err, "Error getting battery level channel")
break break
} }
reqID := uuid.New().String() reqID := uuid.New().String()
@ -200,7 +200,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Get battery level from watch // Get battery level from watch
motionVals, err := dev.Motion() motionVals, err := dev.Motion()
if err != nil { if err != nil {
connErr(conn, err, "Error getting motion values") connErr(conn, req.Type, err, "Error getting motion values")
break break
} }
// Encode battery level to connection // Encode battery level to connection
@ -211,7 +211,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeWatchMotion: case types.ReqTypeWatchMotion:
motionValCh, cancel, err := dev.WatchMotion() motionValCh, cancel, err := dev.WatchMotion()
if err != nil { if err != nil {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, req.Type, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String() reqID := uuid.New().String()
@ -240,7 +240,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Get battery level from watch // Get battery level from watch
stepCount, err := dev.StepCount() stepCount, err := dev.StepCount()
if err != nil { if err != nil {
connErr(conn, err, "Error getting step count") connErr(conn, req.Type, err, "Error getting step count")
break break
} }
// Encode battery level to connection // Encode battery level to connection
@ -251,7 +251,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeWatchStepCount: case types.ReqTypeWatchStepCount:
stepCountCh, cancel, err := dev.WatchStepCount() stepCountCh, cancel, err := dev.WatchStepCount()
if err != nil { if err != nil {
connErr(conn, err, "Error getting heart rate channel") connErr(conn, req.Type, err, "Error getting heart rate channel")
break break
} }
reqID := uuid.New().String() reqID := uuid.New().String()
@ -279,7 +279,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Get firmware version from watch // Get firmware version from watch
version, err := dev.Version() version, err := dev.Version()
if err != nil { if err != nil {
connErr(conn, err, "Error getting firmware version") connErr(conn, req.Type, err, "Error getting firmware version")
break break
} }
// Encode version to connection // Encode version to connection
@ -296,14 +296,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeNotify: case types.ReqTypeNotify:
// If no data, return error // If no data, return error
if req.Data == nil { if req.Data == nil {
connErr(conn, nil, "Data required for notify request") connErr(conn, req.Type, nil, "Data required for notify request")
break break
} }
var reqData types.ReqDataNotify var reqData types.ReqDataNotify
// Decode data map to notify request data // Decode data map to notify request data
err = mapstructure.Decode(req.Data, &reqData) err = mapstructure.Decode(req.Data, &reqData)
if err != nil { if err != nil {
connErr(conn, err, "Error decoding request data") connErr(conn, req.Type, err, "Error decoding request data")
break break
} }
maps := viper.GetStringSlice("notifs.translit.use") maps := viper.GetStringSlice("notifs.translit.use")
@ -313,7 +313,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Send notification to watch // Send notification to watch
err = dev.Notify(title, body) err = dev.Notify(title, body)
if err != nil { if err != nil {
connErr(conn, err, "Error sending notification") connErr(conn, req.Type, err, "Error sending notification")
break break
} }
// Encode empty types.Response to connection // Encode empty types.Response to connection
@ -321,13 +321,13 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeSetTime: case types.ReqTypeSetTime:
// If no data, return error // If no data, return error
if req.Data == nil { if req.Data == nil {
connErr(conn, nil, "Data required for settime request") connErr(conn, req.Type, nil, "Data required for settime request")
break break
} }
// Get string from data or return error // Get string from data or return error
reqTimeStr, ok := req.Data.(string) reqTimeStr, ok := req.Data.(string)
if !ok { if !ok {
connErr(conn, nil, "Data for settime request must be RFC3339 formatted time string") connErr(conn, req.Type, nil, "Data for settime request must be RFC3339 formatted time string")
break break
} }
@ -338,14 +338,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Parse time as RFC3339/ISO8601 // Parse time as RFC3339/ISO8601
reqTime, err = time.Parse(time.RFC3339, reqTimeStr) reqTime, err = time.Parse(time.RFC3339, reqTimeStr)
if err != nil { if err != nil {
connErr(conn, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`") connErr(conn, req.Type, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`")
break break
} }
} }
// Set time on watch // Set time on watch
err = dev.SetTime(reqTime) err = dev.SetTime(reqTime)
if err != nil { if err != nil {
connErr(conn, err, "Error setting device time") connErr(conn, req.Type, err, "Error setting device time")
break break
} }
// Encode empty types.Response to connection // Encode empty types.Response to connection
@ -353,14 +353,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.ReqTypeFwUpgrade: case types.ReqTypeFwUpgrade:
// If no data, return error // If no data, return error
if req.Data == nil { if req.Data == nil {
connErr(conn, nil, "Data required for firmware upgrade request") connErr(conn, req.Type, nil, "Data required for firmware upgrade request")
break break
} }
var reqData types.ReqDataFwUpgrade var reqData types.ReqDataFwUpgrade
// Decode data map to firmware upgrade request data // Decode data map to firmware upgrade request data
err = mapstructure.Decode(req.Data, &reqData) err = mapstructure.Decode(req.Data, &reqData)
if err != nil { if err != nil {
connErr(conn, err, "Error decoding request data") connErr(conn, req.Type, err, "Error decoding request data")
break break
} }
// Reset DFU to prepare for next update // Reset DFU to prepare for next update
@ -369,40 +369,40 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
case types.UpgradeTypeArchive: case types.UpgradeTypeArchive:
// If less than one file, return error // If less than one file, return error
if len(reqData.Files) < 1 { if len(reqData.Files) < 1 {
connErr(conn, nil, "Archive upgrade requires one file with .zip extension") connErr(conn, req.Type, nil, "Archive upgrade requires one file with .zip extension")
break break
} }
// If file is not zip archive, return error // If file is not zip archive, return error
if filepath.Ext(reqData.Files[0]) != ".zip" { if filepath.Ext(reqData.Files[0]) != ".zip" {
connErr(conn, nil, "Archive upgrade file must be a zip archive") connErr(conn, req.Type, nil, "Archive upgrade file must be a zip archive")
break break
} }
// Load DFU archive // Load DFU archive
err := dev.DFU.LoadArchive(reqData.Files[0]) err := dev.DFU.LoadArchive(reqData.Files[0])
if err != nil { if err != nil {
connErr(conn, err, "Error loading archive file") connErr(conn, req.Type, err, "Error loading archive file")
break break
} }
case types.UpgradeTypeFiles: case types.UpgradeTypeFiles:
// If less than two files, return error // If less than two files, return error
if len(reqData.Files) < 2 { if len(reqData.Files) < 2 {
connErr(conn, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.") connErr(conn, req.Type, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.")
break break
} }
// If first file is not init packet, return error // If first file is not init packet, return error
if filepath.Ext(reqData.Files[0]) != ".dat" { if filepath.Ext(reqData.Files[0]) != ".dat" {
connErr(conn, nil, "First file must be a .dat file") connErr(conn, req.Type, nil, "First file must be a .dat file")
break break
} }
// If second file is not firmware image, return error // If second file is not firmware image, return error
if filepath.Ext(reqData.Files[1]) != ".bin" { if filepath.Ext(reqData.Files[1]) != ".bin" {
connErr(conn, nil, "Second file must be a .bin file") connErr(conn, req.Type, nil, "Second file must be a .bin file")
break break
} }
// Load individual DFU files // Load individual DFU files
err := dev.DFU.LoadFiles(reqData.Files[0], reqData.Files[1]) err := dev.DFU.LoadFiles(reqData.Files[0], reqData.Files[1])
if err != nil { if err != nil {
connErr(conn, err, "Error loading firmware files") connErr(conn, req.Type, err, "Error loading firmware files")
break break
} }
} }
@ -426,19 +426,19 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
// Start DFU // Start DFU
err = dev.DFU.Start() err = dev.DFU.Start()
if err != nil { if err != nil {
connErr(conn, err, "Error performing upgrade") connErr(conn, req.Type, err, "Error performing upgrade")
firmwareUpdating = false firmwareUpdating = false
break break
} }
firmwareUpdating = false firmwareUpdating = false
case types.ReqTypeCancel: case types.ReqTypeCancel:
if req.Data == nil { if req.Data == nil {
connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.") connErr(conn, req.Type, nil, "No data provided. Cancel request requires request ID string as data.")
continue continue
} }
reqID, ok := req.Data.(string) reqID, ok := req.Data.(string)
if !ok { if !ok {
connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.") connErr(conn, req.Type, nil, "Invalid data. Cancel request required request ID string as data.")
} }
// Stop notifications // Stop notifications
done.Done(reqID) done.Done(reqID)
@ -447,7 +447,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
} }
} }
func connErr(conn net.Conn, err error, msg string) { func connErr(conn net.Conn, resType int, err error, msg string) {
var res types.Response var res types.Response
// If error exists, add to types.Response, otherwise don't // If error exists, add to types.Response, otherwise don't
if err != nil { if err != nil {
@ -455,7 +455,7 @@ func connErr(conn net.Conn, err error, msg string) {
res = types.Response{Message: fmt.Sprintf("%s: %s", msg, err)} res = types.Response{Message: fmt.Sprintf("%s: %s", msg, err)}
} else { } else {
log.Error().Msg(msg) log.Error().Msg(msg)
res = types.Response{Message: msg, Type: math.MaxInt} res = types.Response{Message: msg, Type: resType}
} }
res.Error = true res.Error = true