vortex_io/dispatcher/
tokio.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
use std::future::Future;
use std::panic::resume_unwind;
use std::thread::JoinHandle;

use futures::channel::oneshot;
use tokio::task::{JoinHandle as TokioJoinHandle, LocalSet};
use vortex_error::{vortex_bail, vortex_panic, VortexResult};

use super::{Dispatch, JoinHandle as VortexJoinHandle};

trait TokioSpawn {
    fn spawn(self: Box<Self>) -> TokioJoinHandle<()>;
}

/// A [dispatcher][Dispatch] of IO operations that runs tasks on one of several
/// Tokio `current_thread` runtimes.
#[derive(Debug)]
pub(super) struct TokioDispatcher {
    submitter: flume::Sender<Box<dyn TokioSpawn + Send>>,
    threads: Vec<JoinHandle<()>>,
}

impl TokioDispatcher {
    pub fn new(num_threads: usize) -> Self {
        let (submitter, rx) = flume::unbounded();
        let threads: Vec<_> = (0..num_threads)
            .map(|tid| {
                let worker_thread =
                    std::thread::Builder::new().name(format!("tokio-dispatch-{tid}"));
                let rx: flume::Receiver<Box<dyn TokioSpawn + Send>> = rx.clone();

                worker_thread
                    .spawn(move || {
                        // Create a runtime-per-thread
                        let rt = tokio::runtime::Builder::new_current_thread()
                            // The dispatcher should not have any blocking work.
                            // Maybe in the future we can add this as a config param.
                            .max_blocking_threads(1)
                            .enable_all()
                            .build()
                            .unwrap_or_else(|e| {
                                vortex_panic!("TokioDispatcher new_current_thread build(): {e}")
                            });

                        rt.block_on(async move {
                            // Use a LocalSet so that all spawned tasks will run on the current thread. This allows
                            // spawning !Send futures.
                            LocalSet::new()
                                .run_until(async {
                                    while let Ok(task) = rx.recv_async().await {
                                        task.spawn();
                                    }
                                })
                                .await;
                        });
                    })
                    .unwrap_or_else(|e| vortex_panic!("TokioDispatcher worker thread spawn: {e}"))
            })
            .collect();

        Self { submitter, threads }
    }
}

/// Tasks that can be launched onto a runtime.
struct TokioTask<F, R> {
    task: F,
    result: oneshot::Sender<R>,
}

impl<F, Fut, R> TokioSpawn for TokioTask<F, R>
where
    F: FnOnce() -> Fut + Send + 'static,
    Fut: Future<Output = R>,
    R: Send + 'static,
{
    fn spawn(self: Box<Self>) -> TokioJoinHandle<()> {
        let TokioTask { task, result } = *self;
        tokio::task::spawn_local(async move {
            let task_output = task().await;
            result.send(task_output).ok();
        })
    }
}

impl Dispatch for TokioDispatcher {
    fn dispatch<F, Fut, R>(&self, task: F) -> VortexResult<VortexJoinHandle<R>>
    where
        F: (FnOnce() -> Fut) + Send + 'static,
        Fut: Future<Output = R> + 'static,
        R: Send + 'static,
    {
        let (tx, rx) = oneshot::channel();

        let task = TokioTask { result: tx, task };

        match self.submitter.send(Box::new(task)) {
            Ok(()) => Ok(VortexJoinHandle(rx)),
            Err(err) => vortex_bail!("Dispatcher error spawning task: {err}"),
        }
    }

    fn shutdown(self) -> VortexResult<()> {
        // drop the submitter.
        //
        // Each worker thread will receive an `Err(Canceled)`
        drop(self.submitter);
        for thread in self.threads {
            // Propagate any panics from the worker threads.
            // NOTE: currently, panics inside any of the tasks will not propagate to the LocalSet's join handle,
            // see https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html#panics-1
            thread.join().unwrap_or_else(|err| resume_unwind(err));
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use std::sync::atomic::{AtomicU32, Ordering};
    use std::sync::Arc;

    use super::TokioDispatcher;
    use crate::dispatcher::Dispatch;

    #[tokio::test]
    async fn test_tokio_dispatch_simple() {
        let dispatcher = TokioDispatcher::new(4);
        let atomic_number = Arc::new(AtomicU32::new(0));
        let atomic_number_clone = Arc::clone(&atomic_number);
        let rx = dispatcher
            .dispatch(|| async move {
                atomic_number_clone.fetch_add(1, Ordering::SeqCst);
            })
            .unwrap();

        rx.await.unwrap();
        assert_eq!(atomic_number.load(Ordering::SeqCst), 1u32);
    }
}