OkHttp Post 限速上传

在局域网中通过OkHttp post 上传一些大文件,测试后发现文件上传经常占满带宽,影响业务交互。在上传时需要限速。

限速可以服务端限速,也可以客户端限速。服务端限速只是延迟接受,造成TCP 缓冲区拥堵,带宽的问题并没有真正的解决。客户端限速的思路就是写Socket 限速。搜了一下资料,OkHttp 并没有提供限速的接口。

研究了一下OkHttp 的拦截器 Interceptor

一 网络访问的执行 RealCall

在RealCall 的execute 函数中调用getResponseWithInterceptorChain 函数 获取网络的Response。

RealCall.java 
 @Override protected void execute() {
      boolean signalledCallback = false;
      try {
        Response response = getResponseWithInterceptorChain();
        if (retryAndFollowUpInterceptor.isCanceled()) {
          signalledCallback = true;
          responseCallback.onFailure(RealCall.this, new IOException("Canceled"));
        } else {
          signalledCallback = true;
          responseCallback.onResponse(RealCall.this, response);
        }
      } catch (IOException e) {
        if (signalledCallback) {
          // Do not signal the callback twice!
          Platform.get().log(INFO, "Callback failure for " + toLoggableString(), e);
        } else {
          responseCallback.onFailure(RealCall.this, e);
        }
      } finally {
        client.dispatcher().finished(this);
      }
    }
  }

二 OkHttp Interceptor 的实现。

从代码中可以看到,拦截器是别加入到一个数组中。依次是:

  1. client.interceptors() 自定的拦截器
  2. retryAndFollowUpInterceptor
  3. BridgeInterceptor
  4. CacheInterceptor
  5. ConnectInterceptor
  6. CallServerInterceptor

顺序很重要,因为下面拦截器的执行和顺序有关。
真正的网咯访问是在 CallServerInterceptor 中

RealCall.java 
  Response getResponseWithInterceptorChain() throws IOException {
    // Build a full stack of interceptors.
    List interceptors = new ArrayList<>();
    interceptors.addAll(client.interceptors());
    interceptors.add(retryAndFollowUpInterceptor);
    interceptors.add(new BridgeInterceptor(client.cookieJar()));
    interceptors.add(new CacheInterceptor(client.internalCache()));
    interceptors.add(new ConnectInterceptor(client));
    if (!forWebSocket) {
      interceptors.addAll(client.networkInterceptors());
    }
    interceptors.add(new CallServerInterceptor(forWebSocket));

    Interceptor.Chain chain = new RealInterceptorChain(
        interceptors, null, null, null, 0, originalRequest);
    return chain.proceed(originalRequest);
  }

三 RealInterceptorChain 的执行

RealInterceptorChain 执行是一个链式的过程。注意在RealCall.java 中构造RealInterceptorChain 传入的index 参数是0, 然后在proceed 函数中又new
一个新的RealInterceptorChain next, next 的index 加一了。

  1. this.interceptor.intercept(next); 注意把新的拦截器 作为参数传入了。
@Override public Response proceed(Request request) throws IOException {
    return proceed(request, streamAllocation, httpCodec, connection);
  }

  public Response proceed(Request request, StreamAllocation streamAllocation, HttpCodec httpCodec,
      RealConnection connection) throws IOException {
    
    // Call the next interceptor in the chain.
    RealInterceptorChain next = new RealInterceptorChain(
        interceptors, streamAllocation, httpCodec, connection, index + 1, request);
    Interceptor interceptor = interceptors.get(index);
    
    Response response = interceptor.intercept(next);


    return response;
  }

在拦截其中可以根据需要

  1. 如果是处理Request 先处理拦截器的逻辑,然后链式调用 next 拦截器的 proceed.
  2. 如果是处理 Response 先调用 next 拦截器然后 处理拦截器的逻辑。

这个设计模式不错。

四 CallServerInterceptor

CallServerInterceptor 是负责网络读写的地方,如果要实现限速,最大的可能就是这里。

    request.body().writeTo(bufferedRequestBody);

通过request 的body 向网络写数据。而这个body 来自哪里呢, 在通过post 上传的数据的时候,需要构建 MultipartBody 来封装上传的文件。

public CallServerInterceptor(boolean forWebSocket) {
    this.forWebSocket = forWebSocket;
  }

  @Override public Response intercept(Chain chain) throws IOException {
    

    Response.Builder responseBuilder = null;
    if (HttpMethod.permitsRequestBody(request.method()) && request.body() != null) {
     
      if (responseBuilder == null) {
        // Write the request body if the "Expect: 100-continue" expectation was met.
        Sink requestBodyOut = httpCodec.createRequestBody(request, request.body().contentLength());
        BufferedSink bufferedRequestBody = Okio.buffer(requestBodyOut);
        request.body().writeTo(bufferedRequestBody);
        bufferedRequestBody.close();
      } else if (!connection.isMultiplexed()) {
        // If the "Expect: 100-continue" expectation wasn't met, prevent the HTTP/1 connection from
        // being reused. Otherwise we're still obligated to transmit the request body to leave the
        // connection in a consistent state.
        streamAllocation.noNewStreams();
      }
    }

    httpCodec.finishRequest();

    if (responseBuilder == null) {
      responseBuilder = httpCodec.readResponseHeaders(false);
    }

    Response response = responseBuilder
        .request(request)
        .handshake(streamAllocation.connection().handshake())
        .sentRequestAtMillis(sentRequestMillis)
        .receivedResponseAtMillis(System.currentTimeMillis())
        .build();

    int code = response.code();
    if (forWebSocket && code == 101) {
      // Connection is upgrading, but we need to ensure interceptors see a non-null response body.
      response = response.newBuilder()
          .body(Util.EMPTY_RESPONSE)
          .build();
    } else {
      response = response.newBuilder()
          .body(httpCodec.openResponseBody(response))
          .build();
    }

    
    return response;
  }

五 RequestBody

通过代码可以看到, RequestBody 为一个抽象类,通过 MultipartBody.create 直接new 出来。那我们的思路就是修改这个RequestBody 的writeTo 函数,控制写Socket 的速度。

    public MultipartBody.Part getMultipartBodyPart(){
        RequestBody requestFile = MultipartBody.create(MediaType.parse("multipart/form-data"), new File(mFileEncrypt));
        MultipartBody.Part fileBody = MultipartBody.Part.createFormData(FILE_ENCRYPT, mFileEncrypt, requestFile);

        return fileBody;
    }
    
      public static RequestBody create(final @Nullable MediaType contentType, final File file) {
    if (file == null) throw new NullPointerException("content == null");

    return new RequestBody() {
      @Override public @Nullable MediaType contentType() {
        return contentType;
      }

      @Override public long contentLength() {
        return file.length();
      }

      @Override public void writeTo(BufferedSink sink) throws IOException {
        Source source = null;
        try {
          source = Okio.source(file);
          sink.writeAll(source);
        } finally {
          Util.closeQuietly(source);
        }
      }
    };
  }

六 RateLimitingRequestBody

修改后的代码如下,针对OkIO 的一些操作 复制了一些代码出来。另外由于编译问题,OkIO.source 方法采用了反射。

public MultipartBody.Part getMultipartBodyPart(){
        RequestBody requestFile = RateLimitingRequestBody.createRequestBody(MediaType.parse("multipart/form-data"), new File(mFileEncrypt), UPLOAD_RATE);
        MultipartBody.Part fileBody = MultipartBody.Part.createFormData(FILE_ENCRYPT, mFileEncrypt, requestFile);

        return fileBody;
}    
    
public class RateLimitingRequestBody extends RequestBody {

    private MediaType mContentType;
    private File mFile;
    private int mMaxRate;    // bit/ms

    private RateLimitingRequestBody(@Nullable final MediaType contentType, final File file, int rate){
        mContentType = contentType;
        mFile = file;
        mMaxRate = rate;
    }

    @Override
    public MediaType contentType() {
        return mContentType;
    }

    @Override
    public void writeTo(BufferedSink sink) throws IOException {

        Source source = null;

        try {

            /*
            *  reflect instead of Okio.source(mFile) because of build error at platform 23.
            *  the error is java.nio.** can't find.
            */

            // source = Okio.source(mFile);

            String className = "okio.Okio";
            String methodName = "source";
            Class clazz = Class.forName(className);
            Method method = clazz.getMethod(methodName, File.class);
            source = (Source) method.invoke(null, mFile);
            writeAll(sink, source);

        } catch (InterruptedException e) {
            NLog.exception("writeTo", e);
        } catch (NoSuchMethodException e) {
            NLog.exception("writeTo", e);
        } catch (IllegalAccessException e) {
            NLog.exception("writeTo", e);
        } catch (InvocationTargetException e) {
            NLog.exception("writeTo", e);
        } catch (ClassNotFoundException e) {
            NLog.exception("writeTo", e);
        } finally {
            Util.closeQuietly(source);
        }
    }


    public long writeAll(BufferedSink sink, Source source) throws IOException, InterruptedException {
        if (source == null) {
            throw new IllegalArgumentException("source == null");
        } else {
            long totalBytesRead = 0L;

            long readCount;
            long start = System.currentTimeMillis();
            while((readCount = source.read(sink.buffer(), 8192L)) != -1L) {
                totalBytesRead += readCount;
                sink.emitCompleteSegments();

                long time = System.currentTimeMillis();
                if(time == start) continue;
                long rate = (totalBytesRead * 8) / (time - start);

                if(rate > mMaxRate/1000){
                    int sleep = (int) (totalBytesRead * 8 * 1000 / mMaxRate - (time - start));
                    NLog.v("writeAll","totalBytesRead:"+totalBytesRead+"B "+ " Rate:"+rate*1000+"bits");
                    NLog.d("writeAll", "sleep:"+sleep);
                    Thread.sleep(sleep+500);
                }
            }

            long end = System.currentTimeMillis();
            long rate = (totalBytesRead * 8 * 1000) / ((end - start));
            NLog.e("writeAll","totalBytesRead:"+totalBytesRead+"B "+ " Rate:"+rate+"bits"+" total time:"+(end-start));
            return totalBytesRead;
        }
    }


    public static RequestBody createRequestBody(@Nullable final MediaType contentType, final File file, int rate) {
        if (file == null) {
            throw new NullPointerException("content == null");
        } else {
            return new RateLimitingRequestBody(contentType, file, rate);
        }
    }
}

你可能感兴趣的:(android)