1use std::net::SocketAddr;
2use std::path::PathBuf;
3use std::sync::{Arc, OnceLock};
4use std::time::Duration;
5
6use anyhow::{Context as _, Result};
7use async_ssh2_russh::russh::client::{Config, Handler};
8use async_ssh2_russh::russh::{Disconnect, compression};
9use async_ssh2_russh::russh_sftp::protocol::{Status, StatusCode};
10use async_ssh2_russh::sftp::SftpError;
11use async_ssh2_russh::{AsyncChannel, AsyncSession, NoCheckHandler};
12use async_trait::async_trait;
13use hydro_deploy_integration::ServerBindConfig;
14use inferno::collapse::Collapse;
15use inferno::collapse::perf::Folder;
16use nanoid::nanoid;
17use tokio::fs::File;
18use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
19use tokio::net::TcpListener;
20use tokio::sync::{mpsc, oneshot};
21use tokio_stream::StreamExt;
22use tokio_stream::wrappers::LinesStream;
23use tokio_util::io::SyncIoBridge;
24
25use crate::progress::ProgressTracker;
26use crate::rust_crate::build::BuildOutput;
27use crate::rust_crate::flamegraph::handle_fold_data;
28use crate::rust_crate::tracing_options::TracingOptions;
29use crate::util::{PriorityBroadcast, async_retry, prioritized_broadcast};
30use crate::{BaseServerStrategy, LaunchedBinary, LaunchedHost, ResourceResult, TracingResults};
31
32const PERF_OUTFILE: &str = "__profile.perf.data";
33
34struct LaunchedSshBinary {
35 _resource_result: Arc<ResourceResult>,
36 session: Option<AsyncSession<NoCheckHandler>>,
40 channel: AsyncChannel,
41 stdin_sender: mpsc::UnboundedSender<String>,
42 stdout_broadcast: PriorityBroadcast,
43 stderr_broadcast: PriorityBroadcast,
44 tracing: Option<TracingOptions>,
45 tracing_results: OnceLock<TracingResults>,
46}
47
48#[async_trait]
49impl LaunchedBinary for LaunchedSshBinary {
50 fn stdin(&self) -> mpsc::UnboundedSender<String> {
51 self.stdin_sender.clone()
52 }
53
54 fn deploy_stdout(&self) -> oneshot::Receiver<String> {
55 self.stdout_broadcast.receive_priority()
56 }
57
58 fn stdout(&self) -> mpsc::UnboundedReceiver<String> {
59 self.stdout_broadcast.receive(None)
60 }
61
62 fn stderr(&self) -> mpsc::UnboundedReceiver<String> {
63 self.stderr_broadcast.receive(None)
64 }
65
66 fn stdout_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
67 self.stdout_broadcast.receive(Some(prefix))
68 }
69
70 fn stderr_filter(&self, prefix: String) -> mpsc::UnboundedReceiver<String> {
71 self.stderr_broadcast.receive(Some(prefix))
72 }
73
74 fn tracing_results(&self) -> Option<&TracingResults> {
75 self.tracing_results.get()
76 }
77
78 fn exit_code(&self) -> Option<i32> {
79 self.channel
81 .recv_exit_status()
82 .try_get()
83 .map(|&ec| ec as _)
84 .ok()
85 }
86
87 async fn wait(&self) -> Result<i32> {
88 let _ = self.channel.closed().wait().await;
89 Ok(*self.channel.recv_exit_status().try_get()? as _)
90 }
91
92 async fn stop(&self) -> Result<()> {
93 if !self.channel.closed().is_done() {
94 ProgressTracker::leaf("force stopping", async {
95 self.channel.eof().await?; self.channel.close().await?; self.channel.closed().wait().await;
99 Result::<_>::Ok(())
100 })
101 .await?;
102 }
103
104 if let Some(tracing) = self.tracing.as_ref() {
106 assert!(
107 self.tracing_results.get().is_none(),
108 "`tracing_results` already set! Was `stop()` called twice? This is a bug."
109 );
110
111 let session = self.session.as_ref().unwrap();
112 if let Some(local_raw_perf) = tracing.perf_raw_outfile.as_ref() {
113 ProgressTracker::progress_leaf("downloading perf data", |progress, _| async move {
114 let sftp =
115 async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
116
117 let mut remote_raw_perf = sftp.open(PERF_OUTFILE).await?;
118 let mut local_raw_perf = File::create(local_raw_perf).await?;
119
120 let total_size = remote_raw_perf.metadata().await?.size.unwrap();
121
122 use tokio::io::AsyncWriteExt;
123 let mut index = 0;
124 loop {
125 let mut buffer = [0; 16 * 1024];
126 let n = remote_raw_perf.read(&mut buffer).await?;
127 if n == 0 {
128 break;
129 }
130 local_raw_perf.write_all(&buffer[..n]).await?;
131 index += n;
132 progress(((index as f64 / total_size as f64) * 100.0) as u64);
133 }
134
135 Ok::<(), anyhow::Error>(())
136 })
137 .await?;
138 }
139
140 let script_channel = session.open_channel().await?;
141 let mut fold_er = Folder::from(tracing.fold_perf_options.clone().unwrap_or_default());
142
143 let fold_data = ProgressTracker::leaf("perf script & folding", async move {
144 let mut stderr_lines = script_channel.stderr().lines();
145 let stdout = script_channel.stdout();
146
147 let ((), fold_data, ()) = tokio::try_join!(
149 async move {
150 while let Ok(Some(s)) = stderr_lines.next_line().await {
152 ProgressTracker::eprintln(format!("[perf stderr] {s}"));
153 }
154 Result::<_>::Ok(())
155 },
156 async move {
157 tokio::task::spawn_blocking(move || {
159 let mut fold_data = Vec::new();
160 fold_er.collapse(
161 SyncIoBridge::new(BufReader::new(stdout)),
162 &mut fold_data,
163 )?;
164 Ok(fold_data)
165 })
166 .await?
167 },
168 async move {
169 script_channel
171 .exec(false, format!("perf script --symfs=/ -i {PERF_OUTFILE}"))
172 .await?;
173 Ok(())
174 },
175 )?;
176 Result::<_>::Ok(fold_data)
177 })
178 .await?;
179
180 self.tracing_results
181 .set(TracingResults {
182 folded_data: fold_data.clone(),
183 })
184 .expect("`tracing_results` already set! This is a bug.");
185
186 handle_fold_data(tracing, fold_data).await?;
187 };
188
189 Ok(())
190 }
191}
192
193impl Drop for LaunchedSshBinary {
194 fn drop(&mut self) {
195 if let Some(session) = self.session.take() {
196 tokio::task::block_in_place(|| {
197 tokio::runtime::Handle::current().block_on(session.disconnect(
198 Disconnect::ByApplication,
199 "",
200 "",
201 ))
202 })
203 .unwrap();
204 }
205 }
206}
207
208#[async_trait]
209pub trait LaunchedSshHost: Send + Sync {
210 fn get_internal_ip(&self) -> String;
211 fn get_external_ip(&self) -> Option<String>;
212 fn get_cloud_provider(&self) -> String;
213 fn resource_result(&self) -> &Arc<ResourceResult>;
214 fn ssh_user(&self) -> &str;
215
216 fn ssh_key_path(&self) -> PathBuf {
217 self.resource_result()
218 .terraform
219 .deployment_folder
220 .as_ref()
221 .unwrap()
222 .path()
223 .join(".ssh")
224 .join("vm_instance_ssh_key_pem")
225 }
226
227 async fn open_ssh_session(&self) -> Result<AsyncSession<NoCheckHandler>> {
228 let target_addr = SocketAddr::new(
229 self.get_external_ip()
230 .as_ref()
231 .context(
232 self.get_cloud_provider()
233 + " host must be configured with an external IP to launch binaries",
234 )?
235 .parse()
236 .unwrap(),
237 22,
238 );
239
240 let res = ProgressTracker::leaf(
241 format!(
242 "connecting to host @ {}",
243 self.get_external_ip().as_ref().unwrap()
244 ),
245 async_retry(
246 &|| async {
247 let mut config = Config::default();
248 config.preferred.compression = (&[
249 compression::ZLIB,
250 compression::ZLIB_LEGACY,
251 compression::NONE,
252 ])
253 .into();
254 AsyncSession::connect_publickey(
255 config,
256 target_addr,
257 self.ssh_user(),
258 self.ssh_key_path(),
259 )
260 .await
261 },
262 10,
263 Duration::from_secs(1),
264 ),
265 )
266 .await?;
267
268 Ok(res)
269 }
270}
271
272async fn create_channel<H>(session: &AsyncSession<H>) -> Result<AsyncChannel>
273where
274 H: 'static + Handler,
275{
276 async_retry(
277 &|| async {
278 Ok(tokio::time::timeout(Duration::from_secs(60), session.open_channel()).await??)
279 },
280 10,
281 Duration::from_secs(1),
282 )
283 .await
284}
285
286#[async_trait]
287impl<T: LaunchedSshHost> LaunchedHost for T {
288 fn base_server_config(&self, bind_type: &BaseServerStrategy) -> ServerBindConfig {
289 match bind_type {
290 BaseServerStrategy::UnixSocket => ServerBindConfig::UnixSocket,
291 BaseServerStrategy::InternalTcpPort(hint) => {
292 ServerBindConfig::TcpPort(self.get_internal_ip().clone(), *hint)
293 }
294 BaseServerStrategy::ExternalTcpPort(_) => todo!(),
295 }
296 }
297
298 async fn copy_binary(&self, binary: &BuildOutput) -> Result<()> {
299 let session = self.open_ssh_session().await?;
300
301 let sftp = async_retry(&|| session.open_sftp(), 10, Duration::from_secs(1)).await?;
302
303 let user = self.ssh_user();
304 let binary_path = format!("/home/{user}/hydro-{}", binary.unique_id());
306
307 if sftp.metadata(&binary_path).await.is_err() {
308 let random = nanoid!(8);
309 let temp_path = format!("/home/{user}/hydro-{random}");
310 let sftp = &sftp;
311
312 ProgressTracker::progress_leaf(
313 format!("uploading binary to {}", binary_path),
314 |set_progress, _| {
315 async move {
316 let mut created_file = sftp.create(&temp_path).await?;
317
318 let mut index = 0;
319 while index < binary.bin_data.len() {
320 let written = created_file
321 .write(
322 &binary.bin_data[index
323 ..std::cmp::min(index + 128 * 1024, binary.bin_data.len())],
324 )
325 .await?;
326 index += written;
327 set_progress(
328 ((index as f64 / binary.bin_data.len() as f64) * 100.0) as u64,
329 );
330 }
331 let mut orig_file_stat = sftp.metadata(&temp_path).await?;
332 orig_file_stat.permissions = Some(0o755); created_file.set_metadata(orig_file_stat).await?;
334 created_file.sync_all().await?;
335 drop(created_file);
336
337 match sftp.rename(&temp_path, binary_path).await {
338 Ok(_) => {}
339 Err(SftpError::Status(Status {
340 status_code: StatusCode::Failure, ..
342 })) => {
343 sftp.remove_file(temp_path).await?;
345 }
346 Err(e) => return Err(e.into()),
347 }
348
349 anyhow::Ok(())
350 }
351 },
352 )
353 .await?;
354 }
355 sftp.close().await?;
356
357 Ok(())
358 }
359
360 async fn launch_binary(
361 &self,
362 id: String,
363 binary: &BuildOutput,
364 args: &[String],
365 tracing: Option<TracingOptions>,
366 ) -> Result<Box<dyn LaunchedBinary>> {
367 let session = self.open_ssh_session().await?;
368
369 let user = self.ssh_user();
370 let binary_path = PathBuf::from(format!("/home/{user}/hydro-{}", binary.unique_id()));
371
372 let mut command = binary_path.to_str().unwrap().to_owned();
373 for arg in args {
374 command.push(' ');
375 command.push_str(&shell_escape::unix::escape(arg.into()))
376 }
377
378 if let Some(TracingOptions {
380 frequency,
381 setup_command,
382 ..
383 }) = tracing.clone()
384 {
385 let id_clone = id.clone();
386 ProgressTracker::leaf("install perf", async {
387 if let Some(setup_command) = setup_command {
389 let setup_channel = create_channel(&session).await?;
390 let (setup_stdout, setup_stderr) =
391 (setup_channel.stdout(), setup_channel.stderr());
392 setup_channel.exec(false, &*setup_command).await?;
393
394 let mut output_lines = LinesStream::new(setup_stdout.lines())
396 .merge(LinesStream::new(setup_stderr.lines()));
397 while let Some(line) = output_lines.next().await {
398 ProgressTracker::eprintln(format!(
399 "[{} install perf] {}",
400 id_clone,
401 line.unwrap()
402 ));
403 }
404
405 setup_channel.closed().wait().await;
406 let exit_code = setup_channel.recv_exit_status().try_get();
407 if Ok(&0) != exit_code {
408 anyhow::bail!("Failed to install perf on remote host");
409 }
410 }
411 Ok(())
412 })
413 .await?;
414
415 command = format!(
418 "perf record -F {frequency} -e cycles:u --call-graph dwarf,65528 -o {PERF_OUTFILE} {command}",
419 );
420 }
421
422 let (channel, stdout, stderr) = ProgressTracker::leaf(
423 format!("launching binary {}", binary_path.display()),
424 async {
425 let channel = create_channel(&session).await?;
426 let (stdout, stderr) = (channel.stdout(), channel.stderr());
428 channel.exec(false, command).await?;
429 anyhow::Ok((channel, stdout, stderr))
430 },
431 )
432 .await?;
433
434 let (stdin_sender, mut stdin_receiver) = mpsc::unbounded_channel::<String>();
435 let mut stdin = channel.stdin();
436
437 tokio::spawn(async move {
438 while let Some(line) = stdin_receiver.recv().await {
439 if stdin.write_all(line.as_bytes()).await.is_err() {
440 break;
441 }
442 stdin.flush().await.unwrap();
443 }
444 });
445
446 let id_clone = id.clone();
447 let stdout_broadcast = prioritized_broadcast(LinesStream::new(stdout.lines()), move |s| {
448 ProgressTracker::println(format!("[{id_clone}] {s}"));
449 });
450 let stderr_broadcast = prioritized_broadcast(LinesStream::new(stderr.lines()), move |s| {
451 ProgressTracker::println(format!("[{id} stderr] {s}"));
452 });
453
454 Ok(Box::new(LaunchedSshBinary {
455 _resource_result: self.resource_result().clone(),
456 session: Some(session),
457 channel,
458 stdin_sender,
459 stdout_broadcast,
460 stderr_broadcast,
461 tracing,
462 tracing_results: OnceLock::new(),
463 }))
464 }
465
466 async fn forward_port(&self, addr: &SocketAddr) -> Result<SocketAddr> {
467 let session = self.open_ssh_session().await?;
468
469 let local_port = TcpListener::bind("127.0.0.1:0").await?;
470 let local_addr = local_port.local_addr()?;
471
472 let internal_ip = addr.ip().to_string();
473 let port = addr.port();
474
475 tokio::spawn(async move {
476 #[expect(clippy::never_loop, reason = "tcp accept loop pattern")]
477 while let Ok((mut local_stream, _)) = local_port.accept().await {
478 let mut channel = session
479 .channel_open_direct_tcpip(internal_ip, port.into(), "127.0.0.1", 22)
480 .await
481 .unwrap()
482 .into_stream();
483 let _ = tokio::io::copy_bidirectional(&mut local_stream, &mut channel).await;
484 break;
485 }
488
489 ProgressTracker::println("[hydro] closing forwarded port");
490 });
491
492 Ok(local_addr)
493 }
494}