Refactor: Extract AbstractFileByteChannel from FTP, SFTP and SMB

This commit is contained in:
Hai Zhang 2024-02-11 21:48:23 -08:00
parent 79114bfefe
commit 46acf5894c
8 changed files with 745 additions and 749 deletions

View File

@ -0,0 +1,48 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.compat
import java.io.IOException
import java.io.InputStream
import kotlin.reflect.KClass
fun KClass<InputStream>.nullInputStream(): InputStream =
object : InputStream() {
private var closed = false
override fun read(): Int {
ensureOpen()
return -1;
}
override fun read(bytes: ByteArray, offset: Int, length: Int): Int {
if (!(offset >= 0 && length >= 0 && length <= bytes.size - offset)) {
throw IndexOutOfBoundsException()
}
ensureOpen()
return if (length == 0) 0 else -1
}
override fun skip(length: Long): Long {
ensureOpen()
return 0
}
override fun available(): Int {
ensureOpen()
return 0
}
override fun close() {
closed = true
}
private fun ensureOpen() {
if (closed) {
throw IOException("Stream closed")
}
}
}

View File

@ -0,0 +1,299 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java8.nio.channels.SeekableByteChannel
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.async
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout
import me.zhanghai.android.files.util.closeSafe
import java.io.Closeable
import java.io.IOException
import java.io.InterruptedIOException
import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
import java.util.concurrent.CancellationException
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
abstract class AbstractFileByteChannel(
private val isAppend: Boolean,
private val shouldCancelRead: Boolean = true,
private val joinCancelledRead: Boolean = false
) : ForceableChannel, SeekableByteChannel {
private var position = 0L
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
@Throws(IOException::class)
final override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
}
val remaining = destination.remaining()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
}
}
}
}
protected open fun onReadAsync(
position: Long,
size: Int,
timeoutMillis: Long
): Future<ByteBuffer> =
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.async(Dispatchers.IO) {
withTimeout(timeoutMillis) {
runInterruptible {
onRead(position, size)
}
}
}
.asFuture()
@Throws(IOException::class)
protected open fun onRead(position: Long, size: Int): ByteBuffer {
throw NotImplementedError()
}
@Throws(IOException::class)
final override fun write(source: ByteBuffer): Int {
ensureOpen()
val remaining = source.remaining()
if (remaining == 0) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
onAppend(source)
position = onSize()
} else {
onWrite(position, source)
position += remaining - source.remaining()
}
return remaining
}
}
@Throws(IOException::class)
protected abstract fun onWrite(position: Long, source: ByteBuffer)
@Throws(IOException::class)
protected open fun onAppend(source: ByteBuffer) {
val position = onSize()
onWrite(position, source)
}
@Throws(IOException::class)
final override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = onSize()
}
return position
}
}
final override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
final override fun size(): Long {
ensureOpen()
return onSize()
}
@Throws(IOException::class)
final override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = onSize()
if (size >= currentSize) {
return this
}
onTruncate(size)
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
protected abstract fun onTruncate(size: Long)
@Throws(IOException::class)
protected abstract fun onSize(): Long
@Throws(IOException::class)
final override fun force(metaData: Boolean) {
ensureOpen()
synchronized(ioLock) {
onForce(metaData)
}
}
@Throws(IOException::class)
protected open fun onForce(metaData: Boolean) {}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
final override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
final override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
synchronized(ioLock) {
readBuffer.closeSafe()
onClose()
}
}
}
protected fun setClosed() {
synchronized(closeLock) {
isOpen = false
}
}
@Throws(IOException::class)
protected open fun onClose() {}
private inner class ReadBuffer : Closeable {
private val buffer = ByteBuffer.allocate(BUFFER_SIZE).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingRead: Future<ByteBuffer>? = null
private val pendingReadLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val future = synchronized(pendingReadLock) {
pendingRead?.also { pendingRead = null }
} ?: readIntoBufferAsync()
val newBuffer = try {
future.get()
} catch (e: CancellationException) {
throw InterruptedIOException().apply { initCause(e) }
} catch (e: InterruptedException) {
throw InterruptedIOException().apply { initCause(e) }
} catch (e: ExecutionException) {
val exception = e.cause ?: e
if (exception is IOException) {
throw exception
} else {
throw IOException(exception)
}
}
buffer.clear()
buffer.put(newBuffer)
buffer.flip()
if (!buffer.hasRemaining()) {
return
}
bufferedPosition += buffer.remaining()
synchronized(pendingReadLock) {
pendingRead = readIntoBufferAsync()
}
}
private fun readIntoBufferAsync(): Future<ByteBuffer> =
onReadAsync(bufferedPosition, BUFFER_SIZE, TIMEOUT_MILLIS)
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
cancelPendingRead()
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
cancelPendingRead()
}
private fun cancelPendingRead() {
synchronized(pendingReadLock) {
pendingRead?.let {
if (shouldCancelRead) {
it.cancel(true)
if (joinCancelledRead) {
try {
it.get()
} catch (e: Exception) {
// Ignored
}
}
}
pendingRead = null
}
}
}
}
companion object {
private const val BUFFER_SIZE = 1024 * 1024
private const val TIMEOUT_MILLIS = 15_000L
}
}

View File

@ -0,0 +1,14 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java.nio.ByteBuffer
import kotlin.reflect.KClass
private val EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0)
val KClass<ByteBuffer>.EMPTY: ByteBuffer
get() = EMPTY_BYTE_BUFFER

View File

@ -0,0 +1,66 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import java.io.IOException
import java.io.InputStream
import java.nio.ByteBuffer
class ByteBufferInputStream(buffer: ByteBuffer) : InputStream() {
private var buffer: ByteBuffer? = buffer
override fun read(): Int {
val buffer = ensureOpen()
return if (buffer.hasRemaining()) buffer.get().toInt() and 0xFF else -1
}
override fun read(bytes: ByteArray, offset: Int, length: Int): Int {
val buffer = ensureOpen()
if (length == 0) {
return 0
}
val remaining = buffer.remaining()
if (remaining == 0) {
return -1
}
val readLength = length.coerceAtMost(remaining)
buffer.get(bytes, offset, readLength)
return readLength
}
override fun skip(length: Long): Long {
val buffer = ensureOpen()
if (length <= 0) {
return 0
}
val skippedLength = length.toInt().coerceAtMost(buffer.remaining())
buffer.position(buffer.position() + skippedLength)
return skippedLength.toLong()
}
override fun available(): Int {
val buffer = ensureOpen()
return buffer.remaining()
}
override fun markSupported(): Boolean = true
override fun mark(readlimit: Int) {
val buffer = ensureOpen()
buffer.mark()
}
override fun reset() {
val buffer = ensureOpen()
buffer.reset()
}
override fun close() {
buffer = null
}
private fun ensureOpen(): ByteBuffer = buffer ?: throw IOException("Stream closed");
}

View File

@ -0,0 +1,132 @@
/*
* Copyright (c) 2024 Hai Zhang <dreaming.in.code.zh@gmail.com>
* All Rights Reserved.
*/
package me.zhanghai.android.files.provider.common
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import net.schmizz.concurrent.Promise
import java.util.concurrent.CancellationException
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
inline fun <T, R> Future<T>.map(
crossinline transform: (T) -> R,
crossinline transformException: (Exception) -> Exception = { it }
): Future<R> =
object : Future<R> {
override fun cancel(mayInterruptIfRunning: Boolean): Boolean =
this@map.cancel(mayInterruptIfRunning)
override fun isCancelled(): Boolean = this@map.isCancelled
override fun isDone(): Boolean = this@map.isDone
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): R = transformGet { this@map.get() }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit): R =
transformGet { this@map.get(timeout, unit) }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
private inline fun transformGet(get: () -> T): R {
val result = try {
get()
} catch (e: Exception) {
val exception = try {
transformException(e)
} catch (e2: Exception) {
e2.addSuppressed(e)
throw ExecutionException(e2)
}
check(
exception is ExecutionException || exception is InterruptedException ||
exception is TimeoutException
)
throw exception
}
try {
return transform(result)
} catch (e: Exception) {
throw ExecutionException(e)
}
}
}
fun <T> Deferred<T>.asFuture(): Future<T> =
object : Future<T> {
private val latch = CountDownLatch(1)
init {
invokeOnCompletion { latch.countDown() }
}
override fun cancel(mayInterruptIfRunning: Boolean): Boolean {
cancel()
return this@asFuture.isCancelled
}
override fun isCancelled(): Boolean = this@asFuture.isCancelled
override fun isDone(): Boolean = isCompleted
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): T {
latch.await()
return getCompleted()
}
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit): T {
latch.await(timeout, unit)
return getCompleted()
}
@OptIn(ExperimentalCoroutinesApi::class)
@Throws(ExecutionException::class)
private fun getCompleted(): T =
try {
this@asFuture.getCompleted()
} catch (e: CancellationException) {
throw e
} catch (e: Exception) {
throw ExecutionException(e)
}
}
fun <T> Promise<T, *>.asFuture(): Future<T> =
object : Future<T> {
override fun cancel(mayInterruptIfRunning: Boolean): Boolean = false
override fun isCancelled(): Boolean = false
override fun isDone(): Boolean = isFulfilled
@Throws(ExecutionException::class, InterruptedException::class)
override fun get(): T = tryRetrieve { retrieve() }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
override fun get(timeout: Long, unit: TimeUnit?): T =
tryRetrieve { retrieve(timeout, unit) }
@Throws(ExecutionException::class, InterruptedException::class, TimeoutException::class)
private inline fun tryRetrieve(retrieve: () -> T): T =
try {
retrieve()
} catch (e: Exception) {
when (val cause = e.cause) {
is InterruptedException -> {
Thread.interrupted()
throw cause
}
is TimeoutException -> throw cause
else -> throw ExecutionException(e)
}
}
}

View File

@ -5,150 +5,80 @@
package me.zhanghai.android.files.provider.ftp.client
import java8.nio.channels.SeekableByteChannel
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.runInterruptible
import kotlinx.coroutines.withTimeout
import me.zhanghai.android.files.provider.common.ForceableChannel
import me.zhanghai.android.files.compat.nullInputStream
import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import me.zhanghai.android.files.provider.common.ByteBufferInputStream
import me.zhanghai.android.files.provider.common.readFully
import me.zhanghai.android.files.util.closeSafe
import org.apache.commons.net.ftp.FTPClient
import java.io.ByteArrayInputStream
import java.io.Closeable
import java.io.IOException
import java.io.InterruptedIOException
import java.io.InputStream
import java.nio.ByteBuffer
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
class FileByteChannel(
private val client: FTPClient,
private val releaseClient: (FTPClient) -> Unit,
private val path: String,
private val isAppend: Boolean
) : ForceableChannel, SeekableByteChannel {
isAppend: Boolean
) : AbstractFileByteChannel(isAppend, joinCancelledRead = true) {
private val clientLock = Any()
private var position = 0L
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
@Throws(IOException::class)
override fun onRead(position: Long, size: Int): ByteBuffer {
val destination = ByteBuffer.allocate(size)
synchronized(clientLock) {
client.restartOffset = position
val inputStream = client.retrieveFileStream(path)
?: client.throwNegativeReplyCodeException()
try {
val limit = inputStream.use { it.readFully(destination.array(), 0, size) }
destination.limit(limit)
} finally {
// We will likely close the input stream before the file is fully
// read and it will result in a false return value here, but that's
// totally fine.
client.completePendingCommand()
}
}
return destination
}
@Throws(IOException::class)
override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
}
val remaining = destination.remaining()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
override fun onWrite(position: Long, source: ByteBuffer) {
synchronized(clientLock) {
client.restartOffset = position
ByteBufferInputStream(source).use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
}
@Throws(IOException::class)
override fun write(source: ByteBuffer): Int {
ensureOpen()
val remaining = source.remaining()
if (remaining == 0) {
return 0
}
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
synchronized(ioLock) {
if (isAppend) {
synchronized(clientLock) {
ByteArrayInputStream(source.array(), source.position(), remaining).use {
if (!client.appendFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
position = getSize()
} else {
synchronized(clientLock) {
client.restartOffset = position
ByteArrayInputStream(source.array(), source.position(), remaining).use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
position += remaining
}
source.position(source.limit())
return remaining
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
synchronized(clientLock) {
client.restartOffset = size
ByteArrayInputStream(byteArrayOf()).use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
override fun onAppend(source: ByteBuffer) {
synchronized(clientLock) {
ByteBufferInputStream(source).use {
if (!client.appendFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
private fun getSize(): Long {
override fun onTruncate(size: Long) {
synchronized(clientLock) {
client.restartOffset = size
InputStream::class.nullInputStream().use {
if (!client.storeFile(path, it)) {
client.throwNegativeReplyCodeException()
}
}
}
}
@Throws(IOException::class)
override fun onSize(): Long {
val sizeString = synchronized(clientLock) {
client.getSize(path) ?: client.throwNegativeReplyCodeException()
}
@ -156,145 +86,7 @@ class FileByteChannel(
}
@Throws(IOException::class)
override fun force(metaData: Boolean) {
ensureOpen()
// Unsupported.
}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
synchronized(ioLock) {
readBuffer.closeSafe()
synchronized(clientLock) { releaseClient(client) }
}
}
}
private inner class ReadBuffer : Closeable {
private val bufferSize = DEFAULT_BUFFER_SIZE
private val timeoutMillis = 15_000L
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingDeferred: Deferred<ByteBuffer>? = null
private val pendingDeferredLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val deferred = synchronized(pendingDeferredLock) {
pendingDeferred?.also { pendingDeferred = null }
} ?: readIntoBufferAsync()
val newBuffer = try {
runBlocking { deferred.await() }
} catch (e: CancellationException) {
throw InterruptedIOException().apply { initCause(e) }
}
buffer.clear()
buffer.put(newBuffer)
buffer.flip()
if (!buffer.hasRemaining()) {
return
}
bufferedPosition += buffer.remaining()
synchronized(pendingDeferredLock) {
pendingDeferred = readIntoBufferAsync()
}
}
private fun readIntoBufferAsync(): Deferred<ByteBuffer> =
@OptIn(DelicateCoroutinesApi::class)
GlobalScope.async(Dispatchers.IO) {
withTimeout(timeoutMillis) {
runInterruptible {
synchronized(clientLock) {
client.restartOffset = bufferedPosition
val inputStream = client.retrieveFileStream(path)
?: client.throwNegativeReplyCodeException()
try {
val buffer = ByteBuffer.allocate(bufferSize)
val limit = inputStream.use {
it.readFully(
buffer.array(), buffer.position(), buffer.remaining()
)
}
buffer.limit(limit)
buffer
} finally {
// We will likely close the input stream before the file is fully
// read and it will result in a false return value here, but that's
// totally fine.
client.completePendingCommand()
}
}
}
}
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingDeferredLock) {
pendingDeferred?.let {
it.cancel()
runBlocking { it.join() }
pendingDeferred = null
}
}
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
synchronized(pendingDeferredLock) {
pendingDeferred?.let {
it.cancel()
runBlocking { it.join() }
pendingDeferred = null
}
}
}
}
companion object {
private const val DEFAULT_BUFFER_SIZE = 1024 * 1024
override fun onClose() {
synchronized(clientLock) { releaseClient(client) }
}
}

View File

@ -5,11 +5,12 @@
package me.zhanghai.android.files.provider.sftp.client
import java8.nio.channels.SeekableByteChannel
import me.zhanghai.android.files.provider.common.ForceableChannel
import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import me.zhanghai.android.files.provider.common.EMPTY
import me.zhanghai.android.files.provider.common.asFuture
import me.zhanghai.android.files.provider.common.map
import me.zhanghai.android.files.util.closeSafe
import me.zhanghai.android.files.util.findCauseByClass
import net.schmizz.concurrent.Promise
import net.schmizz.sshj.sftp.PacketType
import net.schmizz.sshj.sftp.RemoteFile
import net.schmizz.sshj.sftp.RemoteFileAccessor
@ -19,139 +20,76 @@ import java.io.IOException
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.ClosedByInterruptException
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
import java.util.concurrent.TimeUnit
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
class FileByteChannel(
private val file: RemoteFile,
private val isAppend: Boolean
) : ForceableChannel, SeekableByteChannel {
private var position = 0L
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
@Throws(IOException::class)
override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
isAppend: Boolean
) : AbstractFileByteChannel(isAppend) {
override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future<ByteBuffer> =
try {
RemoteFileAccessor.asyncRead(file, position, size)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
val remaining = destination.remaining()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
.asFuture()
.map(
{ response ->
val dataLength: Int
when (response.type) {
PacketType.STATUS -> {
response.ensureStatusIs(Response.StatusCode.EOF)
return@map ByteBuffer::class.EMPTY
}
PacketType.DATA -> {
dataLength = response.readUInt32AsInt()
}
else -> throw SFTPException("Unexpected packet type ${response.type}")
}
if (dataLength == 0) {
return@map ByteBuffer::class.EMPTY
}
val length = dataLength.coerceAtMost(size)
ByteBuffer.wrap(response.array(), response.rpos(), length)
}, { e ->
((e as? ExecutionException)?.cause as? IOException)?.maybeToSpecificException()
?.let { ExecutionException(it) } ?: e
}
}
)
@Throws(IOException::class)
override fun onWrite(position: Long, source: ByteBuffer) {
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
try {
file.write(position, source.array(), source.position(), source.remaining())
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
source.position(source.limit())
}
@Throws(IOException::class)
override fun onTruncate(size: Long) {
try {
file.setLength(size)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
}
@Throws(IOException::class)
override fun write(source: ByteBuffer): Int {
ensureOpen()
val remaining = source.remaining()
if (remaining == 0) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
// I don't think we are using native or read-only ByteBuffer, so just call array() here.
try {
file.write(position, source.array(), source.position(), remaining)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
source.position(source.limit())
position += remaining
return remaining
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
try {
file.setLength(size)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
private fun getSize(): Long =
override fun onSize(): Long =
try{
file.length()
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
@Throws(IOException::class)
override fun force(metaData: Boolean) {
ensureOpen()
// Unsupported.
}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
private fun IOException.maybeToSpecificException(): IOException =
when {
this is SFTPException && statusCode == Response.StatusCode.INVALID_HANDLE -> {
synchronized(closeLock) { isOpen = false }
setClosed()
AsynchronousCloseException().apply { initCause(this@maybeToSpecificException) }
}
findCauseByClass<InterruptedException>() != null -> {
@ -161,122 +99,15 @@ class FileByteChannel(
else -> this
}
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
try {
file.close()
} catch (e: SFTPException) {
// NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server.
if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) {
throw e
}
override fun onClose() {
try {
file.close()
} catch (e: SFTPException) {
// NO_SUCH_FILE is returned when canceling an in-progress copy to SFTP server.
if (e.statusCode != Response.StatusCode.NO_SUCH_FILE) {
throw e
}
}
}
private inner class ReadBuffer {
private val bufferSize: Int = DEFAULT_BUFFER_SIZE
private val timeout: Long
init {
val engine = RemoteFileAccessor.getRequester(file)
timeout = engine.timeoutMs.toLong()
}
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) }
private var bufferedPosition = 0L
private var pendingPromise: Promise<Response, SFTPException>? = null
private val pendingPromiseLock = Any()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val promise = synchronized(pendingPromiseLock) {
pendingPromise?.also { pendingPromise = null }
} ?: readIntoBufferAsync()
val response = try {
promise.retrieve(timeout, TimeUnit.MILLISECONDS)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
val dataLength: Int
when (response.type) {
PacketType.STATUS -> {
response.ensureStatusIs(Response.StatusCode.EOF)
buffer.limit(0)
return
}
PacketType.DATA -> {
dataLength = response.readUInt32AsInt()
}
else -> throw SFTPException("Unexpected packet type ${response.type}")
}
if (dataLength == 0) {
buffer.limit(0)
return
}
buffer.clear()
val length = dataLength.coerceAtMost(buffer.remaining())
buffer.put(response.array(), response.rpos(), length)
buffer.flip()
bufferedPosition += length
synchronized(pendingPromiseLock) {
try {
pendingPromise = readIntoBufferAsync()
} catch (e: IOException) {
e.printStackTrace()
}
}
}
@Throws(IOException::class)
private fun readIntoBufferAsync(): Promise<Response, SFTPException> =
try {
RemoteFileAccessor.asyncRead(file, bufferedPosition, bufferSize)
} catch (e: IOException) {
throw e.maybeToSpecificException()
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingPromiseLock) { pendingPromise = null }
buffer.limit(0)
bufferedPosition = newPosition
}
}
}
companion object {
// @see SmbConfig.DEFAULT_BUFFER_SIZE
private const val DEFAULT_BUFFER_SIZE = 1024 * 1024
}
}

View File

@ -8,129 +8,78 @@ package me.zhanghai.android.files.provider.smb.client
import com.hierynomus.mserref.NtStatus
import com.hierynomus.msfscc.fileinformation.FileStandardInformation
import com.hierynomus.mssmb2.SMBApiException
import com.hierynomus.mssmb2.messages.SMB2ReadResponse
import com.hierynomus.protocol.commons.concurrent.Futures
import com.hierynomus.protocol.transport.TransportException
import com.hierynomus.smbj.common.SMBRuntimeException
import com.hierynomus.smbj.io.ByteChunkProvider
import com.hierynomus.smbj.share.File
import com.hierynomus.smbj.share.FileAccessor
import java8.nio.channels.SeekableByteChannel
import me.zhanghai.android.files.provider.common.ForceableChannel
import me.zhanghai.android.files.provider.common.AbstractFileByteChannel
import me.zhanghai.android.files.provider.common.EMPTY
import me.zhanghai.android.files.provider.common.map
import me.zhanghai.android.files.util.closeSafe
import me.zhanghai.android.files.util.findCauseByClass
import java.io.Closeable
import java.io.IOException
import java.io.InterruptedIOException
import java.nio.ByteBuffer
import java.nio.channels.AsynchronousCloseException
import java.nio.channels.ClosedByInterruptException
import java.nio.channels.ClosedChannelException
import java.nio.channels.NonReadableChannelException
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
class FileByteChannel(
private val file: File,
private val isAppend: Boolean
) : ForceableChannel, SeekableByteChannel {
private var position = 0L
private val readBuffer = ReadBuffer()
private val ioLock = Any()
private var isOpen = true
private val closeLock = Any()
isAppend: Boolean
// Cancelling reads leads to TransportException: Received response with unknown sequence number
) : AbstractFileByteChannel(isAppend, shouldCancelRead = false) {
@Throws(IOException::class)
override fun read(destination: ByteBuffer): Int {
ensureOpen()
if (isAppend) {
throw NonReadableChannelException()
override fun onReadAsync(position: Long, size: Int, timeoutMillis: Long): Future<ByteBuffer> =
try {
FileAccessor.readAsync(file, position, size)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
val remaining = destination.remaining()
if (remaining == 0) {
return 0
}
return synchronized(ioLock) {
readBuffer.read(destination).also {
if (it != -1) {
position += it
.map(
{ response ->
when (response.header.statusCode) {
NtStatus.STATUS_END_OF_FILE.value -> {
return@map ByteBuffer::class.EMPTY
}
NtStatus.STATUS_SUCCESS.value -> {}
else -> throw SMBApiException(response.header, "Read failed for $this")
.toIOException()
}
val data = response.data
if (data.isEmpty()) {
return@map ByteBuffer::class.EMPTY
}
val length = data.size.coerceAtMost(size)
ByteBuffer.wrap(data, 0, length)
}, { e ->
ExecutionException(SMBRuntimeException(e).toIOException())
}
}
)
@Throws(IOException::class)
override fun onWrite(position: Long, source: ByteBuffer) {
val sourcePosition = source.position()
val bytesWritten = try {
file.write(ByteBufferChunkProvider(source, position)).toInt()
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
source.position(sourcePosition + bytesWritten)
}
@Throws(IOException::class)
override fun onTruncate(size: Long) {
try {
file.setLength(size)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
}
@Throws(IOException::class)
override fun write(source: ByteBuffer): Int {
ensureOpen()
if (!source.hasRemaining()) {
return 0
}
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return try {
file.write(ByteBufferChunkProvider(source, position)).toInt()
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}.also {
position += it
}
}
}
@Throws(IOException::class)
override fun position(): Long {
ensureOpen()
synchronized(ioLock) {
if (isAppend) {
position = getSize()
}
return position
}
}
override fun position(newPosition: Long): SeekableByteChannel {
ensureOpen()
if (isAppend) {
// Ignored.
return this
}
synchronized(ioLock) {
readBuffer.reposition(position, newPosition)
position = newPosition
}
return this
}
@Throws(IOException::class)
override fun size(): Long {
ensureOpen()
return getSize()
}
@Throws(IOException::class)
override fun truncate(size: Long): SeekableByteChannel {
ensureOpen()
require(size >= 0)
synchronized(ioLock) {
val currentSize = getSize()
if (size >= currentSize) {
return this
}
try {
file.setLength(size)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
position = position.coerceAtMost(size)
}
return this
}
@Throws(IOException::class)
private fun getSize(): Long =
override fun onSize(): Long =
try {
file.getFileInformation(FileStandardInformation::class.java).endOfFile
} catch (e: SMBRuntimeException) {
@ -138,8 +87,7 @@ class FileByteChannel(
}
@Throws(IOException::class)
override fun force(metaData: Boolean) {
ensureOpen()
override fun onForce(metaData: Boolean) {
try {
file.flush()
} catch (e: SMBRuntimeException) {
@ -147,20 +95,11 @@ class FileByteChannel(
}
}
@Throws(ClosedChannelException::class)
private fun ensureOpen() {
synchronized(closeLock) {
if (!isOpen) {
throw ClosedChannelException()
}
}
}
private fun SMBRuntimeException.toIOException(): IOException =
when {
findCauseByClass<SMBApiException>()
.let { it != null && it.status == NtStatus.STATUS_FILE_CLOSED } -> {
synchronized(closeLock) { isOpen = false }
setClosed()
AsynchronousCloseException().apply { initCause(this@toIOException) }
}
findCauseByClass<InterruptedException>() != null -> {
@ -170,162 +109,37 @@ class FileByteChannel(
else -> IOException(this)
}
override fun isOpen(): Boolean = synchronized(closeLock) { isOpen }
@Throws(IOException::class)
override fun close() {
synchronized(closeLock) {
if (!isOpen) {
return
}
isOpen = false
readBuffer.closeSafe()
try {
file.close()
} catch (e: SMBRuntimeException) {
throw when {
e.findCauseByClass<InterruptedException>() != null ->
InterruptedIOException().apply { initCause(e) }
else -> IOException(e)
}
override fun onClose() {
try {
file.close()
} catch (e: SMBRuntimeException) {
throw when {
e.findCauseByClass<InterruptedException>() != null ->
InterruptedIOException().apply { initCause(e) }
else -> IOException(e)
}
}
}
private inner class ReadBuffer : Closeable {
private val bufferSize: Int
private val timeout: Long
private class ByteBufferChunkProvider(
private val buffer: ByteBuffer,
offset: Long
) : ByteChunkProvider() {
init {
val treeConnect = file.diskShare.treeConnect
val config = treeConnect.config
bufferSize = config.readBufferSize
.coerceAtMost(treeConnect.session.connection.negotiatedProtocol.maxReadSize)
timeout = config.readTimeout
this.offset = offset
}
private val buffer = ByteBuffer.allocate(bufferSize).apply { limit(0) }
private var bufferedPosition = 0L
override fun isAvailable(): Boolean = buffer.hasRemaining()
private var pendingFuture: Future<SMB2ReadResponse>? = null
private val pendingFutureLock = Any()
override fun bytesLeft(): Int = buffer.remaining()
@Throws(IOException::class)
fun read(destination: ByteBuffer): Int {
if (!buffer.hasRemaining()) {
readIntoBuffer()
if (!buffer.hasRemaining()) {
return -1
}
}
val length = destination.remaining().coerceAtMost(buffer.remaining())
val bufferLimit = buffer.limit()
buffer.limit(buffer.position() + length)
destination.put(buffer)
buffer.limit(bufferLimit)
override fun prepareWrite(maxBytesToPrepare: Int) {}
override fun getChunk(chunk: ByteArray): Int {
val length = chunk.size.coerceAtMost(buffer.remaining())
buffer.get(chunk, 0, length)
return length
}
@Throws(IOException::class)
private fun readIntoBuffer() {
val future = synchronized(pendingFutureLock) {
pendingFuture?.also { pendingFuture = null }
} ?: readIntoBufferAsync()
val response = try {
receive(future, timeout)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
when (response.header.statusCode) {
NtStatus.STATUS_END_OF_FILE.value -> {
buffer.limit(0)
return
}
NtStatus.STATUS_SUCCESS.value -> {}
else -> throw SMBApiException(response.header, "Read failed for $this")
.toIOException()
}
val data = response.data
if (data.isEmpty()) {
buffer.limit(0)
return
}
buffer.clear()
val length = data.size.coerceAtMost(buffer.remaining())
buffer.put(data, 0, length)
buffer.flip()
bufferedPosition += length
synchronized(pendingFutureLock) {
try {
pendingFuture = readIntoBufferAsync()
} catch (e: IOException) {
e.printStackTrace()
}
}
}
// @see com.hierynomus.smbj.share.Share.receive
@Throws(SMBRuntimeException::class)
private fun <T> receive(future: Future<T>, timeout: Long): T =
try {
Futures.get(future, timeout, TimeUnit.MILLISECONDS, TransportException.Wrapper)
} catch (e: TransportException) {
throw SMBRuntimeException(e)
}
@Throws(IOException::class)
private fun readIntoBufferAsync(): Future<SMB2ReadResponse> =
try {
FileAccessor.readAsync(file, bufferedPosition, bufferSize)
} catch (e: SMBRuntimeException) {
throw e.toIOException()
}
fun reposition(oldPosition: Long, newPosition: Long) {
if (newPosition == oldPosition) {
return
}
val newBufferPosition = buffer.position() + (newPosition - oldPosition)
if (newBufferPosition in 0..buffer.limit()) {
buffer.position(newBufferPosition.toInt())
} else {
synchronized(pendingFutureLock) {
// TransportException: Received response with unknown sequence number
//pendingFuture?.cancel(true)?.also { pendingFuture = null }
pendingFuture = null
}
buffer.limit(0)
bufferedPosition = newPosition
}
}
override fun close() {
synchronized(pendingFutureLock) {
// TransportException: Received response with unknown sequence number
//pendingFuture?.cancel(true)?.also { pendingFuture = null }
pendingFuture = null
}
}
}
}
private class ByteBufferChunkProvider(
private val buffer: ByteBuffer,
offset: Long
) : ByteChunkProvider() {
init {
this.offset = offset
}
override fun isAvailable(): Boolean = buffer.hasRemaining()
override fun bytesLeft(): Int = buffer.remaining()
override fun prepareWrite(maxBytesToPrepare: Int) {}
override fun getChunk(chunk: ByteArray): Int {
val length = chunk.size.coerceAtMost(buffer.remaining())
buffer.get(chunk, 0, length)
return length
}
}