From 46acf5894c3d6110f99fc3a80ead78d89fb415ee Mon Sep 17 00:00:00 2001 From: Hai Zhang Date: Sun, 11 Feb 2024 21:48:23 -0800 Subject: [PATCH] Refactor: Extract AbstractFileByteChannel from FTP, SFTP and SMB --- .../android/files/compat/InputStreamCompat.kt | 48 +++ .../common/AbstractFileByteChannel.kt | 299 ++++++++++++++++ .../provider/common/ByteBufferExtensions.kt | 14 + .../provider/common/ByteBufferInputStream.kt | 66 ++++ .../files/provider/common/FutureExtensions.kt | 132 +++++++ .../provider/ftp/client/FileByteChannel.kt | 310 +++------------- .../provider/sftp/client/FileByteChannel.kt | 295 ++++------------ .../provider/smb/client/FileByteChannel.kt | 330 ++++-------------- 8 files changed, 745 insertions(+), 749 deletions(-) create mode 100644 app/src/main/java/me/zhanghai/android/files/compat/InputStreamCompat.kt create mode 100644 app/src/main/java/me/zhanghai/android/files/provider/common/AbstractFileByteChannel.kt create mode 100644 app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferExtensions.kt create mode 100644 app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferInputStream.kt create mode 100644 app/src/main/java/me/zhanghai/android/files/provider/common/FutureExtensions.kt diff --git a/app/src/main/java/me/zhanghai/android/files/compat/InputStreamCompat.kt b/app/src/main/java/me/zhanghai/android/files/compat/InputStreamCompat.kt new file mode 100644 index 00000000..8246474c --- /dev/null +++ b/app/src/main/java/me/zhanghai/android/files/compat/InputStreamCompat.kt @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024 Hai Zhang + * All Rights Reserved. + */ + +package me.zhanghai.android.files.compat + +import java.io.IOException +import java.io.InputStream +import kotlin.reflect.KClass + +fun KClass.nullInputStream(): InputStream = + object : InputStream() { + private var closed = false + + override fun read(): Int { + ensureOpen() + return -1; + } + + override fun read(bytes: ByteArray, offset: Int, length: Int): Int { + if (!(offset >= 0 && length >= 0 && length <= bytes.size - offset)) { + throw IndexOutOfBoundsException() + } + ensureOpen() + return if (length == 0) 0 else -1 + } + + override fun skip(length: Long): Long { + ensureOpen() + return 0 + } + + override fun available(): Int { + ensureOpen() + return 0 + } + + override fun close() { + closed = true + } + + private fun ensureOpen() { + if (closed) { + throw IOException("Stream closed") + } + } + } diff --git a/app/src/main/java/me/zhanghai/android/files/provider/common/AbstractFileByteChannel.kt b/app/src/main/java/me/zhanghai/android/files/provider/common/AbstractFileByteChannel.kt new file mode 100644 index 00000000..14455e2a --- /dev/null +++ b/app/src/main/java/me/zhanghai/android/files/provider/common/AbstractFileByteChannel.kt @@ -0,0 +1,299 @@ +/* + * Copyright (c) 2024 Hai Zhang + * All Rights Reserved. + */ + +package me.zhanghai.android.files.provider.common + +import java8.nio.channels.SeekableByteChannel +import kotlinx.coroutines.DelicateCoroutinesApi +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.GlobalScope +import kotlinx.coroutines.async +import kotlinx.coroutines.runInterruptible +import kotlinx.coroutines.withTimeout +import me.zhanghai.android.files.util.closeSafe +import java.io.Closeable +import java.io.IOException +import java.io.InterruptedIOException +import java.nio.ByteBuffer +import java.nio.channels.ClosedChannelException +import java.nio.channels.NonReadableChannelException +import java.util.concurrent.CancellationException +import java.util.concurrent.ExecutionException +import java.util.concurrent.Future + +abstract class AbstractFileByteChannel( + private val isAppend: Boolean, + private val shouldCancelRead: Boolean = true, + private val joinCancelledRead: Boolean = false +) : ForceableChannel, SeekableByteChannel { + private var position = 0L + private val readBuffer = ReadBuffer() + private val ioLock = Any() + + private var isOpen = true + private val closeLock = Any() + + @Throws(IOException::class) + final override fun read(destination: ByteBuffer): Int { + ensureOpen() + if (isAppend) { + throw NonReadableChannelException() + } + val remaining = destination.remaining() + if (remaining == 0) { + return 0 + } + return synchronized(ioLock) { + readBuffer.read(destination).also { + if (it != -1) { + position += it + } + } + } + } + + protected open fun onReadAsync( + position: Long, + size: Int, + timeoutMillis: Long + ): Future = + @OptIn(DelicateCoroutinesApi::class) + GlobalScope.async(Dispatchers.IO) { + withTimeout(timeoutMillis) { + runInterruptible { + onRead(position, size) + } + } + } + .asFuture() + + @Throws(IOException::class) + protected open fun onRead(position: Long, size: Int): ByteBuffer { + throw NotImplementedError() + } + + @Throws(IOException::class) + final override fun write(source: ByteBuffer): Int { + ensureOpen() + val remaining = source.remaining() + if (remaining == 0) { + return 0 + } + synchronized(ioLock) { + if (isAppend) { + onAppend(source) + position = onSize() + } else { + onWrite(position, source) + position += remaining - source.remaining() + } + return remaining + } + } + + @Throws(IOException::class) + protected abstract fun onWrite(position: Long, source: ByteBuffer) + + @Throws(IOException::class) + protected open fun onAppend(source: ByteBuffer) { + val position = onSize() + onWrite(position, source) + } + + @Throws(IOException::class) + final override fun position(): Long { + ensureOpen() + synchronized(ioLock) { + if (isAppend) { + position = onSize() + } + return position + } + } + + final override fun position(newPosition: Long): SeekableByteChannel { + ensureOpen() + if (isAppend) { + // Ignored. + return this + } + synchronized(ioLock) { + readBuffer.reposition(position, newPosition) + position = newPosition + } + return this + } + + @Throws(IOException::class) + final override fun size(): Long { + ensureOpen() + return onSize() + } + + @Throws(IOException::class) + final override fun truncate(size: Long): SeekableByteChannel { + ensureOpen() + require(size >= 0) + synchronized(ioLock) { + val currentSize = onSize() + if (size >= currentSize) { + return this + } + onTruncate(size) + position = position.coerceAtMost(size) + } + return this + } + + @Throws(IOException::class) + protected abstract fun onTruncate(size: Long) + + @Throws(IOException::class) + protected abstract fun onSize(): Long + + @Throws(IOException::class) + final override fun force(metaData: Boolean) { + ensureOpen() + synchronized(ioLock) { + onForce(metaData) + } + } + + @Throws(IOException::class) + protected open fun onForce(metaData: Boolean) {} + + @Throws(ClosedChannelException::class) + private fun ensureOpen() { + synchronized(closeLock) { + if (!isOpen) { + throw ClosedChannelException() + } + } + } + + final override fun isOpen(): Boolean = synchronized(closeLock) { isOpen } + + @Throws(IOException::class) + final override fun close() { + synchronized(closeLock) { + if (!isOpen) { + return + } + isOpen = false + synchronized(ioLock) { + readBuffer.closeSafe() + onClose() + } + } + } + + protected fun setClosed() { + synchronized(closeLock) { + isOpen = false + } + } + + @Throws(IOException::class) + protected open fun onClose() {} + + private inner class ReadBuffer : Closeable { + private val buffer = ByteBuffer.allocate(BUFFER_SIZE).apply { limit(0) } + private var bufferedPosition = 0L + + private var pendingRead: Future? = null + private val pendingReadLock = Any() + + @Throws(IOException::class) + fun read(destination: ByteBuffer): Int { + if (!buffer.hasRemaining()) { + readIntoBuffer() + if (!buffer.hasRemaining()) { + return -1 + } + } + val length = destination.remaining().coerceAtMost(buffer.remaining()) + val bufferLimit = buffer.limit() + buffer.limit(buffer.position() + length) + destination.put(buffer) + buffer.limit(bufferLimit) + return length + } + + @Throws(IOException::class) + private fun readIntoBuffer() { + val future = synchronized(pendingReadLock) { + pendingRead?.also { pendingRead = null } + } ?: readIntoBufferAsync() + val newBuffer = try { + future.get() + } catch (e: CancellationException) { + throw InterruptedIOException().apply { initCause(e) } + } catch (e: InterruptedException) { + throw InterruptedIOException().apply { initCause(e) } + } catch (e: ExecutionException) { + val exception = e.cause ?: e + if (exception is IOException) { + throw exception + } else { + throw IOException(exception) + } + } + buffer.clear() + buffer.put(newBuffer) + buffer.flip() + if (!buffer.hasRemaining()) { + return + } + bufferedPosition += buffer.remaining() + synchronized(pendingReadLock) { + pendingRead = readIntoBufferAsync() + } + } + + private fun readIntoBufferAsync(): Future = + onReadAsync(bufferedPosition, BUFFER_SIZE, TIMEOUT_MILLIS) + + fun reposition(oldPosition: Long, newPosition: Long) { + if (newPosition == oldPosition) { + return + } + val newBufferPosition = buffer.position() + (newPosition - oldPosition) + if (newBufferPosition in 0..buffer.limit()) { + buffer.position(newBufferPosition.toInt()) + } else { + cancelPendingRead() + buffer.limit(0) + bufferedPosition = newPosition + } + } + + override fun close() { + cancelPendingRead() + } + + private fun cancelPendingRead() { + synchronized(pendingReadLock) { + pendingRead?.let { + if (shouldCancelRead) { + it.cancel(true) + if (joinCancelledRead) { + try { + it.get() + } catch (e: Exception) { + // Ignored + } + } + } + pendingRead = null + } + } + } + } + + companion object { + private const val BUFFER_SIZE = 1024 * 1024 + private const val TIMEOUT_MILLIS = 15_000L + } +} diff --git a/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferExtensions.kt b/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferExtensions.kt new file mode 100644 index 00000000..6c99cefa --- /dev/null +++ b/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferExtensions.kt @@ -0,0 +1,14 @@ +/* + * Copyright (c) 2024 Hai Zhang + * All Rights Reserved. + */ + +package me.zhanghai.android.files.provider.common + +import java.nio.ByteBuffer +import kotlin.reflect.KClass + +private val EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0) + +val KClass.EMPTY: ByteBuffer + get() = EMPTY_BYTE_BUFFER diff --git a/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferInputStream.kt b/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferInputStream.kt new file mode 100644 index 00000000..3899adb0 --- /dev/null +++ b/app/src/main/java/me/zhanghai/android/files/provider/common/ByteBufferInputStream.kt @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2024 Hai Zhang + * All Rights Reserved. + */ + +package me.zhanghai.android.files.provider.common + +import java.io.IOException +import java.io.InputStream +import java.nio.ByteBuffer + +class ByteBufferInputStream(buffer: ByteBuffer) : InputStream() { + private var buffer: ByteBuffer? = buffer + + override fun read(): Int { + val buffer = ensureOpen() + return if (buffer.hasRemaining()) buffer.get().toInt() and 0xFF else -1 + } + + override fun read(bytes: ByteArray, offset: Int, length: Int): Int { + val buffer = ensureOpen() + if (length == 0) { + return 0 + } + val remaining = buffer.remaining() + if (remaining == 0) { + return -1 + } + val readLength = length.coerceAtMost(remaining) + buffer.get(bytes, offset, readLength) + return readLength + } + + override fun skip(length: Long): Long { + val buffer = ensureOpen() + if (length <= 0) { + return 0 + } + val skippedLength = length.toInt().coerceAtMost(buffer.remaining()) + buffer.position(buffer.position() + skippedLength) + return skippedLength.toLong() + } + + override fun available(): Int { + val buffer = ensureOpen() + return buffer.remaining() + } + + override fun markSupported(): Boolean = true + + override fun mark(readlimit: Int) { + val buffer = ensureOpen() + buffer.mark() + } + + override fun reset() { + val buffer = ensureOpen() + buffer.reset() + } + + override fun close() { + buffer = null + } + + private fun ensureOpen(): ByteBuffer = buffer ?: throw IOException("Stream closed"); +} diff --git a/app/src/main/java/me/zhanghai/android/files/provider/common/FutureExtensions.kt b/app/src/main/java/me/zhanghai/android/files/provider/common/FutureExtensions.kt new file mode 100644 index 00000000..0d3bb1b3 --- /dev/null +++ b/app/src/main/java/me/zhanghai/android/files/provider/common/FutureExtensions.kt @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2024 Hai Zhang + * All Rights Reserved. + */ + +package me.zhanghai.android.files.provider.common + +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.ExperimentalCoroutinesApi +import net.schmizz.concurrent.Promise +import java.util.concurrent.CancellationException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.ExecutionException +import java.util.concurrent.Future +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException + +inline fun Future.map( + crossinline transform: (T) -> R, + crossinline transformException: (Exception) -> Exception = { it } +): Future = + object : Future { + override fun cancel(mayInterruptIfRunning: Boolean): Boolean = + this@map.cancel(mayInterruptIfRunning) + + override fun isCancelled(): Boolean = this@map.isCancelled + + override fun isDone(): Boolean = this@map.isDone + + @Throws(ExecutionException::class, InterruptedException::class) + override fun get(): R = transformGet { this@map.get() } + + @Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class) + override fun get(timeout: Long, unit: TimeUnit): R = + transformGet { this@map.get(timeout, unit) } + + @Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class) + private inline fun transformGet(get: () -> T): R { + val result = try { + get() + } catch (e: Exception) { + val exception = try { + transformException(e) + } catch (e2: Exception) { + e2.addSuppressed(e) + throw ExecutionException(e2) + } + check( + exception is ExecutionException || exception is InterruptedException || + exception is TimeoutException + ) + throw exception + } + try { + return transform(result) + } catch (e: Exception) { + throw ExecutionException(e) + } + } + } + +fun Deferred.asFuture(): Future = + object : Future { + private val latch = CountDownLatch(1) + + init { + invokeOnCompletion { latch.countDown() } + } + + override fun cancel(mayInterruptIfRunning: Boolean): Boolean { + cancel() + return this@asFuture.isCancelled + } + + override fun isCancelled(): Boolean = this@asFuture.isCancelled + + override fun isDone(): Boolean = isCompleted + + @Throws(ExecutionException::class, InterruptedException::class) + override fun get(): T { + latch.await() + return getCompleted() + } + + @Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class) + override fun get(timeout: Long, unit: TimeUnit): T { + latch.await(timeout, unit) + return getCompleted() + } + + @OptIn(ExperimentalCoroutinesApi::class) + @Throws(ExecutionException::class) + private fun getCompleted(): T = + try { + this@asFuture.getCompleted() + } catch (e: CancellationException) { + throw e + } catch (e: Exception) { + throw ExecutionException(e) + } + } + +fun Promise.asFuture(): Future = + object : Future { + override fun cancel(mayInterruptIfRunning: Boolean): Boolean = false + + override fun isCancelled(): Boolean = false + + override fun isDone(): Boolean = isFulfilled + + @Throws(ExecutionException::class, InterruptedException::class) + override fun get(): T = tryRetrieve { retrieve() } + + @Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class) + override fun get(timeout: Long, unit: TimeUnit?): T = + tryRetrieve { retrieve(timeout, unit) } + + @Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class) + private inline fun tryRetrieve(retrieve: () -> T): T = + try { + retrieve() + } catch (e: Exception) { + when (val cause = e.cause) { + is InterruptedException -> { + Thread.interrupted() + throw cause + } + is TimeoutException -> throw cause + else -> throw ExecutionException(e) + } + } + } diff --git a/app/src/main/java/me/zhanghai/android/files/provider/ftp/client/FileByteChannel.kt b/app/src/main/java/me/zhanghai/android/files/provider/ftp/client/FileByteChannel.kt index 0ee8962a..e333e01e 100644 --- a/app/src/main/java/me/zhanghai/android/files/provider/ftp/client/FileByteChannel.kt +++ b/app/src/main/java/me/zhanghai/android/files/provider/ftp/client/FileByteChannel.kt @@ -5,150 +5,80 @@ package me.zhanghai.android.files.provider.ftp.client -import java8.nio.channels.SeekableByteChannel -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.DelicateCoroutinesApi -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.GlobalScope -import kotlinx.coroutines.async -import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.runInterruptible -import kotlinx.coroutines.withTimeout -import me.zhanghai.android.files.provider.common.ForceableChannel +import me.zhanghai.android.files.compat.nullInputStream +import me.zhanghai.android.files.provider.common.AbstractFileByteChannel +import me.zhanghai.android.files.provider.common.ByteBufferInputStream import me.zhanghai.android.files.provider.common.readFully -import me.zhanghai.android.files.util.closeSafe import org.apache.commons.net.ftp.FTPClient -import java.io.ByteArrayInputStream -import java.io.Closeable import java.io.IOException -import java.io.InterruptedIOException +import java.io.InputStream import java.nio.ByteBuffer -import java.nio.channels.ClosedChannelException -import java.nio.channels.NonReadableChannelException class FileByteChannel( private val client: FTPClient, private val releaseClient: (FTPClient) -> Unit, private val path: String, - private val isAppend: Boolean -) : ForceableChannel, SeekableByteChannel { + isAppend: Boolean +) : AbstractFileByteChannel(isAppend, joinCancelledRead = true) { private val clientLock = Any() - private var position = 0L - private val readBuffer = ReadBuffer() - private val ioLock = Any() - - private var isOpen = true - private val closeLock = Any() + @Throws(IOException::class) + override fun onRead(position: Long, size: Int): ByteBuffer { + val destination = ByteBuffer.allocate(size) + synchronized(clientLock) { + client.restartOffset = position + val inputStream = client.retrieveFileStream(path) + ?: client.throwNegativeReplyCodeException() + try { + val limit = inputStream.use { it.readFully(destination.array(), 0, size) } + destination.limit(limit) + } finally { + // We will likely close the input stream before the file is fully + // read and it will result in a false return value here, but that's + // totally fine. + client.completePendingCommand() + } + } + return destination + } @Throws(IOException::class) - override fun read(destination: ByteBuffer): Int { - ensureOpen() - if (isAppend) { - throw NonReadableChannelException() - } - val remaining = destination.remaining() - if (remaining == 0) { - return 0 - } - return synchronized(ioLock) { - readBuffer.read(destination).also { - if (it != -1) { - position += it + override fun onWrite(position: Long, source: ByteBuffer) { + synchronized(clientLock) { + client.restartOffset = position + ByteBufferInputStream(source).use { + if (!client.storeFile(path, it)) { + client.throwNegativeReplyCodeException() } } } } @Throws(IOException::class) - override fun write(source: ByteBuffer): Int { - ensureOpen() - val remaining = source.remaining() - if (remaining == 0) { - return 0 - } - // I don't think we are using native or read-only ByteBuffer, so just call array() here. - synchronized(ioLock) { - if (isAppend) { - synchronized(clientLock) { - ByteArrayInputStream(source.array(), source.position(), remaining).use { - if (!client.appendFile(path, it)) { - client.throwNegativeReplyCodeException() - } - } - } - position = getSize() - } else { - synchronized(clientLock) { - client.restartOffset = position - ByteArrayInputStream(source.array(), source.position(), remaining).use { - if (!client.storeFile(path, it)) { - client.throwNegativeReplyCodeException() - } - } - } - position += remaining - } - source.position(source.limit()) - return remaining - } - } - - @Throws(IOException::class) - override fun position(): Long { - ensureOpen() - synchronized(ioLock) { - if (isAppend) { - position = getSize() - } - return position - } - } - - override fun position(newPosition: Long): SeekableByteChannel { - ensureOpen() - if (isAppend) { - // Ignored. - return this - } - synchronized(ioLock) { - readBuffer.reposition(position, newPosition) - position = newPosition - } - return this - } - - @Throws(IOException::class) - override fun size(): Long { - ensureOpen() - return getSize() - } - - @Throws(IOException::class) - override fun truncate(size: Long): SeekableByteChannel { - ensureOpen() - require(size >= 0) - synchronized(ioLock) { - val currentSize = getSize() - if (size >= currentSize) { - return this - } - synchronized(clientLock) { - client.restartOffset = size - ByteArrayInputStream(byteArrayOf()).use { - if (!client.storeFile(path, it)) { - client.throwNegativeReplyCodeException() - } + override fun onAppend(source: ByteBuffer) { + synchronized(clientLock) { + ByteBufferInputStream(source).use { + if (!client.appendFile(path, it)) { + client.throwNegativeReplyCodeException() } } - position = position.coerceAtMost(size) } - return this } @Throws(IOException::class) - private fun getSize(): Long { + override fun onTruncate(size: Long) { + synchronized(clientLock) { + client.restartOffset = size + InputStream::class.nullInputStream().use { + if (!client.storeFile(path, it)) { + client.throwNegativeReplyCodeException() + } + } + } + } + + @Throws(IOException::class) + override fun onSize(): Long { val sizeString = synchronized(clientLock) { client.getSize(path) ?: client.throwNegativeReplyCodeException() } @@ -156,145 +86,7 @@ class FileByteChannel( } @Throws(IOException::class) - override fun force(metaData: Boolean) { - ensureOpen() - // Unsupported. - } - - @Throws(ClosedChannelException::class) - private fun ensureOpen() { - synchronized(closeLock) { - if (!isOpen) { - throw ClosedChannelException() - } - } - } - - override fun isOpen(): Boolean = synchronized(closeLock) { isOpen } - - @Throws(IOException::class) - override fun close() { - synchronized(closeLock) { - if (!isOpen) { - return - } - isOpen = false - synchronized(ioLock) { - readBuffer.closeSafe() - synchronized(clientLock) { releaseClient(client) } - } - } - } - - private inner class ReadBuffer : Closeable { - private val bufferSize = DEFAULT_BUFFER_SIZE - private val timeoutMillis = 15_000L - - private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) } - private var bufferedPosition = 0L - - private var pendingDeferred: Deferred? = null - private val pendingDeferredLock = Any() - - @Throws(IOException::class) - fun read(destination: ByteBuffer): Int { - if (!buffer.hasRemaining()) { - readIntoBuffer() - if (!buffer.hasRemaining()) { - return -1 - } - } - val length = destination.remaining().coerceAtMost(buffer.remaining()) - val bufferLimit = buffer.limit() - buffer.limit(buffer.position() + length) - destination.put(buffer) - buffer.limit(bufferLimit) - return length - } - - @Throws(IOException::class) - private fun readIntoBuffer() { - val deferred = synchronized(pendingDeferredLock) { - pendingDeferred?.also { pendingDeferred = null } - } ?: readIntoBufferAsync() - val newBuffer = try { - runBlocking { deferred.await() } - } catch (e: CancellationException) { - throw InterruptedIOException().apply { initCause(e) } - } - buffer.clear() - buffer.put(newBuffer) - buffer.flip() - if (!buffer.hasRemaining()) { - return - } - bufferedPosition += buffer.remaining() - synchronized(pendingDeferredLock) { - pendingDeferred = readIntoBufferAsync() - } - } - - private fun readIntoBufferAsync(): Deferred = - @OptIn(DelicateCoroutinesApi::class) - GlobalScope.async(Dispatchers.IO) { - withTimeout(timeoutMillis) { - runInterruptible { - synchronized(clientLock) { - client.restartOffset = bufferedPosition - val inputStream = client.retrieveFileStream(path) - ?: client.throwNegativeReplyCodeException() - try { - val buffer = ByteBuffer.allocate(bufferSize) - val limit = inputStream.use { - it.readFully( - buffer.array(), buffer.position(), buffer.remaining() - ) - } - buffer.limit(limit) - buffer - } finally { - // We will likely close the input stream before the file is fully - // read and it will result in a false return value here, but that's - // totally fine. - client.completePendingCommand() - } - } - } - } - } - - fun reposition(oldPosition: Long, newPosition: Long) { - if (newPosition == oldPosition) { - return - } - val newBufferPosition = buffer.position() + (newPosition - oldPosition) - if (newBufferPosition in 0..buffer.limit()) { - buffer.position(newBufferPosition.toInt()) - } else { - synchronized(pendingDeferredLock) { - pendingDeferred?.let { - it.cancel() - runBlocking { it.join() } - pendingDeferred = null - } - } - buffer.limit(0) - bufferedPosition = newPosition - } - } - - override fun close() { - synchronized(pendingDeferredLock) { - pendingDeferred?.let { - it.cancel() - runBlocking { it.join() } - pendingDeferred = null - } - } - } - } - - companion object { - private const val DEFAULT_BUFFER_SIZE = 1024 * 1024 + override fun onClose() { + synchronized(clientLock) { releaseClient(client) } } } diff --git a/app/src/main/java/me/zhanghai/android/files/provider/sftp/client/FileByteChannel.kt b/app/src/main/java/me/zhanghai/android/files/provider/sftp/client/FileByteChannel.kt index 8ef1f7a1..86d9557c 100644 --- a/app/src/main/java/me/zhanghai/android/files/provider/sftp/client/FileByteChannel.kt +++ b/app/src/main/java/me/zhanghai/android/files/provider/sftp/client/FileByteChannel.kt @@ -5,11 +5,12 @@ package me.zhanghai.android.files.provider.sftp.client -import java8.nio.channels.SeekableByteChannel -import me.zhanghai.android.files.provider.common.ForceableChannel +import me.zhanghai.android.files.provider.common.AbstractFileByteChannel +import me.zhanghai.android.files.provider.common.EMPTY +import me.zhanghai.android.files.provider.common.asFuture +import me.zhanghai.android.files.provider.common.map import me.zhanghai.android.files.util.closeSafe import me.zhanghai.android.files.util.findCauseByClass -import net.schmizz.concurrent.Promise import net.schmizz.sshj.sftp.PacketType import net.schmizz.sshj.sftp.RemoteFile import net.schmizz.sshj.sftp.RemoteFileAccessor @@ -19,139 +20,76 @@ import java.io.IOException import java.nio.ByteBuffer import java.nio.channels.AsynchronousCloseException import java.nio.channels.ClosedByInterruptException -import java.nio.channels.ClosedChannelException -import java.nio.channels.NonReadableChannelException -import java.util.concurrent.TimeUnit +import java.util.concurrent.ExecutionException +import java.util.concurrent.Future class FileByteChannel( private val file: RemoteFile, - private val isAppend: Boolean -) : ForceableChannel, SeekableByteChannel { - private var position = 0L - private val readBuffer = ReadBuffer() - private val ioLock = Any() - - private var isOpen = true - private val closeLock = Any() - - @Throws(IOException::class) - override fun read(destination: ByteBuffer): Int { - ensureOpen() - if (isAppend) { - throw NonReadableChannelException() + isAppend: Boolean +) : AbstractFileByteChannel(isAppend) { + override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future = + try { + RemoteFileAccessor.asyncRead(file, position, size) + } catch (e: IOException) { + throw e.maybeToSpecificException() } - val remaining = destination.remaining() - if (remaining == 0) { - return 0 - } - return synchronized(ioLock) { - readBuffer.read(destination).also { - if (it != -1) { - position += it + .asFuture() + .map( + { response -> + val dataLength: Int + when (response.type) { + PacketType.STATUS -> { + response.ensureStatusIs(Response.StatusCode.EOF) + return@map ByteBuffer::class.EMPTY + } + PacketType.DATA -> { + dataLength = response.readUInt32AsInt() + } + else -> throw SFTPException("Unexpected packet type ${response.type}") + } + if (dataLength == 0) { + return@map ByteBuffer::class.EMPTY + } + val length = dataLength.coerceAtMost(size) + ByteBuffer.wrap(response.array(), response.rpos(), length) + }, { e -> + ((e as? ExecutionException)?.cause as? IOException)?.maybeToSpecificException() + ?.let { ExecutionException(it) } ?: e } - } + ) + + @Throws(IOException::class) + override fun onWrite(position: Long, source: ByteBuffer) { + // I don't think we are using native or read-only ByteBuffer, so just call array() here. + try { + file.write(position, source.array(), source.position(), source.remaining()) + } catch (e: IOException) { + throw e.maybeToSpecificException() + } + source.position(source.limit()) + } + + @Throws(IOException::class) + override fun onTruncate(size: Long) { + try { + file.setLength(size) + } catch (e: IOException) { + throw e.maybeToSpecificException() } } @Throws(IOException::class) - override fun write(source: ByteBuffer): Int { - ensureOpen() - val remaining = source.remaining() - if (remaining == 0) { - return 0 - } - synchronized(ioLock) { - if (isAppend) { - position = getSize() - } - // I don't think we are using native or read-only ByteBuffer, so just call array() here. - try { - file.write(position, source.array(), source.position(), remaining) - } catch (e: IOException) { - throw e.maybeToSpecificException() - } - source.position(source.limit()) - position += remaining - return remaining - } - } - - @Throws(IOException::class) - override fun position(): Long { - ensureOpen() - synchronized(ioLock) { - if (isAppend) { - position = getSize() - } - return position - } - } - - override fun position(newPosition: Long): SeekableByteChannel { - ensureOpen() - if (isAppend) { - // Ignored. - return this - } - synchronized(ioLock) { - readBuffer.reposition(position, newPosition) - position = newPosition - } - return this - } - - @Throws(IOException::class) - override fun size(): Long { - ensureOpen() - return getSize() - } - - @Throws(IOException::class) - override fun truncate(size: Long): SeekableByteChannel { - ensureOpen() - require(size >= 0) - synchronized(ioLock) { - val currentSize = getSize() - if (size >= currentSize) { - return this - } - try { - file.setLength(size) - } catch (e: IOException) { - throw e.maybeToSpecificException() - } - position = position.coerceAtMost(size) - } - return this - } - - @Throws(IOException::class) - private fun getSize(): Long = + override fun onSize(): Long = try{ file.length() } catch (e: IOException) { throw e.maybeToSpecificException() } - @Throws(IOException::class) - override fun force(metaData: Boolean) { - ensureOpen() - // Unsupported. - } - - @Throws(ClosedChannelException::class) - private fun ensureOpen() { - synchronized(closeLock) { - if (!isOpen) { - throw ClosedChannelException() - } - } - } - private fun IOException.maybeToSpecificException(): IOException = when { this is SFTPException && statusCode == Response.StatusCode.INVALID_HANDLE -> { - synchronized(closeLock) { isOpen = false } + setClosed() AsynchronousCloseException().apply { initCause(this@maybeToSpecificException) } } findCauseByClass() != null -> { @@ -161,122 +99,15 @@ class FileByteChannel( else -> this } - override fun isOpen(): Boolean = synchronized(closeLock) { isOpen } - @Throws(IOException::class) - override fun close() { - synchronized(closeLock) { - if (!isOpen) { - return - } - isOpen = false - try { - file.close() - } catch (e: SFTPException) { - // NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server. - if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) { - throw e - } + override fun onClose() { + try { + file.close() + } catch (e: SFTPException) { + // NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server. + if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) { + throw e } } } - - private inner class ReadBuffer { - private val bufferSize: Int = DEFAULT_BUFFER_SIZE - private val timeout: Long - - init { - val engine = RemoteFileAccessor.getRequester(file) - timeout = engine.timeoutMs.toLong() - } - - private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) } - private var bufferedPosition = 0L - - private var pendingPromise: Promise? = null - private val pendingPromiseLock = Any() - - @Throws(IOException::class) - fun read(destination: ByteBuffer): Int { - if (!buffer.hasRemaining()) { - readIntoBuffer() - if (!buffer.hasRemaining()) { - return -1 - } - } - val length = destination.remaining().coerceAtMost(buffer.remaining()) - val bufferLimit = buffer.limit() - buffer.limit(buffer.position() + length) - destination.put(buffer) - buffer.limit(bufferLimit) - return length - } - - @Throws(IOException::class) - private fun readIntoBuffer() { - val promise = synchronized(pendingPromiseLock) { - pendingPromise?.also { pendingPromise = null } - } ?: readIntoBufferAsync() - val response = try { - promise.retrieve(timeout, TimeUnit.MILLISECONDS) - } catch (e: IOException) { - throw e.maybeToSpecificException() - } - val dataLength: Int - when (response.type) { - PacketType.STATUS -> { - response.ensureStatusIs(Response.StatusCode.EOF) - buffer.limit(0) - return - } - PacketType.DATA -> { - dataLength = response.readUInt32AsInt() - } - else -> throw SFTPException("Unexpected packet type ${response.type}") - } - if (dataLength == 0) { - buffer.limit(0) - return - } - buffer.clear() - val length = dataLength.coerceAtMost(buffer.remaining()) - buffer.put(response.array(), response.rpos(), length) - buffer.flip() - bufferedPosition += length - synchronized(pendingPromiseLock) { - try { - pendingPromise = readIntoBufferAsync() - } catch (e: IOException) { - e.printStackTrace() - } - } - } - - @Throws(IOException::class) - private fun readIntoBufferAsync(): Promise = - try { - RemoteFileAccessor.asyncRead(file, bufferedPosition, bufferSize) - } catch (e: IOException) { - throw e.maybeToSpecificException() - } - - fun reposition(oldPosition: Long, newPosition: Long) { - if (newPosition == oldPosition) { - return - } - val newBufferPosition = buffer.position() + (newPosition - oldPosition) - if (newBufferPosition in 0..buffer.limit()) { - buffer.position(newBufferPosition.toInt()) - } else { - synchronized(pendingPromiseLock) { pendingPromise = null } - buffer.limit(0) - bufferedPosition = newPosition - } - } - } - - companion object { - // @see SmbConfig.DEFAULT_BUFFER_SIZE - private const val DEFAULT_BUFFER_SIZE = 1024 * 1024 - } } diff --git a/app/src/main/java/me/zhanghai/android/files/provider/smb/client/FileByteChannel.kt b/app/src/main/java/me/zhanghai/android/files/provider/smb/client/FileByteChannel.kt index f8ccc07f..fbf57611 100644 --- a/app/src/main/java/me/zhanghai/android/files/provider/smb/client/FileByteChannel.kt +++ b/app/src/main/java/me/zhanghai/android/files/provider/smb/client/FileByteChannel.kt @@ -8,129 +8,78 @@ package me.zhanghai.android.files.provider.smb.client import com.hierynomus.mserref.NtStatus import com.hierynomus.msfscc.fileinformation.FileStandardInformation import com.hierynomus.mssmb2.SMBApiException -import com.hierynomus.mssmb2.messages.SMB2ReadResponse -import com.hierynomus.protocol.commons.concurrent.Futures -import com.hierynomus.protocol.transport.TransportException import com.hierynomus.smbj.common.SMBRuntimeException import com.hierynomus.smbj.io.ByteChunkProvider import com.hierynomus.smbj.share.File import com.hierynomus.smbj.share.FileAccessor -import java8.nio.channels.SeekableByteChannel -import me.zhanghai.android.files.provider.common.ForceableChannel +import me.zhanghai.android.files.provider.common.AbstractFileByteChannel +import me.zhanghai.android.files.provider.common.EMPTY +import me.zhanghai.android.files.provider.common.map import me.zhanghai.android.files.util.closeSafe import me.zhanghai.android.files.util.findCauseByClass -import java.io.Closeable import java.io.IOException import java.io.InterruptedIOException import java.nio.ByteBuffer import java.nio.channels.AsynchronousCloseException import java.nio.channels.ClosedByInterruptException -import java.nio.channels.ClosedChannelException -import java.nio.channels.NonReadableChannelException +import java.util.concurrent.ExecutionException import java.util.concurrent.Future -import java.util.concurrent.TimeUnit class FileByteChannel( private val file: File, - private val isAppend: Boolean -) : ForceableChannel, SeekableByteChannel { - private var position = 0L - private val readBuffer = ReadBuffer() - private val ioLock = Any() - - private var isOpen = true - private val closeLock = Any() - + isAppend: Boolean +// Cancelling reads leads to TransportException: Received response with unknown sequence number +) : AbstractFileByteChannel(isAppend, shouldCancelRead = false) { @Throws(IOException::class) - override fun read(destination: ByteBuffer): Int { - ensureOpen() - if (isAppend) { - throw NonReadableChannelException() + override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future = + try { + FileAccessor.readAsync(file, position, size) + } catch (e: SMBRuntimeException) { + throw e.toIOException() } - val remaining = destination.remaining() - if (remaining == 0) { - return 0 - } - return synchronized(ioLock) { - readBuffer.read(destination).also { - if (it != -1) { - position += it + .map( + { response -> + when (response.header.statusCode) { + NtStatus.STATUS_END_OF_FILE.value -> { + return@map ByteBuffer::class.EMPTY + } + NtStatus.STATUS_SUCCESS.value -> {} + else -> throw SMBApiException(response.header, "Read failed for $this") + .toIOException() + } + val data = response.data + if (data.isEmpty()) { + return@map ByteBuffer::class.EMPTY + } + val length = data.size.coerceAtMost(size) + ByteBuffer.wrap(data, 0, length) + }, { e -> + ExecutionException(SMBRuntimeException(e).toIOException()) } - } + ) + + @Throws(IOException::class) + override fun onWrite(position: Long, source: ByteBuffer) { + val sourcePosition = source.position() + val bytesWritten = try { + file.write(ByteBufferChunkProvider(source, position)).toInt() + } catch (e: SMBRuntimeException) { + throw e.toIOException() + } + source.position(sourcePosition + bytesWritten) + } + + @Throws(IOException::class) + override fun onTruncate(size: Long) { + try { + file.setLength(size) + } catch (e: SMBRuntimeException) { + throw e.toIOException() } } @Throws(IOException::class) - override fun write(source: ByteBuffer): Int { - ensureOpen() - if (!source.hasRemaining()) { - return 0 - } - synchronized(ioLock) { - if (isAppend) { - position = getSize() - } - return try { - file.write(ByteBufferChunkProvider(source, position)).toInt() - } catch (e: SMBRuntimeException) { - throw e.toIOException() - }.also { - position += it - } - } - } - - @Throws(IOException::class) - override fun position(): Long { - ensureOpen() - synchronized(ioLock) { - if (isAppend) { - position = getSize() - } - return position - } - } - - override fun position(newPosition: Long): SeekableByteChannel { - ensureOpen() - if (isAppend) { - // Ignored. - return this - } - synchronized(ioLock) { - readBuffer.reposition(position, newPosition) - position = newPosition - } - return this - } - - @Throws(IOException::class) - override fun size(): Long { - ensureOpen() - return getSize() - } - - @Throws(IOException::class) - override fun truncate(size: Long): SeekableByteChannel { - ensureOpen() - require(size >= 0) - synchronized(ioLock) { - val currentSize = getSize() - if (size >= currentSize) { - return this - } - try { - file.setLength(size) - } catch (e: SMBRuntimeException) { - throw e.toIOException() - } - position = position.coerceAtMost(size) - } - return this - } - - @Throws(IOException::class) - private fun getSize(): Long = + override fun onSize(): Long = try { file.getFileInformation(FileStandardInformation::class.java).endOfFile } catch (e: SMBRuntimeException) { @@ -138,8 +87,7 @@ class FileByteChannel( } @Throws(IOException::class) - override fun force(metaData: Boolean) { - ensureOpen() + override fun onForce(metaData: Boolean) { try { file.flush() } catch (e: SMBRuntimeException) { @@ -147,20 +95,11 @@ class FileByteChannel( } } - @Throws(ClosedChannelException::class) - private fun ensureOpen() { - synchronized(closeLock) { - if (!isOpen) { - throw ClosedChannelException() - } - } - } - private fun SMBRuntimeException.toIOException(): IOException = when { findCauseByClass() .let { it != null && it.status == NtStatus.STATUS_FILE_CLOSED } -> { - synchronized(closeLock) { isOpen = false } + setClosed() AsynchronousCloseException().apply { initCause(this@toIOException) } } findCauseByClass() != null -> { @@ -170,162 +109,37 @@ class FileByteChannel( else -> IOException(this) } - override fun isOpen(): Boolean = synchronized(closeLock) { isOpen } - @Throws(IOException::class) - override fun close() { - synchronized(closeLock) { - if (!isOpen) { - return - } - isOpen = false - readBuffer.closeSafe() - try { - file.close() - } catch (e: SMBRuntimeException) { - throw when { - e.findCauseByClass() != null -> - InterruptedIOException().apply { initCause(e) } - else -> IOException(e) - } + override fun onClose() { + try { + file.close() + } catch (e: SMBRuntimeException) { + throw when { + e.findCauseByClass() != null -> + InterruptedIOException().apply { initCause(e) } + else -> IOException(e) } } } - private inner class ReadBuffer : Closeable { - private val bufferSize: Int - private val timeout: Long - + private class ByteBufferChunkProvider( + private val buffer: ByteBuffer, + offset: Long + ) : ByteChunkProvider() { init { - val treeConnect = file.diskShare.treeConnect - val config = treeConnect.config - bufferSize = config.readBufferSize - .coerceAtMost(treeConnect.session.connection.negotiatedProtocol.maxReadSize) - timeout = config.readTimeout + this.offset = offset } - private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) } - private var bufferedPosition = 0L + override fun isAvailable(): Boolean = buffer.hasRemaining() - private var pendingFuture: Future? = null - private val pendingFutureLock = Any() + override fun bytesLeft(): Int = buffer.remaining() - @Throws(IOException::class) - fun read(destination: ByteBuffer): Int { - if (!buffer.hasRemaining()) { - readIntoBuffer() - if (!buffer.hasRemaining()) { - return -1 - } - } - val length = destination.remaining().coerceAtMost(buffer.remaining()) - val bufferLimit = buffer.limit() - buffer.limit(buffer.position() + length) - destination.put(buffer) - buffer.limit(bufferLimit) + override fun prepareWrite(maxBytesToPrepare: Int) {} + + override fun getChunk(chunk: ByteArray): Int { + val length = chunk.size.coerceAtMost(buffer.remaining()) + buffer.get(chunk, 0, length) return length } - - @Throws(IOException::class) - private fun readIntoBuffer() { - val future = synchronized(pendingFutureLock) { - pendingFuture?.also { pendingFuture = null } - } ?: readIntoBufferAsync() - val response = try { - receive(future, timeout) - } catch (e: SMBRuntimeException) { - throw e.toIOException() - } - when (response.header.statusCode) { - NtStatus.STATUS_END_OF_FILE.value -> { - buffer.limit(0) - return - } - NtStatus.STATUS_SUCCESS.value -> {} - else -> throw SMBApiException(response.header, "Read failed for $this") - .toIOException() - } - val data = response.data - if (data.isEmpty()) { - buffer.limit(0) - return - } - buffer.clear() - val length = data.size.coerceAtMost(buffer.remaining()) - buffer.put(data, 0, length) - buffer.flip() - bufferedPosition += length - synchronized(pendingFutureLock) { - try { - pendingFuture = readIntoBufferAsync() - } catch (e: IOException) { - e.printStackTrace() - } - } - } - - // @see com.hierynomus.smbj.share.Share.receive - @Throws(SMBRuntimeException::class) - private fun receive(future: Future, timeout: Long): T = - try { - Futures.get(future, timeout, TimeUnit.MILLISECONDS, TransportException.Wrapper) - } catch (e: TransportException) { - throw SMBRuntimeException(e) - } - - @Throws(IOException::class) - private fun readIntoBufferAsync(): Future = - try { - FileAccessor.readAsync(file, bufferedPosition, bufferSize) - } catch (e: SMBRuntimeException) { - throw e.toIOException() - } - - fun reposition(oldPosition: Long, newPosition: Long) { - if (newPosition == oldPosition) { - return - } - val newBufferPosition = buffer.position() + (newPosition - oldPosition) - if (newBufferPosition in 0..buffer.limit()) { - buffer.position(newBufferPosition.toInt()) - } else { - synchronized(pendingFutureLock) { - // TransportException: Received response with unknown sequence number - //pendingFuture?.cancel(true)?.also { pendingFuture = null } - pendingFuture = null - } - buffer.limit(0) - bufferedPosition = newPosition - } - } - - override fun close() { - synchronized(pendingFutureLock) { - // TransportException: Received response with unknown sequence number - //pendingFuture?.cancel(true)?.also { pendingFuture = null } - pendingFuture = null - } - } - } -} - -private class ByteBufferChunkProvider( - private val buffer: ByteBuffer, - offset: Long -) : ByteChunkProvider() { - init { - this.offset = offset - } - - override fun isAvailable(): Boolean = buffer.hasRemaining() - - override fun bytesLeft(): Int = buffer.remaining() - - override fun prepareWrite(maxBytesToPrepare: Int) {} - - override fun getChunk(chunk: ByteArray): Int { - val length = chunk.size.coerceAtMost(buffer.remaining()) - buffer.get(chunk, 0, length) - return length } }