Skip to main content

dfir_lang/
process_singletons.rs

1//! Utility methods for processing singleton references: `#my_var`.
2
3use itertools::Itertools;
4use proc_macro2::{Group, Ident, TokenStream, TokenTree};
5use syn::punctuated::Punctuated;
6use syn::{Expr, Token};
7
8use crate::parse::parse_terminated;
9
10/// Finds all the singleton references `#my_var` and appends them to `found_idents`. Returns the
11/// `TokenStream` but with the hashes removed from the varnames.
12///
13/// The returned tokens are used for "preflight" parsing, to check that the rest of the syntax is
14/// OK. However the returned tokens are not used in the codegen as we need to use [`postprocess_singletons`]
15/// later to substitute-in the context referencing code for each singleton
16pub fn preprocess_singletons(tokens: TokenStream, found_idents: &mut Vec<Ident>) -> TokenStream {
17    process_singletons(tokens, &mut |singleton_ident| {
18        found_idents.push(singleton_ident.clone());
19        TokenTree::Ident(singleton_ident)
20    })
21}
22
23/// Replaces singleton references `#my_var` with the code needed to actually get the value inside.
24///
25/// * `tokens` - The tokens to update singleton references within.
26/// * `resolved_idents` - The local variable idents that correspond 1:1 and in the same
27///   order as the singleton references within `tokens` (found in-order via [`preprocess_singletons`]).
28///
29/// Generates `(*&ident)` — an immutable place expression that prevents consumer mutation.
30/// Use [`postprocess_singletons_handles`] for just the raw idents.
31pub fn postprocess_singletons(
32    tokens: TokenStream,
33    resolved_idents: impl IntoIterator<Item = Ident>,
34) -> Punctuated<Expr, Token![,]> {
35    let mut resolved_idents_iter = resolved_idents.into_iter();
36    let processed = process_singletons(tokens, &mut |singleton_ident| {
37        let span = singleton_ident.span();
38        let mut resolved_ident = resolved_idents_iter.next().unwrap();
39        resolved_ident.set_span(span);
40        // Emit `(*&ident)` so consumers get an immutable place expression.
41        // The `&` prevents mutation (can't assign through a shared reference),
42        // and the `*` dereferences back to the original type for ergonomic use.
43        let deref_ref_tokens: TokenStream = [
44            TokenTree::Punct(proc_macro2::Punct::new('*', proc_macro2::Spacing::Alone)),
45            TokenTree::Punct(proc_macro2::Punct::new('&', proc_macro2::Spacing::Alone)),
46            TokenTree::Ident(resolved_ident),
47        ]
48        .into_iter()
49        .collect();
50        let mut group = Group::new(proc_macro2::Delimiter::Parenthesis, deref_ref_tokens);
51        group.set_span(span);
52        TokenTree::Group(group)
53    });
54    parse_terminated(processed).unwrap()
55}
56
57/// Same as [`postprocess_singletons`] but generates just the raw ident rather than
58/// `RefCell` borrowing code.
59pub fn postprocess_singletons_handles(
60    tokens: TokenStream,
61    resolved_idents: impl IntoIterator<Item = Ident>,
62) -> Punctuated<Expr, Token![,]> {
63    let mut resolved_idents_iter = resolved_idents.into_iter();
64    let processed = process_singletons(tokens, &mut |singleton_ident| {
65        let mut resolved_ident = resolved_idents_iter.next().unwrap();
66        resolved_ident.set_span(singleton_ident.span().resolved_at(resolved_ident.span()));
67        TokenTree::Ident(resolved_ident)
68    });
69    parse_terminated(processed).unwrap()
70}
71
72/// Traverse the token stream, applying the `map_singleton_fn` whenever a singleton is found,
73/// returning the transformed token stream.
74fn process_singletons(
75    tokens: TokenStream,
76    map_singleton_fn: &mut impl FnMut(Ident) -> TokenTree,
77) -> TokenStream {
78    tokens
79        .into_iter()
80        .peekable()
81        .batching(|iter| {
82            let out = match iter.next()? {
83                TokenTree::Group(group) => {
84                    let mut new_group = Group::new(
85                        group.delimiter(),
86                        process_singletons(group.stream(), map_singleton_fn),
87                    );
88                    new_group.set_span(group.span());
89                    TokenTree::Group(new_group)
90                }
91                TokenTree::Ident(ident) => TokenTree::Ident(ident),
92                TokenTree::Punct(punct) => {
93                    if '#' == punct.as_char() && matches!(iter.peek(), Some(TokenTree::Ident(_))) {
94                        // Found a singleton.
95                        let Some(TokenTree::Ident(mut singleton_ident)) = iter.next() else {
96                            unreachable!()
97                        };
98                        {
99                            // Include the `#` in the span.
100                            let span = singleton_ident
101                                .span()
102                                .join(punct.span())
103                                .unwrap_or(singleton_ident.span());
104                            singleton_ident.set_span(span.resolved_at(singleton_ident.span()));
105                        }
106                        (map_singleton_fn)(singleton_ident)
107                    } else {
108                        TokenTree::Punct(punct)
109                    }
110                }
111                TokenTree::Literal(lit) => TokenTree::Literal(lit),
112            };
113            Some(out)
114        })
115        .collect()
116}