理系学生日記

おまえはいつまで学生気分なのか

複数のスレッドで同時にある処理を起動させたい

こういうときは CountDownLatch を使う。
latch、英語としては「掛け金」とかそういう意味だけど、その名の通り、スレッドに対して掛け金をかけておいて、掛け金を外したらスレッドが猪突猛進に走り出すイメージ。CountDownLatch では、コンストラクタに与えた数を countdown() で減らしていって、その値が 0 になったら掛け金が外れ、掛け金が外れるのを待っていたスレッドが一斉に走り出すことになる。
最近だと、DB の悲観ロックがきちんと効いていることを確認するテストのために、CountDownLatch 使ったテストを書いた。

簡単なサンプルとしては以下のようになる。ここでは、3 つのスレッドの同期を取るために、CountDownLatch に 3 を設定している。

import java.util.concurrent.CountDownLatch;

public class CountdownLatchTest {
        
    public static void main(String[] args) throws Exception {
        // スレッド数
        final int THREAD_NUM = 3;

        CountDownLatch latch = new CountDownLatch(THREAD_NUM);
        Thread[] threadArray = new Thread[THREAD_NUM];
        
        for ( int i = 0; i < THREAD_NUM; i++ ) {
            Thread th = new Thread(new TestRunnable(latch));
            th.setName("thread" + i);
            threadArray[i] = th;
            // ここでスレッド起動
            th.start();

            // 起動タイミングをずらすために 1000 ms 待機
            Thread.sleep(1000);
        }
        
        for (Thread th : threadArray) {
            th.join();
        }
    }

}

class TestRunnable implements Runnable {
    private CountDownLatch latch;
    
    public TestRunnable(CountDownLatch latch) {
        this.latch = latch;
    }

    @Override
    public void run() {
        try {
            System.out.println(Thread.currentThread().getName() + "#run() starts at " + System.currentTimeMillis());
            
            latch.countDown();
            // ここで各スレッドは掛け金が外れるのを待つ
            latch.await();

            // ここに同時に実行したい処理を書く
            System.out.println(Thread.currentThread().getName() + "#run() passes the latch at " + System.currentTimeMillis());
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    
}

出力結果はこんなかんじ。スレッドの起動タイミングはおおよそ 1,000 ms ズレてるけど、latch.await() を抜けるタイミングは全て同じになっていることがわかる。

thread0#run() starts at 1404923870827
thread1#run() starts at 1404923871827
thread2#run() starts at 1404923872828
thread2#run() passes the latch at 1404923872828
thread0#run() passes the latch at 1404923872828
thread1#run() passes the latch at 1404923872828