Assist last used: remember STT and record proactively (before connected) (#3755)

Assist last used: remember STT and record before connected

 - For the last used pipeline for Assist, remember whether or not it supports STT input, and if it does start recording proactively/as soon as possible to avoid missing voice input while doing network checks.
 - Fix potential wrong server while sending voice data.
 - Fix voice input remaining active after getting an error response.
This commit is contained in:
Joris Pelgröm 2023-08-04 16:59:27 +02:00 committed by GitHub
parent 56798849fe
commit add1955901
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 48 deletions

View file

@ -74,6 +74,7 @@ class AssistActivity : BaseActivity() {
if (savedInstanceState == null) { if (savedInstanceState == null) {
viewModel.onCreate( viewModel.onCreate(
hasPermission = hasRecordingPermission(),
serverId = if (intent.hasExtra(EXTRA_SERVER)) { serverId = if (intent.hasExtra(EXTRA_SERVER)) {
intent.getIntExtra(EXTRA_SERVER, ServerManager.SERVER_ID_ACTIVE) intent.getIntExtra(EXTRA_SERVER, ServerManager.SERVER_ID_ACTIVE)
} else { } else {
@ -137,9 +138,7 @@ class AssistActivity : BaseActivity() {
override fun onResume() { override fun onResume() {
super.onResume() super.onResume()
viewModel.setPermissionInfo( viewModel.setPermissionInfo(hasRecordingPermission()) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
} }
override fun onPause() { override fun onPause() {
@ -152,4 +151,7 @@ class AssistActivity : BaseActivity() {
this.intent = intent this.intent = intent
viewModel.onNewIntent(intent) viewModel.onNewIntent(intent)
} }
private fun hasRecordingPermission() =
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
} }

View file

@ -53,22 +53,37 @@ class AssistViewModel @Inject constructor(
var inputMode by mutableStateOf<AssistInputMode?>(null) var inputMode by mutableStateOf<AssistInputMode?>(null)
private set private set
fun onCreate(serverId: Int?, pipelineId: String?, startListening: Boolean?) { fun onCreate(hasPermission: Boolean, serverId: Int?, pipelineId: String?, startListening: Boolean?) {
viewModelScope.launch { viewModelScope.launch {
this@AssistViewModel.hasPermission = hasPermission
serverId?.let { serverId?.let {
filteredServerId = serverId filteredServerId = serverId
selectedServerId = serverId selectedServerId = serverId
} }
startListening?.let { recorderAutoStart = it } startListening?.let { recorderAutoStart = it }
val supported = checkSupport()
if (!serverManager.isRegistered()) { if (!serverManager.isRegistered()) {
inputMode = AssistInputMode.BLOCKED inputMode = AssistInputMode.BLOCKED
_conversation.clear() _conversation.clear()
_conversation.add( _conversation.add(
AssistMessage(app.getString(commonR.string.not_registered), isInput = false) AssistMessage(app.getString(commonR.string.not_registered), isInput = false)
) )
} else if (supported == null) { // Couldn't get config return@launch
}
if (
pipelineId == PIPELINE_LAST_USED && recorderAutoStart &&
hasPermission && hasMicrophone &&
serverManager.getServer(selectedServerId) != null &&
serverManager.integrationRepository(selectedServerId).getLastUsedPipelineSttSupport()
) {
// Start microphone recording to prevent missing voice input while doing network checks
onMicrophoneInput(proactive = true)
}
val supported = checkSupport()
if (supported != true) stopRecording()
if (supported == null) { // Couldn't get config
inputMode = AssistInputMode.BLOCKED inputMode = AssistInputMode.BLOCKED
_conversation.clear() _conversation.clear()
_conversation.add( _conversation.add(
@ -86,7 +101,7 @@ class AssistViewModel @Inject constructor(
} else { } else {
setPipeline( setPipeline(
when { when {
pipelineId == PIPELINE_LAST_USED -> serverManager.integrationRepository(selectedServerId).getLastUsedPipeline() pipelineId == PIPELINE_LAST_USED -> serverManager.integrationRepository(selectedServerId).getLastUsedPipelineId()
pipelineId == PIPELINE_PREFERRED -> null pipelineId == PIPELINE_PREFERRED -> null
pipelineId?.isNotBlank() == true -> pipelineId pipelineId?.isNotBlank() == true -> pipelineId
else -> null else -> null
@ -169,7 +184,7 @@ class AssistViewModel @Inject constructor(
id = it.id, id = it.id,
name = it.name name = it.name
) )
serverManager.integrationRepository(selectedServerId).setLastUsedPipeline(it.id) serverManager.integrationRepository(selectedServerId).setLastUsedPipeline(it.id, it.sttEngine != null)
_conversation.clear() _conversation.clear()
_conversation.add(startMessage) _conversation.add(startMessage)
@ -177,7 +192,7 @@ class AssistViewModel @Inject constructor(
if (hasMicrophone && it.sttEngine != null) { if (hasMicrophone && it.sttEngine != null) {
if (recorderAutoStart && (hasPermission || requestSilently)) { if (recorderAutoStart && (hasPermission || requestSilently)) {
inputMode = AssistInputMode.VOICE_INACTIVE inputMode = AssistInputMode.VOICE_INACTIVE
onMicrophoneInput() onMicrophoneInput(proactive = null)
} else { // already requested permission once and was denied } else { // already requested permission once and was denied
inputMode = AssistInputMode.TEXT inputMode = AssistInputMode.TEXT
} }
@ -219,31 +234,37 @@ class AssistViewModel @Inject constructor(
fun onTextInput(input: String) = runAssistPipeline(input) fun onTextInput(input: String) = runAssistPipeline(input)
fun onMicrophoneInput() { /**
* Start/stop microphone input for Assist, depending on the current state.
* @param proactive true if proactive, null if not important, false if not
*/
fun onMicrophoneInput(proactive: Boolean? = false) {
if (!hasPermission) { if (!hasPermission) {
requestPermission?.let { it() } requestPermission?.let { it() }
return return
} }
if (inputMode == AssistInputMode.VOICE_ACTIVE) { if (inputMode == AssistInputMode.VOICE_ACTIVE && proactive == false) {
stopRecording() stopRecording()
return return
} }
val recording = try { val recording = try {
audioRecorder.startRecording() recorderProactive || audioRecorder.startRecording()
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Exception while starting recording", e) Log.e(TAG, "Exception while starting recording", e)
false false
} }
if (recording) { if (recording) {
setupRecorderQueue() if (!recorderProactive) setupRecorderQueue()
inputMode = AssistInputMode.VOICE_ACTIVE inputMode = AssistInputMode.VOICE_ACTIVE
runAssistPipeline(null) if (proactive == true) _conversation.add(AssistMessage("", isInput = true))
if (proactive != true) runAssistPipeline(null)
} else { } else {
_conversation.add(AssistMessage(app.getString(commonR.string.assist_error), isInput = false, isError = true)) _conversation.add(AssistMessage(app.getString(commonR.string.assist_error), isInput = false, isError = true))
} }
recorderProactive = recording && proactive == true
} }
private fun runAssistPipeline(text: String?) { private fun runAssistPipeline(text: String?) {
@ -269,6 +290,9 @@ class AssistViewModel @Inject constructor(
_conversation.add(haMessage) _conversation.add(haMessage)
message = haMessage message = haMessage
} }
if (isError && inputMode == AssistInputMode.VOICE_ACTIVE) {
stopRecording()
}
} }
} }
} }
@ -280,15 +304,16 @@ class AssistViewModel @Inject constructor(
fun onPermissionResult(granted: Boolean) { fun onPermissionResult(granted: Boolean) {
hasPermission = granted hasPermission = granted
val proactive = currentPipeline == null
if (granted) { if (granted) {
inputMode = AssistInputMode.VOICE_INACTIVE inputMode = AssistInputMode.VOICE_INACTIVE
onMicrophoneInput() onMicrophoneInput(proactive = proactive)
} else if (requestSilently) { // Don't notify the user if they haven't explicitly requested } else if (requestSilently && !proactive) { // Don't notify the user if they haven't explicitly requested
inputMode = AssistInputMode.TEXT inputMode = AssistInputMode.TEXT
} else { } else if (!requestSilently) {
_conversation.add(AssistMessage(app.getString(commonR.string.assist_permission), isInput = false)) _conversation.add(AssistMessage(app.getString(commonR.string.assist_permission), isInput = false))
} }
requestSilently = false if (!proactive) requestSilently = false
} }
fun onPause() { fun onPause() {

View file

@ -43,6 +43,7 @@ abstract class AssistViewModelBase(
protected var selectedServerId = ServerManager.SERVER_ID_ACTIVE protected var selectedServerId = ServerManager.SERVER_ID_ACTIVE
protected var recorderProactive = false
private var recorderJob: Job? = null private var recorderJob: Job? = null
private var recorderQueue: MutableList<ByteArray>? = null private var recorderQueue: MutableList<ByteArray>? = null
protected val hasMicrophone = app.packageManager.hasSystemFeature(PackageManager.FEATURE_MICROPHONE) protected val hasMicrophone = app.packageManager.hasSystemFeature(PackageManager.FEATURE_MICROPHONE)
@ -99,8 +100,11 @@ abstract class AssistViewModelBase(
} }
AssistPipelineEventType.STT_START -> { AssistPipelineEventType.STT_START -> {
viewModelScope.launch { viewModelScope.launch {
recorderQueue?.forEach { item -> binaryHandlerId?.let { id ->
sendVoiceData(item) // Manually loop here to avoid the queue being reset too soon
recorderQueue?.forEach { data ->
serverManager.webSocketRepository(selectedServerId).sendVoiceData(id, data)
}
} }
recorderQueue = null recorderQueue = null
} }
@ -156,7 +160,7 @@ abstract class AssistViewModelBase(
binaryHandlerId?.let { binaryHandlerId?.let {
viewModelScope.launch { viewModelScope.launch {
// Launch to prevent blocking the output flow if the network is slow // Launch to prevent blocking the output flow if the network is slow
serverManager.webSocketRepository().sendVoiceData(it, data) serverManager.webSocketRepository(selectedServerId).sendVoiceData(it, data)
} }
} }
} }
@ -186,8 +190,9 @@ abstract class AssistViewModelBase(
recorderQueue = null recorderQueue = null
} }
if (getInput() == AssistInputMode.VOICE_ACTIVE) { if (getInput() == AssistInputMode.VOICE_ACTIVE) {
setInput(AssistInputMode.VOICE_INACTIVE) setInput(if (recorderProactive) AssistInputMode.BLOCKED else AssistInputMode.VOICE_INACTIVE)
} }
recorderProactive = false
} }
protected fun stopPlayback() = audioUrlPlayer.stop() protected fun stopPlayback() = audioUrlPlayer.stop()

View file

@ -63,9 +63,11 @@ interface IntegrationRepository {
conversationId: String? = null conversationId: String? = null
): Flow<AssistPipelineEvent>? ): Flow<AssistPipelineEvent>?
suspend fun getLastUsedPipeline(): String? suspend fun getLastUsedPipelineId(): String?
suspend fun setLastUsedPipeline(pipelineId: String) suspend fun getLastUsedPipelineSttSupport(): Boolean
suspend fun setLastUsedPipeline(pipelineId: String, supportsStt: Boolean)
} }
@AssistedFactory @AssistedFactory

View file

@ -62,7 +62,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
private const val PREF_SESSION_EXPIRE = "session_expire" private const val PREF_SESSION_EXPIRE = "session_expire"
private const val PREF_TRUSTED = "trusted" private const val PREF_TRUSTED = "trusted"
private const val PREF_SEC_WARNING_NEXT = "sec_warning_last" private const val PREF_SEC_WARNING_NEXT = "sec_warning_last"
private const val PREF_LAST_USED_PIPELINE = "last_used_pipeline" private const val PREF_LAST_USED_PIPELINE_ID = "last_used_pipeline"
private const val PREF_LAST_USED_PIPELINE_STT = "last_used_pipeline_stt"
private const val TAG = "IntegrationRepository" private const val TAG = "IntegrationRepository"
private const val RATE_LIMIT_URL = BuildConfig.RATE_LIMIT_URL private const val RATE_LIMIT_URL = BuildConfig.RATE_LIMIT_URL
@ -166,7 +167,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
localStorage.remove("${serverId}_$PREF_SESSION_EXPIRE") localStorage.remove("${serverId}_$PREF_SESSION_EXPIRE")
localStorage.remove("${serverId}_$PREF_TRUSTED") localStorage.remove("${serverId}_$PREF_TRUSTED")
localStorage.remove("${serverId}_$PREF_SEC_WARNING_NEXT") localStorage.remove("${serverId}_$PREF_SEC_WARNING_NEXT")
localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE") localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE_ID")
localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE_STT")
// app version and push token is device-specific // app version and push token is device-specific
} }
@ -552,11 +554,16 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
} }
} }
override suspend fun getLastUsedPipeline(): String? = override suspend fun getLastUsedPipelineId(): String? =
localStorage.getString("${serverId}_$PREF_LAST_USED_PIPELINE") localStorage.getString("${serverId}_$PREF_LAST_USED_PIPELINE_ID")
override suspend fun setLastUsedPipeline(pipelineId: String) = override suspend fun getLastUsedPipelineSttSupport(): Boolean =
localStorage.putString("${serverId}_$PREF_LAST_USED_PIPELINE", pipelineId) localStorage.getBoolean("${serverId}_$PREF_LAST_USED_PIPELINE_STT")
override suspend fun setLastUsedPipeline(pipelineId: String, supportsStt: Boolean) {
localStorage.putString("${serverId}_$PREF_LAST_USED_PIPELINE_ID", pipelineId)
localStorage.putBoolean("${serverId}_$PREF_LAST_USED_PIPELINE_STT", supportsStt)
}
override suspend fun getEntities(): List<Entity<Any>>? { override suspend fun getEntities(): List<Entity<Any>>? {
val response = webSocketRepository.getStates() val response = webSocketRepository.getStates()

View file

@ -48,7 +48,7 @@ class ConversationActivity : ComponentActivity() {
super.onCreate(savedInstanceState) super.onCreate(savedInstanceState)
lifecycleScope.launch { lifecycleScope.launch {
val launchIntent = conversationViewModel.onCreate() val launchIntent = conversationViewModel.onCreate(hasRecordingPermission())
if (launchIntent) { if (launchIntent) {
launchVoiceInputIntent() launchVoiceInputIntent()
} }
@ -64,9 +64,7 @@ class ConversationActivity : ComponentActivity() {
override fun onResume() { override fun onResume() {
super.onResume() super.onResume()
conversationViewModel.setPermissionInfo( conversationViewModel.setPermissionInfo(hasRecordingPermission()) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
} }
override fun onPause() { override fun onPause() {
@ -88,6 +86,9 @@ class ConversationActivity : ComponentActivity() {
} }
} }
private fun hasRecordingPermission() =
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
private fun launchVoiceInputIntent() { private fun launchVoiceInputIntent() {
val searchIntent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply { val searchIntent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra( putExtra(

View file

@ -56,14 +56,25 @@ class ConversationViewModel @Inject constructor(
val conversation: List<AssistMessage> = _conversation val conversation: List<AssistMessage> = _conversation
/** @return `true` if the voice input intent should be fired */ /** @return `true` if the voice input intent should be fired */
suspend fun onCreate(): Boolean { suspend fun onCreate(hasPermission: Boolean): Boolean {
val supported = checkAssistSupport() this.hasPermission = hasPermission
if (!serverManager.isRegistered()) { if (!serverManager.isRegistered()) {
_conversation.clear() _conversation.clear()
_conversation.add( _conversation.add(
AssistMessage(app.getString(commonR.string.not_registered), isInput = false) AssistMessage(app.getString(commonR.string.not_registered), isInput = false)
) )
} else if (supported == null) { // Couldn't get config return false
}
if (hasPermission && hasMicrophone && serverManager.integrationRepository().getLastUsedPipelineSttSupport()) {
// Start microphone recording to prevent missing voice input while doing network checks
onMicrophoneInput(proactive = true)
}
val supported = checkAssistSupport()
if (supported != true) stopRecording()
if (supported == null) { // Couldn't get config
_conversation.clear() _conversation.clear()
_conversation.add( _conversation.add(
AssistMessage(app.getString(commonR.string.assist_connnect), isInput = false) AssistMessage(app.getString(commonR.string.assist_connnect), isInput = false)
@ -90,7 +101,7 @@ class ConversationViewModel @Inject constructor(
return setPipeline( return setPipeline(
if (useAssistPipeline) { if (useAssistPipeline) {
serverManager.integrationRepository().getLastUsedPipeline() serverManager.integrationRepository().getLastUsedPipelineId()
} else { } else {
null null
} }
@ -168,7 +179,7 @@ class ConversationViewModel @Inject constructor(
if (pipeline != null || !useAssistPipeline) { if (pipeline != null || !useAssistPipeline) {
currentPipeline = pipeline currentPipeline = pipeline
currentPipeline?.let { currentPipeline?.let {
serverManager.integrationRepository().setLastUsedPipeline(it.id) serverManager.integrationRepository().setLastUsedPipeline(it.id, pipeline?.sttEngine != null)
} }
_conversation.clear() _conversation.clear()
@ -178,7 +189,7 @@ class ConversationViewModel @Inject constructor(
if (hasPermission || requestSilently) { if (hasPermission || requestSilently) {
inputMode = AssistInputMode.VOICE_INACTIVE inputMode = AssistInputMode.VOICE_INACTIVE
useAssistPipelineStt = true useAssistPipelineStt = true
onMicrophoneInput() onMicrophoneInput(proactive = null)
} else { } else {
inputMode = AssistInputMode.TEXT inputMode = AssistInputMode.TEXT
} }
@ -198,31 +209,37 @@ class ConversationViewModel @Inject constructor(
fun updateSpeechResult(commonResult: String) = runAssistPipeline(commonResult) fun updateSpeechResult(commonResult: String) = runAssistPipeline(commonResult)
fun onMicrophoneInput() { /**
* Start/stop microphone input for Assist, depending on the current state.
* @param proactive true if proactive, null if not important, false if not
*/
fun onMicrophoneInput(proactive: Boolean? = false) {
if (!hasPermission) { if (!hasPermission) {
requestPermission?.let { it() } requestPermission?.let { it() }
return return
} }
if (inputMode == AssistInputMode.VOICE_ACTIVE) { if (inputMode == AssistInputMode.VOICE_ACTIVE && proactive == false) {
stopRecording() stopRecording()
return return
} }
val recording = try { val recording = try {
audioRecorder.startRecording() recorderProactive || audioRecorder.startRecording()
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Exception while starting recording", e) Log.e(TAG, "Exception while starting recording", e)
false false
} }
if (recording) { if (recording) {
setupRecorderQueue() if (!recorderProactive) setupRecorderQueue()
inputMode = AssistInputMode.VOICE_ACTIVE inputMode = AssistInputMode.VOICE_ACTIVE
runAssistPipeline(null) if (proactive == true) _conversation.add(AssistMessage("", isInput = true))
if (proactive != true) runAssistPipeline(null)
} else { } else {
_conversation.add(AssistMessage(app.getString(commonR.string.assist_error), isInput = false, isError = true)) _conversation.add(AssistMessage(app.getString(commonR.string.assist_error), isInput = false, isError = true))
} }
recorderProactive = recording && proactive == true
} }
private fun runAssistPipeline(text: String?) { private fun runAssistPipeline(text: String?) {
@ -248,6 +265,9 @@ class ConversationViewModel @Inject constructor(
_conversation.add(haMessage) _conversation.add(haMessage)
message = haMessage message = haMessage
} }
if (isError && inputMode == AssistInputMode.VOICE_ACTIVE) {
stopRecording()
}
} }
} }
} }
@ -260,14 +280,15 @@ class ConversationViewModel @Inject constructor(
fun onPermissionResult(granted: Boolean, voiceInputIntent: (() -> Unit)) { fun onPermissionResult(granted: Boolean, voiceInputIntent: (() -> Unit)) {
hasPermission = granted hasPermission = granted
useAssistPipelineStt = currentPipeline?.sttEngine != null && granted useAssistPipelineStt = currentPipeline?.sttEngine != null && granted
val proactive = currentPipeline == null
if (granted) { if (granted) {
inputMode = AssistInputMode.VOICE_INACTIVE inputMode = AssistInputMode.VOICE_INACTIVE
onMicrophoneInput() onMicrophoneInput(proactive = proactive)
} else if (requestSilently) { // Don't notify the user if they haven't explicitly requested } else if (requestSilently && !proactive) { // Don't notify the user if they haven't explicitly requested
inputMode = AssistInputMode.TEXT inputMode = AssistInputMode.TEXT
voiceInputIntent() voiceInputIntent()
} }
requestSilently = false if (!proactive) requestSilently = false
} }
fun onConversationScreenHidden() { fun onConversationScreenHidden() {