1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8 BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::{
14 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
15 parse_quote,
16};
17
18#[proc_macro]
25pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
26 dfir_syntax_internal(input, Some(Level::Help))
27}
28
29#[proc_macro]
33pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34 dfir_syntax_internal(input, None)
35}
36
37fn root() -> proc_macro2::TokenStream {
38 use std::env::{VarError, var as env_var};
39
40 let root_crate_name = format!(
41 "{}_rs",
42 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
43 );
44 let root_crate_ident = root_crate_name.replace('-', "_");
45 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
46 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
47 match root_crate {
48 proc_macro_crate::FoundCrate::Itself => {
49 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
50 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
51 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
52 {
53 quote! { crate }
55 } else {
56 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
58 quote! { ::#ident }
59 }
60 }
61 proc_macro_crate::FoundCrate::Name(name) => {
62 let ident = Ident::new(&name, Span::call_site());
63 quote! { ::#ident }
64 }
65 }
66}
67
68fn dfir_syntax_internal(
69 input: proc_macro::TokenStream,
70 retain_diagnostic_level: Option<Level>,
71) -> proc_macro::TokenStream {
72 let input = parse_macro_input!(input as DfirCode);
73 let root = root();
74
75 let (code, mut diagnostics) = match build_dfir_code(input, &root) {
76 Ok(BuildDfirCodeOutput {
77 partitioned_graph: _,
78 code,
79 diagnostics,
80 }) => (code, diagnostics),
81 Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
82 };
83
84 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
85 diagnostics.retain_level(level);
86 diagnostics.try_emit_all().err()
87 });
88
89 quote! {
90 {
91 #diagnostic_tokens
92 #code
93 }
94 }
95 .into()
96}
97
98#[proc_macro]
102pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
103 let input = parse_macro_input!(input as DfirCode);
104
105 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
106 let err_diagnostics = 'err: {
107 let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
108 Ok(FlatGraphBuilderOutput {
109 flat_graph,
110 uses: _,
111 diagnostics,
112 }) => (flat_graph, diagnostics),
113 Err(diagnostics) => {
114 break 'err diagnostics;
115 }
116 };
117
118 if let Err(diagnostic) = flat_graph.merge_modules() {
119 diagnostics.push(diagnostic);
120 break 'err diagnostics;
121 }
122
123 let flat_mermaid = flat_graph.mermaid_string_flat();
124
125 let part_graph = partition_graph(flat_graph).unwrap();
126 let part_mermaid = part_graph.to_mermaid(&Default::default());
127
128 let lit0 = Literal::string(&flat_mermaid);
129 let lit1 = Literal::string(&part_mermaid);
130
131 return quote! {
132 {
133 println!("{}\n\n{}\n", #lit0, #lit1);
134 }
135 }
136 .into();
137 };
138
139 err_diagnostics
140 .try_emit_all()
141 .err()
142 .unwrap_or_default()
143 .into()
144}
145
146fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
147 use quote::ToTokens;
148
149 let root = root();
150
151 let mut input: syn::ItemFn = match syn::parse(item) {
152 Ok(it) => it,
153 Err(e) => return e.into_compile_error().into(),
154 };
155
156 let statements = input.block.stmts;
157
158 input.block.stmts = parse_quote!(
159 #root::tokio::task::LocalSet::new().run_until(async {
160 #( #statements )*
161 }).await
162 );
163
164 input.attrs.push(attribute);
165
166 input.into_token_stream().into()
167}
168
169#[proc_macro]
171pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
172 item
174}
175
176#[proc_macro]
178pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
179 item
181}
182
183#[proc_macro_attribute]
184pub fn dfir_test(
185 args: proc_macro::TokenStream,
186 item: proc_macro::TokenStream,
187) -> proc_macro::TokenStream {
188 let root = root();
189 let args_2: proc_macro2::TokenStream = args.into();
190
191 wrap_localset(
192 item,
193 parse_quote!(
194 #[#root::tokio::test(flavor = "current_thread", #args_2)]
195 ),
196 )
197}
198
199#[proc_macro_attribute]
200pub fn dfir_main(
201 _: proc_macro::TokenStream,
202 item: proc_macro::TokenStream,
203) -> proc_macro::TokenStream {
204 let root = root();
205
206 wrap_localset(
207 item,
208 parse_quote!(
209 #[#root::tokio::main(flavor = "current_thread")]
210 ),
211 )
212}
213
214#[proc_macro_derive(DemuxEnum)]
215pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
216 let root = root();
217
218 let ItemEnum {
219 ident: item_ident,
220 generics,
221 variants,
222 ..
223 } = parse_macro_input!(item as ItemEnum);
224
225 let mut variants = variants.into_iter().collect::<Vec<_>>();
227 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
228
229 let variant_output_types = variants
231 .iter()
232 .map(|variant| match &variant.fields {
233 Fields::Named(fields) => {
234 let field_types = fields.named.iter().map(|field| &field.ty);
235 quote! {
236 ( #( #field_types, )* )
237 }
238 }
239 Fields::Unnamed(fields) => {
240 let field_types = fields.unnamed.iter().map(|field| &field.ty);
241 quote! {
242 ( #( #field_types, )* )
243 }
244 }
245 Fields::Unit => quote!(()),
246 })
247 .collect::<Vec<_>>();
248
249 let variant_generics_sink = variants
250 .iter()
251 .map(|variant| format_ident!("__Sink{}", variant.ident))
252 .collect::<Vec<_>>();
253 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
254 quote_spanned! {ident.span()=>
255 ::std::pin::Pin::<&mut #ident>
256 }
257 });
258 let variant_generics_pinned_sink_all = quote! {
259 ( #( #variant_generics_pinned_sink, )* )
260 };
261 let variant_localvars_sink = variants
262 .iter()
263 .map(|variant| {
264 format_ident!(
265 "__sink_{}",
266 variant.ident.to_string().to_lowercase(),
267 span = variant.ident.span()
268 )
269 })
270 .collect::<Vec<_>>();
271
272 let mut full_generics_sink = generics.clone();
273 full_generics_sink.params.extend(
274 variant_generics_sink
275 .iter()
276 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
277 );
278 full_generics_sink.make_where_clause().predicates.extend(
279 variant_generics_sink
280 .iter()
281 .zip(variant_output_types.iter())
282 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
283 parse_quote! {
284 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
286 }
287 }),
288 );
289
290 let variant_pats_sink_start_send =
291 variants
292 .iter()
293 .zip(variant_localvars_sink.iter())
294 .map(|(variant, sinkvar)| {
295 let Variant { ident, fields, .. } = variant;
296 let (fields_pat, push_item) = field_pattern_item(fields);
297 quote! {
298 Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
299 }
300 });
301
302 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
303 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
304 full_generics_sink.split_for_impl();
305
306 let single_impl = (1 == variants.len()).then(|| {
307 let Variant { ident, fields, .. } = variants.first().unwrap();
308 let (fields_pat, push_item) = field_pattern_item(fields);
309 let out_type = variant_output_types.first().unwrap();
310 quote! {
311 impl #impl_generics_item #root::util::demux_enum::SingleVariant
312 for #item_ident #ty_generics #where_clause_item
313 {
314 type Output = #out_type;
315 fn single_variant(self) -> Self::Output {
316 match self {
317 Self::#ident #fields_pat => #push_item,
318 }
319 }
320 }
321 }
322 });
323
324 quote! {
325 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
326 for #item_ident #ty_generics #where_clause_sink
327 {
328 type Error = #root::Never;
329
330 fn poll_ready(
331 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
332 __cx: &mut ::std::task::Context<'_>,
333 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
334 #(
336 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
337 )*
338 #(
339 ::std::task::ready!(#variant_localvars_sink);
340 )*
341 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
342 }
343
344 fn start_send(
345 self,
346 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
347 ) -> ::std::result::Result<(), Self::Error> {
348 match self {
349 #( #variant_pats_sink_start_send, )*
350 }
351 }
352
353 fn poll_flush(
354 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
355 __cx: &mut ::std::task::Context<'_>,
356 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
357 #(
359 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
360 )*
361 #(
362 ::std::task::ready!(#variant_localvars_sink);
363 )*
364 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
365 }
366
367 fn poll_close(
368 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
369 __cx: &mut ::std::task::Context<'_>,
370 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
371 #(
373 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
374 )*
375 #(
376 ::std::task::ready!(#variant_localvars_sink);
377 )*
378 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
379 }
380 }
381
382 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
383 for #item_ident #ty_generics #where_clause_item {}
384
385 #single_impl
386 }
387 .into()
388}
389
390fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
392 let idents = fields
393 .iter()
394 .enumerate()
395 .map(|(i, field)| {
396 field
397 .ident
398 .clone()
399 .unwrap_or_else(|| format_ident!("_{}", i))
400 })
401 .collect::<Vec<_>>();
402 let (fields_pat, push_item) = match fields {
403 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
404 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
405 Fields::Unit => (quote!(), quote!(())),
406 };
407 (fields_pat, push_item)
408}