forked from Elara6331/itd
		
	Use request type for error response type
This commit is contained in:
		
							
								
								
									
										74
									
								
								socket.go
									
									
									
									
									
								
							
							
						
						
									
										74
									
								
								socket.go
									
									
									
									
									
								
							| @@ -22,7 +22,6 @@ import ( | |||||||
| 	"bufio" | 	"bufio" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math" |  | ||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
| @@ -100,11 +99,6 @@ func startSocket(dev *infinitime.Device) error { | |||||||
|  |  | ||||||
| func handleConnection(conn net.Conn, dev *infinitime.Device) { | func handleConnection(conn net.Conn, dev *infinitime.Device) { | ||||||
| 	defer conn.Close() | 	defer conn.Close() | ||||||
| 	// If firmware is updating, return error |  | ||||||
| 	if firmwareUpdating { |  | ||||||
| 		connErr(conn, nil, "Firmware update in progress") |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Create new scanner on connection | 	// Create new scanner on connection | ||||||
| 	scanner := bufio.NewScanner(conn) | 	scanner := bufio.NewScanner(conn) | ||||||
| @@ -113,16 +107,22 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		// Decode scanned message into types.Request | 		// Decode scanned message into types.Request | ||||||
| 		err := json.Unmarshal(scanner.Bytes(), &req) | 		err := json.Unmarshal(scanner.Bytes(), &req) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			connErr(conn, err, "Error decoding JSON input") | 			connErr(conn, req.Type, err, "Error decoding JSON input") | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// If firmware is updating, return error | ||||||
|  | 		if firmwareUpdating { | ||||||
|  | 			connErr(conn, req.Type, nil, "Firmware update in progress") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		switch req.Type { | 		switch req.Type { | ||||||
| 		case types.ReqTypeHeartRate: | 		case types.ReqTypeHeartRate: | ||||||
| 			// Get heart rate from watch | 			// Get heart rate from watch | ||||||
| 			heartRate, err := dev.HeartRate() | 			heartRate, err := dev.HeartRate() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting heart rate") | 				connErr(conn, req.Type, err, "Error getting heart rate") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode heart rate to connection | 			// Encode heart rate to connection | ||||||
| @@ -133,7 +133,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeWatchHeartRate: | 		case types.ReqTypeWatchHeartRate: | ||||||
| 			heartRateCh, cancel, err := dev.WatchHeartRate() | 			heartRateCh, cancel, err := dev.WatchHeartRate() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting heart rate channel") | 				connErr(conn, req.Type, err, "Error getting heart rate channel") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			reqID := uuid.New().String() | 			reqID := uuid.New().String() | ||||||
| @@ -161,7 +161,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Get battery level from watch | 			// Get battery level from watch | ||||||
| 			battLevel, err := dev.BatteryLevel() | 			battLevel, err := dev.BatteryLevel() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting battery level") | 				connErr(conn, req.Type, err, "Error getting battery level") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode battery level to connection | 			// Encode battery level to connection | ||||||
| @@ -172,7 +172,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeWatchBattLevel: | 		case types.ReqTypeWatchBattLevel: | ||||||
| 			battLevelCh, cancel, err := dev.WatchBatteryLevel() | 			battLevelCh, cancel, err := dev.WatchBatteryLevel() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting battery level channel") | 				connErr(conn, req.Type, err, "Error getting battery level channel") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			reqID := uuid.New().String() | 			reqID := uuid.New().String() | ||||||
| @@ -200,7 +200,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Get battery level from watch | 			// Get battery level from watch | ||||||
| 			motionVals, err := dev.Motion() | 			motionVals, err := dev.Motion() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting motion values") | 				connErr(conn, req.Type, err, "Error getting motion values") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode battery level to connection | 			// Encode battery level to connection | ||||||
| @@ -211,7 +211,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeWatchMotion: | 		case types.ReqTypeWatchMotion: | ||||||
| 			motionValCh, cancel, err := dev.WatchMotion() | 			motionValCh, cancel, err := dev.WatchMotion() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting heart rate channel") | 				connErr(conn, req.Type, err, "Error getting heart rate channel") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			reqID := uuid.New().String() | 			reqID := uuid.New().String() | ||||||
| @@ -240,7 +240,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Get battery level from watch | 			// Get battery level from watch | ||||||
| 			stepCount, err := dev.StepCount() | 			stepCount, err := dev.StepCount() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting step count") | 				connErr(conn, req.Type, err, "Error getting step count") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode battery level to connection | 			// Encode battery level to connection | ||||||
| @@ -251,7 +251,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeWatchStepCount: | 		case types.ReqTypeWatchStepCount: | ||||||
| 			stepCountCh, cancel, err := dev.WatchStepCount() | 			stepCountCh, cancel, err := dev.WatchStepCount() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting heart rate channel") | 				connErr(conn, req.Type, err, "Error getting heart rate channel") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			reqID := uuid.New().String() | 			reqID := uuid.New().String() | ||||||
| @@ -279,7 +279,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Get firmware version from watch | 			// Get firmware version from watch | ||||||
| 			version, err := dev.Version() | 			version, err := dev.Version() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error getting firmware version") | 				connErr(conn, req.Type, err, "Error getting firmware version") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode version to connection | 			// Encode version to connection | ||||||
| @@ -296,14 +296,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeNotify: | 		case types.ReqTypeNotify: | ||||||
| 			// If no data, return error | 			// If no data, return error | ||||||
| 			if req.Data == nil { | 			if req.Data == nil { | ||||||
| 				connErr(conn, nil, "Data required for notify request") | 				connErr(conn, req.Type, nil, "Data required for notify request") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			var reqData types.ReqDataNotify | 			var reqData types.ReqDataNotify | ||||||
| 			// Decode data map to notify request data | 			// Decode data map to notify request data | ||||||
| 			err = mapstructure.Decode(req.Data, &reqData) | 			err = mapstructure.Decode(req.Data, &reqData) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error decoding request data") | 				connErr(conn, req.Type, err, "Error decoding request data") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			maps := viper.GetStringSlice("notifs.translit.use") | 			maps := viper.GetStringSlice("notifs.translit.use") | ||||||
| @@ -313,7 +313,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Send notification to watch | 			// Send notification to watch | ||||||
| 			err = dev.Notify(title, body) | 			err = dev.Notify(title, body) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error sending notification") | 				connErr(conn, req.Type, err, "Error sending notification") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode empty types.Response to connection | 			// Encode empty types.Response to connection | ||||||
| @@ -321,13 +321,13 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeSetTime: | 		case types.ReqTypeSetTime: | ||||||
| 			// If no data, return error | 			// If no data, return error | ||||||
| 			if req.Data == nil { | 			if req.Data == nil { | ||||||
| 				connErr(conn, nil, "Data required for settime request") | 				connErr(conn, req.Type, nil, "Data required for settime request") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Get string from data or return error | 			// Get string from data or return error | ||||||
| 			reqTimeStr, ok := req.Data.(string) | 			reqTimeStr, ok := req.Data.(string) | ||||||
| 			if !ok { | 			if !ok { | ||||||
| 				connErr(conn, nil, "Data for settime request must be RFC3339 formatted time string") | 				connErr(conn, req.Type, nil, "Data for settime request must be RFC3339 formatted time string") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -338,14 +338,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 				// Parse time as RFC3339/ISO8601 | 				// Parse time as RFC3339/ISO8601 | ||||||
| 				reqTime, err = time.Parse(time.RFC3339, reqTimeStr) | 				reqTime, err = time.Parse(time.RFC3339, reqTimeStr) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					connErr(conn, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`") | 					connErr(conn, req.Type, err, "Invalid time format. Time string must be formatted as ISO8601 or the word `now`") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 			// Set time on watch | 			// Set time on watch | ||||||
| 			err = dev.SetTime(reqTime) | 			err = dev.SetTime(reqTime) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error setting device time") | 				connErr(conn, req.Type, err, "Error setting device time") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Encode empty types.Response to connection | 			// Encode empty types.Response to connection | ||||||
| @@ -353,14 +353,14 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 		case types.ReqTypeFwUpgrade: | 		case types.ReqTypeFwUpgrade: | ||||||
| 			// If no data, return error | 			// If no data, return error | ||||||
| 			if req.Data == nil { | 			if req.Data == nil { | ||||||
| 				connErr(conn, nil, "Data required for firmware upgrade request") | 				connErr(conn, req.Type, nil, "Data required for firmware upgrade request") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			var reqData types.ReqDataFwUpgrade | 			var reqData types.ReqDataFwUpgrade | ||||||
| 			// Decode data map to firmware upgrade request data | 			// Decode data map to firmware upgrade request data | ||||||
| 			err = mapstructure.Decode(req.Data, &reqData) | 			err = mapstructure.Decode(req.Data, &reqData) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error decoding request data") | 				connErr(conn, req.Type, err, "Error decoding request data") | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			// Reset DFU to prepare for next update | 			// Reset DFU to prepare for next update | ||||||
| @@ -369,40 +369,40 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			case types.UpgradeTypeArchive: | 			case types.UpgradeTypeArchive: | ||||||
| 				// If less than one file, return error | 				// If less than one file, return error | ||||||
| 				if len(reqData.Files) < 1 { | 				if len(reqData.Files) < 1 { | ||||||
| 					connErr(conn, nil, "Archive upgrade requires one file with .zip extension") | 					connErr(conn, req.Type, nil, "Archive upgrade requires one file with .zip extension") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// If file is not zip archive, return error | 				// If file is not zip archive, return error | ||||||
| 				if filepath.Ext(reqData.Files[0]) != ".zip" { | 				if filepath.Ext(reqData.Files[0]) != ".zip" { | ||||||
| 					connErr(conn, nil, "Archive upgrade file must be a zip archive") | 					connErr(conn, req.Type, nil, "Archive upgrade file must be a zip archive") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// Load DFU archive | 				// Load DFU archive | ||||||
| 				err := dev.DFU.LoadArchive(reqData.Files[0]) | 				err := dev.DFU.LoadArchive(reqData.Files[0]) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					connErr(conn, err, "Error loading archive file") | 					connErr(conn, req.Type, err, "Error loading archive file") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 			case types.UpgradeTypeFiles: | 			case types.UpgradeTypeFiles: | ||||||
| 				// If less than two files, return error | 				// If less than two files, return error | ||||||
| 				if len(reqData.Files) < 2 { | 				if len(reqData.Files) < 2 { | ||||||
| 					connErr(conn, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.") | 					connErr(conn, req.Type, nil, "Files upgrade requires two files. First with .dat and second with .bin extension.") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// If first file is not init packet, return error | 				// If first file is not init packet, return error | ||||||
| 				if filepath.Ext(reqData.Files[0]) != ".dat" { | 				if filepath.Ext(reqData.Files[0]) != ".dat" { | ||||||
| 					connErr(conn, nil, "First file must be a .dat file") | 					connErr(conn, req.Type, nil, "First file must be a .dat file") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// If second file is not firmware image, return error | 				// If second file is not firmware image, return error | ||||||
| 				if filepath.Ext(reqData.Files[1]) != ".bin" { | 				if filepath.Ext(reqData.Files[1]) != ".bin" { | ||||||
| 					connErr(conn, nil, "Second file must be a .bin file") | 					connErr(conn, req.Type, nil, "Second file must be a .bin file") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 				// Load individual DFU files | 				// Load individual DFU files | ||||||
| 				err := dev.DFU.LoadFiles(reqData.Files[0], reqData.Files[1]) | 				err := dev.DFU.LoadFiles(reqData.Files[0], reqData.Files[1]) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					connErr(conn, err, "Error loading firmware files") | 					connErr(conn, req.Type, err, "Error loading firmware files") | ||||||
| 					break | 					break | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| @@ -426,19 +426,19 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 			// Start DFU | 			// Start DFU | ||||||
| 			err = dev.DFU.Start() | 			err = dev.DFU.Start() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				connErr(conn, err, "Error performing upgrade") | 				connErr(conn, req.Type, err, "Error performing upgrade") | ||||||
| 				firmwareUpdating = false | 				firmwareUpdating = false | ||||||
| 				break | 				break | ||||||
| 			} | 			} | ||||||
| 			firmwareUpdating = false | 			firmwareUpdating = false | ||||||
| 		case types.ReqTypeCancel: | 		case types.ReqTypeCancel: | ||||||
| 			if req.Data == nil { | 			if req.Data == nil { | ||||||
| 				connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.") | 				connErr(conn, req.Type, nil, "No data provided. Cancel request requires request ID string as data.") | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			reqID, ok := req.Data.(string) | 			reqID, ok := req.Data.(string) | ||||||
| 			if !ok { | 			if !ok { | ||||||
| 				connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.") | 				connErr(conn, req.Type, nil, "Invalid data. Cancel request required request ID string as data.") | ||||||
| 			} | 			} | ||||||
| 			// Stop notifications | 			// Stop notifications | ||||||
| 			done.Done(reqID) | 			done.Done(reqID) | ||||||
| @@ -447,7 +447,7 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func connErr(conn net.Conn, err error, msg string) { | func connErr(conn net.Conn, resType int, err error, msg string) { | ||||||
| 	var res types.Response | 	var res types.Response | ||||||
| 	// If error exists, add to types.Response, otherwise don't | 	// If error exists, add to types.Response, otherwise don't | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -455,7 +455,7 @@ func connErr(conn net.Conn, err error, msg string) { | |||||||
| 		res = types.Response{Message: fmt.Sprintf("%s: %s", msg, err)} | 		res = types.Response{Message: fmt.Sprintf("%s: %s", msg, err)} | ||||||
| 	} else { | 	} else { | ||||||
| 		log.Error().Msg(msg) | 		log.Error().Msg(msg) | ||||||
| 		res = types.Response{Message: msg, Type: math.MaxInt} | 		res = types.Response{Message: msg, Type: resType} | ||||||
| 	} | 	} | ||||||
| 	res.Error = true | 	res.Error = true | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user