1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use futures::StreamExt;
8use futures::stream::FuturesUnordered;
9pub use hydro_deploy_integration::*;
10use serde::de::DeserializeOwned;
11
12use crate::scheduled::graph::Dfir;
13
14#[macro_export]
15macro_rules! launch {
16 ($f:expr) => {
17 async {
18 let ports = $crate::util::deploy::init_no_ack_start().await;
19 let flow = $f(&ports);
20
21 println!("ack start");
22
23 $crate::util::deploy::launch_flow(flow).await
24 }
25 };
26}
27
28pub use crate::launch;
29
30pub async fn launch_flow(mut flow: Dfir<'_>) {
31 let stop = tokio::sync::oneshot::channel();
32 tokio::task::spawn_blocking(|| {
33 let mut line = String::new();
34 std::io::stdin().read_line(&mut line).unwrap();
35 if line.starts_with("stop") {
36 stop.0.send(()).unwrap();
37 } else {
38 eprintln!("Unexpected stdin input: {:?}", line);
39 }
40 });
41
42 let local_set = tokio::task::LocalSet::new();
43 let flow = local_set.run_until(flow.run());
44
45 tokio::select! {
46 _ = stop.1 => {},
47 _ = flow => {}
48 }
49}
50
51pub async fn launch_flow_containerized(mut flow: Dfir<'_>) {
52 let local_set = tokio::task::LocalSet::new();
53 local_set.run_until(flow.run()).await;
54}
55
56pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
57 let mut input = String::new();
58 std::io::stdin().read_line(&mut input).unwrap();
59 let trimmed = input.trim();
60
61 let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
62
63 let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
65 let mut binds = HashMap::new();
66 for (name, config) in bind_config.0 {
67 let bound = config.bind().await;
68 bind_results.insert(name.clone(), bound.server_port());
69 binds.insert(name.clone(), bound);
70 }
71
72 let bind_serialized = serde_json::to_string(&bind_results).unwrap();
73 println!("ready: {bind_serialized}");
74
75 let mut start_buf = String::new();
76 std::io::stdin().read_line(&mut start_buf).unwrap();
77 let connection_defns = if start_buf.starts_with("start: ") {
78 serde_json::from_str::<HashMap<String, ServerPort>>(
79 start_buf.trim_start_matches("start: ").trim(),
80 )
81 .unwrap()
82 } else {
83 panic!("expected start");
84 };
85
86 let (client_conns, server_conns) = futures::join!(
87 connection_defns
88 .into_iter()
89 .map(|(name, defn)| async move { (name, Connection::AsClient(defn.connect().await)) })
90 .collect::<FuturesUnordered<_>>()
91 .collect::<Vec<_>>(),
92 binds
93 .into_iter()
94 .map(
95 |(name, defn)| async move { (name, Connection::AsServer(accept_bound(defn).await)) }
96 )
97 .collect::<FuturesUnordered<_>>()
98 .collect::<Vec<_>>()
99 );
100
101 let all_connected = client_conns
102 .into_iter()
103 .chain(server_conns.into_iter())
104 .collect();
105
106 DeployPorts {
107 ports: RefCell::new(all_connected),
108 meta: bind_config
109 .1
110 .map(|b| serde_json::from_str(&b).unwrap())
111 .unwrap_or_default(),
112 }
113}
114
115pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
116 let ret = init_no_ack_start::<T>().await;
117
118 println!("ack start");
119
120 ret
121}