dfir_rs/scheduled/
context.rs

1//! Module for the user-facing [`Context`] object.
2//!
3//! Provides APIs for state and scheduling.
4
5use std::any::Any;
6use std::cell::Cell;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::ops::DerefMut;
11use std::pin::Pin;
12
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::task::JoinHandle;
15use web_time::SystemTime;
16
17use super::graph::StateLifespan;
18use super::state::StateHandle;
19use super::{LoopId, LoopTag, StateId, StateTag, SubgraphId, SubgraphTag};
20use crate::scheduled::ticks::TickInstant;
21use crate::util::priority_stack::PriorityStack;
22use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
23
24/// The main state and scheduler of the runtime instance. Provided as the `context` API to each
25/// subgraph/operator as it is run.
26///
27/// Each instance stores eactly one Context inline. Before the `Context` is provided to
28/// a running operator, the `subgraph_id` field must be updated.
29pub struct Context {
30    /// Storage for the user-facing State API.
31    states: SlotVec<StateTag, StateData>,
32
33    /// Priority stack for handling strata within loops. Prioritized by loop depth.
34    pub(super) stratum_stack: PriorityStack<usize>,
35
36    /// Stack of loop nonces. Used to identify when a new loop iteration begins.
37    pub(super) loop_nonce_stack: Vec<usize>,
38
39    /// TODO(mingwei):
40    /// used for loop iteration scheduling.
41    pub(super) schedule_deferred: Vec<SubgraphId>,
42
43    /// TODO(mingwei): separate scheduler into its own struct/trait?
44    /// Index is stratum, value is FIFO queue for that stratum.
45    pub(super) stratum_queues: Vec<VecDeque<SubgraphId>>,
46
47    /// Receive events, if second arg indicates if it is an external "important" event (true).
48    pub(super) event_queue_recv: UnboundedReceiver<(SubgraphId, bool)>,
49    /// If external events or data can justify starting the next tick.
50    pub(super) can_start_tick: bool,
51    /// If the events have been received for this tick.
52    pub(super) events_received_tick: bool,
53
54    // TODO(mingwei): as long as this is here, it's impossible to know when all work is done.
55    // Second field (bool) is for if the event is an external "important" event (true).
56    pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,
57
58    /// If the current subgraph wants to reschedule the current loop block (in the current tick).
59    pub(super) reschedule_loop_block: Cell<bool>,
60    pub(super) allow_another_iteration: Cell<bool>,
61
62    pub(super) current_tick: TickInstant,
63    pub(super) current_stratum: usize,
64
65    pub(super) current_tick_start: SystemTime,
66    pub(super) is_first_run_this_tick: bool,
67    pub(super) loop_iter_count: usize,
68
69    /// Depth of loop (zero for top-level).
70    pub(super) loop_depth: SlotVec<LoopTag, usize>,
71    /// For each loop, state which needs to be reset between loop executions.
72    loop_states: SecondarySlotVec<LoopTag, Vec<StateId>>,
73    /// Used to differentiate between loop executions. Incremented at the start of each loop execution.
74    pub(super) loop_nonce: usize,
75
76    /// For each subgraph, state which needs to be reset between executions.
77    subgraph_states: SecondarySlotVec<SubgraphTag, Vec<StateId>>,
78
79    /// The SubgraphId of the currently running operator. When this context is
80    /// not being forwarded to a running operator, this field is meaningless.
81    pub(super) subgraph_id: SubgraphId,
82
83    tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
84    /// Join handles for spawned tasks.
85    task_join_handles: Vec<JoinHandle<()>>,
86}
87/// Public APIs.
88impl Context {
89    /// Gets the current tick (local time) count.
90    pub fn current_tick(&self) -> TickInstant {
91        self.current_tick
92    }
93
94    /// Gets the timestamp of the beginning of the current tick.
95    pub fn current_tick_start(&self) -> SystemTime {
96        self.current_tick_start
97    }
98
99    /// Gets whether this is the first time this subgraph is being scheduled for this tick
100    pub fn is_first_run_this_tick(&self) -> bool {
101        self.is_first_run_this_tick
102    }
103
104    /// Gets the current loop iteration count.
105    pub fn loop_iter_count(&self) -> usize {
106        self.loop_iter_count
107    }
108
109    /// Gets the current stratum nubmer.
110    pub fn current_stratum(&self) -> usize {
111        self.current_stratum
112    }
113
114    /// Gets the ID of the current subgraph.
115    pub fn current_subgraph(&self) -> SubgraphId {
116        self.subgraph_id
117    }
118
119    /// Schedules a subgraph for the next tick.
120    ///
121    /// If `is_external` is `true`, the scheduling will trigger the next tick to begin. If it is
122    /// `false` then scheduling will be lazy and the next tick will not begin unless there is other
123    /// reason to.
124    pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
125        self.event_queue_send.send((sg_id, is_external)).unwrap()
126    }
127
128    /// Schedules the current loop block to be run again (_in this tick_).
129    pub fn reschedule_loop_block(&self) {
130        self.reschedule_loop_block.set(true);
131    }
132
133    /// Allow another iteration of the loop, if more data comes.
134    pub fn allow_another_iteration(&self) {
135        self.allow_another_iteration.set(true);
136    }
137
138    /// Returns a `Waker` for interacting with async Rust.
139    /// Waker events are considered to be extenral.
140    pub fn waker(&self) -> std::task::Waker {
141        use std::sync::Arc;
142        use std::task::Wake;
143
144        struct ContextWaker {
145            subgraph_id: SubgraphId,
146            event_queue_send: UnboundedSender<(SubgraphId, bool)>,
147        }
148        impl Wake for ContextWaker {
149            fn wake(self: Arc<Self>) {
150                self.wake_by_ref();
151            }
152
153            fn wake_by_ref(self: &Arc<Self>) {
154                let _recv_closed_error = self.event_queue_send.send((self.subgraph_id, true));
155            }
156        }
157
158        let context_waker = ContextWaker {
159            subgraph_id: self.subgraph_id,
160            event_queue_send: self.event_queue_send.clone(),
161        };
162        std::task::Waker::from(Arc::new(context_waker))
163    }
164
165    /// Returns a shared reference to the state.
166    ///
167    /// # Safety
168    /// `StateHandle<T>` must be from _this_ instance, created via [`Self::add_state`].
169    pub unsafe fn state_ref_unchecked<T>(&self, handle: StateHandle<T>) -> &'_ T
170    where
171        T: Any,
172    {
173        let state = self
174            .states
175            .get(handle.state_id)
176            .expect("Failed to find state with given handle.")
177            .state
178            .as_ref();
179
180        debug_assert!(state.is::<T>());
181
182        unsafe {
183            // SAFETY: `handle` is from this instance.
184            // TODO(shadaj): replace with `downcast_ref_unchecked` when it's stabilized
185            &*(state as *const dyn Any as *const T)
186        }
187    }
188
189    /// Returns a shared reference to the state.
190    pub fn state_ref<T>(&self, handle: StateHandle<T>) -> &'_ T
191    where
192        T: Any,
193    {
194        self.states
195            .get(handle.state_id)
196            .expect("Failed to find state with given handle.")
197            .state
198            .downcast_ref()
199            .expect("StateHandle wrong type T for casting.")
200    }
201
202    /// Returns an exclusive reference to the state.
203    pub fn state_mut<T>(&mut self, handle: StateHandle<T>) -> &'_ mut T
204    where
205        T: Any,
206    {
207        self.states
208            .get_mut(handle.state_id)
209            .expect("Failed to find state with given handle.")
210            .state
211            .downcast_mut()
212            .expect("StateHandle wrong type T for casting.")
213    }
214
215    /// Adds state to the context and returns the handle.
216    pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
217    where
218        T: Any,
219    {
220        let state_data = StateData {
221            state: Box::new(state),
222            lifespan_hook_fn: None,
223            lifespan: None,
224        };
225        let state_id = self.states.insert(state_data);
226
227        StateHandle {
228            state_id,
229            _phantom: PhantomData,
230        }
231    }
232
233    /// Sets a hook to modify the state at the end of each tick, using the supplied closure.
234    pub fn set_state_lifespan_hook<T>(
235        &mut self,
236        handle: StateHandle<T>,
237        lifespan: StateLifespan,
238        mut hook_fn: impl 'static + FnMut(&mut T),
239    ) where
240        T: Any,
241    {
242        let state_data = self
243            .states
244            .get_mut(handle.state_id)
245            .expect("Failed to find state with given handle.");
246        state_data.lifespan_hook_fn = Some(Box::new(move |state| {
247            (hook_fn)(state.downcast_mut::<T>().unwrap());
248        }));
249        state_data.lifespan = Some(lifespan);
250
251        match lifespan {
252            StateLifespan::Subgraph(key) => {
253                self.subgraph_states
254                    .get_or_insert_with(key, Vec::new)
255                    .push(handle.state_id);
256            }
257            StateLifespan::Loop(loop_id) => {
258                self.loop_states
259                    .get_or_insert_with(loop_id, Vec::new)
260                    .push(handle.state_id);
261            }
262            StateLifespan::Tick => {
263                // Already included in `run_state_hooks_tick`.
264            }
265            StateLifespan::Static => {
266                // Never resets.
267            }
268        }
269    }
270
271    /// Prepares an async task to be launched by [`Self::spawn_tasks`].
272    pub fn request_task<Fut>(&mut self, future: Fut)
273    where
274        Fut: Future<Output = ()> + 'static,
275    {
276        self.tasks_to_spawn.push(Box::pin(future));
277    }
278
279    /// Launches all tasks requested with [`Self::request_task`] on the internal Tokio executor.
280    pub fn spawn_tasks(&mut self) {
281        for task in self.tasks_to_spawn.drain(..) {
282            self.task_join_handles.push(tokio::task::spawn_local(task));
283        }
284    }
285
286    /// Aborts all tasks spawned with [`Self::spawn_tasks`].
287    pub fn abort_tasks(&mut self) {
288        for task in self.task_join_handles.drain(..) {
289            task.abort();
290        }
291    }
292
293    /// Waits for all tasks spawned with [`Self::spawn_tasks`] to complete.
294    ///
295    /// Will probably just hang.
296    pub async fn join_tasks(&mut self) {
297        futures::future::join_all(self.task_join_handles.drain(..)).await;
298    }
299}
300
301impl Default for Context {
302    fn default() -> Self {
303        let stratum_queues = vec![Default::default()]; // Always initialize stratum #0.
304        let (event_queue_send, event_queue_recv) = mpsc::unbounded_channel();
305        let (stratum_stack, loop_depth) = Default::default();
306        Self {
307            states: SlotVec::new(),
308
309            stratum_stack,
310
311            loop_nonce_stack: Vec::new(),
312
313            schedule_deferred: Vec::new(),
314
315            stratum_queues,
316            event_queue_recv,
317            can_start_tick: false,
318            events_received_tick: false,
319
320            event_queue_send,
321            reschedule_loop_block: Cell::new(false),
322            allow_another_iteration: Cell::new(false),
323
324            current_stratum: 0,
325            current_tick: TickInstant::default(),
326
327            current_tick_start: SystemTime::now(),
328            is_first_run_this_tick: false,
329            loop_iter_count: 0,
330
331            loop_depth,
332            loop_states: SecondarySlotVec::new(),
333            loop_nonce: 0,
334
335            subgraph_states: SecondarySlotVec::new(),
336
337            // Will be re-set before use.
338            subgraph_id: SubgraphId::from_raw(0),
339
340            tasks_to_spawn: Vec::new(),
341            task_join_handles: Vec::new(),
342        }
343    }
344}
345/// Internal APIs.
346impl Context {
347    /// Makes sure stratum STRATUM is initialized.
348    pub(super) fn init_stratum(&mut self, stratum: usize) {
349        if self.stratum_queues.len() <= stratum {
350            self.stratum_queues
351                .resize_with(stratum + 1, Default::default);
352        }
353    }
354
355    /// Call this at the end of a tick,
356    pub(super) fn run_state_hooks_tick(&mut self) {
357        tracing::trace!("Running state hooks for tick.");
358        for state_data in self.states.values_mut() {
359            let StateData {
360                state,
361                lifespan_hook_fn: Some(lifespan_hook_fn),
362                lifespan: Some(StateLifespan::Tick),
363            } = state_data
364            else {
365                continue;
366            };
367            (lifespan_hook_fn)(Box::deref_mut(state));
368        }
369    }
370
371    pub(super) fn run_state_hooks_subgraph(&mut self, subgraph_id: SubgraphId) {
372        tracing::trace!("Running state hooks for subgraph.");
373        for state_id in self.subgraph_states.get(subgraph_id).into_iter().flatten() {
374            let StateData {
375                state,
376                lifespan_hook_fn,
377                lifespan: _,
378            } = self
379                .states
380                .get_mut(*state_id)
381                .expect("Failed to find state with given ID.");
382
383            if let Some(lifespan_hook_fn) = lifespan_hook_fn {
384                (lifespan_hook_fn)(Box::deref_mut(state));
385            }
386        }
387    }
388
389    // Run the state hooks for each state in the loop.
390    // Call at the end of each loop execution.
391    pub(super) fn run_state_hooks_loop(&mut self, loop_id: LoopId) {
392        tracing::trace!(
393            loop_id = loop_id.to_string(),
394            "Running state hooks for loop."
395        );
396        for state_id in self.loop_states.get(loop_id).into_iter().flatten() {
397            let StateData {
398                state,
399                lifespan_hook_fn,
400                lifespan: _,
401            } = self
402                .states
403                .get_mut(*state_id)
404                .expect("Failed to find state with given ID.");
405
406            if let Some(lifespan_hook_fn) = lifespan_hook_fn {
407                (lifespan_hook_fn)(Box::deref_mut(state));
408            }
409        }
410    }
411}
412
413/// Internal struct containing a pointer to instance-owned state.
414struct StateData {
415    state: Box<dyn Any>,
416    lifespan_hook_fn: Option<LifespanResetFn>, // TODO(mingwei): replace with trait?
417    /// `None` for static.
418    lifespan: Option<StateLifespan>,
419}
420type LifespanResetFn = Box<dyn FnMut(&mut dyn Any)>;