1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet, VecDeque};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, TokenStreamExt, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18 DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19 null_write_iterator_fn,
20};
21use super::{
22 CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23 GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24 Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41 nodes: SlotMap<GraphNodeId, GraphNode>,
43
44 #[serde(skip)]
47 operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48 operator_tag: SecondaryMap<GraphNodeId, String>,
50 graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52 ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55 node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57 loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59 loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61 root_loops: Vec<GraphLoopId>,
63 loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66 node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69 subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71 subgraph_stratum: SecondaryMap<GraphSubgraphId, usize>,
73
74 node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
76 node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
78
79 subgraph_laziness: SecondaryMap<GraphSubgraphId, bool>,
83}
84
85impl DfirGraph {
87 pub fn new() -> Self {
89 Default::default()
90 }
91}
92
93impl DfirGraph {
95 pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
97 self.nodes.get(node_id).expect("Node not found.")
98 }
99
100 pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
105 self.operator_instances.get(node_id)
106 }
107
108 pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
110 self.node_varnames.get(node_id)
111 }
112
113 pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
115 self.node_subgraph.get(node_id).copied()
116 }
117
118 pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
120 self.graph.degree_in(node_id)
121 }
122
123 pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
125 self.graph.degree_out(node_id)
126 }
127
128 pub fn node_successors(
130 &self,
131 src: GraphNodeId,
132 ) -> impl '_
133 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
134 + ExactSizeIterator
135 + FusedIterator
136 + Clone
137 + Debug {
138 self.graph.successors(src)
139 }
140
141 pub fn node_predecessors(
143 &self,
144 dst: GraphNodeId,
145 ) -> impl '_
146 + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
147 + ExactSizeIterator
148 + FusedIterator
149 + Clone
150 + Debug {
151 self.graph.predecessors(dst)
152 }
153
154 pub fn node_successor_edges(
156 &self,
157 src: GraphNodeId,
158 ) -> impl '_
159 + DoubleEndedIterator<Item = GraphEdgeId>
160 + ExactSizeIterator
161 + FusedIterator
162 + Clone
163 + Debug {
164 self.graph.successor_edges(src)
165 }
166
167 pub fn node_predecessor_edges(
169 &self,
170 dst: GraphNodeId,
171 ) -> impl '_
172 + DoubleEndedIterator<Item = GraphEdgeId>
173 + ExactSizeIterator
174 + FusedIterator
175 + Clone
176 + Debug {
177 self.graph.predecessor_edges(dst)
178 }
179
180 pub fn node_successor_nodes(
182 &self,
183 src: GraphNodeId,
184 ) -> impl '_
185 + DoubleEndedIterator<Item = GraphNodeId>
186 + ExactSizeIterator
187 + FusedIterator
188 + Clone
189 + Debug {
190 self.graph.successor_vertices(src)
191 }
192
193 pub fn node_predecessor_nodes(
195 &self,
196 dst: GraphNodeId,
197 ) -> impl '_
198 + DoubleEndedIterator<Item = GraphNodeId>
199 + ExactSizeIterator
200 + FusedIterator
201 + Clone
202 + Debug {
203 self.graph.predecessor_vertices(dst)
204 }
205
206 pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
208 self.nodes.keys()
209 }
210
211 pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
213 self.nodes.iter()
214 }
215
216 pub fn insert_node(
218 &mut self,
219 node: GraphNode,
220 varname_opt: Option<Ident>,
221 loop_opt: Option<GraphLoopId>,
222 ) -> GraphNodeId {
223 let node_id = self.nodes.insert(node);
224 if let Some(varname) = varname_opt {
225 self.node_varnames.insert(node_id, Varname(varname));
226 }
227 if let Some(loop_id) = loop_opt {
228 self.node_loops.insert(node_id, loop_id);
229 self.loop_nodes[loop_id].push(node_id);
230 }
231 node_id
232 }
233
234 pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
236 assert!(matches!(
237 self.nodes.get(node_id),
238 Some(GraphNode::Operator(_))
239 ));
240 let old_inst = self.operator_instances.insert(node_id, op_inst);
241 assert!(old_inst.is_none());
242 }
243
244 pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
246 let mut op_insts = Vec::new();
247 for (node_id, node) in self.nodes() {
248 let GraphNode::Operator(operator) = node else {
249 continue;
250 };
251 if self.node_op_inst(node_id).is_some() {
252 continue;
253 };
254
255 let Some(op_constraints) = find_op_op_constraints(operator) else {
257 diagnostics.push(Diagnostic::spanned(
258 operator.path.span(),
259 Level::Error,
260 format!("Unknown operator `{}`", operator.name_string()),
261 ));
262 continue;
263 };
264
265 let (input_ports, output_ports) = {
267 let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
268 .node_predecessors(node_id)
269 .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
270 .collect();
271 input_edges.sort();
273 let input_ports: Vec<PortIndexValue> = input_edges
274 .into_iter()
275 .map(|(port, _pred)| port)
276 .cloned()
277 .collect();
278
279 let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
281 .node_successors(node_id)
282 .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
283 .collect();
284 output_edges.sort();
286 let output_ports: Vec<PortIndexValue> = output_edges
287 .into_iter()
288 .map(|(port, _succ)| port)
289 .cloned()
290 .collect();
291
292 (input_ports, output_ports)
293 };
294
295 let generics = get_operator_generics(diagnostics, operator);
297 {
299 let generics_span = generics
301 .generic_args
302 .as_ref()
303 .map(Spanned::span)
304 .unwrap_or_else(|| operator.path.span());
305
306 if !op_constraints
307 .persistence_args
308 .contains(&generics.persistence_args.len())
309 {
310 diagnostics.push(Diagnostic::spanned(
311 generics_span,
312 Level::Error,
313 format!(
314 "`{}` should have {} persistence lifetime arguments, actually has {}.",
315 op_constraints.name,
316 op_constraints.persistence_args.human_string(),
317 generics.persistence_args.len()
318 ),
319 ));
320 }
321 if !op_constraints.type_args.contains(&generics.type_args.len()) {
322 diagnostics.push(Diagnostic::spanned(
323 generics_span,
324 Level::Error,
325 format!(
326 "`{}` should have {} generic type arguments, actually has {}.",
327 op_constraints.name,
328 op_constraints.type_args.human_string(),
329 generics.type_args.len()
330 ),
331 ));
332 }
333 }
334
335 op_insts.push((
336 node_id,
337 OperatorInstance {
338 op_constraints,
339 input_ports,
340 output_ports,
341 singletons_referenced: operator.singletons_referenced.clone(),
342 generics,
343 arguments_pre: operator.args.clone(),
344 arguments_raw: operator.args_raw.clone(),
345 },
346 ));
347 }
348
349 for (node_id, op_inst) in op_insts {
350 self.insert_node_op_inst(node_id, op_inst);
351 }
352 }
353
354 pub fn insert_intermediate_node(
366 &mut self,
367 edge_id: GraphEdgeId,
368 new_node: GraphNode,
369 ) -> (GraphNodeId, GraphEdgeId) {
370 let span = Some(new_node.span());
371
372 let op_inst_opt = 'oc: {
374 let GraphNode::Operator(operator) = &new_node else {
375 break 'oc None;
376 };
377 let Some(op_constraints) = find_op_op_constraints(operator) else {
378 break 'oc None;
379 };
380 let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
381
382 let mut dummy_diagnostics = Diagnostics::new();
383 let generics = get_operator_generics(&mut dummy_diagnostics, operator);
384 assert!(dummy_diagnostics.is_empty());
385
386 Some(OperatorInstance {
387 op_constraints,
388 input_ports: vec![input_port],
389 output_ports: vec![output_port],
390 singletons_referenced: operator.singletons_referenced.clone(),
391 generics,
392 arguments_pre: operator.args.clone(),
393 arguments_raw: operator.args_raw.clone(),
394 })
395 };
396
397 let node_id = self.nodes.insert(new_node);
399 if let Some(op_inst) = op_inst_opt {
401 self.operator_instances.insert(node_id, op_inst);
402 }
403 let (e0, e1) = self
405 .graph
406 .insert_intermediate_vertex(node_id, edge_id)
407 .unwrap();
408
409 let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
411 self.ports
412 .insert(e0, (src_idx, PortIndexValue::Elided(span)));
413 self.ports
414 .insert(e1, (PortIndexValue::Elided(span), dst_idx));
415
416 (node_id, e1)
417 }
418
419 pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
422 assert_eq!(
423 1,
424 self.node_degree_in(node_id),
425 "Removed intermediate node must have one predecessor"
426 );
427 assert_eq!(
428 1,
429 self.node_degree_out(node_id),
430 "Removed intermediate node must have one successor"
431 );
432 assert!(
433 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
434 "Should not remove intermediate node after subgraph partitioning"
435 );
436
437 assert!(self.nodes.remove(node_id).is_some());
438 let (new_edge_id, (pred_edge_id, succ_edge_id)) =
439 self.graph.remove_intermediate_vertex(node_id).unwrap();
440 self.operator_instances.remove(node_id);
441 self.node_varnames.remove(node_id);
442
443 let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
444 let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
445 self.ports.insert(new_edge_id, (src_port, dst_port));
446 }
447
448 pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
454 if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
455 return Some(Color::Hoff);
456 }
457
458 if let GraphNode::Operator(op) = self.node(node_id)
460 && (op.name_string() == "resolve_futures_blocking"
461 || op.name_string() == "resolve_futures_blocking_ordered")
462 {
463 return Some(Color::Push);
464 }
465
466 let inn_degree = self.node_predecessor_nodes(node_id).count();
468 let out_degree = self.node_successor_nodes(node_id).count();
470
471 match (inn_degree, out_degree) {
472 (0, 0) => None, (0, 1) => Some(Color::Pull),
474 (1, 0) => Some(Color::Push),
475 (1, 1) => None, (_many, 0 | 1) => Some(Color::Pull),
477 (0 | 1, _many) => Some(Color::Push),
478 (_many, _to_many) => Some(Color::Comp),
479 }
480 }
481
482 pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
484 self.operator_tag.insert(node_id, tag);
485 }
486}
487
488impl DfirGraph {
490 pub fn set_node_singleton_references(
493 &mut self,
494 node_id: GraphNodeId,
495 singletons_referenced: Vec<Option<GraphNodeId>>,
496 ) -> Option<Vec<Option<GraphNodeId>>> {
497 self.node_singleton_references
498 .insert(node_id, singletons_referenced)
499 }
500
501 pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
504 self.node_singleton_references
505 .get(node_id)
506 .map(std::ops::Deref::deref)
507 .unwrap_or_default()
508 }
509}
510
511impl DfirGraph {
513 pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
521 let mod_bound_nodes = self
522 .nodes()
523 .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
524 .map(|(nid, _node)| nid)
525 .collect::<Vec<_>>();
526
527 for mod_bound_node in mod_bound_nodes {
528 self.remove_module_boundary(mod_bound_node)?;
529 }
530
531 Ok(())
532 }
533
534 fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
538 assert!(
539 self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
540 "Should not remove intermediate node after subgraph partitioning"
541 );
542
543 let mut mod_pred_ports = BTreeMap::new();
544 let mut mod_succ_ports = BTreeMap::new();
545
546 for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
547 let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
548 mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
549 }
550
551 for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
552 let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
553 mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
554 }
555
556 if mod_pred_ports.keys().collect::<BTreeSet<_>>()
557 != mod_succ_ports.keys().collect::<BTreeSet<_>>()
558 {
559 let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
561 panic!();
562 };
563
564 if *input {
565 return Err(Diagnostic {
566 span: *import_expr,
567 level: Level::Error,
568 message: format!(
569 "The ports into the module did not match. input: {:?}, expected: {:?}",
570 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
571 mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
572 ),
573 });
574 } else {
575 return Err(Diagnostic {
576 span: *import_expr,
577 level: Level::Error,
578 message: format!(
579 "The ports out of the module did not match. output: {:?}, expected: {:?}",
580 mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
581 mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
582 ),
583 });
584 }
585 }
586
587 for (port, (pred_edge, pred_port)) in mod_pred_ports {
588 let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
589
590 let (src, _) = self.edge(pred_edge);
591 let (_, dst) = self.edge(succ_edge);
592 self.remove_edge(pred_edge);
593 self.remove_edge(succ_edge);
594
595 let new_edge_id = self.graph.insert_edge(src, dst);
596 self.ports.insert(new_edge_id, (pred_port, succ_port));
597 }
598
599 self.graph.remove_vertex(mod_bound_node);
600 self.nodes.remove(mod_bound_node);
601
602 Ok(())
603 }
604}
605
606impl DfirGraph {
608 pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
610 let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
611 (src, dst)
612 }
613
614 pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
616 let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
617 (src_port, dst_port)
618 }
619
620 pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
622 self.graph.edge_ids()
623 }
624
625 pub fn edges(
627 &self,
628 ) -> impl '_
629 + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
630 + FusedIterator
631 + Clone
632 + Debug {
633 self.graph.edges()
634 }
635
636 pub fn insert_edge(
638 &mut self,
639 src: GraphNodeId,
640 src_port: PortIndexValue,
641 dst: GraphNodeId,
642 dst_port: PortIndexValue,
643 ) -> GraphEdgeId {
644 let edge_id = self.graph.insert_edge(src, dst);
645 self.ports.insert(edge_id, (src_port, dst_port));
646 edge_id
647 }
648
649 pub fn remove_edge(&mut self, edge: GraphEdgeId) {
651 let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
652 let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
653 }
654}
655
656impl DfirGraph {
658 pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
660 self.subgraph_nodes
661 .get(subgraph_id)
662 .expect("Subgraph not found.")
663 }
664
665 pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
667 self.subgraph_nodes.keys()
668 }
669
670 pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
672 self.subgraph_nodes.iter()
673 }
674
675 pub fn insert_subgraph(
677 &mut self,
678 node_ids: Vec<GraphNodeId>,
679 ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
680 for &node_id in node_ids.iter() {
682 if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
683 return Err((node_id, old_sg_id));
684 }
685 }
686 let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
687 for &node_id in node_ids.iter() {
688 self.node_subgraph.insert(node_id, sg_id);
689 }
690 node_ids
691 });
692
693 Ok(subgraph_id)
694 }
695
696 pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
698 if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
699 self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
700 true
701 } else {
702 false
703 }
704 }
705
706 pub fn subgraph_stratum(&self, sg_id: GraphSubgraphId) -> Option<usize> {
708 self.subgraph_stratum.get(sg_id).copied()
709 }
710
711 pub fn set_subgraph_stratum(
713 &mut self,
714 sg_id: GraphSubgraphId,
715 stratum: usize,
716 ) -> Option<usize> {
717 self.subgraph_stratum.insert(sg_id, stratum)
718 }
719
720 fn subgraph_laziness(&self, sg_id: GraphSubgraphId) -> bool {
722 self.subgraph_laziness.get(sg_id).copied().unwrap_or(false)
723 }
724
725 pub fn set_subgraph_laziness(&mut self, sg_id: GraphSubgraphId, lazy: bool) -> bool {
727 self.subgraph_laziness.insert(sg_id, lazy).unwrap_or(false)
728 }
729
730 pub fn max_stratum(&self) -> Option<usize> {
732 self.subgraph_stratum.values().copied().max()
733 }
734
735 fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
737 subgraph_nodes
738 .iter()
739 .position(|&node_id| {
740 self.node_color(node_id)
741 .is_some_and(|color| Color::Pull != color)
742 })
743 .unwrap_or(subgraph_nodes.len())
744 }
745}
746
747impl DfirGraph {
749 fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
751 let name = match &self.nodes[node_id] {
752 GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
753 GraphNode::Handoff { .. } => format!(
754 "hoff_{:?}_{}",
755 node_id.data(),
756 if is_pred { "recv" } else { "send" }
757 ),
758 GraphNode::ModuleBoundary { .. } => panic!(),
759 };
760 let span = match (is_pred, &self.nodes[node_id]) {
761 (_, GraphNode::Operator(operator)) => operator.span(),
762 (true, &GraphNode::Handoff { src_span, .. }) => src_span,
763 (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
764 (_, GraphNode::ModuleBoundary { .. }) => panic!(),
765 };
766 Ident::new(&name, span)
767 }
768
769 fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
771 Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
772 }
773
774 fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
776 self.node_singleton_references(node_id)
777 .iter()
778 .map(|singleton_node_id| {
779 self.node_as_singleton_ident(
781 singleton_node_id
782 .expect("Expected singleton to be resolved but was not, this is a bug."),
783 span,
784 )
785 })
786 .collect::<Vec<_>>()
787 }
788
789 fn helper_collect_subgraph_handoffs(
792 &self,
793 ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
794 let mut subgraph_handoffs: SecondaryMap<
796 GraphSubgraphId,
797 (Vec<GraphNodeId>, Vec<GraphNodeId>),
798 > = self
799 .subgraph_nodes
800 .keys()
801 .map(|k| (k, Default::default()))
802 .collect();
803
804 for (hoff_id, node) in self.nodes() {
806 if !matches!(node, GraphNode::Handoff { .. }) {
807 continue;
808 }
809 for (_edge, succ_id) in self.node_successors(hoff_id) {
811 let succ_sg = self.node_subgraph(succ_id).unwrap();
812 subgraph_handoffs[succ_sg].0.push(hoff_id);
813 }
814 for (_edge, pred_id) in self.node_predecessors(hoff_id) {
816 let pred_sg = self.node_subgraph(pred_id).unwrap();
817 subgraph_handoffs[pred_sg].1.push(hoff_id);
818 }
819 }
820
821 subgraph_handoffs
822 }
823
824 fn codegen_nested_loops(&self, df: &Ident) -> TokenStream {
826 let mut out = TokenStream::new();
828 let mut queue = VecDeque::from_iter(self.root_loops.iter().copied());
829 while let Some(loop_id) = queue.pop_front() {
830 let parent_opt = self
831 .loop_parent(loop_id)
832 .map(|loop_id| loop_id.as_ident(Span::call_site()))
833 .map(|ident| quote! { Some(#ident) })
834 .unwrap_or_else(|| quote! { None });
835 let loop_name = loop_id.as_ident(Span::call_site());
836 out.append_all(quote! {
837 let #loop_name = #df.add_loop(#parent_opt);
838 });
839 queue.extend(self.loop_children.get(loop_id).into_iter().flatten());
840 }
841 out
842 }
843
844 pub fn as_code(
848 &self,
849 root: &TokenStream,
850 include_type_guards: bool,
851 prefix: TokenStream,
852 diagnostics: &mut Diagnostics,
853 ) -> Result<TokenStream, Diagnostics> {
854 let df = Ident::new(GRAPH, Span::call_site());
855 let context = Ident::new(CONTEXT, Span::call_site());
856
857 let handoff_code = self
859 .nodes
860 .iter()
861 .filter_map(|(node_id, node)| match node {
862 GraphNode::Operator(_) => None,
863 &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
864 GraphNode::ModuleBoundary { .. } => panic!(),
865 })
866 .map(|(node_id, (src_span, dst_span))| {
867 let ident_send = Ident::new(&format!("hoff_{:?}_send", node_id.data()), dst_span);
868 let ident_recv = Ident::new(&format!("hoff_{:?}_recv", node_id.data()), src_span);
869 let span = src_span.join(dst_span).unwrap_or(src_span);
870 let mut hoff_name = Literal::string(&format!("handoff {:?}", node_id));
871 hoff_name.set_span(span);
872 let hoff_type = quote_spanned! (span=> #root::scheduled::handoff::VecHandoff<_>);
873 quote_spanned! {span=>
874 let (#ident_send, #ident_recv) =
875 #df.make_edge::<_, #hoff_type>(#hoff_name);
876 }
877 });
878
879 let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
880
881 let (subgraphs_without_preds, subgraphs_with_preds) = self
883 .subgraph_nodes
884 .iter()
885 .partition::<Vec<_>, _>(|(_, nodes)| {
886 nodes
887 .iter()
888 .any(|&node_id| self.node_degree_in(node_id) == 0)
889 });
890
891 let mut op_prologue_code = Vec::new();
892 let mut op_prologue_after_code = Vec::new();
893 let mut subgraphs = Vec::new();
894 {
895 for &(subgraph_id, subgraph_nodes) in subgraphs_without_preds
896 .iter()
897 .chain(subgraphs_with_preds.iter())
898 {
899 let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
900 let recv_ports: Vec<Ident> = recv_hoffs
901 .iter()
902 .map(|&hoff_id| self.node_as_ident(hoff_id, true))
903 .collect();
904 let send_ports: Vec<Ident> = send_hoffs
905 .iter()
906 .map(|&hoff_id| self.node_as_ident(hoff_id, false))
907 .collect();
908
909 let recv_port_code = recv_ports.iter().map(|ident| {
910 quote_spanned! {ident.span()=>
911 let mut #ident = #ident.borrow_mut_swap();
912 let #ident = #root::futures::stream::iter(#ident.drain(..));
913 }
914 });
915 let send_port_code = send_ports.iter().map(|ident| {
916 quote_spanned! {ident.span()=>
917 let #ident = #root::sinktools::for_each(|v| {
918 #ident.give(Some(v));
919 });
920 }
921 });
922
923 let loop_id = self
924 .node_loop(subgraph_nodes[0]);
926
927 let mut subgraph_op_iter_code = Vec::new();
928 let mut subgraph_op_iter_after_code = Vec::new();
929 {
930 let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
931
932 let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
933 let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
934
935 for (idx, &node_id) in nodes_iter.enumerate() {
936 let node = &self.nodes[node_id];
937 assert!(
938 matches!(node, GraphNode::Operator(_)),
939 "Handoffs are not part of subgraphs."
940 );
941 let op_inst = &self.operator_instances[node_id];
942
943 let op_span = node.span();
944 let op_name = op_inst.op_constraints.name;
945 let root = change_spans(root.clone(), op_span);
947 let op_constraints = OPERATORS
949 .iter()
950 .find(|op| op_name == op.name)
951 .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
952
953 let ident = self.node_as_ident(node_id, false);
954
955 {
956 let mut input_edges = self
959 .graph
960 .predecessor_edges(node_id)
961 .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
962 .collect::<Vec<_>>();
963 input_edges.sort();
965
966 let inputs = input_edges
967 .iter()
968 .map(|&(_port, edge_id)| {
969 let (pred, _) = self.edge(edge_id);
970 self.node_as_ident(pred, true)
971 })
972 .collect::<Vec<_>>();
973
974 let mut output_edges = self
976 .graph
977 .successor_edges(node_id)
978 .map(|edge_id| (&self.ports[edge_id].0, edge_id))
979 .collect::<Vec<_>>();
980 output_edges.sort();
982
983 let outputs = output_edges
984 .iter()
985 .map(|&(_port, edge_id)| {
986 let (_, succ) = self.edge(edge_id);
987 self.node_as_ident(succ, false)
988 })
989 .collect::<Vec<_>>();
990
991 let is_pull = idx < pull_to_push_idx;
992
993 let singleton_output_ident = &if op_constraints.has_singleton_output {
994 self.node_as_singleton_ident(node_id, op_span)
995 } else {
996 Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
998 };
999
1000 let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1009 let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1010
1011 let singletons_resolved =
1012 self.helper_resolve_singletons(node_id, op_span);
1013 let arguments = &process_singletons::postprocess_singletons(
1014 op_inst.arguments_raw.clone(),
1015 singletons_resolved.clone(),
1016 context,
1017 );
1018 let arguments_handles =
1019 &process_singletons::postprocess_singletons_handles(
1020 op_inst.arguments_raw.clone(),
1021 singletons_resolved.clone(),
1022 );
1023
1024 let source_tag = 'a: {
1025 if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1026 break 'a tag;
1027 }
1028
1029 #[cfg(nightly)]
1030 if proc_macro::is_available() {
1031 let op_span = op_span.unwrap();
1032 break 'a format!(
1033 "loc_{}_{}_{}_{}_{}",
1034 crate::pretty_span::make_source_path_relative(
1035 &op_span.file()
1036 )
1037 .display()
1038 .to_string()
1039 .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1040 op_span.start().line(),
1041 op_span.start().column(),
1042 op_span.end().line(),
1043 op_span.end().column(),
1044 );
1045 }
1046
1047 format!(
1048 "loc_nopath_{}_{}_{}_{}",
1049 op_span.start().line,
1050 op_span.start().column,
1051 op_span.end().line,
1052 op_span.end().column
1053 )
1054 };
1055
1056 let work_fn = format_ident!(
1057 "{}__{}__{}",
1058 ident,
1059 op_name,
1060 source_tag,
1061 span = op_span
1062 );
1063 let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1064
1065 let context_args = WriteContextArgs {
1066 root: &root,
1067 df_ident: df_local,
1068 context,
1069 subgraph_id,
1070 node_id,
1071 loop_id,
1072 op_span,
1073 op_tag: self.operator_tag.get(node_id).cloned(),
1074 work_fn: &work_fn,
1075 work_fn_async: &work_fn_async,
1076 ident: &ident,
1077 is_pull,
1078 inputs: &inputs,
1079 outputs: &outputs,
1080 singleton_output_ident,
1081 op_name,
1082 op_inst,
1083 arguments,
1084 arguments_handles,
1085 };
1086
1087 let write_result =
1088 (op_constraints.write_fn)(&context_args, diagnostics);
1089 let OperatorWriteOutput {
1090 write_prologue,
1091 write_prologue_after,
1092 write_iterator,
1093 write_iterator_after,
1094 } = write_result.unwrap_or_else(|()| {
1095 assert!(
1096 diagnostics.has_error(),
1097 "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1098 op_name,
1099 );
1100 OperatorWriteOutput { write_iterator: null_write_iterator_fn(&context_args), ..Default::default() }
1101 });
1102
1103 op_prologue_code.push(syn::parse_quote! {
1104 #[allow(non_snake_case)]
1105 #[inline(always)]
1106 fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1107 thunk()
1108 }
1109
1110 #[allow(non_snake_case)]
1111 #[inline(always)]
1112 async fn #work_fn_async<T>(thunk: impl ::std::future::Future<Output = T>) -> T {
1113 thunk.await
1114 }
1115 });
1116 op_prologue_code.push(write_prologue);
1117 op_prologue_after_code.push(write_prologue_after);
1118 subgraph_op_iter_code.push(write_iterator);
1119
1120 if include_type_guards {
1121 let type_guard = if is_pull {
1122 quote_spanned! {op_span=>
1123 let #ident = {
1124 #[allow(non_snake_case)]
1125 #[inline(always)]
1126 pub fn #work_fn<Item, Input: #root::futures::stream::Stream<Item = Item>>(input: Input) -> impl #root::futures::stream::Stream<Item = Item> {
1127 #root::pin_project_lite::pin_project! {
1128 #[repr(transparent)]
1129 struct Pull<Item, Input: #root::futures::stream::Stream<Item = Item>> {
1130 #[pin]
1131 inner: Input
1132 }
1133 }
1134
1135 impl<Item, Input> #root::futures::stream::Stream for Pull<Item, Input>
1136 where
1137 Input: #root::futures::stream::Stream<Item = Item>,
1138 {
1139 type Item = Item;
1140
1141 #[inline(always)]
1142 fn poll_next(
1143 self: ::std::pin::Pin<&mut Self>,
1144 cx: &mut ::std::task::Context<'_>,
1145 ) -> ::std::task::Poll<::std::option::Option<Self::Item>> {
1146 #root::futures::stream::Stream::poll_next(self.project().inner, cx)
1147 }
1148
1149 #[inline(always)]
1150 fn size_hint(&self) -> (usize, Option<usize>) {
1151 #root::futures::stream::Stream::size_hint(&self.inner)
1152 }
1153 }
1154
1155 Pull {
1156 inner: input
1157 }
1158 }
1159 #work_fn( #ident )
1160 };
1161 }
1162 } else {
1163 quote_spanned! {op_span=>
1164 let #ident = {
1165 #[allow(non_snake_case)]
1166 #[inline(always)]
1167 pub fn #work_fn<Item, Si>(si: Si) -> impl #root::futures::sink::Sink<Item, Error = #root::Never>
1168 where
1169 Si: #root::futures::sink::Sink<Item, Error = #root::Never>
1170 {
1171 #root::pin_project_lite::pin_project! {
1172 #[repr(transparent)]
1173 struct Push<Si> {
1174 #[pin]
1175 si: Si,
1176 }
1177 }
1178
1179 impl<Item, Si> #root::futures::sink::Sink<Item> for Push<Si>
1180 where
1181 Si: #root::futures::sink::Sink<Item>,
1182 {
1183 type Error = Si::Error;
1184
1185 fn poll_ready(
1186 self: ::std::pin::Pin<&mut Self>,
1187 cx: &mut ::std::task::Context<'_>,
1188 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1189 self.project().si.poll_ready(cx)
1190 }
1191
1192 fn start_send(
1193 self: ::std::pin::Pin<&mut Self>,
1194 item: Item,
1195 ) -> ::std::result::Result<(), Self::Error> {
1196 self.project().si.start_send(item)
1197 }
1198
1199 fn poll_flush(
1200 self: ::std::pin::Pin<&mut Self>,
1201 cx: &mut ::std::task::Context<'_>,
1202 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1203 self.project().si.poll_flush(cx)
1204 }
1205
1206 fn poll_close(
1207 self: ::std::pin::Pin<&mut Self>,
1208 cx: &mut ::std::task::Context<'_>,
1209 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
1210 self.project().si.poll_close(cx)
1211 }
1212 }
1213
1214 Push {
1215 si
1216 }
1217 }
1218 #work_fn( #ident )
1219 };
1220 }
1221 };
1222 subgraph_op_iter_code.push(type_guard);
1223 }
1224 subgraph_op_iter_after_code.push(write_iterator_after);
1225 }
1226 }
1227
1228 {
1229 let pull_ident = if 0 < pull_to_push_idx {
1231 self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1232 } else {
1233 recv_ports[0].clone()
1235 };
1236
1237 #[rustfmt::skip]
1238 let push_ident = if let Some(&node_id) =
1239 subgraph_nodes.get(pull_to_push_idx)
1240 {
1241 self.node_as_ident(node_id, false)
1242 } else if 1 == send_ports.len() {
1243 send_ports[0].clone()
1245 } else {
1246 diagnostics.push(Diagnostic::spanned(
1247 pull_ident.span(),
1248 Level::Error,
1249 "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1250 ));
1251 continue;
1252 };
1253
1254 let pivot_span = pull_ident
1256 .span()
1257 .join(push_ident.span())
1258 .unwrap_or_else(|| push_ident.span());
1259 let pivot_fn_ident =
1260 Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1261 let root = change_spans(root.clone(), pivot_span);
1262 subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1263 #[inline(always)]
1264 fn #pivot_fn_ident<Pull, Push, Item>(pull: Pull, push: Push)
1265 -> impl ::std::future::Future<Output = ::std::result::Result<(), #root::Never>>
1266 where
1267 Pull: #root::futures::stream::Stream<Item = Item>,
1268 Push: #root::futures::sink::Sink<Item, Error = #root::Never>,
1269 {
1270 #root::sinktools::send_stream(pull, push)
1271 }
1272 (#pivot_fn_ident)(#pull_ident, #push_ident).await.unwrap();
1273 });
1274 }
1275 };
1276
1277 let subgraph_name = Literal::string(&format!("Subgraph {:?}", subgraph_id));
1278 let stratum = Literal::usize_unsuffixed(
1279 self.subgraph_stratum.get(subgraph_id).cloned().unwrap_or(0),
1280 );
1281 let laziness = self.subgraph_laziness(subgraph_id);
1282
1283 let loop_id_opt = loop_id
1285 .map(|loop_id| loop_id.as_ident(Span::call_site()))
1286 .map(|ident| quote! { Some(#ident) })
1287 .unwrap_or_else(|| quote! { None });
1288
1289 let sg_ident = subgraph_id.as_ident(Span::call_site());
1290
1291 subgraphs.push(quote! {
1292 let #sg_ident = #df.add_subgraph_full(
1293 #subgraph_name,
1294 #stratum,
1295 var_expr!( #( #recv_ports ),* ),
1296 var_expr!( #( #send_ports ),* ),
1297 #laziness,
1298 #loop_id_opt,
1299 async move |#context, var_args!( #( #recv_ports ),* ), var_args!( #( #send_ports ),* )| {
1300 #( #recv_port_code )*
1301 #( #send_port_code )*
1302 #( #subgraph_op_iter_code )*
1303 #( #subgraph_op_iter_after_code )*
1304 },
1305 );
1306 });
1307 }
1308 }
1309
1310 if diagnostics.has_error() {
1311 return Err(std::mem::take(diagnostics));
1312 }
1313 let _ = diagnostics; let loop_code = self.codegen_nested_loops(&df);
1316
1317 let code = quote! {
1322 #( #handoff_code )*
1323 #loop_code
1324 #( #op_prologue_code )*
1325 #( #subgraphs )*
1326 #( #op_prologue_after_code )*
1327 };
1328
1329 let meta_graph_json = serde_json::to_string(&self).unwrap();
1330 let meta_graph_json = Literal::string(&meta_graph_json);
1331
1332 let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1333 let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1334 let diagnostics_json = Literal::string(&diagnostics_json);
1335
1336 Ok(quote! {
1337 {
1338 #[allow(unused_qualifications, clippy::await_holding_refcell_ref)]
1339 {
1340 #prefix
1341
1342 use #root::{var_expr, var_args};
1343
1344 let mut #df = #root::scheduled::graph::Dfir::new();
1345 #df.__assign_meta_graph(#meta_graph_json);
1346 #df.__assign_diagnostics(#diagnostics_json);
1347
1348 #code
1349
1350 #df
1351 }
1352 }
1353 })
1354 }
1355
1356 pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1359 let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1360 .node_ids()
1361 .filter_map(|node_id| {
1362 let op_color = self.node_color(node_id)?;
1363 Some((node_id, op_color))
1364 })
1365 .collect();
1366
1367 for sg_nodes in self.subgraph_nodes.values() {
1369 let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1370
1371 for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1372 let is_pull = idx < pull_to_push_idx;
1373 node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1374 }
1375 }
1376
1377 node_color_map
1378 }
1379
1380 pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1382 let mut output = String::new();
1383 self.write_mermaid(&mut output, write_config).unwrap();
1384 output
1385 }
1386
1387 pub fn write_mermaid(
1389 &self,
1390 output: impl std::fmt::Write,
1391 write_config: &WriteConfig,
1392 ) -> std::fmt::Result {
1393 let mut graph_write = Mermaid::new(output);
1394 self.write_graph(&mut graph_write, write_config)
1395 }
1396
1397 pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1399 let mut output = String::new();
1400 let mut graph_write = Dot::new(&mut output);
1401 self.write_graph(&mut graph_write, write_config).unwrap();
1402 output
1403 }
1404
1405 pub fn write_dot(
1407 &self,
1408 output: impl std::fmt::Write,
1409 write_config: &WriteConfig,
1410 ) -> std::fmt::Result {
1411 let mut graph_write = Dot::new(output);
1412 self.write_graph(&mut graph_write, write_config)
1413 }
1414
1415 pub(crate) fn write_graph<W>(
1417 &self,
1418 mut graph_write: W,
1419 write_config: &WriteConfig,
1420 ) -> Result<(), W::Err>
1421 where
1422 W: GraphWrite,
1423 {
1424 fn helper_edge_label(
1425 src_port: &PortIndexValue,
1426 dst_port: &PortIndexValue,
1427 ) -> Option<String> {
1428 let src_label = match src_port {
1429 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1430 PortIndexValue::Int(index) => Some(index.value.to_string()),
1431 _ => None,
1432 };
1433 let dst_label = match dst_port {
1434 PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1435 PortIndexValue::Int(index) => Some(index.value.to_string()),
1436 _ => None,
1437 };
1438 let label = match (src_label, dst_label) {
1439 (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1440 (Some(l1), None) => Some(l1),
1441 (None, Some(l2)) => Some(l2),
1442 (None, None) => None,
1443 };
1444 label
1445 }
1446
1447 let node_color_map = self.node_color_map();
1449
1450 graph_write.write_prologue()?;
1452
1453 let mut skipped_handoffs = BTreeSet::new();
1455 let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1456 for (node_id, node) in self.nodes() {
1457 if matches!(node, GraphNode::Handoff { .. }) {
1458 if write_config.no_handoffs {
1459 skipped_handoffs.insert(node_id);
1460 continue;
1461 } else {
1462 let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1463 let pred_sg = self.node_subgraph(pred_node);
1464 let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1465 let succ_sg = self.node_subgraph(succ_node);
1466 if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1467 && pred_sg == succ_sg
1468 {
1469 subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1470 }
1471 }
1472 }
1473 graph_write.write_node_definition(
1474 node_id,
1475 &if write_config.op_short_text {
1476 node.to_name_string()
1477 } else if write_config.op_text_no_imports {
1478 let full_text = node.to_pretty_string();
1480 let mut output = String::new();
1481 for sentence in full_text.split('\n') {
1482 if sentence.trim().starts_with("use") {
1483 continue;
1484 }
1485 output.push('\n');
1486 output.push_str(sentence);
1487 }
1488 output.into()
1489 } else {
1490 node.to_pretty_string()
1491 },
1492 if write_config.no_pull_push {
1493 None
1494 } else {
1495 node_color_map.get(node_id).copied()
1496 },
1497 )?;
1498 }
1499
1500 for (edge_id, (src_id, mut dst_id)) in self.edges() {
1502 if skipped_handoffs.contains(&src_id) {
1504 continue;
1505 }
1506
1507 let (src_port, mut dst_port) = self.edge_ports(edge_id);
1508 if skipped_handoffs.contains(&dst_id) {
1509 let mut handoff_succs = self.node_successors(dst_id);
1510 assert_eq!(1, handoff_succs.len());
1511 let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1512 dst_id = succ_node;
1513 dst_port = self.edge_ports(succ_edge).1;
1514 }
1515
1516 let label = helper_edge_label(src_port, dst_port);
1517 let delay_type = self
1518 .node_op_inst(dst_id)
1519 .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1520 graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1521 }
1522
1523 if !write_config.no_references {
1525 for dst_id in self.node_ids() {
1526 for src_ref_id in self
1527 .node_singleton_references(dst_id)
1528 .iter()
1529 .copied()
1530 .flatten()
1531 {
1532 let delay_type = Some(DelayType::Stratum);
1533 let label = None;
1534 graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1535 }
1536 }
1537 }
1538
1539 let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1550 let loop_id = if write_config.no_loops {
1551 None
1552 } else {
1553 self.subgraph_loop(sg_id)
1554 };
1555 (loop_id, sg_id)
1556 });
1557 let loop_subgraphs = into_group_map(loop_subgraphs);
1558 for (loop_id, subgraph_ids) in loop_subgraphs {
1559 if let Some(loop_id) = loop_id {
1560 graph_write.write_loop_start(loop_id)?;
1561 }
1562
1563 let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1565 self.subgraph(sg_id).iter().copied().map(move |node_id| {
1566 let opt_sg_id = if write_config.no_subgraphs {
1567 None
1568 } else {
1569 Some(sg_id)
1570 };
1571 (opt_sg_id, (self.node_varname(node_id), node_id))
1572 })
1573 });
1574 let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1575 for (sg_id, varnames) in subgraph_varnames_nodes {
1576 if let Some(sg_id) = sg_id {
1577 let stratum = self.subgraph_stratum(sg_id).unwrap();
1578 graph_write.write_subgraph_start(sg_id, stratum)?;
1579 }
1580
1581 let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1583 let varname = if write_config.no_varnames {
1584 None
1585 } else {
1586 varname
1587 };
1588 (varname, node)
1589 });
1590 let varname_nodes = into_group_map(varname_nodes);
1591 for (varname, node_ids) in varname_nodes {
1592 if let Some(varname) = varname {
1593 graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1594 }
1595
1596 for node_id in node_ids {
1598 graph_write.write_node(node_id)?;
1599 }
1600
1601 if varname.is_some() {
1602 graph_write.write_varname_end()?;
1603 }
1604 }
1605
1606 if sg_id.is_some() {
1607 graph_write.write_subgraph_end()?;
1608 }
1609 }
1610
1611 if loop_id.is_some() {
1612 graph_write.write_loop_end()?;
1613 }
1614 }
1615
1616 graph_write.write_epilogue()?;
1618
1619 Ok(())
1620 }
1621
1622 pub fn surface_syntax_string(&self) -> String {
1624 let mut string = String::new();
1625 self.write_surface_syntax(&mut string).unwrap();
1626 string
1627 }
1628
1629 pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1631 for (key, node) in self.nodes.iter() {
1632 match node {
1633 GraphNode::Operator(op) => {
1634 writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1635 }
1636 GraphNode::Handoff { .. } => {
1637 writeln!(write, "// {:?} = <handoff>;", key.data())?;
1638 }
1639 GraphNode::ModuleBoundary { .. } => panic!(),
1640 }
1641 }
1642 writeln!(write)?;
1643 for (_e, (src_key, dst_key)) in self.graph.edges() {
1644 writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1645 }
1646 Ok(())
1647 }
1648
1649 pub fn mermaid_string_flat(&self) -> String {
1651 let mut string = String::new();
1652 self.write_mermaid_flat(&mut string).unwrap();
1653 string
1654 }
1655
1656 pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1658 writeln!(write, "flowchart TB")?;
1659 for (key, node) in self.nodes.iter() {
1660 match node {
1661 GraphNode::Operator(operator) => writeln!(
1662 write,
1663 " %% {span}\n {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1664 span = PrettySpan(node.span()),
1665 id = key.data(),
1666 row_col = PrettyRowCol(node.span()),
1667 code = operator
1668 .to_token_stream()
1669 .to_string()
1670 .replace('&', "&")
1671 .replace('<', "<")
1672 .replace('>', ">")
1673 .replace('"', """)
1674 .replace('\n', "<br>"),
1675 ),
1676 GraphNode::Handoff { .. } => {
1677 writeln!(write, r#" {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1678 }
1679 GraphNode::ModuleBoundary { .. } => {
1680 writeln!(
1681 write,
1682 r#" {:?}{{"{}"}}"#,
1683 key.data(),
1684 MODULE_BOUNDARY_NODE_STR
1685 )
1686 }
1687 }?;
1688 }
1689 writeln!(write)?;
1690 for (_e, (src_key, dst_key)) in self.graph.edges() {
1691 writeln!(write, " {:?}-->{:?}", src_key.data(), dst_key.data())?;
1692 }
1693 Ok(())
1694 }
1695}
1696
1697impl DfirGraph {
1699 pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1701 self.loop_nodes.keys()
1702 }
1703
1704 pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1706 self.loop_nodes.iter()
1707 }
1708
1709 pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1711 let loop_id = self.loop_nodes.insert(Vec::new());
1712 self.loop_children.insert(loop_id, Vec::new());
1713 if let Some(parent_loop) = parent_loop {
1714 self.loop_parent.insert(loop_id, parent_loop);
1715 self.loop_children
1716 .get_mut(parent_loop)
1717 .unwrap()
1718 .push(loop_id);
1719 } else {
1720 self.root_loops.push(loop_id);
1721 }
1722 loop_id
1723 }
1724
1725 pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1727 self.node_loops.get(node_id).copied()
1728 }
1729
1730 pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1732 let &node_id = self.subgraph(subgraph_id).first().unwrap();
1733 let out = self.node_loop(node_id);
1734 debug_assert!(
1735 self.subgraph(subgraph_id)
1736 .iter()
1737 .all(|&node_id| self.node_loop(node_id) == out),
1738 "Subgraph nodes should all have the same loop context."
1739 );
1740 out
1741 }
1742
1743 pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1745 self.loop_parent.get(loop_id).copied()
1746 }
1747
1748 pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1750 self.loop_children.get(loop_id).unwrap()
1751 }
1752}
1753
1754#[derive(Clone, Debug, Default)]
1756#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
1757pub struct WriteConfig {
1758 #[cfg_attr(feature = "clap-derive", arg(long))]
1760 pub no_subgraphs: bool,
1761 #[cfg_attr(feature = "clap-derive", arg(long))]
1763 pub no_varnames: bool,
1764 #[cfg_attr(feature = "clap-derive", arg(long))]
1766 pub no_pull_push: bool,
1767 #[cfg_attr(feature = "clap-derive", arg(long))]
1769 pub no_handoffs: bool,
1770 #[cfg_attr(feature = "clap-derive", arg(long))]
1772 pub no_references: bool,
1773 #[cfg_attr(feature = "clap-derive", arg(long))]
1775 pub no_loops: bool,
1776
1777 #[cfg_attr(feature = "clap-derive", arg(long))]
1779 pub op_short_text: bool,
1780 #[cfg_attr(feature = "clap-derive", arg(long))]
1782 pub op_text_no_imports: bool,
1783}
1784
1785#[derive(Copy, Clone, Debug)]
1787#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
1788pub enum WriteGraphType {
1789 Mermaid,
1791 Dot,
1793}
1794
1795fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
1797where
1798 K: Ord,
1799{
1800 let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
1801 for (k, v) in iter {
1802 out.entry(k).or_default().push(v);
1803 }
1804 out
1805}