1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20 BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21 ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25 id: usize,
26 pub(super) on: Arc<dyn Host>,
27 build_params: BuildParams,
28 tracing: Option<TracingOptions>,
29 args: Option<Vec<String>>,
30 display_id: Option<String>,
31 external_ports: Vec<u16>,
32 env: HashMap<String, String>,
33 pin_to_core: Option<usize>,
34
35 meta: OnceLock<String>,
36
37 pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
39 pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
41
42 launched_host: OnceCell<Arc<dyn LaunchedHost>>,
43
44 pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
48
49 launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
50 started: OnceCell<()>,
51}
52
53impl RustCrateService {
54 #[expect(clippy::too_many_arguments, reason = "internal use")]
55 pub fn new(
56 id: usize,
57 on: Arc<dyn Host>,
58 build_params: BuildParams,
59 tracing: Option<TracingOptions>,
60 args: Option<Vec<String>>,
61 display_id: Option<String>,
62 external_ports: Vec<u16>,
63 env: HashMap<String, String>,
64 pin_to_core: Option<usize>,
65 ) -> Self {
66 Self {
67 id,
68 on,
69 build_params,
70 tracing,
71 args,
72 display_id,
73 external_ports,
74 env,
75 pin_to_core,
76 meta: OnceLock::new(),
77 port_to_server: MemoMap::new(),
78 port_to_bind: MemoMap::new(),
79 launched_host: OnceCell::new(),
80 server_defns: Arc::new(RwLock::new(HashMap::new())),
81 launched_binary: OnceCell::new(),
82 started: OnceCell::new(),
83 }
84 }
85
86 pub fn update_meta<T: Serialize>(&self, meta: T) {
87 if self.launched_binary.get().is_some() {
88 panic!("Cannot update meta after binary has been launched")
89 }
90 self.meta
91 .set(serde_json::to_string(&meta).unwrap())
92 .expect("Cannot set meta twice.");
93 }
94
95 pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
96 RustCratePortConfig {
97 service: Arc::downgrade(self),
98 service_host: self.on.clone(),
99 service_server_defns: self.server_defns.clone(),
100 network_hint: PortNetworkHint::Auto,
101 port: name,
102 merge: false,
103 }
104 }
105
106 pub fn get_port_with_hint(
107 self: &Arc<Self>,
108 name: String,
109 network_hint: PortNetworkHint,
110 ) -> RustCratePortConfig {
111 RustCratePortConfig {
112 service: Arc::downgrade(self),
113 service_host: self.on.clone(),
114 service_server_defns: self.server_defns.clone(),
115 network_hint,
116 port: name,
117 merge: false,
118 }
119 }
120
121 pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
122 self.launched_binary.get().unwrap().stdout()
123 }
124
125 pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
126 self.launched_binary.get().unwrap().stderr()
127 }
128
129 pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
130 self.launched_binary.get().unwrap().stdout_filter(prefix)
131 }
132
133 pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
134 self.launched_binary.get().unwrap().stderr_filter(prefix)
135 }
136
137 #[cfg(feature = "profile-folding")]
138 pub fn tracing_results(&self) -> Option<&TracingResults> {
139 self.launched_binary.get().unwrap().tracing_results()
140 }
141
142 pub fn exit_code(&self) -> Option<i32> {
143 self.launched_binary.get().unwrap().exit_code()
144 }
145
146 fn build(
147 &self,
148 ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
149 build_crate_memoized(self.build_params.clone())
151 }
152}
153
154#[async_trait]
155impl Service for RustCrateService {
156 fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
157 if self.launched_host.get().is_some() {
158 return;
159 }
160
161 tokio::task::spawn(self.build());
162
163 let host = &self.on;
164
165 host.request_custom_binary();
166 for (_, bind_type) in self.port_to_bind.iter() {
167 host.request_port(bind_type);
168 }
169
170 for port in self.external_ports.iter() {
171 host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
172 }
173 }
174
175 async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
176 self.launched_host
177 .get_or_try_init::<anyhow::Error, _, _>(|| {
178 ProgressTracker::with_group(
179 self.display_id
180 .clone()
181 .unwrap_or_else(|| format!("service/{}", self.id)),
182 None,
183 || async {
184 let built = self.build().await?;
185
186 let host = &self.on;
187 let launched = host.provision(resource_result);
188
189 launched.copy_binary(built).await?;
190 Ok(launched)
191 },
192 )
193 })
194 .await?;
195 Ok(())
196 }
197
198 async fn ready(&self) -> Result<()> {
199 self.launched_binary
200 .get_or_try_init(|| {
201 ProgressTracker::with_group(
202 self.display_id
203 .clone()
204 .unwrap_or_else(|| format!("service/{}", self.id)),
205 None,
206 || async {
207 let launched_host = self.launched_host.get().unwrap();
208
209 let built = self.build().await?;
210 let args = self.args.as_ref().cloned().unwrap_or_default();
211
212 let binary = launched_host
213 .launch_binary(
214 self.display_id
215 .clone()
216 .unwrap_or_else(|| format!("service/{}", self.id)),
217 built,
218 &args,
219 self.tracing.clone(),
220 &self.env,
221 self.pin_to_core,
222 )
223 .await?;
224
225 let bind_config = self
226 .port_to_bind
227 .iter()
228 .map(|(port_name, bind_type)| {
229 (port_name.clone(), launched_host.server_config(bind_type))
230 })
231 .collect::<HashMap<_, _>>();
232
233 let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
234 bind_config,
235 self.meta.get().map(|s| s.as_str().into()),
236 ))
237 .unwrap();
238
239 let stdout_receiver = binary.deploy_stdout();
241
242 binary.stdin().send(format!("{formatted_bind_config}\n"))?;
243
244 let ready_line = ProgressTracker::leaf(
245 "waiting for ready",
246 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
247 )
248 .await
249 .context("Timed out waiting for ready")?
250 .context("Program unexpectedly quit")?;
251 if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
252 *self.server_defns.try_write().unwrap() =
253 serde_json::from_str(line_rest).unwrap();
254 } else {
255 bail!("expected ready");
256 }
257 Ok(binary)
258 },
259 )
260 })
261 .await?;
262 Ok(())
263 }
264
265 async fn start(&self) -> Result<()> {
266 self.started
267 .get_or_try_init(|| async {
268 let sink_ports_futures =
269 self.port_to_server
270 .iter()
271 .map(|(port_name, outgoing)| async {
272 (&**port_name, outgoing.load_instantiated(&|p| p).await)
273 });
274 let sink_ports = futures::future::join_all(sink_ports_futures)
275 .await
276 .into_iter()
277 .collect::<HashMap<_, _>>();
278
279 let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
280
281 let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
282
283 self.launched_binary
284 .get()
285 .unwrap()
286 .stdin()
287 .send(format!("start: {formatted_defns}\n"))
288 .unwrap();
289
290 let start_ack_line = ProgressTracker::leaf(
291 self.display_id
292 .clone()
293 .unwrap_or_else(|| format!("service/{}", self.id))
294 + " / waiting for ack start",
295 tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
296 )
297 .await??;
298 if !start_ack_line.starts_with("ack start") {
299 bail!("expected ack start");
300 }
301
302 Ok(())
303 })
304 .await?;
305
306 Ok(())
307 }
308
309 async fn stop(&self) -> Result<()> {
310 ProgressTracker::with_group(
311 self.display_id
312 .clone()
313 .unwrap_or_else(|| format!("service/{}", self.id)),
314 None,
315 || async {
316 let launched_binary = self.launched_binary.get().unwrap();
317 launched_binary.stdin().send("stop\n".to_owned())?;
318
319 let timeout_result = ProgressTracker::leaf(
320 "waiting for exit",
321 tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
322 )
323 .await;
324 match timeout_result {
325 Err(_timeout) => {} Ok(Err(unexpected_error)) => return Err(unexpected_error), Ok(Ok(_exit_status)) => {}
328 }
329 launched_binary.stop().await?;
330
331 Ok(())
332 },
333 )
334 .await
335 }
336}