1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
121 pub write_prologue_after: TokenStream,
124 pub write_iterator: TokenStream,
131 pub write_iterator_after: TokenStream,
133}
134
135pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
137pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
139pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
141
142pub fn identity_write_iterator_fn(
145 &WriteContextArgs {
146 root,
147 op_span,
148 ident,
149 inputs,
150 outputs,
151 is_pull,
152 op_inst:
153 OperatorInstance {
154 generics: OpInstGenerics { type_args, .. },
155 ..
156 },
157 ..
158 }: &WriteContextArgs,
159) -> TokenStream {
160 let generic_type = type_args
161 .first()
162 .map(quote::ToTokens::to_token_stream)
163 .unwrap_or(quote_spanned!(op_span=> _));
164
165 if is_pull {
166 let input = &inputs[0];
167 quote_spanned! {op_span=>
168 let #ident = {
169 fn check_input<St, Item>(stream: St) -> impl #root::futures::stream::Stream<Item = Item>
170 where
171 St: #root::futures::stream::Stream<Item = Item>,
172 {
173 stream
174 }
175 check_input::<_, #generic_type>(#input)
176 };
177 }
178 } else {
179 let output = &outputs[0];
180 quote_spanned! {op_span=>
181 let #ident = {
182 fn check_output<Si, Item>(sink: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
183 where
184 Si: #root::futures::sink::Sink<Item, Error = #root::Never>,
185 {
186 sink
187 }
188 check_output::<_, #generic_type>(#output)
189 };
190 }
191 }
192}
193
194pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
196 let write_iterator = identity_write_iterator_fn(write_context_args);
197 Ok(OperatorWriteOutput {
198 write_iterator,
199 ..Default::default()
200 })
201};
202
203pub fn null_write_iterator_fn(
206 &WriteContextArgs {
207 root,
208 op_span,
209 ident,
210 inputs,
211 outputs,
212 is_pull,
213 op_inst:
214 OperatorInstance {
215 generics: OpInstGenerics { type_args, .. },
216 ..
217 },
218 ..
219 }: &WriteContextArgs,
220) -> TokenStream {
221 let default_type = parse_quote_spanned! {op_span=> _};
222 let iter_type = type_args.first().unwrap_or(&default_type);
223
224 if is_pull {
225 quote_spanned! {op_span=>
226 let #ident = #root::futures::stream::poll_fn(move |_cx| {
227 #(
229 let #inputs = #root::futures::stream::Stream::poll_next(::std::pin::pin!(#inputs), _cx);
230 )*
231 #(
232 let _ = ::std::task::ready!(#inputs);
233 )*
234 ::std::task::Poll::Ready(::std::option::Option::None)
235 });
236 }
237 } else {
238 quote_spanned! {op_span=>
239 #[allow(clippy::let_unit_value)]
240 let _ = (#(#outputs),*);
241 let #ident = #root::sinktools::for_each::ForEach::new::<#iter_type>(::std::mem::drop::<#iter_type>);
242 }
243 }
244}
245
246pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
249 let write_iterator = null_write_iterator_fn(write_context_args);
250 Ok(OperatorWriteOutput {
251 write_iterator,
252 ..Default::default()
253 })
254};
255
256macro_rules! declare_ops {
257 ( $( $mod:ident :: $op:ident, )* ) => {
258 $( pub(crate) mod $mod; )*
259 pub const OPERATORS: &[OperatorConstraints] = &[
261 $( $mod :: $op, )*
262 ];
263 };
264}
265declare_ops![
266 all_iterations::ALL_ITERATIONS,
267 all_once::ALL_ONCE,
268 anti_join::ANTI_JOIN,
269 assert::ASSERT,
270 assert_eq::ASSERT_EQ,
271 batch::BATCH,
272 chain::CHAIN,
273 chain_first_n::CHAIN_FIRST_N,
274 _counter::_COUNTER,
275 cross_join::CROSS_JOIN,
276 cross_join_multiset::CROSS_JOIN_MULTISET,
277 cross_singleton::CROSS_SINGLETON,
278 demux_enum::DEMUX_ENUM,
279 dest_file::DEST_FILE,
280 dest_sink::DEST_SINK,
281 dest_sink_serde::DEST_SINK_SERDE,
282 difference::DIFFERENCE,
283 enumerate::ENUMERATE,
284 filter::FILTER,
285 filter_map::FILTER_MAP,
286 flat_map::FLAT_MAP,
287 flatten::FLATTEN,
288 fold::FOLD,
289 fold_no_replay::FOLD_NO_REPLAY,
290 for_each::FOR_EACH,
291 identity::IDENTITY,
292 initialize::INITIALIZE,
293 inspect::INSPECT,
294 join::JOIN,
295 join_fused::JOIN_FUSED,
296 join_fused_lhs::JOIN_FUSED_LHS,
297 join_fused_rhs::JOIN_FUSED_RHS,
298 join_multiset::JOIN_MULTISET,
299 fold_keyed::FOLD_KEYED,
300 reduce_keyed::REDUCE_KEYED,
301 repeat_n::REPEAT_N,
302 lattice_bimorphism::LATTICE_BIMORPHISM,
304 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
305 lattice_fold::LATTICE_FOLD,
306 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
307 lattice_reduce::LATTICE_REDUCE,
308 map::MAP,
309 union::UNION,
310 multiset_delta::MULTISET_DELTA,
311 next_iteration::NEXT_ITERATION,
312 next_stratum::NEXT_STRATUM,
313 defer_signal::DEFER_SIGNAL,
314 defer_tick::DEFER_TICK,
315 defer_tick_lazy::DEFER_TICK_LAZY,
316 null::NULL,
317 partition::PARTITION,
318 persist::PERSIST,
319 persist_mut::PERSIST_MUT,
320 persist_mut_keyed::PERSIST_MUT_KEYED,
321 prefix::PREFIX,
322 resolve_futures::RESOLVE_FUTURES,
323 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
324 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
325 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
326 reduce::REDUCE,
327 reduce_no_replay::REDUCE_NO_REPLAY,
328 scan::SCAN,
329 spin::SPIN,
330 sort::SORT,
331 sort_by_key::SORT_BY_KEY,
332 source_file::SOURCE_FILE,
333 source_interval::SOURCE_INTERVAL,
334 source_iter::SOURCE_ITER,
335 source_json::SOURCE_JSON,
336 source_stdin::SOURCE_STDIN,
337 source_stream::SOURCE_STREAM,
338 source_stream_serde::SOURCE_STREAM_SERDE,
339 state::STATE,
340 state_by::STATE_BY,
341 tee::TEE,
342 unique::UNIQUE,
343 unzip::UNZIP,
344 zip::ZIP,
345 zip_longest::ZIP_LONGEST,
346];
347
348pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
350 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
351 OnceLock::new();
352 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
353}
354pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
356 if let GraphNode::Operator(operator) = node {
357 find_op_op_constraints(operator)
358 } else {
359 None
360 }
361}
362pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
364 let name = &*operator.name_string();
365 operator_lookup().get(name).copied()
366}
367
368#[derive(Clone)]
370pub struct WriteContextArgs<'a> {
371 pub root: &'a TokenStream,
373 pub context: &'a Ident,
376 pub df_ident: &'a Ident,
380 pub subgraph_id: GraphSubgraphId,
382 pub node_id: GraphNodeId,
384 pub loop_id: Option<GraphLoopId>,
386 pub op_span: Span,
388 pub op_tag: Option<String>,
390 pub work_fn: &'a Ident,
392 pub work_fn_async: &'a Ident,
394
395 pub ident: &'a Ident,
397 pub is_pull: bool,
399 pub inputs: &'a [Ident],
401 pub outputs: &'a [Ident],
403 pub singleton_output_ident: &'a Ident,
405
406 pub op_name: &'static str,
408 pub op_inst: &'a OperatorInstance,
410 pub arguments: &'a Punctuated<Expr, Token![,]>,
416 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
418}
419impl WriteContextArgs<'_> {
420 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
426 Ident::new(
427 &format!(
428 "sg_{:?}_node_{:?}_{}",
429 self.subgraph_id.data(),
430 self.node_id.data(),
431 suffix.as_ref(),
432 ),
433 self.op_span,
434 )
435 }
436
437 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
440 let root = self.root;
441 let variant =
442 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
443 Some(quote_spanned! {self.op_span=>
444 #root::scheduled::graph::StateLifespan::#variant
445 })
446 }
447
448 pub fn persistence_args_disallow_mutable<const N: usize>(
450 &self,
451 diagnostics: &mut Diagnostics,
452 ) -> [Persistence; N] {
453 let len = self.op_inst.generics.persistence_args.len();
454 if 0 != len && 1 != len && N != len {
455 diagnostics.push(Diagnostic::spanned(
456 self.op_span,
457 Level::Error,
458 format!(
459 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
460 self.op_name, N
461 ),
462 ));
463 }
464
465 let default_persistence = if self.loop_id.is_some() {
466 Persistence::None
467 } else {
468 Persistence::Tick
469 };
470 let mut out = [default_persistence; N];
471 self.op_inst
472 .generics
473 .persistence_args
474 .iter()
475 .copied()
476 .cycle() .take(N)
478 .enumerate()
479 .filter(|&(_i, p)| {
480 if p == Persistence::Mutable {
481 diagnostics.push(Diagnostic::spanned(
482 self.op_span,
483 Level::Error,
484 format!(
485 "An implementation of `'{}` does not exist",
486 p.to_str_lowercase()
487 ),
488 ));
489 false
490 } else {
491 true
492 }
493 })
494 .for_each(|(i, p)| {
495 out[i] = p;
496 });
497 out
498 }
499}
500
501pub trait RangeTrait<T>: Send + Sync + Debug
503where
504 T: ?Sized,
505{
506 fn start_bound(&self) -> Bound<&T>;
508 fn end_bound(&self) -> Bound<&T>;
510 fn contains(&self, item: &T) -> bool
512 where
513 T: PartialOrd<T>;
514
515 fn human_string(&self) -> String
517 where
518 T: Display + PartialEq,
519 {
520 match (self.start_bound(), self.end_bound()) {
521 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
522
523 (Bound::Included(n), Bound::Included(x)) if n == x => {
524 format!("exactly {}", n)
525 }
526 (Bound::Included(n), Bound::Included(x)) => {
527 format!("at least {} and at most {}", n, x)
528 }
529 (Bound::Included(n), Bound::Excluded(x)) => {
530 format!("at least {} and less than {}", n, x)
531 }
532 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
533 (Bound::Excluded(n), Bound::Included(x)) => {
534 format!("more than {} and at most {}", n, x)
535 }
536 (Bound::Excluded(n), Bound::Excluded(x)) => {
537 format!("more than {} and less than {}", n, x)
538 }
539 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
540 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
541 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
542 }
543 }
544}
545
546impl<R, T> RangeTrait<T> for R
547where
548 R: RangeBounds<T> + Send + Sync + Debug,
549{
550 fn start_bound(&self) -> Bound<&T> {
551 self.start_bound()
552 }
553
554 fn end_bound(&self) -> Bound<&T> {
555 self.end_bound()
556 }
557
558 fn contains(&self, item: &T) -> bool
559 where
560 T: PartialOrd<T>,
561 {
562 self.contains(item)
563 }
564}
565
566#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
568pub enum Persistence {
569 None,
571 Loop,
573 Tick,
575 Static,
577 Mutable,
579}
580impl Persistence {
581 pub fn as_state_lifespan_variant(
583 self,
584 subgraph_id: GraphSubgraphId,
585 loop_id: Option<GraphLoopId>,
586 span: Span,
587 ) -> Option<TokenStream> {
588 match self {
589 Persistence::None => {
590 let sg_ident = subgraph_id.as_ident(span);
591 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
592 }
593 Persistence::Loop => {
594 let loop_ident = loop_id
595 .expect("`Persistence::Loop` outside of a loop context.")
596 .as_ident(span);
597 Some(quote_spanned!(span=> Loop(#loop_ident)))
598 }
599 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
600 Persistence::Static => None,
601 Persistence::Mutable => None,
602 }
603 }
604
605 pub fn to_str_lowercase(self) -> &'static str {
607 match self {
608 Persistence::None => "none",
609 Persistence::Tick => "tick",
610 Persistence::Loop => "loop",
611 Persistence::Static => "static",
612 Persistence::Mutable => "mutable",
613 }
614 }
615}
616
617fn make_missing_runtime_msg(op_name: &str) -> Literal {
619 Literal::string(&format!(
620 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
621 op_name
622 ))
623}
624
625#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
627pub enum OperatorCategory {
628 Map,
630 Filter,
632 Flatten,
634 Fold,
636 KeyedFold,
638 LatticeFold,
640 Persistence,
642 MultiIn,
644 MultiOut,
646 Source,
648 Sink,
650 Control,
652 CompilerFusionOperator,
654 Windowing,
656 Unwindowing,
658}
659impl OperatorCategory {
660 pub fn name(self) -> &'static str {
662 self.get_variant_docs().split_once(":").unwrap().0
663 }
664 pub fn description(self) -> &'static str {
666 self.get_variant_docs().split_once(":").unwrap().1
667 }
668}
669
670#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
672pub enum FloType {
673 Source,
675 Windowing,
677 Unwindowing,
679 NextIteration,
681}