HTTP Workers: use runInterruptible instead of interrupting manually (bitfireAT/davx5#444)

* RefreshCollectionsWorker: use runInterruptible instead of interrupting manually

* SyncWorker: use CoroutineWorker + runInterruptible

* Use global SyncWorkDispatcher that guarantees classLoader to be set

* Set SyncWorkDispatcher for whole SyncWorker's doWork

* Remove obsolete test

* SyncManager: add structured concurrency again

* Use up to <number of processors> threads for synchronization

---------

Co-authored-by: Sunik Kupfer <kupfer@bitfire.at>
This commit is contained in:
Ricki Hirner 2023-11-09 16:58:53 +01:00
parent 42bd1e8449
commit 61c1ef8831
No known key found for this signature in database
GPG key ID: 79A019FCAAEDD3AA
5 changed files with 138 additions and 156 deletions

View file

@ -12,9 +12,7 @@ import android.provider.ContactsContract
import android.util.Log
import androidx.test.platform.app.InstrumentationRegistry
import androidx.work.Configuration
import androidx.work.testing.TestWorkerBuilder
import androidx.work.testing.WorkManagerTestInitHelper
import androidx.work.workDataOf
import at.bitfire.davdroid.R
import at.bitfire.davdroid.TestUtils.workScheduledOrRunningOrSuccessful
import at.bitfire.davdroid.db.Credentials
@ -28,13 +26,10 @@ import io.mockk.mockk
import io.mockk.mockkObject
import org.junit.After
import org.junit.Assert.assertFalse
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull
import org.junit.Assert.assertTrue
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.util.concurrent.Executors
@HiltAndroidTest
class SyncWorkerTest {
@ -45,15 +40,12 @@ class SyncWorkerTest {
private val account = Account("Test Account", context.getString(R.string.account_type))
private val fakeCredentials = Credentials("test", "test")
private val executor = Executors.newSingleThreadExecutor()
@get:Rule
val hiltRule = HiltAndroidRule(this)
@Before
fun inject() {
hiltRule.inject()
}
fun inject() = hiltRule.inject()
@Before
fun setUp() {
@ -129,31 +121,4 @@ class SyncWorkerTest {
// TODO: Write test
}
@Test
fun testOnStopped_interruptsSyncThread() {
val authority = CalendarContract.AUTHORITY
val inputData = workDataOf(
SyncWorker.ARG_AUTHORITY to authority,
SyncWorker.ARG_ACCOUNT_NAME to account.name,
SyncWorker.ARG_ACCOUNT_TYPE to account.type
)
// Create SyncWorker as TestWorker
val testSyncWorker = TestWorkerBuilder<SyncWorker>(context, executor, inputData).build()
assertNull(testSyncWorker.syncThread)
// Run SyncWorker and assert sync thread is alive
testSyncWorker.doWork()
assertNotNull(testSyncWorker.syncThread)
assertTrue(testSyncWorker.syncThread!!.isAlive)
assertFalse(testSyncWorker.syncThread!!.isInterrupted) // Sync running
// Stop SyncWorker and assert sync thread was interrupted
testSyncWorker.onStopped()
assertNotNull(testSyncWorker.syncThread)
assertTrue(testSyncWorker.syncThread!!.isAlive)
assertTrue(testSyncWorker.syncThread!!.isInterrupted) // Sync thread interrupted
}
}

View file

@ -8,12 +8,11 @@ import android.accounts.Account
import android.app.PendingIntent
import android.content.Context
import android.content.Intent
import android.os.Build
import androidx.concurrent.futures.CallbackToFutureAdapter
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.hilt.work.HiltWorker
import androidx.lifecycle.map
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.ExistingWorkPolicy
import androidx.work.ForegroundInfo
@ -21,7 +20,6 @@ import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkInfo
import androidx.work.WorkManager
import androidx.work.Worker
import androidx.work.WorkerParameters
import at.bitfire.dav4jvm.DavResource
import at.bitfire.dav4jvm.MultiResponseCallback
@ -64,9 +62,9 @@ import at.bitfire.davdroid.ui.NotificationUtils
import at.bitfire.davdroid.ui.NotificationUtils.notifyIfPossible
import at.bitfire.davdroid.ui.account.SettingsActivity
import at.bitfire.davdroid.util.DavUtils.parent
import com.google.common.util.concurrent.ListenableFuture
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
import kotlinx.coroutines.runInterruptible
import okhttp3.HttpUrl
import okhttp3.OkHttpClient
import java.util.logging.Level
@ -93,7 +91,7 @@ class RefreshCollectionsWorker @AssistedInject constructor(
@Assisted workerParams: WorkerParameters,
var db: AppDatabase,
var settings: SettingsManager
): Worker(appContext, workerParams) {
): CoroutineWorker(appContext, workerParams) {
companion object {
@ -170,10 +168,7 @@ class RefreshCollectionsWorker @AssistedInject constructor(
val service = db.serviceDao().get(serviceId) ?: throw IllegalArgumentException("Service #$serviceId not found")
val account = Account(service.accountName, applicationContext.getString(R.string.account_type))
/** thread which runs the actual refresh code (can be interrupted to stop refreshing) */
var refreshThread: Thread? = null
override fun doWork(): Result {
override suspend fun doWork(): Result {
try {
Logger.log.info("Refreshing ${service.type} collections of service #$service")
@ -182,28 +177,29 @@ class RefreshCollectionsWorker @AssistedInject constructor(
.cancel(serviceId.toString(), NotificationUtils.NOTIFY_REFRESH_COLLECTIONS)
// create authenticating OkHttpClient (credentials taken from account settings)
refreshThread = Thread.currentThread()
HttpClient.Builder(applicationContext, AccountSettings(applicationContext, account))
.setForeground(true)
.build().use { client ->
val httpClient = client.okHttpClient
val refresher = Refresher(db, service, settings, httpClient)
runInterruptible {
HttpClient.Builder(applicationContext, AccountSettings(applicationContext, account))
.setForeground(true)
.build().use { client ->
val httpClient = client.okHttpClient
val refresher = Refresher(db, service, settings, httpClient)
// refresh home set list (from principal url)
service.principal?.let { principalUrl ->
Logger.log.fine("Querying principal $principalUrl for home sets")
refresher.discoverHomesets(principalUrl)
// refresh home set list (from principal url)
service.principal?.let { principalUrl ->
Logger.log.fine("Querying principal $principalUrl for home sets")
refresher.discoverHomesets(principalUrl)
}
// refresh home sets and their member collections
refresher.refreshHomesetsAndTheirCollections()
// also refresh collections without a home set
refresher.refreshHomelessCollections()
// Lastly, refresh the principals (collection owners)
refresher.refreshPrincipals()
}
// refresh home sets and their member collections
refresher.refreshHomesetsAndTheirCollections()
// also refresh collections without a home set
refresher.refreshHomelessCollections()
// Lastly, refresh the principals (collection owners)
refresher.refreshPrincipals()
}
}
} catch(e: InvalidAccountException) {
Logger.log.log(Level.SEVERE, "Invalid account", e)
@ -232,31 +228,23 @@ class RefreshCollectionsWorker @AssistedInject constructor(
return Result.failure()
}
// Success
return Result.success()
}
override fun onStopped() {
Logger.log.info("Stopping refresh (reason ${if (Build.VERSION.SDK_INT >= 31) stopReason else "n/a"})")
refreshThread?.interrupt()
override suspend fun getForegroundInfo(): ForegroundInfo {
val notification = NotificationUtils.newBuilder(applicationContext, NotificationUtils.CHANNEL_STATUS)
.setSmallIcon(R.drawable.ic_foreground_notify)
.setContentTitle(applicationContext.getString(R.string.foreground_service_notify_title))
.setContentText(applicationContext.getString(R.string.foreground_service_notify_text))
.setStyle(NotificationCompat.BigTextStyle())
.setCategory(NotificationCompat.CATEGORY_STATUS)
.setOngoing(true)
.setPriority(NotificationCompat.PRIORITY_LOW)
.build()
return ForegroundInfo(NotificationUtils.NOTIFY_SYNC_EXPEDITED, notification)
}
override fun getForegroundInfoAsync(): ListenableFuture<ForegroundInfo> =
CallbackToFutureAdapter.getFuture { completer ->
val notification = NotificationUtils.newBuilder(applicationContext, NotificationUtils.CHANNEL_STATUS)
.setSmallIcon(R.drawable.ic_foreground_notify)
.setContentTitle(applicationContext.getString(R.string.foreground_service_notify_title))
.setContentText(applicationContext.getString(R.string.foreground_service_notify_text))
.setStyle(NotificationCompat.BigTextStyle())
.setCategory(NotificationCompat.CATEGORY_STATUS)
.setOngoing(true)
.setPriority(NotificationCompat.PRIORITY_LOW)
.build()
completer.set(ForegroundInfo(NotificationUtils.NOTIFY_SYNC_EXPEDITED, notification))
}
private fun notifyRefreshError(contentText: String, contentIntent: Intent) {
val notify = NotificationUtils.newBuilder(applicationContext, NotificationUtils.CHANNEL_GENERAL)
.setSmallIcon(R.drawable.ic_sync_problem_notify)

View file

@ -51,14 +51,11 @@ import org.apache.commons.lang3.exception.ContextedException
import org.dmfs.tasks.contract.TaskContract
import java.io.IOException
import java.io.InterruptedIOException
import java.lang.ref.WeakReference
import java.net.HttpURLConnection
import java.security.cert.CertificateException
import java.time.Instant
import java.util.*
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.logging.Level
import javax.net.ssl.SSLHandshakeException
@ -121,27 +118,6 @@ abstract class SyncManager<ResourceType: LocalResource<*>, out CollectionType: L
}
}
var _workDispatcher: WeakReference<CoroutineDispatcher>? = null
/**
* We use our own dispatcher to
*
* - make sure that all threads have [Thread.getContextClassLoader] set, which is required for dav4jvm and ical4j (because they rely on [ServiceLoader]),
* - control the global number of sync worker threads.
*
* Threads created by a service automatically have a contextClassLoader.
*/
fun getWorkDispatcher(): CoroutineDispatcher {
val cached = _workDispatcher?.get()
if (cached != null)
return cached
val newDispatcher = ThreadPoolExecutor(
0, Integer.min(Runtime.getRuntime().availableProcessors(), 4),
10, TimeUnit.SECONDS, LinkedBlockingQueue()
).asCoroutineDispatcher()
return newDispatcher
}
}
init {
@ -162,8 +138,6 @@ abstract class SyncManager<ResourceType: LocalResource<*>, out CollectionType: L
protected var hasCollectionSync = false
val workDispatcher = getWorkDispatcher()
fun performSync() {
// dismiss previous error notifications
@ -389,7 +363,7 @@ abstract class SyncManager<ResourceType: LocalResource<*>, out CollectionType: L
var numUploaded = 0
// upload dirty resources (parallelized)
runBlocking(workDispatcher) {
runBlocking {
for (local in localCollection.findDirty())
launch {
localExceptionContext(local) {
@ -578,7 +552,7 @@ abstract class SyncManager<ResourceType: LocalResource<*>, out CollectionType: L
}
}
withContext(workDispatcher) { // structured concurrency: blocks until all inner coroutines are finished
coroutineScope { // structured concurrency: blocks until all inner coroutines are finished
listRemote { response, relation ->
// ignore non-members
if (relation != Response.HrefRelation.MEMBER)
@ -602,7 +576,7 @@ abstract class SyncManager<ResourceType: LocalResource<*>, out CollectionType: L
} else {
val localETag = local.eTag
val remoteETag = response[GetETag::class.java]?.eTag
?: throw DavException("Server didn't provide ETag")
?: throw DavException("Server didn't provide ETag")
if (localETag == remoteETag) {
Logger.log.info("$name has not been changed on server (ETag still $remoteETag)")
nSkipped.incrementAndGet()

View file

@ -0,0 +1,50 @@
/***************************************************************************************************
* Copyright © All Contributors. See LICENSE and AUTHORS in the root directory for details.
**************************************************************************************************/
package at.bitfire.davdroid.syncadapter
import android.content.Context
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
object SyncWorkDispatcher {
private var _dispatcher: CoroutineDispatcher? = null
/**
* We use our own dispatcher to
*
* - make sure that all threads have [Thread.getContextClassLoader] set,
* which is required for dav4jvm and ical4j (because they rely on [ServiceLoader]),
* - control the global number of sync worker threads.
*/
@Synchronized
fun getInstance(context: Context): CoroutineDispatcher {
// prefer cached work dispatcher
_dispatcher?.let { return it }
val newDispatcher = createDispatcher(context.applicationContext.classLoader)
_dispatcher = newDispatcher
return newDispatcher
}
private fun createDispatcher(classLoader: ClassLoader) =
ThreadPoolExecutor(
0, Runtime.getRuntime().availableProcessors(),
10, TimeUnit.SECONDS, LinkedBlockingQueue(),
object: ThreadFactory {
val group = ThreadGroup("sync-work")
override fun newThread(r: Runnable) =
Thread(group, r).apply {
contextClassLoader = classLoader
}
}
).asCoroutineDispatcher()
}

View file

@ -5,21 +5,36 @@
package at.bitfire.davdroid.syncadapter
import android.accounts.Account
import android.content.*
import android.content.ContentProviderClient
import android.content.ContentResolver
import android.content.Context
import android.content.Intent
import android.content.SyncResult
import android.net.ConnectivityManager
import android.net.wifi.WifiManager
import android.os.Build
import android.provider.CalendarContract
import android.provider.ContactsContract
import androidx.annotation.IntDef
import androidx.concurrent.futures.CallbackToFutureAdapter
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.core.content.getSystemService
import androidx.hilt.work.HiltWorker
import androidx.lifecycle.LiveData
import androidx.lifecycle.map
import androidx.work.*
import androidx.work.BackoffPolicy
import androidx.work.Constraints
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.ExistingWorkPolicy
import androidx.work.ForegroundInfo
import androidx.work.NetworkType
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkInfo
import androidx.work.WorkManager
import androidx.work.WorkQuery
import androidx.work.WorkRequest
import androidx.work.WorkerParameters
import at.bitfire.davdroid.R
import at.bitfire.davdroid.log.Logger
import at.bitfire.davdroid.network.ConnectionUtils.internetAvailable
@ -31,9 +46,10 @@ import at.bitfire.davdroid.ui.NotificationUtils.notifyIfPossible
import at.bitfire.davdroid.ui.account.WifiPermissionsActivity
import at.bitfire.davdroid.util.PermissionUtils
import at.bitfire.ical4android.TaskProvider
import com.google.common.util.concurrent.ListenableFuture
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withContext
import java.util.concurrent.TimeUnit
import java.util.logging.Level
@ -58,7 +74,7 @@ import java.util.logging.Level
class SyncWorker @AssistedInject constructor(
@Assisted appContext: Context,
@Assisted workerParams: WorkerParameters
) : Worker(appContext, workerParams) {
) : CoroutineWorker(appContext, workerParams) {
companion object {
@ -278,12 +294,10 @@ class SyncWorker @AssistedInject constructor(
}
private val dispatcher = SyncWorkDispatcher.getInstance(applicationContext)
private val notificationManager = NotificationManagerCompat.from(applicationContext)
/** thread which runs the actual sync code (can be interrupted to stop synchronization) */
var syncThread: Thread? = null
override fun doWork(): Result {
override suspend fun doWork(): Result = withContext(dispatcher) {
// ensure we got the required arguments
val account = Account(
inputData.getString(ARG_ACCOUNT_NAME) ?: throw IllegalArgumentException("$ARG_ACCOUNT_NAME required"),
@ -296,7 +310,7 @@ class SyncWorker @AssistedInject constructor(
val connectivityManager = applicationContext.getSystemService<ConnectivityManager>()!!
if (!internetAvailable(connectivityManager, ignoreVpns)) {
Logger.log.info("WorkManager started SyncWorker without Internet connection. Aborting.")
return Result.failure()
return@withContext Result.failure()
}
Logger.log.info("Running sync worker: account=$account, authority=$authority")
@ -338,19 +352,16 @@ class SyncWorker @AssistedInject constructor(
}
if (provider == null) {
Logger.log.warning("Couldn't acquire ContentProviderClient for $authority")
return Result.failure()
return@withContext Result.failure()
}
// Start syncing. We still use the sync adapter framework's SyncResult to pass the sync results, but this
// is only for legacy reasons and can be replaced by an own result class in the future.
val result = SyncResult()
try {
syncThread = Thread.currentThread()
syncer.onPerformSync(account, extras.toTypedArray(), authority, provider, result)
} catch (e: SecurityException) {
Logger.log.log(Level.WARNING, "Security exception when opening content provider for $authority")
} finally {
provider.close()
provider.use {
// Start syncing. We still use the sync adapter framework's SyncResult to pass the sync results, but this
// is only for legacy reasons and can be replaced by an own result class in the future.
runInterruptible {
syncer.onPerformSync(account, extras.toTypedArray(), authority, provider, result)
}
}
// Check for errors
@ -375,7 +386,7 @@ class SyncWorker @AssistedInject constructor(
Thread.sleep(blockDuration*1000)
Logger.log.warning("Retrying on soft error (attempt $runAttemptCount of $MAX_RUN_ATTEMPTS)")
return Result.retry()
return@withContext Result.retry()
}
Logger.log.warning("Max retries on soft errors reached ($runAttemptCount of $MAX_RUN_ATTEMPTS). Treating as failed")
@ -394,7 +405,7 @@ class SyncWorker @AssistedInject constructor(
.build()
)
return Result.failure(syncResult)
return@withContext Result.failure(syncResult)
}
// If no soft error found, dismiss sync error notification
@ -407,30 +418,24 @@ class SyncWorker @AssistedInject constructor(
// Note: SyncManager should have notified the user
if (result.hasHardError()) {
Logger.log.warning("Hard error while syncing: result=$result, stats=${result.stats}")
return Result.failure(syncResult)
return@withContext Result.failure(syncResult)
}
}
return Result.success()
return@withContext Result.success()
}
override fun onStopped() {
Logger.log.info("Work stopped (reason ${if (Build.VERSION.SDK_INT >= 31) stopReason else "n/a"}), stopping sync thread")
syncThread?.interrupt()
override suspend fun getForegroundInfo(): ForegroundInfo {
val notification = NotificationUtils.newBuilder(applicationContext, NotificationUtils.CHANNEL_STATUS)
.setSmallIcon(R.drawable.ic_foreground_notify)
.setContentTitle(applicationContext.getString(R.string.foreground_service_notify_title))
.setContentText(applicationContext.getString(R.string.foreground_service_notify_text))
.setStyle(NotificationCompat.BigTextStyle())
.setCategory(NotificationCompat.CATEGORY_STATUS)
.setOngoing(true)
.setPriority(NotificationCompat.PRIORITY_LOW)
.build()
return ForegroundInfo(NotificationUtils.NOTIFY_SYNC_EXPEDITED, notification)
}
override fun getForegroundInfoAsync(): ListenableFuture<ForegroundInfo> =
CallbackToFutureAdapter.getFuture { completer ->
val notification = NotificationUtils.newBuilder(applicationContext, NotificationUtils.CHANNEL_STATUS)
.setSmallIcon(R.drawable.ic_foreground_notify)
.setContentTitle(applicationContext.getString(R.string.foreground_service_notify_title))
.setContentText(applicationContext.getString(R.string.foreground_service_notify_text))
.setStyle(NotificationCompat.BigTextStyle())
.setCategory(NotificationCompat.CATEGORY_STATUS)
.setOngoing(true)
.setPriority(NotificationCompat.PRIORITY_LOW)
.build()
completer.set(ForegroundInfo(NotificationUtils.NOTIFY_SYNC_EXPEDITED, notification))
}
}