diff --git a/Cargo.toml b/Cargo.toml index 0c33883..e17c816 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] # Default enable napi4 feature, see https://nodejs.org/api/n-api.html#node-api-version-matrix napi = { version = "2.12.2", default-features = false, features = ["napi4"] } napi-derive = "2.12.2" -rustix = "0.38.30" +rustix = { version = "0.38.30", features = ["event"] } rustix-openpty = "0.1.1" libc = "0.2.152" diff --git a/index.d.ts b/index.d.ts index a3194a3..b387d67 100644 --- a/index.d.ts +++ b/index.d.ts @@ -53,11 +53,14 @@ export interface Size { * // TODO: Handle the error. * }); * ``` + * + * The last parameter (a callback that gets stdin chunks) is optional and is only there for + * compatibility with bun 1.1.7. */ export class Pty { /** The pid of the forked process. */ pid: number - constructor(command: string, args: Array, envs: Record, dir: string, size: Size, onExit: (err: null | Error, exitCode: number) => void) + constructor(command: string, args: Array, envs: Record, dir: string, size: Size, onExit: (err: null | Error, exitCode: number) => void, onData?: (err: null | Error, data: Buffer) => void) /** Resize the terminal. */ resize(size: Size): void /** diff --git a/index.test.ts b/index.test.ts index b6f086a..409083d 100644 --- a/index.test.ts +++ b/index.test.ts @@ -271,6 +271,34 @@ describe('PTY', () => { Bun.write(pty.fd(), message + EOT + EOT); }); + (os.type() !== 'Darwin' ? test : test.skip)( + 'works with data callback', + (done) => { + const message = 'hello bun\n'; + let buffer = ''; + + const pty = new Pty( + '/bin/cat', + [], + {}, + CWD, + { rows: 24, cols: 80 }, + () => { + expect(buffer).toBe('hello bun\r\nhello bun\r\n'); + pty.close(); + + done(); + }, + (err: Error | null, chunk: Buffer) => { + expect(err).toBeNull(); + buffer += chunk.toString(); + }, + ); + + Bun.write(pty.fd(), message + EOT + EOT); + }, + ); + test("doesn't break when executing non-existing binary", (done) => { try { new Pty( diff --git a/src/lib.rs b/src/lib.rs index 1349d5f..14c2368 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,10 +7,11 @@ use std::process::{Command, Stdio}; use std::thread; use libc::{self, c_int}; -use napi::bindgen_prelude::JsFunction; +use napi::bindgen_prelude::{Buffer, JsFunction}; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}; use napi::Status::GenericFailure; use napi::{self, Env}; +use rustix::event::{poll, PollFd, PollFlags}; use rustix_openpty::openpty; use rustix_openpty::rustix::termios::{self, InputModes, OptionalActions, Winsize}; @@ -61,6 +62,9 @@ extern crate napi_derive; /// // TODO: Handle the error. /// }); /// ``` +/// +/// The last parameter (a callback that gets stdin chunks) is optional and is only there for +/// compatibility with bun 1.1.7. #[napi] #[allow(dead_code)] struct Pty { @@ -129,6 +133,7 @@ impl Pty { dir: String, size: Size, #[napi(ts_arg_type = "(err: null | Error, exitCode: number) => void")] on_exit: JsFunction, + #[napi(ts_arg_type = "(err: null | Error, data: Buffer) => void")] on_data: Option, ) -> Result { let is_node = env.get_node_version()?.release == "node"; let window_size = Winsize { @@ -208,34 +213,138 @@ impl Pty { // analysis to ensure that every single call goes through the wrapper to avoid double `wait`'s // on a child. // - Have a single thread loop where other entities can register children (by sending the pid - // over a channel) and this loop can use `epoll` to listen for each child's `pidfd` for when + // over a channel) and this loop can use `poll` to listen for each child's `pidfd` for when // they are ready to be `wait`'ed. This has the inconvenience that it consumes one FD per child. // // For discussion check out: https://github.com/replit/ruspty/pull/1#discussion_r1463672548 let ts_on_exit: ThreadsafeFunction = on_exit .create_threadsafe_function(0, |ctx| ctx.env.create_int32(ctx.value).map(|v| vec![v]))?; - thread::spawn(move || match child.wait() { - Ok(status) => { - if status.success() { - ts_on_exit.call(Ok(0), ThreadsafeFunctionCallMode::Blocking); - } else { - ts_on_exit.call( - Ok(status.code().unwrap_or(-1)), + let ts_on_data = on_data + .map(|on_data| { + Ok::< + ( + ThreadsafeFunction, + OwnedFd, + ), + napi::Error, + >(( + on_data.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?, + match controller_fd.try_clone() { + Ok(fd) => Ok(fd), + Err(err) => Err(napi::Error::new( + GenericFailure, + format!( + "OS error when setting up child process wait: {}", + err.raw_os_error().unwrap_or(-1) + ), + )), + }?, + )) + }) + .transpose()?; + thread::spawn(move || { + #[cfg(target_os = "linux")] + { + // The following code only works on Linux due to the reliance on pidfd. + use rustix::process::{pidfd_open, Pid, PidfdFlags}; + + if let Some((ts_on_data, controller_fd)) = ts_on_data { + if let Err(err) = || -> Result<(), napi::Error> { + let pidfd = pidfd_open( + unsafe { Pid::from_raw_unchecked(child.id() as i32) }, + PidfdFlags::empty(), + ) + .map_err(|err| napi::Error::new(GenericFailure, format!("pidfd_open: {:#?}", err)))?; + let mut poll_fds = [ + PollFd::new(&controller_fd, PollFlags::IN), + PollFd::new(&pidfd, PollFlags::IN), + ]; + let mut buf = [0u8; 16 * 1024]; + loop { + for poll_fd in &mut poll_fds[..] { + poll_fd.clear_revents(); + } + poll(&mut poll_fds, -1).map_err(|err| { + napi::Error::new( + GenericFailure, + format!("OS error when waiting for child read: {:#?}", err), + ) + })?; + // Always check the controller FD first to see if it has any events. + if poll_fds[0].revents().contains(PollFlags::IN) { + match rustix::io::read(&controller_fd, &mut buf) { + Ok(n) => { + ts_on_data.call( + Ok(buf[..n as usize].into()), + ThreadsafeFunctionCallMode::Blocking, + ); + } + Err(errno) => { + if errno == rustix::io::Errno::AGAIN || errno == rustix::io::Errno::INTR { + // These two errors are safe to retry. + continue; + } + if errno == rustix::io::Errno::IO { + // This error happens when the child closes. We can simply break the loop. + return Ok(()); + } + return Err(napi::Error::new( + GenericFailure, + format!("OS error when reading from child: {:#?}", errno,), + )); + } + } + // If there was data, keep trying to read this FD. + continue; + } + + // Now that we're sure that the controller FD doesn't have any events, we have + // successfully drained the child's output, so we can now check if the child has + // exited. + if poll_fds[1].revents().contains(PollFlags::IN) { + return Ok(()); + } + } + }() { + ts_on_data.call(Err(err), ThreadsafeFunctionCallMode::Blocking); + } + } + } + #[cfg(target_os != "linux")] + { + if let Some((ts_on_data, _controller_fd)) = ts_on_data { + ts_on_data.call( + Err(napi::Error::new( + GenericFailure, + "the data callback is only implemented in Linux", + )), ThreadsafeFunctionCallMode::Blocking, ); } } - Err(err) => { - ts_on_exit.call( - Err(napi::Error::new( - GenericFailure, - format!( - "OS error when waiting for child process to exit: {}", - err.raw_os_error().unwrap_or(-1) - ), - )), - ThreadsafeFunctionCallMode::Blocking, - ); + match child.wait() { + Ok(status) => { + if status.success() { + ts_on_exit.call(Ok(0), ThreadsafeFunctionCallMode::Blocking); + } else { + ts_on_exit.call( + Ok(status.code().unwrap_or(-1)), + ThreadsafeFunctionCallMode::Blocking, + ); + } + } + Err(err) => { + ts_on_exit.call( + Err(napi::Error::new( + GenericFailure, + format!( + "OS error when waiting for child process to exit: {}", + err.raw_os_error().unwrap_or(-1) + ), + )), + ThreadsafeFunctionCallMode::Blocking, + ); + } } });