diff --git a/app/src/app/java/io/legado/app/lib/cronet/AbsCallBack.kt b/app/src/app/java/io/legado/app/lib/cronet/AbsCallBack.kt index 7627949b8..e47cb3b12 100644 --- a/app/src/app/java/io/legado/app/lib/cronet/AbsCallBack.kt +++ b/app/src/app/java/io/legado/app/lib/cronet/AbsCallBack.kt @@ -9,12 +9,18 @@ import okhttp3.EventListener import okhttp3.MediaType.Companion.toMediaTypeOrNull import okhttp3.ResponseBody.Companion.asResponseBody import okio.Buffer +import okio.Source +import okio.Timeout +import okio.buffer import org.chromium.net.CronetException import org.chromium.net.UrlRequest import org.chromium.net.UrlResponseInfo import java.io.IOException import java.nio.ByteBuffer import java.util.* +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean @Keep abstract class AbsCallBack( @@ -22,13 +28,14 @@ abstract class AbsCallBack( val mCall: Call, private val eventListener: EventListener? = null, private val responseCallback: Callback? = null -) : UrlRequest.Callback(), AutoCloseable { - - val buffer = Buffer() +) : UrlRequest.Callback() { var mResponse: Response - private var followCount = 0 + private var request: UrlRequest? = null + private var finished = AtomicBoolean(false) + private val callbackResults = ArrayBlockingQueue(2) + private val urlResponseInfoChain = arrayListOf() @Throws(IOException::class) @@ -63,6 +70,7 @@ abstract class AbsCallBack( return } followCount += 1 + urlResponseInfoChain.add(info) val client = okHttpClient if (originalRequest.url.isHttps && newLocationUrl.startsWith("http://") && client.followSslRedirects) { request.followRedirect() @@ -78,14 +86,31 @@ abstract class AbsCallBack( override fun onResponseStarted(request: UrlRequest, info: UrlResponseInfo) { - this.mResponse = responseFromResponse(this.mResponse, info) + this.request = request + val contentLength = info.allHeaders["Content-Length"]?.lastOrNull()?.toLongOrNull() ?: -1 + val contentType = (info.allHeaders["content-type"]?.lastOrNull() + ?: "text/plain; charset=\"utf-8\"").toMediaTypeOrNull() + val responseBody = CronetBodySource().buffer().asResponseBody(contentType, contentLength) + val newRequest = originalRequest.newBuilder().url(info.url).build() + val response = createResponse(originalRequest, info) + .request(newRequest) + .body(responseBody) + .priorResponse(buildPriorResponse(originalRequest, urlResponseInfoChain, info.urlChain)) + .build() + mResponse = response + onSuccess(response) + //打印协议,用于调试 DebugLog.i(javaClass.simpleName, "start[${info.negotiatedProtocol}]${info.url}") if (eventListener != null) { - eventListener.responseHeadersEnd(mCall, this.mResponse) + eventListener.responseHeadersEnd(mCall, response) eventListener.responseBodyStart(mCall) } - request.read(ByteBuffer.allocateDirect(32 * 1024)) + try { + responseCallback?.onResponse(mCall, response) + } catch (e: IOException) { + // Pass? + } } @@ -95,60 +120,31 @@ abstract class AbsCallBack( info: UrlResponseInfo, byteBuffer: ByteBuffer ) { - - - if (mCall.isCanceled()) { - request.cancel() - onError(IOException("Request Canceled")) - } - - byteBuffer.flip() - - try { - buffer.write(byteBuffer) - } catch (e: IOException) { - DebugLog.e(javaClass.name, "IOException during ByteBuffer read. Details: ", e) - onError(IOException("IOException during ByteBuffer read. Details:", e)) - return - } - byteBuffer.clear() - request.read(byteBuffer) + callbackResults.add(CallbackResult(CallbackStep.ON_READ_COMPLETED, byteBuffer)) } override fun onSucceeded(request: UrlRequest, info: UrlResponseInfo) { + callbackResults.add(CallbackResult(CallbackStep.ON_SUCCESS)) eventListener?.responseBodyEnd(mCall, info.receivedByteCount) - val contentType: MediaType? = (this.mResponse.header("content-type") - ?: "text/plain; charset=\"utf-8\"").toMediaTypeOrNull() - val responseBody: ResponseBody = - buffer.asResponseBody(contentType) - val newRequest = originalRequest.newBuilder().url(info.url).build() - this.mResponse = this.mResponse.newBuilder().body(responseBody).request(newRequest).build() - onSuccess(this.mResponse) //DebugLog.i(javaClass.simpleName, "end[${info.negotiatedProtocol}]${info.url}") eventListener?.callEnd(mCall) - if (responseCallback != null) { - try { - responseCallback.onResponse(mCall, this.mResponse) - } catch (e: IOException) { - // Pass? - } - } } //UrlResponseInfo可能为null override fun onFailed(request: UrlRequest, info: UrlResponseInfo?, error: CronetException) { + callbackResults.add(CallbackResult(CallbackStep.ON_FAILED, null, error)) DebugLog.e(javaClass.name, error.message.toString()) onError(error.asIOException()) - this.eventListener?.callFailed(mCall, error) + eventListener?.callFailed(mCall, error) responseCallback?.onFailure(mCall, error) } override fun onCanceled(request: UrlRequest?, info: UrlResponseInfo?) { - super.onCanceled(request, info) - this.eventListener?.callEnd(mCall) + callbackResults.add(CallbackResult(CallbackStep.ON_CANCELED)) + eventListener?.callEnd(mCall) //onError(IOException("Cronet Request Canceled")) } @@ -169,21 +165,26 @@ abstract class AbsCallBack( val negotiatedProtocol = responseInfo.negotiatedProtocol.lowercase(Locale.getDefault()) return when { negotiatedProtocol.contains("h3") -> { - return Protocol.QUIC + Protocol.QUIC } + negotiatedProtocol.contains("quic") -> { Protocol.QUIC } + negotiatedProtocol.contains("spdy") -> { @Suppress("DEPRECATION") Protocol.SPDY_3 } + negotiatedProtocol.contains("h2") -> { Protocol.HTTP_2 } + negotiatedProtocol.contains("1.1") -> { Protocol.HTTP_1_1 } + else -> { Protocol.HTTP_1_0 } @@ -211,23 +212,114 @@ abstract class AbsCallBack( } - private fun responseFromResponse( - response: Response, + private fun createResponse( + request: Request, responseInfo: UrlResponseInfo - ): Response { + ): Response.Builder { val protocol = protocolFromNegotiatedProtocol(responseInfo) val headers = headersFromResponse(responseInfo) - return response.newBuilder() + return Response.Builder() + .request(request) .receivedResponseAtMillis(System.currentTimeMillis()) .protocol(protocol) .code(responseInfo.httpStatusCode) .message(responseInfo.httpStatusText) .headers(headers) - .build() + } + + private fun buildPriorResponse( + request: Request, + redirectResponseInfos: List, + urlChain: List + ): Response? { + var priorResponse: Response? = null + if (redirectResponseInfos.isNotEmpty()) { + check(urlChain.size == redirectResponseInfos.size + 1) { + "The number of redirects should be consistent across URLs and headers!" + } + for (i in redirectResponseInfos.indices) { + val redirectedRequest = request.newBuilder().url(urlChain[i]).build() + priorResponse = createResponse(redirectedRequest, redirectResponseInfos[i]) + .priorResponse(priorResponse) + .build() + } + + } + return priorResponse } } - override fun close() { - buffer.clear() + inner class CronetBodySource : Source { + + private var buffer = ByteBuffer.allocateDirect(32 * 1024) + private var closed = false + private val timeout = mCall.timeout().timeoutNanos() + override fun close() { + if (closed) { + return + } + closed = true + if (!finished.get()) { + request?.cancel() + } + } + + @Suppress("NULLABILITY_MISMATCH_BASED_ON_JAVA_ANNOTATIONS") + override fun read(sink: Buffer, byteCount: Long): Long { + if (mCall.isCanceled()) { + throw IOException("Request Canceled") + } + + if (closed) { + throw IOException("Source Closed") + } + + if (finished.get()) { + return -1 + } + + if (byteCount < buffer.limit()) { + buffer.limit(byteCount.toInt()) + } + + request?.read(buffer) + + val result = callbackResults.poll(timeout, TimeUnit.NANOSECONDS) + if (result == null) { + request?.cancel() + throw IOException("Request Timeout") + } + + return when (result.callbackStep) { + CallbackStep.ON_FAILED -> { + finished.set(true) + buffer = null + throw IOException(result.exception) + } + + CallbackStep.ON_SUCCESS -> { + finished.set(true) + buffer = null + -1 + } + + CallbackStep.ON_CANCELED -> { + buffer = null + throw IOException("Request Canceled") + } + + CallbackStep.ON_READ_COMPLETED -> { + result.buffer!!.flip() + val bytesWritten = sink.write(result.buffer) + result.buffer.clear() + bytesWritten.toLong() + } + } + } + + override fun timeout(): Timeout { + return mCall.timeout() + } + } -} \ No newline at end of file +} diff --git a/app/src/app/java/io/legado/app/lib/cronet/CallbackResult.kt b/app/src/app/java/io/legado/app/lib/cronet/CallbackResult.kt new file mode 100644 index 000000000..e2b31125c --- /dev/null +++ b/app/src/app/java/io/legado/app/lib/cronet/CallbackResult.kt @@ -0,0 +1,12 @@ +package io.legado.app.lib.cronet + +import org.chromium.net.CronetException + +import java.nio.ByteBuffer + + +data class CallbackResult( + val callbackStep: CallbackStep, + val buffer: ByteBuffer? = null, + val exception: CronetException? = null +) diff --git a/app/src/app/java/io/legado/app/lib/cronet/CallbackStep.kt b/app/src/app/java/io/legado/app/lib/cronet/CallbackStep.kt new file mode 100644 index 000000000..968b0555f --- /dev/null +++ b/app/src/app/java/io/legado/app/lib/cronet/CallbackStep.kt @@ -0,0 +1,8 @@ +package io.legado.app.lib.cronet + +enum class CallbackStep { + ON_READ_COMPLETED, + ON_SUCCESS, + ON_FAILED, + ON_CANCELED +} diff --git a/app/src/main/java/io/legado/app/web/HttpServer.kt b/app/src/main/java/io/legado/app/web/HttpServer.kt index 260ed1ed1..0093e73a6 100644 --- a/app/src/main/java/io/legado/app/web/HttpServer.kt +++ b/app/src/main/java/io/legado/app/web/HttpServer.kt @@ -11,6 +11,8 @@ import io.legado.app.help.coroutine.Coroutine import io.legado.app.service.WebService import io.legado.app.utils.* import io.legado.app.web.utils.AssetsWeb +import okio.Pipe +import okio.buffer import java.io.* class HttpServer(port: Int) : NanoHTTPD(port) { @@ -101,11 +103,10 @@ class HttpServer(port: Int) : NanoHTTPD(port) { ) } else { val data = returnData.data - if (data is List<*> && data.size > 1000) { - val pis = PipedInputStream(1024 * 1024) + if (data is List<*> && data.size > 3000) { + val pipe = Pipe(16 * 1024) Coroutine.async { - @Suppress("BlockingMethodInNonBlockingContext") - PipedOutputStream(pis).use { out -> + pipe.sink.buffer().outputStream().use { out -> BufferedWriter(OutputStreamWriter(out, "UTF-8")).use { GSON.toJson(returnData, it) } @@ -114,7 +115,7 @@ class HttpServer(port: Int) : NanoHTTPD(port) { newChunkedResponse( Response.Status.OK, "application/json", - pis + pipe.source.buffer().inputStream() ) } else { newFixedLengthResponse(GSON.toJson(returnData))