Generalize socket cancellation and update API accordingly

This commit is contained in:
2021-10-23 18:03:17 -07:00
parent ef4bad94b5
commit e198b769f9
9 changed files with 211 additions and 76 deletions

View File

@@ -18,11 +18,11 @@ const DefaultAddr = "/tmp/itd/socket"
type Client struct {
conn net.Conn
respCh chan types.Response
heartRateCh chan uint8
battLevelCh chan uint8
stepCountCh chan uint32
motionCh chan infinitime.MotionValues
dfuProgressCh chan DFUProgress
heartRateCh chan types.Response
battLevelCh chan types.Response
stepCountCh chan types.Response
motionCh chan types.Response
dfuProgressCh chan types.Response
}
// New creates a new client and sets it up
@@ -91,27 +91,43 @@ func (c *Client) requestNoRes(req types.Request) error {
func (c *Client) handleResp(res types.Response) error {
switch res.Type {
case types.ResTypeWatchHeartRate:
c.heartRateCh <- uint8(res.Value.(float64))
c.heartRateCh <- res
case types.ResTypeWatchBattLevel:
c.battLevelCh <- uint8(res.Value.(float64))
c.battLevelCh <- res
case types.ResTypeWatchStepCount:
c.stepCountCh <- uint32(res.Value.(float64))
c.stepCountCh <- res
case types.ResTypeWatchMotion:
out := infinitime.MotionValues{}
err := mapstructure.Decode(res.Value, &out)
if err != nil {
return err
}
c.motionCh <- out
c.motionCh <- res
case types.ResTypeDFUProgress:
out := DFUProgress{}
err := mapstructure.Decode(res.Value, &out)
if err != nil {
return err
}
c.dfuProgressCh <- out
c.dfuProgressCh <- res
default:
c.respCh <- res
}
return nil
}
func decodeUint8(val interface{}) uint8 {
return uint8(val.(float64))
}
func decodeUint32(val interface{}) uint32 {
return uint32(val.(float64))
}
func decodeMotion(val interface{}) (infinitime.MotionValues, error) {
out := infinitime.MotionValues{}
err := mapstructure.Decode(val, &out)
if err != nil {
return out, err
}
return out, nil
}
func decodeDFUProgress(val interface{}) (DFUProgress, error) {
out := DFUProgress{}
err := mapstructure.Decode(val, &out)
if err != nil {
return out, err
}
return out, nil
}

View File

@@ -1,8 +1,6 @@
package api
import (
"reflect"
"github.com/mitchellh/mapstructure"
"go.arsenm.dev/infinitime"
"go.arsenm.dev/itd/internal/types"
@@ -48,15 +46,27 @@ func (c *Client) BatteryLevel() (uint8, error) {
// new battery level values as they update. Do not use after
// calling cancellation function
func (c *Client) WatchBatteryLevel() (<-chan uint8, func(), error) {
c.battLevelCh = make(chan uint8, 2)
c.battLevelCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{
Type: types.ReqTypeBattLevel,
})
if err != nil {
return nil, nil, err
}
cancel := c.cancelFn(types.ReqTypeCancelBattLevel, c.battLevelCh)
return c.battLevelCh, cancel, nil
res := <-c.battLevelCh
done, cancel := c.cancelFn(res.ID, c.battLevelCh)
out := make(chan uint8, 2)
go func() {
for res := range c.battLevelCh {
select {
case <-done:
return
default:
out <- decodeUint8(res.Value)
}
}
}()
return out, cancel, nil
}
// HeartRate gets the heart rate from the connected device
@@ -68,33 +78,46 @@ func (c *Client) HeartRate() (uint8, error) {
return 0, err
}
return uint8(res.Value.(float64)), nil
return decodeUint8(res.Value), nil
}
// WatchHeartRate returns a channel which will contain
// new heart rate values as they update. Do not use after
// calling cancellation function
func (c *Client) WatchHeartRate() (<-chan uint8, func(), error) {
c.heartRateCh = make(chan uint8, 2)
c.heartRateCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchHeartRate,
})
if err != nil {
return nil, nil, err
}
cancel := c.cancelFn(types.ReqTypeCancelHeartRate, c.heartRateCh)
return c.heartRateCh, cancel, nil
res := <-c.heartRateCh
done, cancel := c.cancelFn(res.ID, c.heartRateCh)
out := make(chan uint8, 2)
go func() {
for res := range c.heartRateCh {
select {
case <-done:
return
default:
out <- decodeUint8(res.Value)
}
}
}()
return out, cancel, nil
}
// cancelFn generates a cancellation function for the given
// request type and channel
func (c *Client) cancelFn(reqType int, ch interface{}) func() {
return func() {
reflectCh := reflect.ValueOf(ch)
reflectCh.Close()
reflectCh.Set(reflect.Zero(reflectCh.Type()))
func (c *Client) cancelFn(reqID string, ch chan types.Response) (chan struct{}, func()) {
done := make(chan struct{}, 1)
return done, func() {
done <- struct{}{}
close(ch)
c.requestNoRes(types.Request{
Type: reqType,
Type: types.ReqTypeCancel,
Data: reqID,
})
}
}
@@ -115,15 +138,27 @@ func (c *Client) StepCount() (uint32, error) {
// new step count values as they update. Do not use after
// calling cancellation function
func (c *Client) WatchStepCount() (<-chan uint32, func(), error) {
c.stepCountCh = make(chan uint32, 2)
c.stepCountCh = make(chan types.Response, 2)
err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchStepCount,
})
if err != nil {
return nil, nil, err
}
cancel := c.cancelFn(types.ReqTypeCancelStepCount, c.stepCountCh)
return c.stepCountCh, cancel, nil
res := <-c.stepCountCh
done, cancel := c.cancelFn(res.ID, c.stepCountCh)
out := make(chan uint32, 2)
go func() {
for res := range c.stepCountCh {
select {
case <-done:
return
default:
out <- decodeUint32(res.Value)
}
}
}()
return out, cancel, nil
}
// Motion gets the motion values from the connected device
@@ -146,13 +181,29 @@ func (c *Client) Motion() (infinitime.MotionValues, error) {
// new motion values as they update. Do not use after
// calling cancellation function
func (c *Client) WatchMotion() (<-chan infinitime.MotionValues, func(), error) {
c.motionCh = make(chan infinitime.MotionValues, 2)
c.motionCh = make(chan types.Response, 5)
err := c.requestNoRes(types.Request{
Type: types.ReqTypeWatchMotion,
})
if err != nil {
return nil, nil, err
}
cancel := c.cancelFn(types.ReqTypeCancelMotion, c.motionCh)
return c.motionCh, cancel, nil
res := <-c.motionCh
done, cancel := c.cancelFn(res.ID, c.motionCh)
out := make(chan infinitime.MotionValues, 5)
go func() {
for res := range c.motionCh {
select {
case <-done:
return
default:
motion, err := decodeMotion(res.Value)
if err != nil {
continue
}
out <- motion
}
}
}()
return out, cancel, nil
}

View File

@@ -31,7 +31,18 @@ func (c *Client) FirmwareUpgrade(upgType UpgradeType, files ...string) (<-chan D
return nil, err
}
c.dfuProgressCh = make(chan DFUProgress, 5)
c.dfuProgressCh = make(chan types.Response, 5)
return c.dfuProgressCh, nil
out := make(chan DFUProgress, 5)
go func() {
for res := range c.dfuProgressCh {
progress, err := decodeDFUProgress(res.Value)
if err != nil {
continue
}
out <- progress
}
}()
return out, nil
}