Use mutex

This commit is contained in:
Hugh Nimmo-Smith 2022-10-18 08:48:28 +01:00
parent 8a62dfb34a
commit f297117df2

View file

@ -19,6 +19,8 @@ package org.matrix.android.sdk.api.rendezvous.channels
import android.util.Base64
import com.squareup.moshi.Json
import com.squareup.moshi.JsonClass
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import okhttp3.MediaType.Companion.toMediaType
import org.matrix.android.sdk.api.logger.LoggerTag
import org.matrix.android.sdk.api.rendezvous.RendezvousChannel
@ -71,6 +73,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Json val iv: String? = null
)
private var olmSASMutex = Mutex()
private var olmSAS: OlmSAS?
private val ourPublicKey: ByteArray
private val ecdhAdapter = MatrixJsonParser.getMoshi().adapter(ECDHPayload::class.java)
@ -87,45 +90,44 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
@Throws(RendezvousError::class)
override suspend fun connect(): String {
olmSAS ?.let { olmSAS ->
val isInitiator = theirPublicKey == null
val sas = olmSAS ?: throw RendezvousError("Channel closed", RendezvousFailureReason.Unknown)
val isInitiator = theirPublicKey == null
if (isInitiator) {
Timber.tag(TAG).i("Waiting for other device to send their public key")
val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError)
if (isInitiator) {
Timber.tag(TAG).i("Waiting for other device to send their public key")
val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError)
if (res.key == null) {
throw RendezvousError(
"Unsupported algorithm: ${res.algorithm}",
RendezvousFailureReason.UnsupportedAlgorithm,
)
}
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
if (res.key == null) {
throw RendezvousError(
"Unsupported algorithm: ${res.algorithm}",
RendezvousFailureReason.UnsupportedAlgorithm,
)
}
theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP)
} else {
// send our public key unencrypted
Timber.tag(TAG).i("Sending public key")
send(
ECDHPayload(
algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1,
key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP)
)
)
}
synchronized(olmSAS) {
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
olmSASMutex.withLock {
sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP))
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP)
val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP)
val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey"
aesKey = olmSAS.generateShortCode(aesInfo, 32)
aesKey = sas.generateShortCode(aesInfo, 32)
val rawChecksum = olmSAS.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum)
}
} ?: throw RuntimeException("Channel closed")
val rawChecksum = sas.generateShortCode(aesInfo, 5)
return getDecimalCodeRepresentation(rawChecksum)
}
}
private suspend fun send(payload: ECDHPayload) {
@ -154,12 +156,11 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu
}
override suspend fun close() {
olmSAS ?.let {
synchronized(it) {
// this does a double release check already so we don't re-check ourselves
it.releaseSas()
olmSAS = null
}
val sas = olmSAS ?: throw IllegalStateException("Channel already closed")
olmSASMutex.withLock {
// this does a double release check already so we don't re-check ourselves
sas.releaseSas()
olmSAS = null
}
transport.close()
}