1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
//! Utterly inefficient cross-platform preemptive user-mode scheduling
use slab::Slab;
use std::{
    panic::{catch_unwind, AssertUnwindSafe},
    sync::OnceLock,
    sync::{mpsc, Arc},
    thread::Result,
};

use crate::threading;

type SlabPtr = usize;

#[cfg(test)]
mod tests;

/// Represents a dynamic set of threads that can be scheduled for execution by
/// `Sched: `[`Scheduler`].
#[derive(Debug)]
pub struct ThreadGroup<Sched: ?Sized> {
    state: Arc<threading::Mutex<State<Sched>>>,
}

impl<Sched: ?Sized> Clone for ThreadGroup<Sched> {
    fn clone(&self) -> Self {
        Self {
            state: Arc::clone(&self.state),
        }
    }
}

/// Object that can be used to join on a [`ThreadGroup`].
#[derive(Debug)]
pub struct ThreadGroupJoinHandle {
    result_recv: mpsc::Receiver<Result<()>>,
}

/// RAII guard returned by [`ThreadGroup::lock`].
pub struct ThreadGroupLockGuard<'a, Sched: ?Sized> {
    state_ref: &'a Arc<threading::Mutex<State<Sched>>>,
    guard: threading::MutexGuard<'a, State<Sched>>,
}

/// Identifies a thread in [`ThreadGroup`].
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct ThreadId(SlabPtr);

/// Encapsulates the state of a client-supplied user-mode scheduler.
pub trait Scheduler: Send + 'static {
    /// Choose the next thread to run.
    ///
    /// It's an error to return an already-exited thread. The client is
    /// responsible for tracking the lifetime of spawned threads.
    fn choose_next_thread(&mut self) -> Option<ThreadId>;

    /// Called when a thread exits.
    fn thread_exited(&mut self, thread_id: ThreadId) {
        let _ = thread_id;
    }
}

#[derive(Debug)]
struct State<Sched: ?Sized> {
    threads: Slab<WorkerThread>,
    num_threads: usize,
    cur_thread_id: Option<ThreadId>,
    shutting_down: bool,
    result_send: mpsc::Sender<Result<()>>,
    sched: Sched,
}

#[derive(Debug)]
struct WorkerThread {
    join_handle: Option<threading::JoinHandle<()>>,
}

thread_local! {
    static TLB: OnceLock<ThreadLocalBlock> = OnceLock::new();
}

struct ThreadLocalBlock {
    /// The current thread ID.
    thread_id: ThreadId,
    /// The thread group the current worker thread belongs to.
    state: Arc<threading::Mutex<State<dyn Scheduler>>>,
}

impl<Sched: Scheduler> ThreadGroup<Sched> {
    /// Construct a new `ThreadGroup` and the corresponding
    /// [`ThreadGroupJoinHandle`].
    pub fn new(sched: Sched) -> (Self, ThreadGroupJoinHandle) {
        let (send, recv) = mpsc::channel();

        let state = Arc::new(threading::Mutex::new(State {
            threads: Slab::new(),
            num_threads: 0,
            cur_thread_id: None,
            shutting_down: false,
            result_send: send,
            sched,
        }));

        (Self { state }, ThreadGroupJoinHandle { result_recv: recv })
    }
}

impl ThreadGroupJoinHandle {
    /// Wait for the thread group to shut down.
    pub fn join(self) -> Result<()> {
        self.result_recv.recv().unwrap()
    }
}

impl<Sched: Scheduler + ?Sized> ThreadGroup<Sched> {
    /// Acquire a lock on the thread group's state.
    pub fn lock(&self) -> ThreadGroupLockGuard<'_, Sched> {
        ThreadGroupLockGuard {
            state_ref: &self.state,
            guard: self.state.lock().unwrap(),
        }
    }
}

impl<'a, Sched: Scheduler> ThreadGroupLockGuard<'a, Sched> {
    /// Start a worker thread.
    ///
    /// This does not automatically schedule the spawned thread. You should
    /// store the obtained `ThreadId` in the contained `Sched: `[`Scheduler`]
    /// and have it chosen by [`Scheduler::choose_next_thread`] for the thread
    /// to actually run.
    ///
    /// The following functions are avabile for use inside a worker thread. You
    /// should use them instead of the same named methods defined in other
    /// places.
    ///
    ///  - [`exit_thread`]
    ///  - [`yield_now`]
    ///
    pub fn spawn(&mut self, f: impl FnOnce(ThreadId) + Send + 'static) -> ThreadId {
        if self.guard.shutting_down && self.guard.num_threads == 0 {
            panic!("thread group has already been shut down");
        }

        let state = Arc::clone(self.state_ref);

        // Allocate a `ThreadId`
        let ptr: SlabPtr = self
            .guard
            .threads
            .insert(WorkerThread { join_handle: None });
        let thread_id = ThreadId(ptr);
        self.guard.num_threads += 1;

        let join_handle = threading::spawn(move || {
            let state2 = Arc::clone(&state);
            TLB.with(|cell| {
                cell.set(ThreadLocalBlock { thread_id, state })
                    .ok()
                    .unwrap()
            });

            // Block thw spawned thread until scheduled to run
            threading::park();

            // Call the thread entry point
            let result = catch_unwind(AssertUnwindSafe(move || {
                f(thread_id);
            }));

            finalize_thread(state2, thread_id, result);
        });

        // Save the `JoinHandle` representing the spawned thread
        self.guard.threads[ptr].join_handle = Some(join_handle);

        log::trace!("created {thread_id:?}");

        thread_id
    }

    /// Preempt the thread group to let the scheduler decide the next thread
    /// to run.
    ///
    /// Calling this method from a worker thread is not allowed.
    pub fn preempt(&mut self) {
        assert!(
            TLB.with(|cell| cell.get().is_none()),
            "this method cannot be called from a worker thread"
        );

        // Preeempt the current thread
        let guard = &mut *self.guard;
        log::trace!("preempting {:?}", guard.cur_thread_id);
        if let Some(thread_id) = guard.cur_thread_id {
            let join_handle = guard.threads[thread_id.0].join_handle.as_ref().unwrap();
            join_handle.thread().park();
        }

        guard.unpark_next_thread();
    }

    /// Initiate graceful shutdown for the thread group.
    ///
    /// The shutdown completes when all threads complete execution. After this
    /// happens, the system will not call [`Scheduler::choose_next_thread`]
    /// anymore. [`ThreadGroupJoinHandle::join`] will unblock, returning
    /// `Ok(())`.
    pub fn shutdown(&mut self) {
        if self.guard.shutting_down {
            return;
        }
        log::trace!("shutdown requested");
        self.guard.shutting_down = true;
        if self.guard.num_threads == 0 {
            self.guard.complete_shutdown();
        } else {
            log::trace!(
                "shutdown is pending because there are {} thread(s) remaining",
                self.guard.num_threads
            );
        }
    }
}

impl<'a, Sched: Scheduler + ?Sized> ThreadGroupLockGuard<'a, Sched> {
    /// Get a mutable reference to the contained `Sched: `[`Scheduler`].
    pub fn scheduler(&mut self) -> &mut Sched {
        &mut self.guard.sched
    }
}

impl<Sched: Scheduler> State<Sched> {
    fn unpark_next_thread(&mut self) {
        (self as &mut State<dyn Scheduler>).unpark_next_thread();
    }

    fn complete_shutdown(&mut self) {
        (self as &mut State<dyn Scheduler>).complete_shutdown();
    }
}

impl State<dyn Scheduler> {
    /// Find the next thread to run and unpark that thread.
    fn unpark_next_thread(&mut self) {
        self.cur_thread_id = self.sched.choose_next_thread();
        log::trace!("scheduling {:?}", self.cur_thread_id);
        if let Some(thread_id) = self.cur_thread_id {
            let join_handle = self.threads[thread_id.0].join_handle.as_ref().unwrap();
            join_handle.thread().unpark();
        }
    }

    fn complete_shutdown(&mut self) {
        assert_eq!(self.num_threads, 0);
        log::trace!("shutdown is complete");

        // Ignore if the receiver has already hung up
        let _ = self.result_send.send(Ok(()));
    }
}

/// Voluntarily yield the processor to let the scheduler decide the next thread
/// to run.
///
/// Panics if the current thread is not a worker thread of some [`ThreadGroup`].
pub fn yield_now() {
    let thread_group: Arc<threading::Mutex<State<dyn Scheduler>>> = TLB
        .with(|cell| cell.get().map(|tlb| Arc::clone(&tlb.state)))
        .expect("current thread does not belong to a thread group");

    {
        let mut state_guard = thread_group.lock().unwrap();
        log::trace!("{:?} yielded the processor", state_guard.cur_thread_id);
        state_guard.unpark_next_thread();
    }

    // Block thw thread until scheduled to run. This might end immediately if
    // the current thread is the next thread to run.
    threading::park();
}

/// Terminate the current worker thread.
///
/// Panics if the current thread is not a worker thread of some [`ThreadGroup`].
///
/// # Safety
///
/// It comes with all the unsafety of terminating a thread, such as that it
/// could unpin pinned local variables.
pub unsafe fn exit_thread() -> ! {
    let (thread_id, thread_group) = TLB
        .with(|cell| {
            cell.get()
                .map(|tlb| (tlb.thread_id, Arc::clone(&tlb.state)))
        })
        .expect("current thread does not belong to a thread group");

    finalize_thread(thread_group, thread_id, Ok(()));

    // Safety: Inherited
    unsafe { threading::exit_thread() };
}

/// Mark the specified thread as exited.
fn finalize_thread(
    thread_group: Arc<threading::Mutex<State<dyn Scheduler>>>,
    thread_id: ThreadId,
    result: Result<()>,
) {
    log::trace!("{thread_id:?} exited with result {result:?}");

    // Delete the current thread
    let mut state_guard = thread_group.lock().unwrap();
    state_guard.sched.thread_exited(thread_id);
    state_guard.threads.remove(thread_id.0);
    state_guard.num_threads -= 1;

    if let Err(e) = result {
        // Send the panic payload to the thread group's owner.
        // Leave other threads hanging because there's no way to
        // terminate them safely.
        // This should be at least sufficient for running tests and
        // apps with `panic = "abort"`.
        let _ = state_guard.result_send.send(Err(e));
        return;
    }

    if state_guard.num_threads == 0 && state_guard.shutting_down {
        // Complete the shutdown
        state_guard.complete_shutdown();
        return;
    }

    // Invoke the scheduler
    state_guard.unpark_next_thread();
}

/// Get the current worker thread.
pub fn current_thread() -> Option<ThreadId> {
    TLB.with(|cell| cell.get().map(|tlb| tlb.thread_id))
}