Skip to main content

hydro_deploy/rust_crate/
service.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16#[cfg(feature = "profile-folding")]
17use crate::TracingResults;
18use crate::progress::ProgressTracker;
19use crate::{
20    BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
21    ResourceResult, ServerStrategy, Service,
22};
23
24pub struct RustCrateService {
25    id: usize,
26    pub(super) on: Arc<dyn Host>,
27    build_params: BuildParams,
28    tracing: Option<TracingOptions>,
29    args: Option<Vec<String>>,
30    display_id: Option<String>,
31    external_ports: Vec<u16>,
32    env: HashMap<String, String>,
33    pin_to_core: Option<usize>,
34
35    meta: OnceLock<String>,
36
37    /// Configuration for the ports this service will connect to as a client.
38    pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
39    /// Configuration for the ports that this service will listen on a port for.
40    pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
41
42    launched_host: OnceCell<Arc<dyn LaunchedHost>>,
43
44    /// A map of port names to config for how other services can connect to this one.
45    /// Only valid after `ready` has been called, only contains ports that are configured
46    /// in `server_ports`.
47    pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
48
49    launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
50    started: OnceCell<()>,
51}
52
53impl RustCrateService {
54    #[expect(clippy::too_many_arguments, reason = "internal use")]
55    pub fn new(
56        id: usize,
57        on: Arc<dyn Host>,
58        build_params: BuildParams,
59        tracing: Option<TracingOptions>,
60        args: Option<Vec<String>>,
61        display_id: Option<String>,
62        external_ports: Vec<u16>,
63        env: HashMap<String, String>,
64        pin_to_core: Option<usize>,
65    ) -> Self {
66        Self {
67            id,
68            on,
69            build_params,
70            tracing,
71            args,
72            display_id,
73            external_ports,
74            env,
75            pin_to_core,
76            meta: OnceLock::new(),
77            port_to_server: MemoMap::new(),
78            port_to_bind: MemoMap::new(),
79            launched_host: OnceCell::new(),
80            server_defns: Arc::new(RwLock::new(HashMap::new())),
81            launched_binary: OnceCell::new(),
82            started: OnceCell::new(),
83        }
84    }
85
86    pub fn update_meta<T: Serialize>(&self, meta: T) {
87        if self.launched_binary.get().is_some() {
88            panic!("Cannot update meta after binary has been launched")
89        }
90        self.meta
91            .set(serde_json::to_string(&meta).unwrap())
92            .expect("Cannot set meta twice.");
93    }
94
95    pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
96        RustCratePortConfig {
97            service: Arc::downgrade(self),
98            service_host: self.on.clone(),
99            service_server_defns: self.server_defns.clone(),
100            network_hint: PortNetworkHint::Auto,
101            port: name,
102            merge: false,
103        }
104    }
105
106    pub fn get_port_with_hint(
107        self: &Arc<Self>,
108        name: String,
109        network_hint: PortNetworkHint,
110    ) -> RustCratePortConfig {
111        RustCratePortConfig {
112            service: Arc::downgrade(self),
113            service_host: self.on.clone(),
114            service_server_defns: self.server_defns.clone(),
115            network_hint,
116            port: name,
117            merge: false,
118        }
119    }
120
121    pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
122        self.launched_binary.get().unwrap().stdout()
123    }
124
125    pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
126        self.launched_binary.get().unwrap().stderr()
127    }
128
129    pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
130        self.launched_binary.get().unwrap().stdout_filter(prefix)
131    }
132
133    pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
134        self.launched_binary.get().unwrap().stderr_filter(prefix)
135    }
136
137    #[cfg(feature = "profile-folding")]
138    pub fn tracing_results(&self) -> Option<&TracingResults> {
139        self.launched_binary.get().unwrap().tracing_results()
140    }
141
142    pub fn exit_code(&self) -> Option<i32> {
143        self.launched_binary.get().unwrap().exit_code()
144    }
145
146    fn build(
147        &self,
148    ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
149        // Memoized, so no caching in `self` is needed.
150        build_crate_memoized(self.build_params.clone())
151    }
152}
153
154#[async_trait]
155impl Service for RustCrateService {
156    fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
157        if self.launched_host.get().is_some() {
158            return;
159        }
160
161        tokio::task::spawn(self.build());
162
163        let host = &self.on;
164
165        host.request_custom_binary();
166        for (_, bind_type) in self.port_to_bind.iter() {
167            host.request_port(bind_type);
168        }
169
170        for port in self.external_ports.iter() {
171            host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
172        }
173    }
174
175    async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
176        self.launched_host
177            .get_or_try_init::<anyhow::Error, _, _>(|| {
178                ProgressTracker::with_group(
179                    self.display_id
180                        .clone()
181                        .unwrap_or_else(|| format!("service/{}", self.id)),
182                    None,
183                    || async {
184                        let built = self.build().await?;
185
186                        let host = &self.on;
187                        let launched = host.provision(resource_result);
188
189                        launched.copy_binary(built).await?;
190                        Ok(launched)
191                    },
192                )
193            })
194            .await?;
195        Ok(())
196    }
197
198    async fn ready(&self) -> Result<()> {
199        self.launched_binary
200            .get_or_try_init(|| {
201                ProgressTracker::with_group(
202                    self.display_id
203                        .clone()
204                        .unwrap_or_else(|| format!("service/{}", self.id)),
205                    None,
206                    || async {
207                        let launched_host = self.launched_host.get().unwrap();
208
209                        let built = self.build().await?;
210                        let args = self.args.as_ref().cloned().unwrap_or_default();
211
212                        let binary = launched_host
213                            .launch_binary(
214                                self.display_id
215                                    .clone()
216                                    .unwrap_or_else(|| format!("service/{}", self.id)),
217                                built,
218                                &args,
219                                self.tracing.clone(),
220                                &self.env,
221                                self.pin_to_core,
222                            )
223                            .await?;
224
225                        let bind_config = self
226                            .port_to_bind
227                            .iter()
228                            .map(|(port_name, bind_type)| {
229                                (port_name.clone(), launched_host.server_config(bind_type))
230                            })
231                            .collect::<HashMap<_, _>>();
232
233                        let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
234                            bind_config,
235                            self.meta.get().map(|s| s.as_str().into()),
236                        ))
237                        .unwrap();
238
239                        // request stdout before sending config so we don't miss the "ready" response
240                        let stdout_receiver = binary.deploy_stdout();
241
242                        binary.stdin().send(format!("{formatted_bind_config}\n"))?;
243
244                        let ready_line = ProgressTracker::leaf(
245                            "waiting for ready",
246                            tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
247                        )
248                        .await
249                        .context("Timed out waiting for ready")?
250                        .context("Program unexpectedly quit")?;
251                        if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
252                            *self.server_defns.try_write().unwrap() =
253                                serde_json::from_str(line_rest).unwrap();
254                        } else {
255                            bail!("expected ready");
256                        }
257                        Ok(binary)
258                    },
259                )
260            })
261            .await?;
262        Ok(())
263    }
264
265    async fn start(&self) -> Result<()> {
266        self.started
267            .get_or_try_init(|| async {
268                let sink_ports_futures =
269                    self.port_to_server
270                        .iter()
271                        .map(|(port_name, outgoing)| async {
272                            (&**port_name, outgoing.load_instantiated(&|p| p).await)
273                        });
274                let sink_ports = futures::future::join_all(sink_ports_futures)
275                    .await
276                    .into_iter()
277                    .collect::<HashMap<_, _>>();
278
279                let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
280
281                let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
282
283                self.launched_binary
284                    .get()
285                    .unwrap()
286                    .stdin()
287                    .send(format!("start: {formatted_defns}\n"))
288                    .unwrap();
289
290                let start_ack_line = ProgressTracker::leaf(
291                    self.display_id
292                        .clone()
293                        .unwrap_or_else(|| format!("service/{}", self.id))
294                        + " / waiting for ack start",
295                    tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
296                )
297                .await??;
298                if !start_ack_line.starts_with("ack start") {
299                    bail!("expected ack start");
300                }
301
302                Ok(())
303            })
304            .await?;
305
306        Ok(())
307    }
308
309    async fn stop(&self) -> Result<()> {
310        ProgressTracker::with_group(
311            self.display_id
312                .clone()
313                .unwrap_or_else(|| format!("service/{}", self.id)),
314            None,
315            || async {
316                let launched_binary = self.launched_binary.get().unwrap();
317                launched_binary.stdin().send("stop\n".to_owned())?;
318
319                let timeout_result = ProgressTracker::leaf(
320                    "waiting for exit",
321                    tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
322                )
323                .await;
324                match timeout_result {
325                    Err(_timeout) => {} // `wait()` timed out, but stop will force quit.
326                    Ok(Err(unexpected_error)) => return Err(unexpected_error), // `wait()` errored.
327                    Ok(Ok(_exit_status)) => {}
328                }
329                launched_binary.stop().await?;
330
331                Ok(())
332            },
333        )
334        .await
335    }
336}