Work on error handling from Websocket. (#2091)

* Work on error handling from Websocket.

* Fix wear build.
This commit is contained in:
Justin Bassett 2021-12-30 15:55:35 -05:00 committed by GitHub
parent 5711e0e1b2
commit 476bd8984a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 305 additions and 346 deletions

View file

@ -86,7 +86,7 @@ class SettingsWearViewModel @Inject constructor(
).show()
}
viewModelScope.launch {
integrationUseCase.getEntities().forEach {
integrationUseCase.getEntities()?.forEach {
entities[it.entityId] = it
}
}

View file

@ -66,7 +66,7 @@ class HaControlsProviderService : ControlsProviderService() {
integrationRepository
.getEntities()
.mapNotNull {
?.mapNotNull {
val domain = it.entityId.split(".")[0]
domainToHaControl[domain]?.createControl(
applicationContext,
@ -74,7 +74,7 @@ class HaControlsProviderService : ControlsProviderService() {
getAreaForEntity(it.entityId, areaRegistry, deviceRegistry, entityRegistry)
)
}
.forEach {
?.forEach {
subscriber.onNext(it)
}
} catch (e: Exception) {
@ -106,13 +106,17 @@ class HaControlsProviderService : ControlsProviderService() {
val entities = mutableMapOf<String, Entity<Map<String, Any>>>()
controlIds.forEach {
val entity = integrationRepository.getEntity(it)
entities[it] = entity
if (entity != null) {
entities[it] = entity
} else {
Log.e(TAG, "Unable to get $it from Home Assistant.")
}
}
sendEntitiesToSubscriber(subscriber, entities, areaRegistry, deviceRegistry, entityRegistry)
// Listen for the state changed events.
webSocketScope.launch {
entityFlow.collect {
entityFlow?.collect {
if (controlIds.contains(it.entityId)) {
val domain = it.entityId.split(".")[0]
val control = domainToHaControl[domain]?.createControl(
@ -125,19 +129,19 @@ class HaControlsProviderService : ControlsProviderService() {
}
}
webSocketScope.launch {
areaRegistryFlow.collect {
areaRegistryFlow?.collect {
areaRegistry = webSocketRepository.getAreaRegistry()
sendEntitiesToSubscriber(subscriber, entities, areaRegistry, deviceRegistry, entityRegistry)
}
}
webSocketScope.launch {
deviceRegistryFlow.collect {
deviceRegistryFlow?.collect {
deviceRegistry = webSocketRepository.getDeviceRegistry()
sendEntitiesToSubscriber(subscriber, entities, areaRegistry, deviceRegistry, entityRegistry)
}
}
webSocketScope.launch {
entityRegistryFlow.collect { event ->
entityRegistryFlow?.collect { event ->
if (event.action == "update" && controlIds.contains(event.entityId)) {
entityRegistry = webSocketRepository.getEntityRegistry()
sendEntitiesToSubscriber(subscriber, entities, areaRegistry, deviceRegistry, entityRegistry)
@ -182,9 +186,9 @@ class HaControlsProviderService : ControlsProviderService() {
private fun sendEntitiesToSubscriber(
subscriber: Flow.Subscriber<in Control>,
entities: Map<String, Entity<Map<String, Any>>>,
areaRegistry: List<AreaRegistryResponse>,
deviceRegistry: List<DeviceRegistryResponse>,
entityRegistry: List<EntityRegistryResponse>
areaRegistry: List<AreaRegistryResponse>?,
deviceRegistry: List<DeviceRegistryResponse>?,
entityRegistry: List<EntityRegistryResponse>?
) {
entities.forEach {
val domain = it.key.split(".")[0]
@ -199,20 +203,20 @@ class HaControlsProviderService : ControlsProviderService() {
private fun getAreaForEntity(
entityId: String,
areaRegistry: List<AreaRegistryResponse>,
deviceRegistry: List<DeviceRegistryResponse>,
entityRegistry: List<EntityRegistryResponse>
areaRegistry: List<AreaRegistryResponse>?,
deviceRegistry: List<DeviceRegistryResponse>?,
entityRegistry: List<EntityRegistryResponse>?
): AreaRegistryResponse? {
val rEntity = entityRegistry.firstOrNull { it.entityId == entityId }
val rEntity = entityRegistry?.firstOrNull { it.entityId == entityId }
if (rEntity != null) {
// By default, an entity should be considered to be in the same area as the associated device (if any)
// This can be overridden for an individual entity, so check the entity registry first
if (rEntity.areaId != null) {
return areaRegistry.firstOrNull { it.areaId == rEntity.areaId }
return areaRegistry?.firstOrNull { it.areaId == rEntity.areaId }
} else if (rEntity.deviceId != null) {
val rDevice = deviceRegistry.firstOrNull { it.id == rEntity.deviceId }
val rDevice = deviceRegistry?.firstOrNull { it.id == rEntity.deviceId }
if (rDevice != null) {
return areaRegistry.firstOrNull { it.areaId == rDevice.areaId }
return areaRegistry?.firstOrNull { it.areaId == rDevice.areaId }
}
}
}

View file

@ -75,7 +75,7 @@ abstract class TileExtensions : TileService() {
}
if (tileData.entityId.split('.')[0] in toggleDomains) {
val state = runBlocking { integrationUseCase.getEntity(tileData.entityId) }
tile.state = if (state.state == "on" || state.state == "open") Tile.STATE_ACTIVE else Tile.STATE_INACTIVE
tile.state = if (state?.state == "on" || state?.state == "open") Tile.STATE_ACTIVE else Tile.STATE_INACTIVE
} else
tile.state = Tile.STATE_INACTIVE
val iconId = tileData.iconId

View file

@ -97,7 +97,7 @@ class ManageTilesFragment constructor(
runBlocking {
try {
integrationRepository.getEntities().forEach {
integrationRepository.getEntities()?.forEach {
val split = it.entityId.split(".")
if (split[0] in validDomains)
entityList = entityList + it.entityId

View file

@ -98,7 +98,7 @@ class ManageShortcutsSettingsFragment : PreferenceFragmentCompat(), IconDialog.C
runBlocking {
try {
integrationUseCase.getEntities().forEach {
integrationUseCase.getEntities()?.forEach {
entityList = entityList + it.entityId
}
} catch (e: Exception) {

View file

@ -270,7 +270,7 @@ class ButtonWidgetConfigureActivity : BaseActivity(), IconDialog.Callback {
mainScope.launch {
try {
// Fetch services
integrationUseCase.getServices().forEach {
integrationUseCase.getServices()?.forEach {
services[getServiceString(it)] = it
}
Log.d(TAG, "Services found: $services")
@ -305,7 +305,7 @@ class ButtonWidgetConfigureActivity : BaseActivity(), IconDialog.Callback {
dynamicFields.add(ServiceFieldBinder(serviceText, fieldKey))
}
}
integrationUseCase.getEntities().forEach {
integrationUseCase.getEntities()?.forEach {
entities[it.entityId] = it
}
dynamicFieldAdapter.notifyDataSetChanged()
@ -326,7 +326,7 @@ class ButtonWidgetConfigureActivity : BaseActivity(), IconDialog.Callback {
try {
// Fetch entities
integrationUseCase.getEntities().forEach {
integrationUseCase.getEntities()?.forEach {
entities[it.entityId] = it
}
} catch (e: Exception) {

View file

@ -17,7 +17,6 @@ import com.squareup.picasso.Picasso
import dagger.hilt.android.AndroidEntryPoint
import io.homeassistant.companion.android.BuildConfig
import io.homeassistant.companion.android.R
import io.homeassistant.companion.android.common.data.integration.Entity
import io.homeassistant.companion.android.common.data.integration.IntegrationRepository
import io.homeassistant.companion.android.common.data.url.UrlRepository
import io.homeassistant.companion.android.database.AppDatabase
@ -173,11 +172,9 @@ class CameraWidget : AppWidgetProvider() {
}
private suspend fun retrieveCameraImageUrl(context: Context, entityId: String): String? {
val entity: Entity<Map<String, Any>>
try {
entity = integrationUseCase.getEntity(entityId)
} catch (e: Exception) {
Log.e(TAG, "Failed to fetch entity or entity does not exist", e)
val entity = integrationUseCase.getEntity(entityId)
if (entity == null) {
Log.e(TAG, "Failed to fetch entity or entity does not exist")
if (lastIntent == UPDATE_IMAGE)
Toast.makeText(context, commonR.string.widget_entity_fetch_error, Toast.LENGTH_LONG).show()
return null

View file

@ -76,7 +76,7 @@ class CameraWidgetConfigureActivity : BaseActivity() {
try {
// Fetch entities
val fetchedEntities = integrationUseCase.getEntities()
fetchedEntities.forEach {
fetchedEntities?.forEach {
val entityId = it.entityId
val domain = entityId.split(".")[0]

View file

@ -137,7 +137,7 @@ class EntityWidgetConfigureActivity : BaseActivity() {
try {
// Fetch entities
val fetchedEntities = integrationUseCase.getEntities()
fetchedEntities.forEach {
fetchedEntities?.forEach {
entities[it.entityId] = it
}
entityAdapter.addAll(entities.values)

View file

@ -343,7 +343,7 @@ class MediaPlayerControlsWidget : BaseWidgetProvider() {
}
private suspend fun getEntity(context: Context, entityId: String, suggestedEntity: Entity<Map<String, Any>>?): Entity<Map<String, Any>>? {
val entity: Entity<Map<String, Any>>
val entity: Entity<Map<String, Any>>?
try {
entity = if (suggestedEntity != null && suggestedEntity.entityId == entityId) {
suggestedEntity
@ -477,10 +477,8 @@ class MediaPlayerControlsWidget : BaseWidgetProvider() {
"entity id: " + entity.entityId + System.lineSeparator()
)
val currentEntityInfo: Entity<Map<String, Any>>
try {
currentEntityInfo = integrationUseCase.getEntity(entity.entityId)
} catch (e: Exception) {
val currentEntityInfo = integrationUseCase.getEntity(entity.entityId)
if (currentEntityInfo == null) {
Log.d(TAG, "Failed to fetch entity or entity does not exist")
if (lastIntent != Intent.ACTION_SCREEN_ON)
Toast.makeText(context, commonR.string.widget_entity_fetch_error, Toast.LENGTH_LONG).show()
@ -548,10 +546,8 @@ class MediaPlayerControlsWidget : BaseWidgetProvider() {
"entity id: " + entity.entityId + System.lineSeparator()
)
val currentEntityInfo: Entity<Map<String, Any>>
try {
currentEntityInfo = integrationUseCase.getEntity(entity.entityId)
} catch (e: Exception) {
val currentEntityInfo = integrationUseCase.getEntity(entity.entityId)
if (currentEntityInfo == null) {
Log.d(TAG, "Failed to fetch entity or entity does not exist")
if (lastIntent != Intent.ACTION_SCREEN_ON)
Toast.makeText(context, commonR.string.widget_entity_fetch_error, Toast.LENGTH_LONG).show()

View file

@ -105,7 +105,7 @@ class MediaPlayerControlsWidgetConfigureActivity : BaseActivity() {
try {
// Fetch entities
val fetchedEntities = integrationUseCase.getEntities()
fetchedEntities.forEach {
fetchedEntities?.forEach {
val entityId = it.entityId
val domain = entityId.split(".")[0]

View file

@ -8,7 +8,6 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategy
import com.fasterxml.jackson.module.kotlin.registerKotlinModule
import io.homeassistant.companion.android.common.BuildConfig
import io.homeassistant.companion.android.common.data.url.UrlRepository
import kotlinx.coroutines.runBlocking
import okhttp3.OkHttpClient
import okhttp3.logging.HttpLoggingInterceptor
import retrofit2.Retrofit
@ -34,24 +33,13 @@ class HomeAssistantApis @Inject constructor(private val urlRepository: UrlReposi
}
)
}
builder.addInterceptor {
return@addInterceptor if (it.request().url.toString().contains(LOCAL_HOST)) {
val newRequest = runBlocking {
it.request().newBuilder()
.url(
it.request().url.toString()
.replace(LOCAL_HOST, urlRepository.getUrl().toString())
)
.header(
USER_AGENT,
"$USER_AGENT_STRING ${Build.MODEL} ${BuildConfig.VERSION_NAME}"
)
.build()
}
it.proceed(newRequest)
} else {
it.proceed(it.request())
}
builder.addNetworkInterceptor {
it.proceed(
it.request()
.newBuilder()
.header(USER_AGENT, "$USER_AGENT_STRING ${Build.MODEL} ${BuildConfig.VERSION_NAME}")
.build()
)
}
// Only deal with cookies when on non wear device and for now I don't have a better
// way to determine if we are really on wear os....

View file

@ -1,5 +1,6 @@
package io.homeassistant.companion.android.common.data.authentication.impl
import android.util.Log
import com.fasterxml.jackson.databind.ObjectMapper
import io.homeassistant.companion.android.common.data.LocalStorage
import io.homeassistant.companion.android.common.data.authentication.AuthenticationRepository
@ -22,6 +23,7 @@ class AuthenticationRepositoryImpl @Inject constructor(
) : AuthenticationRepository {
companion object {
private const val TAG = "AuthRepo"
private const val PREF_ACCESS_TOKEN = "access_token"
private const val PREF_EXPIRED_DATE = "expires_date"
private const val PREF_REFRESH_TOKEN = "refresh_token"
@ -51,7 +53,13 @@ class AuthenticationRepositoryImpl @Inject constructor(
}
override suspend fun registerAuthorizationCode(authorizationCode: String) {
val url = urlRepository.getUrl()?.toHttpUrlOrNull()
if (url == null) {
Log.e(TAG, "Unable to register auth code.")
return
}
authenticationService.getToken(
url.newBuilder().addPathSegments("auth/token").build(),
AuthenticationService.GRANT_TYPE_CODE,
authorizationCode,
AuthenticationService.CLIENT_ID
@ -76,8 +84,17 @@ class AuthenticationRepositoryImpl @Inject constructor(
}
override suspend fun revokeSession() {
val session = retrieveSession() ?: throw AuthorizationException()
authenticationService.revokeToken(session.refreshToken, AuthenticationService.REVOKE_ACTION)
val session = retrieveSession()
val url = urlRepository.getUrl()?.toHttpUrlOrNull()
if (session == null || url == null) {
Log.e(TAG, "Unable to revoke session.")
return
}
authenticationService.revokeToken(
url.newBuilder().addPathSegments("auth/token").build(),
session.refreshToken,
AuthenticationService.REVOKE_ACTION
)
saveSession(null)
urlRepository.saveUrl("", true)
urlRepository.saveUrl("", false)
@ -132,10 +149,16 @@ class AuthenticationRepositoryImpl @Inject constructor(
}
private suspend fun ensureValidSession(forceRefresh: Boolean = false): Session {
val session = retrieveSession() ?: throw AuthorizationException()
val session = retrieveSession()
val url = urlRepository.getUrl()?.toHttpUrlOrNull()
if (session == null || url == null) {
Log.e(TAG, "Unable to revoke session.")
throw AuthorizationException()
}
if (session.isExpired() || forceRefresh) {
return authenticationService.refreshToken(
url.newBuilder().addPathSegments("auth/token").build(),
AuthenticationService.GRANT_TYPE_REFRESH,
session.refreshToken,
AuthenticationService.CLIENT_ID

View file

@ -5,6 +5,7 @@ import io.homeassistant.companion.android.common.data.authentication.impl.entiti
import io.homeassistant.companion.android.common.data.authentication.impl.entities.LoginFlowInit
import io.homeassistant.companion.android.common.data.authentication.impl.entities.LoginFlowRequest
import io.homeassistant.companion.android.common.data.authentication.impl.entities.Token
import okhttp3.HttpUrl
import retrofit2.Response
import retrofit2.http.Body
import retrofit2.http.Field
@ -25,24 +26,27 @@ interface AuthenticationService {
}
@FormUrlEncoded
@POST("auth/token")
@POST
suspend fun getToken(
@Url url: HttpUrl,
@Field("grant_type") grandType: String,
@Field("code") code: String,
@Field("client_id") clientId: String
): Token
@FormUrlEncoded
@POST("auth/token")
@POST
suspend fun refreshToken(
@Url url: HttpUrl,
@Field("grant_type") grandType: String,
@Field("refresh_token") refreshToken: String,
@Field("client_id") clientId: String
): Response<Token>
@FormUrlEncoded
@POST("auth/token")
@POST
suspend fun revokeToken(
@Url url: HttpUrl,
@Field("token") refreshToken: String,
@Field("action") action: String
)

View file

@ -42,11 +42,11 @@ interface IntegrationRepository {
suspend fun getHomeAssistantVersion(): String
suspend fun getServices(): List<Service>
suspend fun getServices(): List<Service>?
suspend fun getEntities(): List<Entity<Any>>
suspend fun getEntity(entityId: String): Entity<Map<String, Any>>
suspend fun getEntityUpdates(): Flow<Entity<*>>
suspend fun getEntities(): List<Entity<Any>>?
suspend fun getEntity(entityId: String): Entity<Map<String, Any>>?
suspend fun getEntityUpdates(): Flow<Entity<*>>?
suspend fun callService(domain: String, service: String, serviceData: HashMap<String, Any>)

View file

@ -25,7 +25,6 @@ import io.homeassistant.companion.android.common.data.integration.impl.entities.
import io.homeassistant.companion.android.common.data.url.UrlRepository
import io.homeassistant.companion.android.common.data.websocket.WebSocketRepository
import io.homeassistant.companion.android.common.data.websocket.impl.entities.GetConfigResponse
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.map
@ -86,9 +85,15 @@ class IntegrationRepositoryImpl @Inject constructor(
request.supportsEncryption = false
request.deviceId = deviceId
val url = urlRepository.getUrl()?.toHttpUrlOrNull()
if (url == null) {
Log.e(TAG, "Unable to register device due to missing URL")
return
}
try {
val response =
integrationService.registerDevice(
url.newBuilder().addPathSegments("api/mobile_app/registrations").build(),
authenticationRepository.buildBearerToken(),
request
)
@ -427,38 +432,43 @@ class IntegrationRepositoryImpl @Inject constructor(
}
}
override suspend fun getServices(): List<Service> {
override suspend fun getServices(): List<Service>? {
val response = webSocketRepository.getServices()
return response.flatMap {
return response?.flatMap {
it.services.map { service ->
Service(it.domain, service.key, service.value)
}
}.toList()
}?.toList()
}
override suspend fun getEntities(): List<Entity<Any>> {
override suspend fun getEntities(): List<Entity<Any>>? {
val response = webSocketRepository.getStates()
return response
.map {
Entity(
it.entityId,
it.state,
it.attributes,
it.lastChanged,
it.lastUpdated,
it.context
)
}
.sortedBy { it.entityId }
.toList()
return response?.map {
Entity(
it.entityId,
it.state,
it.attributes,
it.lastChanged,
it.lastUpdated,
it.context
)
}
?.sortedBy { it.entityId }
?.toList()
}
override suspend fun getEntity(entityId: String): Entity<Map<String, Any>> {
override suspend fun getEntity(entityId: String): Entity<Map<String, Any>>? {
val url = urlRepository.getUrl()?.toHttpUrlOrNull()
if (url == null) {
Log.e(TAG, "Unable to register device due to missing URL")
return null
}
val response = integrationService.getState(
authenticationRepository.buildBearerToken(),
entityId
url.newBuilder().addPathSegments("api/states/$entityId").build(),
authenticationRepository.buildBearerToken()
)
return Entity(
response.entityId,
@ -470,11 +480,10 @@ class IntegrationRepositoryImpl @Inject constructor(
)
}
@ExperimentalCoroutinesApi
override suspend fun getEntityUpdates(): Flow<Entity<*>> {
override suspend fun getEntityUpdates(): Flow<Entity<*>>? {
return webSocketRepository.getStateChanges()
.filter { it.newState != null }
.map {
?.filter { it.newState != null }
?.map {
Entity(
it.newState!!.entityId,
it.newState.state,

View file

@ -14,21 +14,21 @@ import retrofit2.http.Body
import retrofit2.http.GET
import retrofit2.http.Header
import retrofit2.http.POST
import retrofit2.http.Path
import retrofit2.http.Url
interface IntegrationService {
@POST("/api/mobile_app/registrations")
@POST
suspend fun registerDevice(
@Url url: HttpUrl,
@Header("Authorization") auth: String,
@Body request: RegisterDeviceRequest
): RegisterDeviceResponse
@GET("/api/states/{entityId}")
@GET
suspend fun getState(
@Header("Authorization") auth: String,
@Path("entityId") entityId: String
@Url url: HttpUrl,
@Header("Authorization") auth: String
): EntityResponse<Map<String, Any>>
@POST

View file

@ -1,7 +1,6 @@
package io.homeassistant.companion.android.common.data.websocket
import io.homeassistant.companion.android.common.data.integration.impl.entities.EntityResponse
import io.homeassistant.companion.android.common.data.integration.impl.entities.ServiceCallRequest
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryUpdatedEvent
import io.homeassistant.companion.android.common.data.websocket.impl.entities.DeviceRegistryResponse
@ -14,22 +13,17 @@ import io.homeassistant.companion.android.common.data.websocket.impl.entities.St
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.flow.Flow
@ExperimentalCoroutinesApi
interface WebSocketRepository {
suspend fun sendPing(): Boolean
suspend fun getConfig(): GetConfigResponse?
suspend fun getStates(): List<EntityResponse<Any>>
suspend fun getAreaRegistry(): List<AreaRegistryResponse>
suspend fun getDeviceRegistry(): List<DeviceRegistryResponse>
suspend fun getEntityRegistry(): List<EntityRegistryResponse>
suspend fun getServices(): List<DomainResponse>
suspend fun getPanels(): List<String>
suspend fun callService(request: ServiceCallRequest)
@ExperimentalCoroutinesApi
suspend fun getStateChanges(): Flow<StateChangedEvent>
@ExperimentalCoroutinesApi
suspend fun getAreaRegistryUpdates(): Flow<AreaRegistryUpdatedEvent>
@ExperimentalCoroutinesApi
suspend fun getDeviceRegistryUpdates(): Flow<DeviceRegistryUpdatedEvent>
@ExperimentalCoroutinesApi
suspend fun getEntityRegistryUpdates(): Flow<EntityRegistryUpdatedEvent>
suspend fun getStates(): List<EntityResponse<Any>>?
suspend fun getAreaRegistry(): List<AreaRegistryResponse>?
suspend fun getDeviceRegistry(): List<DeviceRegistryResponse>?
suspend fun getEntityRegistry(): List<EntityRegistryResponse>?
suspend fun getServices(): List<DomainResponse>?
suspend fun getStateChanges(): Flow<StateChangedEvent>?
suspend fun getAreaRegistryUpdates(): Flow<AreaRegistryUpdatedEvent>?
suspend fun getDeviceRegistryUpdates(): Flow<DeviceRegistryUpdatedEvent>?
suspend fun getEntityRegistryUpdates(): Flow<EntityRegistryUpdatedEvent>?
}

View file

@ -4,12 +4,12 @@ import android.util.Log
import com.fasterxml.jackson.core.type.TypeReference
import com.fasterxml.jackson.databind.DeserializationFeature
import com.fasterxml.jackson.databind.PropertyNamingStrategies
import com.fasterxml.jackson.module.kotlin.convertValue
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import io.homeassistant.companion.android.common.data.authentication.AuthenticationRepository
import io.homeassistant.companion.android.common.data.integration.ServiceData
import io.homeassistant.companion.android.common.data.integration.impl.entities.EntityResponse
import io.homeassistant.companion.android.common.data.integration.impl.entities.ServiceCallRequest
import io.homeassistant.companion.android.common.data.url.UrlRepository
import io.homeassistant.companion.android.common.data.websocket.WebSocketRepository
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AreaRegistryResponse
@ -34,14 +34,13 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.shareIn
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.withTimeoutOrNull
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.Response
@ -51,6 +50,7 @@ import okio.ByteString
import java.util.concurrent.atomic.AtomicLong
import javax.inject.Inject
@ExperimentalCoroutinesApi
class WebSocketRepositoryImpl @Inject constructor(
private val okHttpClient: OkHttpClient,
private val urlRepository: UrlRepository,
@ -77,7 +77,6 @@ class WebSocketRepositoryImpl @Inject constructor(
private val eventSubscriptionMutex = Mutex()
private val eventSubscriptionFlow = mutableMapOf<String, SharedFlow<*>>()
@ExperimentalCoroutinesApi
private var eventSubscriptionProducerScope = mutableMapOf<String, ProducerScope<Any>>()
override suspend fun sendPing(): Boolean {
@ -87,187 +86,135 @@ class WebSocketRepositoryImpl @Inject constructor(
)
)
return socketResponse.type == "pong"
return socketResponse?.type == "pong"
}
override suspend fun getConfig(): GetConfigResponse? {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "get_config"
)
val socketResponse = sendMessage(
mapOf(
"type" to "get_config"
)
)
mapper.convertValue(socketResponse.result!!, GetConfigResponse::class.java)
} catch (e: Exception) {
Log.e(TAG, "Unable to get config response", e)
null
return mapResponse(socketResponse)
}
override suspend fun getStates(): List<EntityResponse<Any>>? {
val socketResponse = sendMessage(
mapOf(
"type" to "get_states"
)
)
return mapResponse(socketResponse)
}
override suspend fun getAreaRegistry(): List<AreaRegistryResponse>? {
val socketResponse = sendMessage(
mapOf(
"type" to "config/area_registry/list"
)
)
return mapResponse(socketResponse)
}
override suspend fun getDeviceRegistry(): List<DeviceRegistryResponse>? {
val socketResponse = sendMessage(
mapOf(
"type" to "config/device_registry/list"
)
)
return mapResponse(socketResponse)
}
override suspend fun getEntityRegistry(): List<EntityRegistryResponse>? {
val socketResponse = sendMessage(
mapOf(
"type" to "config/entity_registry/list"
)
)
return mapResponse(socketResponse)
}
override suspend fun getServices(): List<DomainResponse>? {
val socketResponse = sendMessage(
mapOf(
"type" to "get_services"
)
)
val response: Map<String, Map<String, ServiceData>>? = mapResponse(socketResponse)
return response?.map {
DomainResponse(it.key, it.value)
}
}
override suspend fun getStates(): List<EntityResponse<Any>> {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "get_states"
)
)
override suspend fun getStateChanges(): Flow<StateChangedEvent>? =
subscribeToEventsForType(EVENT_STATE_CHANGED)
mapper.convertValue(
socketResponse.result!!,
object : TypeReference<List<EntityResponse<Any>>>() {}
)
} catch (e: Exception) {
Log.e(TAG, "Unable to get list of entities", e)
emptyList()
}
}
override suspend fun getAreaRegistryUpdates(): Flow<AreaRegistryUpdatedEvent>? =
subscribeToEventsForType(EVENT_AREA_REGISTRY_UPDATED)
override suspend fun getAreaRegistry(): List<AreaRegistryResponse> {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "config/area_registry/list"
)
)
override suspend fun getDeviceRegistryUpdates(): Flow<DeviceRegistryUpdatedEvent>? =
subscribeToEventsForType(EVENT_DEVICE_REGISTRY_UPDATED)
mapper.convertValue(
socketResponse.result!!,
object : TypeReference<List<AreaRegistryResponse>>() {}
)
} catch (e: Exception) {
Log.e(TAG, "Unable to get area registry list", e)
emptyList()
}
}
override suspend fun getEntityRegistryUpdates(): Flow<EntityRegistryUpdatedEvent>? =
subscribeToEventsForType(EVENT_ENTITY_REGISTRY_UPDATED)
override suspend fun getDeviceRegistry(): List<DeviceRegistryResponse> {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "config/device_registry/list"
)
)
private suspend fun <T : Any> subscribeToEventsForType(eventType: String): Flow<T>? {
eventSubscriptionMutex.withLock {
if (eventSubscriptionFlow[eventType] == null) {
mapper.convertValue(
socketResponse.result!!,
object : TypeReference<List<DeviceRegistryResponse>>() {}
)
} catch (e: Exception) {
Log.e(TAG, "Unable to get device registry list", e)
emptyList()
}
}
override suspend fun getEntityRegistry(): List<EntityRegistryResponse> {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "config/entity_registry/list"
)
)
mapper.convertValue(
socketResponse.result!!,
object : TypeReference<List<EntityRegistryResponse>>() {}
)
} catch (e: Exception) {
Log.e(TAG, "Unable to get entity registry list", e)
emptyList()
}
}
override suspend fun getServices(): List<DomainResponse> {
return try {
val socketResponse = sendMessage(
mapOf(
"type" to "get_services"
)
)
val response = mapper.convertValue(
socketResponse.result!!,
object : TypeReference<Map<String, Map<String, ServiceData>>>() {}
)
response.map {
DomainResponse(it.key, it.value)
}
} catch (e: Exception) {
Log.e(TAG, "Unable to get service data")
emptyList()
}
}
override suspend fun getPanels(): List<String> {
TODO("Not yet implemented")
}
override suspend fun callService(request: ServiceCallRequest) {
TODO("Not yet implemented")
}
@ExperimentalCoroutinesApi
override suspend fun getStateChanges(): Flow<StateChangedEvent> = subscribeToEventsForType(EVENT_STATE_CHANGED)
@ExperimentalCoroutinesApi
override suspend fun getAreaRegistryUpdates(): Flow<AreaRegistryUpdatedEvent> = subscribeToEventsForType(EVENT_AREA_REGISTRY_UPDATED)
@ExperimentalCoroutinesApi
override suspend fun getDeviceRegistryUpdates(): Flow<DeviceRegistryUpdatedEvent> = subscribeToEventsForType(EVENT_DEVICE_REGISTRY_UPDATED)
@ExperimentalCoroutinesApi
override suspend fun getEntityRegistryUpdates(): Flow<EntityRegistryUpdatedEvent> = subscribeToEventsForType(EVENT_ENTITY_REGISTRY_UPDATED)
@ExperimentalCoroutinesApi
private suspend fun <T : Any> subscribeToEventsForType(eventType: String): Flow<T> {
return try {
eventSubscriptionMutex.withLock {
if (eventSubscriptionFlow[eventType] == null) {
val response = sendMessage(
mapOf(
"type" to "subscribe_events",
"event_type" to eventType
)
val response = sendMessage(
mapOf(
"type" to "subscribe_events",
"event_type" to eventType
)
eventSubscriptionFlow[eventType] = callbackFlow<T> {
eventSubscriptionProducerScope[eventType] = this as ProducerScope<Any>
awaitClose {
Log.d(TAG, "Unsubscribing from $eventType")
ioScope.launch {
sendMessage(
mapOf(
"type" to "unsubscribe_events",
"subscription" to response.id
)
)
}
eventSubscriptionProducerScope.remove(eventType)
eventSubscriptionFlow.remove(eventType)
}
}.shareIn(ioScope, SharingStarted.WhileSubscribed())
)
if (response == null) {
Log.e(TAG, "Unable to register for events of type $eventType")
return null
}
}
eventSubscriptionFlow[eventType]!! as SharedFlow<T>
} catch (e: Exception) {
Log.e(TAG, "Unable to subscribe to $eventType", e)
emptyFlow()
eventSubscriptionFlow[eventType] = callbackFlow<T> {
eventSubscriptionProducerScope[eventType] = this as ProducerScope<Any>
awaitClose {
Log.d(TAG, "Unsubscribing from $eventType")
ioScope.launch {
sendMessage(
mapOf(
"type" to "unsubscribe_events",
"subscription" to response.id
)
)
}
eventSubscriptionProducerScope.remove(eventType)
eventSubscriptionFlow.remove(eventType)
}
}.shareIn(ioScope, SharingStarted.WhileSubscribed())
}
}
return eventSubscriptionFlow[eventType]!! as Flow<T>
}
/**
* This method will
*/
private suspend fun connect() {
private suspend fun connect(): Boolean {
connectedMutex.withLock {
if (connection != null && connected.isCompleted) {
return
return true
}
val url = urlRepository.getUrl()
if (url == null) {
Log.w(TAG, "No url to connect websocket too.")
return false
}
val url = urlRepository.getUrl() ?: throw Exception("Unable to get URL for WebSocket")
val urlString = url.toString()
.replace("https://", "wss://")
.replace("http://", "ws://")
@ -276,53 +223,52 @@ class WebSocketRepositoryImpl @Inject constructor(
connection = okHttpClient.newWebSocket(
Request.Builder().url(urlString).build(),
this
)
// Preemptively send auth
authenticate()
// Wait up to 30 seconds for auth response
withTimeout(30000) {
connected.join()
}
}
}
private suspend fun sendMessage(request: Map<*, *>): SocketResponse {
for (i in 0..1) {
val requestId = id.getAndIncrement()
val outbound = request.plus("id" to requestId)
Log.d(TAG, "Sending message number $requestId: $outbound")
connect()
try {
return withTimeout(30000) {
suspendCancellableCoroutine { cont ->
responseCallbackJobs[requestId] = cont
connection!!.send(mapper.writeValueAsString(outbound))
Log.d(TAG, "Message number $requestId sent")
}
}
} catch (e: Exception) {
Log.e(TAG, "Error sending request number $requestId", e)
}
}
throw Exception("Unable to send message: $request")
}
private suspend fun authenticate() {
if (connection != null) {
connection!!.send(
mapper.writeValueAsString(
mapOf(
"type" to "auth",
"access_token" to authenticationRepository.retrieveAccessToken()
).also {
// Preemptively send auth
it.send(
mapper.writeValueAsString(
mapOf(
"type" to "auth",
"access_token" to authenticationRepository.retrieveAccessToken()
)
)
)
)
} else
Log.e(TAG, "Attempted to authenticate when connection is null")
}
// Wait up to 30 seconds for auth response
return true == withTimeoutOrNull(30000) {
return@withTimeoutOrNull try {
connected.join()
true
} catch (e: Exception) {
Log.e(TAG, "Unable to authenticate", e)
false
}
}
}
}
private suspend fun sendMessage(request: Map<*, *>): SocketResponse? {
val requestId = id.getAndIncrement()
val outbound = request.plus("id" to requestId)
return if (connect()) {
Log.d(TAG, "Sending message $requestId: $outbound")
withTimeoutOrNull(30000) {
suspendCancellableCoroutine { cont ->
responseCallbackJobs[requestId] = cont
connection!!.send(mapper.writeValueAsString(outbound))
Log.d(TAG, "Message number $requestId sent")
}
}
} else {
Log.e(TAG, "Unable to send message $requestId: $outbound")
null
}
}
private inline fun <reified T> mapResponse(response: SocketResponse?): T? =
if (response?.result != null) mapper.convertValue(response.result) else null
private fun handleAuthComplete(successful: Boolean) {
if (successful)
connected.complete()
@ -336,15 +282,20 @@ class WebSocketRepositoryImpl @Inject constructor(
responseCallbackJobs.remove(id)
}
@ExperimentalCoroutinesApi
private suspend fun handleEvent(response: SocketResponse) {
val eventResponseType = response.event?.get("event_type")
if (eventResponseType != null && eventResponseType.isTextual) {
val eventResponseClass = when (eventResponseType.textValue()) {
EVENT_STATE_CHANGED -> object : TypeReference<EventResponse<StateChangedEvent>>() {}
EVENT_AREA_REGISTRY_UPDATED -> object : TypeReference<EventResponse<AreaRegistryUpdatedEvent>>() {}
EVENT_DEVICE_REGISTRY_UPDATED -> object : TypeReference<EventResponse<DeviceRegistryUpdatedEvent>>() {}
EVENT_ENTITY_REGISTRY_UPDATED -> object : TypeReference<EventResponse<EntityRegistryUpdatedEvent>>() {}
EVENT_AREA_REGISTRY_UPDATED ->
object :
TypeReference<EventResponse<AreaRegistryUpdatedEvent>>() {}
EVENT_DEVICE_REGISTRY_UPDATED ->
object :
TypeReference<EventResponse<DeviceRegistryUpdatedEvent>>() {}
EVENT_ENTITY_REGISTRY_UPDATED ->
object :
TypeReference<EventResponse<EntityRegistryUpdatedEvent>>() {}
else -> {
Log.d(TAG, "Unknown event type received")
object : TypeReference<EventResponse<Any>>() {}
@ -358,26 +309,24 @@ class WebSocketRepositoryImpl @Inject constructor(
}
}
@ExperimentalCoroutinesApi
private fun handleClosingSocket() {
connected = Job()
connection = null
// If we still have flows flowing
if (eventSubscriptionFlow.any() && ioScope.isActive) {
ioScope.launch {
try {
connect()
// Register for websocket events!
if (connect()) {
eventSubscriptionFlow.forEach { (eventType, _) ->
sendMessage(
val resp = sendMessage(
mapOf(
"type" to "subscribe_events",
"event_type" to eventType
)
)
if (resp == null) {
Log.e(TAG, "Issue re-registering event subscriptions")
}
}
} catch (e: Exception) {
Log.e(TAG, "Issue reconnecting websocket", e)
}
}
}

View file

@ -15,8 +15,8 @@ interface HomePresenter {
suspend fun isConnected(): Boolean
suspend fun getEntities(): List<Entity<*>>
suspend fun getEntityUpdates(): Flow<Entity<*>>
suspend fun getEntities(): List<Entity<*>>?
suspend fun getEntityUpdates(): Flow<Entity<*>>?
suspend fun getTileShortcuts(): List<SimplifiedEntity>
suspend fun setTileShortcuts(entities: List<SimplifiedEntity>)

View file

@ -53,16 +53,11 @@ class HomePresenterImpl @Inject constructor(
}
}
override suspend fun getEntities(): List<Entity<*>> {
return try {
integrationUseCase.getEntities()
} catch (e: Exception) {
Log.e(TAG, "Unable to get entities", e)
emptyList()
}
override suspend fun getEntities(): List<Entity<*>>? {
return integrationUseCase.getEntities()
}
override suspend fun getEntityUpdates(): Flow<Entity<*>> {
override suspend fun getEntityUpdates(): Flow<Entity<*>>? {
return integrationUseCase.getEntityUpdates()
}

View file

@ -70,11 +70,11 @@ class MainViewModel @Inject constructor(application: Application) : AndroidViewM
shortcutEntities.addAll(homePresenter.getTileShortcuts())
isHapticEnabled.value = homePresenter.getWearHapticFeedback()
isToastEnabled.value = homePresenter.getWearToastConfirmation()
homePresenter.getEntities().forEach {
homePresenter.getEntities()?.forEach {
entities[it.entityId] = it
}
updateEntityDomains()
homePresenter.getEntityUpdates().collect {
homePresenter.getEntityUpdates()?.collect {
entities[it.entityId] = it
updateEntityDomains()
}

View file

@ -41,7 +41,7 @@ class TileActionReceiver : BroadcastReceiver() {
val serviceName = when (domain) {
"lock" -> {
val lockEntity = integrationUseCase.getEntity(entityId)
if (lockEntity.state == "locked")
if (lockEntity?.state == "locked")
"unlock"
else
"lock"