Refactor: Extract AbstractFileByteChannel from FTP, SFTP and SMB

This commit is contained in:
Hai Zhang 2024-02-11 21:48:23 -08:00
parent 79114bfefe
commit 46acf5894c
8 changed files with 745 additions and 749 deletions

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.compat
import java.io.IOException
import java.io.InputStream
import kotlin.reflect.KClass
fun KClass<InputStream>.nullInputStream(): InputStream =
object : InputStream() {
private var closed = false
override fun read(): Int {
ensureOpen()
return -1;
}
override fun read(bytes: ByteArray, offset: Int, length: Int): Int {
if (!(offset >= 0 && length >= 0 && length <= bytes.size - offset)) {
throw IndexOutOfBoundsException()
}
ensureOpen()
return if (length == 0) 0 else -1
}
override fun skip(length: Long): Long {
ensureOpen()
return 0
}
override fun available(): Int {
ensureOpen()
return 0
}
override fun close() {
closed = true
}
private fun ensureOpen() {
if (closed) {
throw IOException("Stream closed")
}
}
}

View File

@ -0,0 +1,299 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java8.nio.channels.SeekableByteChannel
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.async
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout
import me.zhanghai.android.files.util.closeSafe
import java.io.Closeable
import java.io.IOException
import java.io.InterruptedIOException
import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
import java.util.concurrent.CancellationException
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
abstract class AbstractFileByteChannel(
private val isAppend: Boolean,
private val shouldCancelRead: Boolean = true,
private val joinCancelledRead: Boolean = false
) : ForceableChannel, SeekableByteChannel {
private var position = 0L
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
@Throws(IOException::class)
final override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
}
val remaining = destination.remaining()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
}
}
}
}
protected open fun onReadAsync(
position: Long,
size: Int,
timeoutMillis: Long
): Future<ByteBuffer> =
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.async(Dispatchers.IO) {
withTimeout(timeoutMillis) {
runInterruptible {
onRead(position, size)
}
}
}
.asFuture()
@Throws(IOException::class)
protected open fun onRead(position: Long, size: Int): ByteBuffer {
throw NotImplementedError()
}
@Throws(IOException::class)
final override fun write(source: ByteBuffer): Int {
ensureOpen()
val remaining = source.remaining()
if (remaining == 0) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
onAppend(source)
position = onSize()
} else {
onWrite(position, source)
position += remaining - source.remaining()
}
return remaining
}
}
@Throws(IOException::class)
protected abstract fun onWrite(position: Long, source: ByteBuffer)
@Throws(IOException::class)
protected open fun onAppend(source: ByteBuffer) {
val position = onSize()
onWrite(position, source)
}
@Throws(IOException::class)
final override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = onSize()
}
return position
}
}
final override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
final override fun size(): Long {
ensureOpen()
return onSize()
}
@Throws(IOException::class)
final override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = onSize()
if (size >= currentSize) {
return this
}
onTruncate(size)
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
protected abstract fun onTruncate(size: Long)
@Throws(IOException::class)
protected abstract fun onSize(): Long
@Throws(IOException::class)
final override fun force(metaData: Boolean) {
ensureOpen()
synchronized(ioLock) {
onForce(metaData)
}
}
@Throws(IOException::class)
protected open fun onForce(metaData: Boolean) {}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
final override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
final override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
synchronized(ioLock) {
readBuffer.closeSafe()
onClose()
}
}
}
protected fun setClosed() {
synchronized(closeLock) {
isOpen = false
}
}
@Throws(IOException::class)
protected open fun onClose() {}
private inner class ReadBuffer : Closeable {
private val buffer = ByteBuffer.allocate(BUFFER_SIZE).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingRead: Future<ByteBuffer>? = null
private val pendingReadLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val future = synchronized(pendingReadLock) {
pendingRead?.also { pendingRead = null }
} ?: readIntoBufferAsync()
val newBuffer = try {
future.get()
} catch (e: CancellationException) {
throw InterruptedIOException().apply { initCause(e) }
} catch (e: InterruptedException) {
throw InterruptedIOException().apply { initCause(e) }
} catch (e: ExecutionException) {
val exception = e.cause ?: e
if (exception is IOException) {
throw exception
} else {
throw IOException(exception)
}
}
buffer.clear()
buffer.put(newBuffer)
buffer.flip()
if (!buffer.hasRemaining()) {
return
}
bufferedPosition += buffer.remaining()
synchronized(pendingReadLock) {
pendingRead = readIntoBufferAsync()
}
}
private fun readIntoBufferAsync(): Future<ByteBuffer> =
onReadAsync(bufferedPosition, BUFFER_SIZE, TIMEOUT_MILLIS)
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
cancelPendingRead()
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
cancelPendingRead()
}
private fun cancelPendingRead() {
synchronized(pendingReadLock) {
pendingRead?.let {
if (shouldCancelRead) {
it.cancel(true)
if (joinCancelledRead) {
try {
it.get()
} catch (e: Exception) {
// Ignored
}
}
}
pendingRead = null
}
}
}
}
companion object {
private const val BUFFER_SIZE = 1024 * 1024
private const val TIMEOUT_MILLIS = 15_000L
}
}

View File

@ -0,0 +1,14 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java.nio.ByteBuffer
import kotlin.reflect.KClass
private val EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0)
val KClass<ByteBuffer>.EMPTY: ByteBuffer
get() = EMPTY_BYTE_BUFFER

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java.io.IOException
import java.io.InputStream
import java.nio.ByteBuffer
class ByteBufferInputStream(buffer: ByteBuffer) : InputStream() {
private var buffer: ByteBuffer? = buffer
override fun read(): Int {
val buffer = ensureOpen()
return if (buffer.hasRemaining()) buffer.get().toInt() and 0xFF else -1
}
override fun read(bytes: ByteArray, offset: Int, length: Int): Int {
val buffer = ensureOpen()
if (length == 0) {
return 0
}
val remaining = buffer.remaining()
if (remaining == 0) {
return -1
}
val readLength = length.coerceAtMost(remaining)
buffer.get(bytes, offset, readLength)
return readLength
}
override fun skip(length: Long): Long {
val buffer = ensureOpen()
if (length <= 0) {
return 0
}
val skippedLength = length.toInt().coerceAtMost(buffer.remaining())
buffer.position(buffer.position() + skippedLength)
return skippedLength.toLong()
}
override fun available(): Int {
val buffer = ensureOpen()
return buffer.remaining()
}
override fun markSupported(): Boolean = true
override fun mark(readlimit: Int) {
val buffer = ensureOpen()
buffer.mark()
}
override fun reset() {
val buffer = ensureOpen()
buffer.reset()
}
override fun close() {
buffer = null
}
private fun ensureOpen(): ByteBuffer = buffer ?: throw IOException("Stream closed");
}

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import net.schmizz.concurrent.Promise
import java.util.concurrent.CancellationException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
inline fun <T, R> Future<T>.map(
crossinline transform: (T) -> R,
crossinline transformException: (Exception) -> Exception = { it }
): Future<R> =
object : Future<R> {
override fun cancel(mayInterruptIfRunning: Boolean): Boolean =
this@map.cancel(mayInterruptIfRunning)
override fun isCancelled(): Boolean = this@map.isCancelled
override fun isDone(): Boolean = this@map.isDone
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): R = transformGet { this@map.get() }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit): R =
transformGet { this@map.get(timeout, unit) }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
private inline fun transformGet(get: () -> T): R {
val result = try {
get()
} catch (e: Exception) {
val exception = try {
transformException(e)
} catch (e2: Exception) {
e2.addSuppressed(e)
throw ExecutionException(e2)
}
check(
exception is ExecutionException || exception is InterruptedException ||
exception is TimeoutException
)
throw exception
}
try {
return transform(result)
} catch (e: Exception) {
throw ExecutionException(e)
}
}
}
fun <T> Deferred<T>.asFuture(): Future<T> =
object : Future<T> {
private val latch = CountDownLatch(1)
init {
invokeOnCompletion { latch.countDown() }
}
override fun cancel(mayInterruptIfRunning: Boolean): Boolean {
cancel()
return this@asFuture.isCancelled
}
override fun isCancelled(): Boolean = this@asFuture.isCancelled
override fun isDone(): Boolean = isCompleted
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): T {
latch.await()
return getCompleted()
}
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit): T {
latch.await(timeout, unit)
return getCompleted()
}
@OptIn(ExperimentalCoroutinesApi::class)
@Throws(ExecutionException::class)
private fun getCompleted(): T =
try {
this@asFuture.getCompleted()
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
throw ExecutionException(e)
}
}
fun <T> Promise<T, *>.asFuture(): Future<T> =
object : Future<T> {
override fun cancel(mayInterruptIfRunning: Boolean): Boolean = false
override fun isCancelled(): Boolean = false
override fun isDone(): Boolean = isFulfilled
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): T = tryRetrieve { retrieve() }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit?): T =
tryRetrieve { retrieve(timeout, unit) }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
private inline fun tryRetrieve(retrieve: () -> T): T =
try {
retrieve()
} catch (e: Exception) {
when (val cause = e.cause) {
is InterruptedException -> {
Thread.interrupted()
throw cause
}
is TimeoutException -> throw cause
else -> throw ExecutionException(e)
}
}
}

View File

@ -5,150 +5,80 @@
package me.zhanghai.android.files.provider.ftp.client package me.zhanghai.android.files.provider.ftp.client
import java8.nio.channels.SeekableByteChannel import me.zhanghai.android.files.compat.nullInputStream
import kotlinx.coroutines.CancellationException import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import kotlinx.coroutines.Deferred import me.zhanghai.android.files.provider.common.ByteBufferInputStream
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout
import me.zhanghai.android.files.provider.common.ForceableChannel
import me.zhanghai.android.files.provider.common.readFully import me.zhanghai.android.files.provider.common.readFully
import me.zhanghai.android.files.util.closeSafe
import org.apache.commons.net.ftp.FTPClient import org.apache.commons.net.ftp.FTPClient
import java.io.ByteArrayInputStream
import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.io.InterruptedIOException import java.io.InputStream
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
class FileByteChannel( class FileByteChannel(
private val client: FTPClient, private val client: FTPClient,
private val releaseClient: (FTPClient) -> Unit, private val releaseClient: (FTPClient) -> Unit,
private val path: String, private val path: String,
private val isAppend: Boolean isAppend: Boolean
) : ForceableChannel, SeekableByteChannel { ) : AbstractFileByteChannel(isAppend, joinCancelledRead = true) {
private val clientLock = Any() private val clientLock = Any()
private var position = 0L @Throws(IOException::class)
private val readBuffer = ReadBuffer() override fun onRead(position: Long, size: Int): ByteBuffer {
private val ioLock = Any() val destination = ByteBuffer.allocate(size)
synchronized(clientLock) {
private var isOpen = true client.restartOffset = position
private val closeLock = Any() val inputStream = client.retrieveFileStream(path)
?: client.throwNegativeReplyCodeException()
try {
val limit = inputStream.use { it.readFully(destination.array(), 0, size) }
destination.limit(limit)
} finally {
// We will likely close the input stream before the file is fully
// read and it will result in a false return value here, but that's
// totally fine.
client.completePendingCommand()
}
}
return destination
}
@Throws(IOException::class) @Throws(IOException::class)
override fun read(destination: ByteBuffer): Int { override fun onWrite(position: Long, source: ByteBuffer) {
ensureOpen() synchronized(clientLock) {
if (isAppend) { client.restartOffset = position
throw NonReadableChannelException() ByteBufferInputStream(source).use {
} if (!client.storeFile(path, it)) {
val remaining = destination.remaining() client.throwNegativeReplyCodeException()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
} }
} }
} }
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun write(source: ByteBuffer): Int { override fun onAppend(source: ByteBuffer) {
ensureOpen() synchronized(clientLock) {
val remaining = source.remaining() ByteBufferInputStream(source).use {
if (remaining == 0) { if (!client.appendFile(path, it)) {
return 0 client.throwNegativeReplyCodeException()
}
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
synchronized(ioLock) {
if (isAppend) {
synchronized(clientLock) {
ByteArrayInputStream(source.array(), source.position(), remaining).use {
if (!client.appendFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
position = getSize()
} else {
synchronized(clientLock) {
client.restartOffset = position
ByteArrayInputStream(source.array(), source.position(), remaining).use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
position += remaining
}
source.position(source.limit())
return remaining
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
synchronized(clientLock) {
client.restartOffset = size
ByteArrayInputStream(byteArrayOf()).use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
} }
} }
position = position.coerceAtMost(size)
} }
return this
} }
@Throws(IOException::class) @Throws(IOException::class)
private fun getSize(): Long { override fun onTruncate(size: Long) {
synchronized(clientLock) {
client.restartOffset = size
InputStream::class.nullInputStream().use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
}
@Throws(IOException::class)
override fun onSize(): Long {
val sizeString = synchronized(clientLock) { val sizeString = synchronized(clientLock) {
client.getSize(path) ?: client.throwNegativeReplyCodeException() client.getSize(path) ?: client.throwNegativeReplyCodeException()
} }
@ -156,145 +86,7 @@ class FileByteChannel(
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun force(metaData: Boolean) { override fun onClose() {
ensureOpen() synchronized(clientLock) { releaseClient(client) }
// Unsupported.
}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
synchronized(ioLock) {
readBuffer.closeSafe()
synchronized(clientLock) { releaseClient(client) }
}
}
}
private inner class ReadBuffer : Closeable {
private val bufferSize = DEFAULT_BUFFER_SIZE
private val timeoutMillis = 15_000L
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingDeferred: Deferred<ByteBuffer>? = null
private val pendingDeferredLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val deferred = synchronized(pendingDeferredLock) {
pendingDeferred?.also { pendingDeferred = null }
} ?: readIntoBufferAsync()
val newBuffer = try {
runBlocking { deferred.await() }
} catch (e: CancellationException) {
throw InterruptedIOException().apply { initCause(e) }
}
buffer.clear()
buffer.put(newBuffer)
buffer.flip()
if (!buffer.hasRemaining()) {
return
}
bufferedPosition += buffer.remaining()
synchronized(pendingDeferredLock) {
pendingDeferred = readIntoBufferAsync()
}
}
private fun readIntoBufferAsync(): Deferred<ByteBuffer> =
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.async(Dispatchers.IO) {
withTimeout(timeoutMillis) {
runInterruptible {
synchronized(clientLock) {
client.restartOffset = bufferedPosition
val inputStream = client.retrieveFileStream(path)
?: client.throwNegativeReplyCodeException()
try {
val buffer = ByteBuffer.allocate(bufferSize)
val limit = inputStream.use {
it.readFully(
buffer.array(), buffer.position(), buffer.remaining()
)
}
buffer.limit(limit)
buffer
} finally {
// We will likely close the input stream before the file is fully
// read and it will result in a false return value here, but that's
// totally fine.
client.completePendingCommand()
}
}
}
}
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingDeferredLock) {
pendingDeferred?.let {
it.cancel()
runBlocking { it.join() }
pendingDeferred = null
}
}
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
synchronized(pendingDeferredLock) {
pendingDeferred?.let {
it.cancel()
runBlocking { it.join() }
pendingDeferred = null
}
}
}
}
companion object {
private const val DEFAULT_BUFFER_SIZE = 1024 * 1024
} }
} }

View File

@ -5,11 +5,12 @@
package me.zhanghai.android.files.provider.sftp.client package me.zhanghai.android.files.provider.sftp.client
import java8.nio.channels.SeekableByteChannel import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import me.zhanghai.android.files.provider.common.ForceableChannel import me.zhanghai.android.files.provider.common.EMPTY
import me.zhanghai.android.files.provider.common.asFuture
import me.zhanghai.android.files.provider.common.map
import me.zhanghai.android.files.util.closeSafe import me.zhanghai.android.files.util.closeSafe
import me.zhanghai.android.files.util.findCauseByClass import me.zhanghai.android.files.util.findCauseByClass
import net.schmizz.concurrent.Promise
import net.schmizz.sshj.sftp.PacketType import net.schmizz.sshj.sftp.PacketType
import net.schmizz.sshj.sftp.RemoteFile import net.schmizz.sshj.sftp.RemoteFile
import net.schmizz.sshj.sftp.RemoteFileAccessor import net.schmizz.sshj.sftp.RemoteFileAccessor
@ -19,139 +20,76 @@ import java.io.IOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.AsynchronousCloseException import java.nio.channels.AsynchronousCloseException
import java.nio.channels.ClosedByInterruptException import java.nio.channels.ClosedByInterruptException
import java.nio.channels.ClosedChannelException import java.util.concurrent.ExecutionException
import java.nio.channels.NonReadableChannelException import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
class FileByteChannel( class FileByteChannel(
private val file: RemoteFile, private val file: RemoteFile,
private val isAppend: Boolean isAppend: Boolean
) : ForceableChannel, SeekableByteChannel { ) : AbstractFileByteChannel(isAppend) {
private var position = 0L override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future<ByteBuffer> =
private val readBuffer = ReadBuffer() try {
private val ioLock = Any() RemoteFileAccessor.asyncRead(file, position, size)
} catch (e: IOException) {
private var isOpen = true throw e.maybeToSpecificException()
private val closeLock = Any()
@Throws(IOException::class)
override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
} }
val remaining = destination.remaining() .asFuture()
if (remaining == 0) { .map(
return 0 { response ->
} val dataLength: Int
return synchronized(ioLock) { when (response.type) {
readBuffer.read(destination).also { PacketType.STATUS -> {
if (it != -1) { response.ensureStatusIs(Response.StatusCode.EOF)
position += it return@map ByteBuffer::class.EMPTY
}
PacketType.DATA -> {
dataLength = response.readUInt32AsInt()
}
else -> throw SFTPException("Unexpected packet type ${response.type}")
}
if (dataLength == 0) {
return@map ByteBuffer::class.EMPTY
}
val length = dataLength.coerceAtMost(size)
ByteBuffer.wrap(response.array(), response.rpos(), length)
}, { e ->
((e as? ExecutionException)?.cause as? IOException)?.maybeToSpecificException()
?.let { ExecutionException(it) } ?: e
} }
} )
@Throws(IOException::class)
override fun onWrite(position: Long, source: ByteBuffer) {
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
try {
file.write(position, source.array(), source.position(), source.remaining())
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
source.position(source.limit())
}
@Throws(IOException::class)
override fun onTruncate(size: Long) {
try {
file.setLength(size)
} catch (e: IOException) {
throw e.maybeToSpecificException()
} }
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun write(source: ByteBuffer): Int { override fun onSize(): Long =
ensureOpen()
val remaining = source.remaining()
if (remaining == 0) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
try {
file.write(position, source.array(), source.position(), remaining)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
source.position(source.limit())
position += remaining
return remaining
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
try {
file.setLength(size)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
private fun getSize(): Long =
try{ try{
file.length() file.length()
} catch (e: IOException) { } catch (e: IOException) {
throw e.maybeToSpecificException() throw e.maybeToSpecificException()
} }
@Throws(IOException::class)
override fun force(metaData: Boolean) {
ensureOpen()
// Unsupported.
}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
private fun IOException.maybeToSpecificException(): IOException = private fun IOException.maybeToSpecificException(): IOException =
when { when {
this is SFTPException && statusCode == Response.StatusCode.INVALID_HANDLE -> { this is SFTPException && statusCode == Response.StatusCode.INVALID_HANDLE -> {
synchronized(closeLock) { isOpen = false } setClosed()
AsynchronousCloseException().apply { initCause(this@maybeToSpecificException) } AsynchronousCloseException().apply { initCause(this@maybeToSpecificException) }
} }
findCauseByClass<InterruptedException>() != null -> { findCauseByClass<InterruptedException>() != null -> {
@ -161,122 +99,15 @@ class FileByteChannel(
else -> this else -> this
} }
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class) @Throws(IOException::class)
override fun close() { override fun onClose() {
synchronized(closeLock) { try {
if (!isOpen) { file.close()
return } catch (e: SFTPException) {
} // NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server.
isOpen = false if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) {
try { throw e
file.close()
} catch (e: SFTPException) {
// NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server.
if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) {
throw e
}
} }
} }
} }
private inner class ReadBuffer {
private val bufferSize: Int = DEFAULT_BUFFER_SIZE
private val timeout: Long
init {
val engine = RemoteFileAccessor.getRequester(file)
timeout = engine.timeoutMs.toLong()
}
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingPromise: Promise<Response, SFTPException>? = null
private val pendingPromiseLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val promise = synchronized(pendingPromiseLock) {
pendingPromise?.also { pendingPromise = null }
} ?: readIntoBufferAsync()
val response = try {
promise.retrieve(timeout, TimeUnit.MILLISECONDS)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
val dataLength: Int
when (response.type) {
PacketType.STATUS -> {
response.ensureStatusIs(Response.StatusCode.EOF)
buffer.limit(0)
return
}
PacketType.DATA -> {
dataLength = response.readUInt32AsInt()
}
else -> throw SFTPException("Unexpected packet type ${response.type}")
}
if (dataLength == 0) {
buffer.limit(0)
return
}
buffer.clear()
val length = dataLength.coerceAtMost(buffer.remaining())
buffer.put(response.array(), response.rpos(), length)
buffer.flip()
bufferedPosition += length
synchronized(pendingPromiseLock) {
try {
pendingPromise = readIntoBufferAsync()
} catch (e: IOException) {
e.printStackTrace()
}
}
}
@Throws(IOException::class)
private fun readIntoBufferAsync(): Promise<Response, SFTPException> =
try {
RemoteFileAccessor.asyncRead(file, bufferedPosition, bufferSize)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingPromiseLock) { pendingPromise = null }
buffer.limit(0)
bufferedPosition = newPosition
}
}
}
companion object {
// @see SmbConfig.DEFAULT_BUFFER_SIZE
private const val DEFAULT_BUFFER_SIZE = 1024 * 1024
}
} }

View File

@ -8,129 +8,78 @@ package me.zhanghai.android.files.provider.smb.client
import com.hierynomus.mserref.NtStatus import com.hierynomus.mserref.NtStatus
import com.hierynomus.msfscc.fileinformation.FileStandardInformation import com.hierynomus.msfscc.fileinformation.FileStandardInformation
import com.hierynomus.mssmb2.SMBApiException import com.hierynomus.mssmb2.SMBApiException
import com.hierynomus.mssmb2.messages.SMB2ReadResponse
import com.hierynomus.protocol.commons.concurrent.Futures
import com.hierynomus.protocol.transport.TransportException
import com.hierynomus.smbj.common.SMBRuntimeException import com.hierynomus.smbj.common.SMBRuntimeException
import com.hierynomus.smbj.io.ByteChunkProvider import com.hierynomus.smbj.io.ByteChunkProvider
import com.hierynomus.smbj.share.File import com.hierynomus.smbj.share.File
import com.hierynomus.smbj.share.FileAccessor import com.hierynomus.smbj.share.FileAccessor
import java8.nio.channels.SeekableByteChannel import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import me.zhanghai.android.files.provider.common.ForceableChannel import me.zhanghai.android.files.provider.common.EMPTY
import me.zhanghai.android.files.provider.common.map
import me.zhanghai.android.files.util.closeSafe import me.zhanghai.android.files.util.closeSafe
import me.zhanghai.android.files.util.findCauseByClass import me.zhanghai.android.files.util.findCauseByClass
import java.io.Closeable
import java.io.IOException import java.io.IOException
import java.io.InterruptedIOException import java.io.InterruptedIOException
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.nio.channels.AsynchronousCloseException import java.nio.channels.AsynchronousCloseException
import java.nio.channels.ClosedByInterruptException import java.nio.channels.ClosedByInterruptException
import java.nio.channels.ClosedChannelException import java.util.concurrent.ExecutionException
import java.nio.channels.NonReadableChannelException
import java.util.concurrent.Future import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
class FileByteChannel( class FileByteChannel(
private val file: File, private val file: File,
private val isAppend: Boolean isAppend: Boolean
) : ForceableChannel, SeekableByteChannel { // Cancelling reads leads to TransportException: Received response with unknown sequence number
private var position = 0L ) : AbstractFileByteChannel(isAppend, shouldCancelRead = false) {
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
@Throws(IOException::class) @Throws(IOException::class)
override fun read(destination: ByteBuffer): Int { override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future<ByteBuffer> =
ensureOpen() try {
if (isAppend) { FileAccessor.readAsync(file, position, size)
throw NonReadableChannelException() } catch (e: SMBRuntimeException) {
throw e.toIOException()
} }
val remaining = destination.remaining() .map(
if (remaining == 0) { { response ->
return 0 when (response.header.statusCode) {
} NtStatus.STATUS_END_OF_FILE.value -> {
return synchronized(ioLock) { return@map ByteBuffer::class.EMPTY
readBuffer.read(destination).also { }
if (it != -1) { NtStatus.STATUS_SUCCESS.value -> {}
position += it else -> throw SMBApiException(response.header, "Read failed for $this")
.toIOException()
}
val data = response.data
if (data.isEmpty()) {
return@map ByteBuffer::class.EMPTY
}
val length = data.size.coerceAtMost(size)
ByteBuffer.wrap(data, 0, length)
}, { e ->
ExecutionException(SMBRuntimeException(e).toIOException())
} }
} )
@Throws(IOException::class)
override fun onWrite(position: Long, source: ByteBuffer) {
val sourcePosition = source.position()
val bytesWritten = try {
file.write(ByteBufferChunkProvider(source, position)).toInt()
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
source.position(sourcePosition + bytesWritten)
}
@Throws(IOException::class)
override fun onTruncate(size: Long) {
try {
file.setLength(size)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
} }
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun write(source: ByteBuffer): Int { override fun onSize(): Long =
ensureOpen()
if (!source.hasRemaining()) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return try {
file.write(ByteBufferChunkProvider(source, position)).toInt()
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}.also {
position += it
}
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
try {
file.setLength(size)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
private fun getSize(): Long =
try { try {
file.getFileInformation(FileStandardInformation::class.java).endOfFile file.getFileInformation(FileStandardInformation::class.java).endOfFile
} catch (e: SMBRuntimeException) { } catch (e: SMBRuntimeException) {
@ -138,8 +87,7 @@ class FileByteChannel(
} }
@Throws(IOException::class) @Throws(IOException::class)
override fun force(metaData: Boolean) { override fun onForce(metaData: Boolean) {
ensureOpen()
try { try {
file.flush() file.flush()
} catch (e: SMBRuntimeException) { } catch (e: SMBRuntimeException) {
@ -147,20 +95,11 @@ class FileByteChannel(
} }
} }
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
private fun SMBRuntimeException.toIOException(): IOException = private fun SMBRuntimeException.toIOException(): IOException =
when { when {
findCauseByClass<SMBApiException>() findCauseByClass<SMBApiException>()
.let { it != null && it.status == NtStatus.STATUS_FILE_CLOSED } -> { .let { it != null && it.status == NtStatus.STATUS_FILE_CLOSED } -> {
synchronized(closeLock) { isOpen = false } setClosed()
AsynchronousCloseException().apply { initCause(this@toIOException) } AsynchronousCloseException().apply { initCause(this@toIOException) }
} }
findCauseByClass<InterruptedException>() != null -> { findCauseByClass<InterruptedException>() != null -> {
@ -170,162 +109,37 @@ class FileByteChannel(
else -> IOException(this) else -> IOException(this)
} }
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class) @Throws(IOException::class)
override fun close() { override fun onClose() {
synchronized(closeLock) { try {
if (!isOpen) { file.close()
return } catch (e: SMBRuntimeException) {
} throw when {
isOpen = false e.findCauseByClass<InterruptedException>() != null ->
readBuffer.closeSafe() InterruptedIOException().apply { initCause(e) }
try { else -> IOException(e)
file.close()
} catch (e: SMBRuntimeException) {
throw when {
e.findCauseByClass<InterruptedException>() != null ->
InterruptedIOException().apply { initCause(e) }
else -> IOException(e)
}
} }
} }
} }
private inner class ReadBuffer : Closeable { private class ByteBufferChunkProvider(
private val bufferSize: Int private val buffer: ByteBuffer,
private val timeout: Long offset: Long
) : ByteChunkProvider() {
init { init {
val treeConnect = file.diskShare.treeConnect this.offset = offset
val config = treeConnect.config
bufferSize = config.readBufferSize
.coerceAtMost(treeConnect.session.connection.negotiatedProtocol.maxReadSize)
timeout = config.readTimeout
} }
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) } override fun isAvailable(): Boolean = buffer.hasRemaining()
private var bufferedPosition = 0L
private var pendingFuture: Future<SMB2ReadResponse>? = null override fun bytesLeft(): Int = buffer.remaining()
private val pendingFutureLock = Any()
@Throws(IOException::class) override fun prepareWrite(maxBytesToPrepare: Int) {}
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) { override fun getChunk(chunk: ByteArray): Int {
readIntoBuffer() val length = chunk.size.coerceAtMost(buffer.remaining())
if (!buffer.hasRemaining()) { buffer.get(chunk, 0, length)
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length return length
} }
@Throws(IOException::class)
private fun readIntoBuffer() {
val future = synchronized(pendingFutureLock) {
pendingFuture?.also { pendingFuture = null }
} ?: readIntoBufferAsync()
val response = try {
receive(future, timeout)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
when (response.header.statusCode) {
NtStatus.STATUS_END_OF_FILE.value -> {
buffer.limit(0)
return
}
NtStatus.STATUS_SUCCESS.value -> {}
else -> throw SMBApiException(response.header, "Read failed for $this")
.toIOException()
}
val data = response.data
if (data.isEmpty()) {
buffer.limit(0)
return
}
buffer.clear()
val length = data.size.coerceAtMost(buffer.remaining())
buffer.put(data, 0, length)
buffer.flip()
bufferedPosition += length
synchronized(pendingFutureLock) {
try {
pendingFuture = readIntoBufferAsync()
} catch (e: IOException) {
e.printStackTrace()
}
}
}
// @see com.hierynomus.smbj.share.Share.receive
@Throws(SMBRuntimeException::class)
private fun <T> receive(future: Future<T>, timeout: Long): T =
try {
Futures.get(future, timeout, TimeUnit.MILLISECONDS, TransportException.Wrapper)
} catch (e: TransportException) {
throw SMBRuntimeException(e)
}
@Throws(IOException::class)
private fun readIntoBufferAsync(): Future<SMB2ReadResponse> =
try {
FileAccessor.readAsync(file, bufferedPosition, bufferSize)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingFutureLock) {
// TransportException: Received response with unknown sequence number
//pendingFuture?.cancel(true)?.also { pendingFuture = null }
pendingFuture = null
}
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
synchronized(pendingFutureLock) {
// TransportException: Received response with unknown sequence number
//pendingFuture?.cancel(true)?.also { pendingFuture = null }
pendingFuture = null
}
}
}
}
private class ByteBufferChunkProvider(
private val buffer: ByteBuffer,
offset: Long
) : ByteChunkProvider() {
init {
this.offset = offset
}
override fun isAvailable(): Boolean = buffer.hasRemaining()
override fun bytesLeft(): Int = buffer.remaining()
override fun prepareWrite(maxBytesToPrepare: Int) {}
override fun getChunk(chunk: ByteArray): Int {
val length = chunk.size.coerceAtMost(buffer.remaining())
buffer.get(chunk, 0, length)
return length
} }
} }