Problem

Suppose we have a class:

1
2
3
4
5
public class Foo {
  public void first() { print("first"); }
  public void second() { print("second"); }
  public void third() { print("third"); }
}

The same instance of Foo will be passed to three different threads. Thread A will call first(), thread B will call second(), and thread C will call third(). Design a mechanism and modify the program to ensure that second() is executed after first(), and third() is executed after second().

Note:

We do not know how the threads will be scheduled in the operating system, even though the numbers in the input seem to imply the ordering. The input format you see is mainly to ensure our tests’ comprehensiveness.

Examples

Example 1:

1
2
3
Input: nums = [1,2,3]
Output: "firstsecondthird"
Explanation: There are three threads being fired asynchronously. The input [1,2,3] means thread A calls first(), thread B calls second(), and thread C calls third(). "firstsecondthird" is the correct output.

Example 2:

1
2
3
Input: nums = [1,3,2]
Output: "firstsecondthird"
Explanation: The input [1,3,2] means thread A calls first(), thread B calls third(), and thread C calls second(). "firstsecondthird" is the correct output.

Solution

Method 1 – Using Semaphores / Locks

Intuition

We need to synchronize three threads so that first() runs before second(), and second() before third(). Semaphores, locks, or condition variables can enforce this order.

Approach

Use two semaphores (or equivalents) to block second() and third() until their turn. Release the next semaphore at the end of each method.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import java.util.concurrent.Semaphore;
class Foo {
    private Semaphore s2 = new Semaphore(0);
    private Semaphore s3 = new Semaphore(0);
    public void first(Runnable printFirst) throws InterruptedException {
        printFirst.run();
        s2.release();
    }
    public void second(Runnable printSecond) throws InterruptedException {
        s2.acquire();
        printSecond.run();
        s3.release();
    }
    public void third(Runnable printThird) throws InterruptedException {
        s3.acquire();
        printThird.run();
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <mutex>
#include <condition_variable>
class Foo {
    int state = 0;
    std::mutex mtx;
    std::condition_variable cv;
public:
    void first(function<void()> printFirst) {
        std::unique_lock<std::mutex> lock(mtx);
        printFirst();
        state = 1;
        cv.notify_all();
    }
    void second(function<void()> printSecond) {
        std::unique_lock<std::mutex> lock(mtx);
        cv.wait(lock, [&]{ return state == 1; });
        printSecond();
        state = 2;
        cv.notify_all();
    }
    void third(function<void()> printThird) {
        std::unique_lock<std::mutex> lock(mtx);
        cv.wait(lock, [&]{ return state == 2; });
        printThird();
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
import threading
class Foo:
    def __init__(self):
        self.s2 = threading.Event()
        self.s3 = threading.Event()
    def first(self, printFirst):
        printFirst()
        self.s2.set()
    def second(self, printSecond):
        self.s2.wait()
        printSecond()
        self.s3.set()
    def third(self, printThird):
        self.s3.wait()
        printThird()
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
type Foo struct {
    s2 chan struct{}
    s3 chan struct{}
}
func NewFoo() *Foo {
    return &Foo{s2: make(chan struct{}), s3: make(chan struct{})}
}
func (f *Foo) First(printFirst func()) {
    printFirst()
    f.s2 <- struct{}{}
}
func (f *Foo) Second(printSecond func()) {
    <-f.s2
    printSecond()
    f.s3 <- struct{}{}
}
func (f *Foo) Third(printThird func()) {
    <-f.s3
    printThird()
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import java.util.concurrent.Semaphore
class Foo {
    private val s2 = Semaphore(0)
    private val s3 = Semaphore(0)
    fun first(printFirst: () -> Unit) {
        printFirst()
        s2.release()
    }
    fun second(printSecond: () -> Unit) {
        s2.acquire()
        printSecond()
        s3.release()
    }
    fun third(printThird: () -> Unit) {
        s3.acquire()
        printThird()
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
use std::sync::{Arc, Barrier};
struct Foo {
    b1: Arc<Barrier>,
    b2: Arc<Barrier>,
}
impl Foo {
    fn new() -> Self {
        Self { b1: Arc::new(Barrier::new(2)), b2: Arc::new(Barrier::new(2)) }
    }
    fn first(&self, print_first: impl FnOnce()) {
        print_first();
        self.b1.wait();
    }
    fn second(&self, print_second: impl FnOnce()) {
        self.b1.wait();
        print_second();
        self.b2.wait();
    }
    fn third(&self, print_third: impl FnOnce()) {
        self.b2.wait();
        print_third();
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Foo {
    private p2: Promise<void>;
    private p3: Promise<void>;
    private r2!: () => void;
    private r3!: () => void;
    constructor() {
        this.p2 = new Promise(res => this.r2 = res);
        this.p3 = new Promise(res => this.r3 = res);
    }
    async first(printFirst: () => void) {
        printFirst();
        this.r2();
    }
    async second(printSecond: () => void) {
        await this.p2;
        printSecond();
        this.r3();
    }
    async third(printThird: () => void) {
        await this.p3;
        printThird();
    }
}

Complexity

  • ⏰ Time complexity: O(1) for each method call.
  • 🧺 Space complexity: O(1) for synchronization primitives.