1use std::fmt::{Debug, Formatter};
2use std::marker::PhantomData;
3
4use proc_macro2::Span;
5use quote::quote;
6use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
7use stageleft::{QuotedWithContextWithProps, quote_type};
8
9use super::dynamic::LocationId;
10use super::{Location, MemberId};
11use crate::compile::builder::FlowState;
12use crate::location::LocationKey;
13use crate::location::member_id::TaglessMemberId;
14use crate::staging_util::{Invariant, get_this_crate};
15
16pub struct Cluster<'a, ClusterTag> {
17 pub(crate) key: LocationKey,
18 pub(crate) flow_state: FlowState,
19 pub(crate) _phantom: Invariant<'a, ClusterTag>,
20}
21
22impl<C> Debug for Cluster<'_, C> {
23 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24 write!(f, "Cluster({})", self.key)
25 }
26}
27
28impl<C> Eq for Cluster<'_, C> {}
29impl<C> PartialEq for Cluster<'_, C> {
30 fn eq(&self, other: &Self) -> bool {
31 self.key == other.key && FlowState::ptr_eq(&self.flow_state, &other.flow_state)
32 }
33}
34
35impl<C> Clone for Cluster<'_, C> {
36 fn clone(&self) -> Self {
37 Cluster {
38 key: self.key,
39 flow_state: self.flow_state.clone(),
40 _phantom: PhantomData,
41 }
42 }
43}
44
45impl<'a, C> super::dynamic::DynLocation for Cluster<'a, C> {
46 fn id(&self) -> LocationId {
47 LocationId::Cluster(self.key)
48 }
49
50 fn flow_state(&self) -> &FlowState {
51 &self.flow_state
52 }
53
54 fn is_top_level() -> bool {
55 true
56 }
57
58 fn multiversioned(&self) -> bool {
59 false }
61}
62
63impl<'a, C> Location<'a> for Cluster<'a, C> {
64 type Root = Cluster<'a, C>;
65
66 fn root(&self) -> Self::Root {
67 self.clone()
68 }
69}
70
71pub struct ClusterIds<'a> {
72 pub key: LocationKey,
73 pub _phantom: PhantomData<&'a ()>,
74}
75
76impl<'a> Clone for ClusterIds<'a> {
77 fn clone(&self) -> Self {
78 Self {
79 key: self.key,
80 _phantom: Default::default(),
81 }
82 }
83}
84
85impl<'a, Ctx> FreeVariableWithContextWithProps<Ctx, ()> for ClusterIds<'a> {
86 type O = &'a [TaglessMemberId];
87
88 fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
89 where
90 Self: Sized,
91 {
92 let ident = syn::Ident::new(
93 &format!("__hydro_lang_cluster_ids_{}", self.key),
94 Span::call_site(),
95 );
96
97 (
98 QuoteTokens {
99 prelude: None,
100 expr: Some(quote! { #ident }),
101 },
102 (),
103 )
104 }
105}
106
107impl<'a, Ctx> QuotedWithContextWithProps<'a, &'a [TaglessMemberId], Ctx, ()> for ClusterIds<'a> {}
108
109pub trait IsCluster {
110 type Tag;
111}
112
113impl<C> IsCluster for Cluster<'_, C> {
114 type Tag = C;
115}
116
117pub static CLUSTER_SELF_ID: ClusterSelfId = ClusterSelfId { _private: &() };
120
121#[derive(Clone, Copy)]
122pub struct ClusterSelfId<'a> {
123 _private: &'a (),
124}
125
126impl<'a, L> FreeVariableWithContextWithProps<L, ()> for ClusterSelfId<'a>
127where
128 L: Location<'a>,
129 <L as Location<'a>>::Root: IsCluster,
130{
131 type O = MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>;
132
133 fn to_tokens(self, ctx: &L) -> (QuoteTokens, ())
134 where
135 Self: Sized,
136 {
137 let cluster_id = if let LocationId::Cluster(id) = ctx.root().id() {
138 id
139 } else {
140 unreachable!()
141 };
142
143 let ident = syn::Ident::new(
144 &format!("__hydro_lang_cluster_self_id_{}", cluster_id),
145 Span::call_site(),
146 );
147 let root = get_this_crate();
148 let c_type: syn::Type = quote_type::<<<L as Location<'a>>::Root as IsCluster>::Tag>();
149
150 (
151 QuoteTokens {
152 prelude: None,
153 expr: Some(
154 quote! { #root::__staged::location::MemberId::<#c_type>::from_tagless((#ident).clone()) },
155 ),
156 },
157 (),
158 )
159 }
160}
161
162impl<'a, L>
163 QuotedWithContextWithProps<'a, MemberId<<<L as Location<'a>>::Root as IsCluster>::Tag>, L, ()>
164 for ClusterSelfId<'a>
165where
166 L: Location<'a>,
167 <L as Location<'a>>::Root: IsCluster,
168{
169}
170
171#[cfg(test)]
172mod tests {
173 #[cfg(feature = "sim")]
174 use stageleft::q;
175
176 #[cfg(feature = "sim")]
177 use super::CLUSTER_SELF_ID;
178 #[cfg(feature = "sim")]
179 use crate::location::{Location, MemberId, MembershipEvent};
180 #[cfg(feature = "sim")]
181 use crate::networking::TCP;
182 #[cfg(feature = "sim")]
183 use crate::nondet::nondet;
184 #[cfg(feature = "sim")]
185 use crate::prelude::FlowBuilder;
186
187 #[cfg(feature = "sim")]
188 #[test]
189 fn sim_cluster_self_id() {
190 let mut flow = FlowBuilder::new();
191 let cluster1 = flow.cluster::<()>();
192 let cluster2 = flow.cluster::<()>();
193
194 let node = flow.process::<()>();
195
196 let out_recv = cluster1
197 .source_iter(q!(vec![CLUSTER_SELF_ID]))
198 .send(&node, TCP.bincode())
199 .values()
200 .interleave(
201 cluster2
202 .source_iter(q!(vec![CLUSTER_SELF_ID]))
203 .send(&node, TCP.bincode())
204 .values(),
205 )
206 .sim_output();
207
208 flow.sim()
209 .with_cluster_size(&cluster1, 3)
210 .with_cluster_size(&cluster2, 4)
211 .exhaustive(async || {
212 out_recv
213 .assert_yields_only_unordered([0, 1, 2, 0, 1, 2, 3].map(MemberId::from_raw_id))
214 .await
215 });
216 }
217
218 #[cfg(feature = "sim")]
219 #[test]
220 fn sim_cluster_with_tick() {
221 use std::collections::HashMap;
222
223 let mut flow = FlowBuilder::new();
224 let cluster = flow.cluster::<()>();
225 let node = flow.process::<()>();
226
227 let out_recv = cluster
228 .source_iter(q!(vec![1, 2, 3]))
229 .batch(&cluster.tick(), nondet!())
230 .count()
231 .all_ticks()
232 .send(&node, TCP.bincode())
233 .entries()
234 .map(q!(|(id, v)| (id, v)))
235 .sim_output();
236
237 let count = flow
238 .sim()
239 .with_cluster_size(&cluster, 2)
240 .exhaustive(async || {
241 let grouped = out_recv.collect_sorted::<Vec<_>>().await.into_iter().fold(
242 HashMap::new(),
243 |mut acc: HashMap<MemberId<()>, usize>, (id, v)| {
244 *acc.entry(id).or_default() += v;
245 acc
246 },
247 );
248
249 assert!(grouped.len() == 2);
250 for (_id, v) in grouped {
251 assert!(v == 3);
252 }
253 });
254
255 assert_eq!(count, 106);
256 }
260
261 #[cfg(feature = "sim")]
262 #[test]
263 fn sim_cluster_membership() {
264 let mut flow = FlowBuilder::new();
265 let cluster = flow.cluster::<()>();
266 let node = flow.process::<()>();
267
268 let out_recv = node
269 .source_cluster_members(&cluster)
270 .entries()
271 .map(q!(|(id, v)| (id, v)))
272 .sim_output();
273
274 flow.sim()
275 .with_cluster_size(&cluster, 2)
276 .exhaustive(async || {
277 out_recv
278 .assert_yields_only_unordered(vec![
279 (MemberId::from_raw_id(0), MembershipEvent::Joined),
280 (MemberId::from_raw_id(1), MembershipEvent::Joined),
281 ])
282 .await;
283 });
284 }
285}