Skip to main content

hydro_deploy/
ssh.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::path::PathBuf;
4use std::sync::Arc;
5#[cfg(feature = "profile-folding")]
6use std::sync::OnceLock;
7use std::time::Duration;
8
9use anyhow::{Context as _, Result};
10use async_ssh2_russh::russh::client::{Config, Handler};
11use async_ssh2_russh::russh::{Disconnect, compression};
12use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
13use async_ssh2_russh::sftp::SftpError;
14use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
15use async_trait::async_trait;
16use hydro_deploy_integration::ServerBindConfig;
17#[cfg(feature = "profile-folding")]
18use inferno::collapse::Collapse;
19#[cfg(feature = "profile-folding")]
20use inferno::collapse::perf::Folder;
21use nanoid::nanoid;
22use tokio::fs::File;
23#[cfg(feature = "profile-folding")]
24use tokio::io::BufReader;
25use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
26use tokio::net::TcpListener;
27use tokio::sync::{mpsc, oneshot};
28use tokio_stream::StreamExt;
29use tokio_stream::wrappers::LinesStream;
30#[cfg(feature = "profile-folding")]
31use tokio_util::io::SyncIoBridge;
32
33#[cfg(feature = "profile-folding")]
34use crate::TracingResults;
35use crate::progress::ProgressTracker;
36use crate::rust_crate::build::BuildOutput;
37#[cfg(feature = "profile-folding")]
38use crate::rust_crate::flamegraph::handle_fold_data;
39use crate::rust_crate::tracing_options::TracingOptions;
40use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
41use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult};
42
43const PERF_OUTFILE: &str = "__profile.perf.data";
44
45struct LaunchedSshBinary {
46    _resource_result: Arc<ResourceResult>,
47    // TODO(mingwei): instead of using `NoCheckHandler`, we should check the server's public key
48    // fingerprint (get it somehow via terraform), but ssh `publickey` authentication already
49    // generally prevents MITM attacks.
50    session: Option<AsyncSession<NoCheckHandler>>,
51    channel: AsyncChannel,
52    stdin_sender: mpsc::UnboundedSender<String>,
53    stdout_broadcast: PriorityBroadcast,
54    stderr_broadcast: PriorityBroadcast,
55    tracing: Option<TracingOptions>,
56    #[cfg(feature = "profile-folding")]
57    tracing_results: OnceLock<TracingResults>,
58}
59
60#[async_trait]
61impl LaunchedBinary for LaunchedSshBinary {
62    fn stdin(&self) -> mpsc::UnboundedSender<String> {
63        self.stdin_sender.clone()
64    }
65
66    fn deploy_stdout(&self) -> oneshot::Receiver<String> {
67        self.stdout_broadcast.receive_priority()
68    }
69
70    fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
71        self.stdout_broadcast.receive(None)
72    }
73
74    fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
75        self.stderr_broadcast.receive(None)
76    }
77
78    fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
79        self.stdout_broadcast.receive(Some(prefix))
80    }
81
82    fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
83        self.stderr_broadcast.receive(Some(prefix))
84    }
85
86    #[cfg(feature = "profile-folding")]
87    fn tracing_results(&self) -> Option<&TracingResults> {
88        self.tracing_results.get()
89    }
90
91    fn exit_code(&self) -> Option<i32> {
92        // until the program exits, the exit status is meaningless
93        self.channel
94            .recv_exit_status()
95            .try_get()
96            .map(|&ec| ec as _)
97            .ok()
98    }
99
100    async fn wait(&self) -> Result<i32> {
101        let _ = self.channel.closed().wait().await;
102        Ok(*self.channel.recv_exit_status().try_get()? as _)
103    }
104
105    async fn stop(&self) -> Result<()> {
106        if !self.channel.closed().is_done() {
107            ProgressTracker::leaf("force stopping", async {
108                // self.channel.signal(russh::Sig::INT).await?; // `^C`
109                self.channel.eof().await?; // Send EOF.
110                self.channel.close().await?; // Close the channel.
111                self.channel.closed().wait().await;
112                Result::<_>::Ok(())
113            })
114            .await?;
115        }
116
117        // Run perf post-processing and download perf output.
118        if let Some(tracing) = self.tracing.as_ref() {
119            #[cfg(feature = "profile-folding")]
120            assert!(
121                self.tracing_results.get().is_none(),
122                "`tracing_results` already set! Was `stop()` called twice? This is a bug."
123            );
124
125            let session = self.session.as_ref().unwrap();
126            if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
127                ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
128                    let sftp =
129                        async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
130
131                    let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
132                    let mut local_raw_perf = File::create(local_raw_perf).await?;
133
134                    let total_size = remote_raw_perf.metadata().await?.size.unwrap();
135
136                    use tokio::io::AsyncWriteExt;
137                    let mut index = 0;
138                    loop {
139                        let mut buffer = [0; 16 * 1024];
140                        let n = remote_raw_perf.read(&mut buffer).await?;
141                        if n == 0 {
142                            break;
143                        }
144                        local_raw_perf.write_all(&buffer[..n]).await?;
145                        index += n;
146                        progress(((index as f64 / total_size as f64) * 100.0) as u64);
147                    }
148
149                    Ok::<(), anyhow::Error>(())
150                })
151                .await?;
152            }
153
154            #[cfg(feature = "profile-folding")]
155            let script_channel = session.open_channel().await?;
156            #[cfg(feature = "profile-folding")]
157            let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
158
159            #[cfg(feature = "profile-folding")]
160            let fold_data = ProgressTracker::leaf("perf script & folding", async move {
161                let mut stderr_lines = script_channel.stderr().lines();
162                let stdout = script_channel.stdout();
163
164                // Pattern on `()` to make sure no `Result`s are ignored.
165                let ((), fold_data, ()) = tokio::try_join!(
166                    async move {
167                        // Log stderr.
168                        while let Ok(Some(s)) = stderr_lines.next_line().await {
169                            ProgressTracker::eprintln(format!("[perf stderr] {s}"));
170                        }
171                        Result::<_>::Ok(())
172                    },
173                    async move {
174                        // Download perf output and fold.
175                        tokio::task::spawn_blocking(move || {
176                            let mut fold_data = Vec::new();
177                            fold_er.collapse(
178                                SyncIoBridge::new(BufReader::new(stdout)),
179                                &mut fold_data,
180                            )?;
181                            Ok(fold_data)
182                        })
183                        .await?
184                    },
185                    async move {
186                        // Run command (last!).
187                        script_channel
188                            .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
189                            .await?;
190                        Ok(())
191                    },
192                )?;
193                Result::<_>::Ok(fold_data)
194            })
195            .await?;
196
197            #[cfg(feature = "profile-folding")]
198            self.tracing_results
199                .set(TracingResults {
200                    folded_data: fold_data.clone(),
201                })
202                .expect("`tracing_results` already set! This is a bug.");
203
204            #[cfg(feature = "profile-folding")]
205            handle_fold_data(tracing, fold_data).await?;
206        };
207
208        Ok(())
209    }
210}
211
212impl Drop for LaunchedSshBinary {
213    fn drop(&mut self) {
214        if let Some(session) = self.session.take() {
215            tokio::task::block_in_place(|| {
216                tokio::runtime::Handle::current().block_on(session.disconnect(
217                    Disconnect::ByApplication,
218                    "",
219                    "",
220                ))
221            })
222            .unwrap();
223        }
224    }
225}
226
227#[async_trait]
228pub trait LaunchedSshHost: Send + Sync {
229    fn get_internal_ip(&self) -> &str;
230    fn get_external_ip(&self) -> Option<&str>;
231    fn get_cloud_provider(&self) -> &'static str;
232    fn resource_result(&self) -> &Arc<ResourceResult>;
233    fn ssh_user(&self) -> &str;
234
235    fn ssh_key_path(&self) -> PathBuf {
236        self.resource_result()
237            .terraform
238            .deployment_folder
239            .as_ref()
240            .unwrap()
241            .path()
242            .join(".ssh")
243            .join("vm_instance_ssh_key_pem")
244    }
245
246    async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
247        let target_addr = SocketAddr::new(
248            self.get_external_ip()
249                .context(format!(
250                    "{} host must be configured with an external IP to launch binaries",
251                    self.get_cloud_provider()
252                ))?
253                .parse()
254                .unwrap(),
255            22,
256        );
257
258        let res = ProgressTracker::leaf(
259            format!("connecting to host @ {}", self.get_external_ip().unwrap()),
260            async_retry(
261                &|| async {
262                    let mut config = Config::default();
263                    config.preferred.compression = (&[
264                        compression::ZLIB,
265                        compression::ZLIB_LEGACY,
266                        compression::NONE,
267                    ])
268                        .into();
269                    AsyncSession::connect_publickey(
270                        config,
271                        target_addr,
272                        self.ssh_user(),
273                        self.ssh_key_path(),
274                    )
275                    .await
276                },
277                10,
278                Duration::from_secs(1),
279            ),
280        )
281        .await?;
282
283        Ok(res)
284    }
285}
286
287async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
288where
289    H: 'static + Handler,
290{
291    async_retry(
292        &|| async {
293            Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
294        },
295        10,
296        Duration::from_secs(1),
297    )
298    .await
299}
300
301#[async_trait]
302impl<T: LaunchedSshHost> LaunchedHost for T {
303    fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
304        match bind_type {
305            BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
306            BaseServerStrategy::InternalTcpPort(hint) => {
307                ServerBindConfig::TcpPort(self.get_internal_ip().to_owned(), *hint)
308            }
309            BaseServerStrategy::ExternalTcpPort(_) => todo!(),
310        }
311    }
312
313    async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
314        let session = self.open_ssh_session().await?;
315
316        let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
317
318        let user = self.ssh_user();
319        // we may be deploying multiple binaries, so give each a unique name
320        let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
321
322        if sftp.metadata(&binary_path).await.is_err() {
323            let random = nanoid!(8);
324            let temp_path = format!("/home/{user}/hydro-{random}");
325            let sftp = &sftp;
326
327            ProgressTracker::progress_leaf(
328                format!("uploading binary to {}", binary_path),
329                |set_progress, _| {
330                    async move {
331                        let mut created_file = sftp.create(&temp_path).await?;
332
333                        let mut index = 0;
334                        while index < binary.bin_data.len() {
335                            let written = created_file
336                                .write(
337                                    &binary.bin_data[index
338                                        ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
339                                )
340                                .await?;
341                            index += written;
342                            set_progress(
343                                ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
344                            );
345                        }
346                        let mut orig_file_stat = sftp.metadata(&temp_path).await?;
347                        orig_file_stat.permissions = Some(0o755); // allow the copied binary to be executed by anyone
348                        created_file.set_metadata(orig_file_stat).await?;
349                        created_file.sync_all().await?;
350                        drop(created_file);
351
352                        match sftp.rename(&temp_path, binary_path).await {
353                            Ok(_) => {}
354                            Err(SftpError::Status(Status {
355                                status_code: StatusCode::Failure, // SSH_FXP_STATUS = 4
356                                ..
357                            })) => {
358                                // file already exists
359                                sftp.remove_file(temp_path).await?;
360                            }
361                            Err(e) => return Err(e.into()),
362                        }
363
364                        anyhow::Ok(())
365                    }
366                },
367            )
368            .await?;
369        }
370        sftp.close().await?;
371
372        Ok(())
373    }
374
375    async fn launch_binary(
376        &self,
377        id: String,
378        binary: &BuildOutput,
379        args: &[String],
380        tracing: Option<TracingOptions>,
381        env: &HashMap<String, String>,
382        pin_to_core: Option<usize>,
383    ) -> Result<Box<dyn LaunchedBinary>> {
384        let session = self.open_ssh_session().await?;
385
386        let user = self.ssh_user();
387        let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
388
389        let mut command = String::new();
390        // Prepend env variables
391        for (k, v) in env {
392            command.push_str(&format!("{}={} ", k, shell_escape::unix::escape(v.into())));
393        }
394
395        if let Some(core) = pin_to_core {
396            command.push_str(&format!("taskset -c {core} "));
397        }
398        command.push_str(binary_path.to_str().unwrap());
399        for arg in args {
400            command.push(' ');
401            command.push_str(&shell_escape::unix::escape(arg.into()))
402        }
403
404        // Launch with tracing if specified.
405        if let Some(TracingOptions {
406            frequency,
407            setup_command,
408            ..
409        }) = tracing.clone()
410        {
411            let id_clone = id.clone();
412            ProgressTracker::leaf("install perf", async {
413                // Run setup command
414                if let Some(setup_command) = setup_command {
415                    let setup_channel = create_channel(&session).await?;
416                    let (setup_stdout, setup_stderr) =
417                        (setup_channel.stdout(), setup_channel.stderr());
418                    setup_channel.exec(false, &*setup_command).await?;
419
420                    // log outputs
421                    let mut output_lines = LinesStream::new(setup_stdout.lines())
422                        .merge(LinesStream::new(setup_stderr.lines()));
423                    while let Some(line) = output_lines.next().await {
424                        ProgressTracker::eprintln(format!(
425                            "[{} install perf] {}",
426                            id_clone,
427                            line.unwrap()
428                        ));
429                    }
430
431                    setup_channel.closed().wait().await;
432                    let exit_code = setup_channel.recv_exit_status().try_get();
433                    if Ok(&0) != exit_code {
434                        anyhow::bail!("Failed to install perf on remote host");
435                    }
436                }
437                Ok(())
438            })
439            .await?;
440
441            // Attach perf to the command
442            // Note: `LaunchedSshHost` assumes `perf` on linux.
443            command = format!(
444                "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
445            );
446        }
447
448        let (channel, stdout, stderr) = ProgressTracker::leaf(
449            format!("launching binary {}", binary_path.display()),
450            async {
451                let channel = create_channel(&session).await?;
452                // Make sure to begin reading stdout/stderr before running the command.
453                let (stdout, stderr) = (channel.stdout(), channel.stderr());
454                channel.exec(false, command).await?;
455                anyhow::Ok((channel, stdout, stderr))
456            },
457        )
458        .await?;
459
460        let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
461        let mut stdin = channel.stdin();
462
463        tokio::spawn(async move {
464            while let Some(line) = stdin_receiver.recv().await {
465                if stdin.write_all(line.as_bytes()).await.is_err() {
466                    break;
467                }
468                stdin.flush().await.unwrap();
469            }
470        });
471
472        let id_clone = id.clone();
473        let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
474            ProgressTracker::println(format!("[{id_clone}] {s}"));
475        });
476        let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
477            ProgressTracker::println(format!("[{id} stderr] {s}"));
478        });
479
480        Ok(Box::new(LaunchedSshBinary {
481            _resource_result: self.resource_result().clone(),
482            session: Some(session),
483            channel,
484            stdin_sender,
485            stdout_broadcast,
486            stderr_broadcast,
487            tracing,
488            #[cfg(feature = "profile-folding")]
489            tracing_results: OnceLock::new(),
490        }))
491    }
492
493    async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
494        let session = self.open_ssh_session().await?;
495
496        let local_port = TcpListener::bind("127.0.0.1:0").await?;
497        let local_addr = local_port.local_addr()?;
498
499        let internal_ip = addr.ip().to_string();
500        let port = addr.port();
501
502        tokio::spawn(async move {
503            #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
504            while let Ok((mut local_stream, _)) = local_port.accept().await {
505                let mut channel = session
506                    .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
507                    .await
508                    .unwrap()
509                    .into_stream();
510                let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
511                break;
512                // TODO(shadaj): we should be returning an Arc so that we know
513                // if anyone wants to connect to this forwarded port
514            }
515
516            ProgressTracker::println("[hydro] closing forwarded port");
517        });
518
519        Ok(local_addr)
520    }
521}