1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16use crate::progress::ProgressTracker;
17use crate::{
18 BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
19 ResourceResult, ServerStrategy, Service, TracingResults,
20};
21
22pub struct RustCrateService {
23 id: usize,
24 pub(super) on: Arc<dyn Host>,
25 build_params: BuildParams,
26 tracing: Option<TracingOptions>,
27 args: Option<Vec<String>>,
28 display_id: Option<String>,
29 external_ports: Vec<u16>,
30
31 meta: OnceLock<String>,
32
33 pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
35 pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
37
38 launched_host: OnceCell<Arc<dyn LaunchedHost>>,
39
40 pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
44
45 launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
46 started: OnceCell<()>,
47}
48
49impl RustCrateService {
50 pub fn new(
51 id: usize,
52 on: Arc<dyn Host>,
53 build_params: BuildParams,
54 tracing: Option<TracingOptions>,
55 args: Option<Vec<String>>,
56 display_id: Option<String>,
57 external_ports: Vec<u16>,
58 ) -> Self {
59 Self {
60 id,
61 on,
62 build_params,
63 tracing,
64 args,
65 display_id,
66 external_ports,
67 meta: OnceLock::new(),
68 port_to_server: MemoMap::new(),
69 port_to_bind: MemoMap::new(),
70 launched_host: OnceCell::new(),
71 server_defns: Arc::new(RwLock::new(HashMap::new())),
72 launched_binary: OnceCell::new(),
73 started: OnceCell::new(),
74 }
75 }
76
77 pub fn update_meta<T: Serialize>(&self, meta: T) {
78 if self.launched_binary.get().is_some() {
79 panic!("Cannot update meta after binary has been launched")
80 }
81 self.meta
82 .set(serde_json::to_string(&meta).unwrap())
83 .expect("Cannot set meta twice.");
84 }
85
86 pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
87 RustCratePortConfig {
88 service: Arc::downgrade(self),
89 service_host: self.on.clone(),
90 service_server_defns: self.server_defns.clone(),
91 network_hint: PortNetworkHint::Auto,
92 port: name,
93 merge: false,
94 }
95 }
96
97 pub fn get_port_with_hint(
98 self: &Arc<Self>,
99 name: String,
100 network_hint: PortNetworkHint,
101 ) -> RustCratePortConfig {
102 RustCratePortConfig {
103 service: Arc::downgrade(self),
104 service_host: self.on.clone(),
105 service_server_defns: self.server_defns.clone(),
106 network_hint,
107 port: name,
108 merge: false,
109 }
110 }
111
112 pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
113 self.launched_binary.get().unwrap().stdout()
114 }
115
116 pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
117 self.launched_binary.get().unwrap().stderr()
118 }
119
120 pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
121 self.launched_binary.get().unwrap().stdout_filter(prefix)
122 }
123
124 pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
125 self.launched_binary.get().unwrap().stderr_filter(prefix)
126 }
127
128 pub fn tracing_results(&self) -> Option<&TracingResults> {
129 self.launched_binary.get().unwrap().tracing_results()
130 }
131
132 pub fn exit_code(&self) -> Option<i32> {
133 self.launched_binary.get().unwrap().exit_code()
134 }
135
136 fn build(
137 &self,
138 ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
139 build_crate_memoized(self.build_params.clone())
141 }
142}
143
144#[async_trait]
145impl Service for RustCrateService {
146 fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
147 if self.launched_host.get().is_some() {
148 return;
149 }
150
151 tokio::task::spawn(self.build());
152
153 let host = &self.on;
154
155 host.request_custom_binary();
156 for (_, bind_type) in self.port_to_bind.iter() {
157 host.request_port(bind_type);
158 }
159
160 for port in self.external_ports.iter() {
161 host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
162 }
163 }
164
165 async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
166 self.launched_host
167 .get_or_try_init::<anyhow::Error, _, _>(|| {
168 ProgressTracker::with_group(
169 self.display_id
170 .clone()
171 .unwrap_or_else(|| format!("service/{}", self.id)),
172 None,
173 || async {
174 let built = self.build().await?;
175
176 let host = &self.on;
177 let launched = host.provision(resource_result);
178
179 launched.copy_binary(built).await?;
180 Ok(launched)
181 },
182 )
183 })
184 .await?;
185 Ok(())
186 }
187
188 async fn ready(&self) -> Result<()> {
189 self.launched_binary
190 .get_or_try_init(|| {
191 ProgressTracker::with_group(
192 self.display_id
193 .clone()
194 .unwrap_or_else(|| format!("service/{}", self.id)),
195 None,
196 || async {
197 let launched_host = self.launched_host.get().unwrap();
198
199 let built = self.build().await?;
200 let args = self.args.as_ref().cloned().unwrap_or_default();
201
202 let binary = launched_host
203 .launch_binary(
204 self.display_id
205 .clone()
206 .unwrap_or_else(|| format!("service/{}", self.id)),
207 built,
208 &args,
209 self.tracing.clone(),
210 )
211 .await?;
212
213 let bind_config = self
214 .port_to_bind
215 .iter()
216 .map(|(port_name, bind_type)| {
217 (port_name.clone(), launched_host.server_config(bind_type))
218 })
219 .collect::<HashMap<_, _>>();
220
221 let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
222 bind_config,
223 self.meta.get().map(|s| s.as_str().into()),
224 ))
225 .unwrap();
226
227 let stdout_receiver = binary.deploy_stdout();
229
230 binary.stdin().send(format!("{formatted_bind_config}\n"))?;
231
232 let ready_line = ProgressTracker::leaf(
233 "waiting for ready",
234 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
235 )
236 .await
237 .context("Timed out waiting for ready")?
238 .context("Program unexpectedly quit")?;
239 if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
240 *self.server_defns.try_write().unwrap() =
241 serde_json::from_str(line_rest).unwrap();
242 } else {
243 bail!("expected ready");
244 }
245 Ok(binary)
246 },
247 )
248 })
249 .await?;
250 Ok(())
251 }
252
253 async fn start(&self) -> Result<()> {
254 self.started
255 .get_or_try_init(|| async {
256 let sink_ports_futures =
257 self.port_to_server
258 .iter()
259 .map(|(port_name, outgoing)| async {
260 (&**port_name, outgoing.load_instantiated(&|p| p).await)
261 });
262 let sink_ports = futures::future::join_all(sink_ports_futures)
263 .await
264 .into_iter()
265 .collect::<HashMap<_, _>>();
266
267 let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
268
269 let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
270
271 self.launched_binary
272 .get()
273 .unwrap()
274 .stdin()
275 .send(format!("start: {formatted_defns}\n"))
276 .unwrap();
277
278 let start_ack_line = ProgressTracker::leaf(
279 self.display_id
280 .clone()
281 .unwrap_or_else(|| format!("service/{}", self.id))
282 + " / waiting for ack start",
283 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
284 )
285 .await??;
286 if !start_ack_line.starts_with("ack start") {
287 bail!("expected ack start");
288 }
289
290 Ok(())
291 })
292 .await?;
293
294 Ok(())
295 }
296
297 async fn stop(&self) -> Result<()> {
298 ProgressTracker::with_group(
299 self.display_id
300 .clone()
301 .unwrap_or_else(|| format!("service/{}", self.id)),
302 None,
303 || async {
304 let launched_binary = self.launched_binary.get().unwrap();
305 launched_binary.stdin().send("stop\n".to_string())?;
306
307 let timeout_result = ProgressTracker::leaf(
308 "waiting for exit",
309 tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
310 )
311 .await;
312 match timeout_result {
313 Err(_timeout) => {} Ok(Err(unexpected_error)) => return Err(unexpected_error), Ok(Ok(_exit_status)) => {}
316 }
317 launched_binary.stop().await?;
318
319 Ok(())
320 },
321 )
322 .await
323 }
324}