dfir_rs/util/
deploy.rs

1//! Hydro Deploy integration for DFIR.
2#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
3
4use std::cell::RefCell;
5use std::collections::HashMap;
6
7use futures::StreamExt;
8use futures::stream::FuturesUnordered;
9pub use hydro_deploy_integration::*;
10use serde::de::DeserializeOwned;
11
12use crate::scheduled::graph::Dfir;
13
14#[macro_export]
15macro_rules! launch {
16    ($f:expr) => {
17        async {
18            let ports = $crate::util::deploy::init_no_ack_start().await;
19            let flow = $f(&ports);
20
21            println!("ack start");
22
23            $crate::util::deploy::launch_flow(flow).await
24        }
25    };
26}
27
28pub use crate::launch;
29
30pub async fn launch_flow(mut flow: Dfir<'_>) {
31    let stop = tokio::sync::oneshot::channel();
32    tokio::task::spawn_blocking(|| {
33        let mut line = String::new();
34        std::io::stdin().read_line(&mut line).unwrap();
35        if line.starts_with("stop") {
36            stop.0.send(()).unwrap();
37        } else {
38            eprintln!("Unexpected stdin input: {:?}", line);
39        }
40    });
41
42    let local_set = tokio::task::LocalSet::new();
43    let flow = local_set.run_until(flow.run());
44
45    tokio::select! {
46        _ = stop.1 => {},
47        _ = flow => {}
48    }
49}
50
51pub async fn launch_flow_containerized(mut flow: Dfir<'_>) {
52    let local_set = tokio::task::LocalSet::new();
53    local_set.run_until(flow.run()).await;
54}
55
56pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
57    let mut input = String::new();
58    std::io::stdin().read_line(&mut input).unwrap();
59    let trimmed = input.trim();
60
61    let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
62
63    // config telling other services how to connect to me
64    let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
65    let mut binds = HashMap::new();
66    for (name, config) in bind_config.0 {
67        let bound = config.bind().await;
68        bind_results.insert(name.clone(), bound.server_port());
69        binds.insert(name.clone(), bound);
70    }
71
72    let bind_serialized = serde_json::to_string(&bind_results).unwrap();
73    println!("ready: {bind_serialized}");
74
75    let mut start_buf = String::new();
76    std::io::stdin().read_line(&mut start_buf).unwrap();
77    let connection_defns = if start_buf.starts_with("start: ") {
78        serde_json::from_str::<HashMap<String, ServerPort>>(
79            start_buf.trim_start_matches("start: ").trim(),
80        )
81        .unwrap()
82    } else {
83        panic!("expected start");
84    };
85
86    let (client_conns, server_conns) = futures::join!(
87        connection_defns
88            .into_iter()
89            .map(|(name, defn)| async move { (name, Connection::AsClient(defn.connect().await)) })
90            .collect::<FuturesUnordered<_>>()
91            .collect::<Vec<_>>(),
92        binds
93            .into_iter()
94            .map(
95                |(name, defn)| async move { (name, Connection::AsServer(accept_bound(defn).await)) }
96            )
97            .collect::<FuturesUnordered<_>>()
98            .collect::<Vec<_>>()
99    );
100
101    let all_connected = client_conns
102        .into_iter()
103        .chain(server_conns.into_iter())
104        .collect();
105
106    DeployPorts {
107        ports: RefCell::new(all_connected),
108        meta: bind_config
109            .1
110            .map(|b| serde_json::from_str(&b).unwrap())
111            .unwrap_or_default(),
112    }
113}
114
115pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
116    let ret = init_no_ack_start::<T>().await;
117
118    println!("ack start");
119
120    ret
121}