1use std::any::Any;
6use std::cell::Cell;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::ops::DerefMut;
11use std::pin::Pin;
12
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::task::JoinHandle;
15use web_time::SystemTime;
16
17use super::graph::StateLifespan;
18use super::state::StateHandle;
19use super::{LoopId, LoopTag, StateId, StateTag, SubgraphId, SubgraphTag};
20use crate::scheduled::ticks::TickInstant;
21use crate::util::priority_stack::PriorityStack;
22use crate::util::slot_vec::{SecondarySlotVec, SlotVec};
23
24pub struct Context {
30 states: SlotVec<StateTag, StateData>,
32
33 pub(super) stratum_stack: PriorityStack<usize>,
35
36 pub(super) loop_nonce_stack: Vec<usize>,
38
39 pub(super) schedule_deferred: Vec<SubgraphId>,
42
43 pub(super) stratum_queues: Vec<VecDeque<SubgraphId>>,
46
47 pub(super) event_queue_recv: UnboundedReceiver<(SubgraphId, bool)>,
49 pub(super) can_start_tick: bool,
51 pub(super) events_received_tick: bool,
53
54 pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,
57
58 pub(super) reschedule_loop_block: Cell<bool>,
60 pub(super) allow_another_iteration: Cell<bool>,
61
62 pub(super) current_tick: TickInstant,
63 pub(super) current_stratum: usize,
64
65 pub(super) current_tick_start: SystemTime,
66 pub(super) is_first_run_this_tick: bool,
67 pub(super) loop_iter_count: usize,
68
69 pub(super) loop_depth: SlotVec<LoopTag, usize>,
71 loop_states: SecondarySlotVec<LoopTag, Vec<StateId>>,
73 pub(super) loop_nonce: usize,
75
76 subgraph_states: SecondarySlotVec<SubgraphTag, Vec<StateId>>,
78
79 pub(super) subgraph_id: SubgraphId,
82
83 tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
84 task_join_handles: Vec<JoinHandle<()>>,
86}
87impl Context {
89 pub fn current_tick(&self) -> TickInstant {
91 self.current_tick
92 }
93
94 pub fn current_tick_start(&self) -> SystemTime {
96 self.current_tick_start
97 }
98
99 pub fn is_first_run_this_tick(&self) -> bool {
101 self.is_first_run_this_tick
102 }
103
104 pub fn loop_iter_count(&self) -> usize {
106 self.loop_iter_count
107 }
108
109 pub fn current_stratum(&self) -> usize {
111 self.current_stratum
112 }
113
114 pub fn current_subgraph(&self) -> SubgraphId {
116 self.subgraph_id
117 }
118
119 pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
125 self.event_queue_send.send((sg_id, is_external)).unwrap()
126 }
127
128 pub fn reschedule_loop_block(&self) {
130 self.reschedule_loop_block.set(true);
131 }
132
133 pub fn allow_another_iteration(&self) {
135 self.allow_another_iteration.set(true);
136 }
137
138 pub fn waker(&self) -> std::task::Waker {
141 use std::sync::Arc;
142 use std::task::Wake;
143
144 struct ContextWaker {
145 subgraph_id: SubgraphId,
146 event_queue_send: UnboundedSender<(SubgraphId, bool)>,
147 }
148 impl Wake for ContextWaker {
149 fn wake(self: Arc<Self>) {
150 self.wake_by_ref();
151 }
152
153 fn wake_by_ref(self: &Arc<Self>) {
154 let _recv_closed_error = self.event_queue_send.send((self.subgraph_id, true));
155 }
156 }
157
158 let context_waker = ContextWaker {
159 subgraph_id: self.subgraph_id,
160 event_queue_send: self.event_queue_send.clone(),
161 };
162 std::task::Waker::from(Arc::new(context_waker))
163 }
164
165 pub unsafe fn state_ref_unchecked<T>(&self, handle: StateHandle<T>) -> &'_ T
170 where
171 T: Any,
172 {
173 let state = self
174 .states
175 .get(handle.state_id)
176 .expect("Failed to find state with given handle.")
177 .state
178 .as_ref();
179
180 debug_assert!(state.is::<T>());
181
182 unsafe {
183 &*(state as *const dyn Any as *const T)
186 }
187 }
188
189 pub fn state_ref<T>(&self, handle: StateHandle<T>) -> &'_ T
191 where
192 T: Any,
193 {
194 self.states
195 .get(handle.state_id)
196 .expect("Failed to find state with given handle.")
197 .state
198 .downcast_ref()
199 .expect("StateHandle wrong type T for casting.")
200 }
201
202 pub fn state_mut<T>(&mut self, handle: StateHandle<T>) -> &'_ mut T
204 where
205 T: Any,
206 {
207 self.states
208 .get_mut(handle.state_id)
209 .expect("Failed to find state with given handle.")
210 .state
211 .downcast_mut()
212 .expect("StateHandle wrong type T for casting.")
213 }
214
215 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
217 where
218 T: Any,
219 {
220 let state_data = StateData {
221 state: Box::new(state),
222 lifespan_hook_fn: None,
223 lifespan: None,
224 };
225 let state_id = self.states.insert(state_data);
226
227 StateHandle {
228 state_id,
229 _phantom: PhantomData,
230 }
231 }
232
233 pub fn set_state_lifespan_hook<T>(
235 &mut self,
236 handle: StateHandle<T>,
237 lifespan: StateLifespan,
238 mut hook_fn: impl 'static + FnMut(&mut T),
239 ) where
240 T: Any,
241 {
242 let state_data = self
243 .states
244 .get_mut(handle.state_id)
245 .expect("Failed to find state with given handle.");
246 state_data.lifespan_hook_fn = Some(Box::new(move |state| {
247 (hook_fn)(state.downcast_mut::<T>().unwrap());
248 }));
249 state_data.lifespan = Some(lifespan);
250
251 match lifespan {
252 StateLifespan::Subgraph(key) => {
253 self.subgraph_states
254 .get_or_insert_with(key, Vec::new)
255 .push(handle.state_id);
256 }
257 StateLifespan::Loop(loop_id) => {
258 self.loop_states
259 .get_or_insert_with(loop_id, Vec::new)
260 .push(handle.state_id);
261 }
262 StateLifespan::Tick => {
263 }
265 StateLifespan::Static => {
266 }
268 }
269 }
270
271 pub fn request_task<Fut>(&mut self, future: Fut)
273 where
274 Fut: Future<Output = ()> + 'static,
275 {
276 self.tasks_to_spawn.push(Box::pin(future));
277 }
278
279 pub fn spawn_tasks(&mut self) {
281 for task in self.tasks_to_spawn.drain(..) {
282 self.task_join_handles.push(tokio::task::spawn_local(task));
283 }
284 }
285
286 pub fn abort_tasks(&mut self) {
288 for task in self.task_join_handles.drain(..) {
289 task.abort();
290 }
291 }
292
293 pub async fn join_tasks(&mut self) {
297 futures::future::join_all(self.task_join_handles.drain(..)).await;
298 }
299}
300
301impl Default for Context {
302 fn default() -> Self {
303 let stratum_queues = vec![Default::default()]; let (event_queue_send, event_queue_recv) = mpsc::unbounded_channel();
305 let (stratum_stack, loop_depth) = Default::default();
306 Self {
307 states: SlotVec::new(),
308
309 stratum_stack,
310
311 loop_nonce_stack: Vec::new(),
312
313 schedule_deferred: Vec::new(),
314
315 stratum_queues,
316 event_queue_recv,
317 can_start_tick: false,
318 events_received_tick: false,
319
320 event_queue_send,
321 reschedule_loop_block: Cell::new(false),
322 allow_another_iteration: Cell::new(false),
323
324 current_stratum: 0,
325 current_tick: TickInstant::default(),
326
327 current_tick_start: SystemTime::now(),
328 is_first_run_this_tick: false,
329 loop_iter_count: 0,
330
331 loop_depth,
332 loop_states: SecondarySlotVec::new(),
333 loop_nonce: 0,
334
335 subgraph_states: SecondarySlotVec::new(),
336
337 subgraph_id: SubgraphId::from_raw(0),
339
340 tasks_to_spawn: Vec::new(),
341 task_join_handles: Vec::new(),
342 }
343 }
344}
345impl Context {
347 pub(super) fn init_stratum(&mut self, stratum: usize) {
349 if self.stratum_queues.len() <= stratum {
350 self.stratum_queues
351 .resize_with(stratum + 1, Default::default);
352 }
353 }
354
355 pub(super) fn run_state_hooks_tick(&mut self) {
357 tracing::trace!("Running state hooks for tick.");
358 for state_data in self.states.values_mut() {
359 let StateData {
360 state,
361 lifespan_hook_fn: Some(lifespan_hook_fn),
362 lifespan: Some(StateLifespan::Tick),
363 } = state_data
364 else {
365 continue;
366 };
367 (lifespan_hook_fn)(Box::deref_mut(state));
368 }
369 }
370
371 pub(super) fn run_state_hooks_subgraph(&mut self, subgraph_id: SubgraphId) {
372 tracing::trace!("Running state hooks for subgraph.");
373 for state_id in self.subgraph_states.get(subgraph_id).into_iter().flatten() {
374 let StateData {
375 state,
376 lifespan_hook_fn,
377 lifespan: _,
378 } = self
379 .states
380 .get_mut(*state_id)
381 .expect("Failed to find state with given ID.");
382
383 if let Some(lifespan_hook_fn) = lifespan_hook_fn {
384 (lifespan_hook_fn)(Box::deref_mut(state));
385 }
386 }
387 }
388
389 pub(super) fn run_state_hooks_loop(&mut self, loop_id: LoopId) {
392 tracing::trace!(
393 loop_id = loop_id.to_string(),
394 "Running state hooks for loop."
395 );
396 for state_id in self.loop_states.get(loop_id).into_iter().flatten() {
397 let StateData {
398 state,
399 lifespan_hook_fn,
400 lifespan: _,
401 } = self
402 .states
403 .get_mut(*state_id)
404 .expect("Failed to find state with given ID.");
405
406 if let Some(lifespan_hook_fn) = lifespan_hook_fn {
407 (lifespan_hook_fn)(Box::deref_mut(state));
408 }
409 }
410 }
411}
412
413struct StateData {
415 state: Box<dyn Any>,
416 lifespan_hook_fn: Option<LifespanResetFn>, lifespan: Option<StateLifespan>,
419}
420type LifespanResetFn = Box<dyn FnMut(&mut dyn Any)>;