Skip to main content

dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::spanned::Spanned;
14use syn::{
15    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
16    parse_quote,
17};
18
19/// Create a runnable graph instance using DFIR's custom syntax.
20///
21/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
22/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
23/// in the [Hydro repo](https://github.com/hydro-project/hydro).
24// TODO(mingwei): rustdoc examples inline.
25#[proc_macro]
26pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27    dfir_syntax_internal(input, Some(Level::Help))
28}
29
30/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
31///
32/// Used for testing, users will want to use [`dfir_syntax!`] instead.
33#[proc_macro]
34pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35    dfir_syntax_internal(input, None)
36}
37
38fn root() -> proc_macro2::TokenStream {
39    use std::env::{VarError, var as env_var};
40
41    let root_crate_name = format!(
42        "{}_rs",
43        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
44    );
45    let root_crate_ident = root_crate_name.replace('-', "_");
46    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
47        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
48    match root_crate {
49        proc_macro_crate::FoundCrate::Itself => {
50            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
51                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
52                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
53            {
54                // In the crate itself, including unit tests.
55                quote! { crate }
56            } else {
57                // In an integration test, example, bench, etc.
58                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
59                quote! { ::#ident }
60            }
61        }
62        proc_macro_crate::FoundCrate::Name(name) => {
63            let ident = Ident::new(&name, Span::call_site());
64            quote! { ::#ident }
65        }
66    }
67}
68
69fn dfir_syntax_internal(
70    input: proc_macro::TokenStream,
71    retain_diagnostic_level: Option<Level>,
72) -> proc_macro::TokenStream {
73    let input = parse_macro_input!(input as DfirCode);
74    let root = root();
75
76    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
77        Ok(BuildDfirCodeOutput {
78            partitioned_graph: _,
79            code,
80            diagnostics,
81        }) => (code, diagnostics),
82        Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
83    };
84
85    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
86        diagnostics.retain_level(level);
87        diagnostics.try_emit_all().err()
88    });
89
90    quote! {
91        {
92            #diagnostic_tokens
93            #code
94        }
95    }
96    .into()
97}
98
99/// Parse DFIR syntax without emitting code.
100///
101/// Used for testing, users will want to use [`dfir_syntax!`] instead.
102#[proc_macro]
103pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
104    let input = parse_macro_input!(input as DfirCode);
105
106    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
107    let err_diagnostics = 'err: {
108        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
109            Ok(FlatGraphBuilderOutput {
110                flat_graph,
111                uses: _,
112                diagnostics,
113            }) => (flat_graph, diagnostics),
114            Err(diagnostics) => {
115                break 'err diagnostics;
116            }
117        };
118
119        if let Err(diagnostic) = flat_graph.merge_modules() {
120            diagnostics.push(diagnostic);
121            break 'err diagnostics;
122        }
123
124        let flat_mermaid = flat_graph.mermaid_string_flat();
125
126        let part_graph = partition_graph(flat_graph).unwrap();
127        let part_mermaid = part_graph.to_mermaid(&Default::default());
128
129        let lit0 = Literal::string(&flat_mermaid);
130        let lit1 = Literal::string(&part_mermaid);
131
132        return quote! {
133            {
134                println!("{}\n\n{}\n", #lit0, #lit1);
135            }
136        }
137        .into();
138    };
139
140    err_diagnostics
141        .try_emit_all()
142        .err()
143        .unwrap_or_default()
144        .into()
145}
146
147fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
148    use quote::ToTokens;
149
150    let root = root();
151
152    let mut input: syn::ItemFn = match syn::parse(item) {
153        Ok(it) => it,
154        Err(e) => return e.into_compile_error().into(),
155    };
156
157    let statements = input.block.stmts;
158
159    input.block.stmts = parse_quote!(
160        #root::tokio::task::LocalSet::new().run_until(async {
161            #( #statements )*
162        }).await
163    );
164
165    input.attrs.push(attribute);
166
167    input.into_token_stream().into()
168}
169
170/// Checks that the given closure is a morphism. For now does nothing.
171#[proc_macro]
172pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
173    // TODO(mingwei): some sort of code analysis?
174    item
175}
176
177/// Checks that the given closure is a monotonic function. For now does nothing.
178#[proc_macro]
179pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
180    // TODO(mingwei): some sort of code analysis?
181    item
182}
183
184#[proc_macro_attribute]
185pub fn dfir_test(
186    args: proc_macro::TokenStream,
187    item: proc_macro::TokenStream,
188) -> proc_macro::TokenStream {
189    let root = root();
190    let args_2: proc_macro2::TokenStream = args.into();
191
192    wrap_localset(
193        item,
194        parse_quote!(
195            #[#root::tokio::test(flavor = "current_thread", #args_2)]
196        ),
197    )
198}
199
200#[proc_macro_attribute]
201pub fn dfir_main(
202    _: proc_macro::TokenStream,
203    item: proc_macro::TokenStream,
204) -> proc_macro::TokenStream {
205    let root = root();
206
207    wrap_localset(
208        item,
209        parse_quote!(
210            #[#root::tokio::main(flavor = "current_thread")]
211        ),
212    )
213}
214
215#[proc_macro_derive(DemuxEnum)]
216pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217    let root = root();
218
219    let ItemEnum {
220        ident: item_ident,
221        generics,
222        variants,
223        ..
224    } = parse_macro_input!(item as ItemEnum);
225
226    // Sort variants alphabetically.
227    let mut variants = variants.into_iter().collect::<Vec<_>>();
228    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
229
230    // Return type for each variant.
231    let variant_output_types = variants
232        .iter()
233        .map(|variant| match &variant.fields {
234            Fields::Named(fields) => {
235                let field_types = fields.named.iter().map(|field| &field.ty);
236                quote! {
237                    ( #( #field_types, )* )
238                }
239            }
240            Fields::Unnamed(fields) => {
241                let field_types = fields.unnamed.iter().map(|field| &field.ty);
242                quote! {
243                    ( #( #field_types, )* )
244                }
245            }
246            Fields::Unit => quote!(()),
247        })
248        .collect::<Vec<_>>();
249
250    let variant_generics_sink = variants
251        .iter()
252        .map(|variant| format_ident!("__Sink{}", variant.ident))
253        .collect::<Vec<_>>();
254    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
255        quote_spanned! {ident.span()=>
256            ::std::pin::Pin::<&mut #ident>
257        }
258    });
259    let variant_generics_pinned_sink_all = quote! {
260        ( #( #variant_generics_pinned_sink, )* )
261    };
262    let variant_localvars_sink = variants
263        .iter()
264        .map(|variant| {
265            format_ident!(
266                "__sink_{}",
267                variant.ident.to_string().to_lowercase(),
268                span = variant.ident.span()
269            )
270        })
271        .collect::<Vec<_>>();
272
273    let mut full_generics_sink = generics.clone();
274    full_generics_sink.params.extend(
275        variant_generics_sink
276            .iter()
277            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
278    );
279    full_generics_sink.make_where_clause().predicates.extend(
280        variant_generics_sink
281            .iter()
282            .zip(variant_output_types.iter())
283            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
284                parse_quote! {
285                    // TODO(mingwei): generic error types?
286                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
287                }
288            }),
289    );
290
291    let variant_pats_sink_start_send =
292        variants
293            .iter()
294            .zip(variant_localvars_sink.iter())
295            .map(|(variant, sinkvar)| {
296                let Variant { ident, fields, .. } = variant;
297                let (fields_pat, push_item) = field_pattern_item(fields);
298                quote! {
299                    Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
300                }
301            });
302
303    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
304    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
305        full_generics_sink.split_for_impl();
306
307    let variant_generics_push = variants
308        .iter()
309        .map(|variant| format_ident!("__Push{}", variant.ident))
310        .collect::<Vec<_>>();
311    let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
312        quote_spanned! {ident.span()=>
313            ::std::pin::Pin::<&mut #ident>
314        }
315    });
316    let variant_generics_pinned_push_all = quote! {
317        ( #( #variant_generics_pinned_push, )* )
318    };
319    let variant_localvars_push = variants
320        .iter()
321        .map(|variant| {
322            format_ident!(
323                "__push_{}",
324                variant.ident.to_string().to_lowercase(),
325                span = variant.ident.span()
326            )
327        })
328        .collect::<Vec<_>>();
329
330    let mut full_generics_push = generics.clone();
331    full_generics_push.params.extend(
332        variant_generics_push
333            .iter()
334            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
335    );
336    // Each push just needs Push<Item = VariantOutput, Meta = ()>.
337    full_generics_push.make_where_clause().predicates.extend(
338        variant_generics_push
339            .iter()
340            .zip(variant_output_types.iter())
341            .map::<WherePredicate, _>(|(push_generic, output_type)| {
342                parse_quote! {
343                    #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
344                }
345            }),
346    );
347
348    // Build the recursive Merged Ctx type:
349    // For 0 pushes: `()
350    // For 1 push: `Push0::Ctx<'__ctx>`
351    // For 2 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push1::Ctx<'__ctx>>`
352    // For 3 pushes: `<Push0::Ctx<'__ctx> as Context<'__ctx>>::Merged<<Push1::Ctx<'__ctx> as Context<'__ctx>>::Merged<Push2::Ctx<'__ctx>>>`
353    let ctx_type = variant_generics_push
354        .iter()
355        .zip(variant_output_types.iter())
356        .rev()
357        .map(|(push_generic, output_type)| {
358            quote_spanned! {push_generic.span()=>
359                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
360            }
361        })
362        .reduce(|rest, next| {
363            quote_spanned! {next.span()=>
364                <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
365            }
366        })
367        .unwrap_or_else(|| quote!(()));
368
369    let can_pend = variant_generics_push
370        .iter()
371        .zip(variant_output_types.iter())
372        .rev()
373        .map(|(push_generic, output_type)| {
374            quote_spanned! {push_generic.span()=>
375                <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
376            }
377        })
378        .reduce(|rest, next| {
379            quote_spanned! {next.span()=>
380                <#next as #root::dfir_pipes::Toggle>::Or<#rest>
381            }
382        })
383        .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
384
385    // Generate `Ctx`: `unmerge_self` for each push, `unmerge_other` to get remaining `__ctx`.
386    // For the last push, just pass `__ctx` directly (no unmerge needed).
387    let push_poll_unwrap_context = |method_name: Ident| {
388        variant_localvars_push.split_last().map(|(lastvar, headvar)| {
389            // `#( ... )*` zips all iterators to shortest; `headvar` (all-but-last) is shortest, so
390            // `variant_generics_push` and `variant_output_types` are naturally truncated to match.
391            quote! {
392                #(
393                    let #headvar = {
394                        let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
395                        #root::dfir_pipes::push::Push::#method_name(#headvar.as_mut(), __ctx)
396                    };
397                    let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
398                )*
399                let #lastvar = #root::dfir_pipes::push::Push::#method_name(#lastvar.as_mut(), __ctx);
400                // If any are pending, return pending.
401                #(
402                    if #variant_localvars_push.is_pending() {
403                        return #root::dfir_pipes::push::PushStep::pending();
404                    }
405                )*
406            }
407        })
408    };
409    let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
410    let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
411
412    let variant_pats_push_send =
413        variants
414            .iter()
415            .zip(variant_localvars_push.iter())
416            .map(|(variant, pushvar)| {
417                let Variant { ident, fields, .. } = variant;
418                let (fields_pat, push_item) = field_pattern_item(fields);
419                quote! {
420                    Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
421                }
422            });
423
424    let (impl_generics_push, _ty_generics_push, where_clause_push) =
425        full_generics_push.split_for_impl();
426
427    let single_impl = (1 == variants.len()).then(|| {
428        let Variant { ident, fields, .. } = variants.first().unwrap();
429        let (fields_pat, push_item) = field_pattern_item(fields);
430        let out_type = variant_output_types.first().unwrap();
431        quote! {
432            impl #impl_generics_item #root::util::demux_enum::SingleVariant
433                for #item_ident #ty_generics #where_clause_item
434            {
435                type Output = #out_type;
436                fn single_variant(self) -> Self::Output {
437                    match self {
438                        Self::#ident #fields_pat => #push_item,
439                    }
440                }
441            }
442        }
443    });
444
445    quote! {
446        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
447            for #item_ident #ty_generics #where_clause_sink
448        {
449            type Error = #root::Never;
450
451            fn poll_ready(
452                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
453                __cx: &mut ::std::task::Context<'_>,
454            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
455                // Ready all sinks simultaneously.
456                #(
457                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
458                )*
459                #(
460                    ::std::task::ready!(#variant_localvars_sink);
461                )*
462                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
463            }
464
465            fn start_send(
466                self,
467                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
468            ) -> ::std::result::Result<(), Self::Error> {
469                match self {
470                    #( #variant_pats_sink_start_send, )*
471                }
472            }
473
474            fn poll_flush(
475                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
476                __cx: &mut ::std::task::Context<'_>,
477            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
478                // Flush all sinks simultaneously.
479                #(
480                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
481                )*
482                #(
483                    ::std::task::ready!(#variant_localvars_sink);
484                )*
485                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
486            }
487
488            fn poll_close(
489                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
490                __cx: &mut ::std::task::Context<'_>,
491            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
492                // Close all sinks simultaneously.
493                #(
494                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
495                )*
496                #(
497                    ::std::task::ready!(#variant_localvars_sink);
498                )*
499                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
500            }
501        }
502
503        impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
504            for #item_ident #ty_generics #where_clause_push
505        {
506            type Ctx<'__ctx> = #ctx_type;
507            type CanPend = #can_pend;
508
509            fn poll_ready(
510                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
511                __ctx: &mut Self::Ctx<'_>,
512            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
513                #push_poll_ready_body
514                #root::dfir_pipes::push::PushStep::Done
515            }
516
517            fn start_send(
518                self,
519                __meta: (),
520                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
521            ) {
522                match self {
523                    #( #variant_pats_push_send, )*
524                }
525            }
526
527            fn poll_flush(
528                ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
529                __ctx: &mut Self::Ctx<'_>,
530            ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
531                #push_poll_flush_body
532                #root::dfir_pipes::push::PushStep::Done
533            }
534        }
535
536        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
537            for #item_ident #ty_generics #where_clause_item {}
538
539        #single_impl
540    }
541    .into()
542}
543
544/// (fields pattern, push item expr)
545fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
546    let idents = fields
547        .iter()
548        .enumerate()
549        .map(|(i, field)| {
550            field
551                .ident
552                .clone()
553                .unwrap_or_else(|| format_ident!("_{}", i))
554        })
555        .collect::<Vec<_>>();
556    let (fields_pat, push_item) = match fields {
557        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
558        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
559        Fields::Unit => (quote!(), quote!(())),
560    };
561    (fields_pat, push_item)
562}