Skip to main content

hydro_lang/deploy/
deploy_runtime_containerized_ecs.rs

1#![allow(
2    unused,
3    reason = "unused in trybuild but the __staged version is needed"
4)]
5#![allow(missing_docs, reason = "used internally")]
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::net::SocketAddr;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use bytes::BytesMut;
17use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
18use proc_macro2::Span;
19use sinktools::demux_map_lazy::LazyDemuxSink;
20use sinktools::lazy::{LazySink, LazySource};
21use sinktools::lazy_sink_source::LazySinkSource;
22use stageleft::runtime_support::{
23    FreeVariableWithContext, FreeVariableWithContextWithProps, QuoteTokens,
24};
25use stageleft::{QuotedWithContext, q};
26use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
27use tokio::net::{TcpListener, TcpStream};
28use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
29use tracing::{Instrument, debug, error, instrument, span, trace, trace_span};
30
31use crate::location::dynamic::LocationId;
32use crate::location::member_id::TaglessMemberId;
33use crate::location::{LocationKey, MemberId, MembershipEvent};
34
35pub fn deploy_containerized_o2o(
36    target_task_family: &str,
37    bind_port: u16,
38) -> (syn::Expr, syn::Expr) {
39    (
40        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || Box::pin(
41            async move {
42                let target_task_family = target_task_family;
43                let task_id = self::resolve_task_family_to_task_id(target_task_family).await;
44                let ip = self::resolve_task_ip(&task_id).await;
45                let target = format!("{}:{}", ip, bind_port);
46                debug!(name: "connecting", %target, %target_task_family, %task_id);
47
48                let stream = TcpStream::connect(&target).await?;
49
50                Result::<_, std::io::Error>::Ok(FramedWrite::new(
51                    stream,
52                    LengthDelimitedCodec::new(),
53                ))
54            }
55        )))
56        .splice_untyped_ctx(&()),
57        q!(LazySource::new(move || Box::pin(async move {
58            let bind_addr = format!("0.0.0.0:{}", bind_port);
59            let listener = TcpListener::bind(bind_addr).await?;
60            let (stream, peer) = listener.accept().await?;
61            debug!(name: "accepting", ?peer);
62            Result::<_, std::io::Error>::Ok(FramedRead::new(stream, LengthDelimitedCodec::new()))
63        })))
64        .splice_untyped_ctx(&()),
65    )
66}
67
68pub fn deploy_containerized_o2m(port: u16) -> (syn::Expr, syn::Expr) {
69    (
70        QuotedWithContext::<'static, LazyDemuxSink<TaglessMemberId, _, _>, ()>::splice_untyped_ctx(
71            q!(sinktools::demux_map_lazy::<_, _, _, _>(
72                move |key: &TaglessMemberId| {
73                    let key = key.clone();
74
75                    LazySink::<_, _, _, bytes::Bytes>::new(move || {
76                        Box::pin(async move {
77                            let port = port;
78                            let task_id = key.get_container_name();
79                            let ip = self::resolve_task_ip(&task_id).await;
80                            let target = format!("{}:{}", ip, port);
81                            debug!(name: "connecting", %target, %task_id);
82
83                            let stream = TcpStream::connect(&target).await?;
84
85                            let sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
86                            Result::<_, std::io::Error>::Ok(sink)
87                        })
88                    })
89                }
90            )),
91            &(),
92        ),
93        q!(LazySource::new(move || Box::pin(async move {
94            let bind_addr = format!("0.0.0.0:{}", port);
95            debug!(name: "listening", %bind_addr);
96            let listener = TcpListener::bind(bind_addr).await?;
97            let (stream, peer) = listener.accept().await?;
98            debug!(name: "accepting", ?peer);
99
100            Result::<_, std::io::Error>::Ok(FramedRead::new(stream, LengthDelimitedCodec::new()))
101        })))
102        .splice_untyped_ctx(&()),
103    )
104}
105
106pub fn deploy_containerized_m2o(port: u16, target_task_family: &str) -> (syn::Expr, syn::Expr) {
107    (
108        q!(LazySink::<_, _, _, bytes::Bytes>::new(move || {
109            Box::pin(async move {
110                let target_task_family = target_task_family;
111                let target_task_id = self::resolve_task_family_to_task_id(target_task_family).await;
112                let ip = self::resolve_task_ip(&target_task_id).await;
113                let target = format!("{}:{}", ip, port);
114                debug!(name: "connecting", %target, %target_task_family, %target_task_id);
115
116                let stream = TcpStream::connect(&target).await?;
117
118                let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
119
120                let self_task_id = self::get_self_task_id();
121                sink.send(bytes::Bytes::from(
122                    bincode::serialize(&self_task_id).unwrap(),
123                ))
124                .await?;
125
126                Result::<_, std::io::Error>::Ok(sink)
127            })
128        }))
129        .splice_untyped_ctx(&()),
130        QuotedWithContext::<'static, LazySource<_, _, _, Result<(TaglessMemberId, BytesMut), _>>, ()>::splice_untyped_ctx(
131            q!(LazySource::new(move || Box::pin(async move {
132                let bind_addr = format!("0.0.0.0:{}", port);
133                debug!(name: "listening", %bind_addr);
134                let listener = TcpListener::bind(bind_addr).await?;
135                Result::<_, std::io::Error>::Ok(
136                    futures::stream::unfold(listener, |listener| {
137                        Box::pin(async move {
138                            let (stream, peer) = listener.accept().await.ok()?;
139                            let mut source = FramedRead::new(stream, LengthDelimitedCodec::new());
140                            let from_task_id =
141                                bincode::deserialize::<String>(&source.next().await?.ok()?[..])
142                                    .ok()?;
143
144                            debug!(name: "accepting", endpoint = format!("{}:{}", peer, from_task_id));
145
146                            Some((
147                                source.map(move |v| {
148                                    v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
149                                }),
150                                listener,
151                            ))
152                        })
153                    })
154                    .flatten_unordered(None),
155                )
156            }))),
157            &(),
158        ),
159    )
160}
161
162pub fn deploy_containerized_m2m(port: u16) -> (syn::Expr, syn::Expr) {
163    (
164        QuotedWithContext::<'static, LazyDemuxSink<TaglessMemberId, _, _>, ()>::splice_untyped_ctx(
165            q!(sinktools::demux_map_lazy::<_, _, _, _>(
166                move |key: &TaglessMemberId| {
167                    let key = key.clone();
168
169                    LazySink::<_, _, _, bytes::Bytes>::new(move || {
170                        Box::pin(async move {
171                            let port = port;
172                            let task_id = key.get_container_name();
173                            let ip = self::resolve_task_ip(&task_id).await;
174                            let target = format!("{}:{}", ip, port);
175                            debug!(name: "connecting", %target, %task_id);
176
177                            let stream = TcpStream::connect(&target).await?;
178
179                            let mut sink = FramedWrite::new(stream, LengthDelimitedCodec::new());
180                            debug!(name: "connected", %target);
181
182                            let self_task_id = self::get_self_task_id();
183                            sink.send(bytes::Bytes::from(
184                                bincode::serialize(&self_task_id).unwrap(),
185                            ))
186                            .await?;
187
188                            Result::<_, std::io::Error>::Ok(sink)
189                        })
190                    })
191                }
192            )),
193            &(),
194        ),
195        QuotedWithContext::<'static, LazySource<_, _, _, Result<(TaglessMemberId, BytesMut), _>>, ()>::splice_untyped_ctx(
196            q!(LazySource::new(move || Box::pin(async move {
197                let bind_addr = format!("0.0.0.0:{}", port);
198                debug!(name: "listening", %bind_addr);
199                let listener = TcpListener::bind(bind_addr).await?;
200
201                Result::<_, std::io::Error>::Ok(
202                    futures::stream::unfold(listener, |listener| {
203                        Box::pin(async move {
204                            let (stream, peer) = listener.accept().await.ok()?;
205                            let mut source = FramedRead::new(stream, LengthDelimitedCodec::new());
206                            let from_task_id =
207                                bincode::deserialize::<String>(&source.next().await?.ok()?[..])
208                                    .ok()?;
209
210                            debug!(name: "accepting", endpoint = format!("{}:{}", peer, from_task_id));
211
212                            Some((
213                                source.map(move |v| {
214                                    v.map(|v| (TaglessMemberId::from_container_name(from_task_id.clone()), v))
215                                }),
216                                listener,
217                            ))
218                        })
219                    })
220                    .flatten_unordered(None),
221                )
222            }))),
223            &(),
224        ),
225    )
226}
227
228pub struct SocketIdent {
229    pub socket_ident: syn::Ident,
230}
231
232impl<Ctx> FreeVariableWithContextWithProps<Ctx, ()> for SocketIdent {
233    type O = TcpListener;
234
235    fn to_tokens(self, _ctx: &Ctx) -> (QuoteTokens, ())
236    where
237        Self: Sized,
238    {
239        let ident = self.socket_ident;
240
241        (
242            QuoteTokens {
243                prelude: None,
244                expr: Some(quote::quote! { #ident }),
245            },
246            (),
247        )
248    }
249}
250
251pub fn deploy_containerized_external_sink_source_ident(
252    bind_addr: String,
253    socket_ident: syn::Ident,
254) -> syn::Expr {
255    let socket_ident = SocketIdent { socket_ident };
256
257    q!(LazySinkSource::<
258        _,
259        FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
260        FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
261        bytes::Bytes,
262        // Result<bytes::BytesMut, std::io::Error>,
263        std::io::Error,
264    >::new(async move {
265        let span = span!(tracing::Level::TRACE, "lazy_sink_source");
266        let guard = span.enter();
267        let bind_addr = bind_addr;
268        trace!(name: "attempting to accept from external", %bind_addr);
269        std::mem::drop(guard);
270        let (stream, peer) = socket_ident.accept().instrument(span.clone()).await?;
271        let guard = span.enter();
272
273        debug!(name: "external accepting", ?peer);
274        let (rx, tx) = stream.into_split();
275
276        let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
277        let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
278
279        Result::<_, std::io::Error>::Ok((fr, fw))
280    },))
281    .splice_untyped_ctx(&())
282}
283
284pub fn cluster_ids<'a>() -> impl QuotedWithContext<'a, &'a [TaglessMemberId], ()> + Clone {
285    // unimplemented!(); // this is unused.
286
287    // This is a dummy piece of code, since clusters are dynamic when containerized.
288    q!(Box::leak(Box::new([TaglessMemberId::from_container_name(
289        "INVALID CONTAINER NAME cluster_ids"
290    )]))
291    .as_slice())
292}
293
294pub fn cluster_self_id<'a>() -> impl QuotedWithContext<'a, TaglessMemberId, ()> + Clone + 'a {
295    q!(TaglessMemberId::from_container_name(
296        self::get_self_task_id()
297    ))
298}
299
300pub fn cluster_membership_stream<'a>(
301    location_id: &LocationId,
302) -> impl QuotedWithContext<'a, Box<dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin>, ()>
303{
304    let location_key = location_id.key();
305
306    q!(Box::new(self::ecs_membership_stream(
307        std::env::var("CLUSTER_NAME").unwrap(),
308        location_key
309    ))
310        as Box<
311            dyn Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin,
312        >)
313}
314
315#[instrument(skip_all, fields(%cluster_name, %location_key))]
316fn ecs_membership_stream(
317    cluster_name: String,
318    location_key: LocationKey,
319) -> impl Stream<Item = (TaglessMemberId, MembershipEvent)> + Unpin {
320    use std::collections::HashSet;
321
322    use futures::stream::{StreamExt, once};
323
324    trace!(name: "ecs_membership_stream_created", %cluster_name, %location_key);
325
326    let ecs_poller_span = trace_span!("ecs_poller");
327
328    // Task family format: hy-{name_hint}-loc{idx}v{version}
329    // Example: hy-p1-loc2v1
330    let task_definition_arn_parser =
331        regex::Regex::new(r#"arn:aws:ecs:(?<region>.*):(?<account_id>.*):task-definition\/(?<container_id>hy-(?<type>[^-]+)-loc(?<location_idx>[0-9]+)v(?<location_version>[0-9]+)(?:-(?<instance_id>.*))?):.*"#).unwrap();
332
333    let poll_stream = futures::stream::unfold(
334        (HashSet::<String>::new(), cluster_name, location_key),
335        move |(known_tasks, cluster_name, location_key)| {
336            let task_definition_arn_parser = task_definition_arn_parser.clone();
337
338            async move {
339                let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
340                let ecs_client = aws_sdk_ecs::Client::new(&config);
341
342                let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
343                    Ok(tasks) => tasks,
344                    Err(e) => {
345                        trace!(name: "list_tasks_error", error = %e);
346                        tokio::time::sleep(Duration::from_secs(2)).await;
347                        return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
348                    }
349                };
350
351                let task_arns: Vec<String> =
352                    tasks.task_arns().iter().map(|s| s.to_string()).collect();
353
354                let mut events = Vec::new();
355                let mut current_tasks = HashSet::<String>::new();
356
357                if !task_arns.is_empty() {
358                    let task_details = match ecs_client
359                        .describe_tasks()
360                        .cluster(&cluster_name)
361                        .set_tasks(Some(task_arns.clone()))
362                        .send()
363                        .await
364                    {
365                        Ok(details) => details,
366                        Err(e) => {
367                            trace!(name: "describe_tasks_error", error = %e);
368                            tokio::time::sleep(Duration::from_secs(2)).await;
369                            return Some((Vec::new(), (known_tasks, cluster_name, location_key)));
370                        }
371                    };
372
373                    for task in task_details.tasks() {
374                        let Some(last_status) = task.last_status() else {
375                            continue;
376                        };
377
378                        if last_status != "RUNNING" {
379                            continue;
380                        }
381
382                        let Some(task_def_arn) = task.task_definition_arn() else {
383                            continue;
384                        };
385
386                        let Some(captures) = task_definition_arn_parser.captures(task_def_arn)
387                        else {
388                            continue;
389                        };
390
391                        let Some(location_idx) = captures.name("location_idx") else {
392                            continue;
393                        };
394                        let Some(location_version) = captures.name("location_version") else {
395                            continue;
396                        };
397                        // Reconstruct the location key string and parse it
398                        let location_key_str =
399                            format!("loc{}v{}", location_idx.as_str(), location_version.as_str());
400                        let task_location_key: LocationKey = match location_key_str.parse() {
401                            Ok(key) => key,
402                            Err(_) => {
403                                continue;
404                            }
405                        };
406
407                        // Filter by location_id - only include tasks for this specific cluster
408                        if task_location_key != location_key {
409                            continue;
410                        }
411
412                        // Extract task ID from task ARN (last segment after final /)
413                        // Task ARN format: arn:aws:ecs:region:account:task/cluster-name/task-id
414                        let Some(task_arn) = task.task_arn() else {
415                            continue;
416                        };
417                        let Some(task_id) = task_arn.rsplit('/').next() else {
418                            continue;
419                        };
420
421                        // Use task_id as the member identifier
422                        current_tasks.insert(task_id.to_owned());
423                        if !known_tasks.contains(task_id) {
424                            trace!(name: "task_joined", %task_id);
425                            events.push((task_id.to_owned(), MembershipEvent::Joined));
426                        }
427                    }
428                }
429
430                #[expect(
431                    clippy::disallowed_methods,
432                    reason = "nondeterministic iteration order, container events are not deterministically ordered"
433                )]
434                for task_id in known_tasks.iter() {
435                    if !current_tasks.contains(task_id) {
436                        trace!(name: "task_left", %task_id);
437                        events.push((task_id.to_owned(), MembershipEvent::Left));
438                    }
439                }
440
441                tokio::time::sleep(Duration::from_secs(2)).await;
442
443                Some((events, (current_tasks, cluster_name, location_key)))
444            }
445            .instrument(ecs_poller_span.clone())
446        },
447    )
448    .flat_map(futures::stream::iter);
449
450    Box::pin(
451        poll_stream
452            .map(|(k, v)| (TaglessMemberId::from_container_name(k), v))
453            .inspect(|(member_id, event)| trace!(name: "membership_event", ?member_id, ?event)),
454    )
455}
456
457/// Resolve a task ID to its private IP address via ECS API.
458async fn resolve_task_ip(task_id: &str) -> String {
459    let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
460
461    let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
462    let ecs_client = aws_sdk_ecs::Client::new(&config);
463
464    loop {
465        let tasks = match ecs_client.list_tasks().cluster(&cluster_name).send().await {
466            Ok(t) => t,
467            Err(e) => {
468                trace!(name: "resolve_ip_list_error", %task_id, error = %e);
469                tokio::time::sleep(Duration::from_secs(1)).await;
470                continue;
471            }
472        };
473
474        let task_arns: Vec<_> = tasks.task_arns().to_vec();
475        if task_arns.is_empty() {
476            trace!(name: "resolve_ip_no_tasks", %task_id);
477            tokio::time::sleep(Duration::from_secs(1)).await;
478            continue;
479        }
480
481        let task_details = match ecs_client
482            .describe_tasks()
483            .cluster(&cluster_name)
484            .set_tasks(Some(task_arns))
485            .send()
486            .await
487        {
488            Ok(d) => d,
489            Err(e) => {
490                trace!(name: "resolve_ip_describe_error", %task_id, error = %e);
491                tokio::time::sleep(Duration::from_secs(1)).await;
492                continue;
493            }
494        };
495
496        // Find the task with matching task ID
497        for task in task_details.tasks() {
498            let Some(task_arn) = task.task_arn() else {
499                continue;
500            };
501            let current_task_id = task_arn.rsplit('/').next().unwrap_or_default();
502
503            if current_task_id == task_id
504                && let Some(ip) = task
505                    .attachments()
506                    .iter()
507                    .flat_map(|a| a.details())
508                    .find(|d| d.name() == Some("privateIPv4Address"))
509                    .and_then(|d| d.value())
510            {
511                trace!(name: "resolved_ip", %task_id, %ip);
512                return ip.to_owned();
513            }
514        }
515
516        trace!(name: "resolve_ip_not_found", %task_id);
517        tokio::time::sleep(Duration::from_secs(1)).await;
518    }
519}
520
521/// Resolve a task family name to its task ID via ECS API.
522/// Used for process-to-process connections where the target is known by task family at compile time.
523async fn resolve_task_family_to_task_id(task_family: &str) -> String {
524    let cluster_name = std::env::var("CLUSTER_NAME").unwrap();
525
526    let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
527    let ecs_client = aws_sdk_ecs::Client::new(&config);
528
529    loop {
530        let tasks = match ecs_client
531            .list_tasks()
532            .cluster(&cluster_name)
533            .family(task_family)
534            .send()
535            .await
536        {
537            Ok(t) => t,
538            Err(e) => {
539                trace!(name: "resolve_family_list_error", %task_family, error = %e);
540                tokio::time::sleep(Duration::from_secs(1)).await;
541                continue;
542            }
543        };
544
545        let Some(task_arn) = tasks.task_arns().first() else {
546            trace!(name: "resolve_family_no_task", %task_family);
547            tokio::time::sleep(Duration::from_secs(1)).await;
548            continue;
549        };
550
551        // Extract task ID from ARN
552        let task_id = task_arn.rsplit('/').next().unwrap_or_default();
553        if !task_id.is_empty() {
554            trace!(name: "resolved_task_id", %task_family, %task_id);
555            return task_id.to_owned();
556        }
557
558        trace!(name: "resolve_family_invalid_arn", %task_family, %task_arn);
559        tokio::time::sleep(Duration::from_secs(1)).await;
560    }
561}
562
563/// Get the current task's ID from ECS metadata.
564fn get_self_task_id() -> String {
565    let metadata_uri = std::env::var("ECS_CONTAINER_METADATA_URI_V4")
566        .expect("ECS_CONTAINER_METADATA_URI_V4 not set - are we running in ECS?");
567    metadata_uri
568        .rsplit('/')
569        .next()
570        .expect("Invalid ECS metadata URI format")
571        .to_owned()
572}