hydro_deploy/rust_crate/
service.rs

1use std::collections::HashMap;
2use std::sync::{Arc, OnceLock};
3use std::time::Duration;
4
5use anyhow::{Context, Result, bail};
6use async_trait::async_trait;
7use futures::Future;
8use hydro_deploy_integration::{InitConfig, ServerPort};
9use memo_map::MemoMap;
10use serde::Serialize;
11use tokio::sync::{OnceCell, RwLock, mpsc};
12
13use super::build::{BuildError, BuildOutput, BuildParams, build_crate_memoized};
14use super::ports::{self, RustCratePortConfig};
15use super::tracing_options::TracingOptions;
16use crate::progress::ProgressTracker;
17use crate::{
18    BaseServerStrategy, Host, LaunchedBinary, LaunchedHost, PortNetworkHint, ResourceBatch,
19    ResourceResult, ServerStrategy, Service, TracingResults,
20};
21
22pub struct RustCrateService {
23    id: usize,
24    pub(super) on: Arc<dyn Host>,
25    build_params: BuildParams,
26    tracing: Option<TracingOptions>,
27    args: Option<Vec<String>>,
28    display_id: Option<String>,
29    external_ports: Vec<u16>,
30
31    meta: OnceLock<String>,
32
33    /// Configuration for the ports this service will connect to as a client.
34    pub(super) port_to_server: MemoMap<String, ports::ServerConfig>,
35    /// Configuration for the ports that this service will listen on a port for.
36    pub(super) port_to_bind: MemoMap<String, ServerStrategy>,
37
38    launched_host: OnceCell<Arc<dyn LaunchedHost>>,
39
40    /// A map of port names to config for how other services can connect to this one.
41    /// Only valid after `ready` has been called, only contains ports that are configured
42    /// in `server_ports`.
43    pub(super) server_defns: Arc<RwLock<HashMap<String, ServerPort>>>,
44
45    launched_binary: OnceCell<Box<dyn LaunchedBinary>>,
46    started: OnceCell<()>,
47}
48
49impl RustCrateService {
50    pub fn new(
51        id: usize,
52        on: Arc<dyn Host>,
53        build_params: BuildParams,
54        tracing: Option<TracingOptions>,
55        args: Option<Vec<String>>,
56        display_id: Option<String>,
57        external_ports: Vec<u16>,
58    ) -> Self {
59        Self {
60            id,
61            on,
62            build_params,
63            tracing,
64            args,
65            display_id,
66            external_ports,
67            meta: OnceLock::new(),
68            port_to_server: MemoMap::new(),
69            port_to_bind: MemoMap::new(),
70            launched_host: OnceCell::new(),
71            server_defns: Arc::new(RwLock::new(HashMap::new())),
72            launched_binary: OnceCell::new(),
73            started: OnceCell::new(),
74        }
75    }
76
77    pub fn update_meta<T: Serialize>(&self, meta: T) {
78        if self.launched_binary.get().is_some() {
79            panic!("Cannot update meta after binary has been launched")
80        }
81        self.meta
82            .set(serde_json::to_string(&meta).unwrap())
83            .expect("Cannot set meta twice.");
84    }
85
86    pub fn get_port(self: &Arc<Self>, name: String) -> RustCratePortConfig {
87        RustCratePortConfig {
88            service: Arc::downgrade(self),
89            service_host: self.on.clone(),
90            service_server_defns: self.server_defns.clone(),
91            network_hint: PortNetworkHint::Auto,
92            port: name,
93            merge: false,
94        }
95    }
96
97    pub fn get_port_with_hint(
98        self: &Arc<Self>,
99        name: String,
100        network_hint: PortNetworkHint,
101    ) -> RustCratePortConfig {
102        RustCratePortConfig {
103            service: Arc::downgrade(self),
104            service_host: self.on.clone(),
105            service_server_defns: self.server_defns.clone(),
106            network_hint,
107            port: name,
108            merge: false,
109        }
110    }
111
112    pub fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
113        self.launched_binary.get().unwrap().stdout()
114    }
115
116    pub fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
117        self.launched_binary.get().unwrap().stderr()
118    }
119
120    pub fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
121        self.launched_binary.get().unwrap().stdout_filter(prefix)
122    }
123
124    pub fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
125        self.launched_binary.get().unwrap().stderr_filter(prefix)
126    }
127
128    pub fn tracing_results(&self) -> Option<&TracingResults> {
129        self.launched_binary.get().unwrap().tracing_results()
130    }
131
132    pub fn exit_code(&self) -> Option<i32> {
133        self.launched_binary.get().unwrap().exit_code()
134    }
135
136    fn build(
137        &self,
138    ) -> impl use<> + 'static + Future<Output = Result<&'static BuildOutput, BuildError>> {
139        // Memoized, so no caching in `self` is needed.
140        build_crate_memoized(self.build_params.clone())
141    }
142}
143
144#[async_trait]
145impl Service for RustCrateService {
146    fn collect_resources(&self, _resource_batch: &mut ResourceBatch) {
147        if self.launched_host.get().is_some() {
148            return;
149        }
150
151        tokio::task::spawn(self.build());
152
153        let host = &self.on;
154
155        host.request_custom_binary();
156        for (_, bind_type) in self.port_to_bind.iter() {
157            host.request_port(bind_type);
158        }
159
160        for port in self.external_ports.iter() {
161            host.request_port_base(&BaseServerStrategy::ExternalTcpPort(*port));
162        }
163    }
164
165    async fn deploy(&self, resource_result: &Arc<ResourceResult>) -> Result<()> {
166        self.launched_host
167            .get_or_try_init::<anyhow::Error, _, _>(|| {
168                ProgressTracker::with_group(
169                    self.display_id
170                        .clone()
171                        .unwrap_or_else(|| format!("service/{}", self.id)),
172                    None,
173                    || async {
174                        let built = self.build().await?;
175
176                        let host = &self.on;
177                        let launched = host.provision(resource_result);
178
179                        launched.copy_binary(built).await?;
180                        Ok(launched)
181                    },
182                )
183            })
184            .await?;
185        Ok(())
186    }
187
188    async fn ready(&self) -> Result<()> {
189        self.launched_binary
190            .get_or_try_init(|| {
191                ProgressTracker::with_group(
192                    self.display_id
193                        .clone()
194                        .unwrap_or_else(|| format!("service/{}", self.id)),
195                    None,
196                    || async {
197                        let launched_host = self.launched_host.get().unwrap();
198
199                        let built = self.build().await?;
200                        let args = self.args.as_ref().cloned().unwrap_or_default();
201
202                        let binary = launched_host
203                            .launch_binary(
204                                self.display_id
205                                    .clone()
206                                    .unwrap_or_else(|| format!("service/{}", self.id)),
207                                built,
208                                &args,
209                                self.tracing.clone(),
210                            )
211                            .await?;
212
213                        let bind_config = self
214                            .port_to_bind
215                            .iter()
216                            .map(|(port_name, bind_type)| {
217                                (port_name.clone(), launched_host.server_config(bind_type))
218                            })
219                            .collect::<HashMap<_, _>>();
220
221                        let formatted_bind_config = serde_json::to_string::<InitConfig>(&(
222                            bind_config,
223                            self.meta.get().map(|s| s.as_str().into()),
224                        ))
225                        .unwrap();
226
227                        // request stdout before sending config so we don't miss the "ready" response
228                        let stdout_receiver = binary.deploy_stdout();
229
230                        binary.stdin().send(format!("{formatted_bind_config}\n"))?;
231
232                        let ready_line = ProgressTracker::leaf(
233                            "waiting for ready",
234                            tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
235                        )
236                        .await
237                        .context("Timed out waiting for ready")?
238                        .context("Program unexpectedly quit")?;
239                        if let Some(line_rest) = ready_line.strip_prefix("ready: ") {
240                            *self.server_defns.try_write().unwrap() =
241                                serde_json::from_str(line_rest).unwrap();
242                        } else {
243                            bail!("expected ready");
244                        }
245                        Ok(binary)
246                    },
247                )
248            })
249            .await?;
250        Ok(())
251    }
252
253    async fn start(&self) -> Result<()> {
254        self.started
255            .get_or_try_init(|| async {
256                let sink_ports_futures =
257                    self.port_to_server
258                        .iter()
259                        .map(|(port_name, outgoing)| async {
260                            (&**port_name, outgoing.load_instantiated(&|p| p).await)
261                        });
262                let sink_ports = futures::future::join_all(sink_ports_futures)
263                    .await
264                    .into_iter()
265                    .collect::<HashMap<_, _>>();
266
267                let formatted_defns = serde_json::to_string(&sink_ports).unwrap();
268
269                let stdout_receiver = self.launched_binary.get().unwrap().deploy_stdout();
270
271                self.launched_binary
272                    .get()
273                    .unwrap()
274                    .stdin()
275                    .send(format!("start: {formatted_defns}\n"))
276                    .unwrap();
277
278                let start_ack_line = ProgressTracker::leaf(
279                    self.display_id
280                        .clone()
281                        .unwrap_or_else(|| format!("service/{}", self.id))
282                        + " / waiting for ack start",
283                    tokio::time::timeout(Duration::from_secs(60), stdout_receiver),
284                )
285                .await??;
286                if !start_ack_line.starts_with("ack start") {
287                    bail!("expected ack start");
288                }
289
290                Ok(())
291            })
292            .await?;
293
294        Ok(())
295    }
296
297    async fn stop(&self) -> Result<()> {
298        ProgressTracker::with_group(
299            self.display_id
300                .clone()
301                .unwrap_or_else(|| format!("service/{}", self.id)),
302            None,
303            || async {
304                let launched_binary = self.launched_binary.get().unwrap();
305                launched_binary.stdin().send("stop\n".to_string())?;
306
307                let timeout_result = ProgressTracker::leaf(
308                    "waiting for exit",
309                    tokio::time::timeout(Duration::from_secs(60), launched_binary.wait()),
310                )
311                .await;
312                match timeout_result {
313                    Err(_timeout) => {} // `wait()` timed out, but stop will force quit.
314                    Ok(Err(unexpected_error)) => return Err(unexpected_error), // `wait()` errored.
315                    Ok(Ok(_exit_status)) => {}
316                }
317                launched_binary.stop().await?;
318
319                Ok(())
320            },
321        )
322        .await
323    }
324}