1#![cfg_attr(
2 nightly,
3 feature(proc_macro_diagnostic, proc_macro_span, proc_macro_def_site)
4)]
5
6use dfir_lang::diagnostic::Level;
7use dfir_lang::graph::{
8 BuildDfirCodeOutput, FlatGraphBuilder, FlatGraphBuilderOutput, build_dfir_code, partition_graph,
9};
10use dfir_lang::parse::DfirCode;
11use proc_macro2::{Ident, Literal, Span};
12use quote::{format_ident, quote, quote_spanned};
13use syn::spanned::Spanned;
14use syn::{
15 Attribute, Fields, GenericParam, ItemEnum, Variant, WherePredicate, parse_macro_input,
16 parse_quote,
17};
18
19#[proc_macro]
26pub fn dfir_syntax(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27 dfir_syntax_internal(input, Some(Level::Help))
28}
29
30#[proc_macro]
34pub fn dfir_syntax_noemit(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
35 dfir_syntax_internal(input, None)
36}
37
38fn root() -> proc_macro2::TokenStream {
39 use std::env::{VarError, var as env_var};
40
41 let root_crate_name = format!(
42 "{}_rs",
43 env!("CARGO_PKG_NAME").strip_suffix("_macro").unwrap()
44 );
45 let root_crate_ident = root_crate_name.replace('-', "_");
46 let root_crate = proc_macro_crate::crate_name(&root_crate_name)
47 .unwrap_or_else(|_| panic!("{root_crate_name} should be present in `Cargo.toml`"));
48 match root_crate {
49 proc_macro_crate::FoundCrate::Itself => {
50 if Err(VarError::NotPresent) == env_var("CARGO_BIN_NAME")
51 && Err(VarError::NotPresent) != env_var("CARGO_PRIMARY_PACKAGE")
52 && Ok(&*root_crate_ident) == env_var("CARGO_CRATE_NAME").as_deref()
53 {
54 quote! { crate }
56 } else {
57 let ident: Ident = Ident::new(&root_crate_ident, Span::call_site());
59 quote! { ::#ident }
60 }
61 }
62 proc_macro_crate::FoundCrate::Name(name) => {
63 let ident = Ident::new(&name, Span::call_site());
64 quote! { ::#ident }
65 }
66 }
67}
68
69fn dfir_syntax_internal(
70 input: proc_macro::TokenStream,
71 retain_diagnostic_level: Option<Level>,
72) -> proc_macro::TokenStream {
73 let input = parse_macro_input!(input as DfirCode);
74 let root = root();
75
76 let (code, mut diagnostics) = match build_dfir_code(input, &root) {
77 Ok(BuildDfirCodeOutput {
78 partitioned_graph: _,
79 code,
80 diagnostics,
81 }) => (code, diagnostics),
82 Err(diagnostics) => (quote! { #root::scheduled::graph::Dfir::new() }, diagnostics),
83 };
84
85 let diagnostic_tokens = retain_diagnostic_level.and_then(|level| {
86 diagnostics.retain_level(level);
87 diagnostics.try_emit_all().err()
88 });
89
90 quote! {
91 {
92 #diagnostic_tokens
93 #code
94 }
95 }
96 .into()
97}
98
99#[proc_macro]
103pub fn dfir_parser(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
104 let input = parse_macro_input!(input as DfirCode);
105
106 let flat_graph_builder = FlatGraphBuilder::from_dfir(input);
107 let err_diagnostics = 'err: {
108 let (mut flat_graph, mut diagnostics) = match flat_graph_builder.build() {
109 Ok(FlatGraphBuilderOutput {
110 flat_graph,
111 uses: _,
112 diagnostics,
113 }) => (flat_graph, diagnostics),
114 Err(diagnostics) => {
115 break 'err diagnostics;
116 }
117 };
118
119 if let Err(diagnostic) = flat_graph.merge_modules() {
120 diagnostics.push(diagnostic);
121 break 'err diagnostics;
122 }
123
124 let flat_mermaid = flat_graph.mermaid_string_flat();
125
126 let part_graph = partition_graph(flat_graph).unwrap();
127 let part_mermaid = part_graph.to_mermaid(&Default::default());
128
129 let lit0 = Literal::string(&flat_mermaid);
130 let lit1 = Literal::string(&part_mermaid);
131
132 return quote! {
133 {
134 println!("{}\n\n{}\n", #lit0, #lit1);
135 }
136 }
137 .into();
138 };
139
140 err_diagnostics
141 .try_emit_all()
142 .err()
143 .unwrap_or_default()
144 .into()
145}
146
147fn wrap_localset(item: proc_macro::TokenStream, attribute: Attribute) -> proc_macro::TokenStream {
148 use quote::ToTokens;
149
150 let root = root();
151
152 let mut input: syn::ItemFn = match syn::parse(item) {
153 Ok(it) => it,
154 Err(e) => return e.into_compile_error().into(),
155 };
156
157 let statements = input.block.stmts;
158
159 input.block.stmts = parse_quote!(
160 #root::tokio::task::LocalSet::new().run_until(async {
161 #( #statements )*
162 }).await
163 );
164
165 input.attrs.push(attribute);
166
167 input.into_token_stream().into()
168}
169
170#[proc_macro]
172pub fn morphism(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
173 item
175}
176
177#[proc_macro]
179pub fn monotonic_fn(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
180 item
182}
183
184#[proc_macro_attribute]
185pub fn dfir_test(
186 args: proc_macro::TokenStream,
187 item: proc_macro::TokenStream,
188) -> proc_macro::TokenStream {
189 let root = root();
190 let args_2: proc_macro2::TokenStream = args.into();
191
192 wrap_localset(
193 item,
194 parse_quote!(
195 #[#root::tokio::test(flavor = "current_thread", #args_2)]
196 ),
197 )
198}
199
200#[proc_macro_attribute]
201pub fn dfir_main(
202 _: proc_macro::TokenStream,
203 item: proc_macro::TokenStream,
204) -> proc_macro::TokenStream {
205 let root = root();
206
207 wrap_localset(
208 item,
209 parse_quote!(
210 #[#root::tokio::main(flavor = "current_thread")]
211 ),
212 )
213}
214
215#[proc_macro_derive(DemuxEnum)]
216pub fn derive_demux_enum(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
217 let root = root();
218
219 let ItemEnum {
220 ident: item_ident,
221 generics,
222 variants,
223 ..
224 } = parse_macro_input!(item as ItemEnum);
225
226 let mut variants = variants.into_iter().collect::<Vec<_>>();
228 variants.sort_by(|a, b| a.ident.cmp(&b.ident));
229
230 let variant_output_types = variants
232 .iter()
233 .map(|variant| match &variant.fields {
234 Fields::Named(fields) => {
235 let field_types = fields.named.iter().map(|field| &field.ty);
236 quote! {
237 ( #( #field_types, )* )
238 }
239 }
240 Fields::Unnamed(fields) => {
241 let field_types = fields.unnamed.iter().map(|field| &field.ty);
242 quote! {
243 ( #( #field_types, )* )
244 }
245 }
246 Fields::Unit => quote!(()),
247 })
248 .collect::<Vec<_>>();
249
250 let variant_generics_sink = variants
251 .iter()
252 .map(|variant| format_ident!("__Sink{}", variant.ident))
253 .collect::<Vec<_>>();
254 let variant_generics_pinned_sink = variant_generics_sink.iter().map(|ident| {
255 quote_spanned! {ident.span()=>
256 ::std::pin::Pin::<&mut #ident>
257 }
258 });
259 let variant_generics_pinned_sink_all = quote! {
260 ( #( #variant_generics_pinned_sink, )* )
261 };
262 let variant_localvars_sink = variants
263 .iter()
264 .map(|variant| {
265 format_ident!(
266 "__sink_{}",
267 variant.ident.to_string().to_lowercase(),
268 span = variant.ident.span()
269 )
270 })
271 .collect::<Vec<_>>();
272
273 let mut full_generics_sink = generics.clone();
274 full_generics_sink.params.extend(
275 variant_generics_sink
276 .iter()
277 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
278 );
279 full_generics_sink.make_where_clause().predicates.extend(
280 variant_generics_sink
281 .iter()
282 .zip(variant_output_types.iter())
283 .map::<WherePredicate, _>(|(sink_generic, output_type)| {
284 parse_quote! {
285 #sink_generic: #root::futures::sink::Sink<#output_type, Error = #root::Never>
287 }
288 }),
289 );
290
291 let variant_pats_sink_start_send =
292 variants
293 .iter()
294 .zip(variant_localvars_sink.iter())
295 .map(|(variant, sinkvar)| {
296 let Variant { ident, fields, .. } = variant;
297 let (fields_pat, push_item) = field_pattern_item(fields);
298 quote! {
299 Self::#ident #fields_pat => #sinkvar.as_mut().start_send(#push_item)
300 }
301 });
302
303 let (impl_generics_item, ty_generics, where_clause_item) = generics.split_for_impl();
304 let (impl_generics_sink, _ty_generics_sink, where_clause_sink) =
305 full_generics_sink.split_for_impl();
306
307 let variant_generics_push = variants
308 .iter()
309 .map(|variant| format_ident!("__Push{}", variant.ident))
310 .collect::<Vec<_>>();
311 let variant_generics_pinned_push = variant_generics_push.iter().map(|ident| {
312 quote_spanned! {ident.span()=>
313 ::std::pin::Pin::<&mut #ident>
314 }
315 });
316 let variant_generics_pinned_push_all = quote! {
317 ( #( #variant_generics_pinned_push, )* )
318 };
319 let variant_localvars_push = variants
320 .iter()
321 .map(|variant| {
322 format_ident!(
323 "__push_{}",
324 variant.ident.to_string().to_lowercase(),
325 span = variant.ident.span()
326 )
327 })
328 .collect::<Vec<_>>();
329
330 let mut full_generics_push = generics.clone();
331 full_generics_push.params.extend(
332 variant_generics_push
333 .iter()
334 .map::<GenericParam, _>(|ident| parse_quote!(#ident)),
335 );
336 full_generics_push.make_where_clause().predicates.extend(
338 variant_generics_push
339 .iter()
340 .zip(variant_output_types.iter())
341 .map::<WherePredicate, _>(|(push_generic, output_type)| {
342 parse_quote! {
343 #push_generic: #root::dfir_pipes::push::Push<#output_type, ()>
344 }
345 }),
346 );
347
348 let ctx_type = variant_generics_push
354 .iter()
355 .zip(variant_output_types.iter())
356 .rev()
357 .map(|(push_generic, output_type)| {
358 quote_spanned! {push_generic.span()=>
359 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::Ctx<'__ctx>
360 }
361 })
362 .reduce(|rest, next| {
363 quote_spanned! {next.span()=>
364 <#next as #root::dfir_pipes::Context<'__ctx>>::Merged<#rest>
365 }
366 })
367 .unwrap_or_else(|| quote!(()));
368
369 let can_pend = variant_generics_push
370 .iter()
371 .zip(variant_output_types.iter())
372 .rev()
373 .map(|(push_generic, output_type)| {
374 quote_spanned! {push_generic.span()=>
375 <#push_generic as #root::dfir_pipes::push::Push<#output_type, ()>>::CanPend
376 }
377 })
378 .reduce(|rest, next| {
379 quote_spanned! {next.span()=>
380 <#next as #root::dfir_pipes::Toggle>::Or<#rest>
381 }
382 })
383 .unwrap_or_else(|| quote!(#root::dfir_pipes::No));
384
385 let push_poll_unwrap_context = |method_name: Ident| {
388 variant_localvars_push.split_last().map(|(lastvar, headvar)| {
389 quote! {
392 #(
393 let #headvar = {
394 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_self(__ctx);
395 #root::dfir_pipes::push::Push::#method_name(#headvar.as_mut(), __ctx)
396 };
397 let __ctx = <<#variant_generics_push as #root::dfir_pipes::push::Push<#variant_output_types, ()>>::Ctx<'_> as #root::dfir_pipes::Context<'_>>::unmerge_other(__ctx);
398 )*
399 let #lastvar = #root::dfir_pipes::push::Push::#method_name(#lastvar.as_mut(), __ctx);
400 #(
402 if #variant_localvars_push.is_pending() {
403 return #root::dfir_pipes::push::PushStep::pending();
404 }
405 )*
406 }
407 })
408 };
409 let push_poll_ready_body = (push_poll_unwrap_context)(format_ident!("poll_ready"));
410 let push_poll_flush_body = (push_poll_unwrap_context)(format_ident!("poll_flush"));
411
412 let variant_pats_push_send =
413 variants
414 .iter()
415 .zip(variant_localvars_push.iter())
416 .map(|(variant, pushvar)| {
417 let Variant { ident, fields, .. } = variant;
418 let (fields_pat, push_item) = field_pattern_item(fields);
419 quote! {
420 Self::#ident #fields_pat => { #root::dfir_pipes::push::Push::start_send(#pushvar.as_mut(), #push_item, __meta); }
421 }
422 });
423
424 let (impl_generics_push, _ty_generics_push, where_clause_push) =
425 full_generics_push.split_for_impl();
426
427 let single_impl = (1 == variants.len()).then(|| {
428 let Variant { ident, fields, .. } = variants.first().unwrap();
429 let (fields_pat, push_item) = field_pattern_item(fields);
430 let out_type = variant_output_types.first().unwrap();
431 quote! {
432 impl #impl_generics_item #root::util::demux_enum::SingleVariant
433 for #item_ident #ty_generics #where_clause_item
434 {
435 type Output = #out_type;
436 fn single_variant(self) -> Self::Output {
437 match self {
438 Self::#ident #fields_pat => #push_item,
439 }
440 }
441 }
442 }
443 });
444
445 quote! {
446 impl #impl_generics_sink #root::util::demux_enum::DemuxEnumSink<#variant_generics_pinned_sink_all>
447 for #item_ident #ty_generics #where_clause_sink
448 {
449 type Error = #root::Never;
450
451 fn poll_ready(
452 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
453 __cx: &mut ::std::task::Context<'_>,
454 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
455 #(
457 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_ready(__cx)?;
458 )*
459 #(
460 ::std::task::ready!(#variant_localvars_sink);
461 )*
462 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
463 }
464
465 fn start_send(
466 self,
467 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
468 ) -> ::std::result::Result<(), Self::Error> {
469 match self {
470 #( #variant_pats_sink_start_send, )*
471 }
472 }
473
474 fn poll_flush(
475 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
476 __cx: &mut ::std::task::Context<'_>,
477 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
478 #(
480 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_flush(__cx)?;
481 )*
482 #(
483 ::std::task::ready!(#variant_localvars_sink);
484 )*
485 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
486 }
487
488 fn poll_close(
489 ( #( #variant_localvars_sink, )* ): &mut #variant_generics_pinned_sink_all,
490 __cx: &mut ::std::task::Context<'_>,
491 ) -> ::std::task::Poll<::std::result::Result<(), Self::Error>> {
492 #(
494 let #variant_localvars_sink = #variant_localvars_sink.as_mut().poll_close(__cx)?;
495 )*
496 #(
497 ::std::task::ready!(#variant_localvars_sink);
498 )*
499 ::std::task::Poll::Ready(::std::result::Result::Ok(()))
500 }
501 }
502
503 impl #impl_generics_push #root::util::demux_enum::DemuxEnumPush<#variant_generics_pinned_push_all, ()>
504 for #item_ident #ty_generics #where_clause_push
505 {
506 type Ctx<'__ctx> = #ctx_type;
507 type CanPend = #can_pend;
508
509 fn poll_ready(
510 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
511 __ctx: &mut Self::Ctx<'_>,
512 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
513 #push_poll_ready_body
514 #root::dfir_pipes::push::PushStep::Done
515 }
516
517 fn start_send(
518 self,
519 __meta: (),
520 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
521 ) {
522 match self {
523 #( #variant_pats_push_send, )*
524 }
525 }
526
527 fn poll_flush(
528 ( #( #variant_localvars_push, )* ): &mut #variant_generics_pinned_push_all,
529 __ctx: &mut Self::Ctx<'_>,
530 ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
531 #push_poll_flush_body
532 #root::dfir_pipes::push::PushStep::Done
533 }
534 }
535
536 impl #impl_generics_item #root::util::demux_enum::DemuxEnumBase
537 for #item_ident #ty_generics #where_clause_item {}
538
539 #single_impl
540 }
541 .into()
542}
543
544fn field_pattern_item(fields: &Fields) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
546 let idents = fields
547 .iter()
548 .enumerate()
549 .map(|(i, field)| {
550 field
551 .ident
552 .clone()
553 .unwrap_or_else(|| format_ident!("_{}", i))
554 })
555 .collect::<Vec<_>>();
556 let (fields_pat, push_item) = match fields {
557 Fields::Named(_) => (quote!( { #( #idents, )* } ), quote!( ( #( #idents, )* ) )),
558 Fields::Unnamed(_) => (quote!( ( #( #idents ),* ) ), quote!( ( #( #idents, )* ) )),
559 Fields::Unit => (quote!(), quote!(())),
560 };
561 (fields_pat, push_item)
562}