Skip to main content

dfir_lang/graph/
mod.rs

1//! Graph representation stages for DFIR graphs.
2
3use std::borrow::Cow;
4use std::hash::Hash;
5
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::ToTokens;
8use serde::{Deserialize, Serialize};
9use slotmap::new_key_type;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Expr, ExprPath, GenericArgument, Token, Type};
13
14use self::ops::{OperatorConstraints, Persistence};
15use crate::diagnostic::{Diagnostic, Diagnostics, Level};
16use crate::parse::{DfirCode, IndexInt, Operator, PortIndex, Ported};
17use crate::pretty_span::PrettySpan;
18
19mod di_mul_graph;
20mod eliminate_extra_unions_tees;
21mod flat_graph_builder;
22mod flat_to_partitioned;
23mod graph_write;
24mod meta_graph;
25mod meta_graph_debugging;
26
27use std::fmt::Display;
28
29pub use di_mul_graph::DiMulGraph;
30pub use eliminate_extra_unions_tees::eliminate_extra_unions_tees;
31pub use flat_graph_builder::{FlatGraphBuilder, FlatGraphBuilderOutput};
32pub use flat_to_partitioned::partition_graph;
33pub use meta_graph::{DfirGraph, WriteConfig, WriteGraphType};
34
35pub mod graph_algorithms;
36pub mod ops;
37
38new_key_type! {
39    /// ID to identify a node (operator or handoff) in [`DfirGraph`].
40    pub struct GraphNodeId;
41
42    /// ID to identify an edge.
43    pub struct GraphEdgeId;
44
45    /// ID to identify a subgraph in [`DfirGraph`].
46    pub struct GraphSubgraphId;
47
48    /// ID to identify a loop block in [`DfirGraph`].
49    pub struct GraphLoopId;
50}
51
52impl GraphSubgraphId {
53    /// Generate a deterministic `Ident` for the given loop ID.
54    pub fn as_ident(self, span: Span) -> Ident {
55        use slotmap::Key;
56        Ident::new(&format!("sgid_{:?}", self.data()), span)
57    }
58}
59
60impl GraphLoopId {
61    /// Generate a deterministic `Ident` for the given loop ID.
62    pub fn as_ident(self, span: Span) -> Ident {
63        use slotmap::Key;
64        Ident::new(&format!("loop_{:?}", self.data()), span)
65    }
66}
67
68/// Context identifier as a string.
69const CONTEXT: &str = "context";
70/// Runnable DFIR graph object identifier as a string.
71const GRAPH: &str = "df";
72
73const HANDOFF_NODE_STR: &str = "handoff";
74const MODULE_BOUNDARY_NODE_STR: &str = "module_boundary";
75
76mod serde_syn {
77    use serde::{Deserialize, Deserializer, Serializer};
78
79    pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
80    where
81        S: Serializer,
82        T: quote::ToTokens,
83    {
84        serializer.serialize_str(&value.to_token_stream().to_string())
85    }
86
87    pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
88    where
89        D: Deserializer<'de>,
90        T: syn::parse::Parse,
91    {
92        let s = String::deserialize(deserializer)?;
93        syn::parse_str(&s).map_err(<D::Error as serde::de::Error>::custom)
94    }
95}
96
97/// A variable name assigned to a pipeline in DFIR syntax.
98///
99/// Fundamentally a serializable/deserializable wrapper around [`syn::Ident`].
100#[derive(Clone, Debug, Serialize, Deserialize, PartialOrd, Ord, PartialEq, Eq, Hash)]
101pub struct Varname(#[serde(with = "serde_syn")] pub Ident);
102
103/// A node, corresponding to an operator or a handoff.
104#[derive(Clone, Serialize, Deserialize)]
105pub enum GraphNode {
106    /// An operator.
107    Operator(#[serde(with = "serde_syn")] Operator),
108    /// A handoff point, used between subgraphs (or within a subgraph to break a cycle).
109    Handoff {
110        /// The span of the input into the handoff.
111        #[serde(skip, default = "Span::call_site")]
112        src_span: Span,
113        /// The span of the output out of the handoff.
114        #[serde(skip, default = "Span::call_site")]
115        dst_span: Span,
116    },
117
118    /// Module Boundary, used for importing modules. Only exists prior to partitioning.
119    ModuleBoundary {
120        /// If this module is an input or output boundary.
121        input: bool,
122
123        /// The span of the import!() expression that imported this module.
124        /// The value of this span when the ModuleBoundary node is still inside the module is Span::call_site()
125        /// TODO: This could one day reference into the module file itself?
126        #[serde(skip, default = "Span::call_site")]
127        import_expr: Span,
128    },
129}
130impl GraphNode {
131    /// Return the node as a human-readable string.
132    pub fn to_pretty_string(&self) -> Cow<'static, str> {
133        match self {
134            GraphNode::Operator(op) => op.to_pretty_string().into(),
135            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
136            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
137        }
138    }
139
140    /// Return the name of the node as a string, excluding parenthesis and op source code.
141    pub fn to_name_string(&self) -> Cow<'static, str> {
142        match self {
143            GraphNode::Operator(op) => op.name_string().into(),
144            GraphNode::Handoff { .. } => HANDOFF_NODE_STR.into(),
145            GraphNode::ModuleBoundary { .. } => MODULE_BOUNDARY_NODE_STR.into(),
146        }
147    }
148
149    /// Return the source code span of the node (for operators) or input/otput spans for handoffs.
150    pub fn span(&self) -> Span {
151        match self {
152            Self::Operator(op) => op.span(),
153            &Self::Handoff {
154                src_span, dst_span, ..
155            } => src_span.join(dst_span).unwrap_or(src_span),
156            Self::ModuleBoundary { import_expr, .. } => *import_expr,
157        }
158    }
159}
160impl std::fmt::Debug for GraphNode {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            Self::Operator(operator) => {
164                write!(f, "Node::Operator({} span)", PrettySpan(operator.span()))
165            }
166            Self::Handoff { .. } => write!(f, "Node::Handoff"),
167            Self::ModuleBoundary { input, .. } => {
168                write!(f, "Node::ModuleBoundary{{input: {}}}", input)
169            }
170        }
171    }
172}
173
174/// Meta-data relating to operators which may be useful throughout the compilation process.
175///
176/// This data can be generated from the graph, but it is useful to have it readily available
177/// pre-computed as many algorithms use the same info. Stuff like port names, arguments, and the
178/// [`OperatorConstraints`] for the operator.
179///
180/// Because it is derived from the graph itself, there can be "cache invalidation"-esque issues
181/// if this data is not kept in sync with the graph.
182#[derive(Clone, Debug)]
183pub struct OperatorInstance {
184    /// Name of the operator (will match [`OperatorConstraints::name`]).
185    pub op_constraints: &'static OperatorConstraints,
186    /// Port values used as this operator's input.
187    pub input_ports: Vec<PortIndexValue>,
188    /// Port values used as this operator's output.
189    pub output_ports: Vec<PortIndexValue>,
190    /// Singleton references within the operator arguments.
191    pub singletons_referenced: Vec<Ident>,
192
193    /// Generic arguments.
194    pub generics: OpInstGenerics,
195    /// Arguments provided by the user into the operator as arguments.
196    /// I.e. the `a, b, c` in `-> my_op(a, b, c) -> `.
197    ///
198    /// These arguments do not include singleton postprocessing codegen. Instead use
199    /// [`ops::WriteContextArgs::arguments`].
200    pub arguments_pre: Punctuated<Expr, Token![,]>,
201    /// Unparsed arguments, for singleton parsing.
202    pub arguments_raw: TokenStream,
203}
204
205/// Operator generic arguments, split into specific categories.
206#[derive(Clone, Debug)]
207pub struct OpInstGenerics {
208    /// Operator generic (type or lifetime) arguments.
209    pub generic_args: Option<Punctuated<GenericArgument, Token![,]>>,
210    /// Lifetime persistence arguments. Corresponds to a prefix of [`Self::generic_args`].
211    pub persistence_args: Vec<Persistence>,
212    /// Type persistence arguments. Corersponds to a (suffix) of [`Self::generic_args`].
213    pub type_args: Vec<Type>,
214}
215
216impl OpInstGenerics {
217    /// Helper to join a sequence of spans into a single span, if possible.
218    ///
219    /// Returns `None` if there are no spans or if any `Span::join` call fails
220    /// (for example, when spans are not contiguous).
221    fn join_spans<I>(mut spans: I) -> Option<Span>
222    where
223        I: Iterator<Item = Span>,
224    {
225        let mut span = spans.next()?;
226        for s in spans {
227            span = span.join(s)?;
228        }
229        Some(span)
230    }
231
232    /// Returns a [`Span`] containing all persistence (lifetime) args if possible.
233    pub fn persistence_args_span(&self) -> Option<Span> {
234        self.generic_args.as_ref().and_then(|args| {
235            Self::join_spans(
236                args.iter()
237                    .filter(|a| matches!(a, GenericArgument::Lifetime(_)))
238                    .map(|a| a.span()),
239            )
240        })
241    }
242
243    /// Returns a [`Span`] containing all type args if possible.
244    pub fn type_args_span(&self) -> Option<Span> {
245        self.generic_args.as_ref().and_then(|args| {
246            Self::join_spans(
247                args.iter()
248                    .filter(|a| matches!(a, GenericArgument::Type(_)))
249                    .map(|a| a.span()),
250            )
251        })
252    }
253}
254
255/// Gets the generic arguments for the operator.
256///
257/// This helper method is useful due to the special handling of persistence lifetimes (`'static`,
258/// `'tick`, `'mutable`) which must come before other generic parameters.
259pub fn get_operator_generics(diagnostics: &mut Diagnostics, operator: &Operator) -> OpInstGenerics {
260    // Generic arguments.
261    let generic_args = operator.type_arguments().cloned();
262    let persistence_args = generic_args.iter().flatten().map_while(|generic_arg| match generic_arg {
263            GenericArgument::Lifetime(lifetime) => {
264                match &*lifetime.ident.to_string() {
265                    "none" => Some(Persistence::None),
266                    "loop" => Some(Persistence::Loop),
267                    "tick" => Some(Persistence::Tick),
268                    "static" => Some(Persistence::Static),
269                    "mutable" => Some(Persistence::Mutable),
270                    _ => {
271                        diagnostics.push(Diagnostic::spanned(
272                            generic_arg.span(),
273                            Level::Error,
274                            format!("Unknown lifetime generic argument `'{}`, expected `'none`, `'loop`, `'tick`, `'static`, or `'mutable`.", lifetime.ident),
275                        ));
276                        // TODO(mingwei): should really keep going and not short circuit?
277                        None
278                    }
279                }
280            },
281            _ => None,
282        }).collect::<Vec<_>>();
283    let type_args = generic_args
284        .iter()
285        .flatten()
286        .skip(persistence_args.len())
287        .map_while(|generic_arg| match generic_arg {
288            GenericArgument::Type(typ) => Some(typ),
289            _ => None,
290        })
291        .cloned()
292        .collect::<Vec<_>>();
293
294    OpInstGenerics {
295        generic_args,
296        persistence_args,
297        type_args,
298    }
299}
300
301/// Push, Pull, Comp, or Hoff polarity.
302#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
303pub enum Color {
304    /// Pull (green)
305    Pull,
306    /// Push (blue)
307    Push,
308    /// Computation (yellow)
309    Comp,
310    /// Handoff (grey) -- not a color for operators, inserted between subgraphs.
311    Hoff,
312}
313
314/// Helper struct for [`PortIndex`] which keeps span information for elided ports.
315#[derive(Clone, Debug, Serialize, Deserialize)]
316pub enum PortIndexValue {
317    /// An integer value: `[0]`, `[1]`, etc. Can be negative although we don't use that (2023-08-16).
318    Int(#[serde(with = "serde_syn")] IndexInt),
319    /// A name or path. `[pos]`, `[neg]`, etc. Can use `::` separators but we don't use that (2023-08-16).
320    Path(#[serde(with = "serde_syn")] ExprPath),
321    /// Elided, unspecified port. We have this variant, rather than wrapping in `Option`, in order
322    /// to preserve the `Span` information.
323    Elided(#[serde(skip)] Option<Span>),
324}
325impl PortIndexValue {
326    /// For a [`Ported`] value like `[port_in]name[port_out]`, get the `port_in` and `port_out` as
327    /// [`PortIndexValue`]s.
328    pub fn from_ported<Inner>(ported: Ported<Inner>) -> (Self, Inner, Self)
329    where
330        Inner: Spanned,
331    {
332        let ported_span = Some(ported.inner.span());
333        let port_inn = ported
334            .inn
335            .map(|idx| idx.index.into())
336            .unwrap_or_else(|| Self::Elided(ported_span));
337        let inner = ported.inner;
338        let port_out = ported
339            .out
340            .map(|idx| idx.index.into())
341            .unwrap_or_else(|| Self::Elided(ported_span));
342        (port_inn, inner, port_out)
343    }
344
345    /// Returns `true` if `self` is not [`PortIndexValue::Elided`].
346    pub fn is_specified(&self) -> bool {
347        !matches!(self, Self::Elided(_))
348    }
349
350    /// Returns whichever of the two ports are specified.
351    /// If both are [`Self::Elided`], returns [`Self::Elided`].
352    /// If both are specified, returns `Err(self)`.
353    #[allow(clippy::allow_attributes, reason = "Only triggered on nightly.")]
354    #[allow(
355        clippy::result_large_err,
356        reason = "variants are same size, error isn't to be propagated."
357    )]
358    pub fn combine(self, other: Self) -> Result<Self, Self> {
359        match (self.is_specified(), other.is_specified()) {
360            (false, _other) => Ok(other),
361            (true, false) => Ok(self),
362            (true, true) => Err(self),
363        }
364    }
365
366    /// Formats self as a human-readable string for error messages.
367    pub fn as_error_message_string(&self) -> String {
368        match self {
369            PortIndexValue::Int(n) => format!("`{}`", n.value),
370            PortIndexValue::Path(path) => format!("`{}`", path.to_token_stream()),
371            PortIndexValue::Elided(_) => "<elided>".to_owned(),
372        }
373    }
374
375    /// Returns the span of this port value.
376    pub fn span(&self) -> Span {
377        match self {
378            PortIndexValue::Int(x) => x.span(),
379            PortIndexValue::Path(x) => x.span(),
380            PortIndexValue::Elided(span) => span.unwrap_or_else(Span::call_site),
381        }
382    }
383}
384impl From<PortIndex> for PortIndexValue {
385    fn from(value: PortIndex) -> Self {
386        match value {
387            PortIndex::Int(x) => Self::Int(x),
388            PortIndex::Path(x) => Self::Path(x),
389        }
390    }
391}
392impl PartialEq for PortIndexValue {
393    fn eq(&self, other: &Self) -> bool {
394        match (self, other) {
395            (Self::Int(l0), Self::Int(r0)) => l0 == r0,
396            (Self::Path(l0), Self::Path(r0)) => l0 == r0,
397            (Self::Elided(_), Self::Elided(_)) => true,
398            _else => false,
399        }
400    }
401}
402impl Eq for PortIndexValue {}
403impl PartialOrd for PortIndexValue {
404    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
405        Some(self.cmp(other))
406    }
407}
408impl Ord for PortIndexValue {
409    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
410        match (self, other) {
411            (Self::Int(s), Self::Int(o)) => s.cmp(o),
412            (Self::Path(s), Self::Path(o)) => s
413                .to_token_stream()
414                .to_string()
415                .cmp(&o.to_token_stream().to_string()),
416            (Self::Elided(_), Self::Elided(_)) => std::cmp::Ordering::Equal,
417            (Self::Int(_), Self::Path(_)) => std::cmp::Ordering::Less,
418            (Self::Path(_), Self::Int(_)) => std::cmp::Ordering::Greater,
419            (_, Self::Elided(_)) => std::cmp::Ordering::Less,
420            (Self::Elided(_), _) => std::cmp::Ordering::Greater,
421        }
422    }
423}
424
425impl Display for PortIndexValue {
426    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427        match self {
428            PortIndexValue::Int(x) => write!(f, "{}", x.to_token_stream()),
429            PortIndexValue::Path(x) => write!(f, "{}", x.to_token_stream()),
430            PortIndexValue::Elided(_) => write!(f, "[]"),
431        }
432    }
433}
434
435/// Output of [`build_dfir_code`].
436pub struct BuildDfirCodeOutput {
437    /// The now-partitioned graph.
438    pub partitioned_graph: DfirGraph,
439    /// The Rust source code tokens for the DFIR.
440    pub code: TokenStream,
441    /// Any (non-error) diagnostics emitted.
442    pub diagnostics: Diagnostics,
443}
444
445/// The main function of this module. Compiles a [`DfirCode`] AST into a [`DfirGraph`] and
446/// source code, or [`Diagnostic`] errors.
447pub fn build_dfir_code(
448    dfir_code: DfirCode,
449    root: &TokenStream,
450) -> Result<BuildDfirCodeOutput, Diagnostics> {
451    let flat_graph_builder = FlatGraphBuilder::from_dfir(dfir_code);
452
453    let FlatGraphBuilderOutput {
454        mut flat_graph,
455        uses,
456        mut diagnostics,
457    } = flat_graph_builder.build()?;
458
459    let () = match flat_graph.merge_modules() {
460        Ok(()) => (),
461        Err(d) => {
462            diagnostics.push(d);
463            return Err(diagnostics);
464        }
465    };
466
467    eliminate_extra_unions_tees(&mut flat_graph);
468    let partitioned_graph = match partition_graph(flat_graph) {
469        Ok(partitioned_graph) => partitioned_graph,
470        Err(d) => {
471            diagnostics.push(d);
472            return Err(diagnostics);
473        }
474    };
475
476    let code =
477        partitioned_graph.as_code(root, true, quote::quote! { #( #uses )* }, &mut diagnostics)?;
478
479    // Success
480    Ok(BuildDfirCodeOutput {
481        partitioned_graph,
482        code,
483        diagnostics,
484    })
485}
486
487/// Changes all of token's spans to `span`, recursing into groups.
488fn change_spans(tokens: TokenStream, span: Span) -> TokenStream {
489    use proc_macro2::{Group, TokenTree};
490    tokens
491        .into_iter()
492        .map(|token| match token {
493            TokenTree::Group(mut group) => {
494                group.set_span(span);
495                TokenTree::Group(Group::new(
496                    group.delimiter(),
497                    change_spans(group.stream(), span),
498                ))
499            }
500            TokenTree::Ident(mut ident) => {
501                ident.set_span(span.resolved_at(ident.span()));
502                TokenTree::Ident(ident)
503            }
504            TokenTree::Punct(mut punct) => {
505                punct.set_span(span);
506                TokenTree::Punct(punct)
507            }
508            TokenTree::Literal(mut literal) => {
509                literal.set_span(span);
510                TokenTree::Literal(literal)
511            }
512        })
513        .collect()
514}