hydro_deploy/
ssh.rs

1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::{Arc, OnceLock};
4use std::time::Duration;
5
6use anyhow::{Context as _, Result};
7use async_ssh2_russh::russh::client::{Config, Handler};
8use async_ssh2_russh::russh::{Disconnect, compression};
9use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
10use async_ssh2_russh::sftp::SftpError;
11use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
12use async_trait::async_trait;
13use hydro_deploy_integration::ServerBindConfig;
14use inferno::collapse::Collapse;
15use inferno::collapse::perf::Folder;
16use nanoid::nanoid;
17use tokio::fs::File;
18use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::net::TcpListener;
20use tokio::sync::{mpsc, oneshot};
21use tokio_stream::StreamExt;
22use tokio_stream::wrappers::LinesStream;
23use tokio_util::io::SyncIoBridge;
24
25use crate::progress::ProgressTracker;
26use crate::rust_crate::build::BuildOutput;
27use crate::rust_crate::flamegraph::handle_fold_data;
28use crate::rust_crate::tracing_options::TracingOptions;
29use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
30use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult, TracingResults};
31
32const PERF_OUTFILE: &str = "__profile.perf.data";
33
34struct LaunchedSshBinary {
35    _resource_result: Arc<ResourceResult>,
36    // TODO(mingwei): instead of using `NoCheckHandler`, we should check the server's public key
37    // fingerprint (get it somehow via terraform), but ssh `publickey` authentication already
38    // generally prevents MITM attacks.
39    session: Option<AsyncSession<NoCheckHandler>>,
40    channel: AsyncChannel,
41    stdin_sender: mpsc::UnboundedSender<String>,
42    stdout_broadcast: PriorityBroadcast,
43    stderr_broadcast: PriorityBroadcast,
44    tracing: Option<TracingOptions>,
45    tracing_results: OnceLock<TracingResults>,
46}
47
48#[async_trait]
49impl LaunchedBinary for LaunchedSshBinary {
50    fn stdin(&self) -> mpsc::UnboundedSender<String> {
51        self.stdin_sender.clone()
52    }
53
54    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
55        self.stdout_broadcast.receive_priority()
56    }
57
58    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
59        self.stdout_broadcast.receive(None)
60    }
61
62    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
63        self.stderr_broadcast.receive(None)
64    }
65
66    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
67        self.stdout_broadcast.receive(Some(prefix))
68    }
69
70    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
71        self.stderr_broadcast.receive(Some(prefix))
72    }
73
74    fn tracing_results(&self) -> Option<&TracingResults> {
75        self.tracing_results.get()
76    }
77
78    fn exit_code(&self) -> Option<i32> {
79        // until the program exits, the exit status is meaningless
80        self.channel
81            .recv_exit_status()
82            .try_get()
83            .map(|&ec| ec as _)
84            .ok()
85    }
86
87    async fn wait(&self) -> Result<i32> {
88        let _ = self.channel.closed().wait().await;
89        Ok(*self.channel.recv_exit_status().try_get()? as _)
90    }
91
92    async fn stop(&self) -> Result<()> {
93        if !self.channel.closed().is_done() {
94            ProgressTracker::leaf("force stopping", async {
95                // self.channel.signal(russh::Sig::INT).await?; // `^C`
96                self.channel.eof().await?; // Send EOF.
97                self.channel.close().await?; // Close the channel.
98                self.channel.closed().wait().await;
99                Result::<_>::Ok(())
100            })
101            .await?;
102        }
103
104        // Run perf post-processing and download perf output.
105        if let Some(tracing) = self.tracing.as_ref() {
106            assert!(
107                self.tracing_results.get().is_none(),
108                "`tracing_results` already set! Was `stop()` called twice? This is a bug."
109            );
110
111            let session = self.session.as_ref().unwrap();
112            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
113                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
114                    let sftp =
115                        async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
116
117                    let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
118                    let mut local_raw_perf = File::create(local_raw_perf).await?;
119
120                    let total_size = remote_raw_perf.metadata().await?.size.unwrap();
121
122                    use tokio::io::AsyncWriteExt;
123                    let mut index = 0;
124                    loop {
125                        let mut buffer = [0; 16 * 1024];
126                        let n = remote_raw_perf.read(&mut buffer).await?;
127                        if n == 0 {
128                            break;
129                        }
130                        local_raw_perf.write_all(&buffer[..n]).await?;
131                        index += n;
132                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
133                    }
134
135                    Ok::<(), anyhow::Error>(())
136                })
137                .await?;
138            }
139
140            let script_channel = session.open_channel().await?;
141            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
142
143            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
144                let mut stderr_lines = script_channel.stderr().lines();
145                let stdout = script_channel.stdout();
146
147                // Pattern on `()` to make sure no `Result`s are ignored.
148                let ((), fold_data, ()) = tokio::try_join!(
149                    async move {
150                        // Log stderr.
151                        while let Ok(Some(s)) = stderr_lines.next_line().await {
152                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
153                        }
154                        Result::<_>::Ok(())
155                    },
156                    async move {
157                        // Download perf output and fold.
158                        tokio::task::spawn_blocking(move || {
159                            let mut fold_data = Vec::new();
160                            fold_er.collapse(
161                                SyncIoBridge::new(BufReader::new(stdout)),
162                                &mut fold_data,
163                            )?;
164                            Ok(fold_data)
165                        })
166                        .await?
167                    },
168                    async move {
169                        // Run command (last!).
170                        script_channel
171                            .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
172                            .await?;
173                        Ok(())
174                    },
175                )?;
176                Result::<_>::Ok(fold_data)
177            })
178            .await?;
179
180            self.tracing_results
181                .set(TracingResults {
182                    folded_data: fold_data.clone(),
183                })
184                .expect("`tracing_results` already set! This is a bug.");
185
186            handle_fold_data(tracing, fold_data).await?;
187        };
188
189        Ok(())
190    }
191}
192
193impl Drop for LaunchedSshBinary {
194    fn drop(&mut self) {
195        if let Some(session) = self.session.take() {
196            tokio::task::block_in_place(|| {
197                tokio::runtime::Handle::current().block_on(session.disconnect(
198                    Disconnect::ByApplication,
199                    "",
200                    "",
201                ))
202            })
203            .unwrap();
204        }
205    }
206}
207
208#[async_trait]
209pub trait LaunchedSshHost: Send + Sync {
210    fn get_internal_ip(&self) -> String;
211    fn get_external_ip(&self) -> Option<String>;
212    fn get_cloud_provider(&self) -> String;
213    fn resource_result(&self) -> &Arc<ResourceResult>;
214    fn ssh_user(&self) -> &str;
215
216    fn ssh_key_path(&self) -> PathBuf {
217        self.resource_result()
218            .terraform
219            .deployment_folder
220            .as_ref()
221            .unwrap()
222            .path()
223            .join(".ssh")
224            .join("vm_instance_ssh_key_pem")
225    }
226
227    async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
228        let target_addr = SocketAddr::new(
229            self.get_external_ip()
230                .as_ref()
231                .context(
232                    self.get_cloud_provider()
233                        + " host must be configured with an external IP to launch binaries",
234                )?
235                .parse()
236                .unwrap(),
237            22,
238        );
239
240        let res = ProgressTracker::leaf(
241            format!(
242                "connecting to host @ {}",
243                self.get_external_ip().as_ref().unwrap()
244            ),
245            async_retry(
246                &|| async {
247                    let mut config = Config::default();
248                    config.preferred.compression = (&[
249                        compression::ZLIB,
250                        compression::ZLIB_LEGACY,
251                        compression::NONE,
252                    ])
253                        .into();
254                    AsyncSession::connect_publickey(
255                        config,
256                        target_addr,
257                        self.ssh_user(),
258                        self.ssh_key_path(),
259                    )
260                    .await
261                },
262                10,
263                Duration::from_secs(1),
264            ),
265        )
266        .await?;
267
268        Ok(res)
269    }
270}
271
272async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
273where
274    H: 'static + Handler,
275{
276    async_retry(
277        &|| async {
278            Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
279        },
280        10,
281        Duration::from_secs(1),
282    )
283    .await
284}
285
286#[async_trait]
287impl<T: LaunchedSshHost> LaunchedHost for T {
288    fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
289        match bind_type {
290            BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
291            BaseServerStrategy::InternalTcpPort(hint) => {
292                ServerBindConfig::TcpPort(self.get_internal_ip().clone(), *hint)
293            }
294            BaseServerStrategy::ExternalTcpPort(_) => todo!(),
295        }
296    }
297
298    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
299        let session = self.open_ssh_session().await?;
300
301        let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
302
303        let user = self.ssh_user();
304        // we may be deploying multiple binaries, so give each a unique name
305        let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
306
307        if sftp.metadata(&binary_path).await.is_err() {
308            let random = nanoid!(8);
309            let temp_path = format!("/home/{user}/hydro-{random}");
310            let sftp = &sftp;
311
312            ProgressTracker::progress_leaf(
313                format!("uploading binary to {}", binary_path),
314                |set_progress, _| {
315                    async move {
316                        let mut created_file = sftp.create(&temp_path).await?;
317
318                        let mut index = 0;
319                        while index < binary.bin_data.len() {
320                            let written = created_file
321                                .write(
322                                    &binary.bin_data[index
323                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
324                                )
325                                .await?;
326                            index += written;
327                            set_progress(
328                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
329                            );
330                        }
331                        let mut orig_file_stat = sftp.metadata(&temp_path).await?;
332                        orig_file_stat.permissions = Some(0o755); // allow the copied binary to be executed by anyone
333                        created_file.set_metadata(orig_file_stat).await?;
334                        created_file.sync_all().await?;
335                        drop(created_file);
336
337                        match sftp.rename(&temp_path, binary_path).await {
338                            Ok(_) => {}
339                            Err(SftpError::Status(Status {
340                                status_code: StatusCode::Failure, // SSH_FXP_STATUS = 4
341                                ..
342                            })) => {
343                                // file already exists
344                                sftp.remove_file(temp_path).await?;
345                            }
346                            Err(e) => return Err(e.into()),
347                        }
348
349                        anyhow::Ok(())
350                    }
351                },
352            )
353            .await?;
354        }
355        sftp.close().await?;
356
357        Ok(())
358    }
359
360    async fn launch_binary(
361        &self,
362        id: String,
363        binary: &BuildOutput,
364        args: &[String],
365        tracing: Option<TracingOptions>,
366    ) -> Result<Box<dyn LaunchedBinary>> {
367        let session = self.open_ssh_session().await?;
368
369        let user = self.ssh_user();
370        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
371
372        let mut command = binary_path.to_str().unwrap().to_owned();
373        for arg in args {
374            command.push(' ');
375            command.push_str(&shell_escape::unix::escape(arg.into()))
376        }
377
378        // Launch with tracing if specified.
379        if let Some(TracingOptions {
380            frequency,
381            setup_command,
382            ..
383        }) = tracing.clone()
384        {
385            let id_clone = id.clone();
386            ProgressTracker::leaf("install perf", async {
387                // Run setup command
388                if let Some(setup_command) = setup_command {
389                    let setup_channel = create_channel(&session).await?;
390                    let (setup_stdout, setup_stderr) =
391                        (setup_channel.stdout(), setup_channel.stderr());
392                    setup_channel.exec(false, &*setup_command).await?;
393
394                    // log outputs
395                    let mut output_lines = LinesStream::new(setup_stdout.lines())
396                        .merge(LinesStream::new(setup_stderr.lines()));
397                    while let Some(line) = output_lines.next().await {
398                        ProgressTracker::eprintln(format!(
399                            "[{} install perf] {}",
400                            id_clone,
401                            line.unwrap()
402                        ));
403                    }
404
405                    setup_channel.closed().wait().await;
406                    let exit_code = setup_channel.recv_exit_status().try_get();
407                    if Ok(&0) != exit_code {
408                        anyhow::bail!("Failed to install perf on remote host");
409                    }
410                }
411                Ok(())
412            })
413            .await?;
414
415            // Attach perf to the command
416            // Note: `LaunchedSshHost` assumes `perf` on linux.
417            command = format!(
418                "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
419            );
420        }
421
422        let (channel, stdout, stderr) = ProgressTracker::leaf(
423            format!("launching binary {}", binary_path.display()),
424            async {
425                let channel = create_channel(&session).await?;
426                // Make sure to begin reading stdout/stderr before running the command.
427                let (stdout, stderr) = (channel.stdout(), channel.stderr());
428                channel.exec(false, command).await?;
429                anyhow::Ok((channel, stdout, stderr))
430            },
431        )
432        .await?;
433
434        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
435        let mut stdin = channel.stdin();
436
437        tokio::spawn(async move {
438            while let Some(line) = stdin_receiver.recv().await {
439                if stdin.write_all(line.as_bytes()).await.is_err() {
440                    break;
441                }
442                stdin.flush().await.unwrap();
443            }
444        });
445
446        let id_clone = id.clone();
447        let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
448            ProgressTracker::println(format!("[{id_clone}] {s}"));
449        });
450        let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
451            ProgressTracker::println(format!("[{id} stderr] {s}"));
452        });
453
454        Ok(Box::new(LaunchedSshBinary {
455            _resource_result: self.resource_result().clone(),
456            session: Some(session),
457            channel,
458            stdin_sender,
459            stdout_broadcast,
460            stderr_broadcast,
461            tracing,
462            tracing_results: OnceLock::new(),
463        }))
464    }
465
466    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
467        let session = self.open_ssh_session().await?;
468
469        let local_port = TcpListener::bind("127.0.0.1:0").await?;
470        let local_addr = local_port.local_addr()?;
471
472        let internal_ip = addr.ip().to_string();
473        let port = addr.port();
474
475        tokio::spawn(async move {
476            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
477            while let Ok((mut local_stream, _)) = local_port.accept().await {
478                let mut channel = session
479                    .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
480                    .await
481                    .unwrap()
482                    .into_stream();
483                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
484                break;
485                // TODO(shadaj): we should be returning an Arc so that we know
486                // if anyone wants to connect to this forwarded port
487            }
488
489            ProgressTracker::println("[hydro] closing forwarded port");
490        });
491
492        Ok(local_addr)
493    }
494}