Java多线程编程-文件下载

前言:

本文是基于《Java多线程编程实战指南》第四章,源码是摘抄作者的源码,源码会加上自己的理解,并且用一个实例运行起来。第四章开始是实战部分,而前面3章更多是讲解多线程介绍,生命周期,目标和挑战,以及同步机制。

前三章内容后续博客也会整理一个总结,如果文章觉得对您有益,还请点点关注。

基于数据分割实现并发化:

比如要下载一个大文件,那么基于数据分割实现并发化就很有意义,比如在网络带宽100Mbps下,我们下载一个600MB文件需要600/(100/8)=48秒,如果用3个线程就会是16秒。

摘抄书中的一个图,基于数据分割实现并发化的示意图。其原理就是利用将原始数据进行平均分割,然后又每个工作线程负责分割后的数据。

Java多线程编程-文件下载_第1张图片

代码基础讲解:

    public void download(int taskCount, long reportInterval)
            throws Exception {

        long chunkSizePerThread = fileSize / taskCount;
        // 下载数据段的起始字节
        long lowerBound = 0;
        // 下载数据段的结束字节
        long upperBound = 0;

        DownloadTask dt;
        for (int i = taskCount - 1; i >= 0; i--) {
            lowerBound = i * chunkSizePerThread;
            if (i == taskCount - 1) {
                upperBound = fileSize;
            } else {
                upperBound = lowerBound + chunkSizePerThread - 1;
            }

            // 创建下载任务
            dt = new DownloadTask(lowerBound, upperBound, requestURL, storage,
                    taskCanceled);
            dispatchWork(dt, i);
        }
        // 定时报告下载进度
        reportProgress(reportInterval);
        // 清理程序占用的资源
        doCleanup();

    }
    protected void dispatchWork(final DownloadTask dt, int workerIndex) {
        // 创建下载线程
        Thread workerThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    dt.run();
                } catch (Exception e) {
                    e.printStackTrace();
                    // 取消整个文件的下载
                    cancelDownload();
                }
            }
        });
        workerThread.setName("downloader-" + workerIndex);
        workerThread.start();
    }
    @Override
    public void run() {
        if (cancelFlag.get()) {
            return;
        }
        ReadableByteChannel channel = null;
        try {
            channel = Channels.newChannel(issueRequest(requestURL, lowerBound,
                    upperBound));
            ByteBuffer buf = ByteBuffer.allocate(1024);
            while (!cancelFlag.get() && channel.read(buf) > 0) {
                // 将从网络读取的数据写入缓冲区
                xbuf.write(buf);
                buf.clear();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            Tools.silentClose(channel, xbuf);
        }
    }
    public void write(ByteBuffer buf) throws IOException {
        int length = buf.position();
        final int capacity = byteBuf.capacity();
        // 当前缓冲区已满,或者剩余容量不够容纳新数据
        if (offset + length > capacity || length == capacity) {
            // 将缓冲区中的数据写入文件
            flush();
        }
        byteBuf.position(offset);
        buf.flip();
        byteBuf.put(buf);
        offset += length;
    }

    public void flush() throws IOException {
        int length;
        byteBuf.flip();
        length = storage.store(globalOffset, byteBuf);
        byteBuf.clear();
        globalOffset += length;
        offset = 0;
    }

1.Download部分根据输入taskcount数量(一个taskcount就是一个工作线程)计算每个线程应该下载的字节数,lowerBound表示起始字节下限,upperBound表示终结字节上限。比如2个线程,总共下载10个字节,那么第一个线程lowerbound=0,upperBound=4,第二个线程就是lowerbound=5,upperBound=10。

这边实际用到了http range头部实现一段一段数据下载,在DownloadTask类里面,设置http头部的Range types: bytes=[lowerbound] - [upperBound]

2.然后new DownloadTask并且在dispatchWork真正的实例化了工作线程,new Thread,并调用DownloadTask里面的run函数。

3.而run函数里面主要用到了Java NIO,这部分会单独写一个文章介绍,这边只是简单介绍一下:

run函数里面持续从socket的获取数据存入bytebuffer,capacity为1024,而到了xbuf.write函数,xbuf也有一个bytebuffer,capacity为1024*1024, 将xbuf的bytebuffer的position设置为即将要填入数据的起始位置,将socket 的bytebuffer flip(flip意味着准备清空bytebuff,实际上是设置limit位置以及将position设置为0),然后将socket bytebuffer填入xbuf的bytebuffer里面。

4.最后Storage类中,设置了最终IO输出文件位置,并且根据offset,一段段将每个工作线程填入文件内容中,比如说,总字节数2048,所以线程1是从位置0填到1023,线程2是从1024到2048。

有关于postion,limilt,capacity,mark如图所示----引用《Java NIO》

Java多线程编程-文件下载_第2张图片

代码:

项目结构:

Java多线程编程-文件下载_第3张图片

Main.java

public class Main {
    public static void main(String[] args) throws Exception {
        String url ="yoururl";
        BigFileDownloader bigFileDownloader = new BigFileDownloader(url);
        bigFileDownloader.download(2,3000);
    }
}

BigFileDownloader.java

import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.atomic.AtomicBoolean;

public class BigFileDownloader {
    protected final URL requestURL;
    protected final long fileSize;


    protected final Storage storage;
    protected final AtomicBoolean taskCanceled = new AtomicBoolean(false);

    public BigFileDownloader(String strURL) throws Exception {
        requestURL = new URL(strURL);

        fileSize = retieveFileSize(requestURL);
        String filename = strURL.substring(strURL.lastIndexOf("/")+1);

        storage = new Storage(fileSize, filename);
    }

    /**
     * 下载指定的文件
     *
     * @param taskCount
     *          任务个数
     * @param reportInterval
     *          下载进度报告周期
     * @throws Exception
     */
    public void download(int taskCount, long reportInterval)
            throws Exception {

        long chunkSizePerThread = fileSize / taskCount;
        // 下载数据段的起始字节
        long lowerBound = 0;
        // 下载数据段的结束字节
        long upperBound = 0;

        DownloadTask dt;
        for (int i = taskCount - 1; i >= 0; i--) {
            lowerBound = i * chunkSizePerThread;
            if (i == taskCount - 1) {
                upperBound = fileSize;
            } else {
                upperBound = lowerBound + chunkSizePerThread - 1;
            }

            // 创建下载任务
            dt = new DownloadTask(lowerBound, upperBound, requestURL, storage,
                    taskCanceled);
            dispatchWork(dt, i);
        }
        // 定时报告下载进度
        reportProgress(reportInterval);
        // 清理程序占用的资源
        doCleanup();

    }

    protected void doCleanup() {
        Tools.silentClose(storage);
    }

    protected void cancelDownload() {
        if (taskCanceled.compareAndSet(false, true)) {
            doCleanup();
        }
    }

    protected void dispatchWork(final DownloadTask dt, int workerIndex) {
        // 创建下载线程
        Thread workerThread = new Thread(new Runnable() {
            @Override
            public void run() {
                try {
                    dt.run();
                } catch (Exception e) {
                    e.printStackTrace();
                    // 取消整个文件的下载
                    cancelDownload();
                }
            }
        });
        workerThread.setName("downloader-" + workerIndex);
        workerThread.start();
    }

    // 根据指定的URL获取相应文件的大小
    private static long retieveFileSize(URL requestURL) throws Exception {
        long size = -1;
        HttpURLConnection conn = null;
        try {
            conn = (HttpURLConnection) requestURL.openConnection();

            conn.setRequestMethod("HEAD");
            conn.setRequestProperty("Connection", "Keep-alive");
            conn.connect();
            int statusCode = conn.getResponseCode();
            if (HttpURLConnection.HTTP_OK != statusCode) {
                throw new Exception("Server exception,status code:" + statusCode);
            }

            String cl = conn.getHeaderField("Content-Length");
            size = Long.valueOf(cl);
        } finally {
            if (null != conn) {
                conn.disconnect();
            }
        }
        return size;
    }

    // 报告下载进度
    private void reportProgress(long reportInterval) throws InterruptedException {
        float lastCompletion;
        int completion = 0;
        while (!taskCanceled.get()) {
            lastCompletion = completion;
            completion = (int) (storage.getTotalWrites() * 100 / fileSize);
            if (completion == 100) {
                break;
            } else if (completion - lastCompletion >= 1) {
                Debug.info("Completion:%s%%", completion);
                if (completion >= 90) {
                    reportInterval = 1000;
                }
            }
            Thread.sleep(reportInterval);
        }
        Debug.info("Completion:%s%%", completion);
    }
}

 Debug.java

import java.io.PrintStream;
import java.text.SimpleDateFormat;
import java.util.Date;

public class Debug {
    private static ThreadLocal sdfWrapper = new ThreadLocal() {
        @Override
        protected SimpleDateFormat initialValue() {
            return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
        }

    };

    enum Label {
        INFO("INFO"),
        ERR("ERROR");
        String name;

        Label(String name) {
            this.name = name;
        }

        public String getName() {
            return name;
        }
    }

    // public static void info(String message) {
    // printf(Label.INFO, "%s", message);
    // }

    public static void info(String format, Object... args) {
        printf(Label.INFO, format, args);
    }

    public static void info(boolean message) {
        info("%s", message);
    }

    public static void info(int message) {
        info("%d", message);
    }

    public static void error(String message, Object... args) {
        printf(Label.ERR, message, args);
    }

    public static void printf(Label label, String format, Object... args) {
        SimpleDateFormat sdf = sdfWrapper.get();
        @SuppressWarnings("resource")
        final PrintStream ps = label == Label.INFO ? System.out : System.err;
        ps.printf('[' + sdf.format(new Date()) + "][" + label.getName()
                + "]["
                + Thread.currentThread().getName() + "]:" + format + " %n", args);
    }
}

DownloadBuffer.java 

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;

public class DownloadBuffer implements Closeable {
    /**
     * 当前Buffer中缓冲的数据相对于整个存储文件的位置偏移
     */
    private long globalOffset;
    private long upperBound;
    private int offset = 0;
    public final ByteBuffer byteBuf;
    private final Storage storage;

    public DownloadBuffer(long globalOffset, long upperBound,
                          final Storage storage) {
        this.globalOffset = globalOffset;
        this.upperBound = upperBound;
        this.byteBuf = ByteBuffer.allocate(1024 * 1024);
        this.storage = storage;
    }

    public void write(ByteBuffer buf) throws IOException {
        int length = buf.position();
        final int capacity = byteBuf.capacity();
        // 当前缓冲区已满,或者剩余容量不够容纳新数据
        if (offset + length > capacity || length == capacity) {
            // 将缓冲区中的数据写入文件
            flush();
        }
        byteBuf.position(offset);
        buf.flip();
        byteBuf.put(buf);
        offset += length;
    }

    public void flush() throws IOException {
        int length;
        byteBuf.flip();
        length = storage.store(globalOffset, byteBuf);
        byteBuf.clear();
        globalOffset += length;
        offset = 0;
    }

    @Override
    public void close() throws IOException {
        Debug.info("globalOffset:%s,upperBound:%s", globalOffset, upperBound);
        if (globalOffset < upperBound) {
            flush();
        }
    }
}

DownloadTask.java 

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * 下载子任务
 *
 * @author Viscent Huang
 */
public class DownloadTask implements Runnable {
    private final long lowerBound;
    private final long upperBound;
    private final DownloadBuffer xbuf;
    private final URL requestURL;
    private final AtomicBoolean cancelFlag;

    public DownloadTask(long lowerBound, long upperBound, URL requestURL,
                        Storage storage, AtomicBoolean cancelFlag) {
        this.lowerBound = lowerBound;
        this.upperBound = upperBound;
        this.requestURL = requestURL;
        this.xbuf = new DownloadBuffer(lowerBound, upperBound, storage);
        this.cancelFlag = cancelFlag;
    }

    // 对指定的URL发起HTTP分段下载请求
    private static InputStream issueRequest(URL requestURL, long lowerBound,
                                            long upperBound) throws IOException {
        Thread me = Thread.currentThread();
        Debug.info(me + "->[" + lowerBound + "," + upperBound + "]");
        final HttpURLConnection conn;
        InputStream in = null;
        conn = (HttpURLConnection) requestURL.openConnection();
        String strConnTimeout = System.getProperty("x.dt.conn.timeout");
        int connTimeout = null == strConnTimeout ? 60000 : Integer
                .valueOf(strConnTimeout);
        conn.setConnectTimeout(connTimeout);

        String strReadTimeout = System.getProperty("x.dt.read.timeout");
        int readTimeout = null == strReadTimeout ? 60000 : Integer
                .valueOf(strReadTimeout);
        conn.setReadTimeout(readTimeout);

        conn.setRequestMethod("GET");
        conn.setRequestProperty("Connection", "Keep-alive");
        // Range: bytes=0-1024
        conn.setRequestProperty("Range", "bytes=" + lowerBound + "-" + upperBound);
        conn.setDoInput(true);
        conn.connect();

        int statusCode = conn.getResponseCode();
        if (HttpURLConnection.HTTP_PARTIAL != statusCode) {
            conn.disconnect();
            throw new IOException("Server exception,status code:" + statusCode);
        }

        Debug.info(me + "-Content-Range:" + conn.getHeaderField("Content-Range")
                + ",connection:" + conn.getHeaderField("connection"));

        in = new BufferedInputStream(conn.getInputStream()) {
            @Override
            public void close() throws IOException {
                try {
                    super.close();
                } finally {
                    conn.disconnect();
                }
            }
        };

        return in;
    }

    @Override
    public void run() {
        if (cancelFlag.get()) {
            return;
        }
        ReadableByteChannel channel = null;
        try {
            channel = Channels.newChannel(issueRequest(requestURL, lowerBound,
                    upperBound));
            ByteBuffer buf = ByteBuffer.allocate(1024);
            while (!cancelFlag.get() && channel.read(buf) > 0) {
                // 将从网络读取的数据写入缓冲区
                xbuf.write(buf);
                buf.clear();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            Tools.silentClose(channel, xbuf);
        }
    }
}

Storage.java 

import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.atomic.AtomicLong;

public class Storage implements Closeable, AutoCloseable {
    private final RandomAccessFile storeFile;
    private final FileChannel storeChannel;
    protected final AtomicLong totalWrites = new AtomicLong(0);

    public Storage(long fileSize, String fileShortName) throws IOException {
        String fullFileName = System.getProperty("java.io.tmpdir") + "/"
                + fileShortName;
        String localFileName;
        localFileName = createStoreFile(fileSize, fullFileName);
        storeFile = new RandomAccessFile(localFileName, "rw");
        storeChannel = storeFile.getChannel();
    }

    /**
     * 将data中指定的数据写入文件
     *
     * @param offset
     *          写入数据在整个文件中的起始偏移位置
     * @param byteBuf
     *          byteBuf必须在该方法调用前执行byteBuf.flip()
     * @throws IOException
     * @return 写入文件的数据长度
     */
    public int store(long offset, ByteBuffer byteBuf)
            throws IOException {
        int length;
        storeChannel.write(byteBuf, offset);
        length = byteBuf.limit();
        totalWrites.addAndGet(length);
        return length;
    }

    public long getTotalWrites() {
        return totalWrites.get();
    }

    private String createStoreFile(final long fileSize, String fullFileName)
            throws IOException {
        File file = new File(fullFileName);
        Debug.info("create local file:%s", fullFileName);
        RandomAccessFile raf;
        raf = new RandomAccessFile(file, "rw");
        try {
            raf.setLength(fileSize);
        } finally {
            Tools.silentClose(raf);
        }
        return fullFileName;
    }

    @Override
    public synchronized void close() throws IOException {
        if (storeChannel.isOpen()) {
            Tools.silentClose(storeChannel, storeFile);
        }
    }
}

Tools.java 

import java.io.*;
import java.lang.reflect.Field;
import java.math.BigInteger;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import sun.misc.Unsafe;
public final class Tools {
    private static final Random rnd = new Random();
    private static final Logger LOGGER = Logger.getAnonymousLogger();

    public static void startAndWaitTerminated(Thread... threads)
            throws InterruptedException {
        if (null == threads) {
            throw new IllegalArgumentException("threads is null!");
        }
        for (Thread t : threads) {
            t.start();
        }
        for (Thread t : threads) {
            t.join();
        }
    }

    public static void startThread(Thread... threads) {
        if (null == threads) {
            throw new IllegalArgumentException("threads is null!");
        }
        for (Thread t : threads) {
            t.start();
        }
    }

    public static void startAndWaitTerminated(Iterable threads)
            throws InterruptedException {
        if (null == threads) {
            throw new IllegalArgumentException("threads is null!");
        }
        for (Thread t : threads) {
            t.start();
        }
        for (Thread t : threads) {
            t.join();
        }
    }

    public static void randomPause(int maxPauseTime) {
        int sleepTime = rnd.nextInt(maxPauseTime);
        try {
            Thread.sleep(sleepTime);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    public static void randomPause(int maxPauseTime, int minPauseTime) {
        int sleepTime = maxPauseTime == minPauseTime ? minPauseTime : rnd
                .nextInt(maxPauseTime - minPauseTime) + minPauseTime;
        try {
            Thread.sleep(sleepTime);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    public static Unsafe getUnsafe() {
        try {
            Field f = Unsafe.class.getDeclaredField("theUnsafe");
            ((Field) f).setAccessible(true);
            return (Unsafe) f.get(null);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    public static void silentClose(Closeable... closeable) {
        if (null == closeable) {
            return;
        }
        for (Closeable c : closeable) {
            if (null == c) {
                continue;
            }
            try {
                c.close();
            } catch (Exception ignored) {
            }
        }
    }

    public static void split(String str, String[] result, char delimeter) {
        int partsCount = result.length;
        int posOfDelimeter;
        int fromIndex = 0;
        String recordField;
        int i = 0;
        while (i < partsCount) {
            posOfDelimeter = str.indexOf(delimeter, fromIndex);
            if (-1 == posOfDelimeter) {
                recordField = str.substring(fromIndex);
                result[i] = recordField;
                break;
            }
            recordField = str.substring(fromIndex, posOfDelimeter);
            result[i] = recordField;
            i++;
            fromIndex = posOfDelimeter + 1;
        }
    }

    public static void log(String message) {
        LOGGER.log(Level.INFO, message);
    }

    public static String md5sum(final InputStream in) throws NoSuchAlgorithmException, IOException {
        MessageDigest md = MessageDigest.getInstance("MD5");
        byte[] buf = new byte[1024];
        try (DigestInputStream dis = new DigestInputStream(in, md)) {
            while (-1 != dis.read(buf))
                ;
        }
        byte[] digest = md.digest();
        BigInteger bigInt = new BigInteger(1, digest);
        String checkSum = bigInt.toString(16);

        while (checkSum.length() < 32) {
            checkSum = "0" + checkSum;
        }
        return checkSum;
    }

    public static String md5sum(final File file) throws NoSuchAlgorithmException, IOException {
        return md5sum(new BufferedInputStream(new FileInputStream(file)));
    }

    public static String md5sum(String str) throws NoSuchAlgorithmException, IOException {
        ByteArrayInputStream in = new ByteArrayInputStream(str.getBytes("UTF-8"));
        return md5sum(in);
    }

    public static void delayedAction(String prompt, Runnable action, int delay/* seconds */) {
        Debug.info("%s in %d seconds.", prompt, delay);
        try {
            Thread.sleep(delay * 1000);
        } catch (InterruptedException ignored) {
        }
        action.run();
    }

    public static Object newInstanceOf(String className) throws InstantiationException,
            IllegalAccessException, ClassNotFoundException {
        return Class.forName(className).newInstance();
    }

}

参考文献:

《java多线程编程实战指南》

你可能感兴趣的:(Java读书笔记,java,开发语言)