Fork me on GitHub

一千万个数高效求和

背景

“一千万个数如何高效求和?”,当看到这个问题的时候,第一反应就是分段求和再相加,而JDK1.8提供的LongAdder类,就是通过分段求和再汇总的思想设计的。为了对比实践,我们先用单线程直接求和,然后再使用多线程求和。

单线程求和

1
2
3
4
5
6
7
8
9
10
11
// 单线程直接求和
public static int singleThreadSum(List<Integer> list){
long startTime = System.currentTimeMillis();
int sum = 0;
for (Integer num:list) {
sum+=num;
}
long endTime = System.currentTimeMillis();
System.out.println("单线程总和——>"+sum+" 耗时——>"+((endTime-startTime))+"毫秒");
return sum;
}

多线程求和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// 一千万个数
private static final int TOTAL_NUMBER = 10000000;

// 每个task求和的规模
private static final int SIZE_PER_TASK = 200000;
// 线程池
private static ThreadPoolExecutor executor = null;

static {
// 核心线程数 CPU数量 + 1
int corePoolSize = Runtime.getRuntime().availableProcessors() + 1;
int maxPoolSize = corePoolSize*2+1;
executor = new ThreadPoolExecutor(corePoolSize, maxPoolSize, 3, TimeUnit.SECONDS, new LinkedBlockingQueue<>());
}

// LogAdder多线程求和
public static int multiThreadLongAdderSum(List<Integer> list) throws InterruptedException {
long startTime = System.currentTimeMillis();
LongAdder longAdder = new LongAdder();
//拆分任务
List<List<Integer>> resultList = split(list,SIZE_PER_TASK);
final int taskSize = resultList.size();
final CountDownLatch countDownLatch = new CountDownLatch(taskSize);
for (int i = 0; i < taskSize; i++) {
List<Integer> subList = resultList.get(i);
executor.execute(()->{
try {
for (int num : subList) {
// 把每个task中的数字累加
longAdder.add(num);
}
} finally {
// task执行完成后,计数器减一
countDownLatch.countDown();
}

});
}
// 主线程等待所有子线程执行完成
countDownLatch.await();
long endTime = System.currentTimeMillis();
System.out.println("LongAdder多线程总和——>"+longAdder.intValue()+" 耗时——>"+((endTime-startTime))+"毫秒");
// 关闭线程池
executor.shutdown();
return longAdder.intValue();

}

// 多线程求和
public static int multiThreadSum(List<Integer> list) throws InterruptedException {
long startTime = System.currentTimeMillis();
//拆分任务
List<List<Integer>> resultList = split(list,SIZE_PER_TASK);
final int taskSize = resultList.size();
final CountDownLatch countDownLatch = new CountDownLatch(taskSize);
int[] result = new int[taskSize];
for (int i = 0; i < taskSize; i++) {
List<Integer> subList = resultList.get(i);
int index = i;
executor.execute(()->{
try {
for (int num : subList) {
// 把每个task中的数字累加
result[index]+=num;
}
} finally {
// task执行完成后,计数器减一
countDownLatch.countDown();
}

});
}
// 主线程等待所有子线程执行完成
countDownLatch.await();
int sum = 0;
for (int per:result) {
sum += per;
}
long endTime = System.currentTimeMillis();
System.out.println("多线程总和——>"+sum+" 耗时——>"+((endTime-startTime))+"毫秒");
// 关闭线程池
executor.shutdown();
return sum;

}

// 分割列表
public static final List<List<Integer>> split(List<Integer> rawList, int perSize){
List<List<Integer>> resultList = new ArrayList<List<Integer>>();
for (int i = 0; i < TOTAL_NUMBER ; i = i+SIZE_PER_TASK) {
List<Integer> list = rawList.subList(i,i+SIZE_PER_TASK);
resultList.add(list);
}
System.out.println("分割好的list->"+resultList.size());
return resultList;
}

运行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public static void main(String[] args) throws InterruptedException {
Random random = new Random();
List<Integer> list = new ArrayList<>();
for (int i = 0; i < TOTAL_NUMBER; i++) {
int temp = random.nextInt(10);
list.add(temp);
}

// 单线程执行
singleThreadSum(list);
// LongAdder多线程直接求和
multiThreadLongAdderSum(list);
// 多线程直接求和
multiThreadSum(list);

}

结果

1
2
单线程总和——>45006860 耗时——>24毫秒
LongAdder多线程总和——>45006860 耗时——>251毫秒
1
2
单线程总和——>45001009 耗时——>19毫秒
多线程总和——>45001009 耗时——>120毫秒

结果有点出乎意料,单线程比多线程花费时间更少。

JDK1.8的stream

1
2
3
4
5
6
7
8
// JDK1.8的stream
public static int streamSum(List<Integer> list) {
long start = System.currentTimeMillis();
int sum = list.stream().mapToInt(num -> num).sum();
long end = System.currentTimeMillis();
System.out.printf("stream方式计算结果:%d, 耗时:%d 毫秒",sum, (end - start));
return sum;
}

结果:

1
2
3
单线程总和——>45011698 耗时——>19毫秒
多线程总和——>45011698 耗时——>130毫秒
stream方式计算结果:45011698, 耗时:24 毫秒

JDK1.8的 parallelStream方式

parallelStream见名知意,就是并行的stream。

1
2
3
4
5
6
7
8
9
// JDK1.8的parallelStream方式
//parallelStream见名知意,就是并行的stream。
public static int parallelStreamSum(List<Integer> list) {
long start = System.currentTimeMillis();
int sum = list.parallelStream().mapToInt(num -> num).sum();
long end = System.currentTimeMillis();
System.out.printf("parallel stream方式计算结果:%d, 耗时:%d 毫秒",sum, (end - start));
return sum;
}

结果:

1
2
3
4
单线程总和——>45016893 耗时——>17毫秒
多线程总和——>45016893 耗时——>112毫秒
stream方式计算结果:45016893, 耗时:23 毫秒
parallel stream方式计算结果:45016893, 耗时:28 毫秒

ForkJoin方式

ForkJoin框架是JDK1.7提出的,用于拆分任务计算再合并计算结果的框架。

1
当我们需要执行大量的小任务时,有经验的Java开发人员都会采用线程池来高效执行这些小任务。然而,有一种任务,例如,对超过1000万个元素的数组进行排序,这种任务本身可以并发执行,但如何拆解成小任务需要在任务执行的过程中动态拆分。这样,大任务可以拆成小任务,小任务还可以继续拆成更小的任务,最后把任务的结果汇总合并,得到最终结果,这种模型就是Fork/Join模型。

ForkJoin框架的使用大致分为两个部分:实现ForkJoin任务、执行任务

  • 实现ForkJoin任务
    自定义类继承RecursiveTask(有返回值)或者RecursiveAction(无返回值),实现compute方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
/**
* 静态内部类的方式实现
* forkjoin任务
*/
static class SicForkJoinTask extends RecursiveTask<Integer> {
// 子任务计算区间开始
private Integer left;
// 子任务计算区间结束
private Integer right;
private int[] arr;

@Override
protected Integer compute() {
if (right - left < SIZE_PER_TASK) {
// 任务足够小时,直接计算
int sum = 0;
for (int i = left; i < right; i++) {
sum += arr[i];
}
return sum;
}
// 继续拆分任务
int middle = left + (right - left) / 2;
SicForkJoinTask leftTask = new SicForkJoinTask(arr, left, middle);
SicForkJoinTask rightTask = new SicForkJoinTask(arr, middle, right);
invokeAll(leftTask, rightTask);
Integer leftResult = leftTask.join();
Integer rightResult = rightTask.join();
return leftResult + rightResult;
}

public SicForkJoinTask(int[] arr, Integer left, Integer right) {
this.arr = arr;
this.left = left;
this.right = right;
}
}
  • 执行任务
    通过ForkJoinPool的invoke方法执行ForkJoin任务
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    // ForkJoin线程池
    private static final ForkJoinPool forkJoinPool = new ForkJoinPool();

    public static int forkJoinSum(int[] arr) {
    long start = System.currentTimeMillis();
    // 执行ForkJoin任务
    Integer sum = forkJoinPool.invoke(new SicForkJoinTask(arr, 0, TOTAL_NUMBER));
    long end = System.currentTimeMillis();
    System.out.printf("\nforkjoin方式计算结果:%d, 耗时:%d 毫秒", sum, (end - start));
    return sum;
    }

结果:

1
2
3
4
5
单线程总和——>44994415 耗时——>17毫秒
多线程总和——>44994415 耗时——>132毫秒
stream方式计算结果:44994415, 耗时:22 毫秒
parallel stream方式计算结果:44994415, 耗时:30 毫秒
forkjoin方式计算结果:44994415, 耗时:174 毫秒