From 7d6f11af4fc4508a7384cb4662da1719233ba7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joris=20Pelgr=C3=B6m?= Date: Sat, 13 May 2023 04:51:47 +0200 Subject: [PATCH] Add support for Assist pipeline, update Wear implementation (#3526) * Group incoming messages by subscription to prevent out-of-order delivery - Messages received on the websocket are processed asynchronously, which is usually fine but can cause issues if messages need to be received in a specific order for a subscription. To fix this, process messages in order for the same subscription. * Implement Assist pipeline API - Add basic support for the Assist pipeline API - Update conversation function to use the Assist pipeline when on the minimum required version - Update UI to refer to Assist pipeline requirement --- .../data/integration/IntegrationRepository.kt | 2 +- .../impl/IntegrationRepositoryImpl.kt | 38 +++++++++++-- .../data/websocket/WebSocketRepository.kt | 12 +++++ .../websocket/impl/WebSocketRepositoryImpl.kt | 54 +++++++++++++++---- .../impl/entities/AssistPipelineEvent.kt | 41 ++++++++++++++ common/src/main/res/values/strings.xml | 4 +- .../conversation/ConversationActivity.kt | 4 +- .../conversation/ConversationViewModel.kt | 25 ++++++--- .../conversation/views/ConversationView.kt | 9 +++- 9 files changed, 161 insertions(+), 28 deletions(-) create mode 100644 common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/entities/AssistPipelineEvent.kt diff --git a/common/src/main/java/io/homeassistant/companion/android/common/data/integration/IntegrationRepository.kt b/common/src/main/java/io/homeassistant/companion/android/common/data/integration/IntegrationRepository.kt index 9fda5078c..2ecc2a737 100644 --- a/common/src/main/java/io/homeassistant/companion/android/common/data/integration/IntegrationRepository.kt +++ b/common/src/main/java/io/homeassistant/companion/android/common/data/integration/IntegrationRepository.kt @@ -56,7 +56,7 @@ interface IntegrationRepository { suspend fun shouldNotifySecurityWarning(): Boolean - suspend fun getConversation(speech: String): String? + suspend fun getAssistResponse(speech: String): String? } @AssistedFactory diff --git a/common/src/main/java/io/homeassistant/companion/android/common/data/integration/impl/IntegrationRepositoryImpl.kt b/common/src/main/java/io/homeassistant/companion/android/common/data/integration/impl/IntegrationRepositoryImpl.kt index 39205a38b..5e6f31c89 100644 --- a/common/src/main/java/io/homeassistant/companion/android/common/data/integration/impl/IntegrationRepositoryImpl.kt +++ b/common/src/main/java/io/homeassistant/companion/android/common/data/integration/impl/IntegrationRepositoryImpl.kt @@ -25,13 +25,21 @@ import io.homeassistant.companion.android.common.data.integration.impl.entities. import io.homeassistant.companion.android.common.data.integration.impl.entities.Template import io.homeassistant.companion.android.common.data.integration.impl.entities.UpdateLocationRequest import io.homeassistant.companion.android.common.data.servers.ServerManager +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEventType +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentEnd import io.homeassistant.companion.android.common.data.websocket.impl.entities.GetConfigResponse +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.map +import kotlinx.coroutines.launch +import kotlinx.coroutines.suspendCancellableCoroutine import okhttp3.HttpUrl.Companion.toHttpUrlOrNull import java.util.concurrent.TimeUnit import javax.inject.Named +import kotlin.coroutines.resume class IntegrationRepositoryImpl @AssistedInject constructor( private val integrationService: IntegrationService, @@ -64,6 +72,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor( private const val APPLOCK_TIMEOUT_GRACE_MS = 1000 } + private val ioScope = CoroutineScope(Dispatchers.IO + Job()) + private val server get() = serverManager.getServer(serverId)!! private val webSocketRepository get() = serverManager.webSocketRepository(serverId) @@ -523,11 +533,29 @@ class IntegrationRepositoryImpl @AssistedInject constructor( }?.toList() } - override suspend fun getConversation(speech: String): String? { - // TODO: Also send back conversation ID for dialogue - val response = webSocketRepository.getConversation(speech) - - return response?.response?.speech?.plain?.get("speech") + override suspend fun getAssistResponse(speech: String): String? { + return if (server.version?.isAtLeast(2023, 5, 0) == true) { + var job: Job? = null + val response = suspendCancellableCoroutine { cont -> + job = ioScope.launch { + webSocketRepository.runAssistPipeline(speech)?.collect { + if (!cont.isActive) return@collect + when (it.type) { + AssistPipelineEventType.INTENT_END -> + cont.resume((it.data as AssistPipelineIntentEnd).intentOutput.response.speech.plain["speech"]) + AssistPipelineEventType.ERROR, + AssistPipelineEventType.RUN_END -> cont.resume(null) + else -> { /* Do nothing */ } + } + } ?: cont.resume(null) + } + } + job?.cancel() + response + } else { + val response = webSocketRepository.getConversation(speech) + response?.response?.speech?.plain?.get("speech") + } } override suspend fun getEntities(): List>? { diff --git a/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/WebSocketRepository.kt b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/WebSocketRepository.kt index eb9db317a..9f019a172 100644 --- a/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/WebSocketRepository.kt +++ b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/WebSocketRepository.kt @@ -5,6 +5,7 @@ import io.homeassistant.companion.android.common.data.integration.impl.entities. import io.homeassistant.companion.android.common.data.websocket.impl.WebSocketRepositoryImpl import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryUpdatedEvent +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEvent import io.homeassistant.companion.android.common.data.websocket.impl.entities.CompressedStateChangedEvent import io.homeassistant.companion.android.common.data.websocket.impl.entities.ConversationResponse import io.homeassistant.companion.android.common.data.websocket.impl.entities.CurrentUserResponse @@ -48,7 +49,18 @@ interface WebSocketRepository { suspend fun getThreadDatasets(): List? suspend fun getThreadDatasetTlv(datasetId: String): ThreadDatasetTlvResponse? suspend fun addThreadDataset(tlv: ByteArray): Boolean + + /** + * Get an Assist response for the given text input. For core >= 2023.5, use [runAssistPipeline] + * instead. + */ suspend fun getConversation(speech: String): ConversationResponse? + + /** + * Run the Assist pipeline for the given text input + * @return a Flow that will emit all events for the pipeline + */ + suspend fun runAssistPipeline(text: String): Flow? } @AssistedFactory diff --git a/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/WebSocketRepositoryImpl.kt b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/WebSocketRepositoryImpl.kt index 7d52fda50..4ad24656a 100644 --- a/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/WebSocketRepositoryImpl.kt +++ b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/WebSocketRepositoryImpl.kt @@ -23,6 +23,11 @@ import io.homeassistant.companion.android.common.data.websocket.WebSocketRequest import io.homeassistant.companion.android.common.data.websocket.WebSocketState import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryUpdatedEvent +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEvent +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEventType +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentEnd +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentStart +import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineRunStart import io.homeassistant.companion.android.common.data.websocket.impl.entities.CompressedStateChangedEvent import io.homeassistant.companion.android.common.data.websocket.impl.entities.ConversationResponse import io.homeassistant.companion.android.common.data.websocket.impl.entities.CurrentUserResponse @@ -81,6 +86,7 @@ class WebSocketRepositoryImpl @AssistedInject constructor( companion object { private const val TAG = "WebSocketRepository" + private const val SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN = "assist_pipeline/run" private const val SUBSCRIBE_TYPE_SUBSCRIBE_EVENTS = "subscribe_events" private const val SUBSCRIBE_TYPE_SUBSCRIBE_ENTITIES = "subscribe_entities" private const val SUBSCRIBE_TYPE_SUBSCRIBE_TRIGGER = "subscribe_trigger" @@ -209,6 +215,18 @@ class WebSocketRepositoryImpl @AssistedInject constructor( return mapResponse(socketResponse) } + override suspend fun runAssistPipeline(text: String): Flow? = + subscribeTo( + SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN, + mapOf( + "start_stage" to "intent", + "end_stage" to "intent", + "input" to mapOf( + "text" to text + ) + ) + ) + override suspend fun getStateChanges(): Flow? = subscribeToEventsForType(EVENT_STATE_CHANGED) @@ -629,6 +647,21 @@ class WebSocketRepositoryImpl @AssistedInject constructor( Log.w(TAG, "Received no trigger value for trigger subscription, skipping") return } + } else if (subscriptionType == SUBSCRIBE_TYPE_ASSIST_PIPELINE_RUN) { + val eventType = response.event?.get("type") + if (eventType?.isTextual == true) { + val eventDataMap = response.event.get("data") + val eventData = when (eventType.textValue()) { + AssistPipelineEventType.RUN_START -> mapper.convertValue(eventDataMap, AssistPipelineRunStart::class.java) + AssistPipelineEventType.INTENT_START -> mapper.convertValue(eventDataMap, AssistPipelineIntentStart::class.java) + AssistPipelineEventType.INTENT_END -> mapper.convertValue(eventDataMap, AssistPipelineIntentEnd::class.java) + else -> null + } + AssistPipelineEvent(eventType.textValue(), eventData) + } else { + Log.w(TAG, "Received Assist pipeline event without type, skipping") + return + } } else if (eventResponseType != null && eventResponseType.isTextual) { val eventResponseClass = when (eventResponseType.textValue()) { EVENT_STATE_CHANGED -> @@ -737,17 +770,18 @@ class WebSocketRepositoryImpl @AssistedInject constructor( listOf(mapper.readValue(text)) } - messages.forEach { message -> - Log.d(TAG, "Message number ${message.id} received") - + messages.groupBy { it.id }.values.forEach { messagesForId -> ioScope.launch { - when (message.type) { - "auth_required" -> Log.d(TAG, "Auth Requested") - "auth_ok" -> handleAuthComplete(true, message.haVersion) - "auth_invalid" -> handleAuthComplete(false, message.haVersion) - "pong", "result" -> handleMessage(message) - "event" -> handleEvent(message) - else -> Log.d(TAG, "Unknown message type: ${message.type}") + messagesForId.forEach { message -> + Log.d(TAG, "Message number ${message.id} received") + when (message.type) { + "auth_required" -> Log.d(TAG, "Auth Requested") + "auth_ok" -> handleAuthComplete(true, message.haVersion) + "auth_invalid" -> handleAuthComplete(false, message.haVersion) + "pong", "result" -> handleMessage(message) + "event" -> handleEvent(message) + else -> Log.d(TAG, "Unknown message type: ${message.type}") + } } } } diff --git a/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/entities/AssistPipelineEvent.kt b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/entities/AssistPipelineEvent.kt new file mode 100644 index 000000000..dbf3c988a --- /dev/null +++ b/common/src/main/java/io/homeassistant/companion/android/common/data/websocket/impl/entities/AssistPipelineEvent.kt @@ -0,0 +1,41 @@ +package io.homeassistant.companion.android.common.data.websocket.impl.entities + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties + +data class AssistPipelineEvent( + val type: String, + val data: AssistPipelineEventData? +) + +object AssistPipelineEventType { + const val RUN_START = "run-start" + const val RUN_END = "run-end" + const val STT_START = "stt-start" + const val STT_END = "stt-end" + const val INTENT_START = "intent-start" + const val INTENT_END = "intent-end" + const val TTS_START = "tts-start" + const val TTS_END = "tts-end" + const val ERROR = "error" +} + +interface AssistPipelineEventData + +@JsonIgnoreProperties(ignoreUnknown = true) +data class AssistPipelineRunStart( + val pipeline: String, + val language: String, + val runnerData: Map +) : AssistPipelineEventData + +@JsonIgnoreProperties(ignoreUnknown = true) +data class AssistPipelineIntentStart( + val engine: String, + val language: String, + val intentInput: String +) : AssistPipelineEventData + +@JsonIgnoreProperties(ignoreUnknown = true) +data class AssistPipelineIntentEnd( + val intentOutput: ConversationResponse +) : AssistPipelineEventData diff --git a/common/src/main/res/values/strings.xml b/common/src/main/res/values/strings.xml index 685d44ea9..fe77911c8 100644 --- a/common/src/main/res/values/strings.xml +++ b/common/src/main/res/values/strings.xml @@ -1052,7 +1052,9 @@ Vibrate when clicked Requires unlocked device No results yet - You must be at least on Home Assistant 2023.1 and have the conversation integration enabled + You must be at least on Home Assistant %1$s and have the %2$s integration enabled + conversation + Assist pipeline Conversation Assist Log in to Home Assistant to start using Assist diff --git a/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationActivity.kt b/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationActivity.kt index 656fc7ed5..3de9f6ac7 100755 --- a/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationActivity.kt +++ b/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationActivity.kt @@ -42,8 +42,8 @@ class ConversationActivity : ComponentActivity() { super.onCreate(savedInstanceState) lifecycleScope.launch { - conversationViewModel.isSupportConversation() - if (conversationViewModel.supportsConversation) { + conversationViewModel.checkAssistSupport() + if (conversationViewModel.supportsAssist) { val searchIntent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply { putExtra( RecognizerIntent.EXTRA_LANGUAGE_MODEL, diff --git a/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationViewModel.kt b/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationViewModel.kt index 32d7b21fd..d5afc65e4 100755 --- a/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationViewModel.kt +++ b/wear/src/main/java/io/homeassistant/companion/android/conversation/ConversationViewModel.kt @@ -25,7 +25,10 @@ class ConversationViewModel @Inject constructor( var conversationResult by mutableStateOf("") private set - var supportsConversation by mutableStateOf(false) + var supportsAssist by mutableStateOf(false) + private set + + var useAssistPipeline by mutableStateOf(false) private set var isHapticEnabled = mutableStateOf(false) @@ -41,20 +44,28 @@ class ConversationViewModel @Inject constructor( viewModelScope.launch { conversationResult = if (serverManager.isRegistered()) { - serverManager.integrationRepository().getConversation(speechResult) ?: "" + serverManager.integrationRepository().getAssistResponse(speechResult) ?: "" } else { "" } } } - suspend fun isSupportConversation() { + suspend fun checkAssistSupport() { checkSupportProgress = true isRegistered = serverManager.isRegistered() - supportsConversation = - serverManager.isRegistered() && - serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 1, 0) && - serverManager.webSocketRepository().getConfig()?.components?.contains("conversation") == true + + if (serverManager.isRegistered()) { + val config = serverManager.webSocketRepository().getConfig() + val onConversationVersion = serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 1, 0) + val onPipelineVersion = serverManager.integrationRepository().isHomeAssistantVersionAtLeast(2023, 5, 0) + + supportsAssist = + (onConversationVersion && !onPipelineVersion && config?.components?.contains("conversation") == true) || + (onPipelineVersion && config?.components?.contains("assist_pipeline") == true) + useAssistPipeline = onPipelineVersion + } + isHapticEnabled.value = wearPrefsRepository.getWearHapticFeedback() checkSupportProgress = false } diff --git a/wear/src/main/java/io/homeassistant/companion/android/conversation/views/ConversationView.kt b/wear/src/main/java/io/homeassistant/companion/android/conversation/views/ConversationView.kt index ab237f5ea..546e0807b 100755 --- a/wear/src/main/java/io/homeassistant/companion/android/conversation/views/ConversationView.kt +++ b/wear/src/main/java/io/homeassistant/companion/android/conversation/views/ConversationView.kt @@ -56,8 +56,13 @@ fun ConversationResultView( SpeechBubble( text = conversationViewModel.speechResult.ifEmpty { when { - (conversationViewModel.supportsConversation) -> stringResource(R.string.no_results) - (!conversationViewModel.supportsConversation && !conversationViewModel.checkSupportProgress) -> stringResource(R.string.no_conversation_support) + conversationViewModel.supportsAssist -> stringResource(R.string.no_results) + (!conversationViewModel.supportsAssist && !conversationViewModel.checkSupportProgress) -> + if (conversationViewModel.useAssistPipeline) { + stringResource(R.string.no_assist_support, "2023.5", stringResource(R.string.no_assist_support_assist_pipeline)) + } else { + stringResource(R.string.no_assist_support, "2023.1", stringResource(R.string.no_assist_support_conversation)) + } (!conversationViewModel.isRegistered) -> stringResource(R.string.not_registered) else -> "..." }