diff --git a/Cargo.lock b/Cargo.lock index e7f1a147..98d8a954 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -815,6 +815,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-channel" version = "0.5.15" @@ -2152,6 +2161,20 @@ dependencies = [ "digest", ] +[[package]] +name = "mdata-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "base64", + "crc32fast", + "getrandom 0.3.4", + "libc", + "thiserror 2.0.18", + "tracing", + "tracing-subscriber", +] + [[package]] name = "memchr" version = "2.8.0" diff --git a/Cargo.toml b/Cargo.toml index e8a8f46d..5301128f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "apis/jira-api", "apis/vmapi-api", "cli/bugview-cli", + "cli/mdata-client", "cli/triton-cli", "cli/vmapi-cli", "clients/internal/bugview-client", @@ -59,12 +60,16 @@ expect_used = "deny" [workspace.dependencies] askama = "0.15" anyhow = "1.0" +base64 = "0.22" build-data = "0.2" camino = "1.1" cargo_toml = "0.22" chrono = "0.4" +crc32fast = "1.4" clap = { version = "4.5", features = ["derive"] } +getrandom = "0.3" dropshot = { version = "0.16" } +libc = "0.2" dropshot-api-manager = "0.3" dropshot-api-manager-types = "0.3" openapiv3 = "2.0" diff --git a/cli/mdata-client/Cargo.toml b/cli/mdata-client/Cargo.toml new file mode 100644 index 00000000..43152708 --- /dev/null +++ b/cli/mdata-client/Cargo.toml @@ -0,0 +1,42 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +# +# Copyright 2026 Edgecast Cloud LLC. + +[package] +name = "mdata-client" +version = "0.1.0" +edition.workspace = true +description = "Rust implementation of the SmartOS metadata client (mdata-get/put/list/delete)" + +[lints] +workspace = true + +[[bin]] +name = "mdata-get" +path = "src/bin/mdata_get.rs" + +[[bin]] +name = "mdata-put" +path = "src/bin/mdata_put.rs" + +[[bin]] +name = "mdata-list" +path = "src/bin/mdata_list.rs" + +[[bin]] +name = "mdata-delete" +path = "src/bin/mdata_delete.rs" + +[dependencies] +anyhow = { workspace = true } +base64 = { workspace = true } +crc32fast = { workspace = true } +getrandom = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } + +[target.'cfg(unix)'.dependencies] +libc = { workspace = true } diff --git a/cli/mdata-client/src/bin/mdata_delete.rs b/cli/mdata-client/src/bin/mdata_delete.rs new file mode 100644 index 00000000..997bb907 --- /dev/null +++ b/cli/mdata-client/src/bin/mdata_delete.rs @@ -0,0 +1,50 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! mdata-delete: Delete a metadata key. +//! +//! Usage: mdata-delete +//! +//! Deleting a non-existent key is not considered an error. +//! Requires V2 protocol support from the metadata service. +//! +//! Exit codes: +//! 0 - Success (or key did not exist) +//! 2 - Error +//! 3 - Usage error + +use mdata_client::protocol::Protocol; +use mdata_client::{Response, exit_code}; + +fn main() { + mdata_client::init_logging(); + match run() { + Ok(code) => std::process::exit(code), + Err(e) => { + eprintln!("ERROR: {e}"); + std::process::exit(exit_code::ERROR); + } + } +} + +fn run() -> anyhow::Result { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + eprintln!( + "Usage: {} ", + args.first().map(String::as_str).unwrap_or("mdata-delete"), + ); + return Ok(exit_code::USAGE_ERROR); + } + + let key = &args[1]; + let mut proto = Protocol::init()?; + + match proto.delete(key)? { + // DELETE of non-existent key is not an error + Response::Success(_) | Response::NotFound => Ok(exit_code::SUCCESS), + } +} diff --git a/cli/mdata-client/src/bin/mdata_get.rs b/cli/mdata-client/src/bin/mdata_get.rs new file mode 100644 index 00000000..e90b2226 --- /dev/null +++ b/cli/mdata-client/src/bin/mdata_get.rs @@ -0,0 +1,61 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! mdata-get: Retrieve the value of a metadata key. +//! +//! Usage: mdata-get +//! +//! Exit codes: +//! 0 - Success (value printed to stdout) +//! 1 - Key not found +//! 2 - Error +//! 3 - Usage error + +use mdata_client::protocol::Protocol; +use mdata_client::{Command, Response, exit_code}; + +fn main() { + mdata_client::init_logging(); + match run() { + Ok(code) => std::process::exit(code), + Err(e) => { + eprintln!("ERROR: {e}"); + std::process::exit(exit_code::ERROR); + } + } +} + +fn run() -> anyhow::Result { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + eprintln!( + "Usage: {} ", + args.first().map(String::as_str).unwrap_or("mdata-get"), + ); + return Ok(exit_code::USAGE_ERROR); + } + + let key = &args[1]; + let mut proto = Protocol::init()?; + + match proto.execute(Command::Get, Some(key))? { + Response::Success(Some(data)) => { + print!("{data}"); + if !data.ends_with('\n') { + println!(); + } + Ok(exit_code::SUCCESS) + } + Response::Success(None) => { + println!(); + Ok(exit_code::SUCCESS) + } + Response::NotFound => { + eprintln!("No metadata for '{key}'"); + Ok(exit_code::NOT_FOUND) + } + } +} diff --git a/cli/mdata-client/src/bin/mdata_list.rs b/cli/mdata-client/src/bin/mdata_list.rs new file mode 100644 index 00000000..1decea29 --- /dev/null +++ b/cli/mdata-client/src/bin/mdata_list.rs @@ -0,0 +1,56 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! mdata-list: List all available metadata keys. +//! +//! Usage: mdata-list +//! +//! Exit codes: +//! 0 - Success (keys printed to stdout, one per line) +//! 2 - Error +//! 3 - Usage error + +use mdata_client::protocol::Protocol; +use mdata_client::{Command, Response, exit_code}; + +fn main() { + mdata_client::init_logging(); + match run() { + Ok(code) => std::process::exit(code), + Err(e) => { + eprintln!("ERROR: {e}"); + std::process::exit(exit_code::ERROR); + } + } +} + +fn run() -> anyhow::Result { + let args: Vec = std::env::args().collect(); + if args.len() != 1 { + eprintln!( + "Usage: {}", + args.first().map(String::as_str).unwrap_or("mdata-list"), + ); + return Ok(exit_code::USAGE_ERROR); + } + + let mut proto = Protocol::init()?; + + match proto.execute(Command::Keys, None)? { + Response::Success(Some(data)) => { + print!("{data}"); + if !data.ends_with('\n') { + println!(); + } + Ok(exit_code::SUCCESS) + } + Response::Success(None) => Ok(exit_code::SUCCESS), + Response::NotFound => { + eprintln!("ERROR: unexpected NOTFOUND response for KEYS"); + Ok(exit_code::ERROR) + } + } +} diff --git a/cli/mdata-client/src/bin/mdata_put.rs b/cli/mdata-client/src/bin/mdata_put.rs new file mode 100644 index 00000000..17fd5b98 --- /dev/null +++ b/cli/mdata-client/src/bin/mdata_put.rs @@ -0,0 +1,72 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! mdata-put: Set the value of a metadata key. +//! +//! Usage: mdata-put [] +//! +//! If is not provided, reads from stdin (only when stdin is +//! not a terminal). +//! +//! Requires V2 protocol support from the metadata service. +//! +//! Exit codes: +//! 0 - Success +//! 2 - Error +//! 3 - Usage error + +use std::io::{IsTerminal, Read}; + +use mdata_client::protocol::Protocol; +use mdata_client::{Response, exit_code}; + +fn main() { + mdata_client::init_logging(); + match run() { + Ok(code) => std::process::exit(code), + Err(e) => { + eprintln!("ERROR: {e}"); + std::process::exit(exit_code::ERROR); + } + } +} + +fn run() -> anyhow::Result { + let args: Vec = std::env::args().collect(); + let progname = args.first().map(String::as_str).unwrap_or("mdata-put"); + + if args.len() < 2 || args.len() > 3 { + eprintln!("Usage: {progname} []"); + return Ok(exit_code::USAGE_ERROR); + } + + let key = &args[1]; + + // Get value from argument or stdin + let value = if args.len() == 3 { + args[2].clone() + } else if !std::io::stdin().is_terminal() { + let mut buf = String::new(); + std::io::stdin().read_to_string(&mut buf)?; + buf + } else { + eprintln!( + "Usage: {progname} []\n\ + ERROR: either specify value as argument or pipe via stdin" + ); + return Ok(exit_code::USAGE_ERROR); + }; + + let mut proto = Protocol::init()?; + + match proto.put(key, &value)? { + Response::Success(_) => Ok(exit_code::SUCCESS), + Response::NotFound => { + eprintln!("ERROR: unexpected NOTFOUND response for PUT"); + Ok(exit_code::ERROR) + } + } +} diff --git a/cli/mdata-client/src/lib.rs b/cli/mdata-client/src/lib.rs new file mode 100644 index 00000000..717a9e6a --- /dev/null +++ b/cli/mdata-client/src/lib.rs @@ -0,0 +1,74 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! Rust implementation of the SmartOS metadata protocol client. +//! +//! This crate implements the V1/V2 metadata protocol used by SmartOS +//! zones and KVM guests to communicate with the host metadata service. +//! It supports communication over Unix domain sockets (zones) and +//! serial ports (KVM/HVM guests). + +use std::fmt; + +pub mod protocol; +pub mod transport; + +/// Initialize tracing for mdata-client tools. +/// +/// Enables debug output when `MDATA_DEBUG=1` is set, otherwise only +/// warnings. Output goes to stderr to avoid interfering with stdout +/// data (which callers may be parsing). +pub fn init_logging() { + let filter = if std::env::var("MDATA_DEBUG").is_ok_and(|v| v == "1") { + "mdata_client=debug" + } else { + "mdata_client=warn" + }; + tracing_subscriber::fmt() + .with_env_filter(filter) + .with_writer(std::io::stderr) + .without_time() + .with_target(false) + .init(); +} + +/// Metadata protocol commands. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Command { + Get, + Put, + Delete, + Keys, +} + +impl fmt::Display for Command { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Command::Get => "GET", + Command::Put => "PUT", + Command::Delete => "DELETE", + Command::Keys => "KEYS", + }) + } +} + +/// Exit codes matching the original C mdata-client implementation. +pub mod exit_code { + pub const SUCCESS: i32 = 0; + pub const NOT_FOUND: i32 = 1; + pub const ERROR: i32 = 2; + pub const USAGE_ERROR: i32 = 3; +} + +/// Response from a metadata operation. +#[derive(Clone, Debug, PartialEq)] +#[must_use] +pub enum Response { + /// Operation succeeded, with optional data payload. + Success(Option), + /// Key was not found. + NotFound, +} diff --git a/cli/mdata-client/src/protocol.rs b/cli/mdata-client/src/protocol.rs new file mode 100644 index 00000000..9ee0c7fa --- /dev/null +++ b/cli/mdata-client/src/protocol.rs @@ -0,0 +1,645 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! Metadata protocol implementation (V1 and V2). +//! +//! The protocol supports two versions: +//! +//! **V1**: Simple text commands (`COMMAND [ARG]\n`) with multi-line +//! responses terminated by a `.` line. +//! +//! **V2**: Framed protocol with BASE64 encoding and CRC32 checksums. +//! Format: `V2 []\n` +//! +//! V2 is negotiated automatically on connection. PUT and DELETE +//! operations require V2. +//! +//! ## Error handling +//! +//! Transport methods return `Result<_, TransportError>` for structured +//! I/O errors (timeout, EOF, invalid data). Protocol methods return +//! `anyhow::Result` to add contextual information. `TransportError` +//! converts to `anyhow::Error` via thiserror's `#[error]` derive. + +use std::fmt; +use std::thread; +use std::time::Duration; + +use anyhow::{Context, Result, bail}; +use base64::Engine as _; +use base64::engine::general_purpose::STANDARD; +use tracing::{debug, warn}; + +use crate::transport::{MetadataTransport, Transport, TransportError}; +use crate::{Command, Response}; + +/// Timeout for V1 commands and protocol negotiation (6 seconds). +const RECV_TIMEOUT_MS: u64 = 6_000; + +/// Timeout for V2 operations (45 seconds, allows for slower PUT). +const RECV_TIMEOUT_MS_V2: u64 = 45_000; + +/// Maximum number of timeout-and-reset retries before giving up. +const MAX_RETRIES: u32 = 3; + +/// Maximum number of stale V2 frames to discard before giving up. +const MAX_STALE_FRAMES: u32 = 5; + +/// Negotiated protocol version. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ProtocolVersion { + V1, + V2, +} + +/// Protocol handler for metadata operations. +pub struct Protocol { + transport: T, + version: ProtocolVersion, +} + +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Protocol") + .field("version", &self.version) + .finish_non_exhaustive() + } +} + +impl Protocol { + /// Initialize: open transport, negotiate protocol version. + pub fn init() -> Result { + let mut transport = Transport::open()?; + let version = Self::negotiate(&mut transport)?; + Ok(Self { transport, version }) + } +} + +impl Protocol { + /// Create a protocol handler with an existing transport. + #[cfg(test)] + pub fn with_transport(mut transport: T) -> Result { + let version = Self::negotiate(&mut transport)?; + Ok(Self { transport, version }) + } + + /// Execute a DELETE command. + /// + /// Requires V2 protocol support. + pub fn delete(&mut self, key: &str) -> Result { + if self.version != ProtocolVersion::V2 { + bail!( + "metadata service does not support V2 protocol \ + (required for DELETE)" + ); + } + self.execute(Command::Delete, Some(key)) + } + + /// Execute a PUT command, encoding the key and value per protocol. + /// + /// The V2 PUT wire format uses double base64 encoding: + /// the key and value are each individually base64-encoded, joined + /// by a space, and then the combined string is base64-encoded again + /// as the V2 frame argument. This matches the original C mdata-client. + /// + /// On the wire: `V2 PUT ` + pub fn put(&mut self, key: &str, value: &str) -> Result { + if self.version != ProtocolVersion::V2 { + bail!( + "metadata service does not support V2 protocol \ + (required for PUT)" + ); + } + let arg = format!("{} {}", STANDARD.encode(key), STANDARD.encode(value)); + self.execute(Command::Put, Some(&arg)) + } + + /// Execute a metadata command with automatic retry on timeout. + /// + /// On timeout, the protocol is reset (transport reconnected and + /// V2 re-negotiated) and the command is retried. + pub fn execute(&mut self, command: Command, arg: Option<&str>) -> Result { + let mut retries = 0; + loop { + match self.try_execute(command, arg) { + Ok(response) => return Ok(response), + Err(e) => { + if is_timeout(&e) { + retries += 1; + if retries > MAX_RETRIES { + bail!( + "giving up after {MAX_RETRIES} \ + timeout retries" + ); + } + warn!( + "receive timeout, resetting \ + protocol (attempt {retries}/{MAX_RETRIES})" + ); + self.reset()?; + continue; + } + return Err(e); + } + } + } + } + + fn try_execute(&mut self, command: Command, arg: Option<&str>) -> Result { + match self.version { + ProtocolVersion::V2 => self.execute_v2(command, arg), + ProtocolVersion::V1 => self.execute_v1(command, arg), + } + } + + /// Execute a V1 protocol command. + fn execute_v1(&mut self, command: Command, arg: Option<&str>) -> Result { + let request = match arg { + Some(a) => format!("{command} {a}\n"), + None => format!("{command}\n"), + }; + + self.transport.send(&request)?; + + let header = self.transport.recv_line(RECV_TIMEOUT_MS)?; + + match header.as_str() { + "SUCCESS" => { + let mut data = String::new(); + loop { + let line = self.transport.recv_line(RECV_TIMEOUT_MS)?; + if line == "." { + break; + } + if !data.is_empty() { + data.push('\n'); + } + data.push_str(&line); + } + if data.is_empty() { + Ok(Response::Success(None)) + } else { + Ok(Response::Success(Some(data))) + } + } + "NOTFOUND" => Ok(Response::NotFound), + other => bail!("unexpected V1 response: {other}"), + } + } + + /// Execute a V2 protocol command. + fn execute_v2(&mut self, command: Command, arg: Option<&str>) -> Result { + let reqid = generate_request_id()?; + + let body = match arg { + Some(a) => { + let b64_arg = STANDARD.encode(a); + format!("{reqid} {command} {b64_arg}") + } + None => format!("{reqid} {command}"), + }; + + let crc = crc32fast::hash(body.as_bytes()); + let request = format!("V2 {} {crc:08x} {body}\n", body.len()); + + self.transport.send(&request)?; + + // Read V2 response, discarding stale frames from + // previous timed-out requests (mismatched request IDs) + let mut stale_count = 0u32; + loop { + let line = self.transport.recv_line(RECV_TIMEOUT_MS_V2)?; + match parse_v2_frame(&line, &reqid) { + Ok(frame) => { + return match frame.status.as_str() { + "SUCCESS" => { + let data = frame.payload.map(|p| decode_b64_payload(&p)).transpose()?; + Ok(Response::Success(data)) + } + "NOTFOUND" => Ok(Response::NotFound), + other => { + bail!("unexpected V2 status: {other}") + } + }; + } + Err(FrameError::ReqIdMismatch { .. }) => { + stale_count += 1; + if stale_count > MAX_STALE_FRAMES { + bail!( + "too many stale V2 frames \ + ({MAX_STALE_FRAMES}), giving up" + ); + } + continue; + } + Err(FrameError::Other(e)) => return Err(e), + } + } + } + + /// Reset the protocol: reconnect transport and re-negotiate. + fn reset(&mut self) -> Result<()> { + thread::sleep(Duration::from_secs(1)); + self.transport.reconnect()?; + self.version = Self::negotiate(&mut self.transport)?; + Ok(()) + } + + /// Negotiate protocol version with the metadata service. + /// + /// For serial transports, sends a reset sequence first (`\n` -> + /// `invalid command`) to clear any stale state on the port. + fn negotiate(transport: &mut T) -> Result { + if transport.is_serial() { + // Serial port reset: send a bare newline, expect + // "invalid command" response to confirm port is alive + transport.send("\n").ok(); + match transport.recv_line(RECV_TIMEOUT_MS) { + Ok(_) => {} // Discard response (usually "invalid command") + Err(TransportError::Timeout) => { + // Port may not be responsive yet, continue anyway + } + Err(TransportError::Eof) => { + bail!("serial port closed during reset sequence"); + } + Err(TransportError::Io(e)) => { + bail!("serial port I/O error during reset: {e}"); + } + Err(TransportError::InvalidData) => {} + } + } + + // Attempt V2 negotiation + transport.send("NEGOTIATE V2\n")?; + match transport.recv_line(RECV_TIMEOUT_MS) { + Ok(ref line) if line == "V2_OK" => { + debug!("negotiated V2 protocol"); + Ok(ProtocolVersion::V2) + } + Ok(ref line) if line == "invalid command" => { + debug!("V2 not supported, falling back to V1"); + Ok(ProtocolVersion::V1) + } + Ok(other) => { + bail!("unexpected negotiation response: {other}") + } + Err(TransportError::Timeout) => { + bail!("timeout during protocol negotiation") + } + Err(e) => Err(e.into()), + } + } +} + +/// Errors from V2 frame parsing. +#[derive(Debug, thiserror::Error)] +enum FrameError { + #[error("V2 request ID mismatch: expected {expected}, got {actual}")] + ReqIdMismatch { expected: String, actual: String }, + #[error("{0}")] + Other(anyhow::Error), +} + +/// A parsed V2 protocol frame. +#[derive(Debug)] +struct V2Frame { + status: String, + payload: Option, +} + +/// Parse a V2 response frame and validate its integrity. +/// +/// Frame format: `V2 []` +fn parse_v2_frame(line: &str, expected_reqid: &str) -> std::result::Result { + match parse_v2_body(line) { + Ok((reqid, status, payload)) => { + if reqid != expected_reqid { + return Err(FrameError::ReqIdMismatch { + expected: expected_reqid.to_string(), + actual: reqid, + }); + } + Ok(V2Frame { status, payload }) + } + Err(e) => Err(FrameError::Other(e)), + } +} + +/// Parse the envelope and body of a V2 frame, validating length and CRC. +/// +/// Returns `(request_id, status, optional_payload)`. +fn parse_v2_body(line: &str) -> Result<(String, String, Option)> { + let mut parts = line.splitn(4, ' '); + + let marker = parts.next(); + let len_str = parts.next(); + let crc_str = parts.next(); + let body = parts.next(); + + let (Some("V2"), Some(len_str), Some(crc_str), Some(body)) = (marker, len_str, crc_str, body) + else { + bail!("invalid V2 frame: expected 'V2 '"); + }; + + let expected_len: usize = len_str + .parse() + .map_err(|_| anyhow::anyhow!("invalid V2 frame length: {len_str}"))?; + + let expected_crc = u32::from_str_radix(crc_str, 16) + .map_err(|_| anyhow::anyhow!("invalid V2 frame CRC: {crc_str}"))?; + + if body.len() != expected_len { + bail!( + "V2 frame length mismatch: header says {expected_len}, body is {}", + body.len() + ); + } + + let actual_crc = crc32fast::hash(body.as_bytes()); + if actual_crc != expected_crc { + bail!("V2 frame CRC mismatch: expected {expected_crc:08x}, got {actual_crc:08x}"); + } + + let mut body_parts = body.splitn(3, ' '); + let reqid = body_parts + .next() + .ok_or_else(|| anyhow::anyhow!("missing request ID in V2 frame"))?; + let status = body_parts + .next() + .ok_or_else(|| anyhow::anyhow!("missing status in V2 frame"))?; + let payload = body_parts.next().map(String::from); + + Ok((reqid.to_string(), status.to_string(), payload)) +} + +/// Decode a BASE64-encoded payload string. +fn decode_b64_payload(encoded: &str) -> Result { + let bytes = STANDARD + .decode(encoded) + .context("invalid base64 in response")?; + String::from_utf8(bytes).context("response payload is not valid UTF-8") +} + +/// Generate an 8-character hex request ID for V2 protocol frames. +fn generate_request_id() -> Result { + let mut buf = [0u8; 4]; + getrandom::fill(&mut buf) + .map_err(|e| anyhow::anyhow!("getrandom: {e}")) + .context("failed to generate request ID")?; + Ok(format!("{:08x}", u32::from_ne_bytes(buf))) +} + +/// Check if an error is a transport timeout. +fn is_timeout(e: &anyhow::Error) -> bool { + e.downcast_ref::() + .is_some_and(|te| matches!(te, TransportError::Timeout)) +} + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used)] +mod tests { + use std::cell::RefCell; + + use super::*; + + /// Mock transport that replays scripted responses. + struct MockTransport { + responses: RefCell>, + sent: RefCell>, + serial: bool, + } + + impl MockTransport { + fn new(responses: Vec<&str>, serial: bool) -> Self { + // Reverse so we can pop from the end + let responses = responses.into_iter().rev().map(String::from).collect(); + Self { + responses: RefCell::new(responses), + sent: RefCell::new(Vec::new()), + serial, + } + } + + fn sent_lines(&self) -> Vec { + self.sent.borrow().clone() + } + } + + impl MetadataTransport for MockTransport { + fn send(&self, data: &str) -> Result<(), TransportError> { + self.sent.borrow_mut().push(data.to_string()); + Ok(()) + } + + fn recv_line(&self, _timeout_ms: u64) -> Result { + self.responses.borrow_mut().pop().ok_or(TransportError::Eof) + } + + fn reconnect(&mut self) -> anyhow::Result<()> { + Ok(()) + } + + fn is_serial(&self) -> bool { + self.serial + } + } + + #[test] + fn test_generate_request_id_format() { + let id = generate_request_id().unwrap(); + assert_eq!(id.len(), 8); + assert!(id.chars().all(|c| c.is_ascii_hexdigit())); + } + + #[test] + fn test_parse_v2_frame_valid() { + let reqid = "dc4fae17"; + let status = "SUCCESS"; + let payload = STANDARD.encode("hello world"); + let body = format!("{reqid} {status} {payload}"); + let crc = crc32fast::hash(body.as_bytes()); + let frame = format!("V2 {} {crc:08x} {body}", body.len()); + + let f = parse_v2_frame(&frame, reqid).unwrap(); + assert_eq!(f.status, "SUCCESS"); + let decoded = decode_b64_payload(&f.payload.unwrap()).unwrap(); + assert_eq!(decoded, "hello world"); + } + + #[test] + fn test_parse_v2_frame_notfound() { + let reqid = "abcd1234"; + let body = format!("{reqid} NOTFOUND"); + let crc = crc32fast::hash(body.as_bytes()); + let frame = format!("V2 {} {crc:08x} {body}", body.len()); + + let f = parse_v2_frame(&frame, reqid).unwrap(); + assert_eq!(f.status, "NOTFOUND"); + assert!(f.payload.is_none()); + } + + #[test] + fn test_parse_v2_frame_bad_crc() { + let reqid = "dc4fae17"; + let body = format!("{reqid} SUCCESS"); + let frame = format!("V2 {} 00000000 {body}", body.len()); + + let err = parse_v2_frame(&frame, reqid).unwrap_err(); + assert!(matches!(err, FrameError::Other(_))); + assert!(format!("{err}").contains("CRC mismatch")); + } + + #[test] + fn test_parse_v2_frame_wrong_reqid() { + let reqid = "dc4fae17"; + let body = format!("{reqid} SUCCESS"); + let crc = crc32fast::hash(body.as_bytes()); + let frame = format!("V2 {} {crc:08x} {body}", body.len()); + + let err = parse_v2_frame(&frame, "00000000").unwrap_err(); + assert!(matches!(err, FrameError::ReqIdMismatch { .. })); + } + + #[test] + fn test_parse_v2_frame_bad_length() { + let reqid = "dc4fae17"; + let body = format!("{reqid} SUCCESS"); + let crc = crc32fast::hash(body.as_bytes()); + let frame = format!("V2 99 {crc:08x} {body}"); + + let err = parse_v2_frame(&frame, reqid).unwrap_err(); + assert!(matches!(err, FrameError::Other(_))); + assert!(format!("{err}").contains("length mismatch")); + } + + #[test] + fn test_v1_get_success() { + let mock = MockTransport::new( + vec![ + "SUCCESS", // V1 response header + "hello world", // response data + ".", // terminator + ], + false, + ); + let mut proto = Protocol { + transport: mock, + version: ProtocolVersion::V1, + }; + + let resp = proto.execute(Command::Get, Some("mykey")).unwrap(); + match resp { + Response::Success(Some(data)) => assert_eq!(data, "hello world"), + other => panic!("expected Success(Some), got {other:?}"), + } + + let sent = proto.transport.sent_lines(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0], "GET mykey\n"); + } + + #[test] + fn test_v1_get_notfound() { + let mock = MockTransport::new(vec!["NOTFOUND"], false); + let mut proto = Protocol { + transport: mock, + version: ProtocolVersion::V1, + }; + + let resp = proto.execute(Command::Get, Some("nokey")).unwrap(); + assert!(matches!(resp, Response::NotFound)); + } + + #[test] + fn test_negotiate_v2() { + let mock = MockTransport::new(vec!["V2_OK"], false); + let proto = Protocol::with_transport(mock).unwrap(); + assert_eq!(proto.version, ProtocolVersion::V2); + } + + #[test] + fn test_negotiate_v1_fallback() { + let mock = MockTransport::new(vec!["invalid command"], false); + let proto = Protocol::with_transport(mock).unwrap(); + assert_eq!(proto.version, ProtocolVersion::V1); + } + + #[test] + fn test_serial_negotiation_sends_reset() { + let mock = MockTransport::new( + vec![ + "invalid command", // response to \n reset + "V2_OK", // response to NEGOTIATE V2 + ], + true, // serial + ); + let proto = Protocol::with_transport(mock).unwrap(); + assert_eq!(proto.version, ProtocolVersion::V2); + + let sent = proto.transport.sent_lines(); + assert_eq!(sent.len(), 2); + assert_eq!(sent[0], "\n"); // reset sequence + assert_eq!(sent[1], "NEGOTIATE V2\n"); + } + + #[test] + fn test_v2_get_success() { + // Build a V2 response frame for the mock. + // We don't know the reqid in advance, so we need + // to inspect what was sent and build the response + // dynamically. Instead, test at the frame parsing + // level — the execute_v2 → parse_v2_frame path is + // covered by the frame parsing tests + this V1 test + // proving execute() dispatches correctly. + // + // For a true V2 end-to-end test, we'd need a mock + // that inspects the request and echoes the reqid. + // Test the encoding logic directly instead: + let key = "test-key"; + let value = "test-value"; + let expected = format!("{} {}", STANDARD.encode(key), STANDARD.encode(value)); + + // Verify the PUT argument encoding matches the spec + let outer = STANDARD.encode(&expected); + let decoded_outer = String::from_utf8(STANDARD.decode(&outer).unwrap()).unwrap(); + assert_eq!(decoded_outer, expected); + + // Decode the inner key and value + let parts: Vec<&str> = decoded_outer.splitn(2, ' ').collect(); + let decoded_key = String::from_utf8(STANDARD.decode(parts[0]).unwrap()).unwrap(); + let decoded_val = String::from_utf8(STANDARD.decode(parts[1]).unwrap()).unwrap(); + assert_eq!(decoded_key, key); + assert_eq!(decoded_val, value); + } + + #[test] + fn test_v2_stale_frame_discarded() { + // A stale frame has a mismatched reqid — parse_v2_frame + // should return ReqIdMismatch, and execute_v2 should + // skip it and read the next frame. + let reqid = "aabbccdd"; + let stale_reqid = "00000000"; + + // Build a stale frame + let stale_body = format!("{stale_reqid} SUCCESS"); + let stale_crc = crc32fast::hash(stale_body.as_bytes()); + let stale_frame = format!("V2 {} {stale_crc:08x} {stale_body}", stale_body.len()); + + // Build the correct frame + let good_body = format!("{reqid} SUCCESS"); + let good_crc = crc32fast::hash(good_body.as_bytes()); + let good_frame = format!("V2 {} {good_crc:08x} {good_body}", good_body.len()); + + // Parsing the stale frame with the good reqid should error + let err = parse_v2_frame(&stale_frame, reqid).unwrap_err(); + assert!(matches!(err, FrameError::ReqIdMismatch { .. })); + + // Parsing the good frame should succeed + let frame = parse_v2_frame(&good_frame, reqid).unwrap(); + assert_eq!(frame.status, "SUCCESS"); + } +} diff --git a/cli/mdata-client/src/transport/mod.rs b/cli/mdata-client/src/transport/mod.rs new file mode 100644 index 00000000..dd7207b6 --- /dev/null +++ b/cli/mdata-client/src/transport/mod.rs @@ -0,0 +1,90 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! Transport layer for the metadata protocol. +//! +//! Supports Unix domain sockets (for SmartOS zones) and serial ports +//! (for KVM/HVM guests on Unix and Windows). Platform detection +//! automatically selects the appropriate transport. +//! +//! The socket transport uses safe Rust I/O (`UnixStream`) exclusively. +//! The serial transport requires `unsafe` only for terminal +//! configuration (termios), file locking (fcntl), poll-based timeouts, +//! and Windows serial port setup — operations with no safe Rust +//! equivalent. + +use std::io; +use std::path::PathBuf; + +#[cfg(unix)] +mod unix; + +#[cfg(windows)] +mod windows; + +/// Errors specific to the transport layer. +#[derive(Debug, thiserror::Error)] +pub enum TransportError { + #[error("timed out waiting for response")] + Timeout, + #[error("connection closed unexpectedly")] + Eof, + #[error("invalid UTF-8 in response")] + InvalidData, + #[error("I/O error: {0}")] + Io(#[from] io::Error), +} + +/// Interface for metadata protocol transports. +/// +/// Implemented by the platform-specific `Transport` and by +/// `MockTransport` in tests. +pub trait MetadataTransport { + fn send(&self, data: &str) -> Result<(), TransportError>; + fn recv_line(&self, timeout_ms: u64) -> Result; + fn reconnect(&mut self) -> anyhow::Result<()>; + fn is_serial(&self) -> bool; +} + +/// Detected transport configuration. +#[derive(Clone, Debug)] +pub enum TransportConfig { + /// Unix domain socket (SmartOS zone). + #[cfg(unix)] + UnixSocket(PathBuf), + /// Serial port (KVM/HVM guest). + Serial(PathBuf), +} + +/// Line-oriented transport for the metadata protocol. +/// +/// Uses safe Rust I/O for sockets. Serial ports require minimal +/// unsafe for terminal configuration and poll-based timeouts. +pub struct Transport { + config: TransportConfig, + #[cfg(unix)] + inner: unix::TransportInner, + #[cfg(windows)] + file: std::fs::File, +} + +impl MetadataTransport for Transport { + fn send(&self, data: &str) -> Result<(), TransportError> { + Transport::send(self, data) + } + + fn recv_line(&self, timeout_ms: u64) -> Result { + Transport::recv_line(self, timeout_ms) + } + + fn reconnect(&mut self) -> anyhow::Result<()> { + Transport::reconnect(self) + } + + fn is_serial(&self) -> bool { + matches!(self.config, TransportConfig::Serial(_)) + } +} diff --git a/cli/mdata-client/src/transport/unix.rs b/cli/mdata-client/src/transport/unix.rs new file mode 100644 index 00000000..5612ead9 --- /dev/null +++ b/cli/mdata-client/src/transport/unix.rs @@ -0,0 +1,344 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! Unix transport implementation. +//! +//! Socket transport (SmartOS zones) uses safe Rust I/O exclusively. +//! Serial transport (KVM/HVM guests) uses safe I/O for read/write +//! and requires unsafe only for: +//! - `poll()` — timeout-based readability check (no safe equivalent +//! for `File`) +//! - `tcgetattr`/`tcsetattr` — terminal raw mode configuration +//! - `fcntl(F_SETLK)` — exclusive file locking +//! - `fcntl(F_SETFL)` — clearing O_NONBLOCK after setup +//! - `tcflush` — flushing pending serial data + +use std::fs::{File, OpenOptions}; +use std::io::{self, Read, Write}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::path::{Path, PathBuf}; +use std::time::{Duration, Instant}; + +use anyhow::{Context, Result, bail}; +use tracing::debug; + +use super::{Transport, TransportConfig, TransportError}; + +/// Platform-specific transport inner type. +pub(super) enum TransportInner { + Socket(UnixStream), + Serial(File), +} + +impl Transport { + /// Detect the appropriate transport and open it. + pub fn open() -> Result { + let config = detect_transport()?; + let inner = open_transport(&config)?; + Ok(Self { config, inner }) + } + + /// Send a string over the transport. + pub fn send(&self, data: &str) -> Result<(), TransportError> { + let bytes = data.as_bytes(); + match &self.inner { + TransportInner::Socket(stream) => { + (&*stream).write_all(bytes).map_err(TransportError::Io) + } + TransportInner::Serial(file) => (&*file).write_all(bytes).map_err(TransportError::Io), + } + } + + /// Receive a single line (terminated by `\n`) with a timeout. + /// + /// Returns the line content without the trailing newline. + pub fn recv_line(&self, timeout_ms: u64) -> Result { + match &self.inner { + TransportInner::Socket(stream) => recv_line_socket(stream, timeout_ms), + TransportInner::Serial(file) => recv_line_serial(file, timeout_ms), + } + } + + /// Close and reopen the transport for protocol reset. + pub fn reconnect(&mut self) -> Result<()> { + // Dropping the old inner closes the fd automatically. + self.inner = open_transport(&self.config)?; + Ok(()) + } +} + +// No custom Drop needed — UnixStream and File close their fds on drop. + +/// Receive a line from a Unix socket using `set_read_timeout`. +/// +/// Fully safe — no unsafe required. +fn recv_line_socket(stream: &UnixStream, timeout_ms: u64) -> Result { + let deadline = Instant::now() + Duration::from_millis(timeout_ms); + let mut line = Vec::new(); + let mut byte = [0u8; 1]; + + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Err(TransportError::Timeout); + } + + stream + .set_read_timeout(Some(remaining)) + .map_err(TransportError::Io)?; + + match (&*stream).read(&mut byte) { + Ok(0) => return Err(TransportError::Eof), + Ok(_) => { + if byte[0] == b'\n' { + return String::from_utf8(line).map_err(|_| TransportError::InvalidData); + } + line.push(byte[0]); + } + Err(e) + if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => + { + return Err(TransportError::Timeout); + } + Err(e) if e.kind() == io::ErrorKind::Interrupted => { + continue; + } + Err(e) => return Err(TransportError::Io(e)), + } + } +} + +/// Receive a line from a serial port using `poll()` for timeouts. +/// +/// `File` has no `set_read_timeout`, so we use `poll()` to wait for +/// readability before each `read()` call. +fn recv_line_serial(file: &File, timeout_ms: u64) -> Result { + let deadline = Instant::now() + Duration::from_millis(timeout_ms); + let mut line = Vec::new(); + let mut byte = [0u8; 1]; + + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Err(TransportError::Timeout); + } + + let remaining_ms = remaining.as_millis().min(i32::MAX as u128) as i32; + + if !poll_readable(file.as_raw_fd(), remaining_ms)? { + return Err(TransportError::Timeout); + } + + match (&*file).read(&mut byte) { + Ok(0) => return Err(TransportError::Eof), + Ok(_) => { + if byte[0] == b'\n' { + return String::from_utf8(line).map_err(|_| TransportError::InvalidData); + } + line.push(byte[0]); + } + Err(e) if e.kind() == io::ErrorKind::Interrupted => { + continue; + } + Err(e) => return Err(TransportError::Io(e)), + } + } +} + +/// Poll a file descriptor for readability with a timeout. +fn poll_readable(fd: RawFd, timeout_ms: i32) -> Result { + let mut pfd = libc::pollfd { + fd, + events: libc::POLLIN, + revents: 0, + }; + loop { + // SAFETY: pfd is a stack-allocated pollfd struct with a valid + // fd obtained from File::as_raw_fd(). nfds is 1, matching + // the single pollfd. timeout_ms is bounded by i32::MAX. + let ret = unsafe { libc::poll(&mut pfd, 1, timeout_ms) }; + if ret < 0 { + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::Interrupted { + continue; + } + return Err(TransportError::Io(err)); + } + if ret == 0 { + return Ok(false); + } + if pfd.revents & (libc::POLLERR | libc::POLLHUP | libc::POLLNVAL) != 0 { + return Err(TransportError::Io(io::Error::new( + io::ErrorKind::ConnectionReset, + "poll returned error condition", + ))); + } + return Ok(true); + } +} + +/// Detect the appropriate transport for this platform. +fn detect_transport() -> Result { + // SmartOS zone socket paths (tried in order) + let socket_paths = [ + "/.zonecontrol/metadata.sock", + "/native/.zonecontrol/metadata.sock", + "/var/run/smartdc/metadata.sock", + ]; + + for path in &socket_paths { + if Path::new(path).exists() { + debug!("detected unix socket transport: {path}"); + return Ok(TransportConfig::UnixSocket(PathBuf::from(path))); + } + } + + // Serial ports for KVM/HVM guests + let serial_paths = [ + "/dev/term/b", // illumos/SmartOS + "/dev/ttyS1", // Linux + "/dev/tty01", // NetBSD + "/dev/cua01", // OpenBSD + "/dev/cuau1", // FreeBSD + ]; + + for path in &serial_paths { + if Path::new(path).exists() { + debug!("detected serial transport: {path}"); + return Ok(TransportConfig::Serial(PathBuf::from(path))); + } + } + + bail!( + "no metadata transport found; tried sockets ({}) \ + and serial ports ({})", + socket_paths.join(", "), + serial_paths.join(", "), + ) +} + +/// Open the detected transport. +fn open_transport(config: &TransportConfig) -> Result { + match config { + TransportConfig::UnixSocket(path) => { + let stream = UnixStream::connect(path) + .with_context(|| format!("connecting to metadata socket: {}", path.display()))?; + Ok(TransportInner::Socket(stream)) + } + TransportConfig::Serial(path) => { + let file = open_serial(path)?; + Ok(TransportInner::Serial(file)) + } + } +} + +/// Open and configure a serial port. +/// +/// Uses `std::fs::File` for the open; unsafe is limited to terminal +/// configuration and locking (no safe Rust equivalent). +fn open_serial(path: &Path) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_NOCTTY | libc::O_NONBLOCK) + .open(path) + .map_err(|err| { + if err.kind() == io::ErrorKind::PermissionDenied { + anyhow::anyhow!( + "permission denied opening {}: \ + are you running as root?", + path.display() + ) + } else { + anyhow::anyhow!("opening serial port {}: {}", path.display(), err) + } + })?; + + let fd = file.as_raw_fd(); + + // Acquire an exclusive lock to prevent concurrent access. + // If this fails, `file` is dropped which closes the fd. + acquire_exclusive_lock(fd, path)?; + + // Configure raw mode (no echo, no canonical, 8-bit, etc.). + // If this fails, `file` is dropped which closes the fd. + configure_serial_raw(fd)?; + + // Clear O_NONBLOCK now that setup is done. + // SAFETY: fd is valid (owned by `file`). F_GETFL/F_SETFL + // only modify the file status flags on our own descriptor. + let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) }; + if flags >= 0 { + let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags & !libc::O_NONBLOCK) }; + if ret < 0 { + debug!("failed to clear O_NONBLOCK: {}", io::Error::last_os_error()); + } + } + + // Flush any pending data from previous sessions. + // SAFETY: fd is valid, flushing both input and output queues. + if unsafe { libc::tcflush(fd, libc::TCIOFLUSH) } < 0 { + debug!("tcflush failed: {}", io::Error::last_os_error()); + } + + Ok(file) +} + +/// Acquire an exclusive (F_WRLCK) lock on the serial port fd. +fn acquire_exclusive_lock(fd: RawFd, path: &Path) -> Result<()> { + // SAFETY: libc::flock is a plain C struct; all-zeros is a valid + // initial state (no lock, offset 0, length 0). + let mut flock_val: libc::flock = unsafe { std::mem::zeroed() }; + #[allow(clippy::unnecessary_cast)] + { + flock_val.l_type = libc::F_WRLCK as i16; + } + flock_val.l_whence = libc::SEEK_SET as i16; + + // SAFETY: fd is valid (owned by caller's File), flock_val is + // properly initialized. F_SETLK is a non-blocking lock attempt. + if unsafe { libc::fcntl(fd, libc::F_SETLK, &flock_val) } < 0 { + let err = io::Error::last_os_error(); + bail!( + "failed to lock serial port {} \ + (another mdata process may be running): {}", + path.display(), + err, + ); + } + + Ok(()) +} + +/// Configure a serial port for raw (non-canonical) I/O. +fn configure_serial_raw(fd: RawFd) -> Result<()> { + // SAFETY: libc::termios is a plain C struct; all-zeros is valid. + let mut tios: libc::termios = unsafe { std::mem::zeroed() }; + + // SAFETY: fd is valid, tios is a properly sized stack buffer. + if unsafe { libc::tcgetattr(fd, &mut tios) } < 0 { + bail!("tcgetattr failed: {}", io::Error::last_os_error()); + } + + tios.c_iflag &= !(libc::BRKINT | libc::ICRNL | libc::INPCK | libc::ISTRIP | libc::IXON); + tios.c_oflag &= !libc::OPOST; + tios.c_cflag |= libc::CS8; + tios.c_cflag &= !libc::HUPCL; + tios.c_lflag &= !(libc::ECHO | libc::ICANON | libc::IEXTEN | libc::ISIG); + tios.c_cc[libc::VMIN] = 0; + tios.c_cc[libc::VTIME] = 1; + + // SAFETY: fd is valid, tios is properly initialized from + // tcgetattr + our modifications. TCSAFLUSH drains output + // and discards pending input before applying. + if unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &tios) } < 0 { + bail!("tcsetattr failed: {}", io::Error::last_os_error()); + } + + Ok(()) +} diff --git a/cli/mdata-client/src/transport/windows.rs b/cli/mdata-client/src/transport/windows.rs new file mode 100644 index 00000000..6a06e8ca --- /dev/null +++ b/cli/mdata-client/src/transport/windows.rs @@ -0,0 +1,224 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. +// +// Copyright 2026 Edgecast Cloud LLC. + +//! Windows transport implementation. +//! +//! Communicates with the metadata service over COM2 serial port, +//! matching the transport used by the original mdata-get.exe from +//! sdc-vmtools. +//! +//! Uses `std::fs::File` for open/close/read/write (safe Rust I/O). +//! Unsafe is limited to Win32 serial port configuration APIs +//! (GetCommState, SetCommState, SetCommTimeouts, PurgeComm) which +//! have no safe Rust equivalent. +//! +//! Win32 FFI types (Dcb, CommTimeouts) are defined inline rather than +//! pulling in the `windows` crate — we only need 4 functions and the +//! crate adds ~50 MB of bindings. + +use std::fs::{File, OpenOptions}; +use std::io::{self, Read, Write}; +use std::os::windows::fs::OpenOptionsExt; +use std::os::windows::io::AsRawHandle; +use std::path::PathBuf; +use std::time::{Duration, Instant}; + +use anyhow::{Result, bail}; + +use super::{Transport, TransportConfig, TransportError}; + +/// Opaque handle type matching Windows HANDLE. +type RawHandle = *mut std::ffi::c_void; + +// ── Win32 FFI definitions ────────────────────────────────────── + +/// DCB flags bitmask: only fBinary set (bit 0). +const DCB_FLAGS_BINARY: u32 = 0x0001; + +#[repr(C)] +struct Dcb { + dcb_length: u32, + baud_rate: u32, + flags: u32, + w_reserved: u16, + xon_lim: u16, + xoff_lim: u16, + byte_size: u8, + parity: u8, + stop_bits: u8, + xon_char: i8, + xoff_char: i8, + error_char: i8, + eof_char: i8, + evt_char: i8, + w_reserved1: u16, +} + +#[repr(C)] +struct CommTimeouts { + read_interval_timeout: u32, + read_total_timeout_multiplier: u32, + read_total_timeout_constant: u32, + write_total_timeout_multiplier: u32, + write_total_timeout_constant: u32, +} + +#[link(name = "kernel32")] +unsafe extern "system" { + fn GetCommState(h_file: RawHandle, lp_dcb: *mut Dcb) -> i32; + fn SetCommState(h_file: RawHandle, lp_dcb: *mut Dcb) -> i32; + + fn SetCommTimeouts(h_file: RawHandle, lp_comm_timeouts: *const CommTimeouts) -> i32; + + fn PurgeComm(h_file: RawHandle, dw_flags: u32) -> i32; +} + +/// PURGE_RXCLEAR | PURGE_TXCLEAR +const PURGE_RX_TX: u32 = 0x0004 | 0x0008; + +// ── Transport implementation ─────────────────────────────────── + +impl Transport { + /// Open the COM2 serial port. + /// + /// On Windows, this always uses COM2 (the virtual serial port + /// connected to the SmartOS metadata service). + pub fn open() -> Result { + let config = TransportConfig::Serial(PathBuf::from("\\\\.\\COM2")); + let file = open_serial_port(&config)?; + Ok(Self { config, file }) + } + + /// Send a string over the transport. + pub fn send(&self, data: &str) -> Result<(), TransportError> { + (&self.file) + .write_all(data.as_bytes()) + .map_err(TransportError::Io) + } + + /// Receive a single line (terminated by `\n`) with a timeout. + /// + /// Uses SetCommTimeouts to set the read deadline, then safe + /// `File::read` for the actual I/O. + pub fn recv_line(&self, timeout_ms: u64) -> Result { + let deadline = Instant::now() + Duration::from_millis(timeout_ms); + let mut line = Vec::new(); + let mut byte = [0u8; 1]; + + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Err(TransportError::Timeout); + } + + let remaining_ms = remaining.as_millis().min(u32::MAX as u128) as u32; + + // Set read timeout to remaining time. + // SAFETY: handle is valid (owned by self.file), + // timeouts is a properly initialized stack struct. + let timeouts = CommTimeouts { + read_interval_timeout: 0, + read_total_timeout_multiplier: 0, + read_total_timeout_constant: remaining_ms, + write_total_timeout_multiplier: 0, + write_total_timeout_constant: 5000, + }; + let handle = self.file.as_raw_handle() as RawHandle; + if unsafe { SetCommTimeouts(handle, &timeouts) } == 0 { + return Err(TransportError::Io(io::Error::last_os_error())); + } + + let n = (&self.file).read(&mut byte).map_err(TransportError::Io)?; + + if n == 0 { + return Err(TransportError::Timeout); + } + + if byte[0] == b'\n' { + return String::from_utf8(line).map_err(|_| TransportError::InvalidData); + } + line.push(byte[0]); + } + } + + /// Close and reopen the transport for protocol reset. + pub fn reconnect(&mut self) -> Result<()> { + // Dropping the old file closes the handle automatically. + self.file = open_serial_port(&self.config)?; + Ok(()) + } +} + +// No custom Drop needed — File closes the handle on drop. + +/// Open and configure the serial port for metadata communication. +fn open_serial_port(config: &TransportConfig) -> Result { + let TransportConfig::Serial(path) = config; + + let file = OpenOptions::new() + .read(true) + .write(true) + .share_mode(0) // exclusive access + .open(path) + .map_err( + |err| anyhow::anyhow!("failed to open serial port {}: {}", path.display(), err,), + )?; + + // Configure serial port: 8N1, no flow control. + // If this fails, `file` is dropped which closes the handle. + configure_serial(&file)?; + + // Flush any pending data from previous sessions. + // SAFETY: handle is valid (owned by file), PURGE_RX_TX clears + // both input and output buffers. + let handle = file.as_raw_handle() as RawHandle; + unsafe { PurgeComm(handle, PURGE_RX_TX) }; + + Ok(file) +} + +/// Configure serial port for raw 8N1 communication. +fn configure_serial(file: &File) -> Result<()> { + let handle = file.as_raw_handle() as RawHandle; + + // SAFETY: Dcb is a plain C struct; all-zeros is valid. + let mut dcb: Dcb = unsafe { std::mem::zeroed() }; + dcb.dcb_length = std::mem::size_of::() as u32; + + // SAFETY: handle is valid, dcb is a properly sized stack buffer. + if unsafe { GetCommState(handle, &mut dcb) } == 0 { + bail!("GetCommState failed: {}", io::Error::last_os_error()); + } + + // 8 data bits, no parity, 1 stop bit, binary mode, no flow control + dcb.byte_size = 8; + dcb.parity = 0; // NOPARITY + dcb.stop_bits = 0; // ONESTOPBIT + dcb.flags = DCB_FLAGS_BINARY; + + // SAFETY: handle is valid, dcb is properly initialized from + // GetCommState + our modifications. + if unsafe { SetCommState(handle, &mut dcb) } == 0 { + bail!("SetCommState failed: {}", io::Error::last_os_error()); + } + + // Set initial timeouts. + // SAFETY: handle is valid, timeouts is a properly initialized + // stack struct. + let timeouts = CommTimeouts { + read_interval_timeout: 0, + read_total_timeout_multiplier: 0, + read_total_timeout_constant: 6000, + write_total_timeout_multiplier: 0, + write_total_timeout_constant: 5000, + }; + + if unsafe { SetCommTimeouts(handle, &timeouts) } == 0 { + bail!("SetCommTimeouts failed: {}", io::Error::last_os_error()); + } + + Ok(()) +}