2021-04-22 02:29:14 +00:00
|
|
|
/*
|
|
|
|
* Copyright (C) 2021 Arsen Musayelyan
|
|
|
|
*
|
|
|
|
* This program is free software: you can redistribute it and/or modify
|
|
|
|
* it under the terms of the GNU General Public License as published by
|
|
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
|
|
* (at your option) any later version.
|
|
|
|
*
|
|
|
|
* This program is distributed in the hope that it will be useful,
|
|
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
* GNU General Public License for more details.
|
|
|
|
*
|
|
|
|
* You should have received a copy of the GNU General Public License
|
|
|
|
* along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
*/
|
|
|
|
|
|
|
|
package main
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bufio"
|
|
|
|
"bytes"
|
|
|
|
ds "github.com/asticode/go-astideepspeech"
|
|
|
|
"github.com/gen2brain/malgo"
|
|
|
|
flag "github.com/spf13/pflag"
|
|
|
|
"net"
|
|
|
|
"os"
|
|
|
|
"os/signal"
|
|
|
|
"path/filepath"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"syscall"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
var verbose *bool
|
|
|
|
var execDir string
|
|
|
|
var configDir string
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
// Configure environment (paths to resources)
|
|
|
|
var gopath, confPath string
|
|
|
|
gopath, configDir, execDir, confPath = configEnv()
|
|
|
|
|
|
|
|
// Define and parse command line flags
|
|
|
|
tfLogLevel := flag.Int("tf-log-level", 2, "Log level for TensorFlow")
|
|
|
|
verbose = flag.BoolP("verbose", "v", false, "Log more events")
|
2021-04-23 03:00:02 +00:00
|
|
|
showDecode := flag.BoolP("show-decode", "d", false, "Show text to speech decodes")
|
2021-04-22 02:29:14 +00:00
|
|
|
configPath := flag.StringP("config", "c", confPath, "Location of trident TOML config")
|
|
|
|
modelPath := flag.StringP("model", "m", filepath.Join(execDir, "deepspeech.pbmm"), "Path to DeepSpeech model")
|
|
|
|
scorerPath := flag.StringP("scorer", "s", filepath.Join(execDir, "deepspeech.scorer"), "Path to DeepSpeech scorer")
|
|
|
|
socketPath := flag.StringP("socket", "S", filepath.Join(configDir, "trident.sock"), "Path to UNIX socket for IPC")
|
|
|
|
GOPATH := flag.String("gopath", gopath, "GOPATH for use with plugins")
|
|
|
|
flag.Parse()
|
|
|
|
|
|
|
|
// Set TensorFlow log level to specified level (default 2)
|
|
|
|
_ = os.Setenv("TF_CPP_MIN_LOG_LEVEL", strconv.Itoa(*tfLogLevel))
|
|
|
|
|
|
|
|
// Get and parse TOML config
|
|
|
|
config, err := getConfig(*configPath)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error getting TOML config")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create new channel storing os.Signal
|
|
|
|
sigChannel := make(chan os.Signal, 1)
|
|
|
|
// Notify channel upon reception of specified signals
|
|
|
|
signal.Notify(sigChannel,
|
|
|
|
syscall.SIGINT,
|
|
|
|
syscall.SIGTERM,
|
|
|
|
syscall.SIGHUP,
|
|
|
|
syscall.SIGQUIT,
|
|
|
|
)
|
|
|
|
// Create new goroutine to handle signals gracefully
|
|
|
|
go func() {
|
|
|
|
// Wait for signal
|
|
|
|
sig := <-sigChannel
|
|
|
|
// Log reception of signal
|
|
|
|
log.Info().Str("signal", sig.String()).Msg("Received signal, shutting down")
|
|
|
|
// If IPC is enabled in the config, remove the UNIX socket
|
|
|
|
if config.IPCEnabled {
|
|
|
|
_ = os.RemoveAll(*socketPath)
|
|
|
|
}
|
|
|
|
// Exit with code 0
|
|
|
|
os.Exit(0)
|
|
|
|
}()
|
|
|
|
|
|
|
|
// Create new DeepSpeech model
|
|
|
|
model, err := ds.New(*modelPath)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error opening DeepSpeech model")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize available plugins
|
|
|
|
plugins := initPlugins(*GOPATH)
|
|
|
|
|
|
|
|
// If IPC is enabled in config
|
|
|
|
if config.IPCEnabled {
|
|
|
|
// Remove UNIX socket ignoring error
|
|
|
|
_ = os.RemoveAll(*socketPath)
|
|
|
|
// Listen on UNIX socket
|
|
|
|
ln, err := net.Listen("unix", *socketPath)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error listening on UNIX socket")
|
|
|
|
}
|
|
|
|
go func() {
|
|
|
|
for {
|
|
|
|
// Accept any connection when it arrives
|
|
|
|
conn, err := ln.Accept()
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error accepting connection")
|
|
|
|
}
|
|
|
|
go func(conn net.Conn) {
|
|
|
|
// Close connection at end of function
|
|
|
|
defer conn.Close()
|
|
|
|
// Create new scanner for connection (default is ScanLines)
|
|
|
|
scanner := bufio.NewScanner(conn)
|
|
|
|
// Scan until EOF
|
|
|
|
for scanner.Scan() {
|
|
|
|
// If error encountered, return from function
|
|
|
|
if scanner.Err() != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
// Get text from scanner
|
|
|
|
input := scanner.Text()
|
|
|
|
// Attempt to match text to action and return action
|
|
|
|
action, ok := getAction(config, &input)
|
|
|
|
// If match founc
|
|
|
|
if ok {
|
|
|
|
// Log performing action
|
|
|
|
log.Info().Str("action", action.Name).Str("source", "socket").Msg("Performing action")
|
|
|
|
// Perform returned action
|
|
|
|
done, err := performAction(action, &input, plugins)
|
|
|
|
if err != nil {
|
|
|
|
log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action")
|
|
|
|
}
|
|
|
|
// If action complete, close connection and return
|
|
|
|
if done {
|
|
|
|
conn.Close()
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}(conn)
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
|
|
|
|
// Initialize audio context
|
|
|
|
ctx, err := malgo.InitContext(nil, malgo.ContextConfig{}, func(message string) {
|
|
|
|
log.Warn().Msg(message)
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error initializing malgo context")
|
|
|
|
}
|
|
|
|
// Uninitialize and free at end of function
|
|
|
|
defer func() {
|
|
|
|
_ = ctx.Uninit()
|
|
|
|
ctx.Free()
|
|
|
|
}()
|
|
|
|
|
|
|
|
// Set device configuration options
|
|
|
|
deviceConfig := malgo.DefaultDeviceConfig(malgo.Capture)
|
|
|
|
deviceConfig.Capture.Format = malgo.FormatS16
|
|
|
|
deviceConfig.Capture.Channels = 1
|
|
|
|
deviceConfig.Playback.Format = malgo.FormatS16
|
|
|
|
deviceConfig.Playback.Channels = 1
|
|
|
|
deviceConfig.SampleRate = uint32(model.SampleRate())
|
|
|
|
deviceConfig.Alsa.NoMMap = 1
|
|
|
|
|
|
|
|
// Create new buffer to store audio samples
|
|
|
|
captured := &bytes.Buffer{}
|
|
|
|
onRecvFrames := func(_, sample []byte, _ uint32) {
|
|
|
|
// Upon receipt of sample, write to buffer
|
|
|
|
captured.Write(sample)
|
|
|
|
}
|
|
|
|
log.Info().Msg("Listening to audio events")
|
|
|
|
// Initialize audio device using configuration options
|
|
|
|
device, err := malgo.InitDevice(ctx.Context, deviceConfig, malgo.DeviceCallbacks{
|
|
|
|
Data: onRecvFrames,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error initializing audio device")
|
|
|
|
}
|
|
|
|
// Uninitialize at end of function
|
|
|
|
defer device.Uninit()
|
|
|
|
|
|
|
|
// Start capture device (begin recording)
|
|
|
|
err = device.Start()
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error starting capture device")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set DeepSpeech scorer
|
|
|
|
err = model.EnableExternalScorer(*scorerPath)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error opening DeepSpeech scorer")
|
|
|
|
}
|
|
|
|
|
|
|
|
// Create new stream for DeepSpeech model
|
|
|
|
sttStream, err := model.NewStream()
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error creating DeepSpeech stream")
|
|
|
|
}
|
|
|
|
// Create a safe stream using sync.Mutex
|
|
|
|
safeStream := &SafeStream{Stream: sttStream}
|
|
|
|
|
|
|
|
// Create goroutine to clean stream every minute
|
|
|
|
go func() {
|
|
|
|
for {
|
2021-04-23 03:13:47 +00:00
|
|
|
time.Sleep(20 * time.Second)
|
2021-04-22 02:29:14 +00:00
|
|
|
// Lock mutex of stream
|
|
|
|
safeStream.Lock()
|
|
|
|
// Reset stream and buffer
|
|
|
|
resetStream(safeStream, model, captured)
|
|
|
|
if *verbose {
|
|
|
|
log.Debug().Msg("1m passed; cleaning stream")
|
|
|
|
}
|
|
|
|
// Unlock mutex of stream
|
|
|
|
safeStream.Unlock()
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
var tts string
|
|
|
|
listenForActivation := true
|
|
|
|
for {
|
2021-04-23 03:13:47 +00:00
|
|
|
time.Sleep(200 * time.Millisecond)
|
2021-04-22 02:29:14 +00:00
|
|
|
// Convert captured raw audio to slice of int16
|
|
|
|
slice, err := convToInt16Slice(captured)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error converting captured audio feed")
|
|
|
|
}
|
|
|
|
// Reset buffer
|
|
|
|
captured.Reset()
|
|
|
|
// Lock mutex of stream
|
|
|
|
safeStream.Lock()
|
|
|
|
// Feed converted audio to stream
|
|
|
|
safeStream.FeedAudioContent(slice)
|
|
|
|
// Decode stream without destroying
|
|
|
|
tts, err = safeStream.IntermediateDecode()
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error intermediate decoding stream")
|
|
|
|
}
|
2021-04-23 03:00:02 +00:00
|
|
|
if *showDecode {
|
|
|
|
log.Debug().Msg("TTS Decode: " + tts)
|
|
|
|
}
|
2021-04-22 02:29:14 +00:00
|
|
|
// If decoded string contains activation phrase and listenForActivation is true
|
|
|
|
if strings.Contains(tts, config.ActivationPhrase) && listenForActivation {
|
|
|
|
// Play activation tone
|
|
|
|
err = playActivationTone(ctx)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal().Err(err).Msg("Error playing activation tone")
|
|
|
|
}
|
|
|
|
// Log detection of activation phrase
|
|
|
|
log.Info().Msg("Activation phrase detected")
|
|
|
|
// Reset stream and buffer
|
|
|
|
resetStream(safeStream, model, captured)
|
|
|
|
// Create new goroutine to listen for commands
|
|
|
|
go func() {
|
|
|
|
// Disable activation
|
|
|
|
listenForActivation = false
|
|
|
|
// Enable activation at end of function
|
|
|
|
defer func() {
|
|
|
|
listenForActivation = true
|
|
|
|
}()
|
|
|
|
// Create timeout channel to trigger after configured time
|
|
|
|
timeout := time.After(config.ActivationTime)
|
|
|
|
activationLoop:
|
|
|
|
for {
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
select {
|
|
|
|
// If timeout has elapsed
|
|
|
|
case <-timeout:
|
|
|
|
log.Warn().Msg("Unknown command")
|
|
|
|
break activationLoop
|
|
|
|
// If timeout has not elapsed
|
|
|
|
default:
|
|
|
|
// Attempt to match decoded string to action
|
|
|
|
action, ok := getAction(config, &tts)
|
|
|
|
// If match found
|
|
|
|
if ok {
|
|
|
|
// Keep listening if user is talking
|
|
|
|
for {
|
|
|
|
// Get length of text to speech string
|
|
|
|
ttsLen := len(tts)
|
|
|
|
time.Sleep(time.Second)
|
|
|
|
// If length has not changed
|
|
|
|
if ttsLen == len(tts) {
|
|
|
|
// Break out of for loop
|
|
|
|
break
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Log performing action
|
|
|
|
log.Info().Str("action", action.Name).Str("source", "voice").Msg("Performing action")
|
|
|
|
// Perform action matched by getAction()
|
|
|
|
done, err := performAction(action, &tts, plugins)
|
|
|
|
if err != nil {
|
|
|
|
log.Warn().Err(err).Str("action", action.Name).Msg("Error performing configured action")
|
|
|
|
}
|
|
|
|
// If action is complete
|
|
|
|
if done {
|
|
|
|
// Lock mutex of stream
|
|
|
|
safeStream.Lock()
|
|
|
|
// Reset stream and buffer
|
|
|
|
resetStream(safeStream, model, captured)
|
|
|
|
// Unlock mutex of stream
|
|
|
|
safeStream.Unlock()
|
|
|
|
// Return from goroutine
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|
|
|
|
// Unlock mutex of stream
|
|
|
|
safeStream.Unlock()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Function to reset stream and buffer
|
|
|
|
func resetStream(s *SafeStream, model *ds.Model, captured *bytes.Buffer) {
|
|
|
|
// Reset buffer
|
|
|
|
captured.Reset()
|
|
|
|
// Discard stream (workaround for lack of Clear function)
|
|
|
|
s.Discard()
|
|
|
|
// Create new stream, setting it to same location as old
|
|
|
|
s.Stream, _ = model.NewStream()
|
|
|
|
}
|