Generalize socket cancellation and update API accordingly
This commit is contained in:
parent
2ab8d24a43
commit
e45bfe3de8
@ -18,11 +18,11 @@ const DefaultAddr = "/tmp/itd/socket"
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
respCh chan types.Response
|
||||
heartRateCh chan uint8
|
||||
battLevelCh chan uint8
|
||||
stepCountCh chan uint32
|
||||
motionCh chan infinitime.MotionValues
|
||||
dfuProgressCh chan DFUProgress
|
||||
heartRateCh chan types.Response
|
||||
battLevelCh chan types.Response
|
||||
stepCountCh chan types.Response
|
||||
motionCh chan types.Response
|
||||
dfuProgressCh chan types.Response
|
||||
}
|
||||
|
||||
// New creates a new client and sets it up
|
||||
@ -91,27 +91,43 @@ func (c *Client) requestNoRes(req types.Request) error {
|
||||
func (c *Client) handleResp(res types.Response) error {
|
||||
switch res.Type {
|
||||
case types.ResTypeWatchHeartRate:
|
||||
c.heartRateCh <- uint8(res.Value.(float64))
|
||||
c.heartRateCh <- res
|
||||
case types.ResTypeWatchBattLevel:
|
||||
c.battLevelCh <- uint8(res.Value.(float64))
|
||||
c.battLevelCh <- res
|
||||
case types.ResTypeWatchStepCount:
|
||||
c.stepCountCh <- uint32(res.Value.(float64))
|
||||
c.stepCountCh <- res
|
||||
case types.ResTypeWatchMotion:
|
||||
out := infinitime.MotionValues{}
|
||||
err := mapstructure.Decode(res.Value, &out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.motionCh <- out
|
||||
c.motionCh <- res
|
||||
case types.ResTypeDFUProgress:
|
||||
out := DFUProgress{}
|
||||
err := mapstructure.Decode(res.Value, &out)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.dfuProgressCh <- out
|
||||
c.dfuProgressCh <- res
|
||||
default:
|
||||
c.respCh <- res
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUint8(val interface{}) uint8 {
|
||||
return uint8(val.(float64))
|
||||
}
|
||||
|
||||
func decodeUint32(val interface{}) uint32 {
|
||||
return uint32(val.(float64))
|
||||
}
|
||||
|
||||
func decodeMotion(val interface{}) (infinitime.MotionValues, error) {
|
||||
out := infinitime.MotionValues{}
|
||||
err := mapstructure.Decode(val, &out)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func decodeDFUProgress(val interface{}) (DFUProgress, error) {
|
||||
out := DFUProgress{}
|
||||
err := mapstructure.Decode(val, &out)
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
93
api/info.go
93
api/info.go
@ -1,8 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"go.arsenm.dev/infinitime"
|
||||
"go.arsenm.dev/itd/internal/types"
|
||||
@ -48,15 +46,27 @@ func (c *Client) BatteryLevel() (uint8, error) {
|
||||
// new battery level values as they update. Do not use after
|
||||
// calling cancellation function
|
||||
func (c *Client) WatchBatteryLevel() (<-chan uint8, func(), error) {
|
||||
c.battLevelCh = make(chan uint8, 2)
|
||||
c.battLevelCh = make(chan types.Response, 2)
|
||||
err := c.requestNoRes(types.Request{
|
||||
Type: types.ReqTypeBattLevel,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancel := c.cancelFn(types.ReqTypeCancelBattLevel, c.battLevelCh)
|
||||
return c.battLevelCh, cancel, nil
|
||||
res := <-c.battLevelCh
|
||||
done, cancel := c.cancelFn(res.ID, c.battLevelCh)
|
||||
out := make(chan uint8, 2)
|
||||
go func() {
|
||||
for res := range c.battLevelCh {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
out <- decodeUint8(res.Value)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, cancel, nil
|
||||
}
|
||||
|
||||
// HeartRate gets the heart rate from the connected device
|
||||
@ -68,33 +78,46 @@ func (c *Client) HeartRate() (uint8, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint8(res.Value.(float64)), nil
|
||||
return decodeUint8(res.Value), nil
|
||||
}
|
||||
|
||||
// WatchHeartRate returns a channel which will contain
|
||||
// new heart rate values as they update. Do not use after
|
||||
// calling cancellation function
|
||||
func (c *Client) WatchHeartRate() (<-chan uint8, func(), error) {
|
||||
c.heartRateCh = make(chan uint8, 2)
|
||||
c.heartRateCh = make(chan types.Response, 2)
|
||||
err := c.requestNoRes(types.Request{
|
||||
Type: types.ReqTypeWatchHeartRate,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancel := c.cancelFn(types.ReqTypeCancelHeartRate, c.heartRateCh)
|
||||
return c.heartRateCh, cancel, nil
|
||||
res := <-c.heartRateCh
|
||||
done, cancel := c.cancelFn(res.ID, c.heartRateCh)
|
||||
out := make(chan uint8, 2)
|
||||
go func() {
|
||||
for res := range c.heartRateCh {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
out <- decodeUint8(res.Value)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, cancel, nil
|
||||
}
|
||||
|
||||
// cancelFn generates a cancellation function for the given
|
||||
// request type and channel
|
||||
func (c *Client) cancelFn(reqType int, ch interface{}) func() {
|
||||
return func() {
|
||||
reflectCh := reflect.ValueOf(ch)
|
||||
reflectCh.Close()
|
||||
reflectCh.Set(reflect.Zero(reflectCh.Type()))
|
||||
func (c *Client) cancelFn(reqID string, ch chan types.Response) (chan struct{}, func()) {
|
||||
done := make(chan struct{}, 1)
|
||||
return done, func() {
|
||||
done <- struct{}{}
|
||||
close(ch)
|
||||
c.requestNoRes(types.Request{
|
||||
Type: reqType,
|
||||
Type: types.ReqTypeCancel,
|
||||
Data: reqID,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -115,15 +138,27 @@ func (c *Client) StepCount() (uint32, error) {
|
||||
// new step count values as they update. Do not use after
|
||||
// calling cancellation function
|
||||
func (c *Client) WatchStepCount() (<-chan uint32, func(), error) {
|
||||
c.stepCountCh = make(chan uint32, 2)
|
||||
c.stepCountCh = make(chan types.Response, 2)
|
||||
err := c.requestNoRes(types.Request{
|
||||
Type: types.ReqTypeWatchStepCount,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancel := c.cancelFn(types.ReqTypeCancelStepCount, c.stepCountCh)
|
||||
return c.stepCountCh, cancel, nil
|
||||
res := <-c.stepCountCh
|
||||
done, cancel := c.cancelFn(res.ID, c.stepCountCh)
|
||||
out := make(chan uint32, 2)
|
||||
go func() {
|
||||
for res := range c.stepCountCh {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
out <- decodeUint32(res.Value)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, cancel, nil
|
||||
}
|
||||
|
||||
// Motion gets the motion values from the connected device
|
||||
@ -146,13 +181,29 @@ func (c *Client) Motion() (infinitime.MotionValues, error) {
|
||||
// new motion values as they update. Do not use after
|
||||
// calling cancellation function
|
||||
func (c *Client) WatchMotion() (<-chan infinitime.MotionValues, func(), error) {
|
||||
c.motionCh = make(chan infinitime.MotionValues, 2)
|
||||
c.motionCh = make(chan types.Response, 5)
|
||||
err := c.requestNoRes(types.Request{
|
||||
Type: types.ReqTypeWatchMotion,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancel := c.cancelFn(types.ReqTypeCancelMotion, c.motionCh)
|
||||
return c.motionCh, cancel, nil
|
||||
res := <-c.motionCh
|
||||
done, cancel := c.cancelFn(res.ID, c.motionCh)
|
||||
out := make(chan infinitime.MotionValues, 5)
|
||||
go func() {
|
||||
for res := range c.motionCh {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
default:
|
||||
motion, err := decodeMotion(res.Value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out <- motion
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, cancel, nil
|
||||
}
|
||||
|
@ -31,7 +31,18 @@ func (c *Client) FirmwareUpgrade(upgType UpgradeType, files ...string) (<-chan D
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.dfuProgressCh = make(chan DFUProgress, 5)
|
||||
c.dfuProgressCh = make(chan types.Response, 5)
|
||||
|
||||
return c.dfuProgressCh, nil
|
||||
out := make(chan DFUProgress, 5)
|
||||
go func() {
|
||||
for res := range c.dfuProgressCh {
|
||||
progress, err := decodeDFUProgress(res.Value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
out <- progress
|
||||
}
|
||||
}()
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
28
cmd/test/main.go
Normal file
28
cmd/test/main.go
Normal file
@ -0,0 +1,28 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.arsenm.dev/itd/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
itd, _ := api.New(api.DefaultAddr)
|
||||
defer itd.Close()
|
||||
|
||||
fmt.Println(itd.Address())
|
||||
|
||||
mCh, cancel, _ := itd.WatchMotion()
|
||||
|
||||
go func() {
|
||||
time.Sleep(10 * time.Second)
|
||||
cancel()
|
||||
fmt.Println("canceled")
|
||||
}()
|
||||
|
||||
for m := range mCh {
|
||||
fmt.Println(m)
|
||||
}
|
||||
|
||||
}
|
1
go.mod
1
go.mod
@ -13,6 +13,7 @@ require (
|
||||
github.com/go-gl/gl v0.0.0-20210905235341-f7a045908259 // indirect
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20210727001814-0db043d8d5be // indirect
|
||||
github.com/godbus/dbus/v5 v5.0.5
|
||||
github.com/google/uuid v1.1.2
|
||||
github.com/mattn/go-colorable v0.1.11 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.13 // indirect
|
||||
github.com/mitchellh/mapstructure v1.4.2
|
||||
|
1
go.sum
1
go.sum
@ -193,6 +193,7 @@ github.com/google/pprof v0.0.0-20210609004039-a478d1d731e9/go.mod h1:kpwsk12EmLe
|
||||
github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
|
@ -9,15 +9,12 @@ const (
|
||||
ReqTypeNotify
|
||||
ReqTypeSetTime
|
||||
ReqTypeWatchHeartRate
|
||||
ReqTypeCancelHeartRate
|
||||
ReqTypeWatchBattLevel
|
||||
ReqTypeCancelBattLevel
|
||||
ReqTypeMotion
|
||||
ReqTypeWatchMotion
|
||||
ReqTypeCancelMotion
|
||||
ReqTypeStepCount
|
||||
ReqTypeWatchStepCount
|
||||
ReqTypeCancelStepCount
|
||||
ReqTypeCancel
|
||||
)
|
||||
|
||||
const (
|
||||
@ -29,15 +26,12 @@ const (
|
||||
ResTypeNotify
|
||||
ResTypeSetTime
|
||||
ResTypeWatchHeartRate
|
||||
ResTypeCancelHeartRate
|
||||
ResTypeWatchBattLevel
|
||||
ResTypeCancelBattLevel
|
||||
ResTypeMotion
|
||||
ResTypeWatchMotion
|
||||
ResTypeCancelMotion
|
||||
ResTypeStepCount
|
||||
ResTypeWatchStepCount
|
||||
ResTypeCancelStepCount
|
||||
ResTypeCancel
|
||||
)
|
||||
|
||||
const (
|
||||
@ -54,6 +48,7 @@ type Response struct {
|
||||
Type int `json:"type"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
Message string `json:"msg,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Error bool `json:"error"`
|
||||
}
|
||||
|
||||
|
82
socket.go
82
socket.go
@ -27,6 +27,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
@ -35,6 +36,29 @@ import (
|
||||
"go.arsenm.dev/itd/translit"
|
||||
)
|
||||
|
||||
type DoneMap map[string]chan struct{}
|
||||
|
||||
func (dm DoneMap) Exists(key string) bool {
|
||||
_, ok := dm[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (dm DoneMap) Done(key string) {
|
||||
ch := dm[key]
|
||||
ch <- struct{}{}
|
||||
}
|
||||
|
||||
func (dm DoneMap) Create(key string) {
|
||||
dm[key] = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
func (dm DoneMap) Remove(key string) {
|
||||
close(dm[key])
|
||||
delete(dm, key)
|
||||
}
|
||||
|
||||
var done = DoneMap{}
|
||||
|
||||
func startSocket(dev *infinitime.Device) error {
|
||||
// Make socket directory if non-existant
|
||||
err := os.MkdirAll(filepath.Dir(viper.GetString("socket.path")), 0755)
|
||||
@ -81,11 +105,6 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
return
|
||||
}
|
||||
|
||||
heartRateDone := make(chan struct{})
|
||||
battLevelDone := make(chan struct{})
|
||||
stepCountDone := make(chan struct{})
|
||||
motionDone := make(chan struct{})
|
||||
|
||||
// Create new scanner on connection
|
||||
scanner := bufio.NewScanner(conn)
|
||||
for scanner.Scan() {
|
||||
@ -116,27 +135,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
connErr(conn, err, "Error getting heart rate channel")
|
||||
break
|
||||
}
|
||||
reqID := uuid.New().String()
|
||||
go func() {
|
||||
done.Create(reqID)
|
||||
// For every heart rate value
|
||||
for heartRate := range heartRateCh {
|
||||
select {
|
||||
case <-heartRateDone:
|
||||
case <-done[reqID]:
|
||||
// Stop notifications if done signal received
|
||||
cancel()
|
||||
done.Remove(reqID)
|
||||
return
|
||||
default:
|
||||
// Encode response to connection if no done signal received
|
||||
json.NewEncoder(conn).Encode(types.Response{
|
||||
Type: types.ResTypeWatchHeartRate,
|
||||
ID: reqID,
|
||||
Value: heartRate,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
case types.ReqTypeCancelHeartRate:
|
||||
// Stop heart rate notifications
|
||||
heartRateDone <- struct{}{}
|
||||
json.NewEncoder(conn).Encode(types.Response{})
|
||||
case types.ReqTypeBattLevel:
|
||||
// Get battery level from watch
|
||||
battLevel, err := dev.BatteryLevel()
|
||||
@ -155,27 +174,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
connErr(conn, err, "Error getting battery level channel")
|
||||
break
|
||||
}
|
||||
reqID := uuid.New().String()
|
||||
go func() {
|
||||
done.Create(reqID)
|
||||
// For every battery level value
|
||||
for battLevel := range battLevelCh {
|
||||
select {
|
||||
case <-battLevelDone:
|
||||
case <-done[reqID]:
|
||||
// Stop notifications if done signal received
|
||||
cancel()
|
||||
done.Remove(reqID)
|
||||
return
|
||||
default:
|
||||
// Encode response to connection if no done signal received
|
||||
json.NewEncoder(conn).Encode(types.Response{
|
||||
Type: types.ResTypeWatchBattLevel,
|
||||
ID: reqID,
|
||||
Value: battLevel,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
case types.ReqTypeCancelBattLevel:
|
||||
// Stop battery level notifications
|
||||
battLevelDone <- struct{}{}
|
||||
json.NewEncoder(conn).Encode(types.Response{})
|
||||
case types.ReqTypeMotion:
|
||||
// Get battery level from watch
|
||||
motionVals, err := dev.Motion()
|
||||
@ -194,27 +213,28 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
connErr(conn, err, "Error getting heart rate channel")
|
||||
break
|
||||
}
|
||||
reqID := uuid.New().String()
|
||||
go func() {
|
||||
done.Create(reqID)
|
||||
// For every motion event
|
||||
for motionVals := range motionValCh {
|
||||
select {
|
||||
case <-motionDone:
|
||||
case <-done[reqID]:
|
||||
// Stop notifications if done signal received
|
||||
cancel()
|
||||
done.Remove(reqID)
|
||||
|
||||
return
|
||||
default:
|
||||
// Encode response to connection if no done signal received
|
||||
json.NewEncoder(conn).Encode(types.Response{
|
||||
Type: types.ResTypeWatchMotion,
|
||||
ID: reqID,
|
||||
Value: motionVals,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
case types.ReqTypeCancelMotion:
|
||||
// Stop motion notifications
|
||||
motionDone <- struct{}{}
|
||||
json.NewEncoder(conn).Encode(types.Response{})
|
||||
case types.ReqTypeStepCount:
|
||||
// Get battery level from watch
|
||||
stepCount, err := dev.StepCount()
|
||||
@ -233,27 +253,27 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
connErr(conn, err, "Error getting heart rate channel")
|
||||
break
|
||||
}
|
||||
reqID := uuid.New().String()
|
||||
go func() {
|
||||
done.Create(reqID)
|
||||
// For every step count value
|
||||
for stepCount := range stepCountCh {
|
||||
select {
|
||||
case <-stepCountDone:
|
||||
case <-done[reqID]:
|
||||
// Stop notifications if done signal received
|
||||
cancel()
|
||||
done.Remove(reqID)
|
||||
return
|
||||
default:
|
||||
// Encode response to connection if no done signal received
|
||||
json.NewEncoder(conn).Encode(types.Response{
|
||||
Type: types.ResTypeWatchStepCount,
|
||||
ID: reqID,
|
||||
Value: stepCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
case types.ReqTypeCancelStepCount:
|
||||
// Stop step count notifications
|
||||
stepCountDone <- struct{}{}
|
||||
json.NewEncoder(conn).Encode(types.Response{})
|
||||
case types.ReqTypeFwVersion:
|
||||
// Get firmware version from watch
|
||||
version, err := dev.Version()
|
||||
@ -409,6 +429,18 @@ func handleConnection(conn net.Conn, dev *infinitime.Device) {
|
||||
break
|
||||
}
|
||||
firmwareUpdating = false
|
||||
case types.ReqTypeCancel:
|
||||
if req.Data == nil {
|
||||
connErr(conn, nil, "No data provided. Cancel request requires request ID string as data.")
|
||||
continue
|
||||
}
|
||||
reqID, ok := req.Data.(string)
|
||||
if !ok {
|
||||
connErr(conn, nil, "Invalid data. Cancel request required request ID string as data.")
|
||||
}
|
||||
// Stop notifications
|
||||
done.Done(reqID)
|
||||
json.NewEncoder(conn).Encode(types.Response{Type: types.ResTypeCancel})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user