1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics.persistence_args_span().unwrap_or(generics_span),
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics.type_args_span().unwrap_or(generics_span),
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381
382 let mut dummy_diagnostics = Diagnostics::new();
383 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
384 assert!(dummy_diagnostics.is_empty());
385
386 Some(OperatorInstance {
387 op_constraints,
388 input_ports: vec![input_port],
389 output_ports: vec![output_port],
390 singletons_referenced: operator.singletons_referenced.clone(),
391 generics,
392 arguments_pre: operator.args.clone(),
393 arguments_raw: operator.args_raw.clone(),
394 })
395 };
396
397 let node_id = self.nodes.insert(new_node);
399 if let Some(op_inst) = op_inst_opt {
401 self.operator_instances.insert(node_id, op_inst);
402 }
403 let (e0, e1) = self
405 .graph
406 .insert_intermediate_vertex(node_id, edge_id)
407 .unwrap();
408
409 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
411 self.ports
412 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
413 self.ports
414 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
415
416 (node_id, e1)
417 }
418
419 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
422 assert_eq!(
423 1,
424 self.node_degree_in(node_id),
425 "Removed intermediate node must have one predecessor"
426 );
427 assert_eq!(
428 1,
429 self.node_degree_out(node_id),
430 "Removed intermediate node must have one successor"
431 );
432 assert!(
433 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
434 "Should not remove intermediate node after subgraph partitioning"
435 );
436
437 assert!(self.nodes.remove(node_id).is_some());
438 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
439 self.graph.remove_intermediate_vertex(node_id).unwrap();
440 self.operator_instances.remove(node_id);
441 self.node_varnames.remove(node_id);
442
443 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
444 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
445 self.ports.insert(new_edge_id, (src_port, dst_port));
446 }
447
448 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
454 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
455 return Some(Color::Hoff);
456 }
457
458 if let GraphNode::Operator(op) = self.node(node_id)
460 && (op.name_string() == "resolve_futures_blocking"
461 || op.name_string() == "resolve_futures_blocking_ordered")
462 {
463 return Some(Color::Push);
464 }
465
466 let inn_degree = self.node_predecessor_nodes(node_id).count();
468 let out_degree = self.node_successor_nodes(node_id).count();
470
471 match (inn_degree, out_degree) {
472 (0, 0) => None, (0, 1) => Some(Color::Pull),
474 (1, 0) => Some(Color::Push),
475 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
477 (0 | 1, _many) => Some(Color::Push),
478 (_many, _to_many) => Some(Color::Comp),
479 }
480 }
481
482 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
484 self.operator_tag.insert(node_id, tag);
485 }
486}
487
488impl DfirGraph {
490 pub fn set_node_singleton_references(
493 &mut self,
494 node_id: GraphNodeId,
495 singletons_referenced: Vec<Option<GraphNodeId>>,
496 ) -> Option<Vec<Option<GraphNodeId>>> {
497 self.node_singleton_references
498 .insert(node_id, singletons_referenced)
499 }
500
501 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
504 self.node_singleton_references
505 .get(node_id)
506 .map(std::ops::Deref::deref)
507 .unwrap_or_default()
508 }
509}
510
511impl DfirGraph {
513 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
521 let mod_bound_nodes = self
522 .nodes()
523 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
524 .map(|(nid, _node)| nid)
525 .collect::<Vec<_>>();
526
527 for mod_bound_node in mod_bound_nodes {
528 self.remove_module_boundary(mod_bound_node)?;
529 }
530
531 Ok(())
532 }
533
534 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
538 assert!(
539 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
540 "Should not remove intermediate node after subgraph partitioning"
541 );
542
543 let mut mod_pred_ports = BTreeMap::new();
544 let mut mod_succ_ports = BTreeMap::new();
545
546 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
547 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
548 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
549 }
550
551 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
552 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
553 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
554 }
555
556 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
557 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
558 {
559 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
561 panic!();
562 };
563
564 if *input {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports into the module did not match. input: {:?}, expected: {:?}",
570 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
572 ),
573 });
574 } else {
575 return Err(Diagnostic {
576 span: *import_expr,
577 level: Level::Error,
578 message: format!(
579 "The ports out of the module did not match. output: {:?}, expected: {:?}",
580 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
581 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
582 ),
583 });
584 }
585 }
586
587 for (port, (pred_edge, pred_port)) in mod_pred_ports {
588 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
589
590 let (src, _) = self.edge(pred_edge);
591 let (_, dst) = self.edge(succ_edge);
592 self.remove_edge(pred_edge);
593 self.remove_edge(succ_edge);
594
595 let new_edge_id = self.graph.insert_edge(src, dst);
596 self.ports.insert(new_edge_id, (pred_port, succ_port));
597 }
598
599 self.graph.remove_vertex(mod_bound_node);
600 self.nodes.remove(mod_bound_node);
601
602 Ok(())
603 }
604}
605
606impl DfirGraph {
608 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
610 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
611 (src, dst)
612 }
613
614 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
616 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
617 (src_port, dst_port)
618 }
619
620 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
622 self.graph.edge_ids()
623 }
624
625 pub fn edges(
627 &self,
628 ) -> impl '_
629 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
630 + FusedIterator
631 + Clone
632 + Debug {
633 self.graph.edges()
634 }
635
636 pub fn insert_edge(
638 &mut self,
639 src: GraphNodeId,
640 src_port: PortIndexValue,
641 dst: GraphNodeId,
642 dst_port: PortIndexValue,
643 ) -> GraphEdgeId {
644 let edge_id = self.graph.insert_edge(src, dst);
645 self.ports.insert(edge_id, (src_port, dst_port));
646 edge_id
647 }
648
649 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
651 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
652 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
653 }
654}
655
656impl DfirGraph {
658 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
660 self.subgraph_nodes
661 .get(subgraph_id)
662 .expect("Subgraph not found.")
663 }
664
665 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
667 self.subgraph_nodes.keys()
668 }
669
670 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
672 self.subgraph_nodes.iter()
673 }
674
675 pub fn insert_subgraph(
677 &mut self,
678 node_ids: Vec<GraphNodeId>,
679 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
680 for &node_id in node_ids.iter() {
682 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
683 return Err((node_id, old_sg_id));
684 }
685 }
686 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
687 for &node_id in node_ids.iter() {
688 self.node_subgraph.insert(node_id, sg_id);
689 }
690 node_ids
691 });
692
693 Ok(subgraph_id)
694 }
695
696 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
698 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
699 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
700 true
701 } else {
702 false
703 }
704 }
705
706 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
708 self.subgraph_stratum.get(sg_id).copied()
709 }
710
711 pub fn set_subgraph_stratum(
713 &mut self,
714 sg_id: GraphSubgraphId,
715 stratum: usize,
716 ) -> Option<usize> {
717 self.subgraph_stratum.insert(sg_id, stratum)
718 }
719
720 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
722 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
723 }
724
725 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
727 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
728 }
729
730 pub fn max_stratum(&self) -> Option<usize> {
732 self.subgraph_stratum.values().copied().max()
733 }
734
735 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
737 subgraph_nodes
738 .iter()
739 .position(|&node_id| {
740 self.node_color(node_id)
741 .is_some_and(|color| Color::Pull != color)
742 })
743 .unwrap_or(subgraph_nodes.len())
744 }
745}
746
747impl DfirGraph {
749 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
751 let name = match &self.nodes[node_id] {
752 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
753 GraphNode::Handoff { .. } => format!(
754 "hoff_{:?}_{}",
755 node_id.data(),
756 if is_pred { "recv" } else { "send" }
757 ),
758 GraphNode::ModuleBoundary { .. } => panic!(),
759 };
760 let span = match (is_pred, &self.nodes[node_id]) {
761 (_, GraphNode::Operator(operator)) => operator.span(),
762 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
763 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
764 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
765 };
766 Ident::new(&name, span)
767 }
768
769 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
771 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
772 }
773
774 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
776 self.node_singleton_references(node_id)
777 .iter()
778 .map(|singleton_node_id| {
779 self.node_as_singleton_ident(
781 singleton_node_id
782 .expect("Expected singleton to be resolved but was not, this is a bug."),
783 span,
784 )
785 })
786 .collect::<Vec<_>>()
787 }
788
789 fn helper_collect_subgraph_handoffs(
792 &self,
793 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
794 let mut subgraph_handoffs: SecondaryMap<
796 GraphSubgraphId,
797 (Vec<GraphNodeId>, Vec<GraphNodeId>),
798 > = self
799 .subgraph_nodes
800 .keys()
801 .map(|k| (k, Default::default()))
802 .collect();
803
804 for (hoff_id, node) in self.nodes() {
806 if !matches!(node, GraphNode::Handoff { .. }) {
807 continue;
808 }
809 for (_edge, succ_id) in self.node_successors(hoff_id) {
811 let succ_sg = self.node_subgraph(succ_id).unwrap();
812 subgraph_handoffs[succ_sg].0.push(hoff_id);
813 }
814 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
816 let pred_sg = self.node_subgraph(pred_id).unwrap();
817 subgraph_handoffs[pred_sg].1.push(hoff_id);
818 }
819 }
820
821 subgraph_handoffs
822 }
823
824 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
826 let mut out = TokenStream::new();
828 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
829 while let Some(loop_id) = queue.pop_front() {
830 let parent_opt = self
831 .loop_parent(loop_id)
832 .map(|loop_id| loop_id.as_ident(Span::call_site()))
833 .map(|ident| quote! { Some(#ident) })
834 .unwrap_or_else(|| quote! { None });
835 let loop_name = loop_id.as_ident(Span::call_site());
836 out.append_all(quote! {
837 let #loop_name = #df.add_loop(#parent_opt);
838 });
839 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
840 }
841 out
842 }
843
844 pub fn as_code(
848 &self,
849 root: &TokenStream,
850 include_type_guards: bool,
851 prefix: TokenStream,
852 diagnostics: &mut Diagnostics,
853 ) -> Result<TokenStream, Diagnostics> {
854 let df = Ident::new(GRAPH, Span::call_site());
855 let context = Ident::new(CONTEXT, Span::call_site());
856
857 let handoff_code = self
859 .nodes
860 .iter()
861 .filter_map(|(node_id, node)| match node {
862 GraphNode::Operator(_) => None,
863 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
864 GraphNode::ModuleBoundary { .. } => panic!(),
865 })
866 .map(|(node_id, (src_span, dst_span))| {
867 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
868 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
869 let span = src_span.join(dst_span).unwrap_or(src_span);
870 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
871 hoff_name.set_span(span);
872 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
873 quote_spanned! {span=>
874 let (#ident_send, #ident_recv) =
875 #df.make_edge::<_, #hoff_type>(#hoff_name);
876 }
877 });
878
879 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
880
881 let (subgraphs_without_preds, subgraphs_with_preds) = self
883 .subgraph_nodes
884 .iter()
885 .partition::<Vec<_>, _>(|(_, nodes)| {
886 nodes
887 .iter()
888 .any(|&node_id| self.node_degree_in(node_id) == 0)
889 });
890
891 let mut op_prologue_code = Vec::new();
892 let mut op_prologue_after_code = Vec::new();
893 let mut subgraphs = Vec::new();
894 {
895 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
896 .iter()
897 .chain(subgraphs_with_preds.iter())
898 {
899 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
900 let recv_ports: Vec<Ident> = recv_hoffs
901 .iter()
902 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
903 .collect();
904 let send_ports: Vec<Ident> = send_hoffs
905 .iter()
906 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
907 .collect();
908
909 let recv_port_code = recv_ports.iter().map(|ident| {
910 quote_spanned! {ident.span()=>
911 let mut #ident = #ident.borrow_mut_swap();
912 let #ident = #root::dfir_pipes::pull::iter(#ident.drain(..));
913 }
914 });
915 let send_port_code = send_ports.iter().map(|ident| {
916 quote_spanned! {ident.span()=>
917 let #ident = #root::dfir_pipes::push::for_each(|v| {
918 #ident.give(Some(v));
919 });
920 }
921 });
922
923 let loop_id = self
924 .node_loop(subgraph_nodes[0]);
926
927 let mut subgraph_op_iter_code = Vec::new();
928 let mut subgraph_op_iter_after_code = Vec::new();
929 {
930 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
931
932 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
933 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
934
935 for (idx, &node_id) in nodes_iter.enumerate() {
936 let node = &self.nodes[node_id];
937 assert!(
938 matches!(node, GraphNode::Operator(_)),
939 "Handoffs are not part of subgraphs."
940 );
941 let op_inst = &self.operator_instances[node_id];
942
943 let op_span = node.span();
944 let op_name = op_inst.op_constraints.name;
945 let root = change_spans(root.clone(), op_span);
947 let op_constraints = OPERATORS
949 .iter()
950 .find(|op| op_name == op.name)
951 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
952
953 let ident = self.node_as_ident(node_id, false);
954
955 {
956 let mut input_edges = self
959 .graph
960 .predecessor_edges(node_id)
961 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
962 .collect::<Vec<_>>();
963 input_edges.sort();
965
966 let inputs = input_edges
967 .iter()
968 .map(|&(_port, edge_id)| {
969 let (pred, _) = self.edge(edge_id);
970 self.node_as_ident(pred, true)
971 })
972 .collect::<Vec<_>>();
973
974 let mut output_edges = self
976 .graph
977 .successor_edges(node_id)
978 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
979 .collect::<Vec<_>>();
980 output_edges.sort();
982
983 let outputs = output_edges
984 .iter()
985 .map(|&(_port, edge_id)| {
986 let (_, succ) = self.edge(edge_id);
987 self.node_as_ident(succ, false)
988 })
989 .collect::<Vec<_>>();
990
991 let is_pull = idx < pull_to_push_idx;
992
993 let singleton_output_ident = &if op_constraints.has_singleton_output {
994 self.node_as_singleton_ident(node_id, op_span)
995 } else {
996 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
998 };
999
1000 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1009 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1010
1011 let singletons_resolved =
1012 self.helper_resolve_singletons(node_id, op_span);
1013 let arguments = &process_singletons::postprocess_singletons(
1014 op_inst.arguments_raw.clone(),
1015 singletons_resolved.clone(),
1016 context,
1017 );
1018 let arguments_handles =
1019 &process_singletons::postprocess_singletons_handles(
1020 op_inst.arguments_raw.clone(),
1021 singletons_resolved.clone(),
1022 );
1023
1024 let source_tag = 'a: {
1025 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1026 break 'a tag;
1027 }
1028
1029 #[cfg(nightly)]
1030 if proc_macro::is_available() {
1031 let op_span = op_span.unwrap();
1032 break 'a format!(
1033 "loc_{}_{}_{}_{}_{}",
1034 crate::pretty_span::make_source_path_relative(
1035 &op_span.file()
1036 )
1037 .display()
1038 .to_string()
1039 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1040 op_span.start().line(),
1041 op_span.start().column(),
1042 op_span.end().line(),
1043 op_span.end().column(),
1044 );
1045 }
1046
1047 format!(
1048 "loc_nopath_{}_{}_{}_{}",
1049 op_span.start().line,
1050 op_span.start().column,
1051 op_span.end().line,
1052 op_span.end().column
1053 )
1054 };
1055
1056 let work_fn = format_ident!(
1057 "{}__{}__{}",
1058 ident,
1059 op_name,
1060 source_tag,
1061 span = op_span
1062 );
1063 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1064
1065 let context_args = WriteContextArgs {
1066 root: &root,
1067 df_ident: df_local,
1068 context,
1069 subgraph_id,
1070 node_id,
1071 loop_id,
1072 op_span,
1073 op_tag: self.operator_tag.get(node_id).cloned(),
1074 work_fn: &work_fn,
1075 work_fn_async: &work_fn_async,
1076 ident: &ident,
1077 is_pull,
1078 inputs: &inputs,
1079 outputs: &outputs,
1080 singleton_output_ident,
1081 op_name,
1082 op_inst,
1083 arguments,
1084 arguments_handles,
1085 };
1086
1087 let write_result =
1088 (op_constraints.write_fn)(&context_args, diagnostics);
1089 let OperatorWriteOutput {
1090 write_prologue,
1091 write_prologue_after,
1092 write_iterator,
1093 write_iterator_after,
1094 } = write_result.unwrap_or_else(|()| {
1095 assert!(
1096 diagnostics.has_error(),
1097 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1098 op_name,
1099 );
1100 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1101 });
1102
1103 op_prologue_code.push(syn::parse_quote! {
1104 #[allow(non_snake_case)]
1105 #[inline(always)]
1106 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1107 thunk()
1108 }
1109
1110 #[allow(non_snake_case)]
1111 #[inline(always)]
1112 async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1113 thunk.await
1114 }
1115 });
1116 op_prologue_code.push(write_prologue);
1117 op_prologue_after_code.push(write_prologue_after);
1118 subgraph_op_iter_code.push(write_iterator);
1119
1120 if include_type_guards {
1121 let type_guard = if is_pull {
1122 quote_spanned! {op_span=>
1123 let #ident = {
1124 #[allow(non_snake_case)]
1125 #[inline(always)]
1126 pub fn #work_fn<Item, Input>(input: Input)
1127 -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1128 where
1129 Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1130 {
1131 #root::pin_project_lite::pin_project! {
1132 #[repr(transparent)]
1133 struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1134 #[pin]
1135 inner: Input
1136 }
1137 }
1138
1139 impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1140 where
1141 Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1142 {
1143 type Ctx<'ctx> = Input::Ctx<'ctx>;
1144
1145 type Item = Item;
1146 type Meta = Input::Meta;
1147 type CanPend = Input::CanPend;
1148 type CanEnd = Input::CanEnd;
1149
1150 #[inline(always)]
1151 fn pull(
1152 self: ::std::pin::Pin<&mut Self>,
1153 ctx: &mut Self::Ctx<'_>,
1154 ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1155 #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1156 }
1157
1158 #[inline(always)]
1159 fn size_hint(&self) -> (usize, Option<usize>) {
1160 #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1161 }
1162 }
1163
1164 Pull {
1165 inner: input
1166 }
1167 }
1168 #work_fn::<_, _>( #ident )
1169 };
1170 }
1171 } else {
1172 quote_spanned! {op_span=>
1173 let #ident = {
1174 #[allow(non_snake_case)]
1175 #[inline(always)]
1176 pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1177 where
1178 Psh: #root::dfir_pipes::push::Push<Item, ()>
1179 {
1180 #root::pin_project_lite::pin_project! {
1181 #[repr(transparent)]
1182 struct PushGuard<Psh> {
1183 #[pin]
1184 inner: Psh,
1185 }
1186 }
1187
1188 impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1189 where
1190 Psh: #root::dfir_pipes::push::Push<Item, ()>,
1191 {
1192 type Ctx<'ctx> = Psh::Ctx<'ctx>;
1193
1194 type CanPend = Psh::CanPend;
1195
1196 #[inline(always)]
1197 fn poll_ready(
1198 self: ::std::pin::Pin<&mut Self>,
1199 ctx: &mut Self::Ctx<'_>,
1200 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1201 #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1202 }
1203
1204 #[inline(always)]
1205 fn start_send(
1206 self: ::std::pin::Pin<&mut Self>,
1207 item: Item,
1208 meta: (),
1209 ) {
1210 #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1211 }
1212
1213 #[inline(always)]
1214 fn poll_flush(
1215 self: ::std::pin::Pin<&mut Self>,
1216 ctx: &mut Self::Ctx<'_>,
1217 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1218 #root::dfir_pipes::push::Push::poll_flush(self.project().inner, ctx)
1219 }
1220 }
1221
1222 PushGuard {
1223 inner: psh
1224 }
1225 }
1226 #work_fn( #ident )
1227 };
1228 }
1229 };
1230 subgraph_op_iter_code.push(type_guard);
1231 }
1232 subgraph_op_iter_after_code.push(write_iterator_after);
1233 }
1234 }
1235
1236 {
1237 let pull_ident = if 0 < pull_to_push_idx {
1239 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1240 } else {
1241 recv_ports[0].clone()
1243 };
1244
1245 #[rustfmt::skip]
1246 let push_ident = if let Some(&node_id) =
1247 subgraph_nodes.get(pull_to_push_idx)
1248 {
1249 self.node_as_ident(node_id, false)
1250 } else if 1 == send_ports.len() {
1251 send_ports[0].clone()
1253 } else {
1254 diagnostics.push(Diagnostic::spanned(
1255 pull_ident.span(),
1256 Level::Error,
1257 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1258 ));
1259 continue;
1260 };
1261
1262 let pivot_span = pull_ident
1264 .span()
1265 .join(push_ident.span())
1266 .unwrap_or_else(|| push_ident.span());
1267 let pivot_fn_ident =
1268 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1269 let root = change_spans(root.clone(), pivot_span);
1270 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1271 #[inline(always)]
1272 fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1273 -> impl ::std::future::Future<Output = ()>
1274 where
1275 Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1276 Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1277 {
1278 #root::dfir_pipes::pull::Pull::send_push(pull, push)
1279 }
1280 (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1281 });
1282 }
1283 };
1284
1285 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1286 let stratum = Literal::usize_unsuffixed(
1287 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1288 );
1289 let laziness = self.subgraph_laziness(subgraph_id);
1290
1291 let loop_id_opt = loop_id
1293 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1294 .map(|ident| quote! { Some(#ident) })
1295 .unwrap_or_else(|| quote! { None });
1296
1297 let sg_ident = subgraph_id.as_ident(Span::call_site());
1298
1299 subgraphs.push(quote! {
1300 let #sg_ident = #df.add_subgraph_full(
1301 #subgraph_name,
1302 #stratum,
1303 var_expr!( #( #recv_ports ),* ),
1304 var_expr!( #( #send_ports ),* ),
1305 #laziness,
1306 #loop_id_opt,
1307 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1308 #( #recv_port_code )*
1309 #( #send_port_code )*
1310 #( #subgraph_op_iter_code )*
1311 #( #subgraph_op_iter_after_code )*
1312 },
1313 );
1314 });
1315 }
1316 }
1317
1318 if diagnostics.has_error() {
1319 return Err(std::mem::take(diagnostics));
1320 }
1321 let _ = diagnostics; let loop_code = self.codegen_nested_loops(&df);
1324
1325 let code = quote! {
1330 #( #handoff_code )*
1331 #loop_code
1332 #( #op_prologue_code )*
1333 #( #subgraphs )*
1334 #( #op_prologue_after_code )*
1335 };
1336
1337 let meta_graph_json = serde_json::to_string(&self).unwrap();
1338 let meta_graph_json = Literal::string(&meta_graph_json);
1339
1340 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1341 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1342 let diagnostics_json = Literal::string(&diagnostics_json);
1343
1344 Ok(quote! {
1345 {
1346 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1347 {
1348 #prefix
1349
1350 use #root::{var_expr, var_args};
1351
1352 let mut #df = #root::scheduled::graph::Dfir::new();
1353 #df.__assign_meta_graph(#meta_graph_json);
1354 #df.__assign_diagnostics(#diagnostics_json);
1355
1356 #code
1357
1358 #df
1359 }
1360 }
1361 })
1362 }
1363
1364 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1367 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1368 .node_ids()
1369 .filter_map(|node_id| {
1370 let op_color = self.node_color(node_id)?;
1371 Some((node_id, op_color))
1372 })
1373 .collect();
1374
1375 for sg_nodes in self.subgraph_nodes.values() {
1377 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1378
1379 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1380 let is_pull = idx < pull_to_push_idx;
1381 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1382 }
1383 }
1384
1385 node_color_map
1386 }
1387
1388 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1390 let mut output = String::new();
1391 self.write_mermaid(&mut output, write_config).unwrap();
1392 output
1393 }
1394
1395 pub fn write_mermaid(
1397 &self,
1398 output: impl std::fmt::Write,
1399 write_config: &WriteConfig,
1400 ) -> std::fmt::Result {
1401 let mut graph_write = Mermaid::new(output);
1402 self.write_graph(&mut graph_write, write_config)
1403 }
1404
1405 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1407 let mut output = String::new();
1408 let mut graph_write = Dot::new(&mut output);
1409 self.write_graph(&mut graph_write, write_config).unwrap();
1410 output
1411 }
1412
1413 pub fn write_dot(
1415 &self,
1416 output: impl std::fmt::Write,
1417 write_config: &WriteConfig,
1418 ) -> std::fmt::Result {
1419 let mut graph_write = Dot::new(output);
1420 self.write_graph(&mut graph_write, write_config)
1421 }
1422
1423 pub(crate) fn write_graph<W>(
1425 &self,
1426 mut graph_write: W,
1427 write_config: &WriteConfig,
1428 ) -> Result<(), W::Err>
1429 where
1430 W: GraphWrite,
1431 {
1432 fn helper_edge_label(
1433 src_port: &PortIndexValue,
1434 dst_port: &PortIndexValue,
1435 ) -> Option<String> {
1436 let src_label = match src_port {
1437 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1438 PortIndexValue::Int(index) => Some(index.value.to_string()),
1439 _ => None,
1440 };
1441 let dst_label = match dst_port {
1442 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1443 PortIndexValue::Int(index) => Some(index.value.to_string()),
1444 _ => None,
1445 };
1446 let label = match (src_label, dst_label) {
1447 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1448 (Some(l1), None) => Some(l1),
1449 (None, Some(l2)) => Some(l2),
1450 (None, None) => None,
1451 };
1452 label
1453 }
1454
1455 let node_color_map = self.node_color_map();
1457
1458 graph_write.write_prologue()?;
1460
1461 let mut skipped_handoffs = BTreeSet::new();
1463 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1464 for (node_id, node) in self.nodes() {
1465 if matches!(node, GraphNode::Handoff { .. }) {
1466 if write_config.no_handoffs {
1467 skipped_handoffs.insert(node_id);
1468 continue;
1469 } else {
1470 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1471 let pred_sg = self.node_subgraph(pred_node);
1472 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1473 let succ_sg = self.node_subgraph(succ_node);
1474 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1475 && pred_sg == succ_sg
1476 {
1477 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1478 }
1479 }
1480 }
1481 graph_write.write_node_definition(
1482 node_id,
1483 &if write_config.op_short_text {
1484 node.to_name_string()
1485 } else if write_config.op_text_no_imports {
1486 let full_text = node.to_pretty_string();
1488 let mut output = String::new();
1489 for sentence in full_text.split('\n') {
1490 if sentence.trim().starts_with("use") {
1491 continue;
1492 }
1493 output.push('\n');
1494 output.push_str(sentence);
1495 }
1496 output.into()
1497 } else {
1498 node.to_pretty_string()
1499 },
1500 if write_config.no_pull_push {
1501 None
1502 } else {
1503 node_color_map.get(node_id).copied()
1504 },
1505 )?;
1506 }
1507
1508 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1510 if skipped_handoffs.contains(&src_id) {
1512 continue;
1513 }
1514
1515 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1516 if skipped_handoffs.contains(&dst_id) {
1517 let mut handoff_succs = self.node_successors(dst_id);
1518 assert_eq!(1, handoff_succs.len());
1519 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1520 dst_id = succ_node;
1521 dst_port = self.edge_ports(succ_edge).1;
1522 }
1523
1524 let label = helper_edge_label(src_port, dst_port);
1525 let delay_type = self
1526 .node_op_inst(dst_id)
1527 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1528 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1529 }
1530
1531 if !write_config.no_references {
1533 for dst_id in self.node_ids() {
1534 for src_ref_id in self
1535 .node_singleton_references(dst_id)
1536 .iter()
1537 .copied()
1538 .flatten()
1539 {
1540 let delay_type = Some(DelayType::Stratum);
1541 let label = None;
1542 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1543 }
1544 }
1545 }
1546
1547 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1558 let loop_id = if write_config.no_loops {
1559 None
1560 } else {
1561 self.subgraph_loop(sg_id)
1562 };
1563 (loop_id, sg_id)
1564 });
1565 let loop_subgraphs = into_group_map(loop_subgraphs);
1566 for (loop_id, subgraph_ids) in loop_subgraphs {
1567 if let Some(loop_id) = loop_id {
1568 graph_write.write_loop_start(loop_id)?;
1569 }
1570
1571 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1573 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1574 let opt_sg_id = if write_config.no_subgraphs {
1575 None
1576 } else {
1577 Some(sg_id)
1578 };
1579 (opt_sg_id, (self.node_varname(node_id), node_id))
1580 })
1581 });
1582 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1583 for (sg_id, varnames) in subgraph_varnames_nodes {
1584 if let Some(sg_id) = sg_id {
1585 let stratum = self.subgraph_stratum(sg_id).unwrap();
1586 graph_write.write_subgraph_start(sg_id, stratum)?;
1587 }
1588
1589 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1591 let varname = if write_config.no_varnames {
1592 None
1593 } else {
1594 varname
1595 };
1596 (varname, node)
1597 });
1598 let varname_nodes = into_group_map(varname_nodes);
1599 for (varname, node_ids) in varname_nodes {
1600 if let Some(varname) = varname {
1601 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1602 }
1603
1604 for node_id in node_ids {
1606 graph_write.write_node(node_id)?;
1607 }
1608
1609 if varname.is_some() {
1610 graph_write.write_varname_end()?;
1611 }
1612 }
1613
1614 if sg_id.is_some() {
1615 graph_write.write_subgraph_end()?;
1616 }
1617 }
1618
1619 if loop_id.is_some() {
1620 graph_write.write_loop_end()?;
1621 }
1622 }
1623
1624 graph_write.write_epilogue()?;
1626
1627 Ok(())
1628 }
1629
1630 pub fn surface_syntax_string(&self) -> String {
1632 let mut string = String::new();
1633 self.write_surface_syntax(&mut string).unwrap();
1634 string
1635 }
1636
1637 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1639 for (key, node) in self.nodes.iter() {
1640 match node {
1641 GraphNode::Operator(op) => {
1642 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1643 }
1644 GraphNode::Handoff { .. } => {
1645 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1646 }
1647 GraphNode::ModuleBoundary { .. } => panic!(),
1648 }
1649 }
1650 writeln!(write)?;
1651 for (_e, (src_key, dst_key)) in self.graph.edges() {
1652 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1653 }
1654 Ok(())
1655 }
1656
1657 pub fn mermaid_string_flat(&self) -> String {
1659 let mut string = String::new();
1660 self.write_mermaid_flat(&mut string).unwrap();
1661 string
1662 }
1663
1664 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1666 writeln!(write, "flowchart TB")?;
1667 for (key, node) in self.nodes.iter() {
1668 match node {
1669 GraphNode::Operator(operator) => writeln!(
1670 write,
1671 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1672 span = PrettySpan(node.span()),
1673 id = key.data(),
1674 row_col = PrettyRowCol(node.span()),
1675 code = operator
1676 .to_token_stream()
1677 .to_string()
1678 .replace('&', "&")
1679 .replace('<', "<")
1680 .replace('>', ">")
1681 .replace('"', """)
1682 .replace('\n', "<br>"),
1683 ),
1684 GraphNode::Handoff { .. } => {
1685 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1686 }
1687 GraphNode::ModuleBoundary { .. } => {
1688 writeln!(
1689 write,
1690 r#" {:?}{{"{}"}}"#,
1691 key.data(),
1692 MODULE_BOUNDARY_NODE_STR
1693 )
1694 }
1695 }?;
1696 }
1697 writeln!(write)?;
1698 for (_e, (src_key, dst_key)) in self.graph.edges() {
1699 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1700 }
1701 Ok(())
1702 }
1703}
1704
1705impl DfirGraph {
1707 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1709 self.loop_nodes.keys()
1710 }
1711
1712 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1714 self.loop_nodes.iter()
1715 }
1716
1717 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1719 let loop_id = self.loop_nodes.insert(Vec::new());
1720 self.loop_children.insert(loop_id, Vec::new());
1721 if let Some(parent_loop) = parent_loop {
1722 self.loop_parent.insert(loop_id, parent_loop);
1723 self.loop_children
1724 .get_mut(parent_loop)
1725 .unwrap()
1726 .push(loop_id);
1727 } else {
1728 self.root_loops.push(loop_id);
1729 }
1730 loop_id
1731 }
1732
1733 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1735 self.node_loops.get(node_id).copied()
1736 }
1737
1738 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1740 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1741 let out = self.node_loop(node_id);
1742 debug_assert!(
1743 self.subgraph(subgraph_id)
1744 .iter()
1745 .all(|&node_id| self.node_loop(node_id) == out),
1746 "Subgraph nodes should all have the same loop context."
1747 );
1748 out
1749 }
1750
1751 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1753 self.loop_parent.get(loop_id).copied()
1754 }
1755
1756 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1758 self.loop_children.get(loop_id).unwrap()
1759 }
1760}
1761
1762#[derive(Clone, Debug, Default)]
1764#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1765pub struct WriteConfig {
1766 #[cfg_attr(feature = "clap-derive", arg(long))]
1768 pub no_subgraphs: bool,
1769 #[cfg_attr(feature = "clap-derive", arg(long))]
1771 pub no_varnames: bool,
1772 #[cfg_attr(feature = "clap-derive", arg(long))]
1774 pub no_pull_push: bool,
1775 #[cfg_attr(feature = "clap-derive", arg(long))]
1777 pub no_handoffs: bool,
1778 #[cfg_attr(feature = "clap-derive", arg(long))]
1780 pub no_references: bool,
1781 #[cfg_attr(feature = "clap-derive", arg(long))]
1783 pub no_loops: bool,
1784
1785 #[cfg_attr(feature = "clap-derive", arg(long))]
1787 pub op_short_text: bool,
1788 #[cfg_attr(feature = "clap-derive", arg(long))]
1790 pub op_text_no_imports: bool,
1791}
1792
1793#[derive(Copy, Clone, Debug)]
1795#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1796pub enum WriteGraphType {
1797 Mermaid,
1799 Dot,
1801}
1802
1803fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1805where
1806 K: Ord,
1807{
1808 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1809 for (k, v) in iter {
1810 out.entry(k).or_default().push(v);
1811 }
1812 out
1813}