Java Concurrency - Fork-Join framework



The fork-join framework allows to break a certain task on several workers and then wait for the result to combine them. It leverages multi-processor machine's capacity to great extent. Following are the core concepts and objects used in fork-join framework.

Fork

Fork is a process in which a task splits itself into smaller and independent sub-tasks which can be executed concurrently.

Syntax

Sum left  = new Sum(array, low, mid);
left.fork();

Here Sum is a subclass of RecursiveTask and left.fork() spilts the task into sub-tasks.

Join

Join is a process in which a task join all the results of sub-tasks once the subtasks have finished executing, otherwise it keeps waiting.

Syntax

left.join();

Here left is an object of Sum class.

ForkJoinPool

it is a special thread pool designed to work with fork-and-join task splitting.

Syntax

ForkJoinPool forkJoinPool = new ForkJoinPool(4);

Here a new ForkJoinPool with a parallelism level of 4 CPUs.

RecursiveAction

RecursiveAction represents a task which does not return any value.

Syntax

class Writer extends RecursiveAction {
   @Override
   protected void compute() { }
}

RecursiveTask

RecursiveTask represents a task which returns a value.

Syntax

class Sum extends RecursiveTask<Long> {
   @Override
   protected Long compute() { return null; }
}

Example

The following TestThread program shows usage of Fork-Join framework in thread based environment.

import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.RecursiveTask; public class TestThread { public static void main(final String[] arguments) throws InterruptedException, ExecutionException { int nThreads = Runtime.getRuntime().availableProcessors(); System.out.println(nThreads); int[] numbers = new int[1000]; for(int i = 0; i < numbers.length; i++) { numbers[i] = i; } ForkJoinPool forkJoinPool = new ForkJoinPool(nThreads); Long result = forkJoinPool.invoke(new Sum(numbers,0,numbers.length)); System.out.println(result); } static class Sum extends RecursiveTask<Long> { int low; int high; int[] array; Sum(int[] array, int low, int high) { this.array = array; this.low = low; this.high = high; } protected Long compute() { if(high - low <= 10) { long sum = 0; for(int i = low; i < high; ++i) sum += array[i]; return sum; } else { int mid = low + (high - low) / 2; Sum left = new Sum(array, low, mid); Sum right = new Sum(array, mid, high); left.fork(); long rightResult = right.compute(); long leftResult = left.join(); return leftResult + rightResult; } } } }

This will produce the following result.

Output

32
499500
Advertisements