Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24    Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71    /// Which stratum each subgraph belongs to.
72    subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74    /// Resolved singletons varnames references, per node.
75    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
77    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79    /// If this subgraph is 'lazy' then when it sends data to a lower stratum it does not cause a new tick to start
80    /// This is to support lazy defers
81    /// If the value does not exist for a given subgraph id then the subgraph is not lazy.
82    subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85/// Basic methods.
86impl DfirGraph {
87    /// Create a new empty graph.
88    pub fn new() -> Self {
89        Default::default()
90    }
91}
92
93/// Node methods.
94impl DfirGraph {
95    /// Get a node with its operator instance (if applicable).
96    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97        self.nodes.get(node_id).expect("Node not found.")
98    }
99
100    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
101    /// `OperatorInstance` present, otherwise will return `None`.
102    ///
103    /// Note that no operator instances will be persent after deserialization.
104    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105        self.operator_instances.get(node_id)
106    }
107
108    /// Get the debug variable name attached to a graph node.
109    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110        self.node_varnames.get(node_id)
111    }
112
113    /// Get subgraph for node.
114    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115        self.node_subgraph.get(node_id).copied()
116    }
117
118    /// Degree into a node, i.e. the number of predecessors.
119    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120        self.graph.degree_in(node_id)
121    }
122
123    /// Degree out of a node, i.e. the number of successors.
124    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125        self.graph.degree_out(node_id)
126    }
127
128    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
129    pub fn node_successors(
130        &self,
131        src: GraphNodeId,
132    ) -> impl '_
133    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134    + ExactSizeIterator
135    + FusedIterator
136    + Clone
137    + Debug {
138        self.graph.successors(src)
139    }
140
141    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
142    pub fn node_predecessors(
143        &self,
144        dst: GraphNodeId,
145    ) -> impl '_
146    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147    + ExactSizeIterator
148    + FusedIterator
149    + Clone
150    + Debug {
151        self.graph.predecessors(dst)
152    }
153
154    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
155    pub fn node_successor_edges(
156        &self,
157        src: GraphNodeId,
158    ) -> impl '_
159    + DoubleEndedIterator<Item = GraphEdgeId>
160    + ExactSizeIterator
161    + FusedIterator
162    + Clone
163    + Debug {
164        self.graph.successor_edges(src)
165    }
166
167    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
168    pub fn node_predecessor_edges(
169        &self,
170        dst: GraphNodeId,
171    ) -> impl '_
172    + DoubleEndedIterator<Item = GraphEdgeId>
173    + ExactSizeIterator
174    + FusedIterator
175    + Clone
176    + Debug {
177        self.graph.predecessor_edges(dst)
178    }
179
180    /// Successor nodes, iterator of `GraphNodeId`.
181    pub fn node_successor_nodes(
182        &self,
183        src: GraphNodeId,
184    ) -> impl '_
185    + DoubleEndedIterator<Item = GraphNodeId>
186    + ExactSizeIterator
187    + FusedIterator
188    + Clone
189    + Debug {
190        self.graph.successor_vertices(src)
191    }
192
193    /// Predecessor nodes, iterator of `GraphNodeId`.
194    pub fn node_predecessor_nodes(
195        &self,
196        dst: GraphNodeId,
197    ) -> impl '_
198    + DoubleEndedIterator<Item = GraphNodeId>
199    + ExactSizeIterator
200    + FusedIterator
201    + Clone
202    + Debug {
203        self.graph.predecessor_vertices(dst)
204    }
205
206    /// Iterator of node IDs `GraphNodeId`.
207    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208        self.nodes.keys()
209    }
210
211    /// Iterator over `(GraphNodeId, &Node)` pairs.
212    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213        self.nodes.iter()
214    }
215
216    /// Insert a node, assigning the given varname.
217    pub fn insert_node(
218        &mut self,
219        node: GraphNode,
220        varname_opt: Option<Ident>,
221        loop_opt: Option<GraphLoopId>,
222    ) -> GraphNodeId {
223        let node_id = self.nodes.insert(node);
224        if let Some(varname) = varname_opt {
225            self.node_varnames.insert(node_id, Varname(varname));
226        }
227        if let Some(loop_id) = loop_opt {
228            self.node_loops.insert(node_id, loop_id);
229            self.loop_nodes[loop_id].push(node_id);
230        }
231        node_id
232    }
233
234    /// Insert an operator instance for the given node. Panics if already set.
235    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236        assert!(matches!(
237            self.nodes.get(node_id),
238            Some(GraphNode::Operator(_))
239        ));
240        let old_inst = self.operator_instances.insert(node_id, op_inst);
241        assert!(old_inst.is_none());
242    }
243
244    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
245    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
246        let mut op_insts = Vec::new();
247        for (node_id, node) in self.nodes() {
248            let GraphNode::Operator(operator) = node else {
249                continue;
250            };
251            if self.node_op_inst(node_id).is_some() {
252                continue;
253            };
254
255            // Op constraints.
256            let Some(op_constraints) = find_op_op_constraints(operator) else {
257                diagnostics.push(Diagnostic::spanned(
258                    operator.path.span(),
259                    Level::Error,
260                    format!("Unknown operator `{}`", operator.name_string()),
261                ));
262                continue;
263            };
264
265            // Input and output ports.
266            let (input_ports, output_ports) = {
267                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268                    .node_predecessors(node_id)
269                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270                    .collect();
271                // Ensure sorted by port index.
272                input_edges.sort();
273                let input_ports: Vec<PortIndexValue> = input_edges
274                    .into_iter()
275                    .map(|(port, _pred)| port)
276                    .cloned()
277                    .collect();
278
279                // Collect output arguments (successors).
280                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281                    .node_successors(node_id)
282                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283                    .collect();
284                // Ensure sorted by port index.
285                output_edges.sort();
286                let output_ports: Vec<PortIndexValue> = output_edges
287                    .into_iter()
288                    .map(|(port, _succ)| port)
289                    .cloned()
290                    .collect();
291
292                (input_ports, output_ports)
293            };
294
295            // Generic arguments.
296            let generics = get_operator_generics(diagnostics, operator);
297            // Generic argument errors.
298            {
299                // Span of `generic_args` (if it exists), otherwise span of the operator name.
300                let generics_span = generics
301                    .generic_args
302                    .as_ref()
303                    .map(Spanned::span)
304                    .unwrap_or_else(|| operator.path.span());
305
306                if !op_constraints
307                    .persistence_args
308                    .contains(&generics.persistence_args.len())
309                {
310                    diagnostics.push(Diagnostic::spanned(
311                        generics_span,
312                        Level::Error,
313                        format!(
314                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
315                            op_constraints.name,
316                            op_constraints.persistence_args.human_string(),
317                            generics.persistence_args.len()
318                        ),
319                    ));
320                }
321                if !op_constraints.type_args.contains(&generics.type_args.len()) {
322                    diagnostics.push(Diagnostic::spanned(
323                        generics_span,
324                        Level::Error,
325                        format!(
326                            "`{}` should have {} generic type arguments, actually has {}.",
327                            op_constraints.name,
328                            op_constraints.type_args.human_string(),
329                            generics.type_args.len()
330                        ),
331                    ));
332                }
333            }
334
335            op_insts.push((
336                node_id,
337                OperatorInstance {
338                    op_constraints,
339                    input_ports,
340                    output_ports,
341                    singletons_referenced: operator.singletons_referenced.clone(),
342                    generics,
343                    arguments_pre: operator.args.clone(),
344                    arguments_raw: operator.args_raw.clone(),
345                },
346            ));
347        }
348
349        for (node_id, op_inst) in op_insts {
350            self.insert_node_op_inst(node_id, op_inst);
351        }
352    }
353
354    /// Inserts a node between two existing nodes connected by the given `edge_id`.
355    ///
356    /// `edge`: (src, dst, dst_idx)
357    ///
358    /// Before: A (src) ------------> B (dst)
359    /// After:  A (src) -> X (new) -> B (dst)
360    ///
361    /// Returns the ID of X & ID of edge OUT of X.
362    ///
363    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
364    /// get the edge type of the original edge.
365    pub fn insert_intermediate_node(
366        &mut self,
367        edge_id: GraphEdgeId,
368        new_node: GraphNode,
369    ) -> (GraphNodeId, GraphEdgeId) {
370        let span = Some(new_node.span());
371
372        // Make corresponding operator instance (if `node` is an operator).
373        let op_inst_opt = 'oc: {
374            let GraphNode::Operator(operator) = &new_node else {
375                break 'oc None;
376            };
377            let Some(op_constraints) = find_op_op_constraints(operator) else {
378                break 'oc None;
379            };
380            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381
382            let mut dummy_diagnostics = Diagnostics::new();
383            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
384            assert!(dummy_diagnostics.is_empty());
385
386            Some(OperatorInstance {
387                op_constraints,
388                input_ports: vec![input_port],
389                output_ports: vec![output_port],
390                singletons_referenced: operator.singletons_referenced.clone(),
391                generics,
392                arguments_pre: operator.args.clone(),
393                arguments_raw: operator.args_raw.clone(),
394            })
395        };
396
397        // Insert new `node`.
398        let node_id = self.nodes.insert(new_node);
399        // Insert corresponding `OperatorInstance` if applicable.
400        if let Some(op_inst) = op_inst_opt {
401            self.operator_instances.insert(node_id, op_inst);
402        }
403        // Update edges to insert node within `edge_id`.
404        let (e0, e1) = self
405            .graph
406            .insert_intermediate_vertex(node_id, edge_id)
407            .unwrap();
408
409        // Update corresponding ports.
410        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
411        self.ports
412            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
413        self.ports
414            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
415
416        (node_id, e1)
417    }
418
419    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
420    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
421    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
422        assert_eq!(
423            1,
424            self.node_degree_in(node_id),
425            "Removed intermediate node must have one predecessor"
426        );
427        assert_eq!(
428            1,
429            self.node_degree_out(node_id),
430            "Removed intermediate node must have one successor"
431        );
432        assert!(
433            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
434            "Should not remove intermediate node after subgraph partitioning"
435        );
436
437        assert!(self.nodes.remove(node_id).is_some());
438        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
439            self.graph.remove_intermediate_vertex(node_id).unwrap();
440        self.operator_instances.remove(node_id);
441        self.node_varnames.remove(node_id);
442
443        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
444        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
445        self.ports.insert(new_edge_id, (src_port, dst_port));
446    }
447
448    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
449    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
450    /// either push or pull.
451    ///
452    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
453    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
454        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
455            return Some(Color::Hoff);
456        }
457
458        // TODO(shadaj): this is a horrible hack
459        if let GraphNode::Operator(op) = self.node(node_id)
460            && (op.name_string() == "resolve_futures_blocking"
461                || op.name_string() == "resolve_futures_blocking_ordered")
462        {
463            return Some(Color::Push);
464        }
465
466        // In-degree, excluding ref-edges.
467        let inn_degree = self.node_predecessor_nodes(node_id).count();
468        // Out-degree excluding ref-edges.
469        let out_degree = self.node_successor_nodes(node_id).count();
470
471        match (inn_degree, out_degree) {
472            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
473            (0, 1) => Some(Color::Pull),
474            (1, 0) => Some(Color::Push),
475            (1, 1) => None, // Linear, can be either push or pull.
476            (_many, 0 | 1) => Some(Color::Pull),
477            (0 | 1, _many) => Some(Color::Push),
478            (_many, _to_many) => Some(Color::Comp),
479        }
480    }
481
482    /// Set the operator tag (for debugging/tracing).
483    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
484        self.operator_tag.insert(node_id, tag);
485    }
486}
487
488/// Singleton references.
489impl DfirGraph {
490    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
491    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
492    pub fn set_node_singleton_references(
493        &mut self,
494        node_id: GraphNodeId,
495        singletons_referenced: Vec<Option<GraphNodeId>>,
496    ) -> Option<Vec<Option<GraphNodeId>>> {
497        self.node_singleton_references
498            .insert(node_id, singletons_referenced)
499    }
500
501    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
502    /// operators that do not reference singletons.
503    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
504        self.node_singleton_references
505            .get(node_id)
506            .map(std::ops::Deref::deref)
507            .unwrap_or_default()
508    }
509}
510
511/// Module methods.
512impl DfirGraph {
513    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
514    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
515    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
516    /// For example:
517    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
518    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
519    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
520    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
521        let mod_bound_nodes = self
522            .nodes()
523            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
524            .map(|(nid, _node)| nid)
525            .collect::<Vec<_>>();
526
527        for mod_bound_node in mod_bound_nodes {
528            self.remove_module_boundary(mod_bound_node)?;
529        }
530
531        Ok(())
532    }
533
534    /// see `merge_modules`
535    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
536    /// `merge_modules` calls this function for each module boundary in the graph.
537    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
538        assert!(
539            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
540            "Should not remove intermediate node after subgraph partitioning"
541        );
542
543        let mut mod_pred_ports = BTreeMap::new();
544        let mut mod_succ_ports = BTreeMap::new();
545
546        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
547            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
548            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
549        }
550
551        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
552            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
553            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
554        }
555
556        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
557            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
558        {
559            // get module boundary node
560            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
561                panic!();
562            };
563
564            if *input {
565                return Err(Diagnostic {
566                    span: *import_expr,
567                    level: Level::Error,
568                    message: format!(
569                        "The ports into the module did not match. input: {:?}, expected: {:?}",
570                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
571                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
572                    ),
573                });
574            } else {
575                return Err(Diagnostic {
576                    span: *import_expr,
577                    level: Level::Error,
578                    message: format!(
579                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
580                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
581                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
582                    ),
583                });
584            }
585        }
586
587        for (port, (pred_edge, pred_port)) in mod_pred_ports {
588            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
589
590            let (src, _) = self.edge(pred_edge);
591            let (_, dst) = self.edge(succ_edge);
592            self.remove_edge(pred_edge);
593            self.remove_edge(succ_edge);
594
595            let new_edge_id = self.graph.insert_edge(src, dst);
596            self.ports.insert(new_edge_id, (pred_port, succ_port));
597        }
598
599        self.graph.remove_vertex(mod_bound_node);
600        self.nodes.remove(mod_bound_node);
601
602        Ok(())
603    }
604}
605
606/// Edge methods.
607impl DfirGraph {
608    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
609    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
610        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
611        (src, dst)
612    }
613
614    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
615    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
616        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
617        (src_port, dst_port)
618    }
619
620    /// Iterator of all edge IDs `GraphEdgeId`.
621    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
622        self.graph.edge_ids()
623    }
624
625    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
626    pub fn edges(
627        &self,
628    ) -> impl '_
629    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
630    + FusedIterator
631    + Clone
632    + Debug {
633        self.graph.edges()
634    }
635
636    /// Insert an edge between nodes thru the given ports.
637    pub fn insert_edge(
638        &mut self,
639        src: GraphNodeId,
640        src_port: PortIndexValue,
641        dst: GraphNodeId,
642        dst_port: PortIndexValue,
643    ) -> GraphEdgeId {
644        let edge_id = self.graph.insert_edge(src, dst);
645        self.ports.insert(edge_id, (src_port, dst_port));
646        edge_id
647    }
648
649    /// Removes an edge and its corresponding ports and edge type info.
650    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
651        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
652        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
653    }
654}
655
656/// Subgraph methods.
657impl DfirGraph {
658    /// Nodes belonging to the given subgraph.
659    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
660        self.subgraph_nodes
661            .get(subgraph_id)
662            .expect("Subgraph not found.")
663    }
664
665    /// Iterator over all subgraph IDs.
666    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
667        self.subgraph_nodes.keys()
668    }
669
670    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
671    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
672        self.subgraph_nodes.iter()
673    }
674
675    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
676    pub fn insert_subgraph(
677        &mut self,
678        node_ids: Vec<GraphNodeId>,
679    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
680        // Check none are already in subgraphs
681        for &node_id in node_ids.iter() {
682            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
683                return Err((node_id, old_sg_id));
684            }
685        }
686        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
687            for &node_id in node_ids.iter() {
688                self.node_subgraph.insert(node_id, sg_id);
689            }
690            node_ids
691        });
692
693        Ok(subgraph_id)
694    }
695
696    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
697    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
698        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
699            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
700            true
701        } else {
702            false
703        }
704    }
705
706    /// Gets the stratum number of the subgraph.
707    pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
708        self.subgraph_stratum.get(sg_id).copied()
709    }
710
711    /// Set subgraph's stratum number, returning the old value if exists.
712    pub fn set_subgraph_stratum(
713        &mut self,
714        sg_id: GraphSubgraphId,
715        stratum: usize,
716    ) -> Option<usize> {
717        self.subgraph_stratum.insert(sg_id, stratum)
718    }
719
720    /// Gets whether the subgraph is lazy or not
721    fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
722        self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
723    }
724
725    /// Set subgraph's laziness, returning the old value.
726    pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
727        self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
728    }
729
730    /// Returns the the stratum number of the largest (latest) stratum (inclusive).
731    pub fn max_stratum(&self) -> Option<usize> {
732        self.subgraph_stratum.values().copied().max()
733    }
734
735    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
736    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
737        subgraph_nodes
738            .iter()
739            .position(|&node_id| {
740                self.node_color(node_id)
741                    .is_some_and(|color| Color::Pull != color)
742            })
743            .unwrap_or(subgraph_nodes.len())
744    }
745}
746
747/// Display/output methods.
748impl DfirGraph {
749    /// Helper to generate a deterministic `Ident` for the given node.
750    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
751        let name = match &self.nodes[node_id] {
752            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
753            GraphNode::Handoff { .. } => format!(
754                "hoff_{:?}_{}",
755                node_id.data(),
756                if is_pred { "recv" } else { "send" }
757            ),
758            GraphNode::ModuleBoundary { .. } => panic!(),
759        };
760        let span = match (is_pred, &self.nodes[node_id]) {
761            (_, GraphNode::Operator(operator)) => operator.span(),
762            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
763            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
764            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
765        };
766        Ident::new(&name, span)
767    }
768
769    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
770    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
771        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
772    }
773
774    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
775    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
776        self.node_singleton_references(node_id)
777            .iter()
778            .map(|singleton_node_id| {
779                // TODO(mingwei): this `expect` should be caught in error checking
780                self.node_as_singleton_ident(
781                    singleton_node_id
782                        .expect("Expected singleton to be resolved but was not, this is a bug."),
783                    span,
784                )
785            })
786            .collect::<Vec<_>>()
787    }
788
789    /// Returns each subgraph's receive and send handoffs.
790    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
791    fn helper_collect_subgraph_handoffs(
792        &self,
793    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
794        // Get data on handoff src and dst subgraphs.
795        let mut subgraph_handoffs: SecondaryMap<
796            GraphSubgraphId,
797            (Vec<GraphNodeId>, Vec<GraphNodeId>),
798        > = self
799            .subgraph_nodes
800            .keys()
801            .map(|k| (k, Default::default()))
802            .collect();
803
804        // For each handoff node, add it to the `send`/`recv` lists for the corresponding subgraphs.
805        for (hoff_id, node) in self.nodes() {
806            if !matches!(node, GraphNode::Handoff { .. }) {
807                continue;
808            }
809            // Receivers from the handoff. (Should really only be one).
810            for (_edge, succ_id) in self.node_successors(hoff_id) {
811                let succ_sg = self.node_subgraph(succ_id).unwrap();
812                subgraph_handoffs[succ_sg].0.push(hoff_id);
813            }
814            // Senders into the handoff. (Should really only be one).
815            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
816                let pred_sg = self.node_subgraph(pred_id).unwrap();
817                subgraph_handoffs[pred_sg].1.push(hoff_id);
818            }
819        }
820
821        subgraph_handoffs
822    }
823
824    /// Code for adding all nested loops.
825    fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
826        // Breadth-first iteration from outermost (root) loops to deepest nested loops.
827        let mut out = TokenStream::new();
828        let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
829        while let Some(loop_id) = queue.pop_front() {
830            let parent_opt = self
831                .loop_parent(loop_id)
832                .map(|loop_id| loop_id.as_ident(Span::call_site()))
833                .map(|ident| quote! { Some(#ident) })
834                .unwrap_or_else(|| quote! { None });
835            let loop_name = loop_id.as_ident(Span::call_site());
836            out.append_all(quote! {
837                let #loop_name = #df.add_loop(#parent_opt);
838            });
839            queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
840        }
841        out
842    }
843
844    /// Emit this graph as runnable Rust source code tokens.
845    ///
846    /// Returns all diagnostics as `Err(diagnostics)` if any are errors (leaving `&mut diagnostics` empty).
847    pub fn as_code(
848        &self,
849        root: &TokenStream,
850        include_type_guards: bool,
851        prefix: TokenStream,
852        diagnostics: &mut Diagnostics,
853    ) -> Result<TokenStream, Diagnostics> {
854        let df = Ident::new(GRAPH, Span::call_site());
855        let context = Ident::new(CONTEXT, Span::call_site());
856
857        // Code for adding handoffs.
858        let handoff_code = self
859            .nodes
860            .iter()
861            .filter_map(|(node_id, node)| match node {
862                GraphNode::Operator(_) => None,
863                &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
864                GraphNode::ModuleBoundary { .. } => panic!(),
865            })
866            .map(|(node_id, (src_span, dst_span))| {
867                let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
868                let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
869                let span = src_span.join(dst_span).unwrap_or(src_span);
870                let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
871                hoff_name.set_span(span);
872                let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
873                quote_spanned! {span=>
874                    let (#ident_send, #ident_recv) =
875                        #df.make_edge::<_, #hoff_type>(#hoff_name);
876                }
877            });
878
879        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
880
881        // we first generate the subgraphs that have no inputs to guide type inference
882        let (subgraphs_without_preds, subgraphs_with_preds) = self
883            .subgraph_nodes
884            .iter()
885            .partition::<Vec<_>, _>(|(_, nodes)| {
886                nodes
887                    .iter()
888                    .any(|&node_id| self.node_degree_in(node_id) == 0)
889            });
890
891        let mut op_prologue_code = Vec::new();
892        let mut op_prologue_after_code = Vec::new();
893        let mut subgraphs = Vec::new();
894        {
895            for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
896                .iter()
897                .chain(subgraphs_with_preds.iter())
898            {
899                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
900                let recv_ports: Vec<Ident> = recv_hoffs
901                    .iter()
902                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
903                    .collect();
904                let send_ports: Vec<Ident> = send_hoffs
905                    .iter()
906                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
907                    .collect();
908
909                let recv_port_code = recv_ports.iter().map(|ident| {
910                    quote_spanned! {ident.span()=>
911                        let mut #ident = #ident.borrow_mut_swap();
912                        let #ident = #root::futures::stream::iter(#ident.drain(..));
913                    }
914                });
915                let send_port_code = send_ports.iter().map(|ident| {
916                    quote_spanned! {ident.span()=>
917                        let #ident = #root::sinktools::for_each(|v| {
918                            #ident.give(Some(v));
919                        });
920                    }
921                });
922
923                let loop_id = self
924                    // All nodes in a subgraph should be in the same loop.
925                    .node_loop(subgraph_nodes[0]);
926
927                let mut subgraph_op_iter_code = Vec::new();
928                let mut subgraph_op_iter_after_code = Vec::new();
929                {
930                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
931
932                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
933                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
934
935                    for (idx, &node_id) in nodes_iter.enumerate() {
936                        let node = &self.nodes[node_id];
937                        assert!(
938                            matches!(node, GraphNode::Operator(_)),
939                            "Handoffs are not part of subgraphs."
940                        );
941                        let op_inst = &self.operator_instances[node_id];
942
943                        let op_span = node.span();
944                        let op_name = op_inst.op_constraints.name;
945                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
946                        let root = change_spans(root.clone(), op_span);
947                        // TODO(mingwei): Just use `op_inst.op_constraints`?
948                        let op_constraints = OPERATORS
949                            .iter()
950                            .find(|op| op_name == op.name)
951                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
952
953                        let ident = self.node_as_ident(node_id, false);
954
955                        {
956                            // TODO clean this up.
957                            // Collect input arguments (predecessors).
958                            let mut input_edges = self
959                                .graph
960                                .predecessor_edges(node_id)
961                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
962                                .collect::<Vec<_>>();
963                            // Ensure sorted by port index.
964                            input_edges.sort();
965
966                            let inputs = input_edges
967                                .iter()
968                                .map(|&(_port, edge_id)| {
969                                    let (pred, _) = self.edge(edge_id);
970                                    self.node_as_ident(pred, true)
971                                })
972                                .collect::<Vec<_>>();
973
974                            // Collect output arguments (successors).
975                            let mut output_edges = self
976                                .graph
977                                .successor_edges(node_id)
978                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
979                                .collect::<Vec<_>>();
980                            // Ensure sorted by port index.
981                            output_edges.sort();
982
983                            let outputs = output_edges
984                                .iter()
985                                .map(|&(_port, edge_id)| {
986                                    let (_, succ) = self.edge(edge_id);
987                                    self.node_as_ident(succ, false)
988                                })
989                                .collect::<Vec<_>>();
990
991                            let is_pull = idx < pull_to_push_idx;
992
993                            let singleton_output_ident = &if op_constraints.has_singleton_output {
994                                self.node_as_singleton_ident(node_id, op_span)
995                            } else {
996                                // This ident *should* go unused.
997                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
998                            };
999
1000                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1001                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1002                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1003                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1004                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1005                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1006                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1007                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1008                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1009                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1010
1011                            let singletons_resolved =
1012                                self.helper_resolve_singletons(node_id, op_span);
1013                            let arguments = &process_singletons::postprocess_singletons(
1014                                op_inst.arguments_raw.clone(),
1015                                singletons_resolved.clone(),
1016                                context,
1017                            );
1018                            let arguments_handles =
1019                                &process_singletons::postprocess_singletons_handles(
1020                                    op_inst.arguments_raw.clone(),
1021                                    singletons_resolved.clone(),
1022                                );
1023
1024                            let source_tag = 'a: {
1025                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1026                                    break 'a tag;
1027                                }
1028
1029                                #[cfg(nightly)]
1030                                if proc_macro::is_available() {
1031                                    let op_span = op_span.unwrap();
1032                                    break 'a format!(
1033                                        "loc_{}_{}_{}_{}_{}",
1034                                        crate::pretty_span::make_source_path_relative(
1035                                            &op_span.file()
1036                                        )
1037                                        .display()
1038                                        .to_string()
1039                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1040                                        op_span.start().line(),
1041                                        op_span.start().column(),
1042                                        op_span.end().line(),
1043                                        op_span.end().column(),
1044                                    );
1045                                }
1046
1047                                format!(
1048                                    "loc_nopath_{}_{}_{}_{}",
1049                                    op_span.start().line,
1050                                    op_span.start().column,
1051                                    op_span.end().line,
1052                                    op_span.end().column
1053                                )
1054                            };
1055
1056                            let work_fn = format_ident!(
1057                                "{}__{}__{}",
1058                                ident,
1059                                op_name,
1060                                source_tag,
1061                                span = op_span
1062                            );
1063                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1064
1065                            let context_args = WriteContextArgs {
1066                                root: &root,
1067                                df_ident: df_local,
1068                                context,
1069                                subgraph_id,
1070                                node_id,
1071                                loop_id,
1072                                op_span,
1073                                op_tag: self.operator_tag.get(node_id).cloned(),
1074                                work_fn: &work_fn,
1075                                work_fn_async: &work_fn_async,
1076                                ident: &ident,
1077                                is_pull,
1078                                inputs: &inputs,
1079                                outputs: &outputs,
1080                                singleton_output_ident,
1081                                op_name,
1082                                op_inst,
1083                                arguments,
1084                                arguments_handles,
1085                            };
1086
1087                            let write_result =
1088                                (op_constraints.write_fn)(&context_args, diagnostics);
1089                            let OperatorWriteOutput {
1090                                write_prologue,
1091                                write_prologue_after,
1092                                write_iterator,
1093                                write_iterator_after,
1094                            } = write_result.unwrap_or_else(|()| {
1095                                assert!(
1096                                    diagnostics.has_error(),
1097                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1098                                    op_name,
1099                                );
1100                                OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1101                            });
1102
1103                            op_prologue_code.push(syn::parse_quote! {
1104                                #[allow(non_snake_case)]
1105                                #[inline(always)]
1106                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1107                                    thunk()
1108                                }
1109
1110                                #[allow(non_snake_case)]
1111                                #[inline(always)]
1112                                async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1113                                    thunk.await
1114                                }
1115                            });
1116                            op_prologue_code.push(write_prologue);
1117                            op_prologue_after_code.push(write_prologue_after);
1118                            subgraph_op_iter_code.push(write_iterator);
1119
1120                            if include_type_guards {
1121                                let type_guard = if is_pull {
1122                                    quote_spanned! {op_span=>
1123                                        let #ident = {
1124                                            #[allow(non_snake_case)]
1125                                            #[inline(always)]
1126                                            pub fn #work_fn<Item, Input: #root::futures::stream::Stream<Item = Item>>(input: Input) -> impl #root::futures::stream::Stream<Item = Item> {
1127                                                #root::pin_project_lite::pin_project! {
1128                                                    #[repr(transparent)]
1129                                                    struct Pull<Item, Input: #root::futures::stream::Stream<Item = Item>> {
1130                                                        #[pin]
1131                                                        inner: Input
1132                                                    }
1133                                                }
1134
1135                                                impl<Item, Input> #root::futures::stream::Stream for Pull<Item, Input>
1136                                                where
1137                                                    Input: #root::futures::stream::Stream<Item = Item>,
1138                                                {
1139                                                    type Item = Item;
1140
1141                                                    #[inline(always)]
1142                                                    fn poll_next(
1143                                                        self: ::std::pin::Pin<&mut Self>,
1144                                                        cx: &mut ::std::task::Context<'_>,
1145                                                    ) -> ::std::task::Poll<::std::option::Option<Self::Item>> {
1146                                                        #root::futures::stream::Stream::poll_next(self.project().inner, cx)
1147                                                    }
1148
1149                                                    #[inline(always)]
1150                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1151                                                        #root::futures::stream::Stream::size_hint(&self.inner)
1152                                                    }
1153                                                }
1154
1155                                                Pull {
1156                                                    inner: input
1157                                                }
1158                                            }
1159                                            #work_fn( #ident )
1160                                        };
1161                                    }
1162                                } else {
1163                                    quote_spanned! {op_span=>
1164                                        let #ident = {
1165                                            #[allow(non_snake_case)]
1166                                            #[inline(always)]
1167                                            pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1168                                            where
1169                                                Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1170                                            {
1171                                                #root::pin_project_lite::pin_project! {
1172                                                    #[repr(transparent)]
1173                                                    struct Push<Si> {
1174                                                        #[pin]
1175                                                        si: Si,
1176                                                    }
1177                                                }
1178
1179                                                impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1180                                                where
1181                                                    Si: #root::futures::sink::Sink<Item>,
1182                                                {
1183                                                    type Error = Si::Error;
1184
1185                                                    fn poll_ready(
1186                                                        self: ::std::pin::Pin<&mut Self>,
1187                                                        cx: &mut ::std::task::Context<'_>,
1188                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1189                                                        self.project().si.poll_ready(cx)
1190                                                    }
1191
1192                                                    fn start_send(
1193                                                        self: ::std::pin::Pin<&mut Self>,
1194                                                        item: Item,
1195                                                    ) -> ::std::result::Result<(), Self::Error> {
1196                                                        self.project().si.start_send(item)
1197                                                    }
1198
1199                                                    fn poll_flush(
1200                                                        self: ::std::pin::Pin<&mut Self>,
1201                                                        cx: &mut ::std::task::Context<'_>,
1202                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1203                                                        self.project().si.poll_flush(cx)
1204                                                    }
1205
1206                                                    fn poll_close(
1207                                                        self: ::std::pin::Pin<&mut Self>,
1208                                                        cx: &mut ::std::task::Context<'_>,
1209                                                    ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1210                                                        self.project().si.poll_close(cx)
1211                                                    }
1212                                                }
1213
1214                                                Push {
1215                                                    si
1216                                                }
1217                                            }
1218                                            #work_fn( #ident )
1219                                        };
1220                                    }
1221                                };
1222                                subgraph_op_iter_code.push(type_guard);
1223                            }
1224                            subgraph_op_iter_after_code.push(write_iterator_after);
1225                        }
1226                    }
1227
1228                    {
1229                        // Determine pull and push halves of the `Pivot`.
1230                        let pull_ident = if 0 < pull_to_push_idx {
1231                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1232                        } else {
1233                            // Entire subgraph is push (with a single recv/pull handoff input).
1234                            recv_ports[0].clone()
1235                        };
1236
1237                        #[rustfmt::skip]
1238                        let push_ident = if let Some(&node_id) =
1239                            subgraph_nodes.get(pull_to_push_idx)
1240                        {
1241                            self.node_as_ident(node_id, false)
1242                        } else if 1 == send_ports.len() {
1243                            // Entire subgraph is pull (with a single send/push handoff output).
1244                            send_ports[0].clone()
1245                        } else {
1246                            diagnostics.push(Diagnostic::spanned(
1247                                pull_ident.span(),
1248                                Level::Error,
1249                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1250                            ));
1251                            continue;
1252                        };
1253
1254                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1255                        let pivot_span = pull_ident
1256                            .span()
1257                            .join(push_ident.span())
1258                            .unwrap_or_else(|| push_ident.span());
1259                        let pivot_fn_ident =
1260                            Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1261                        let root = change_spans(root.clone(), pivot_span);
1262                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1263                            #[inline(always)]
1264                            fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1265                                -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1266                            where
1267                                Pull: #root::futures::stream::Stream<Item = Item>,
1268                                Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1269                            {
1270                                #root::sinktools::send_stream(pull, push)
1271                            }
1272                            (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap(/* Never */);
1273                        });
1274                    }
1275                };
1276
1277                let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1278                let stratum = Literal::usize_unsuffixed(
1279                    self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1280                );
1281                let laziness = self.subgraph_laziness(subgraph_id);
1282
1283                // Codegen: the loop that this subgraph is in `Some(<loop_id>)`, or `None` if not in a loop.
1284                let loop_id_opt = loop_id
1285                    .map(|loop_id| loop_id.as_ident(Span::call_site()))
1286                    .map(|ident| quote! { Some(#ident) })
1287                    .unwrap_or_else(|| quote! { None });
1288
1289                let sg_ident = subgraph_id.as_ident(Span::call_site());
1290
1291                subgraphs.push(quote! {
1292                    let #sg_ident = #df.add_subgraph_full(
1293                        #subgraph_name,
1294                        #stratum,
1295                        var_expr!( #( #recv_ports ),* ),
1296                        var_expr!( #( #send_ports ),* ),
1297                        #laziness,
1298                        #loop_id_opt,
1299                        async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1300                            #( #recv_port_code )*
1301                            #( #send_port_code )*
1302                            #( #subgraph_op_iter_code )*
1303                            #( #subgraph_op_iter_after_code )*
1304                        },
1305                    );
1306                });
1307            }
1308        }
1309
1310        if diagnostics.has_error() {
1311            return Err(std::mem::take(diagnostics));
1312        }
1313        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1314
1315        let loop_code = self.codegen_nested_loops(&df);
1316
1317        // These two are quoted separately here because iterators are lazily evaluated, so this
1318        // forces them to do their work. This work includes populating some data, namely
1319        // `diagonstics`, which we need to determine if it compilation was actually successful.
1320        // -Mingwei
1321        let code = quote! {
1322            #( #handoff_code )*
1323            #loop_code
1324            #( #op_prologue_code )*
1325            #( #subgraphs )*
1326            #( #op_prologue_after_code )*
1327        };
1328
1329        let meta_graph_json = serde_json::to_string(&self).unwrap();
1330        let meta_graph_json = Literal::string(&meta_graph_json);
1331
1332        let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1333        let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1334        let diagnostics_json = Literal::string(&diagnostics_json);
1335
1336        Ok(quote! {
1337            {
1338                #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1339                {
1340                    #prefix
1341
1342                    use #root::{var_expr, var_args};
1343
1344                    let mut #df = #root::scheduled::graph::Dfir::new();
1345                    #df.__assign_meta_graph(#meta_graph_json);
1346                    #df.__assign_diagnostics(#diagnostics_json);
1347
1348                    #code
1349
1350                    #df
1351                }
1352            }
1353        })
1354    }
1355
1356    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1357    /// those nodes will not be set in the returned map.
1358    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1359        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1360            .node_ids()
1361            .filter_map(|node_id| {
1362                let op_color = self.node_color(node_id)?;
1363                Some((node_id, op_color))
1364            })
1365            .collect();
1366
1367        // Fill in rest via subgraphs.
1368        for sg_nodes in self.subgraph_nodes.values() {
1369            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1370
1371            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1372                let is_pull = idx < pull_to_push_idx;
1373                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1374            }
1375        }
1376
1377        node_color_map
1378    }
1379
1380    /// Writes this graph as mermaid into a string.
1381    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1382        let mut output = String::new();
1383        self.write_mermaid(&mut output, write_config).unwrap();
1384        output
1385    }
1386
1387    /// Writes this graph as mermaid into the given `Write`.
1388    pub fn write_mermaid(
1389        &self,
1390        output: impl std::fmt::Write,
1391        write_config: &WriteConfig,
1392    ) -> std::fmt::Result {
1393        let mut graph_write = Mermaid::new(output);
1394        self.write_graph(&mut graph_write, write_config)
1395    }
1396
1397    /// Writes this graph as DOT (graphviz) into a string.
1398    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1399        let mut output = String::new();
1400        let mut graph_write = Dot::new(&mut output);
1401        self.write_graph(&mut graph_write, write_config).unwrap();
1402        output
1403    }
1404
1405    /// Writes this graph as DOT (graphviz) into the given `Write`.
1406    pub fn write_dot(
1407        &self,
1408        output: impl std::fmt::Write,
1409        write_config: &WriteConfig,
1410    ) -> std::fmt::Result {
1411        let mut graph_write = Dot::new(output);
1412        self.write_graph(&mut graph_write, write_config)
1413    }
1414
1415    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1416    pub(crate) fn write_graph<W>(
1417        &self,
1418        mut graph_write: W,
1419        write_config: &WriteConfig,
1420    ) -> Result<(), W::Err>
1421    where
1422        W: GraphWrite,
1423    {
1424        fn helper_edge_label(
1425            src_port: &PortIndexValue,
1426            dst_port: &PortIndexValue,
1427        ) -> Option<String> {
1428            let src_label = match src_port {
1429                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1430                PortIndexValue::Int(index) => Some(index.value.to_string()),
1431                _ => None,
1432            };
1433            let dst_label = match dst_port {
1434                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1435                PortIndexValue::Int(index) => Some(index.value.to_string()),
1436                _ => None,
1437            };
1438            let label = match (src_label, dst_label) {
1439                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1440                (Some(l1), None) => Some(l1),
1441                (None, Some(l2)) => Some(l2),
1442                (None, None) => None,
1443            };
1444            label
1445        }
1446
1447        // Make node color map one time.
1448        let node_color_map = self.node_color_map();
1449
1450        // Write prologue.
1451        graph_write.write_prologue()?;
1452
1453        // Define nodes.
1454        let mut skipped_handoffs = BTreeSet::new();
1455        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1456        for (node_id, node) in self.nodes() {
1457            if matches!(node, GraphNode::Handoff { .. }) {
1458                if write_config.no_handoffs {
1459                    skipped_handoffs.insert(node_id);
1460                    continue;
1461                } else {
1462                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1463                    let pred_sg = self.node_subgraph(pred_node);
1464                    let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1465                    let succ_sg = self.node_subgraph(succ_node);
1466                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1467                        && pred_sg == succ_sg
1468                    {
1469                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1470                    }
1471                }
1472            }
1473            graph_write.write_node_definition(
1474                node_id,
1475                &if write_config.op_short_text {
1476                    node.to_name_string()
1477                } else if write_config.op_text_no_imports {
1478                    // Remove any lines that start with "use" (imports)
1479                    let full_text = node.to_pretty_string();
1480                    let mut output = String::new();
1481                    for sentence in full_text.split('\n') {
1482                        if sentence.trim().starts_with("use") {
1483                            continue;
1484                        }
1485                        output.push('\n');
1486                        output.push_str(sentence);
1487                    }
1488                    output.into()
1489                } else {
1490                    node.to_pretty_string()
1491                },
1492                if write_config.no_pull_push {
1493                    None
1494                } else {
1495                    node_color_map.get(node_id).copied()
1496                },
1497            )?;
1498        }
1499
1500        // Write edges.
1501        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1502            // Handling for if `write_config.no_handoffs` true.
1503            if skipped_handoffs.contains(&src_id) {
1504                continue;
1505            }
1506
1507            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1508            if skipped_handoffs.contains(&dst_id) {
1509                let mut handoff_succs = self.node_successors(dst_id);
1510                assert_eq!(1, handoff_succs.len());
1511                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1512                dst_id = succ_node;
1513                dst_port = self.edge_ports(succ_edge).1;
1514            }
1515
1516            let label = helper_edge_label(src_port, dst_port);
1517            let delay_type = self
1518                .node_op_inst(dst_id)
1519                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1520            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1521        }
1522
1523        // Write reference edges.
1524        if !write_config.no_references {
1525            for dst_id in self.node_ids() {
1526                for src_ref_id in self
1527                    .node_singleton_references(dst_id)
1528                    .iter()
1529                    .copied()
1530                    .flatten()
1531                {
1532                    let delay_type = Some(DelayType::Stratum);
1533                    let label = None;
1534                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1535                }
1536            }
1537        }
1538
1539        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1540        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1541        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1542        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1543        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1544        //
1545        // (Note: `stratum` could also be included in this hierarchy, but it is being phased-out/deprecated in favor of
1546        // Flo loops).
1547
1548        // Loop -> Subgraphs
1549        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1550            let loop_id = if write_config.no_loops {
1551                None
1552            } else {
1553                self.subgraph_loop(sg_id)
1554            };
1555            (loop_id, sg_id)
1556        });
1557        let loop_subgraphs = into_group_map(loop_subgraphs);
1558        for (loop_id, subgraph_ids) in loop_subgraphs {
1559            if let Some(loop_id) = loop_id {
1560                graph_write.write_loop_start(loop_id)?;
1561            }
1562
1563            // Subgraph -> Varnames.
1564            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1565                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1566                    let opt_sg_id = if write_config.no_subgraphs {
1567                        None
1568                    } else {
1569                        Some(sg_id)
1570                    };
1571                    (opt_sg_id, (self.node_varname(node_id), node_id))
1572                })
1573            });
1574            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1575            for (sg_id, varnames) in subgraph_varnames_nodes {
1576                if let Some(sg_id) = sg_id {
1577                    let stratum = self.subgraph_stratum(sg_id).unwrap();
1578                    graph_write.write_subgraph_start(sg_id, stratum)?;
1579                }
1580
1581                // Varnames -> Nodes.
1582                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1583                    let varname = if write_config.no_varnames {
1584                        None
1585                    } else {
1586                        varname
1587                    };
1588                    (varname, node)
1589                });
1590                let varname_nodes = into_group_map(varname_nodes);
1591                for (varname, node_ids) in varname_nodes {
1592                    if let Some(varname) = varname {
1593                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1594                    }
1595
1596                    // Write all nodes.
1597                    for node_id in node_ids {
1598                        graph_write.write_node(node_id)?;
1599                    }
1600
1601                    if varname.is_some() {
1602                        graph_write.write_varname_end()?;
1603                    }
1604                }
1605
1606                if sg_id.is_some() {
1607                    graph_write.write_subgraph_end()?;
1608                }
1609            }
1610
1611            if loop_id.is_some() {
1612                graph_write.write_loop_end()?;
1613            }
1614        }
1615
1616        // Write epilogue.
1617        graph_write.write_epilogue()?;
1618
1619        Ok(())
1620    }
1621
1622    /// Convert back into surface syntax.
1623    pub fn surface_syntax_string(&self) -> String {
1624        let mut string = String::new();
1625        self.write_surface_syntax(&mut string).unwrap();
1626        string
1627    }
1628
1629    /// Convert back into surface syntax.
1630    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1631        for (key, node) in self.nodes.iter() {
1632            match node {
1633                GraphNode::Operator(op) => {
1634                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1635                }
1636                GraphNode::Handoff { .. } => {
1637                    writeln!(write, "// {:?} = <handoff>;", key.data())?;
1638                }
1639                GraphNode::ModuleBoundary { .. } => panic!(),
1640            }
1641        }
1642        writeln!(write)?;
1643        for (_e, (src_key, dst_key)) in self.graph.edges() {
1644            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1645        }
1646        Ok(())
1647    }
1648
1649    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1650    pub fn mermaid_string_flat(&self) -> String {
1651        let mut string = String::new();
1652        self.write_mermaid_flat(&mut string).unwrap();
1653        string
1654    }
1655
1656    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1657    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1658        writeln!(write, "flowchart TB")?;
1659        for (key, node) in self.nodes.iter() {
1660            match node {
1661                GraphNode::Operator(operator) => writeln!(
1662                    write,
1663                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1664                    span = PrettySpan(node.span()),
1665                    id = key.data(),
1666                    row_col = PrettyRowCol(node.span()),
1667                    code = operator
1668                        .to_token_stream()
1669                        .to_string()
1670                        .replace('&', "&amp;")
1671                        .replace('<', "&lt;")
1672                        .replace('>', "&gt;")
1673                        .replace('"', "&quot;")
1674                        .replace('\n', "<br>"),
1675                ),
1676                GraphNode::Handoff { .. } => {
1677                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1678                }
1679                GraphNode::ModuleBoundary { .. } => {
1680                    writeln!(
1681                        write,
1682                        r#"    {:?}{{"{}"}}"#,
1683                        key.data(),
1684                        MODULE_BOUNDARY_NODE_STR
1685                    )
1686                }
1687            }?;
1688        }
1689        writeln!(write)?;
1690        for (_e, (src_key, dst_key)) in self.graph.edges() {
1691            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
1692        }
1693        Ok(())
1694    }
1695}
1696
1697/// Loops
1698impl DfirGraph {
1699    /// Iterator over all loop IDs.
1700    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1701        self.loop_nodes.keys()
1702    }
1703
1704    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
1705    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1706        self.loop_nodes.iter()
1707    }
1708
1709    /// Create a new loop context, with the given parent loop (or `None`).
1710    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1711        let loop_id = self.loop_nodes.insert(Vec::new());
1712        self.loop_children.insert(loop_id, Vec::new());
1713        if let Some(parent_loop) = parent_loop {
1714            self.loop_parent.insert(loop_id, parent_loop);
1715            self.loop_children
1716                .get_mut(parent_loop)
1717                .unwrap()
1718                .push(loop_id);
1719        } else {
1720            self.root_loops.push(loop_id);
1721        }
1722        loop_id
1723    }
1724
1725    /// Get a node's loop context (or `None` for root).
1726    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1727        self.node_loops.get(node_id).copied()
1728    }
1729
1730    /// Get a subgraph's loop context (or `None` for root).
1731    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1732        let &node_id = self.subgraph(subgraph_id).first().unwrap();
1733        let out = self.node_loop(node_id);
1734        debug_assert!(
1735            self.subgraph(subgraph_id)
1736                .iter()
1737                .all(|&node_id| self.node_loop(node_id) == out),
1738            "Subgraph nodes should all have the same loop context."
1739        );
1740        out
1741    }
1742
1743    /// Get a loop context's parent loop context (or `None` for root).
1744    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1745        self.loop_parent.get(loop_id).copied()
1746    }
1747
1748    /// Get a loop context's child loops.
1749    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1750        self.loop_children.get(loop_id).unwrap()
1751    }
1752}
1753
1754/// Configuration for writing graphs.
1755#[derive(Clone, Debug, Default)]
1756#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1757pub struct WriteConfig {
1758    /// Subgraphs will not be rendered if set.
1759    #[cfg_attr(feature = "clap-derive", arg(long))]
1760    pub no_subgraphs: bool,
1761    /// Variable names will not be rendered if set.
1762    #[cfg_attr(feature = "clap-derive", arg(long))]
1763    pub no_varnames: bool,
1764    /// Will not render pull/push shapes if set.
1765    #[cfg_attr(feature = "clap-derive", arg(long))]
1766    pub no_pull_push: bool,
1767    /// Will not render handoffs if set.
1768    #[cfg_attr(feature = "clap-derive", arg(long))]
1769    pub no_handoffs: bool,
1770    /// Will not render singleton references if set.
1771    #[cfg_attr(feature = "clap-derive", arg(long))]
1772    pub no_references: bool,
1773    /// Will not render loops if set.
1774    #[cfg_attr(feature = "clap-derive", arg(long))]
1775    pub no_loops: bool,
1776
1777    /// Op text will only be their name instead of the whole source.
1778    #[cfg_attr(feature = "clap-derive", arg(long))]
1779    pub op_short_text: bool,
1780    /// Op text will exclude any line that starts with "use".
1781    #[cfg_attr(feature = "clap-derive", arg(long))]
1782    pub op_text_no_imports: bool,
1783}
1784
1785/// Enum for choosing between mermaid and dot graph writing.
1786#[derive(Copy, Clone, Debug)]
1787#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1788pub enum WriteGraphType {
1789    /// Mermaid graphs.
1790    Mermaid,
1791    /// Dot (Graphviz) graphs.
1792    Dot,
1793}
1794
1795/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
1796fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1797where
1798    K: Ord,
1799{
1800    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1801    for (k, v) in iter {
1802        out.entry(k).or_default().push(v);
1803    }
1804    out
1805}