Skip to main content

dfir_macro/
lib.rs

1#![cfg_attr(
2    nightly,
3    feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8    BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::{
14    Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
15    parse_quote,
16};
17
18/// Create a runnable graph instance using DFIR's custom syntax.
19///
20/// For example usage, take a look at the [`surface_*` tests in the `tests` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/tests)
21/// or the [`examples` folder](https://github.com/hydro-project/hydro/tree/main/dfir_rs/examples)
22/// in the [Hydro repo](https://github.com/hydro-project/hydro).
23// TODO(mingwei): rustdoc examples inline.
24#[proc_macro]
25pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
26    dfir_syntax_internal(input, Some(Level::Help))
27}
28
29/// [`dfir_syntax!`] but will not emit any diagnostics (errors, warnings, etc.).
30///
31/// Used for testing, users will want to use [`dfir_syntax!`] instead.
32#[proc_macro]
33pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34    dfir_syntax_internal(input, None)
35}
36
37fn root() -> proc_macro2::TokenStream {
38    use std::env::{VarError, var as env_var};
39
40    let root_crate_name = format!(
41        "{}_rs",
42        env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
43    );
44    let root_crate_ident = root_crate_name.replace('-', "_");
45    let root_crate = proc_macro_crate::crate_name(&root_crate_name)
46        .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
47    match root_crate {
48        proc_macro_crate::FoundCrate::Itself => {
49            if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
50                && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
51                && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
52            {
53                // In the crate itself, including unit tests.
54                quote! { crate }
55            } else {
56                // In an integration test, example, bench, etc.
57                let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
58                quote! { ::#ident }
59            }
60        }
61        proc_macro_crate::FoundCrate::Name(name) => {
62            let ident = Ident::new(&name, Span::call_site());
63            quote! { ::#ident }
64        }
65    }
66}
67
68fn dfir_syntax_internal(
69    input: proc_macro::TokenStream,
70    retain_diagnostic_level: Option<Level>,
71) -> proc_macro::TokenStream {
72    let input = parse_macro_input!(input as DfirCode);
73    let root = root();
74
75    let (code, mut diagnostics) = match build_dfir_code(input, &root) {
76        Ok(BuildDfirCodeOutput {
77            partitioned_graph: _,
78            code,
79            diagnostics,
80        }) => (code, diagnostics),
81        Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
82    };
83
84    let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
85        diagnostics.retain_level(level);
86        diagnostics.try_emit_all().err()
87    });
88
89    quote! {
90        {
91            #diagnostic_tokens
92            #code
93        }
94    }
95    .into()
96}
97
98/// Parse DFIR syntax without emitting code.
99///
100/// Used for testing, users will want to use [`dfir_syntax!`] instead.
101#[proc_macro]
102pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
103    let input = parse_macro_input!(input as DfirCode);
104
105    let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
106    let err_diagnostics = 'err: {
107        let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
108            Ok(FlatGraphBuilderOutput {
109                flat_graph,
110                uses: _,
111                diagnostics,
112            }) => (flat_graph, diagnostics),
113            Err(diagnostics) => {
114                break 'err diagnostics;
115            }
116        };
117
118        if let Err(diagnostic) = flat_graph.merge_modules() {
119            diagnostics.push(diagnostic);
120            break 'err diagnostics;
121        }
122
123        let flat_mermaid = flat_graph.mermaid_string_flat();
124
125        let part_graph = partition_graph(flat_graph).unwrap();
126        let part_mermaid = part_graph.to_mermaid(&Default::default());
127
128        let lit0 = Literal::string(&flat_mermaid);
129        let lit1 = Literal::string(&part_mermaid);
130
131        return quote! {
132            {
133                println!("{}\n\n{}\n", #lit0, #lit1);
134            }
135        }
136        .into();
137    };
138
139    err_diagnostics
140        .try_emit_all()
141        .err()
142        .unwrap_or_default()
143        .into()
144}
145
146fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
147    use quote::ToTokens;
148
149    let root = root();
150
151    let mut input: syn::ItemFn = match syn::parse(item) {
152        Ok(it) => it,
153        Err(e) => return e.into_compile_error().into(),
154    };
155
156    let statements = input.block.stmts;
157
158    input.block.stmts = parse_quote!(
159        #root::tokio::task::LocalSet::new().run_until(async {
160            #( #statements )*
161        }).await
162    );
163
164    input.attrs.push(attribute);
165
166    input.into_token_stream().into()
167}
168
169/// Checks that the given closure is a morphism. For now does nothing.
170#[proc_macro]
171pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
172    // TODO(mingwei): some sort of code analysis?
173    item
174}
175
176/// Checks that the given closure is a monotonic function. For now does nothing.
177#[proc_macro]
178pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
179    // TODO(mingwei): some sort of code analysis?
180    item
181}
182
183#[proc_macro_attribute]
184pub fn dfir_test(
185    args: proc_macro::TokenStream,
186    item: proc_macro::TokenStream,
187) -> proc_macro::TokenStream {
188    let root = root();
189    let args_2: proc_macro2::TokenStream = args.into();
190
191    wrap_localset(
192        item,
193        parse_quote!(
194            #[#root::tokio::test(flavor = "current_thread", #args_2)]
195        ),
196    )
197}
198
199#[proc_macro_attribute]
200pub fn dfir_main(
201    _: proc_macro::TokenStream,
202    item: proc_macro::TokenStream,
203) -> proc_macro::TokenStream {
204    let root = root();
205
206    wrap_localset(
207        item,
208        parse_quote!(
209            #[#root::tokio::main(flavor = "current_thread")]
210        ),
211    )
212}
213
214#[proc_macro_derive(DemuxEnum)]
215pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
216    let root = root();
217
218    let ItemEnum {
219        ident: item_ident,
220        generics,
221        variants,
222        ..
223    } = parse_macro_input!(item as ItemEnum);
224
225    // Sort variants alphabetically.
226    let mut variants = variants.into_iter().collect::<Vec<_>>();
227    variants.sort_by(|a, b| a.ident.cmp(&b.ident));
228
229    // Return type for each variant.
230    let variant_output_types = variants
231        .iter()
232        .map(|variant| match &variant.fields {
233            Fields::Named(fields) => {
234                let field_types = fields.named.iter().map(|field| &field.ty);
235                quote! {
236                    ( #( #field_types, )* )
237                }
238            }
239            Fields::Unnamed(fields) => {
240                let field_types = fields.unnamed.iter().map(|field| &field.ty);
241                quote! {
242                    ( #( #field_types, )* )
243                }
244            }
245            Fields::Unit => quote!(()),
246        })
247        .collect::<Vec<_>>();
248
249    let variant_generics_sink = variants
250        .iter()
251        .map(|variant| format_ident!("__Sink{}", variant.ident))
252        .collect::<Vec<_>>();
253    let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
254        quote_spanned! {ident.span()=>
255            ::std::pin::Pin::<&mut #ident>
256        }
257    });
258    let variant_generics_pinned_sink_all = quote! {
259        ( #( #variant_generics_pinned_sink, )* )
260    };
261    let variant_localvars_sink = variants
262        .iter()
263        .map(|variant| {
264            format_ident!(
265                "__sink_{}",
266                variant.ident.to_string().to_lowercase(),
267                span = variant.ident.span()
268            )
269        })
270        .collect::<Vec<_>>();
271
272    let mut full_generics_sink = generics.clone();
273    full_generics_sink.params.extend(
274        variant_generics_sink
275            .iter()
276            .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
277    );
278    full_generics_sink.make_where_clause().predicates.extend(
279        variant_generics_sink
280            .iter()
281            .zip(variant_output_types.iter())
282            .map::<WherePredicate, _>(|(sink_generic, output_type)| {
283                parse_quote! {
284                    // TODO(mingwei): generic error types?
285                    #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
286                }
287            }),
288    );
289
290    let variant_pats_sink_start_send =
291        variants
292            .iter()
293            .zip(variant_localvars_sink.iter())
294            .map(|(variant, sinkvar)| {
295                let Variant { ident, fields, .. } = variant;
296                let (fields_pat, push_item) = field_pattern_item(fields);
297                quote! {
298                    Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
299                }
300            });
301
302    let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
303    let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
304        full_generics_sink.split_for_impl();
305
306    let single_impl = (1 == variants.len()).then(|| {
307        let Variant { ident, fields, .. } = variants.first().unwrap();
308        let (fields_pat, push_item) = field_pattern_item(fields);
309        let out_type = variant_output_types.first().unwrap();
310        quote! {
311            impl #impl_generics_item #root::util::demux_enum::SingleVariant
312                for #item_ident #ty_generics #where_clause_item
313            {
314                type Output = #out_type;
315                fn single_variant(self) -> Self::Output {
316                    match self {
317                        Self::#ident #fields_pat => #push_item,
318                    }
319                }
320            }
321        }
322    });
323
324    quote! {
325        impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
326            for #item_ident #ty_generics #where_clause_sink
327        {
328            type Error = #root::Never;
329
330            fn poll_ready(
331                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
332                __cx: &mut ::std::task::Context<'_>,
333            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
334                // Ready all sinks simultaneously.
335                #(
336                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
337                )*
338                #(
339                    ::std::task::ready!(#variant_localvars_sink);
340                )*
341                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
342            }
343
344            fn start_send(
345                self,
346                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
347            ) -> ::std::result::Result<(), Self::Error> {
348                match self {
349                    #( #variant_pats_sink_start_send, )*
350                }
351            }
352
353            fn poll_flush(
354                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
355                __cx: &mut ::std::task::Context<'_>,
356            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
357                // Flush all sinks simultaneously.
358                #(
359                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
360                )*
361                #(
362                    ::std::task::ready!(#variant_localvars_sink);
363                )*
364                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
365            }
366
367            fn poll_close(
368                ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
369                __cx: &mut ::std::task::Context<'_>,
370            ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
371                // Close all sinks simultaneously.
372                #(
373                    let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
374                )*
375                #(
376                    ::std::task::ready!(#variant_localvars_sink);
377                )*
378                ::std::task::Poll::Ready(::std::result::Result::Ok(()))
379            }
380        }
381
382        impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
383            for #item_ident #ty_generics #where_clause_item {}
384
385        #single_impl
386    }
387    .into()
388}
389
390/// (fields pattern, push item expr)
391fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
392    let idents = fields
393        .iter()
394        .enumerate()
395        .map(|(i, field)| {
396            field
397                .ident
398                .clone()
399                .unwrap_or_else(|| format_ident!("_{}", i))
400        })
401        .collect::<Vec<_>>();
402    let (fields_pat, push_item) = match fields {
403        Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
404        Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
405        Fields::Unit => (quote!(), quote!(())),
406    };
407    (fields_pat, push_item)
408}