1#![allow(
2 unused,
3 reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23 FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{Instrument, debug, error, instrument, span, trace, trace_span};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35pub fn deploy_containerized_o2o(
36 target_task_family: &str,
37 bind_port: u16,
38) -> (syn::Expr, syn::Expr) {
39 (
40 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
41 async move {
42 let target_task_family = target_task_family;
43 let task_id = self::resolve_task_family_to_task_id(target_task_family).await;
44 let ip = self::resolve_task_ip(&task_id).await;
45 let target = format!("{}:{}", ip, bind_port);
46 debug!(name: "connecting", %target, %target_task_family, %task_id);
47
48 let stream = TcpStream::connect(&target).await?;
49
50 Result::<_, std::io::Error>::Ok(FramedWrite::new(
51 stream,
52 LengthDelimitedCodec::new(),
53 ))
54 }
55 )))
56 .splice_untyped_ctx(&()),
57 q!(LazySource::new(move || Box::pin(async move {
58 let bind_addr = format!("0.0.0.0:{}", bind_port);
59 let listener = TcpListener::bind(bind_addr).await?;
60 let (stream, peer) = listener.accept().await?;
61 debug!(name: "accepting", ?peer);
62 Result::<_, std::io::Error>::Ok(FramedRead::new(stream, LengthDelimitedCodec::new()))
63 })))
64 .splice_untyped_ctx(&()),
65 )
66}
67
68pub fn deploy_containerized_o2m(port: u16) -> (syn::Expr, syn::Expr) {
69 (
70 QuotedWithContext::<'static, LazyDemuxSink<TaglessMemberId, _, _>, ()>::splice_untyped_ctx(
71 q!(sinktools::demux_map_lazy::<_, _, _, _>(
72 move |key: &TaglessMemberId| {
73 let key = key.clone();
74
75 LazySink::<_, _, _, bytes::Bytes>::new(move || {
76 Box::pin(async move {
77 let port = port;
78 let task_id = key.get_container_name();
79 let ip = self::resolve_task_ip(&task_id).await;
80 let target = format!("{}:{}", ip, port);
81 debug!(name: "connecting", %target, %task_id);
82
83 let stream = TcpStream::connect(&target).await?;
84
85 let sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
86 Result::<_, std::io::Error>::Ok(sink)
87 })
88 })
89 }
90 )),
91 &(),
92 ),
93 q!(LazySource::new(move || Box::pin(async move {
94 let bind_addr = format!("0.0.0.0:{}", port);
95 debug!(name: "listening", %bind_addr);
96 let listener = TcpListener::bind(bind_addr).await?;
97 let (stream, peer) = listener.accept().await?;
98 debug!(name: "accepting", ?peer);
99
100 Result::<_, std::io::Error>::Ok(FramedRead::new(stream, LengthDelimitedCodec::new()))
101 })))
102 .splice_untyped_ctx(&()),
103 )
104}
105
106pub fn deploy_containerized_m2o(port: u16, target_task_family: &str) -> (syn::Expr, syn::Expr) {
107 (
108 q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
109 Box::pin(async move {
110 let target_task_family = target_task_family;
111 let target_task_id = self::resolve_task_family_to_task_id(target_task_family).await;
112 let ip = self::resolve_task_ip(&target_task_id).await;
113 let target = format!("{}:{}", ip, port);
114 debug!(name: "connecting", %target, %target_task_family, %target_task_id);
115
116 let stream = TcpStream::connect(&target).await?;
117
118 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
119
120 let self_task_id = self::get_self_task_id();
121 sink.send(bytes::Bytes::from(
122 bincode::serialize(&self_task_id).unwrap(),
123 ))
124 .await?;
125
126 Result::<_, std::io::Error>::Ok(sink)
127 })
128 }))
129 .splice_untyped_ctx(&()),
130 QuotedWithContext::<'static, LazySource<_, _, _, Result<(TaglessMemberId, BytesMut), _>>, ()>::splice_untyped_ctx(
131 q!(LazySource::new(move || Box::pin(async move {
132 let bind_addr = format!("0.0.0.0:{}", port);
133 debug!(name: "listening", %bind_addr);
134 let listener = TcpListener::bind(bind_addr).await?;
135 Result::<_, std::io::Error>::Ok(
136 futures::stream::unfold(listener, |listener| {
137 Box::pin(async move {
138 let (stream, peer) = listener.accept().await.ok()?;
139 let mut source = FramedRead::new(stream, LengthDelimitedCodec::new());
140 let from_task_id =
141 bincode::deserialize::<String>(&source.next().await?.ok()?[..])
142 .ok()?;
143
144 debug!(name: "accepting", endpoint = format!("{}:{}", peer, from_task_id));
145
146 Some((
147 source.map(move |v| {
148 v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
149 }),
150 listener,
151 ))
152 })
153 })
154 .flatten_unordered(None),
155 )
156 }))),
157 &(),
158 ),
159 )
160}
161
162pub fn deploy_containerized_m2m(port: u16) -> (syn::Expr, syn::Expr) {
163 (
164 QuotedWithContext::<'static, LazyDemuxSink<TaglessMemberId, _, _>, ()>::splice_untyped_ctx(
165 q!(sinktools::demux_map_lazy::<_, _, _, _>(
166 move |key: &TaglessMemberId| {
167 let key = key.clone();
168
169 LazySink::<_, _, _, bytes::Bytes>::new(move || {
170 Box::pin(async move {
171 let port = port;
172 let task_id = key.get_container_name();
173 let ip = self::resolve_task_ip(&task_id).await;
174 let target = format!("{}:{}", ip, port);
175 debug!(name: "connecting", %target, %task_id);
176
177 let stream = TcpStream::connect(&target).await?;
178
179 let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
180 debug!(name: "connected", %target);
181
182 let self_task_id = self::get_self_task_id();
183 sink.send(bytes::Bytes::from(
184 bincode::serialize(&self_task_id).unwrap(),
185 ))
186 .await?;
187
188 Result::<_, std::io::Error>::Ok(sink)
189 })
190 })
191 }
192 )),
193 &(),
194 ),
195 QuotedWithContext::<'static, LazySource<_, _, _, Result<(TaglessMemberId, BytesMut), _>>, ()>::splice_untyped_ctx(
196 q!(LazySource::new(move || Box::pin(async move {
197 let bind_addr = format!("0.0.0.0:{}", port);
198 debug!(name: "listening", %bind_addr);
199 let listener = TcpListener::bind(bind_addr).await?;
200
201 Result::<_, std::io::Error>::Ok(
202 futures::stream::unfold(listener, |listener| {
203 Box::pin(async move {
204 let (stream, peer) = listener.accept().await.ok()?;
205 let mut source = FramedRead::new(stream, LengthDelimitedCodec::new());
206 let from_task_id =
207 bincode::deserialize::<String>(&source.next().await?.ok()?[..])
208 .ok()?;
209
210 debug!(name: "accepting", endpoint = format!("{}:{}", peer, from_task_id));
211
212 Some((
213 source.map(move |v| {
214 v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
215 }),
216 listener,
217 ))
218 })
219 })
220 .flatten_unordered(None),
221 )
222 }))),
223 &(),
224 ),
225 )
226}
227
228pub struct SocketIdent {
229 pub socket_ident: syn::Ident,
230}
231
232impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
233 type O = TcpListener;
234
235 fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
236 where
237 Self: Sized,
238 {
239 let ident = self.socket_ident;
240
241 (
242 QuoteTokens {
243 prelude: None,
244 expr: Some(quote::quote! { #ident }),
245 },
246 (),
247 )
248 }
249}
250
251pub fn deploy_containerized_external_sink_source_ident(
252 bind_addr: String,
253 socket_ident: syn::Ident,
254) -> syn::Expr {
255 let socket_ident = SocketIdent { socket_ident };
256
257 q!(LazySinkSource::<
258 _,
259 FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
260 FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
261 bytes::Bytes,
262 std::io::Error,
264 >::new(async move {
265 let span = span!(tracing::Level::TRACE, "lazy_sink_source");
266 let guard = span.enter();
267 let bind_addr = bind_addr;
268 trace!(name: "attempting to accept from external", %bind_addr);
269 std::mem::drop(guard);
270 let (stream, peer) = socket_ident.accept().instrument(span.clone()).await?;
271 let guard = span.enter();
272
273 debug!(name: "external accepting", ?peer);
274 let (rx, tx) = stream.into_split();
275
276 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
277 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
278
279 Result::<_, std::io::Error>::Ok((fr, fw))
280 },))
281 .splice_untyped_ctx(&())
282}
283
284pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
285 q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
289 "INVALID CONTAINER NAME cluster_ids"
290 )]))
291 .as_slice())
292}
293
294pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
295 q!(TaglessMemberId::from_container_name(
296 self::get_self_task_id()
297 ))
298}
299
300pub fn cluster_membership_stream<'a>(
301 location_id: &LocationId,
302) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
303{
304 let location_key = location_id.key();
305
306 q!(Box::new(self::ecs_membership_stream(
307 std::env::var("CLUSTER_NAME").unwrap(),
308 location_key
309 ))
310 as Box<
311 dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
312 >)
313}
314
315#[instrument(skip_all, fields(%cluster_name, %location_key))]
316fn ecs_membership_stream(
317 cluster_name: String,
318 location_key: LocationKey,
319) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
320 use std::collections::HashSet;
321
322 use futures::stream::{StreamExt, once};
323
324 trace!(name: "ecs_membership_stream_created", %cluster_name, %location_key);
325
326 let ecs_poller_span = trace_span!("ecs_poller");
327
328 let task_definition_arn_parser =
331 regex::Regex::new(r#"arn:aws:ecs:(?<region>.*):(?<account_id>.*):task-definition\/(?<container_id>hy-(?<type>[^-]+)-loc(?<location_idx>[0-9]+)v(?<location_version>[0-9]+)(?:-(?<instance_id>.*))?):.*"#).unwrap();
332
333 let poll_stream = futures::stream::unfold(
334 (HashSet::<String>::new(), cluster_name, location_key),
335 move |(known_tasks, cluster_name, location_key)| {
336 let task_definition_arn_parser = task_definition_arn_parser.clone();
337
338 async move {
339 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
340 let ecs_client = aws_sdk_ecs::Client::new(&config);
341
342 let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
343 Ok(tasks) => tasks,
344 Err(e) => {
345 trace!(name: "list_tasks_error", error = %e);
346 tokio::time::sleep(Duration::from_secs(2)).await;
347 return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
348 }
349 };
350
351 let task_arns: Vec<String> =
352 tasks.task_arns().iter().map(|s| s.to_string()).collect();
353
354 let mut events = Vec::new();
355 let mut current_tasks = HashSet::<String>::new();
356
357 if !task_arns.is_empty() {
358 let task_details = match ecs_client
359 .describe_tasks()
360 .cluster(&cluster_name)
361 .set_tasks(Some(task_arns.clone()))
362 .send()
363 .await
364 {
365 Ok(details) => details,
366 Err(e) => {
367 trace!(name: "describe_tasks_error", error = %e);
368 tokio::time::sleep(Duration::from_secs(2)).await;
369 return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
370 }
371 };
372
373 for task in task_details.tasks() {
374 let Some(last_status) = task.last_status() else {
375 continue;
376 };
377
378 if last_status != "RUNNING" {
379 continue;
380 }
381
382 let Some(task_def_arn) = task.task_definition_arn() else {
383 continue;
384 };
385
386 let Some(captures) = task_definition_arn_parser.captures(task_def_arn)
387 else {
388 continue;
389 };
390
391 let Some(location_idx) = captures.name("location_idx") else {
392 continue;
393 };
394 let Some(location_version) = captures.name("location_version") else {
395 continue;
396 };
397 let location_key_str =
399 format!("loc{}v{}", location_idx.as_str(), location_version.as_str());
400 let task_location_key: LocationKey = match location_key_str.parse() {
401 Ok(key) => key,
402 Err(_) => {
403 continue;
404 }
405 };
406
407 if task_location_key != location_key {
409 continue;
410 }
411
412 let Some(task_arn) = task.task_arn() else {
415 continue;
416 };
417 let Some(task_id) = task_arn.rsplit('/').next() else {
418 continue;
419 };
420
421 current_tasks.insert(task_id.to_owned());
423 if !known_tasks.contains(task_id) {
424 trace!(name: "task_joined", %task_id);
425 events.push((task_id.to_owned(), MembershipEvent::Joined));
426 }
427 }
428 }
429
430 #[expect(
431 clippy::disallowed_methods,
432 reason = "nondeterministic iteration order, container events are not deterministically ordered"
433 )]
434 for task_id in known_tasks.iter() {
435 if !current_tasks.contains(task_id) {
436 trace!(name: "task_left", %task_id);
437 events.push((task_id.to_owned(), MembershipEvent::Left));
438 }
439 }
440
441 tokio::time::sleep(Duration::from_secs(2)).await;
442
443 Some((events, (current_tasks, cluster_name, location_key)))
444 }
445 .instrument(ecs_poller_span.clone())
446 },
447 )
448 .flat_map(futures::stream::iter);
449
450 Box::pin(
451 poll_stream
452 .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
453 .inspect(|(member_id, event)| trace!(name: "membership_event", ?member_id, ?event)),
454 )
455}
456
457async fn resolve_task_ip(task_id: &str) -> String {
459 let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
460
461 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
462 let ecs_client = aws_sdk_ecs::Client::new(&config);
463
464 loop {
465 let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
466 Ok(t) => t,
467 Err(e) => {
468 trace!(name: "resolve_ip_list_error", %task_id, error = %e);
469 tokio::time::sleep(Duration::from_secs(1)).await;
470 continue;
471 }
472 };
473
474 let task_arns: Vec<_> = tasks.task_arns().to_vec();
475 if task_arns.is_empty() {
476 trace!(name: "resolve_ip_no_tasks", %task_id);
477 tokio::time::sleep(Duration::from_secs(1)).await;
478 continue;
479 }
480
481 let task_details = match ecs_client
482 .describe_tasks()
483 .cluster(&cluster_name)
484 .set_tasks(Some(task_arns))
485 .send()
486 .await
487 {
488 Ok(d) => d,
489 Err(e) => {
490 trace!(name: "resolve_ip_describe_error", %task_id, error = %e);
491 tokio::time::sleep(Duration::from_secs(1)).await;
492 continue;
493 }
494 };
495
496 for task in task_details.tasks() {
498 let Some(task_arn) = task.task_arn() else {
499 continue;
500 };
501 let current_task_id = task_arn.rsplit('/').next().unwrap_or_default();
502
503 if current_task_id == task_id
504 && let Some(ip) = task
505 .attachments()
506 .iter()
507 .flat_map(|a| a.details())
508 .find(|d| d.name() == Some("privateIPv4Address"))
509 .and_then(|d| d.value())
510 {
511 trace!(name: "resolved_ip", %task_id, %ip);
512 return ip.to_owned();
513 }
514 }
515
516 trace!(name: "resolve_ip_not_found", %task_id);
517 tokio::time::sleep(Duration::from_secs(1)).await;
518 }
519}
520
521async fn resolve_task_family_to_task_id(task_family: &str) -> String {
524 let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
525
526 let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
527 let ecs_client = aws_sdk_ecs::Client::new(&config);
528
529 loop {
530 let tasks = match ecs_client
531 .list_tasks()
532 .cluster(&cluster_name)
533 .family(task_family)
534 .send()
535 .await
536 {
537 Ok(t) => t,
538 Err(e) => {
539 trace!(name: "resolve_family_list_error", %task_family, error = %e);
540 tokio::time::sleep(Duration::from_secs(1)).await;
541 continue;
542 }
543 };
544
545 let Some(task_arn) = tasks.task_arns().first() else {
546 trace!(name: "resolve_family_no_task", %task_family);
547 tokio::time::sleep(Duration::from_secs(1)).await;
548 continue;
549 };
550
551 let task_id = task_arn.rsplit('/').next().unwrap_or_default();
553 if !task_id.is_empty() {
554 trace!(name: "resolved_task_id", %task_family, %task_id);
555 return task_id.to_owned();
556 }
557
558 trace!(name: "resolve_family_invalid_arn", %task_family, %task_arn);
559 tokio::time::sleep(Duration::from_secs(1)).await;
560 }
561}
562
563fn get_self_task_id() -> String {
565 let metadata_uri = std::env::var("ECS_CONTAINER_METADATA_URI_V4")
566 .expect("ECS_CONTAINER_METADATA_URI_V4 not set - are we running in ECS?");
567 metadata_uri
568 .rsplit('/')
569 .next()
570 .expect("Invalid ECS metadata URI format")
571 .to_owned()
572}