diff --git a/Makefile b/Makefile index f616809a0c05..a4cfb857e509 100644 --- a/Makefile +++ b/Makefile @@ -55,13 +55,13 @@ endef # Fast bootstrap. fast: release-image barretenberg boxes playground docs aztec-up \ - bb-tests l1-contracts-tests yarn-project-tests boxes-tests playground-tests aztec-up-tests docs-tests noir-protocol-circuits-tests release-image-tests spartan claude-tests + bb-tests l1-contracts-tests yarn-project-tests boxes-tests playground-tests aztec-up-tests docs-tests noir-protocol-circuits-tests release-image-tests spartan claude-tests ipc-codegen-tests # Full bootstrap. full: fast bb-full-tests bb-cpp-full yarn-project-benches # Release. Everything plus copy bb cross compiles to ts projects. -release: fast bb-cpp-release-dir bb-ts-cross-copy +release: fast bb-cpp-release-dir bb-ts-cross-copy ipc-runtime-cross #============================================================================== # Noir @@ -211,7 +211,7 @@ bb-cpp-release-dir: bb-cpp-native bb-cpp-cross bb-cpp-full: bb-cpp bb-cpp-gcc bb-cpp-fuzzing bb-cpp-asan bb-cpp-smt bb-cpp-cross-arm64-macos bb-cpp-cross-arm64-ios bb-cpp-cross-arm64-android # BB TypeScript - TypeScript bindings -bb-ts: bb-cpp-wasm bb-cpp-wasm-threads bb-cpp-native +bb-ts: bb-cpp-wasm bb-cpp-wasm-threads bb-cpp-native ipc-runtime $(call build,$@,barretenberg/ts) # Copies the cross-compiles into bb.js. @@ -275,6 +275,37 @@ bb-tests: bb-cpp-native-tests bb-acir-tests bb-ts-tests bb-sol-tests bb-bbup-tes bb-full-tests: bb-cpp-wasm-threads-tests bb-cpp-asan-tests bb-cpp-smt-tests +#============================================================================== +# IPC Codegen +#============================================================================== + +.PHONY: ipc-codegen ipc-codegen-tests +ipc-codegen: + $(call build,$@,ipc-codegen) + +ipc-codegen-tests: ipc-codegen + $(call test,$@,ipc-codegen) + +.PHONY: ipc-runtime ipc-runtime-tests ipc-runtime-cross +ipc-runtime: + $(call build,$@,ipc-runtime) + +ipc-runtime-tests: ipc-runtime + $(call test,$@,ipc-runtime) + +# Cross-compile the NAPI addon for the 3 non-host release targets. +# Host (amd64-linux) addon is produced by the standalone `ipc-runtime` target. +ipc-runtime-cross-arm64-linux: + $(call build,$@,ipc-runtime,build_cross arm64-linux) + +ipc-runtime-cross-amd64-macos: + $(call build,$@,ipc-runtime,build_cross amd64-macos) + +ipc-runtime-cross-arm64-macos: + $(call build,$@,ipc-runtime,build_cross arm64-macos) + +ipc-runtime-cross: ipc-runtime ipc-runtime-cross-arm64-linux ipc-runtime-cross-amd64-macos ipc-runtime-cross-arm64-macos + #============================================================================== # .claude tooling #============================================================================== diff --git a/barretenberg/.gitignore b/barretenberg/.gitignore index ba04309b3eac..891a8e0ec957 100644 --- a/barretenberg/.gitignore +++ b/barretenberg/.gitignore @@ -13,3 +13,6 @@ bench-out rust/barretenberg-rs/src/generated_types.rs rust/barretenberg-rs/src/api.rs ts/src/cbind/generated/ + +# Codegen output dirs (ipc-codegen emits into a `generated/` subdir under each consumer) +**/generated/ diff --git a/bootstrap.sh b/bootstrap.sh index 919c1a331448..f2639d4985db 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -536,6 +536,7 @@ function release { projects=( barretenberg/cpp + ipc-runtime barretenberg/ts barretenberg/rust noir diff --git a/ipc-codegen/.rebuild_patterns b/ipc-codegen/.rebuild_patterns new file mode 100644 index 000000000000..385c00373c76 --- /dev/null +++ b/ipc-codegen/.rebuild_patterns @@ -0,0 +1,9 @@ +^ipc-codegen/src/.*\.ts$ +^ipc-codegen/templates/ +^ipc-codegen/echo_example/ +^ipc-codegen/package\.json$ +^ipc-codegen/bootstrap\.sh$ +^ipc-runtime/cpp/ +^ipc-runtime/ts/ +^ipc-runtime/zig/ +^ipc-runtime/rust/ diff --git a/ipc-codegen/README.md b/ipc-codegen/README.md new file mode 100644 index 000000000000..2e53e7b3736d --- /dev/null +++ b/ipc-codegen/README.md @@ -0,0 +1,279 @@ +# ipc-codegen + +Schema-driven IPC code generator for **C++**, **TypeScript**, **Rust**, and **Zig**. + +Given a hand-authored JSONC schema describing a service's commands and +responses, emits matching wire-type definitions plus a typed client and/or +server-side dispatcher in the target language. The schema is the source of +truth — see `SCHEMA_SPEC.md` for its format. Wire format is msgpack; the actual byte transport +(Unix-domain socket or MPSC shared memory) is provided by +[`/ipc-runtime`](../ipc-runtime) — clients and servers in different languages +talk byte-compatibly because they all pack the same wire types. + +## Quick start + +```sh +cd ipc-codegen +./bootstrap.sh build # generate echo example bindings, compile all 4 languages +./bootstrap.sh test # run the cross-language wire-compat matrix +``` + +## How it fits together + +``` + ┌──────────────────┐ + │ *_schema.jsonc │ (hand-authored, committed next + └────────┬─────────┘ to the C++ server it describes) + │ + ▼ + ┌──────────────────┐ + │ ipc-codegen │ (this package) + └────────┬─────────┘ + │ + ┌──────────┬───────┴───────┬──────────┐ + ▼ ▼ ▼ ▼ + wire types, wire types, wire types, wire types, + typed typed typed typed + client + client + client + client + + server server server server + (C++) (TS) (Rust) (Zig) + │ │ │ │ + └──────────┴────────┬──────┴──────────┘ + │ + ▼ + ┌──────────────────┐ + │ ipc-runtime │ (transport: UDS / MPSC-SHM, + └──────────────────┘ same path-suffix dispatch in + every language) +``` + +ipc-codegen knows nothing about sockets, shared memory, or processes — it just +serialises typed commands to msgpack bytes and back. ipc-runtime knows nothing +about your service's commands — it just moves bytes. Consumers wire the two +together (codegen-emitted dispatcher on top of an ipc-runtime server, or +codegen-emitted typed client on top of an ipc-runtime client). + +## Layout + +``` +ipc-codegen/ + bootstrap.sh # build / test / update_goldens / hash + src/ # generator (TypeScript, runs under Node 22+) + generate.ts # CLI entry point + schema_visitor.ts # friendly/positional schema -> CompiledSchema IR + cpp_codegen.ts # IR -> C++ output + typescript_codegen.ts # IR -> TypeScript output + rust_codegen.ts # IR -> Rust output + zig_codegen.ts # IR -> Zig output + naming.ts # snake_case / PascalCase helpers + templates/ # static templates copied alongside generated code + cpp/ipc_codegen/*.hpp # C++ support headers copied into generated output + rust/{backend,error,ffi_backend}.rs + zig/{backend,ffi_backend}.zig + echo_example/ # 4-language echo service (cross-lang test harness) + SCHEMA_SPEC.md # wire protocol and schema-format reference +``` + +The package contains no service schemas of its own. Each consumer owns and +commits its hand-authored schema next to the C++ server that implements the +service, and invokes `generate.ts` with that local path. + +## CLI: `src/generate.ts` + +Invoked once per (schema, language) pair. Run directly with +`node --experimental-strip-types`, or via `bootstrap.sh`. + +``` +node --experimental-strip-types --experimental-transform-types --no-warnings \ + src/generate.ts --schema --lang --out [flags] +``` + +### Required flags + +| Flag | Purpose | +|---|---| +| `--schema ` | Path to the schema (JSON or JSONC; friendly or legacy positional form — see `SCHEMA_SPEC.md`). | +| `--lang ` | Target language. | +| `--out ` | Output directory. Generated files are (re)written every run; static templates are copied alongside and re-copied only if missing (so handwritten edits to templated scaffolding are preserved). | + +### Role flags + +| Flag | Purpose | +|---|---| +| `--server` | Emit server dispatch (matches request name to handler, deserialises, calls handler, serialises response). Pair it with an `ipc::IpcServer` from ipc-runtime. | +| `--client` | Emit a typed client class/struct with one method per command. Pair it with an `ipc::IpcClient` (C++) or the equivalent Rust/Zig/TS binding. | +| `--package ` | TS only. Emit a complete package wrapper around the generated async client. The wrapper launches a native service binary, connects over UDS or SHM, and resolves the binary from an override path, environment variable, installed arch package, or local `build//` directory. | +| `--uds` | Rust/Zig only. Copies the `Backend` trait template (and `error.rs` for Rust) into `` so consumers can plug ipc-runtime — or any custom transport — behind the generated client. The flag name is historical: the trait is transport-agnostic. | +| `--ffi` | Rust/Zig only. Adds the `ffi_backend` template (a thin wrapper exposing the generated client over a C ABI for embedding in other languages). | + +### Naming flags + +Friendly-format schemas set naming via the schema's `service` field: generated +types are `` and client methods are the bare command name. The +flags below are only for legacy positional schemas, which have no `service`. + +| Flag | Purpose | +|---|---| +| `--prefix ` | Positional schemas only. Type prefix applied to generated type names (`CircuitProve`, etc.). Auto-detected from the command names if omitted. Ignored when the schema declares `service`. | +| `--strip-method-prefix` | Positional schemas only. Drops the prefix from client *method* names: `bbCircuitProve()` → `circuitProve()`. Types keep the prefix. Implied when the schema declares `service`. | + +### C++-specific flags + +| Flag | Purpose | +|---|---| +| `--cpp-namespace ` | C++ namespace, e.g. `my::service`. Default: lowercased prefix. | +| `--cpp-wire-namespace ` | Inner namespace for wire types, default `wire`. | +| `--cpp-include-dir ` | Include-path prefix for cross-includes between generated files, e.g. `myservice/generated`. Leave unset when generated files are in the same directory as their consumer. | + +### Other + +| Flag | Purpose | +|---|---| +| `--curve-constants` | TS only. Also emit `curve_constants.ts` with bn254/grumpkin/secp moduli & generators for schemas that need curve constants. | +| `--skeleton ` | One-shot scaffolding: writes a `_handlers.{ts,rs,zig,cpp}` stub, `main`, and a build file into `` if they don't already exist. Skipped on subsequent runs. | +| `--package-name ` | TS package mode only. Package name to write into the generated wrapper `package.json`. | +| `--binary-name ` | TS package mode only. Native service binary name to launch. | +| `--binary-env-var ` | TS package mode only. Environment variable that can override the binary path. Defaults to `_PATH`. | +| `--package-transports ` | TS package mode only. Comma-separated transports supported by the generated wrapper. | +| `--ipc-runtime-dependency ` | TS package mode only. Dependency spec for `@aztec/ipc-runtime`, e.g. a release version or local `file:` dependency in examples. | + +## Worked examples + +Paths below are illustrative — consumers commit their own schema next to the +C++ server that owns the wire format and supply absolute or relative paths on +the command line. + +### TypeScript client, with curve constants + +```sh +src/generate.ts \ + --schema /path/to/myservice_schema.jsonc \ + --lang ts \ + --out /path/to/output/generated \ + --client \ + --curve-constants +``` + +Produces `api_types.ts`, `async.ts`, `sync.ts`, `curve_constants.ts`. The TS +client uses `@aztec/ipc-runtime`'s `UdsIpcClient` or `NapiShmSyncClient` for +transport — no template copy. + +### TypeScript spawned-service package + +```sh +src/generate.ts \ + --schema /path/to/myservice_schema.jsonc \ + --lang ts \ + --out /path/to/myservice/src/generated \ + --client \ + --package /path/to/myservice \ + --package-name @aztec/myservice \ + --binary-name myservice \ + --package-transports uds,shm +``` + +Produces the generated TS client under `src/generated/` plus a package shell +(`package.json`, `tsconfig.json`, `src/index.ts`, `src/platform.ts`, and +`scripts/prepare_arch_packages.sh`). The package exports a +`MyServiceService.spawn(...)` helper that launches the native binary and wraps +the generated async client. `scripts/prepare_arch_packages.sh` turns +`build//` directories into per-architecture npm packages +matching the binary resolution path. + +### C++ server + client, under a project sub-include path + +```sh +src/generate.ts \ + --schema /path/to/myservice_schema.jsonc \ + --lang cpp \ + --out /path/to/myservice/generated \ + --server --client \ + --cpp-namespace my::ns \ + --cpp-include-dir myservice/generated +``` + +Produces `myservice_types.hpp`, `myservice_ipc_client.{hpp,cpp}`, and +`myservice_ipc_server.hpp`. Cross-includes use the supplied `--cpp-include-dir` prefix +(`#include "myservice/generated/myservice_types.hpp"`). Wire to an +`ipc::IpcServer` (from ipc-runtime) plus a hand-written +`_handlers.cpp` that supplies one `handle_(...)` per command. +Generated C++ includes support headers as `ipc_codegen/...`; the generator +copies those headers from `templates/cpp/ipc_codegen/` into the output +directory. + +### Rust client + FFI backend + +```sh +src/generate.ts \ + --schema /path/to/myservice_schema.jsonc \ + --lang rust \ + --out /path/to/crate/src/generated \ + --client --uds --ffi \ + --skeleton /path/to/crate/src +``` + +Produces `myservice_types.rs`, `myservice_client.rs`, plus `backend.rs`, +`error.rs`, `ffi_backend.rs`. UDS/SHM transport is provided by the +`ipc-runtime` Rust crate; the consumer chooses which to use via the path +suffix passed at runtime. The skeleton flag also writes a one-time +`myservice_handlers.rs`, `main.rs`, `Cargo.toml`, and `generate.sh` into the +skeleton dir so a new service crate is buildable on first run. + +### Zig client + server + +```sh +src/generate.ts \ + --schema /path/to/myservice_schema.jsonc \ + --lang zig \ + --out /path/to/output/generated \ + --server --client --uds --ffi +``` + +Produces `myservice_types.zig`, `myservice_client.zig`, +`myservice_server.zig`, plus `backend.zig` and `ffi_backend.zig`. Consumers +`@import("ipc_runtime")` for transport. + +## Adding a new service + +1. **Author the schema** as friendly JSONC (`service`, `types`, `error`, + `commands`), and commit it next to the C++ server that will implement the + service. This file is the wire-format source of truth — see `SCHEMA_SPEC.md` + for the format. The C++ wire structs (`MSGPACK_SCHEMA_NAME` + + `SERIALIZATION_FIELDS`) are generated from it, not hand-written. +2. **Wire your consumer's build to invoke `src/generate.ts`**, passing the + absolute path to the committed schema and the desired output directory. + Generated files go under a `generated/` directory which is gitignored by + convention. +3. **Wire transport.** On the C++ server side, instantiate an + `ipc::IpcServer` via `ipc::make_server(path)` (from ipc-runtime) and feed + it the codegen-emitted `make__handler(...)`. On the client side + (any language), point an `ipc::IpcClient` / equivalent at the same path + and wrap it with the codegen-emitted client. +4. **Run `./bootstrap.sh test`** in `ipc-codegen/` to confirm the codegen and + cross-language wire-compat tests still pass. + +## Schemas are the source of truth + +The JSONC schema is the wire contract between client and server. Consumers +commit it next to the C++ server that implements the service, so the file lives +close to what it describes and tracks with that code. Whenever the wire contract +changes, edit the schema, regenerate the bindings, and commit the diff. Both +sides regenerate from the same committed schema, so they stay byte-compatible. + +Each generated file embeds a `SCHEMA_HASH` (a hash of the committed schema) so +callers can detect at connection time that their bindings predate the server. + +## Wire-format contract + +`echo_example/schema/golden/*.msgpack` is a frozen set of byte-level +fixtures covering every relevant msgpack encoding boundary (variable-width +ints, fixstr/str8/str16, bin8/bin16, optional `Some`/`None`, empty +containers, multi-byte UTF-8). The per-language golden tests +(`echo_example/{rust,ts}/...`) both decode the fixtures and re-encode +round-trip — pinning down canonical msgpack output across implementations. + +If you intentionally change the wire format, run +`./bootstrap.sh update_goldens` and review the diff. Any byte-level change +is a breaking change for external implementations of the schema. + +See `SCHEMA_SPEC.md` for the wire protocol details. diff --git a/ipc-codegen/SCHEMA_SPEC.md b/ipc-codegen/SCHEMA_SPEC.md new file mode 100644 index 000000000000..79f75956e935 --- /dev/null +++ b/ipc-codegen/SCHEMA_SPEC.md @@ -0,0 +1,228 @@ +# IPC Schema Format Specification + +This document specifies the schema format used for cross-language code generation +in the IPC codegen system. A schema is a single hand-authored JSONC file per +service and is the source of truth: ipc-codegen reads it to generate the wire +types, client, and server dispatch for every target language (TypeScript, C++, +Rust, Zig). The committed golden msgpack corpus is the cross-language wire-format +contract; the schema is a normal reviewed source file. + +JSONC is plain JSON with `//` and `/* */` comments stripped before parsing — no +extra dependencies. + +## Top-level structure + +A schema is a single object describing one service: + +```jsonc +{ + "service": "Echo", + + // Named byte aliases — nominal 32-byte types. Only bin32 today. + "aliases": { + "Fr": "bin32" + }, + + // Shared struct types, referenced by name from commands or other types. + "types": { + "EchoInner": { + "values": "bytes[]", + "flag": "bool?" + } + }, + + // The error variant, declared once and shared by every command. + "error": { "message": "string" }, + + // command -> { request, response }. + "commands": { + "Bytes": { "request": { "data": "bytes" }, + "response": { "data": "bytes" } }, + + "Fields": { "request": { "a": "u32", "b": "u64", "name": "string" }, + "response": { "a": "u32", "b": "u64", "name": "string" } }, + + "Nested": { "request": { "inner": "EchoInner" }, + "response": { "inner": "EchoInner" } }, + + "Aliases": { "request": { "treeId": "u32", "hash": "Fr", + "maybeHash": "Fr?", "hashes": "Fr[]" }, + "response": { "treeId": "u32", "hash": "Fr", + "maybeHash": "Fr?", "hashes": "Fr[]" } }, + + "Blobs": { "request": { "maybeData": "bytes?", "parts": "bytes[2]" }, + "response": { "maybeData": "bytes?", "parts": "bytes[2]" } }, + + "Fail": { "request": { "message": "string" }, + "response": {} } + } +} +``` + +### `service` + +The service name. It is the prefix for generated **type** names and is *not* +included in **method** names: + +- Command `Bytes` under `"service": "Echo"` generates the wire type `EchoBytes` + and the response type `EchoBytesResponse`. +- The corresponding client method / server handler is the bare command name + (`bytes` / `handle_bytes`), projected to each language's casing convention. + +The error type is named `ErrorResponse` (e.g. `EchoErrorResponse`). + +### `aliases` + +A map of alias name to underlying type. Two kinds: + +- **Nominal byte alias** (`bin32`): a distinct named 32-byte value (e.g. `Fr` is + a field element, not raw bytes). It carries its name as a dispatch tag and is + generated as a distinct wrapper type per language. `bin32` is the only nominal + byte width supported today. +- **Scalar synonym**: an alias whose underlying is a primitive (e.g. + `MerkleTreeId: u32`). These are transparent — generated as plain type + aliases — so consumers can `static_cast`/coerce them to and from the + underlying integer or enum. Because they are transparent, declaring them is + optional: a field may simply use the primitive (`u32`) directly. + +### `types` + +Named shared struct types, each a field-name → type-reference map. A type is +inlined at every reference and deduplicated by name, so it may be referenced +from multiple commands or from other `types`. + +### `error` + +The error struct, declared once. It must have exactly one field `message` +of type `string`. Generated servers wrap handler failures into this variant; +generated clients surface its `message`. + +### `commands` + +A map of command name to `{ request, response }`, where each of `request` and +`response` is a field-name → type-reference map. An empty object `{}` denotes a +command with no fields (e.g. a `Fail` command whose response carries nothing). + +A `response` may instead be a **string** naming another command's response type +to reuse its shape — e.g. `"response": "AliasesResponse"` reuses the +`EchoAliases` response. Use the generated response type name (`Response`). + +## Type-reference shorthand grammar + +Every field type is a shorthand string. The grammar is a leaf type optionally +followed by suffixes, applied right to left: + +| Suffix | Meaning | +|---------|-------------------| +| `T?` | optional | +| `T[]` | vector of `T` | +| `T[N]` | fixed array of N | + +Suffixes compose, e.g. `Fr[]`, `bytes?`, `Fr[2]`, `EchoInner[]`. + +Leaf types: + +| Leaf | Meaning | +|-----------------|------------------------------------------| +| `bool` | boolean | +| `u8 u16 u32 u64`| unsigned 8/16/32/64-bit integers | +| `f64` | 64-bit float | +| `string` | UTF-8 string | +| `bytes` | variable-length byte string (msgpack bin)| +| `bin32` | fixed 32-byte value (msgpack bin) | +| alias name | a declared `aliases` entry (e.g. `Fr`) | +| type name | a declared `types` entry (e.g. `EchoInner`) | + +## Validation rules + +Schemas are validated at generation time; violations are hard errors: + +- `service` must be a non-empty string. +- The `error` struct must have exactly one field, `message: string`. +- Each command produces a matching `Response`; the command and + non-error response counts must agree. +- Command names must be unique. +- A type reference must resolve to a primitive, a declared alias, or a declared + type. +- Field names must not project (via the snake_case or camelCase mapping) to a + reserved word in any target language, and two fields in one struct must not + collapse to the same projected identifier. +- A struct supports at most 20 fields (the C++ serialization macro limit). + +## Wire protocol + +The schema defines the types; this section specifies how a value of each type +is serialized. The golden corpus pins these encodings across all languages. + +### Framing + +All messages use length-prefix framing: + +``` +[4 bytes: payload length, little-endian uint32][payload: msgpack bytes] +``` + +### Request wire format + +A request is a 1-element msgpack array wrapping a `[name, payload]` pair: + +``` +array(1) [ array(2) [ str: "", map: { field: value, ... } ] ] +``` + +The dispatch tag is the generated command type name (e.g. `EchoBytes`). The +outer array exists for extensibility. + +### Response wire format + +A response is a `[name, payload]` pair (no outer wrapper): + +``` +array(2) [ str: "Response" | "ErrorResponse", map: { ... } ] +``` + +A response whose name is `ErrorResponse` indicates an error; its +`message` field carries the text. + +### Type wire encoding + +| Schema type | msgpack encoding | +|--------------------|-----------------------------------------------| +| `bool` | bool | +| `u8 u16 u32 u64` | integer (smallest encoding that fits) | +| `f64` | float64 | +| `string` | str | +| `bytes` | bin | +| `bin32` | bin (32 bytes) | +| `T?` (optional) | nil if absent, else the encoding of `T` | +| `T[]` (vector) | array | +| `T[N]` (array) | array (fixed length) | +| alias | same encoding as the alias's underlying type | +| struct | map with string keys (field names) | + +### Integer encoding note + +msgpack uses the smallest encoding that fits the value, not the declared type: +a `u64` of `5` encodes as a single positive-fixint byte. Decoders MUST accept +any integer encoding width for any integer field. + +## Schema versioning + +A SHA-256 hash of the schema can be computed and embedded in generated code for +optional compatibility checking at connection time. A mismatch indicates the +service binary and client were generated from different schema versions. + +## Adding a new command + +1. Add an entry to `commands` with its `request`/`response` field maps (declare + any new `types`/`aliases` it needs). +2. Re-run ipc-codegen for every target language and confirm everything compiles. +3. If the change alters the wire format, refresh the golden corpus + (`./bootstrap.sh update_goldens`) and review the byte-level diff — any + change is breaking for external implementations of the schema. + +## Source files + +- Schema front-end + IR compiler: `ipc-codegen/src/schema_visitor.ts` +- CLI entry point: `ipc-codegen/src/generate.ts` +- Example schema: `ipc-codegen/echo_example/schema/schema.jsonc` diff --git a/ipc-codegen/bootstrap.sh b/ipc-codegen/bootstrap.sh new file mode 100755 index 000000000000..fa01a18fa960 --- /dev/null +++ b/ipc-codegen/bootstrap.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash +# IPC codegen package. +# Generates IPC bindings from committed JSON schemas under schemas/, in TS, C++, +# Rust and Zig. Zero npm dependencies — runs with just Node.js (v22+). +# +# The build's only direct consumer is its own cross-language test harness under +# echo_example/. Service consumers invoke ipc-codegen from their own build +# scripts with their own schema inputs. + +source $(git rev-parse --show-toplevel)/ci3/source_bootstrap + +hash=$(cache_content_hash .rebuild_patterns) + +function build { + echo_header "ipc-codegen build" + + # Service generation is invoked by each service's own build flow. The build + # step here invokes each echo example project's own bootstrap, so every + # project documents and owns its generation/build flow. + (cd echo_example/cpp && ./bootstrap.sh) + (cd echo_example/rust && ./bootstrap.sh) + (cd echo_example/ts && ./bootstrap.sh) + (cd echo_example/ts_package && ./bootstrap.sh) + (cd echo_example/zig && ./bootstrap.sh) + + # NB: the golden msgpack fixtures under echo_example/schema/golden/ are + # COMMITTED and FROZEN — they're the binding wire-format contract. Don't + # regenerate them here. If a deliberate wire-format change requires + # refreshing them, run `./bootstrap.sh update_goldens` and commit the diff. +} + +function update_goldens { + echo_header "ipc-codegen update_goldens" + # Rebuild the rust generate_golden binary first. + (cd echo_example/rust && cargo build --quiet --bin generate_golden) + echo_example/rust/target/debug/generate_golden --output-dir echo_example/schema/golden + echo "" + echo "Goldens refreshed. Review the diff carefully — these are the wire-format" + echo "contract, and any byte-level change is a breaking change for external" + echo "implementations of the schema." +} + +function test_cmds { + local matrix_langs=(rust ts zig cpp) + + local prefix="$hash:CPUS=1:TIMEOUT=120s" + local script="ipc-codegen/echo_example/scripts/run_cross_language_test.sh" + + # Generator unit tests (schema validation). + echo "$prefix node --experimental-strip-types --no-warnings ipc-codegen/test/schema_visitor.test.ts" + + # Golden tests (each language verifies it can deserialize the goldens + # baked by build(), and re-encode them byte-identically). + echo "$prefix $script golden rust" + echo "$prefix $script golden ts" + echo "$prefix $script golden cpp" + echo "$prefix $script golden zig" + echo "$prefix ipc-codegen/echo_example/ts_package/test.sh uds" + echo "$prefix ipc-codegen/echo_example/ts_package/test.sh shm" + + # Matrix: one command per (server, client) pair over UDS. + for server in "${matrix_langs[@]}"; do + for client in "${matrix_langs[@]}"; do + echo "$prefix $script matrix $server $client uds" + done + done + + # SHM matrix. Argument order is server then client. TS is only covered as a + # client because ipc-runtime/ts has no SHM server. + local shm_server_langs=(rust zig cpp) + local native_shm_client_langs=(rust zig cpp) + for server in "${shm_server_langs[@]}"; do + for client in "${native_shm_client_langs[@]}"; do + echo "$prefix $script matrix $server $client shm" + done + done + + # TS SHM client coverage. The NAPI addon is built by ipc-runtime/bootstrap.sh + # during build(); a missing addon should fail these tests loudly rather than + # silently dropping coverage. + for server in "${shm_server_langs[@]}"; do + echo "$prefix $script matrix $server ts shm" + done +} + +function test { + echo_header "ipc-codegen test" + test_cmds | filter_test_cmds | parallelize +} + +case "$cmd" in + "") + build + ;; + "hash") + echo "$hash" + ;; + *) + default_cmd_handler "$@" + ;; +esac diff --git a/ipc-codegen/echo_example/cpp/.gitignore b/ipc-codegen/echo_example/cpp/.gitignore new file mode 100644 index 000000000000..cff0b76e42aa --- /dev/null +++ b/ipc-codegen/echo_example/cpp/.gitignore @@ -0,0 +1,2 @@ +build/ +src/generated/ diff --git a/ipc-codegen/echo_example/cpp/CMakeLists.txt b/ipc-codegen/echo_example/cpp/CMakeLists.txt new file mode 100644 index 000000000000..ab156f9b048a --- /dev/null +++ b/ipc-codegen/echo_example/cpp/CMakeLists.txt @@ -0,0 +1,52 @@ +cmake_minimum_required(VERSION 3.20) + +project(echo_wire_compat_cpp CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(REPO_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../..") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/bin") + +include(FetchContent) + +FetchContent_Declare( + msgpack_c + GIT_REPOSITORY https://github.com/msgpack/msgpack-c.git + GIT_TAG 8c602e8579c7e7d65d6f9c6703c9699db3fb0488 + GIT_SHALLOW TRUE +) + +FetchContent_GetProperties(msgpack_c) +if(NOT msgpack_c_POPULATED) + FetchContent_Populate(msgpack_c) +endif() + +add_subdirectory( + "${REPO_ROOT}/ipc-runtime/cpp" + "${CMAKE_CURRENT_BINARY_DIR}/ipc-runtime" + EXCLUDE_FROM_ALL +) + +add_library(echo_common INTERFACE) +target_include_directories(echo_common INTERFACE + "${CMAKE_CURRENT_SOURCE_DIR}/src" + "${CMAKE_CURRENT_SOURCE_DIR}/src/generated" + "${msgpack_c_SOURCE_DIR}/include" +) +target_compile_definitions(echo_common INTERFACE + MSGPACK_NO_BOOST + MSGPACK_USE_STD_VARIANT_ADAPTOR +) + +add_executable(echo_server src/echo_server.cpp) +target_link_libraries(echo_server PRIVATE echo_common ipc_runtime) + +add_executable(echo_client + src/echo_client.cpp + src/generated/echo_ipc_client.cpp +) +target_link_libraries(echo_client PRIVATE echo_common ipc_runtime) + +add_executable(golden_test src/golden_test.cpp) +target_link_libraries(golden_test PRIVATE echo_common) diff --git a/ipc-codegen/echo_example/cpp/README.md b/ipc-codegen/echo_example/cpp/README.md new file mode 100644 index 000000000000..fa7ee5f97044 --- /dev/null +++ b/ipc-codegen/echo_example/cpp/README.md @@ -0,0 +1,18 @@ +# C++ Echo Example + +Build from this directory: + +```sh +./bootstrap.sh +``` + +The bootstrap generates bindings and C++ codegen support headers into +`src/generated/`, fetches upstream `msgpack-c`, and builds `ipc-runtime/cpp` +as a subproject. Binaries are written to `build/bin/`. + +Run locally: + +```sh +build/bin/echo_server --socket /tmp/echo.sock +build/bin/echo_client --socket /tmp/echo.sock +``` diff --git a/ipc-codegen/echo_example/cpp/bootstrap.sh b/ipc-codegen/echo_example/cpp/bootstrap.sh new file mode 100755 index 000000000000..9f536ff54fe4 --- /dev/null +++ b/ipc-codegen/echo_example/cpp/bootstrap.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +CODEGEN="$(cd "$DIR/../.." && pwd)" +NODE="node --experimental-strip-types --experimental-transform-types --no-warnings" + +$NODE "$CODEGEN/src/generate.ts" \ + --schema "$DIR/../schema/schema.jsonc" \ + --lang cpp \ + --server \ + --client \ + --out "$DIR/src/generated" \ + --cpp-namespace echo + +cmake -S "$DIR" -B "$DIR/build" +cmake --build "$DIR/build" --target echo_server echo_client golden_test diff --git a/ipc-codegen/echo_example/cpp/src/echo_client.cpp b/ipc-codegen/echo_example/cpp/src/echo_client.cpp new file mode 100644 index 000000000000..ca2dc3751d9e --- /dev/null +++ b/ipc-codegen/echo_example/cpp/src/echo_client.cpp @@ -0,0 +1,133 @@ +// Echo IPC client (C++) — uses the generated EchoIpcClient. +// Usage: echo_client --socket /tmp/echo.sock + +#include "generated/echo_ipc_client.hpp" + +#include +#include +#include + +// Explicit check (not assert): verification must survive NDEBUG builds. +#define CHECK(cond, label) \ + do { \ + if (!(cond)) { \ + std::cerr << "echo_client(cpp): " << (label) << " FAIL\n"; \ + return 1; \ + } \ + } while (0) + +namespace { +echo::Fr test_hash(uint8_t base) { + echo::Fr hash{}; + for (size_t i = 0; i < hash.size(); ++i) { + hash[i] = static_cast(base + i); + } + return hash; +} +} // namespace + +int main(int argc, char **argv) { + const char *socket_path = nullptr; + for (int i = 1; i < argc - 1; i++) { + if (std::string_view(argv[i]) == "--socket") + socket_path = argv[i + 1]; + } + if (!socket_path) { + std::cerr << "Usage: echo_client --socket \n"; + return 1; + } + + echo::EchoIpcClient client(socket_path); + + { + auto resp = client.bytes({.data = {0xDE, 0xAD, 0xBE, 0xEF, 0x42}}); + CHECK((resp.data == std::vector{0xDE, 0xAD, 0xBE, 0xEF, 0x42}), + "EchoBytes"); + std::cerr << "echo_client(cpp): EchoBytes OK\n"; + } + + { + auto resp = + client.fields({.a = 42, .b = 999999, .name = "hello wire compat"}); + CHECK(resp.a == 42 && resp.b == 999999 && resp.name == "hello wire compat", + "EchoFields"); + std::cerr << "echo_client(cpp): EchoFields OK\n"; + } + + { + auto resp = + client.nested({.inner = {.values = {{1, 2, 3}, {4, 5}}, .flag = true}}); + CHECK((resp.inner.values == + std::vector>{{1, 2, 3}, {4, 5}}), + "EchoNested values"); + CHECK(resp.inner.flag == true, "EchoNested flag"); + std::cerr << "echo_client(cpp): EchoNested OK\n"; + } + + { + auto hash = test_hash(0x10); + auto second = test_hash(0x40); + auto resp = client.aliases({.treeId = 7, + .hash = hash, + .maybeHash = second, + .hashes = {hash, second}}); + CHECK(resp.treeId == 7, "EchoAliases treeId"); + CHECK(resp.hash == hash, "EchoAliases hash"); + CHECK(resp.maybeHash == second, "EchoAliases maybeHash"); + CHECK((resp.hashes == std::vector{hash, second}), + "EchoAliases hashes"); + std::cerr << "echo_client(cpp): EchoAliases OK\n"; + } + + // Optional-absent over live IPC. + { + auto hash = test_hash(0x10); + auto resp = client.aliases({.treeId = 7, + .hash = hash, + .maybeHash = std::nullopt, + .hashes = {hash}}); + CHECK(!resp.maybeHash.has_value(), "EchoAliases none"); + std::cerr << "echo_client(cpp): EchoAliases none OK\n"; + } + + // uint64 wire encoding above 2^32 over live IPC. + { + const uint64_t big = (1ULL << 53) - 1; + auto resp = client.fields({.a = 42, .b = big, .name = "big"}); + CHECK(resp.b == big, "EchoFields u64"); + std::cerr << "echo_client(cpp): EchoFields u64 OK\n"; + } + + // Optional bytes Some/None and fixed [bytes; 2]. + { + auto resp = client.blobs({.maybeData = std::vector{0xAA, 0xBB}, + .parts = {{{1, 2, 3}, {4}}}}); + CHECK((resp.maybeData == std::vector{0xAA, 0xBB}), + "EchoBlobs maybeData"); + CHECK((resp.parts == std::array, 2>{{{1, 2, 3}, {4}}}), + "EchoBlobs parts"); + auto resp_none = + client.blobs({.maybeData = std::nullopt, .parts = {{{}, {9}}}}); + CHECK(!resp_none.maybeData.has_value(), "EchoBlobs none"); + std::cerr << "echo_client(cpp): EchoBlobs OK\n"; + } + + // Server error surfaces with its message. + { + bool threw = false; + std::string message; + try { + client.fail({.message = "deliberate failure"}); + } catch (const std::exception &e) { + threw = true; + message = e.what(); + } + CHECK(threw, "EchoFail threw"); + CHECK(message.find("deliberate failure") != std::string::npos, + "EchoFail message"); + std::cerr << "echo_client(cpp): EchoFail OK\n"; + } + + std::cerr << "echo_client(cpp): all tests passed\n"; + return 0; +} diff --git a/ipc-codegen/echo_example/cpp/src/echo_server.cpp b/ipc-codegen/echo_example/cpp/src/echo_server.cpp new file mode 100644 index 000000000000..87fd89a67a6a --- /dev/null +++ b/ipc-codegen/echo_example/cpp/src/echo_server.cpp @@ -0,0 +1,69 @@ +// Echo IPC server (C++) — provides handler specializations for the +// header-only generated dispatch. +// Usage: echo_server --socket /tmp/echo.sock + +#include "generated/echo_ipc_server.hpp" + +#include +#include +#include + +namespace echo { + +struct EchoCtx {}; // empty context for the echo service + +// Template specializations — echo input fields back in response. +template <> +wire::EchoBytesResponse handle_bytes(EchoCtx & /*ctx*/, wire::EchoBytes &&cmd) { + return {.data = std::move(cmd.data)}; +} + +template <> +wire::EchoFieldsResponse handle_fields(EchoCtx & /*ctx*/, + wire::EchoFields &&cmd) { + return {.a = cmd.a, .b = cmd.b, .name = std::move(cmd.name)}; +} + +template <> +wire::EchoNestedResponse handle_nested(EchoCtx & /*ctx*/, + wire::EchoNested &&cmd) { + return {.inner = std::move(cmd.inner)}; +} + +template <> +wire::EchoAliasesResponse handle_aliases(EchoCtx & /*ctx*/, + wire::EchoAliases &&cmd) { + return {.treeId = cmd.treeId, + .hash = cmd.hash, + .maybeHash = cmd.maybeHash, + .hashes = std::move(cmd.hashes)}; +} + +template <> +wire::EchoBlobsResponse handle_blobs(EchoCtx & /*ctx*/, wire::EchoBlobs &&cmd) { + return {.maybeData = std::move(cmd.maybeData), + .parts = std::move(cmd.parts)}; +} + +template <> +wire::EchoFailResponse handle_fail(EchoCtx & /*ctx*/, wire::EchoFail &&cmd) { + throw std::runtime_error(cmd.message); +} + +} // namespace echo + +int main(int argc, char **argv) { + const char *socket_path = nullptr; + for (int i = 1; i < argc - 1; i++) { + if (std::string_view(argv[i]) == "--socket") + socket_path = argv[i + 1]; + } + if (!socket_path) { + std::cerr << "Usage: echo_server --socket \n"; + return 1; + } + + echo::EchoCtx ctx; + echo::serve(socket_path, ctx); + return 0; +} diff --git a/ipc-codegen/echo_example/cpp/src/golden_test.cpp b/ipc-codegen/echo_example/cpp/src/golden_test.cpp new file mode 100644 index 000000000000..747317683b3e --- /dev/null +++ b/ipc-codegen/echo_example/cpp/src/golden_test.cpp @@ -0,0 +1,246 @@ +// Golden file wire-format conformance test (C++). +// For each golden file, asserts: +// 1. We can decode the bytes into the expected typed wire value. +// 2. Re-encoding the same value with the request/response framing produces +// byte-identical output. +// The combination pins down the wire format as a binding contract. +// +// Usage: golden_test --golden-dir + +#include "generated/echo_dispatch.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +int g_pass = 0; +int g_fail = 0; + +std::vector read_file(const std::string &path) { + std::ifstream in(path, std::ios::binary); + if (!in) { + THROW std::runtime_error("cannot open " + path); + } + return std::vector(std::istreambuf_iterator(in), + std::istreambuf_iterator()); +} + +void report(const std::string &file, const std::string &error) { + if (error.empty()) { + std::cerr << " PASS: " << file << "\n"; + g_pass++; + } else { + std::cerr << " FAIL: " << file << ": " << error << "\n"; + g_fail++; + } +} + +bool bytes_equal(const msgpack::sbuffer &buf, + const std::vector &golden) { + return buf.size() == golden.size() && + std::memcmp(buf.data(), golden.data(), golden.size()) == 0; +} + +template bool equals(const T &a, const T &b) { return a == b; } + +bool equals(const echo::wire::EchoAliases &a, + const echo::wire::EchoAliases &b) { + return a == b; +} + +bool equals(const echo::wire::EchoAliasesResponse &a, + const echo::wire::EchoAliasesResponse &b) { + return a == b; +} + +// Requests are framed as [[name, payload-map]]. +template +void check_request(const std::string &dir, const std::string &file, + const std::string &name, const T &expected) { + try { + auto golden = read_file(dir + "/" + file); + auto unpacked = msgpack::unpack( + reinterpret_cast(golden.data()), golden.size()); + auto obj = unpacked.get(); + if (obj.type != msgpack::type::ARRAY || obj.via.array.size != 1) { + report(file, "expected outer array of size 1"); + return; + } + auto &inner = obj.via.array.ptr[0]; + if (inner.type != msgpack::type::ARRAY || inner.via.array.size != 2 || + inner.via.array.ptr[0].type != msgpack::type::STR) { + report(file, "expected [CommandName, payload]"); + return; + } + std::string got_name(inner.via.array.ptr[0].via.str.ptr, + inner.via.array.ptr[0].via.str.size); + if (got_name != name) { + report(file, "wrong command name: " + got_name); + return; + } + T value; + inner.via.array.ptr[1].convert(value); + if (!equals(value, expected)) { + report(file, "decoded value mismatch"); + return; + } + msgpack::sbuffer buf; + msgpack::packer pk(buf); + pk.pack_array(1); + pk.pack_array(2); + pk.pack(name); + pk.pack(value); + if (!bytes_equal(buf, golden)) { + report(file, "roundtrip byte mismatch (" + std::to_string(buf.size()) + + " vs " + std::to_string(golden.size()) + " bytes)"); + return; + } + report(file, ""); + } catch (const std::exception &e) { + report(file, e.what()); + } +} + +// Responses are framed as [name, payload-map]. +template +void check_response(const std::string &dir, const std::string &file, + const std::string &name, const T &expected) { + try { + auto golden = read_file(dir + "/" + file); + auto unpacked = msgpack::unpack( + reinterpret_cast(golden.data()), golden.size()); + auto obj = unpacked.get(); + if (obj.type != msgpack::type::ARRAY || obj.via.array.size != 2 || + obj.via.array.ptr[0].type != msgpack::type::STR) { + report(file, "expected [ResponseName, payload]"); + return; + } + std::string got_name(obj.via.array.ptr[0].via.str.ptr, + obj.via.array.ptr[0].via.str.size); + if (got_name != name) { + report(file, "wrong response name: " + got_name); + return; + } + T value; + obj.via.array.ptr[1].convert(value); + if (!equals(value, expected)) { + report(file, "decoded value mismatch"); + return; + } + msgpack::sbuffer buf; + msgpack::packer pk(buf); + pk.pack_array(2); + pk.pack(name); + pk.pack(value); + if (!bytes_equal(buf, golden)) { + report(file, "roundtrip byte mismatch (" + std::to_string(buf.size()) + + " vs " + std::to_string(golden.size()) + " bytes)"); + return; + } + report(file, ""); + } catch (const std::exception &e) { + report(file, e.what()); + } +} + +echo::Fr test_hash(uint8_t base) { + std::array bytes{}; + for (size_t i = 0; i < bytes.size(); i++) { + bytes[i] = static_cast(base + i); + } + return echo::Fr(bytes); +} + +} // namespace + +int main(int argc, char **argv) { + std::string dir; + for (int i = 1; i + 1 < argc; i++) { + if (std::string(argv[i]) == "--golden-dir") { + dir = argv[i + 1]; + } + } + if (dir.empty()) { + std::cerr << "Usage: golden_test --golden-dir \n"; + return 1; + } + + using namespace echo::wire; + + const std::vector bytes_payload = {0xDE, 0xAD, 0xBE, 0xEF, 0x42}; + const EchoInner nested_inner{{{1, 2, 3}, {4, 5}}, true}; + const auto hash = test_hash(0x10); + const auto second = test_hash(0x40); + + // ============ Original happy-path cases ============ + + check_request(dir, "echo_bytes_request.msgpack", "EchoBytes", + EchoBytes{bytes_payload}); + check_request(dir, "echo_fields_request.msgpack", "EchoFields", + EchoFields{42, 999999, "hello wire compat"}); + check_request(dir, "echo_nested_request.msgpack", "EchoNested", + EchoNested{nested_inner}); + check_request(dir, "echo_aliases_request.msgpack", "EchoAliases", + EchoAliases{7, hash, second, {hash, second}}); + + check_response(dir, "echo_bytes_response.msgpack", "EchoBytesResponse", + EchoBytesResponse{bytes_payload}); + check_response(dir, "echo_fields_response.msgpack", "EchoFieldsResponse", + EchoFieldsResponse{42, 999999, "hello wire compat"}); + check_response(dir, "echo_nested_response.msgpack", "EchoNestedResponse", + EchoNestedResponse{nested_inner}); + check_response(dir, "echo_aliases_response.msgpack", "EchoAliasesResponse", + EchoAliasesResponse{7, hash, second, {hash, second}}); + + // ============ Boundary cases ============ + + check_request(dir, "echo_bytes_empty.msgpack", "EchoBytes", EchoBytes{{}}); + check_request(dir, "echo_bytes_bin16.msgpack", "EchoBytes", + EchoBytes{std::vector(256, 0xAA)}); + check_request(dir, "echo_fields_max.msgpack", "EchoFields", + EchoFields{UINT32_MAX, UINT64_MAX, ""}); + check_request(dir, "echo_fields_uint_boundary.msgpack", "EchoFields", + EchoFields{128, uint64_t(UINT32_MAX) + 1, "x"}); + check_request(dir, "echo_fields_unicode.msgpack", "EchoFields", + EchoFields{0, 0, "héllo τέστ 🚀 mañana"}); + check_request(dir, "echo_fields_str16.msgpack", "EchoFields", + EchoFields{0, 0, std::string(300, 'a')}); + check_request(dir, "echo_nested_flag_none.msgpack", "EchoNested", + EchoNested{EchoInner{{}, std::nullopt}}); + check_request( + dir, "echo_nested_flag_false.msgpack", "EchoNested", + EchoNested{EchoInner{std::vector>{{}}, false}}); + + // ============ Blob / fail / error cases ============ + + const std::array, 2> blob_parts = { + std::vector{1, 2, 3}, std::vector{4}}; + const std::array, 2> blob_parts_none = { + std::vector{}, std::vector{9}}; + + check_request(dir, "echo_blobs_request.msgpack", "EchoBlobs", + EchoBlobs{std::vector{0xAA, 0xBB}, blob_parts}); + check_request(dir, "echo_blobs_none.msgpack", "EchoBlobs", + EchoBlobs{std::nullopt, blob_parts_none}); + check_response( + dir, "echo_blobs_response.msgpack", "EchoBlobsResponse", + EchoBlobsResponse{std::vector{0xAA, 0xBB}, blob_parts}); + check_request(dir, "echo_fail_request.msgpack", "EchoFail", + EchoFail{"deliberate failure"}); + check_response(dir, "echo_fail_response.msgpack", "EchoFailResponse", + EchoFailResponse{}); + check_response(dir, "echo_error_response.msgpack", "EchoErrorResponse", + EchoErrorResponse{"deliberate failure"}); + + std::cerr << "\nResults: " << g_pass << "/" << (g_pass + g_fail) + << " passed, " << g_fail << " failed\n"; + return g_fail > 0 ? 1 : 0; +} diff --git a/ipc-codegen/echo_example/rust/.gitignore b/ipc-codegen/echo_example/rust/.gitignore new file mode 100644 index 000000000000..68a9d34d4dc3 --- /dev/null +++ b/ipc-codegen/echo_example/rust/.gitignore @@ -0,0 +1,2 @@ +target/ +src/generated/ diff --git a/ipc-codegen/echo_example/rust/Cargo.lock b/ipc-codegen/echo_example/rust/Cargo.lock new file mode 100644 index 000000000000..2a7c80f2c944 --- /dev/null +++ b/ipc-codegen/echo_example/rust/Cargo.lock @@ -0,0 +1,168 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "echo-wire-compat" +version = "0.1.0" +dependencies = [ + "ipc-runtime", + "libc", + "rmp-serde", + "serde", + "thiserror", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "ipc-runtime" +version = "0.1.0" +dependencies = [ + "cc", +] + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rmp" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba8be72d372b2c9b35542551678538b562e7cf86c3315773cae48dfbfe7790c" +dependencies = [ + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72f81bee8c8ef9b577d1681a70ebbc962c232461e397b22c208c43c04b67a155" +dependencies = [ + "rmp", + "serde", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" diff --git a/ipc-codegen/echo_example/rust/Cargo.toml b/ipc-codegen/echo_example/rust/Cargo.toml new file mode 100644 index 000000000000..b951bba882df --- /dev/null +++ b/ipc-codegen/echo_example/rust/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "echo-wire-compat" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "echo_server" +path = "src/echo_server.rs" + +[[bin]] +name = "echo_client" +path = "src/echo_client.rs" + +[features] +default = ["ipc-runtime"] +ipc-runtime = ["dep:ipc-runtime"] +# Compile-checks the generated FFI backend (no real FFI library is linked; +# the extern symbol is only required at link time of a consumer binary). +ffi = ["dep:libc"] + +[dependencies] +rmp-serde = "1.1" +serde = { version = "1.0", features = ["derive"] } +thiserror = "1.0" +ipc-runtime = { path = "../../../ipc-runtime/rust", optional = true } +libc = { version = "0.2", optional = true } diff --git a/ipc-codegen/echo_example/rust/README.md b/ipc-codegen/echo_example/rust/README.md new file mode 100644 index 000000000000..5acd0c5e02f4 --- /dev/null +++ b/ipc-codegen/echo_example/rust/README.md @@ -0,0 +1,17 @@ +# Rust Echo Example + +Build from this directory: + +```sh +./bootstrap.sh +``` + +The Cargo project depends on `ipc-runtime/rust` via a repo-relative path. +Binaries are written to `target/debug/`. + +Run locally: + +```sh +target/debug/echo_server --socket /tmp/echo.sock +target/debug/echo_client --socket /tmp/echo.sock +``` diff --git a/ipc-codegen/echo_example/rust/bootstrap.sh b/ipc-codegen/echo_example/rust/bootstrap.sh new file mode 100755 index 000000000000..570221c29f41 --- /dev/null +++ b/ipc-codegen/echo_example/rust/bootstrap.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +CODEGEN="$(cd "$DIR/../.." && pwd)" +NODE="node --experimental-strip-types --experimental-transform-types --no-warnings" + +$NODE "$CODEGEN/src/generate.ts" \ + --schema "$DIR/../schema/schema.jsonc" \ + --lang rust \ + --server \ + --client \ + --uds \ + --ffi \ + --out "$DIR/src/generated" + +(cd "$DIR" && cargo build --locked --quiet) +# Compile-check the generated FFI backend (not linked into the binaries). +(cd "$DIR" && cargo check --locked --quiet --features ffi) diff --git a/ipc-codegen/echo_example/rust/src/bin/generate_golden.rs b/ipc-codegen/echo_example/rust/src/bin/generate_golden.rs new file mode 100644 index 000000000000..6311becfcba5 --- /dev/null +++ b/ipc-codegen/echo_example/rust/src/bin/generate_golden.rs @@ -0,0 +1,248 @@ +//! Generate golden msgpack files for wire compatibility testing. +//! Usage: generate_golden --output-dir golden/ +//! +//! The goldens are a binding wire-format contract: any new implementation +//! of the echo service (in any language) must decode these bytes into the +//! expected values, and re-encode the same inputs back to byte-identical +//! output. They cover the msgpack encoding boundaries that codegen tweaks +//! are most likely to silently break: +//! +//! - Variable-width integer encodings (fixint / uint8 / uint16 / uint32 / uint64) +//! - String encodings (fixstr / str8 / str16) plus multi-byte UTF-8 +//! - Bin encodings (bin8 / bin16) +//! - Optional = Some vs None +//! - Empty containers + +use echo_wire_compat::types_gen::*; +use std::fs; +use std::path::Path; + +fn main() { + let args: Vec = std::env::args().collect(); + let output_dir = args + .iter() + .position(|a| a == "--output-dir") + .and_then(|i| args.get(i + 1)) + .expect("Usage: generate_golden --output-dir "); + + fs::create_dir_all(output_dir).unwrap(); + + // ---------------------------------------------------------------------- + // Original happy-path cases. + // ---------------------------------------------------------------------- + write_request( + output_dir, + "echo_bytes_request.msgpack", + Command::EchoBytes(EchoBytes::new(vec![0xDE, 0xAD, 0xBE, 0xEF, 0x42])), + ); + + write_request( + output_dir, + "echo_fields_request.msgpack", + Command::EchoFields(EchoFields::new(42, 999999, "hello wire compat".to_string())), + ); + + write_request( + output_dir, + "echo_nested_request.msgpack", + Command::EchoNested(EchoNested::new(EchoInner { + values: vec![vec![1, 2, 3], vec![4, 5]], + flag: Some(true), + })), + ); + + let hash = test_hash(0x10); + let second = test_hash(0x40); + write_request( + output_dir, + "echo_aliases_request.msgpack", + Command::EchoAliases(EchoAliases::new( + 7, + hash.clone(), + Some(second.clone()), + vec![hash.clone(), second.clone()], + )), + ); + + write_response( + output_dir, + "echo_bytes_response.msgpack", + Response::EchoBytesResponse(EchoBytesResponse { + data: vec![0xDE, 0xAD, 0xBE, 0xEF, 0x42], + }), + ); + + write_response( + output_dir, + "echo_fields_response.msgpack", + Response::EchoFieldsResponse(EchoFieldsResponse { + a: 42, + b: 999999, + name: "hello wire compat".to_string(), + }), + ); + + write_response( + output_dir, + "echo_nested_response.msgpack", + Response::EchoNestedResponse(EchoNestedResponse { + inner: EchoInner { + values: vec![vec![1, 2, 3], vec![4, 5]], + flag: Some(true), + }, + }), + ); + + write_response( + output_dir, + "echo_aliases_response.msgpack", + Response::EchoAliasesResponse(EchoAliasesResponse { + tree_id: 7, + hash: hash.clone(), + maybe_hash: Some(second.clone()), + hashes: vec![hash, second], + }), + ); + + // ---------------------------------------------------------------------- + // Boundary cases — these are what catch silent format regressions. + // ---------------------------------------------------------------------- + + // Empty Vec. bin8-with-len-0 vs bin16-with-len-0 vs absent — picks one. + write_request( + output_dir, + "echo_bytes_empty.msgpack", + Command::EchoBytes(EchoBytes::new(vec![])), + ); + + // 256-byte Vec. Crosses the bin8 → bin16 boundary (bin8 max is 255). + write_request( + output_dir, + "echo_bytes_bin16.msgpack", + Command::EchoBytes(EchoBytes::new(vec![0xAA; 256])), + ); + + // u32::MAX (= 2^32 - 1) and u64::MAX. Largest uint encodings; empty string + // exercises fixstr-len-0 framing. + write_request( + output_dir, + "echo_fields_max.msgpack", + Command::EchoFields(EchoFields::new(u32::MAX, u64::MAX, String::new())), + ); + + // u32 = 128 (smallest uint8) and u64 above u32::MAX (forces uint64 encoding). + write_request( + output_dir, + "echo_fields_uint_boundary.msgpack", + Command::EchoFields(EchoFields::new(128, (u32::MAX as u64) + 1, "x".to_string())), + ); + + // Multi-byte UTF-8 in name. Catches encoders that mistakenly count bytes + // by char-count, or that switch str/bin tags depending on content. + write_request( + output_dir, + "echo_fields_unicode.msgpack", + Command::EchoFields(EchoFields::new(0, 0, "héllo τέστ 🚀 mañana".to_string())), + ); + + // 300-char ASCII string. Crosses fixstr (≤31) → str8 (≤255) → str16 boundary. + write_request( + output_dir, + "echo_fields_str16.msgpack", + Command::EchoFields(EchoFields::new(0, 0, "a".repeat(300))), + ); + + // Optional = None plus empty outer Vec>. + write_request( + output_dir, + "echo_nested_flag_none.msgpack", + Command::EchoNested(EchoNested::new(EchoInner { + values: vec![], + flag: None, + })), + ); + + // Optional = Some(false) plus a Vec> containing an empty inner. + write_request( + output_dir, + "echo_nested_flag_false.msgpack", + Command::EchoNested(EchoNested::new(EchoInner { + values: vec![vec![]], + flag: Some(false), + })), + ); + + // ---------------------------------------------------------------------- + // Blob / fail / error cases — optional bytes, fixed-size byte arrays, + // the empty response struct, and the error variant's wire format. + // ---------------------------------------------------------------------- + + // Optional> = Some plus [Vec; 2] fixed array of bins. + write_request( + output_dir, + "echo_blobs_request.msgpack", + Command::EchoBlobs(EchoBlobs::new( + Some(vec![0xAA, 0xBB]), + [vec![1, 2, 3], vec![4]], + )), + ); + + // Optional> = None plus an empty first array element. + write_request( + output_dir, + "echo_blobs_none.msgpack", + Command::EchoBlobs(EchoBlobs::new(None, [vec![], vec![9]])), + ); + + write_response( + output_dir, + "echo_blobs_response.msgpack", + Response::EchoBlobsResponse(EchoBlobsResponse { + maybe_data: Some(vec![0xAA, 0xBB]), + parts: [vec![1, 2, 3], vec![4]], + }), + ); + + write_request( + output_dir, + "echo_fail_request.msgpack", + Command::EchoFail(EchoFail::new("deliberate failure".to_string())), + ); + + // Empty struct — pins how a fieldless payload map is framed. + write_response( + output_dir, + "echo_fail_response.msgpack", + Response::EchoFailResponse(EchoFailResponse {}), + ); + + // Pins the error variant's wire format. + write_response( + output_dir, + "echo_error_response.msgpack", + Response::EchoErrorResponse(EchoErrorResponse { + message: "deliberate failure".to_string(), + }), + ); + + eprintln!("Generated golden files in {}", output_dir); +} + +fn write_request(dir: &str, name: &str, command: Command) { + let value = vec![command]; + let bytes = rmp_serde::to_vec_named(&value).unwrap(); + let path = Path::new(dir).join(name); + fs::write(&path, &bytes).unwrap(); + eprintln!(" {} ({} bytes)", name, bytes.len()); +} + +fn write_response(dir: &str, name: &str, response: Response) { + let bytes = rmp_serde::to_vec_named(&response).unwrap(); + let path = Path::new(dir).join(name); + fs::write(&path, &bytes).unwrap(); + eprintln!(" {} ({} bytes)", name, bytes.len()); +} + +fn test_hash(base: u8) -> Fr { + Fr::from_bytes(std::array::from_fn(|i| base + i as u8)) +} diff --git a/ipc-codegen/echo_example/rust/src/bin/golden_test.rs b/ipc-codegen/echo_example/rust/src/bin/golden_test.rs new file mode 100644 index 000000000000..b0c37bf9e0a3 --- /dev/null +++ b/ipc-codegen/echo_example/rust/src/bin/golden_test.rs @@ -0,0 +1,342 @@ +//! Golden file wire-format conformance test (Rust). +//! For each golden file, asserts: +//! 1. We can decode the bytes into the expected typed value. +//! 2. Re-encoding the same value produces byte-identical output. +//! The combination pins down the wire format as a binding contract. + +use echo_wire_compat::types_gen::*; +use std::fs; + +fn main() { + let args: Vec = std::env::args().collect(); + let dir = args + .iter() + .position(|a| a == "--golden-dir") + .and_then(|i| args.get(i + 1)) + .expect("Usage: golden_test --golden-dir "); + + let mut pass = 0; + let mut fail = 0; + + // Helpers close over (pass, fail) via outparams. + let bytes_eq = |a: &[u8], b: &[u8]| -> bool { a == b }; + + // ------ Request goldens (wire format: Vec) ------ + + macro_rules! check_request { + ($file:expr, $variant:ident, $expect_check:expr) => {{ + let path = format!("{dir}/{}", $file); + match fs::read(&path) { + Err(e) => { eprintln!(" FAIL: {}: read: {e}", $file); fail += 1; } + Ok(bytes) => { + match rmp_serde::from_slice::>(&bytes) { + Err(e) => { eprintln!(" FAIL: {}: decode: {e}", $file); fail += 1; } + Ok(cmds) if cmds.len() != 1 => { + eprintln!(" FAIL: {}: expected 1 command, got {}", $file, cmds.len()); + fail += 1; + } + Ok(cmds) => match cmds.into_iter().next().unwrap() { + Command::$variant(v) => { + let check_fn: fn(&_) -> Result<(), String> = $expect_check; + if let Err(e) = check_fn(&v) { + eprintln!(" FAIL: {}: {e}", $file); + fail += 1; + } else { + // Roundtrip: re-encode and compare bytes. + let re = rmp_serde::to_vec_named(&vec![Command::$variant(v)]).unwrap(); + if !bytes_eq(&re, &bytes) { + eprintln!(" FAIL: {}: roundtrip byte mismatch ({} vs {} bytes)", + $file, re.len(), bytes.len()); + fail += 1; + } else { + eprintln!(" PASS: {}", $file); + pass += 1; + } + } + } + other => { + eprintln!(" FAIL: {}: wrong variant ({:?})", $file, std::mem::discriminant(&other)); + fail += 1; + } + } + } + } + } + }}; + } + + // ------ Response goldens (wire format: Response, NamedUnion) ------ + + macro_rules! check_response { + ($file:expr, $variant:ident, $expect_check:expr) => {{ + let path = format!("{dir}/{}", $file); + match fs::read(&path) { + Err(e) => { + eprintln!(" FAIL: {}: read: {e}", $file); + fail += 1; + } + Ok(bytes) => match rmp_serde::from_slice::(&bytes) { + Err(e) => { + eprintln!(" FAIL: {}: decode: {e}", $file); + fail += 1; + } + Ok(Response::$variant(v)) => { + let check_fn: fn(&_) -> Result<(), String> = $expect_check; + if let Err(e) = check_fn(&v) { + eprintln!(" FAIL: {}: {e}", $file); + fail += 1; + } else { + let re = rmp_serde::to_vec_named(&Response::$variant(v)).unwrap(); + if !bytes_eq(&re, &bytes) { + eprintln!(" FAIL: {}: roundtrip byte mismatch", $file); + fail += 1; + } else { + eprintln!(" PASS: {}", $file); + pass += 1; + } + } + } + Ok(_) => { + eprintln!(" FAIL: {}: wrong variant", $file); + fail += 1; + } + }, + } + }}; + } + + // ============ Original happy-path cases ============ + + check_request!("echo_bytes_request.msgpack", EchoBytes, |v: &EchoBytes| { + if v.data != vec![0xDE, 0xAD, 0xBE, 0xEF, 0x42] { + Err("data".into()) + } else { + Ok(()) + } + }); + check_request!( + "echo_fields_request.msgpack", + EchoFields, + |v: &EchoFields| { + if v.a != 42 || v.b != 999999 || v.name != "hello wire compat" { + Err("fields".into()) + } else { + Ok(()) + } + } + ); + check_request!( + "echo_nested_request.msgpack", + EchoNested, + |v: &EchoNested| { + if v.inner.values != vec![vec![1u8, 2, 3], vec![4, 5]] || v.inner.flag != Some(true) { + Err("nested".into()) + } else { + Ok(()) + } + } + ); + check_request!( + "echo_aliases_request.msgpack", + EchoAliases, + |v: &EchoAliases| { + let hash = test_hash(0x10); + let second = test_hash(0x40); + if v.tree_id != 7 + || v.hash != hash + || v.maybe_hash != Some(second.clone()) + || v.hashes != vec![hash, second] + { + Err("aliases".into()) + } else { + Ok(()) + } + } + ); + + check_response!( + "echo_bytes_response.msgpack", + EchoBytesResponse, + |v: &EchoBytesResponse| { + if v.data != vec![0xDE, 0xAD, 0xBE, 0xEF, 0x42] { + Err("data".into()) + } else { + Ok(()) + } + } + ); + check_response!( + "echo_fields_response.msgpack", + EchoFieldsResponse, + |v: &EchoFieldsResponse| { + if v.a != 42 || v.b != 999999 || v.name != "hello wire compat" { + Err("fields".into()) + } else { + Ok(()) + } + } + ); + check_response!( + "echo_nested_response.msgpack", + EchoNestedResponse, + |v: &EchoNestedResponse| { + if v.inner.values != vec![vec![1u8, 2, 3], vec![4, 5]] || v.inner.flag != Some(true) { + Err("nested".into()) + } else { + Ok(()) + } + } + ); + check_response!( + "echo_aliases_response.msgpack", + EchoAliasesResponse, + |v: &EchoAliasesResponse| { + let hash = test_hash(0x10); + let second = test_hash(0x40); + if v.tree_id != 7 + || v.hash != hash + || v.maybe_hash != Some(second.clone()) + || v.hashes != vec![hash, second] + { + Err("aliases".into()) + } else { + Ok(()) + } + } + ); + + // ============ Boundary cases ============ + + check_request!("echo_bytes_empty.msgpack", EchoBytes, |v: &EchoBytes| { + if !v.data.is_empty() { + Err(format!("expected empty, got {} bytes", v.data.len())) + } else { + Ok(()) + } + }); + check_request!("echo_bytes_bin16.msgpack", EchoBytes, |v: &EchoBytes| { + if v.data.len() != 256 || v.data.iter().any(|&b| b != 0xAA) { + Err("expected 256 x 0xAA".into()) + } else { + Ok(()) + } + }); + check_request!("echo_fields_max.msgpack", EchoFields, |v: &EchoFields| { + if v.a != u32::MAX || v.b != u64::MAX || !v.name.is_empty() { + Err("expected u32::MAX/u64::MAX/empty".into()) + } else { + Ok(()) + } + }); + check_request!( + "echo_fields_uint_boundary.msgpack", + EchoFields, + |v: &EchoFields| { + if v.a != 128 || v.b != (u32::MAX as u64) + 1 || v.name != "x" { + Err("expected 128/u32max+1/\"x\"".into()) + } else { + Ok(()) + } + } + ); + check_request!( + "echo_fields_unicode.msgpack", + EchoFields, + |v: &EchoFields| { + if v.name != "héllo τέστ 🚀 mañana" { + Err(format!("unicode mismatch: {:?}", v.name)) + } else { + Ok(()) + } + } + ); + check_request!("echo_fields_str16.msgpack", EchoFields, |v: &EchoFields| { + if v.name.len() != 300 || v.name.chars().any(|c| c != 'a') { + Err("expected 300 x 'a'".into()) + } else { + Ok(()) + } + }); + check_request!( + "echo_nested_flag_none.msgpack", + EchoNested, + |v: &EchoNested| { + if !v.inner.values.is_empty() || v.inner.flag.is_some() { + Err("expected empty values + flag=None".into()) + } else { + Ok(()) + } + } + ); + check_request!( + "echo_nested_flag_false.msgpack", + EchoNested, + |v: &EchoNested| { + if v.inner.values != vec![Vec::::new()] || v.inner.flag != Some(false) { + Err("expected [[]] + flag=Some(false)".into()) + } else { + Ok(()) + } + } + ); + + // ============ Blob / fail / error cases ============ + + check_request!("echo_blobs_request.msgpack", EchoBlobs, |v: &EchoBlobs| { + if v.maybe_data != Some(vec![0xAA, 0xBB]) || v.parts != [vec![1u8, 2, 3], vec![4]] { + Err("blobs".into()) + } else { + Ok(()) + } + }); + check_request!("echo_blobs_none.msgpack", EchoBlobs, |v: &EchoBlobs| { + if v.maybe_data.is_some() || v.parts != [Vec::::new(), vec![9]] { + Err("expected maybeData=None + [[], [9]]".into()) + } else { + Ok(()) + } + }); + check_response!( + "echo_blobs_response.msgpack", + EchoBlobsResponse, + |v: &EchoBlobsResponse| { + if v.maybe_data != Some(vec![0xAA, 0xBB]) || v.parts != [vec![1u8, 2, 3], vec![4]] { + Err("blobs".into()) + } else { + Ok(()) + } + } + ); + check_request!("echo_fail_request.msgpack", EchoFail, |v: &EchoFail| { + if v.message != "deliberate failure" { + Err(format!("message mismatch: {:?}", v.message)) + } else { + Ok(()) + } + }); + check_response!( + "echo_fail_response.msgpack", + EchoFailResponse, + |_v: &EchoFailResponse| Ok(()) + ); + check_response!( + "echo_error_response.msgpack", + EchoErrorResponse, + |v: &EchoErrorResponse| { + if v.message != "deliberate failure" { + Err(format!("message mismatch: {:?}", v.message)) + } else { + Ok(()) + } + } + ); + + eprintln!("\nResults: {pass}/{} passed, {fail} failed", pass + fail); + if fail > 0 { + std::process::exit(1); + } +} + +fn test_hash(base: u8) -> Fr { + Fr::from_bytes(std::array::from_fn(|i| base + i as u8)) +} diff --git a/ipc-codegen/echo_example/rust/src/echo_client.rs b/ipc-codegen/echo_example/rust/src/echo_client.rs new file mode 100644 index 000000000000..678d5e9f7fb3 --- /dev/null +++ b/ipc-codegen/echo_example/rust/src/echo_client.rs @@ -0,0 +1,94 @@ +//! Echo IPC client — uses GENERATED typed client (EchoApi) over ipc-runtime. +//! Usage: echo_client --socket /tmp/echo.sock +//! Exits 0 on success, 1 on failure. + +use echo_wire_compat::generated::echo_client::EchoApi; +use echo_wire_compat::generated::echo_types::{EchoInner, Fr}; +use echo_wire_compat::generated::error::{IpcError, Result}; +use ipc_runtime::IpcClient; + +fn main() -> Result<()> { + let args: Vec = std::env::args().collect(); + let socket_path = args + .iter() + .position(|a| a == "--socket") + .and_then(|i| args.get(i + 1)) + .expect("Usage: echo_client --socket "); + + let backend = + IpcClient::from_path(socket_path).map_err(|e| IpcError::Backend(e.to_string()))?; + let mut client = EchoApi::new(backend); + + // Test 1: EchoBytes + let test_data = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x42]; + let resp = client.bytes(&test_data)?; + assert_eq!(resp.data, test_data, "EchoBytes data mismatch"); + eprintln!("echo_client(rust): EchoBytes OK"); + + // Test 2: EchoFields + let resp = client.fields(42, 999999, "hello wire compat".to_string())?; + assert_eq!(resp.a, 42); + assert_eq!(resp.b, 999999); + assert_eq!(resp.name, "hello wire compat"); + eprintln!("echo_client(rust): EchoFields OK"); + + // Test 3: EchoNested + let inner = EchoInner { + values: vec![vec![1, 2, 3], vec![4, 5]], + flag: Some(true), + }; + let resp = client.nested(inner.clone())?; + assert_eq!(resp.inner.values, inner.values); + assert_eq!(resp.inner.flag, inner.flag); + eprintln!("echo_client(rust): EchoNested OK"); + + // Test 4: EchoAliases + let hash = Fr::from_bytes(std::array::from_fn(|i| 0x10 + i as u8)); + let second = Fr::from_bytes(std::array::from_fn(|i| 0x40 + i as u8)); + let resp = client.aliases( + 7, + hash.clone(), + Some(second.clone()), + vec![hash.clone(), second.clone()], + )?; + assert_eq!(resp.tree_id, 7); + assert_eq!(resp.hash, hash); + assert_eq!(resp.maybe_hash, Some(second.clone())); + assert_eq!(resp.hashes, vec![hash.clone(), second.clone()]); + eprintln!("echo_client(rust): EchoAliases OK"); + + // Test 5: EchoAliases with maybe_hash = None (optional-absent over live IPC) + let resp = client.aliases(7, hash.clone(), None, vec![hash.clone()])?; + assert_eq!(resp.maybe_hash, None); + eprintln!("echo_client(rust): EchoAliases none OK"); + + // Test 6: EchoFields with b > u32::MAX (uint64 wire encoding over live IPC) + let big = (1u64 << 53) - 1; + let resp = client.fields(42, big, "big".to_string())?; + assert_eq!(resp.b, big); + eprintln!("echo_client(rust): EchoFields u64 OK"); + + // Test 7: EchoBlobs — Option Some/None and [bytes; 2] + let resp = client.blobs(Some(vec![0xAA, 0xBB]), [vec![1, 2, 3], vec![4]])?; + assert_eq!(resp.maybe_data, Some(vec![0xAA, 0xBB])); + assert_eq!(resp.parts, [vec![1, 2, 3], vec![4]]); + let resp = client.blobs(None, [vec![], vec![9]])?; + assert_eq!(resp.maybe_data, None); + assert_eq!(resp.parts, [vec![], vec![9]]); + eprintln!("echo_client(rust): EchoBlobs OK"); + + // Test 8: EchoFail — server error surfaces with its message + match client.fail("deliberate failure".to_string()) { + Err(IpcError::Backend(message)) => { + assert!( + message.contains("deliberate failure"), + "EchoFail message mismatch: {message}" + ); + } + other => panic!("EchoFail expected backend error, got {other:?}"), + } + eprintln!("echo_client(rust): EchoFail OK"); + + eprintln!("echo_client(rust): all tests passed"); + Ok(()) +} diff --git a/ipc-codegen/echo_example/rust/src/echo_server.rs b/ipc-codegen/echo_example/rust/src/echo_server.rs new file mode 100644 index 000000000000..0c59196345c9 --- /dev/null +++ b/ipc-codegen/echo_example/rust/src/echo_server.rs @@ -0,0 +1,69 @@ +//! Echo IPC server — uses GENERATED dispatch + types + ipc-runtime transport. +//! Usage: echo_server --socket /tmp/echo.sock + +use echo_wire_compat::generated::echo_server::Handler; +use echo_wire_compat::generated::echo_types::*; +use echo_wire_compat::generated::error::{IpcError, Result}; +use ipc_runtime::IpcServer; +use std::cell::RefCell; + +struct EchoHandler; + +impl Handler for EchoHandler { + fn bytes(&mut self, cmd: EchoBytes) -> Result { + Ok(EchoBytesResponse { data: cmd.data }) + } + fn fields(&mut self, cmd: EchoFields) -> Result { + Ok(EchoFieldsResponse { + a: cmd.a, + b: cmd.b, + name: cmd.name, + }) + } + fn nested(&mut self, cmd: EchoNested) -> Result { + Ok(EchoNestedResponse { inner: cmd.inner }) + } + fn aliases(&mut self, cmd: EchoAliases) -> Result { + Ok(EchoAliasesResponse { + tree_id: cmd.tree_id, + hash: cmd.hash, + maybe_hash: cmd.maybe_hash, + hashes: cmd.hashes, + }) + } + fn blobs(&mut self, cmd: EchoBlobs) -> Result { + Ok(EchoBlobsResponse { + maybe_data: cmd.maybe_data, + parts: cmd.parts, + }) + } + fn fail(&mut self, cmd: EchoFail) -> Result { + Err(IpcError::Backend(cmd.message)) + } +} + +fn main() { + let args: Vec = std::env::args().collect(); + let socket_path = args + .iter() + .position(|a| a == "--socket") + .and_then(|i| args.get(i + 1)) + .expect("Usage: echo_server --socket "); + + let _ = std::fs::remove_file(socket_path); + + // Wrap handler in RefCell so the FnMut closure can borrow mutably across + // dispatches. + let handler = RefCell::new(EchoHandler); + + let mut server = IpcServer::from_path(socket_path).expect("IpcServer::from_path"); + server.install_default_signal_handlers(); + server.listen().expect("IpcServer::listen"); + + server.run(|_client_id, payload| { + echo_wire_compat::generated::echo_server::handle_request( + &mut *handler.borrow_mut(), + payload, + ) + }); +} diff --git a/ipc-codegen/echo_example/rust/src/lib.rs b/ipc-codegen/echo_example/rust/src/lib.rs new file mode 100644 index 000000000000..8f7c1acd17e8 --- /dev/null +++ b/ipc-codegen/echo_example/rust/src/lib.rs @@ -0,0 +1,18 @@ +// Generated modules live in src/generated/. Transport comes from the +// `ipc-runtime` crate; the per-language UDS template that used to live +// here (ipc_server.rs / uds_backend.rs) is gone — the runtime is shared. +pub mod generated { + pub mod backend; + pub mod echo_client; + pub mod echo_server; + pub mod echo_types; + pub mod error; + #[cfg(feature = "ffi")] + pub mod ffi_backend; +} + +// Re-export under the names that generated server/client code expects +// (they use `crate::types_gen`, `crate::error`, `crate::backend`) +pub use generated::backend; +pub use generated::echo_types as types_gen; +pub use generated::error; diff --git a/ipc-codegen/echo_example/schema/golden/echo_aliases_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_aliases_request.msgpack new file mode 100644 index 000000000000..82d73fb28ae1 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_aliases_request.msgpack @@ -0,0 +1 @@ +EchoAliasestreeIdhash  !"#$%&'()*+,-./maybeHash @ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_hashes  !"#$%&'()*+,-./ @ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_ \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_aliases_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_aliases_response.msgpack new file mode 100644 index 000000000000..4a738abb3220 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_aliases_response.msgpack @@ -0,0 +1 @@ +EchoAliasesResponsetreeIdhash  !"#$%&'()*+,-./maybeHash @ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_hashes  !"#$%&'()*+,-./ @ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_ \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_blobs_none.msgpack b/ipc-codegen/echo_example/schema/golden/echo_blobs_none.msgpack new file mode 100644 index 000000000000..30e3d913167f Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_blobs_none.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_blobs_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_blobs_request.msgpack new file mode 100644 index 000000000000..caea91e565b9 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_blobs_request.msgpack @@ -0,0 +1 @@ +EchoBlobsmaybeDataparts \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_blobs_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_blobs_response.msgpack new file mode 100644 index 000000000000..99fc821d909e --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_blobs_response.msgpack @@ -0,0 +1 @@ +EchoBlobsResponsemaybeDataparts \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_bytes_bin16.msgpack b/ipc-codegen/echo_example/schema/golden/echo_bytes_bin16.msgpack new file mode 100644 index 000000000000..a24108950f18 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_bytes_bin16.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_bytes_empty.msgpack b/ipc-codegen/echo_example/schema/golden/echo_bytes_empty.msgpack new file mode 100644 index 000000000000..08696c9d133f Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_bytes_empty.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_bytes_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_bytes_request.msgpack new file mode 100644 index 000000000000..fbd6f3e3f32d --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_bytes_request.msgpack @@ -0,0 +1 @@ +EchoBytesdataޭB \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_bytes_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_bytes_response.msgpack new file mode 100644 index 000000000000..f15b07218068 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_bytes_response.msgpack @@ -0,0 +1 @@ +EchoBytesResponsedataޭB \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_error_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_error_response.msgpack new file mode 100644 index 000000000000..022ea01a0330 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_error_response.msgpack @@ -0,0 +1 @@ +EchoErrorResponsemessagedeliberate failure \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_fail_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fail_request.msgpack new file mode 100644 index 000000000000..6475b2df0c11 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_fail_request.msgpack @@ -0,0 +1 @@ +EchoFailmessagedeliberate failure \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_fail_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fail_response.msgpack new file mode 100644 index 000000000000..1d6daa31bc9f --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_fail_response.msgpack @@ -0,0 +1 @@ +EchoFailResponse \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_max.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_max.msgpack new file mode 100644 index 000000000000..46fa67c1a04a --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_fields_max.msgpack @@ -0,0 +1 @@ +EchoFieldsabname \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_request.msgpack new file mode 100644 index 000000000000..d72172943bc0 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_fields_request.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_response.msgpack new file mode 100644 index 000000000000..1b2aba194ac4 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_fields_response.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_str16.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_str16.msgpack new file mode 100644 index 000000000000..69649cf16c88 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_fields_str16.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_uint_boundary.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_uint_boundary.msgpack new file mode 100644 index 000000000000..9f17a3ebe15b Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_fields_uint_boundary.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_fields_unicode.msgpack b/ipc-codegen/echo_example/schema/golden/echo_fields_unicode.msgpack new file mode 100644 index 000000000000..0691f2bfbd58 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_fields_unicode.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_nested_flag_false.msgpack b/ipc-codegen/echo_example/schema/golden/echo_nested_flag_false.msgpack new file mode 100644 index 000000000000..30af0cb49794 Binary files /dev/null and b/ipc-codegen/echo_example/schema/golden/echo_nested_flag_false.msgpack differ diff --git a/ipc-codegen/echo_example/schema/golden/echo_nested_flag_none.msgpack b/ipc-codegen/echo_example/schema/golden/echo_nested_flag_none.msgpack new file mode 100644 index 000000000000..b5d2addb83d6 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_nested_flag_none.msgpack @@ -0,0 +1 @@ +EchoNestedinnervaluesflag \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_nested_request.msgpack b/ipc-codegen/echo_example/schema/golden/echo_nested_request.msgpack new file mode 100644 index 000000000000..6c8ded7184ed --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_nested_request.msgpack @@ -0,0 +1 @@ +EchoNestedinnervaluesflag \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/golden/echo_nested_response.msgpack b/ipc-codegen/echo_example/schema/golden/echo_nested_response.msgpack new file mode 100644 index 000000000000..c8966a9d4f46 --- /dev/null +++ b/ipc-codegen/echo_example/schema/golden/echo_nested_response.msgpack @@ -0,0 +1 @@ +EchoNestedResponseinnervaluesflag \ No newline at end of file diff --git a/ipc-codegen/echo_example/schema/schema.jsonc b/ipc-codegen/echo_example/schema/schema.jsonc new file mode 100644 index 000000000000..04c9a775dc85 --- /dev/null +++ b/ipc-codegen/echo_example/schema/schema.jsonc @@ -0,0 +1,44 @@ +// echo.schema.jsonc — the entire echo service, human-authored. +{ + "service": "Echo", + + // Named byte aliases (nominal 32-byte types). Only bin32 today. + "aliases": { + "Fr": "bin32" + }, + + // Shared struct types, referenced by name from commands. + "types": { + "EchoInner": { + "values": "bytes[]", + "flag": "bool?" + } + }, + + // Error variant shared by every command. + "error": { "message": "string" }, + + // command -> { request, response }. Names are unprefixed; generated type + // names get the service prefix (EchoBytes), method names do not (bytes). + "commands": { + "Bytes": { "request": { "data": "bytes" }, + "response": { "data": "bytes" } }, + + "Fields": { "request": { "a": "u32", "b": "u64", "name": "string" }, + "response": { "a": "u32", "b": "u64", "name": "string" } }, + + "Nested": { "request": { "inner": "EchoInner" }, + "response": { "inner": "EchoInner" } }, + + "Aliases": { "request": { "treeId": "u32", "hash": "Fr", + "maybeHash": "Fr?", "hashes": "Fr[]" }, + "response": { "treeId": "u32", "hash": "Fr", + "maybeHash": "Fr?", "hashes": "Fr[]" } }, + + "Blobs": { "request": { "maybeData": "bytes?", "parts": "bytes[2]" }, + "response": { "maybeData": "bytes?", "parts": "bytes[2]" } }, + + "Fail": { "request": { "message": "string" }, + "response": {} } + } +} diff --git a/ipc-codegen/echo_example/scripts/run_cross_language_test.sh b/ipc-codegen/echo_example/scripts/run_cross_language_test.sh new file mode 100755 index 000000000000..26881e1c83fd --- /dev/null +++ b/ipc-codegen/echo_example/scripts/run_cross_language_test.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash +# +# Run a single cross-language IPC wire-compat test. +# All binaries are expected to be prebuilt by `ipc-codegen/bootstrap.sh build`. +# +# Usage: +# run_cross_language_test.sh golden # lang in {rust, ts, cpp, zig} +# run_cross_language_test.sh matrix [transport] +# # langs in {rust, ts, cpp, zig} +# # transport in {uds, shm}, default uds +# +# SHM transport requires ipc-runtime's NAPI addon (built by +# ipc-runtime/bootstrap.sh) when TS is the client. There's no SHM TS server. +# +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +EXAMPLE_DIR="$(dirname "$SCRIPT_DIR")" +cd "$EXAMPLE_DIR" + +# Map language -> server command / client command. Each command is run with +# `--socket ` appended. +server_cmd_for() { + case "$1" in + rust) echo "rust/target/debug/echo_server" ;; + ts) echo "ts/node_modules/.bin/tsx ts/src/echo_server.ts" ;; + cpp) echo "cpp/build/bin/echo_server" ;; + zig) echo "zig/zig-out/bin/echo_server" ;; + *) echo "unknown lang: $1" >&2; exit 1 ;; + esac +} + +client_cmd_for() { + case "$1" in + rust) echo "rust/target/debug/echo_client" ;; + ts) echo "ts/node_modules/.bin/tsx ts/src/echo_client.ts" ;; + cpp) echo "cpp/build/bin/echo_client" ;; + zig) echo "zig/zig-out/bin/echo_client" ;; + *) echo "unknown lang: $1" >&2; exit 1 ;; + esac +} + +run_golden() { + local lang="$1" + case "$lang" in + rust) + rust/target/debug/golden_test --golden-dir schema/golden + ;; + ts) + ts/node_modules/.bin/tsx ts/src/golden_test.ts + ;; + cpp) + cpp/build/bin/golden_test --golden-dir schema/golden + ;; + zig) + zig/zig-out/bin/golden_test --golden-dir schema/golden + ;; + *) + echo "golden tests only defined for rust, ts, cpp and zig (got: $lang)" >&2 + exit 1 + ;; + esac +} + +run_matrix() { + local server_lang="$1" + local client_lang="$2" + local transport="${3:-uds}" + local server_cmd client_cmd + server_cmd=$(server_cmd_for "$server_lang") + client_cmd=$(client_cmd_for "$client_lang") + + if [ "$transport" = "shm" ] && [ "$server_lang" = "ts" ]; then + echo "shm transport not supported as TS server (no shm_server in ipc-runtime/ts)" >&2 + exit 1 + fi + + local ext path basename + case "$transport" in + uds) ext="sock" ;; + shm) ext="shm" ;; + *) + echo "unknown transport: $transport (expected uds|shm)" >&2; exit 1 ;; + esac + basename="echo-matrix-${server_lang}-${client_lang}-${transport}-$$" + path="${basename}.${ext}" + + # Spawn first, then install cleanup. Servers install SIGTERM handlers where + # the runtime needs graceful shutdown, so waiting lets transport close paths + # unlink their own resources. + $server_cmd --socket "$path" & + server_pid=$! + trap "kill ${server_pid} 2>/dev/null || true; \ + wait ${server_pid} 2>/dev/null || true; \ + rm -f '$path'" EXIT + + if [ "$transport" = "shm" ] && [ "$client_lang" = "ts" ]; then + $client_cmd --socket "$path" --transport shm + else + $client_cmd --socket "$path" + fi +} + +kind="${1:-}" +case "$kind" in + golden) + run_golden "${2:?golden requires }" + ;; + matrix) + run_matrix "${2:?matrix requires }" "${3:?matrix requires }" "${4:-uds}" + ;; + *) + echo "Usage: $0 golden | matrix [uds|shm]" >&2 + exit 1 + ;; +esac diff --git a/ipc-codegen/echo_example/ts/.gitignore b/ipc-codegen/echo_example/ts/.gitignore new file mode 100644 index 000000000000..af9c5a6e01db --- /dev/null +++ b/ipc-codegen/echo_example/ts/.gitignore @@ -0,0 +1,2 @@ +node_modules/ +src/generated/ diff --git a/ipc-codegen/echo_example/ts/README.md b/ipc-codegen/echo_example/ts/README.md new file mode 100644 index 000000000000..bacbffdf8664 --- /dev/null +++ b/ipc-codegen/echo_example/ts/README.md @@ -0,0 +1,18 @@ +# TypeScript Echo Example + +Build from this directory: + +```sh +./bootstrap.sh +``` + +The package consumes `@aztec/ipc-runtime` via a repo-relative `file:` +dependency. The bootstrap builds `ipc-runtime/ts` before installing this +example so the file-linked package contains compiled output. + +Run locally: + +```sh +node_modules/.bin/tsx src/echo_server.ts --socket /tmp/echo.sock +node_modules/.bin/tsx src/echo_client.ts --socket /tmp/echo.sock +``` diff --git a/ipc-codegen/echo_example/ts/bootstrap.sh b/ipc-codegen/echo_example/ts/bootstrap.sh new file mode 100755 index 000000000000..59c15eba21d2 --- /dev/null +++ b/ipc-codegen/echo_example/ts/bootstrap.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +CODEGEN="$(cd "$DIR/../.." && pwd)" +REPO_ROOT="$(cd "$CODEGEN/.." && pwd)" +NODE="node --experimental-strip-types --experimental-transform-types --no-warnings" + +$NODE "$CODEGEN/src/generate.ts" \ + --schema "$DIR/../schema/schema.jsonc" \ + --lang ts \ + --server \ + --client \ + --out "$DIR/src/generated" + +(cd "$REPO_ROOT/ipc-runtime" && ./bootstrap.sh) +(cd "$REPO_ROOT/ipc-runtime/ts" && yarn install --immutable && yarn build) +rm -rf "$DIR/node_modules" +(cd "$DIR" && npm install --no-package-lock --quiet) +(cd "$DIR" && node_modules/.bin/tsc --noEmit) diff --git a/ipc-codegen/echo_example/ts/package.json b/ipc-codegen/echo_example/ts/package.json new file mode 100644 index 000000000000..b06729bb57f2 --- /dev/null +++ b/ipc-codegen/echo_example/ts/package.json @@ -0,0 +1,14 @@ +{ + "name": "echo-wire-compat-ts", + "private": true, + "type": "module", + "dependencies": { + "@aztec/ipc-runtime": "file:../../../ipc-runtime/ts", + "msgpackr": "^1.10.0", + "tsx": "^4.19.0" + }, + "devDependencies": { + "@types/node": "^22.0.0", + "typescript": "^5.6.0" + } +} diff --git a/ipc-codegen/echo_example/ts/src/echo_client.ts b/ipc-codegen/echo_example/ts/src/echo_client.ts new file mode 100644 index 000000000000..81758913ce91 --- /dev/null +++ b/ipc-codegen/echo_example/ts/src/echo_client.ts @@ -0,0 +1,170 @@ +/** + * Echo IPC client (TypeScript) — uses the GENERATED AsyncApi client over the + * @aztec/ipc-runtime transport. Defaults to UDS; pass `--transport shm` to + * drive the bundled NAPI SHM client (`createNapiShmAsyncClient`) instead. + * Path suffix follows the same convention as ipc::make_client on the C++ + * side: `.sock` for UDS, `.shm` for MPSC-SHM rings. + * + * Usage: npx tsx echo_client.ts --socket /tmp/echo.sock [--transport uds|shm] + * Exits 0 on success, 1 on failure. + */ +import { + createNapiShmAsyncClient, + UdsIpcClient, + type IpcClientAsync, +} from "@aztec/ipc-runtime"; +import { AsyncApi } from "./generated/async.js"; + +const args = process.argv.slice(2); +const socketIdx = args.indexOf("--socket"); +const socketPath = socketIdx >= 0 ? args[socketIdx + 1] : undefined; +if (!socketPath) { + console.error("Usage: echo_client.ts --socket [--transport uds|shm]"); + process.exit(1); +} + +function testHash(base: number): Uint8Array { + return Uint8Array.from({ length: 32 }, (_v, i) => base + i); +} +const transportIdx = args.indexOf("--transport"); +const transport = transportIdx >= 0 ? args[transportIdx + 1] : "uds"; +if (transport !== "uds" && transport !== "shm") { + console.error(`Unknown --transport '${transport}' (expected uds|shm)`); + process.exit(1); +} + +function assertEqual(actual: unknown, expected: unknown, label: string) { + const a = JSON.stringify(actual); + const e = JSON.stringify(expected); + if (a !== e) throw new Error(`${label}: expected ${e}, got ${a}`); +} + +function assertBytes(actual: Uint8Array, expected: Uint8Array, label: string) { + assertEqual( + Buffer.from(actual).toString("hex"), + Buffer.from(expected).toString("hex"), + label, + ); +} + +async function run() { + // SHM clients identify the shared-memory base name without the `.shm` + // suffix — match ipc::make_client's behaviour on the C++ side. + const client: IpcClientAsync = + transport === "shm" + ? createNapiShmAsyncClient(socketPath!.replace(/\.shm$/, "")) + : await UdsIpcClient.connect(socketPath!); + const api = new AsyncApi(client); + + // Test 1: EchoBytes + const testData = Uint8Array.from([0xde, 0xad, 0xbe, 0xef, 0x42]); + const resp1 = await api.bytes({ data: testData }); + assertBytes(resp1.data, testData, "EchoBytes data"); + console.error("echo_client(ts): EchoBytes OK"); + + // Test 2: EchoFields + const resp2 = await api.fields({ + a: 42, + b: 999999, + name: "hello wire compat", + }); + assertEqual(resp2.a, 42, "EchoFields a"); + assertEqual(resp2.b, 999999, "EchoFields b"); + assertEqual(resp2.name, "hello wire compat", "EchoFields name field"); + console.error("echo_client(ts): EchoFields OK"); + + // Test 3: EchoNested + const inner = { + values: [Uint8Array.from([1, 2, 3]), Uint8Array.from([4, 5])], + flag: true, + }; + const resp3 = await api.nested({ inner }); + assertEqual(resp3.inner.flag, true, "EchoNested flag"); + assertEqual(resp3.inner.values.length, 2, "EchoNested values length"); + assertBytes(resp3.inner.values[0]!, inner.values[0]!, "EchoNested values[0]"); + console.error("echo_client(ts): EchoNested OK"); + + // Test 4: EchoAliases + const hash = testHash(0x10); + const second = testHash(0x40); + const resp4 = await api.aliases({ + treeId: 7, + hash, + maybeHash: second, + hashes: [hash, second], + }); + assertEqual(resp4.treeId, 7, "EchoAliases treeId"); + assertBytes(resp4.hash, hash, "EchoAliases hash"); + assertBytes(resp4.maybeHash!, second, "EchoAliases maybeHash"); + assertEqual(resp4.hashes.length, 2, "EchoAliases hashes length"); + assertBytes(resp4.hashes[0]!, hash, "EchoAliases hashes[0]"); + assertBytes(resp4.hashes[1]!, second, "EchoAliases hashes[1]"); + console.error("echo_client(ts): EchoAliases OK"); + + // Test 5: EchoAliases with maybeHash absent (optional over live IPC) + const resp5 = await api.aliases({ + treeId: 7, + hash, + maybeHash: null, + hashes: [hash], + }); + assertEqual(resp5.maybeHash, null, "EchoAliases maybeHash none"); + console.error("echo_client(ts): EchoAliases none OK"); + + // Test 6: EchoFields with b > 2^32 (uint64 wire encoding over live IPC) + const big = Number.MAX_SAFE_INTEGER; + const resp6 = await api.fields({ a: 42, b: big, name: "big" }); + assertEqual(resp6.b, big, "EchoFields u64"); + // Values past 2^53 must throw client-side rather than silently lose precision. + let threw = false; + try { + await api.fields({ a: 42, b: 2 ** 60, name: "too big" }); + } catch { + threw = true; + } + assertEqual(threw, true, "EchoFields u64 guard"); + console.error("echo_client(ts): EchoFields u64 OK"); + + // Test 7: EchoBlobs — optional bytes Some/None and fixed [bytes; 2] + const resp7 = await api.blobs({ + maybeData: Uint8Array.from([0xaa, 0xbb]), + parts: [Uint8Array.from([1, 2, 3]), Uint8Array.from([4])], + }); + assertBytes( + resp7.maybeData!, + Uint8Array.from([0xaa, 0xbb]), + "EchoBlobs maybeData", + ); + assertBytes( + resp7.parts[0]!, + Uint8Array.from([1, 2, 3]), + "EchoBlobs parts[0]", + ); + assertBytes(resp7.parts[1]!, Uint8Array.from([4]), "EchoBlobs parts[1]"); + const resp7b = await api.blobs({ + maybeData: null, + parts: [Uint8Array.from([]), Uint8Array.from([9])], + }); + assertEqual(resp7b.maybeData, null, "EchoBlobs maybeData none"); + console.error("echo_client(ts): EchoBlobs OK"); + + // Test 8: EchoFail — server error surfaces with its message + let failMessage = ""; + try { + await api.fail({ message: "deliberate failure" }); + } catch (e: any) { + failMessage = e.message; + } + if (!failMessage.includes("deliberate failure")) { + throw new Error(`EchoFail: expected error message, got '${failMessage}'`); + } + console.error("echo_client(ts): EchoFail OK"); + + await api.destroy(); + console.error("echo_client(ts): all tests passed"); +} + +run().catch((e) => { + console.error(`echo_client(ts): FAILED: ${e.message}`); + process.exit(1); +}); diff --git a/ipc-codegen/echo_example/ts/src/echo_server.ts b/ipc-codegen/echo_example/ts/src/echo_server.ts new file mode 100644 index 000000000000..1951a296b67c --- /dev/null +++ b/ipc-codegen/echo_example/ts/src/echo_server.ts @@ -0,0 +1,68 @@ +/** + * Echo IPC server (TypeScript) — uses the GENERATED handleRequest (framing, + * dispatch, and error wrapping) over the @aztec/ipc-runtime UDS transport. + * Usage: npx tsx echo_server.ts --socket /tmp/echo.sock + */ +import { UdsIpcServer } from "@aztec/ipc-runtime"; +import { handleRequest } from "./generated/server.js"; +import type { Handler } from "./generated/server.js"; +import type { + EchoBytes, + EchoBytesResponse, + EchoAliases, + EchoAliasesResponse, + EchoBlobs, + EchoBlobsResponse, + EchoFail, + EchoFailResponse, + EchoFields, + EchoFieldsResponse, + EchoNested, + EchoNestedResponse, +} from "./generated/api_types.js"; + +const args = process.argv.slice(2); +const socketIdx = args.indexOf("--socket"); +const socketPath = socketIdx >= 0 ? args[socketIdx + 1] : undefined; +if (!socketPath) { + console.error("Usage: echo_server.ts --socket "); + process.exit(1); +} + +const handler: Handler = { + async bytes(cmd: EchoBytes): Promise { + return { data: cmd.data }; + }, + async fields(cmd: EchoFields): Promise { + return { a: cmd.a, b: cmd.b, name: cmd.name }; + }, + async nested(cmd: EchoNested): Promise { + return { inner: cmd.inner }; + }, + async aliases(cmd: EchoAliases): Promise { + return { + treeId: cmd.treeId, + hash: cmd.hash, + maybeHash: cmd.maybeHash, + hashes: cmd.hashes, + }; + }, + async blobs(cmd: EchoBlobs): Promise { + return { maybeData: cmd.maybeData, parts: cmd.parts }; + }, + async fail(cmd: EchoFail): Promise { + throw new Error(cmd.message); + }, +}; + +async function main() { + await UdsIpcServer.listen(socketPath!, (_clientId, requestBytes) => + handleRequest(handler, requestBytes), + ); + console.error(`ipc-server(ts): listening on ${socketPath}`); +} + +main().catch((e) => { + console.error(`echo_server(ts): FAILED: ${e.message ?? e}`); + process.exit(1); +}); diff --git a/ipc-codegen/echo_example/ts/src/golden_test.ts b/ipc-codegen/echo_example/ts/src/golden_test.ts new file mode 100644 index 000000000000..9f52415a0316 --- /dev/null +++ b/ipc-codegen/echo_example/ts/src/golden_test.ts @@ -0,0 +1,268 @@ +/** + * Golden file wire-format conformance test (TypeScript). + * + * For each golden file, asserts: + * 1. We can decode the bytes into the expected typed value. + * 2. Re-encoding the same value produces byte-identical output. + * The combination pins down the wire format as a binding contract. + * + * Usage: npx tsx golden_test.ts + */ + +import * as fs from "node:fs"; +import * as path from "node:path"; +import { Decoder, Encoder } from "msgpackr"; + +const decoder = new Decoder({ useRecords: false }); +// `variableMapSize: true` makes msgpackr emit fixmap (1-byte header) for small +// maps instead of always reaching for map16. Without it the encoder produces +// a semantically-equivalent but byte-different encoding, so round-tripping +// the goldens would fail even though the wire is otherwise correct. +const encoder = new Encoder({ useRecords: false, variableMapSize: true }); +const goldenDir = path.join(import.meta.dirname!, "../../schema", "golden"); + +let pass = 0; +let fail = 0; + +function bytesEqual(a: Uint8Array, b: Uint8Array): boolean { + if (a.length !== b.length) return false; + for (let i = 0; i < a.length; i++) { + if (a[i] !== b[i]) return false; + } + return true; +} + +function deepEqual(a: any, b: any): boolean { + // For our test data: bigints, strings, numbers, plain arrays of u8 (which + // msgpackr decodes as Uint8Array), and nested objects. The JSON-stringify + // trick falls down on bigint and Uint8Array; do a structural walk. + if (a === b) return true; + if (typeof a === "bigint" || typeof b === "bigint") return a === b; + if (a instanceof Uint8Array && b instanceof Uint8Array) + return bytesEqual(a, b); + if (Array.isArray(a) && Array.isArray(b)) { + if (a.length !== b.length) return false; + return a.every((x, i) => deepEqual(x, b[i])); + } + if (a && b && typeof a === "object" && typeof b === "object") { + const ka = Object.keys(a).sort(); + const kb = Object.keys(b).sort(); + if (ka.length !== kb.length || !ka.every((k, i) => k === kb[i])) + return false; + return ka.every((k) => deepEqual(a[k], b[k])); + } + return false; +} + +/** Read golden, decode, check expectation, and (optionally) verify re-encode + * byte-equals golden. Strict roundtrip is the binding wire-format check, but + * msgpackr has one known divergence from rmp-serde: positive bigints are + * encoded as int64 (`d3`) instead of uint64 (`cf`). Both encodings are + * accepted by every msgpack decoder we care about, so the wire is still + * interoperable — we just can't pin the bytes here. */ +function check( + file: string, + expectedDecoded: any, + opts: { strictRoundtrip?: boolean } = {}, +) { + const strictRoundtrip = opts.strictRoundtrip ?? true; + try { + const golden = fs.readFileSync(path.join(goldenDir, file)); + const decoded = decoder.unpack(golden); + if (!deepEqual(decoded, expectedDecoded)) { + throw new Error( + `decoded mismatch:\n got: ${stringify(decoded)}\n exp: ${stringify(expectedDecoded)}`, + ); + } + if (strictRoundtrip) { + const reencoded = encoder.encode(decoded); + if (!bytesEqual(reencoded, golden)) { + throw new Error( + `roundtrip byte mismatch (decoded OK but re-encoded ${reencoded.length} bytes vs golden ${golden.length})`, + ); + } + } + console.log(` PASS: ${file}`); + pass++; + } catch (e: any) { + console.log(` FAIL: ${file}: ${e.message}`); + fail++; + } +} + +function stringify(v: any): string { + return JSON.stringify(v, (_k, x) => { + if (typeof x === "bigint") return `${x}n`; + if (x instanceof Uint8Array) return `[${Array.from(x).join(",")}]`; + return x; + }); +} + +console.log("Golden file wire-format conformance tests (TypeScript):\n"); + +// Request format: [[CommandName, {fields}]] +function req(cmdName: string, fields: any) { + return [[cmdName, fields]]; +} +// Response format: [ResponseName, {fields}] +function resp(respName: string, fields: any) { + return [respName, fields]; +} + +function testHash(base: number): Uint8Array { + return Uint8Array.from({ length: 32 }, (_v, i) => base + i); +} + +// ============ Original happy-path cases ============ +check( + "echo_bytes_request.msgpack", + req("EchoBytes", { data: new Uint8Array([0xde, 0xad, 0xbe, 0xef, 0x42]) }), +); +check( + "echo_fields_request.msgpack", + req("EchoFields", { a: 42, b: 999999, name: "hello wire compat" }), +); +check( + "echo_nested_request.msgpack", + req("EchoNested", { + inner: { + values: [new Uint8Array([1, 2, 3]), new Uint8Array([4, 5])], + flag: true, + }, + }), +); +check( + "echo_aliases_request.msgpack", + req("EchoAliases", { + treeId: 7, + hash: testHash(0x10), + maybeHash: testHash(0x40), + hashes: [testHash(0x10), testHash(0x40)], + }), +); + +check( + "echo_bytes_response.msgpack", + resp("EchoBytesResponse", { + data: new Uint8Array([0xde, 0xad, 0xbe, 0xef, 0x42]), + }), +); +check( + "echo_fields_response.msgpack", + resp("EchoFieldsResponse", { a: 42, b: 999999, name: "hello wire compat" }), +); +check( + "echo_nested_response.msgpack", + resp("EchoNestedResponse", { + inner: { + values: [new Uint8Array([1, 2, 3]), new Uint8Array([4, 5])], + flag: true, + }, + }), +); +check( + "echo_aliases_response.msgpack", + resp("EchoAliasesResponse", { + treeId: 7, + hash: testHash(0x10), + maybeHash: testHash(0x40), + hashes: [testHash(0x10), testHash(0x40)], + }), +); + +// ============ Boundary cases ============ + +// bin8 empty + bin16 (256 bytes) — bin8/bin16 framing boundary. +check("echo_bytes_empty.msgpack", req("EchoBytes", { data: new Uint8Array() })); +check( + "echo_bytes_bin16.msgpack", + req("EchoBytes", { data: new Uint8Array(256).fill(0xaa) }), +); + +// u32::MAX + u64::MAX + empty string. Largest uint encodings; fixstr-len-0. +// msgpackr decodes u64 fields as bigint when > Number.MAX_SAFE_INTEGER (2^53-1). +// Strict roundtrip is OK here because u64::MAX requires uint64 and msgpackr +// agrees with rmp-serde at the extreme. +check( + "echo_fields_max.msgpack", + req("EchoFields", { a: 4294967295, b: 18446744073709551615n, name: "" }), +); + +// u32 = 128 (fixint → uint8 boundary), u64 above u32::MAX (forces uint64). +// strictRoundtrip: false — see check()'s comment about the bigint/uint64 quirk. +check( + "echo_fields_uint_boundary.msgpack", + req("EchoFields", { a: 128, b: 4294967296n, name: "x" }), + { strictRoundtrip: false }, +); + +// Multi-byte UTF-8 in name. +check( + "echo_fields_unicode.msgpack", + req("EchoFields", { a: 0, b: 0, name: "héllo τέστ 🚀 mañana" }), +); + +// 300-char ASCII string. Crosses fixstr (≤31) and str8 (≤255) into str16. +check( + "echo_fields_str16.msgpack", + req("EchoFields", { a: 0, b: 0, name: "a".repeat(300) }), +); + +// Optional = absent (msgpackr decodes missing-with-nil to undefined or +// strips the key entirely depending on the encoder; rmp-serde emits a nil +// value for None inside a struct, so we expect flag: null here). +check( + "echo_nested_flag_none.msgpack", + req("EchoNested", { inner: { values: [], flag: null } }), +); + +// Optional = Some(false) with values=[empty inner]. +check( + "echo_nested_flag_false.msgpack", + req("EchoNested", { inner: { values: [new Uint8Array()], flag: false } }), +); + +// ============ Blob / fail / error cases ============ + +// Optional = Some plus fixed-size [bytes; 2] array. +check( + "echo_blobs_request.msgpack", + req("EchoBlobs", { + maybeData: new Uint8Array([0xaa, 0xbb]), + parts: [new Uint8Array([1, 2, 3]), new Uint8Array([4])], + }), +); + +// Optional = None (nil on the wire) plus an empty first array element. +check( + "echo_blobs_none.msgpack", + req("EchoBlobs", { + maybeData: null, + parts: [new Uint8Array(), new Uint8Array([9])], + }), +); + +check( + "echo_blobs_response.msgpack", + resp("EchoBlobsResponse", { + maybeData: new Uint8Array([0xaa, 0xbb]), + parts: [new Uint8Array([1, 2, 3]), new Uint8Array([4])], + }), +); + +check( + "echo_fail_request.msgpack", + req("EchoFail", { message: "deliberate failure" }), +); + +// Empty struct — payload is a fixmap with zero entries. +check("echo_fail_response.msgpack", resp("EchoFailResponse", {})); + +// Pins the error variant's wire format. +check( + "echo_error_response.msgpack", + resp("EchoErrorResponse", { message: "deliberate failure" }), +); + +console.log(`\nResults: ${pass}/${pass + fail} passed, ${fail} failed`); +if (fail > 0) process.exit(1); diff --git a/ipc-codegen/echo_example/ts/tsconfig.json b/ipc-codegen/echo_example/ts/tsconfig.json new file mode 100644 index 000000000000..ec9149a77f08 --- /dev/null +++ b/ipc-codegen/echo_example/ts/tsconfig.json @@ -0,0 +1,12 @@ +{ + "compilerOptions": { + "target": "es2022", + "module": "nodenext", + "moduleResolution": "nodenext", + "strict": true, + "noEmit": true, + "skipLibCheck": true, + "types": ["node"] + }, + "include": ["src/**/*.ts"] +} diff --git a/ipc-codegen/echo_example/ts_package/.gitignore b/ipc-codegen/echo_example/ts_package/.gitignore new file mode 100644 index 000000000000..2f10a5bc825a --- /dev/null +++ b/ipc-codegen/echo_example/ts_package/.gitignore @@ -0,0 +1,11 @@ +node_modules/ +dest/ +build/ +packages/ +package-lock.json +package.json +tsconfig.json +src/generated/ +src/index.ts +src/platform.ts +scripts/prepare_arch_packages.sh diff --git a/ipc-codegen/echo_example/ts_package/README.md b/ipc-codegen/echo_example/ts_package/README.md new file mode 100644 index 000000000000..6e04a94a70cd --- /dev/null +++ b/ipc-codegen/echo_example/ts_package/README.md @@ -0,0 +1,29 @@ +# @aztec/echo-ipc + +Generated TypeScript IPC package for the Echo service. + +```ts +import { EchoService } from '@aztec/echo-ipc'; + +const service = await EchoService.spawn({ transport: 'uds' }); +try { + const response = await service.bytes({ data: new Uint8Array([1, 2, 3]) }); +} finally { + await service.destroy(); +} +``` + +The package resolves `echo_server` from `ECHO_SERVER_PATH`, +an explicit `binaryPath`, or an installed/prepared arch package. + +## Build + +The package shell (package.json, tsconfig, src/index.ts, scripts/) is +generated; build through the owning project's `./bootstrap.sh`, which +regenerates and then runs `npm install --omit=optional && npm run build`. + +To prepare per-architecture binary packages: + +```sh +npm run prepare_arch_packages -- linux-x64=/path/to/echo_server +``` diff --git a/ipc-codegen/echo_example/ts_package/bootstrap.sh b/ipc-codegen/echo_example/ts_package/bootstrap.sh new file mode 100755 index 000000000000..7b4405b3d2ad --- /dev/null +++ b/ipc-codegen/echo_example/ts_package/bootstrap.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +CODEGEN="$(cd "$DIR/../.." && pwd)" +REPO_ROOT="$(cd "$CODEGEN/.." && pwd)" +NODE="node --experimental-strip-types --experimental-transform-types --no-warnings" + +$NODE "$CODEGEN/src/generate.ts" \ + --schema "$DIR/../schema/schema.jsonc" \ + --lang ts \ + --client \ + --out "$DIR/src/generated" \ + --package "$DIR" \ + --package-name "@aztec/echo-ipc" \ + --binary-name echo_server \ + --package-transports uds,shm \ + --ipc-runtime-dependency "file:../../../ipc-runtime/ts" + +(cd "$DIR/../cpp" && ./bootstrap.sh) +(cd "$REPO_ROOT/ipc-runtime" && ./bootstrap.sh) +(cd "$REPO_ROOT/ipc-runtime/ts" && yarn install --immutable && yarn build) + +platform_dir="$( + node -e "const arch = { x64: 'amd64', arm64: 'arm64' }[process.arch] ?? process.arch; const os = { linux: 'linux', darwin: 'macos' }[process.platform] ?? process.platform; console.log(arch + '-' + os);" +)" +mkdir -p "$DIR/build/$platform_dir" +cp "$DIR/../cpp/build/bin/echo_server" "$DIR/build/$platform_dir/echo_server" + +rm -rf "$DIR/node_modules" +(cd "$DIR" && npm install --omit=optional --no-package-lock --quiet) +(cd "$DIR" && npm run build --silent) +(cd "$DIR" && npm run prepare_arch_packages --silent) diff --git a/ipc-codegen/echo_example/ts_package/src/package_test.ts b/ipc-codegen/echo_example/ts_package/src/package_test.ts new file mode 100644 index 000000000000..750b277e1495 --- /dev/null +++ b/ipc-codegen/echo_example/ts_package/src/package_test.ts @@ -0,0 +1,86 @@ +import { EchoService, SyncApi, type EchoTransport } from "./index.js"; +import { createNapiShmSyncClient } from "@aztec/ipc-runtime"; + +const args = process.argv.slice(2); +const transportArg = args[args.indexOf("--transport") + 1] ?? "uds"; +if (transportArg !== "uds" && transportArg !== "shm") { + throw new Error(`Unknown --transport '${transportArg}'`); +} +const transport = transportArg as EchoTransport; + +function testHash(base: number): Uint8Array { + return Uint8Array.from({ length: 32 }, (_v, i) => base + i); +} + +function assertEqual(actual: unknown, expected: unknown, label: string) { + const actualJson = JSON.stringify(actual); + const expectedJson = JSON.stringify(expected); + if (actualJson !== expectedJson) { + throw new Error(`${label}: expected ${expectedJson}, got ${actualJson}`); + } +} + +function assertBytes(actual: Uint8Array, expected: Uint8Array, label: string) { + assertEqual( + Buffer.from(actual).toString("hex"), + Buffer.from(expected).toString("hex"), + label, + ); +} + +const service = await EchoService.spawn({ transport }); +try { + const data = Uint8Array.from([0xde, 0xad, 0xbe, 0xef, 0x42]); + const bytes = await service.bytes({ data }); + assertBytes(bytes.data, data, "bytes.data"); + + const fields = await service.fields({ + a: 42, + b: 999999, + name: "hello generated package", + }); + assertEqual(fields.a, 42, "fields.a"); + assertEqual(fields.b, 999999, "fields.b"); + assertEqual(fields.name, "hello generated package", "fields.name"); + + const inner = { + values: [Uint8Array.from([1, 2, 3]), Uint8Array.from([4, 5])], + flag: true, + }; + const nested = await service.nested({ inner }); + assertEqual(nested.inner.flag, true, "nested.inner.flag"); + assertEqual(nested.inner.values.length, 2, "nested.inner.values.length"); + assertBytes(nested.inner.values[0]!, inner.values[0]!, "nested.inner.values[0]"); + + const hash = testHash(0x10); + const second = testHash(0x40); + const aliases = await service.aliases({ + treeId: 7, + hash, + maybeHash: second, + hashes: [hash, second], + }); + assertEqual(aliases.treeId, 7, "aliases.treeId"); + assertBytes(aliases.hash, hash, "aliases.hash"); + assertBytes(aliases.maybeHash!, second, "aliases.maybeHash"); + assertEqual(aliases.hashes.length, 2, "aliases.hashes.length"); + // The generated SyncApi shares the wire format; exercise it over the + // same spawned service (SHM supports multiple client slots). + if (transport === "shm") { + const shmName = service.getIpcPath().replace(/\.shm$/, ""); + const syncApi = new SyncApi( + createNapiShmSyncClient(shmName, { clientId: 1 }), + ); + try { + const syncResp = syncApi.bytes({ data: Uint8Array.from([7, 8]) }); + assertBytes(syncResp.data, Uint8Array.from([7, 8]), "sync bytes.data"); + console.error("echo ts package: SyncApi over shm OK"); + } finally { + syncApi.destroy(); + } + } +} finally { + await service.destroy(); +} + +console.error(`echo ts package: ${transport} OK`); diff --git a/ipc-codegen/echo_example/ts_package/test.sh b/ipc-codegen/echo_example/ts_package/test.sh new file mode 100755 index 000000000000..682358920371 --- /dev/null +++ b/ipc-codegen/echo_example/ts_package/test.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +transport="${1:-uds}" + +(cd "$DIR" && node --no-warnings dest/package_test.js --transport "$transport") diff --git a/ipc-codegen/echo_example/zig/.gitignore b/ipc-codegen/echo_example/zig/.gitignore new file mode 100644 index 000000000000..e95cfbe47ba5 --- /dev/null +++ b/ipc-codegen/echo_example/zig/.gitignore @@ -0,0 +1,3 @@ +.zig-cache/ +src/generated/ +zig-out/ diff --git a/ipc-codegen/echo_example/zig/README.md b/ipc-codegen/echo_example/zig/README.md new file mode 100644 index 000000000000..bf2c5a845500 --- /dev/null +++ b/ipc-codegen/echo_example/zig/README.md @@ -0,0 +1,18 @@ +# Zig Echo Example + +Build from this directory: + +```sh +./bootstrap.sh +``` + +The Zig project depends on the repo-local `ipc-runtime/zig` package and a +`zig_msgpack` copy vendored at `vendor/zig-msgpack`. Binaries are written to +`zig-out/bin/`. + +Run locally: + +```sh +zig-out/bin/echo_server --socket /tmp/echo.sock +zig-out/bin/echo_client --socket /tmp/echo.sock +``` diff --git a/ipc-codegen/echo_example/zig/bootstrap.sh b/ipc-codegen/echo_example/zig/bootstrap.sh new file mode 100755 index 000000000000..a65642aac36a --- /dev/null +++ b/ipc-codegen/echo_example/zig/bootstrap.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +DIR="$(cd "$(dirname "$0")" && pwd)" +CODEGEN="$(cd "$DIR/../.." && pwd)" +NODE="node --experimental-strip-types --experimental-transform-types --no-warnings" + +$NODE "$CODEGEN/src/generate.ts" \ + --schema "$DIR/../schema/schema.jsonc" \ + --lang zig \ + --server \ + --client \ + --uds \ + --ffi \ + --out "$DIR/src/generated" + +(cd "$DIR" && zig build) diff --git a/ipc-codegen/echo_example/zig/build.zig b/ipc-codegen/echo_example/zig/build.zig new file mode 100644 index 000000000000..1adb9909c35f --- /dev/null +++ b/ipc-codegen/echo_example/zig/build.zig @@ -0,0 +1,68 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const msgpack_dep = b.dependency("zig_msgpack", .{ + .target = target, + .optimize = optimize, + }); + const msgpack_mod = msgpack_dep.module("msgpack"); + + const ipc_runtime_dep = b.dependency("ipc_runtime", .{ + .target = target, + .optimize = optimize, + }); + const ipc_runtime_mod = ipc_runtime_dep.module("ipc_runtime"); + + // Echo server + const server_exe = b.addExecutable(.{ + .name = "echo_server", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/echo_server.zig"), + .target = target, + .optimize = optimize, + }), + }); + server_exe.root_module.addImport("msgpack", msgpack_mod); + server_exe.root_module.addImport("ipc_runtime", ipc_runtime_mod); + b.installArtifact(server_exe); + + // Echo client + const client_exe = b.addExecutable(.{ + .name = "echo_client", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/echo_client.zig"), + .target = target, + .optimize = optimize, + }), + }); + client_exe.root_module.addImport("msgpack", msgpack_mod); + client_exe.root_module.addImport("ipc_runtime", ipc_runtime_mod); + b.installArtifact(client_exe); + + // Golden wire-format conformance test (no transport, msgpack only) + const golden_exe = b.addExecutable(.{ + .name = "golden_test", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/golden_test.zig"), + .target = target, + .optimize = optimize, + }), + }); + golden_exe.root_module.addImport("msgpack", msgpack_mod); + b.installArtifact(golden_exe); + + // Compile coverage for the generated FFI backend (stub extern symbol). + const ffi_check_exe = b.addExecutable(.{ + .name = "ffi_check", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/ffi_check.zig"), + .target = target, + .optimize = optimize, + }), + }); + ffi_check_exe.root_module.addImport("msgpack", msgpack_mod); + b.installArtifact(ffi_check_exe); +} diff --git a/ipc-codegen/echo_example/zig/build.zig.zon b/ipc-codegen/echo_example/zig/build.zig.zon new file mode 100644 index 000000000000..6348d0cbe6cd --- /dev/null +++ b/ipc-codegen/echo_example/zig/build.zig.zon @@ -0,0 +1,21 @@ +.{ + .name = .echo_zig, + .version = "0.1.0", + .fingerprint = 0x15539c02bc3573e2, + .minimum_zig_version = "0.14.0", + .dependencies = .{ + .zig_msgpack = .{ + .path = "vendor/zig-msgpack", + }, + // ipc-runtime/zig: UDS + MPSC-SHM transport binding (compiles the + // shared C++ sources via Zig's clang). + .ipc_runtime = .{ + .path = "../../../ipc-runtime/zig", + }, + }, + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + }, +} diff --git a/ipc-codegen/echo_example/zig/src/echo_client.zig b/ipc-codegen/echo_example/zig/src/echo_client.zig new file mode 100644 index 000000000000..ba868a32aa81 --- /dev/null +++ b/ipc-codegen/echo_example/zig/src/echo_client.zig @@ -0,0 +1,187 @@ +/// Echo IPC client (Zig) — uses GENERATED typed client + the ipc-runtime +/// Zig binding for transport. No per-service UDS code in the example. +/// Usage: echo_client --socket /tmp/echo.sock +const std = @import("std"); +const ipc_runtime = @import("ipc_runtime"); +const echo_client = @import("generated/echo_client.zig"); +const types = @import("generated/echo_types.zig"); + +fn testHash(base: u8) types.Fr { + var hash: types.Fr = undefined; + for (&hash, 0..) |*byte, i| { + byte.* = base + @as(u8, @intCast(i)); + } + return hash; +} + +pub fn main() !void { + var args = std.process.args(); + _ = args.next(); + var socket_path: ?[:0]const u8 = null; + while (args.next()) |arg| { + if (std.mem.eql(u8, arg, "--socket")) { + socket_path = args.next(); + } + } + const path = socket_path orelse { + std.debug.print("Usage: echo_client --socket \n", .{}); + std.process.exit(1); + }; + + // Use page_allocator: the codegen-emitted client frees response buffers + // with std.heap.page_allocator, so the runtime Client must allocate with + // the same one. + var backend = try ipc_runtime.Client.fromPath(std.heap.page_allocator, path); + defer backend.deinit(); + + const EchoClient = echo_client.Client(ipc_runtime.Client); + var client = EchoClient.init(&backend); + + // Test 1: EchoBytes + { + const cmd = types.EchoBytes{ .data = &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0x42 } }; + const resp = try client.bytes(cmd); + if (!std.mem.eql(u8, resp.data, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0x42 })) { + std.debug.print("echo_client(zig): EchoBytes FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoBytes OK\n", .{}); + } + + // Test 2: EchoFields + { + const cmd = types.EchoFields{ .a = 42, .b = 999999, .name = "hello wire compat" }; + const resp = try client.fields(cmd); + if (resp.a != 42 or resp.b != 999999 or !std.mem.eql(u8, resp.name, "hello wire compat")) { + std.debug.print("echo_client(zig): EchoFields FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoFields OK\n", .{}); + } + + // Test 3: EchoNested + { + const values = &[_][]const u8{ &[_]u8{ 1, 2, 3 }, &[_]u8{ 4, 5 } }; + const cmd = types.EchoNested{ + .inner = types.EchoInner{ .values = values, .flag = true }, + }; + const resp = try client.nested(cmd); + if (resp.inner.values.len != 2) { + std.debug.print("echo_client(zig): EchoNested FAIL\n", .{}); + std.process.exit(1); + } + if (resp.inner.flag != true) { + std.debug.print("echo_client(zig): EchoNested flag FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoNested OK\n", .{}); + } + + // Test 4: EchoAliases + { + const hash = testHash(0x10); + const second = testHash(0x40); + const hashes = &[_]types.Fr{ hash, second }; + const cmd = types.EchoAliases{ + .tree_id = 7, + .hash = hash, + .maybe_hash = second, + .hashes = hashes, + }; + const resp = try client.aliases(cmd); + if (resp.tree_id != 7 or !std.mem.eql(u8, &resp.hash, &hash)) { + std.debug.print("echo_client(zig): EchoAliases FAIL\n", .{}); + std.process.exit(1); + } + if (resp.maybe_hash == null or !std.mem.eql(u8, &resp.maybe_hash.?, &second) or resp.hashes.len != 2) { + std.debug.print("echo_client(zig): EchoAliases optional/vector FAIL\n", .{}); + std.process.exit(1); + } + if (!std.mem.eql(u8, &resp.hashes[0], &hash) or !std.mem.eql(u8, &resp.hashes[1], &second)) { + std.debug.print("echo_client(zig): EchoAliases hashes contents FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoAliases OK\n", .{}); + } + + // Test 5: EchoAliases with maybe_hash = null (optional-absent over live IPC) + { + const hash = testHash(0x10); + const cmd = types.EchoAliases{ + .tree_id = 7, + .hash = hash, + .maybe_hash = null, + .hashes = &[_]types.Fr{hash}, + }; + const resp = try client.aliases(cmd); + if (resp.maybe_hash != null) { + std.debug.print("echo_client(zig): EchoAliases none FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoAliases none OK\n", .{}); + } + + // Test 6: EchoFields with b > u32::MAX (uint64 wire encoding over live IPC) + { + const big: u64 = (1 << 53) - 1; + const cmd = types.EchoFields{ .a = 42, .b = big, .name = "big" }; + const resp = try client.fields(cmd); + if (resp.b != big) { + std.debug.print("echo_client(zig): EchoFields u64 FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoFields u64 OK\n", .{}); + } + + // Test 7: EchoBlobs — optional bytes Some/None and fixed [bytes; 2] + { + const cmd = types.EchoBlobs{ + .maybe_data = &[_]u8{ 0xAA, 0xBB }, + .parts = .{ &[_]u8{ 1, 2, 3 }, &[_]u8{4} }, + }; + const resp = try client.blobs(cmd); + if (resp.maybe_data == null or !std.mem.eql(u8, resp.maybe_data.?, &[_]u8{ 0xAA, 0xBB })) { + std.debug.print("echo_client(zig): EchoBlobs maybe_data FAIL\n", .{}); + std.process.exit(1); + } + if (!std.mem.eql(u8, resp.parts[0], &[_]u8{ 1, 2, 3 }) or !std.mem.eql(u8, resp.parts[1], &[_]u8{4})) { + std.debug.print("echo_client(zig): EchoBlobs parts FAIL\n", .{}); + std.process.exit(1); + } + const cmd_none = types.EchoBlobs{ + .maybe_data = null, + .parts = .{ &[_]u8{}, &[_]u8{9} }, + }; + const resp_none = try client.blobs(cmd_none); + if (resp_none.maybe_data != null) { + std.debug.print("echo_client(zig): EchoBlobs none FAIL\n", .{}); + std.process.exit(1); + } + std.debug.print("echo_client(zig): EchoBlobs OK\n", .{}); + } + + // Test 8: EchoFail — server error surfaces, message available on the client + { + const cmd = types.EchoFail{ .message = "deliberate failure" }; + if (client.fail(cmd)) |_| { + std.debug.print("echo_client(zig): EchoFail FAIL (no error)\n", .{}); + std.process.exit(1); + } else |err| { + if (err != error.ServerError) { + std.debug.print("echo_client(zig): EchoFail wrong error: {s}\n", .{@errorName(err)}); + std.process.exit(1); + } + const message = client.last_server_error orelse { + std.debug.print("echo_client(zig): EchoFail missing message\n", .{}); + std.process.exit(1); + }; + if (std.mem.indexOf(u8, message, "deliberate failure") == null) { + std.debug.print("echo_client(zig): EchoFail message mismatch: {s}\n", .{message}); + std.process.exit(1); + } + } + std.debug.print("echo_client(zig): EchoFail OK\n", .{}); + } + + std.debug.print("echo_client(zig): all tests passed\n", .{}); +} diff --git a/ipc-codegen/echo_example/zig/src/echo_server.zig b/ipc-codegen/echo_example/zig/src/echo_server.zig new file mode 100644 index 000000000000..edf31fa0fc48 --- /dev/null +++ b/ipc-codegen/echo_example/zig/src/echo_server.zig @@ -0,0 +1,76 @@ +//! Echo IPC server (Zig) — uses the ipc-runtime Zig binding for transport +//! and the GENERATED Dispatcher for framing, dispatch, and error wrapping. +//! Usage: echo_server --socket /tmp/echo.sock +const std = @import("std"); +const ipc_runtime = @import("ipc_runtime"); +const types = @import("generated/echo_types.zig"); +const echo_server = @import("generated/echo_server.zig"); + +const EchoHandler = struct { + /// Diagnostic channel for handler failures — the generated dispatcher + /// sends this as the error variant's message when set. + error_message: ?[]const u8 = null, + + pub fn bytes(self: *EchoHandler, cmd: types.EchoBytes) !types.EchoBytesResponse { + _ = self; + return .{ .data = cmd.data }; + } + + pub fn fields(self: *EchoHandler, cmd: types.EchoFields) !types.EchoFieldsResponse { + _ = self; + return .{ .a = cmd.a, .b = cmd.b, .name = cmd.name }; + } + + pub fn nested(self: *EchoHandler, cmd: types.EchoNested) !types.EchoNestedResponse { + _ = self; + return .{ .inner = cmd.inner }; + } + + pub fn aliases(self: *EchoHandler, cmd: types.EchoAliases) !types.EchoAliasesResponse { + _ = self; + return .{ + .tree_id = cmd.tree_id, + .hash = cmd.hash, + .maybe_hash = cmd.maybe_hash, + .hashes = cmd.hashes, + }; + } + + pub fn blobs(self: *EchoHandler, cmd: types.EchoBlobs) !types.EchoBlobsResponse { + _ = self; + return .{ .maybe_data = cmd.maybe_data, .parts = cmd.parts }; + } + + pub fn fail(self: *EchoHandler, cmd: types.EchoFail) !types.EchoFailResponse { + self.error_message = cmd.message; + return error.EchoFailRequested; + } +}; + +const EchoDispatcher = echo_server.Dispatcher(EchoHandler); + +pub fn main() !void { + var args = std.process.args(); + _ = args.next(); + var socket_path: ?[:0]const u8 = null; + while (args.next()) |arg| { + if (std.mem.eql(u8, arg, "--socket")) { + socket_path = args.next(); + } + } + const path = socket_path orelse { + std.debug.print("Usage: echo_server --socket \n", .{}); + return error.InvalidArgument; + }; + + var handler = EchoHandler{}; + var dispatcher = EchoDispatcher.init(&handler); + + var server = try ipc_runtime.Server.fromPath(path); + defer server.deinit(); + server.installDefaultSignalHandlers(); + try server.listen(); + std.debug.print("ipc-server(zig): listening on {s}\n", .{path}); + + server.run(*EchoDispatcher, &dispatcher, EchoDispatcher.handleRequest); +} diff --git a/ipc-codegen/echo_example/zig/src/ffi_check.zig b/ipc-codegen/echo_example/zig/src/ffi_check.zig new file mode 100644 index 000000000000..9d89b86e07c1 --- /dev/null +++ b/ipc-codegen/echo_example/zig/src/ffi_check.zig @@ -0,0 +1,19 @@ +//! Compile coverage for the generated FFI backend. The real FFI symbol is +//! provided by whatever native library a consumer links; a stub satisfies +//! the linker here so the backend's code is fully analyzed and built. +const std = @import("std"); +const ffi = @import("generated/ffi_backend.zig"); + +export fn ipc_ffi_entry(input: [*]const u8, input_len: usize, output: *[*]u8, output_len: *usize) void { + _ = input; + _ = input_len; + output.* = undefined; + output_len.* = 0; +} + +pub fn main() void { + comptime { + std.testing.refAllDeclsRecursive(ffi); + } + std.debug.print("ffi_check: generated FFI backend compiles\n", .{}); +} diff --git a/ipc-codegen/echo_example/zig/src/golden_test.zig b/ipc-codegen/echo_example/zig/src/golden_test.zig new file mode 100644 index 000000000000..49317643c53b --- /dev/null +++ b/ipc-codegen/echo_example/zig/src/golden_test.zig @@ -0,0 +1,325 @@ +//! Golden file wire-format conformance test (Zig). +//! For each golden file, asserts: +//! 1. We can decode the bytes into the expected typed value. +//! 2. Re-encoding the same value produces byte-identical output. +//! The combination pins down the wire format as a binding contract. +//! +//! Usage: golden_test --golden-dir + +const std = @import("std"); +const msgpack = @import("msgpack"); +const Payload = msgpack.Payload; +const types = @import("generated/echo_types.zig"); + +const alloc = std.heap.page_allocator; + +var pass: u32 = 0; +var fail: u32 = 0; + +fn testHash(base: u8) types.Fr { + var hash: types.Fr = undefined; + for (&hash, 0..) |*byte, i| { + byte.* = base + @as(u8, @intCast(i)); + } + return hash; +} + +// --- framing helpers (mirror generated echo_client.zig / echo_server.zig) --- +// +// Re-encoding goes through toPayload() for every field value, but the struct +// maps themselves are emitted with an explicit schema field order: zig-msgpack +// Payload maps are std.HashMap, so toPayload + PackerIO.write alone would emit +// fields in hash-bucket order and the byte-roundtrip against the goldens +// (which use schema declaration order) would spuriously fail. The wire +// contract is name-keyed, so this only fixes ordering — every field's byte +// encoding still comes from the library encoder. + +const FieldSpec = struct { + name: []const u8, + nested: ?[]const FieldSpec = null, +}; + +const inner_fields = [_]FieldSpec{ .{ .name = "values" }, .{ .name = "flag" } }; + +fn writeOrderedMap( + writer: *std.Io.Writer, + packer: *msgpack.PackerIO, + map: Payload, + comptime fields: []const FieldSpec, +) !void { + comptime std.debug.assert(fields.len < 16); + try writer.writeByte(0x80 | @as(u8, fields.len)); + inline for (fields) |spec| { + try packer.write(try Payload.strToPayload(spec.name, alloc)); + const value = (try map.mapGet(spec.name)) orelse return error.MissingField; + if (spec.nested) |nested| { + try writeOrderedMap(writer, packer, value, nested); + } else { + try packer.write(value); + } + } +} + +/// Requests are framed as [[name, payload-map]]. +fn encodeRequest(name: []const u8, fields_payload: Payload, comptime fields: []const FieldSpec) ![]u8 { + var allocating_writer = std.Io.Writer.Allocating.init(alloc); + var packer = msgpack.PackerIO.init(undefined, &allocating_writer.writer); + try allocating_writer.writer.writeByte(0x91); // fixarray(1) + try allocating_writer.writer.writeByte(0x92); // fixarray(2) + try packer.write(try Payload.strToPayload(name, alloc)); + try writeOrderedMap(&allocating_writer.writer, &packer, fields_payload, fields); + return try allocating_writer.toOwnedSlice(); +} + +/// Responses are framed as [name, payload-map]. +fn encodeResponse(name: []const u8, fields_payload: Payload, comptime fields: []const FieldSpec) ![]u8 { + var allocating_writer = std.Io.Writer.Allocating.init(alloc); + var packer = msgpack.PackerIO.init(undefined, &allocating_writer.writer); + try allocating_writer.writer.writeByte(0x92); // fixarray(2) + try packer.write(try Payload.strToPayload(name, alloc)); + try writeOrderedMap(&allocating_writer.writer, &packer, fields_payload, fields); + return try allocating_writer.toOwnedSlice(); +} + +fn decodePayload(bytes: []const u8) !Payload { + var reader = std.Io.Reader.fixed(bytes); + var unpacker = msgpack.PackerIO.init(&reader, undefined); + return try unpacker.read(alloc); +} + +// --- per-type schema field orders (must match schema.json declaration order) --- + +const bytes_fields = [_]FieldSpec{.{ .name = "data" }}; +const fields_fields = [_]FieldSpec{ .{ .name = "a" }, .{ .name = "b" }, .{ .name = "name" } }; +const nested_fields = [_]FieldSpec{.{ .name = "inner", .nested = &inner_fields }}; +const aliases_fields = [_]FieldSpec{ .{ .name = "treeId" }, .{ .name = "hash" }, .{ .name = "maybeHash" }, .{ .name = "hashes" } }; +const blobs_fields = [_]FieldSpec{ .{ .name = "maybeData" }, .{ .name = "parts" } }; +const message_fields = [_]FieldSpec{.{ .name = "message" }}; +const empty_fields = [_]FieldSpec{}; + +// --- generic per-file check --- + +fn check( + comptime T: type, + comptime is_request: bool, + dir: []const u8, + file: []const u8, + name: []const u8, + comptime fields: []const FieldSpec, + comptime verify: fn (T) bool, +) void { + runCheck(T, is_request, dir, file, name, fields, verify) catch |err| { + std.debug.print(" FAIL: {s}: {s}\n", .{ file, @errorName(err) }); + fail += 1; + return; + }; + std.debug.print(" PASS: {s}\n", .{file}); + pass += 1; +} + +fn runCheck( + comptime T: type, + comptime is_request: bool, + dir: []const u8, + file: []const u8, + name: []const u8, + comptime fields: []const FieldSpec, + comptime verify: fn (T) bool, +) !void { + const path = try std.fs.path.join(alloc, &.{ dir, file }); + defer alloc.free(path); + const golden = try std.fs.cwd().readFileAlloc(alloc, path, 1 << 20); + + const decoded = try decodePayload(golden); + const named = if (is_request) blk: { + if (try decoded.getArrLen() != 1) return error.BadOuterArray; + break :blk try decoded.getArrElement(0); + } else decoded; + if (try named.getArrLen() != 2) return error.BadNamedArray; + const got_name = try (try named.getArrElement(0)).asStr(); + if (!std.mem.eql(u8, got_name, name)) return error.WrongUnionName; + + const value = try T.fromPayload(try named.getArrElement(1)); + if (!verify(value)) return error.DecodedValueMismatch; + + const reencoded = if (is_request) + try encodeRequest(name, try value.toPayload(alloc), fields) + else + try encodeResponse(name, try value.toPayload(alloc), fields); + if (!std.mem.eql(u8, reencoded, golden)) return error.RoundtripByteMismatch; +} + +// --- expected-value predicates --- + +fn verifyBytes(v: types.EchoBytes) bool { + return std.mem.eql(u8, v.data, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0x42 }); +} + +fn verifyFields(v: types.EchoFields) bool { + return v.a == 42 and v.b == 999999 and std.mem.eql(u8, v.name, "hello wire compat"); +} + +fn verifyInnerHappy(inner: types.EchoInner) bool { + return inner.values.len == 2 and + std.mem.eql(u8, inner.values[0], &[_]u8{ 1, 2, 3 }) and + std.mem.eql(u8, inner.values[1], &[_]u8{ 4, 5 }) and + inner.flag == true; +} + +fn verifyNested(v: types.EchoNested) bool { + return verifyInnerHappy(v.inner); +} + +fn verifyAliases(v: types.EchoAliases) bool { + const hash = testHash(0x10); + const second = testHash(0x40); + return v.tree_id == 7 and + std.mem.eql(u8, &v.hash, &hash) and + v.maybe_hash != null and std.mem.eql(u8, &v.maybe_hash.?, &second) and + v.hashes.len == 2 and + std.mem.eql(u8, &v.hashes[0], &hash) and + std.mem.eql(u8, &v.hashes[1], &second); +} + +fn verifyBytesResponse(v: types.EchoBytesResponse) bool { + return std.mem.eql(u8, v.data, &[_]u8{ 0xDE, 0xAD, 0xBE, 0xEF, 0x42 }); +} + +fn verifyFieldsResponse(v: types.EchoFieldsResponse) bool { + return v.a == 42 and v.b == 999999 and std.mem.eql(u8, v.name, "hello wire compat"); +} + +fn verifyNestedResponse(v: types.EchoNestedResponse) bool { + return verifyInnerHappy(v.inner); +} + +fn verifyAliasesResponse(v: types.EchoAliasesResponse) bool { + const hash = testHash(0x10); + const second = testHash(0x40); + return v.tree_id == 7 and + std.mem.eql(u8, &v.hash, &hash) and + v.maybe_hash != null and std.mem.eql(u8, &v.maybe_hash.?, &second) and + v.hashes.len == 2 and + std.mem.eql(u8, &v.hashes[0], &hash) and + std.mem.eql(u8, &v.hashes[1], &second); +} + +fn verifyBytesEmpty(v: types.EchoBytes) bool { + return v.data.len == 0; +} + +fn verifyBytesBin16(v: types.EchoBytes) bool { + if (v.data.len != 256) return false; + for (v.data) |b| { + if (b != 0xAA) return false; + } + return true; +} + +fn verifyFieldsMax(v: types.EchoFields) bool { + return v.a == std.math.maxInt(u32) and v.b == std.math.maxInt(u64) and v.name.len == 0; +} + +fn verifyFieldsUintBoundary(v: types.EchoFields) bool { + return v.a == 128 and v.b == @as(u64, std.math.maxInt(u32)) + 1 and std.mem.eql(u8, v.name, "x"); +} + +fn verifyFieldsUnicode(v: types.EchoFields) bool { + return std.mem.eql(u8, v.name, "héllo τέστ 🚀 mañana"); +} + +fn verifyFieldsStr16(v: types.EchoFields) bool { + if (v.name.len != 300) return false; + for (v.name) |c| { + if (c != 'a') return false; + } + return true; +} + +fn verifyNestedFlagNone(v: types.EchoNested) bool { + return v.inner.values.len == 0 and v.inner.flag == null; +} + +fn verifyNestedFlagFalse(v: types.EchoNested) bool { + return v.inner.values.len == 1 and v.inner.values[0].len == 0 and v.inner.flag == false; +} + +fn verifyBlobs(v: types.EchoBlobs) bool { + return v.maybe_data != null and std.mem.eql(u8, v.maybe_data.?, &[_]u8{ 0xAA, 0xBB }) and + std.mem.eql(u8, v.parts[0], &[_]u8{ 1, 2, 3 }) and + std.mem.eql(u8, v.parts[1], &[_]u8{4}); +} + +fn verifyBlobsNone(v: types.EchoBlobs) bool { + return v.maybe_data == null and + v.parts[0].len == 0 and + std.mem.eql(u8, v.parts[1], &[_]u8{9}); +} + +fn verifyBlobsResponse(v: types.EchoBlobsResponse) bool { + return v.maybe_data != null and std.mem.eql(u8, v.maybe_data.?, &[_]u8{ 0xAA, 0xBB }) and + std.mem.eql(u8, v.parts[0], &[_]u8{ 1, 2, 3 }) and + std.mem.eql(u8, v.parts[1], &[_]u8{4}); +} + +fn verifyFail(v: types.EchoFail) bool { + return std.mem.eql(u8, v.message, "deliberate failure"); +} + +fn verifyFailResponse(_: types.EchoFailResponse) bool { + return true; +} + +fn verifyErrorResponse(v: types.EchoErrorResponse) bool { + return std.mem.eql(u8, v.message, "deliberate failure"); +} + +pub fn main() !void { + var args = std.process.args(); + _ = args.next(); + var golden_dir: ?[]const u8 = null; + while (args.next()) |arg| { + if (std.mem.eql(u8, arg, "--golden-dir")) { + golden_dir = args.next(); + } + } + const dir = golden_dir orelse { + std.debug.print("Usage: golden_test --golden-dir \n", .{}); + std.process.exit(1); + }; + + // ============ Original happy-path cases ============ + + check(types.EchoBytes, true, dir, "echo_bytes_request.msgpack", "EchoBytes", &bytes_fields, verifyBytes); + check(types.EchoFields, true, dir, "echo_fields_request.msgpack", "EchoFields", &fields_fields, verifyFields); + check(types.EchoNested, true, dir, "echo_nested_request.msgpack", "EchoNested", &nested_fields, verifyNested); + check(types.EchoAliases, true, dir, "echo_aliases_request.msgpack", "EchoAliases", &aliases_fields, verifyAliases); + + check(types.EchoBytesResponse, false, dir, "echo_bytes_response.msgpack", "EchoBytesResponse", &bytes_fields, verifyBytesResponse); + check(types.EchoFieldsResponse, false, dir, "echo_fields_response.msgpack", "EchoFieldsResponse", &fields_fields, verifyFieldsResponse); + check(types.EchoNestedResponse, false, dir, "echo_nested_response.msgpack", "EchoNestedResponse", &nested_fields, verifyNestedResponse); + check(types.EchoAliasesResponse, false, dir, "echo_aliases_response.msgpack", "EchoAliasesResponse", &aliases_fields, verifyAliasesResponse); + + // ============ Boundary cases ============ + + check(types.EchoBytes, true, dir, "echo_bytes_empty.msgpack", "EchoBytes", &bytes_fields, verifyBytesEmpty); + check(types.EchoBytes, true, dir, "echo_bytes_bin16.msgpack", "EchoBytes", &bytes_fields, verifyBytesBin16); + check(types.EchoFields, true, dir, "echo_fields_max.msgpack", "EchoFields", &fields_fields, verifyFieldsMax); + check(types.EchoFields, true, dir, "echo_fields_uint_boundary.msgpack", "EchoFields", &fields_fields, verifyFieldsUintBoundary); + check(types.EchoFields, true, dir, "echo_fields_unicode.msgpack", "EchoFields", &fields_fields, verifyFieldsUnicode); + check(types.EchoFields, true, dir, "echo_fields_str16.msgpack", "EchoFields", &fields_fields, verifyFieldsStr16); + check(types.EchoNested, true, dir, "echo_nested_flag_none.msgpack", "EchoNested", &nested_fields, verifyNestedFlagNone); + check(types.EchoNested, true, dir, "echo_nested_flag_false.msgpack", "EchoNested", &nested_fields, verifyNestedFlagFalse); + + // ============ Blob / fail / error cases ============ + + check(types.EchoBlobs, true, dir, "echo_blobs_request.msgpack", "EchoBlobs", &blobs_fields, verifyBlobs); + check(types.EchoBlobs, true, dir, "echo_blobs_none.msgpack", "EchoBlobs", &blobs_fields, verifyBlobsNone); + check(types.EchoBlobsResponse, false, dir, "echo_blobs_response.msgpack", "EchoBlobsResponse", &blobs_fields, verifyBlobsResponse); + check(types.EchoFail, true, dir, "echo_fail_request.msgpack", "EchoFail", &message_fields, verifyFail); + check(types.EchoFailResponse, false, dir, "echo_fail_response.msgpack", "EchoFailResponse", &empty_fields, verifyFailResponse); + check(types.EchoErrorResponse, false, dir, "echo_error_response.msgpack", "EchoErrorResponse", &message_fields, verifyErrorResponse); + + std.debug.print("\nResults: {d}/{d} passed, {d} failed\n", .{ pass, pass + fail, fail }); + if (fail > 0) std.process.exit(1); +} diff --git a/ipc-codegen/echo_example/zig/vendor/zig-msgpack/build.zig b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/build.zig new file mode 100644 index 000000000000..bef5fb0aaee1 --- /dev/null +++ b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/build.zig @@ -0,0 +1,10 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + _ = b.standardTargetOptions(.{}); + _ = b.standardOptimizeOption(.{}); + + _ = b.addModule("msgpack", .{ + .root_source_file = b.path("src/msgpack.zig"), + }); +} diff --git a/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/compat.zig b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/compat.zig new file mode 100644 index 000000000000..f96b0a967a37 --- /dev/null +++ b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/compat.zig @@ -0,0 +1,66 @@ +// Compatibility layer for different Zig versions +const std = @import("std"); +const builtin = @import("builtin"); +const current_zig = builtin.zig_version; + +// BufferStream implementation for Zig 0.16+ +// This mimics the behavior of the old FixedBufferStream +pub const BufferStream = if (current_zig.minor >= 16) struct { + buffer: []u8, + pos: usize, + + const Self = @This(); + + pub const WriteError = error{NoSpaceLeft}; + pub const ReadError = error{EndOfStream}; + + pub fn init(buffer: []u8) Self { + return .{ + .buffer = buffer, + .pos = 0, + }; + } + + pub fn write(self: *Self, bytes: []const u8) WriteError!usize { + const available = self.buffer.len - self.pos; + if (bytes.len > available) return error.NoSpaceLeft; + @memcpy(self.buffer[self.pos..][0..bytes.len], bytes); + self.pos += bytes.len; + return bytes.len; + } + + pub fn read(self: *Self, dest: []u8) ReadError!usize { + // Read from current position in buffer + const available = self.buffer.len - self.pos; + if (available == 0) return 0; + + const to_read = @min(dest.len, available); + @memcpy(dest[0..to_read], self.buffer[self.pos..][0..to_read]); + self.pos += to_read; + return to_read; + } + + pub fn reset(self: *Self) void { + self.pos = 0; + } + + pub fn seekTo(self: *Self, pos: usize) !void { + if (pos > self.buffer.len) { + return error.OutOfBounds; + } + self.pos = pos; + } + + pub fn getPos(self: Self) usize { + return self.pos; + } + + pub fn getEndPos(self: Self) usize { + return self.buffer.len; + } +} else std.io.FixedBufferStream([]u8); + +pub const fixedBufferStream = if (current_zig.minor >= 16) + BufferStream.init +else + std.io.fixedBufferStream; diff --git a/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/msgpack.zig b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/msgpack.zig new file mode 100644 index 000000000000..51da2d8be316 --- /dev/null +++ b/ipc-codegen/echo_example/zig/vendor/zig-msgpack/src/msgpack.zig @@ -0,0 +1,3273 @@ +//! MessagePack implementation with zig +//! https://msgpack.org/ + +const std = @import("std"); +const builtin = @import("builtin"); + +const current_zig = builtin.zig_version; +const Allocator = std.mem.Allocator; +const comptimePrint = std.fmt.comptimePrint; +const native_endian = builtin.cpu.arch.endian(); + +const big_endian = std.builtin.Endian.big; +const little_endian = std.builtin.Endian.little; + +/// Cache line size for prefetch optimization +const CACHE_LINE_SIZE: usize = 64; + +/// Prefetch hint for read-ahead optimization +/// Uses compiler intrinsics to hint CPU to prefetch data +/// This is a performance hint and may be a no-op on some architectures +inline fn prefetchRead(ptr: [*]const u8, comptime locality: u2) void { + // locality: 0=no temporal locality (NTA), 1=low (T2), 2=medium (T1), 3=high (T0) + const arch = comptime builtin.cpu.arch; + + // x86/x64: Check for SSE support (required for PREFETCH instructions) + if (comptime arch.isX86()) { + const has_sse = comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse); + if (has_sse) { + // Use different prefetch instructions based on locality + switch (locality) { + 3 => asm volatile ("prefetcht0 %[ptr]" + : + : [ptr] "m" (@as(*const u8, ptr)), + ), // High locality -> L1+L2+L3 + 2 => asm volatile ("prefetcht1 %[ptr]" + : + : [ptr] "m" (@as(*const u8, ptr)), + ), // Medium -> L2+L3 + 1 => asm volatile ("prefetcht2 %[ptr]" + : + : [ptr] "m" (@as(*const u8, ptr)), + ), // Low -> L3 only + 0 => asm volatile ("prefetchnta %[ptr]" + : + : [ptr] "m" (@as(*const u8, ptr)), + ), // Non-temporal + } + } + } + // ARM64 (Apple Silicon, Linux ARM): Use PRFM instruction + else if (comptime arch.isAARCH64()) { + // ARM PRFM (Prefetch Memory) instruction + // Syntax: prfm , [{, #}] + // prfop encoding: PLD (prefetch for load) + locality hint + switch (locality) { + 3 => asm volatile ("prfm pldl1keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), // Keep in L1 + 2 => asm volatile ("prfm pldl2keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), // Keep in L2 + 1 => asm volatile ("prfm pldl3keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), // Keep in L3 + 0 => asm volatile ("prfm pldl1strm, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), // Streaming (non-temporal) + } + } + // Other architectures: no-op (compiler optimizes away) + // RISC-V, MIPS, etc. may have their own prefetch extensions but not standard +} + +/// Prefetch data for write operations +inline fn prefetchWrite(ptr: [*]u8, comptime locality: u2) void { + const arch = comptime builtin.cpu.arch; + + // x86/x64: Use PREFETCHW if available (3DNow!/SSE), fallback to read prefetch + if (comptime arch.isX86()) { + // PREFETCHW is part of 3DNow! (AMD) or PRFCHW feature (Intel Broadwell+) + const has_prefetchw = comptime std.Target.x86.featureSetHas(builtin.cpu.features, .prfchw) or + std.Target.x86.featureSetHas(builtin.cpu.features, .@"3dnow"); + const has_sse = comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse); + + if (has_prefetchw) { + // Use write-specific prefetch (ignores locality for simplicity) + asm volatile ("prefetchw %[ptr]" + : + : [ptr] "m" (@as(*u8, ptr)), + ); + } else if (has_sse) { + // Fallback to read prefetch with specified locality + switch (locality) { + 3 => asm volatile ("prefetcht0 %[ptr]" + : + : [ptr] "m" (@as(*u8, ptr)), + ), + 2 => asm volatile ("prefetcht1 %[ptr]" + : + : [ptr] "m" (@as(*u8, ptr)), + ), + 1 => asm volatile ("prefetcht2 %[ptr]" + : + : [ptr] "m" (@as(*u8, ptr)), + ), + 0 => asm volatile ("prefetchnta %[ptr]" + : + : [ptr] "m" (@as(*u8, ptr)), + ), + } + } + } + // ARM64: Use PST (prefetch for store) + else if (comptime arch.isAARCH64()) { + switch (locality) { + 3 => asm volatile ("prfm pstl1keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), + 2 => asm volatile ("prfm pstl2keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), + 1 => asm volatile ("prfm pstl3keep, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), + 0 => asm volatile ("prfm pstl1strm, [%[ptr]]" + : + : [ptr] "r" (ptr), + ), + } + } +} + +/// Prefetch multiple cache lines for large data operations +/// Used for arrays/maps/strings >= 256 bytes +inline fn prefetchLarge(ptr: [*]const u8, size: usize) void { + // Prefetch first few cache lines + const lines_to_prefetch = @min(size / CACHE_LINE_SIZE, 4); // Max 4 lines + var i: usize = 0; + while (i < lines_to_prefetch) : (i += 1) { + prefetchRead(ptr + i * CACHE_LINE_SIZE, 2); // Medium locality + } +} + +/// MessagePack format limits for fix types +pub const FixLimits = struct { + pub const POSITIVE_INT_MAX: u8 = 0x7f; + pub const NEGATIVE_INT_MIN: i8 = -32; + pub const STR_LEN_MAX: u8 = 31; + pub const ARRAY_LEN_MAX: u8 = 15; + pub const MAP_LEN_MAX: u8 = 15; +}; + +/// Integer type boundaries +pub const IntBounds = struct { + pub const UINT8_MAX: u64 = 0xff; + pub const UINT16_MAX: u64 = 0xffff; + pub const UINT32_MAX: u64 = 0xffff_ffff; + pub const INT8_MIN: i64 = -128; + pub const INT16_MIN: i64 = -32768; + pub const INT32_MIN: i64 = -2147483648; +}; + +/// Fixed extension type data lengths +pub const FixExtLen = struct { + pub const EXT1: usize = 1; + pub const EXT2: usize = 2; + pub const EXT4: usize = 4; + pub const EXT8: usize = 8; + pub const EXT16: usize = 16; +}; + +/// Timestamp extension type constants +pub const TimestampExt = struct { + pub const TYPE_ID: i8 = -1; + pub const FORMAT32_LEN: usize = 4; + pub const FORMAT64_LEN: usize = 8; + pub const FORMAT96_LEN: usize = 12; + pub const SECONDS_BITS_64: u6 = 34; + pub const SECONDS_MASK_64: u64 = 0x3ffffffff; + pub const NANOSECONDS_MAX: u32 = 999_999_999; + pub const NANOSECONDS_PER_SECOND: f64 = 1_000_000_000.0; +}; + +/// Marker byte base values and masks +pub const MarkerBase = struct { + pub const FIXARRAY: u8 = 0x90; + pub const FIXMAP: u8 = 0x80; + pub const FIXSTR: u8 = 0xa0; + pub const FIXSTR_LEN_MASK: u8 = 0x1f; + pub const FIXSTR_TYPE_MASK: u8 = 0xe0; +}; + +// Backward compatibility aliases (will be deprecated) +const MAX_POSITIVE_FIXINT: u8 = FixLimits.POSITIVE_INT_MAX; +const MIN_NEGATIVE_FIXINT: i8 = FixLimits.NEGATIVE_INT_MIN; +const MAX_FIXSTR_LEN: u8 = FixLimits.STR_LEN_MAX; +const MAX_FIXARRAY_LEN: u8 = FixLimits.ARRAY_LEN_MAX; +const MAX_FIXMAP_LEN: u8 = FixLimits.MAP_LEN_MAX; +const TIMESTAMP_EXT_TYPE: i8 = TimestampExt.TYPE_ID; +const MAX_UINT8: u64 = IntBounds.UINT8_MAX; +const MAX_UINT16: u64 = IntBounds.UINT16_MAX; +const MAX_UINT32: u64 = IntBounds.UINT32_MAX; +const MIN_INT8: i64 = IntBounds.INT8_MIN; +const MIN_INT16: i64 = IntBounds.INT16_MIN; +const MIN_INT32: i64 = IntBounds.INT32_MIN; +const FIXEXT1_LEN: usize = FixExtLen.EXT1; +const FIXEXT2_LEN: usize = FixExtLen.EXT2; +const FIXEXT4_LEN: usize = FixExtLen.EXT4; +const FIXEXT8_LEN: usize = FixExtLen.EXT8; +const FIXEXT16_LEN: usize = FixExtLen.EXT16; +const TIMESTAMP32_DATA_LEN: usize = TimestampExt.FORMAT32_LEN; +const TIMESTAMP64_DATA_LEN: usize = TimestampExt.FORMAT64_LEN; +const TIMESTAMP96_DATA_LEN: usize = TimestampExt.FORMAT96_LEN; +const TIMESTAMP64_SECONDS_BITS: u6 = TimestampExt.SECONDS_BITS_64; +const TIMESTAMP64_SECONDS_MASK: u64 = TimestampExt.SECONDS_MASK_64; +const MAX_NANOSECONDS: u32 = TimestampExt.NANOSECONDS_MAX; +const NANOSECONDS_PER_SECOND: f64 = TimestampExt.NANOSECONDS_PER_SECOND; +const FIXARRAY_BASE: u8 = MarkerBase.FIXARRAY; +const FIXMAP_BASE: u8 = MarkerBase.FIXMAP; +const FIXSTR_BASE: u8 = MarkerBase.FIXSTR; +const FIXSTR_MASK: u8 = MarkerBase.FIXSTR_LEN_MASK; +const FIXSTR_TYPE_MASK: u8 = MarkerBase.FIXSTR_TYPE_MASK; + +/// Parse safety limits configuration +pub const ParseLimits = struct { + /// Maximum nesting depth (default 1000 layers) + max_depth: usize = 1000, + + /// Maximum array length (default 1 million elements) + max_array_length: usize = 1_000_000, + + /// Maximum map size (default 1 million key-value pairs) + max_map_size: usize = 1_000_000, + + /// Maximum string data length (default 100MB) + max_string_length: usize = 100 * 1024 * 1024, + + /// Maximum binary data length (default 100MB) + max_bin_length: usize = 100 * 1024 * 1024, + + /// Maximum extension data length (default 100MB) + max_ext_length: usize = 100 * 1024 * 1024, +}; + +/// Default parse limits +pub const DEFAULT_LIMITS = ParseLimits{}; + +/// the Str Type +pub const Str = struct { + str: []const u8, + + /// Initialize a new Str instance + pub inline fn init(str: []const u8) Str { + return Str{ .str = str }; + } + + /// get Str values + pub fn value(self: Str) []const u8 { + return self.str; + } +}; + +/// this is for encode str in struct +pub inline fn wrapStr(str: []const u8) Str { + return Str.init(str); +} + +/// the Bin Type +pub const Bin = struct { + bin: []u8, + + /// Initialize a new Bin instance + pub inline fn init(bin: []u8) Bin { + return Bin{ .bin = bin }; + } + + /// get bin values + pub fn value(self: Bin) []u8 { + return self.bin; + } +}; + +/// this is wrapping for bin +pub inline fn wrapBin(bin: []u8) Bin { + return Bin.init(bin); +} + +/// the EXT Type +pub const EXT = struct { + type: i8, + data: []u8, + + /// Initialize a new EXT instance + pub inline fn init(t: i8, data: []u8) EXT { + return EXT{ + .type = t, + .data = data, + }; + } +}; + +/// t is type, data is data +pub inline fn wrapEXT(t: i8, data: []u8) EXT { + return EXT.init(t, data); +} + +/// the Timestamp Type +/// Represents an instantaneous point on the time-line in the world +/// that is independent from time zones or calendars. +/// Maximum precision is nanoseconds. +pub const Timestamp = struct { + /// seconds since 1970-01-01 00:00:00 UTC + seconds: i64, + /// nanoseconds (0-999999999) + nanoseconds: u32, + + /// Create a new timestamp + pub inline fn new(seconds: i64, nanoseconds: u32) Timestamp { + return Timestamp{ + .seconds = seconds, + .nanoseconds = nanoseconds, + }; + } + + /// Create timestamp from seconds only (nanoseconds = 0) + pub inline fn fromSeconds(seconds: i64) Timestamp { + return Timestamp{ + .seconds = seconds, + .nanoseconds = 0, + }; + } + + /// Create timestamp from nanoseconds since Unix epoch + /// This is useful for converting from various time sources + /// Example: Timestamp.fromNanos(some_nanosecond_value) + pub fn fromNanos(nanos: i128) Timestamp { + const ns_i64: i64 = @intCast(@divFloor(nanos, std.time.ns_per_s)); + const nano_remainder: i64 = @intCast(@mod(nanos, std.time.ns_per_s)); + const nanoseconds: u32 = @intCast(if (nano_remainder < 0) nano_remainder + std.time.ns_per_s else nano_remainder); + return Timestamp{ + .seconds = if (nano_remainder < 0) ns_i64 - 1 else ns_i64, + .nanoseconds = nanoseconds, + }; + } + + /// Get total seconds as f64 (including fractional nanoseconds) + pub fn toFloat(self: Timestamp) f64 { + return @as(f64, @floatFromInt(self.seconds)) + @as(f64, @floatFromInt(self.nanoseconds)) / NANOSECONDS_PER_SECOND; + } +}; + +/// Key-Value pair for map entries +pub const KeyValuePair = struct { + key: Payload, + value: Payload, +}; + +/// Compute hash for Payload (used for HashMap) +/// Note: For performance, consider using simple types as map keys (int, uint, str) +fn payloadHash(payload: Payload) u64 { + return payloadHashDepth(payload, 0); +} + +/// Internal helper for hashing with depth tracking to prevent infinite recursion +fn payloadHashDepth(payload: Payload, depth: usize) u64 { + // Prevent excessive recursion for deeply nested structures + const MAX_DEPTH = 100; + if (depth > MAX_DEPTH) { + return 0; + } + + const Wyhash = std.hash.Wyhash; + + return switch (payload) { + .nil => 0, + .bool => |v| if (v) 1 else 0, + .int => |v| @bitCast(@as(i64, v)), + .uint => |v| v, + .float => |v| @bitCast(v), + .timestamp => |t| { + var h = Wyhash.init(0); + h.update(std.mem.asBytes(&t.seconds)); + h.update(std.mem.asBytes(&t.nanoseconds)); + return h.final(); + }, + .str => |s| { + return Wyhash.hash(0, s.value()); + }, + .bin => |b| { + return Wyhash.hash(0, b.value()); + }, + .ext => |e| { + var h = Wyhash.init(0); + h.update(std.mem.asBytes(&e.type)); + h.update(e.data); + return h.final(); + }, + .arr => |arr| { + var h = Wyhash.init(0); + h.update(std.mem.asBytes(&arr.len)); + for (arr) |item| { + const item_hash = payloadHashDepth(item, depth + 1); + h.update(std.mem.asBytes(&item_hash)); + } + return h.final(); + }, + .map => |m| { + var h = Wyhash.init(0); + const count = m.count(); + h.update(std.mem.asBytes(&count)); + // Hash map entries (order-independent by XOR) + var hash_acc: u64 = 0; + var it = m.map.iterator(); + while (it.next()) |entry| { + const key_hash = payloadHashDepth(entry.key_ptr.*, depth + 1); + const value_hash = payloadHashDepth(entry.value_ptr.*, depth + 1); + // XOR makes hash order-independent + hash_acc ^= key_hash ^ value_hash; + } + h.update(std.mem.asBytes(&hash_acc)); + return h.final(); + }, + }; +} + +/// Detect best SIMD vector size at compile time based on target features +/// Returns the optimal chunk size for SIMD operations +fn detectSIMDChunkSize() comptime_int { + // Check target features for best available SIMD support + const has_avx512 = std.Target.x86.featureSetHas(builtin.cpu.features, .avx512f); + const has_avx2 = std.Target.x86.featureSetHas(builtin.cpu.features, .avx2); + const has_sse2 = std.Target.x86.featureSetHas(builtin.cpu.features, .sse2); + const has_neon = builtin.cpu.arch.isAARCH64(); + + // Prefer larger vectors for better throughput + if (has_avx512) { + return 64; // AVX-512: 512 bits = 64 bytes + } else if (has_avx2) { + return 32; // AVX2: 256 bits = 32 bytes + } else if (has_sse2 or has_neon) { + return 16; // SSE2/NEON: 128 bits = 16 bytes + } else { + return 0; // No SIMD support, use scalar only + } +} + +/// SIMD-optimized string equality comparison +/// Automatically uses the best available SIMD instruction set (AVX-512, AVX2, SSE2, or NEON) +/// For strings >= chunk_size bytes, uses vector operations; otherwise falls back to scalar +fn stringEqualSIMD(a: []const u8, b: []const u8) bool { + if (a.len != b.len) return false; + if (a.len == 0) return true; + + const len = a.len; + + // Detect optimal SIMD chunk size at compile time + const chunk_size = comptime detectSIMDChunkSize(); + + // For very short strings or no SIMD support, use scalar comparison + if (chunk_size == 0 or len < chunk_size) { + return std.mem.eql(u8, a, b); + } + + // Use compile-time detected vector size + const VecType = @Vector(chunk_size, u8); + + var i: usize = 0; + + // Process chunks with SIMD + while (i + chunk_size <= len) : (i += chunk_size) { + const vec_a: VecType = a[i..][0..chunk_size].*; + const vec_b: VecType = b[i..][0..chunk_size].*; + + // Compare vectors: returns a vector of bools + const cmp_result = vec_a == vec_b; + + // Check if all elements are equal + // Use @reduce to check if all comparisons are true + if (!@reduce(.And, cmp_result)) { + return false; + } + } + + // Process remaining bytes with scalar comparison + if (i < len) { + return std.mem.eql(u8, a[i..], b[i..]); + } + + return true; +} + +/// SIMD-optimized binary data equality comparison (same as string) +inline fn binaryEqualSIMD(a: []const u8, b: []const u8) bool { + return stringEqualSIMD(a, b); +} + +/// SIMD-optimized memory copy for binary data +/// Uses larger vector operations when available for better throughput +/// Optimized with memory alignment to reduce unaligned access penalties +fn memcpySIMD(dest: []u8, src: []const u8) void { + std.debug.assert(dest.len >= src.len); + + const len = src.len; + const chunk_size = comptime detectSIMDChunkSize(); + + // For small copies or no SIMD, use standard memcpy + if (chunk_size == 0 or len < chunk_size * 2) { + @memcpy(dest[0..len], src); + return; + } + + const VecType = @Vector(chunk_size, u8); + var i: usize = 0; + + // Memory alignment optimization: + // Align destination pointer to chunk_size boundary for better SIMD performance + // Unaligned loads/stores can be 2-3x slower on some architectures + const dest_addr = @intFromPtr(dest.ptr); + const alignment_offset = dest_addr & (chunk_size - 1); // Modulo chunk_size + + if (alignment_offset != 0 and len >= chunk_size) { + // Calculate bytes needed to reach alignment + const bytes_to_align = chunk_size - alignment_offset; + if (bytes_to_align < len) { + // Copy unaligned head using scalar operations + @memcpy(dest[0..bytes_to_align], src[0..bytes_to_align]); + i = bytes_to_align; + } + } + + // Process chunks with SIMD + while (i + chunk_size <= len) : (i += chunk_size) { + const vec: VecType = src[i..][0..chunk_size].*; + dest[i..][0..chunk_size].* = vec; + } + + // Copy remaining bytes + if (i < len) { + @memcpy(dest[i..len], src[i..]); + } +} + +// ========== Byte Order Conversion Optimizations ========== + +/// Check if a pointer is aligned to a given boundary at runtime +inline fn isAligned(ptr: [*]const u8, comptime alignment: usize) bool { + return (@intFromPtr(ptr) & (alignment - 1)) == 0; +} + +/// Check if a pointer is aligned to the optimal SIMD boundary +inline fn isAlignedToSIMD(ptr: [*]const u8) bool { + const chunk_size = comptime detectSIMDChunkSize(); + if (chunk_size == 0) return true; // No SIMD, always "aligned" + return isAligned(ptr, chunk_size); +} + +/// Optimized aligned memory read for u32 (faster when data is aligned) +inline fn readU32Aligned(ptr: *align(@alignOf(u32)) const [4]u8) u32 { + // Aligned read can use direct pointer cast for better performance + const val_ptr: *align(@alignOf(u32)) const u32 = @ptrCast(ptr); + return byteSwapU32SIMD(val_ptr.*); +} + +/// Optimized aligned memory read for u64 +inline fn readU64Aligned(ptr: *align(@alignOf(u64)) const [8]u8) u64 { + const val_ptr: *align(@alignOf(u64)) const u64 = @ptrCast(ptr); + return byteSwapU64SIMD(val_ptr.*); +} + +/// Optimized aligned memory write for u32 +inline fn writeU32Aligned(ptr: *align(@alignOf(u32)) [4]u8, val: u32) void { + const swapped = byteSwapU32SIMD(val); + const dest_ptr: *align(@alignOf(u32)) u32 = @ptrCast(ptr); + dest_ptr.* = swapped; +} + +/// Optimized aligned memory write for u64 +inline fn writeU64Aligned(ptr: *align(@alignOf(u64)) [8]u8, val: u64) void { + const swapped = byteSwapU64SIMD(val); + const dest_ptr: *align(@alignOf(u64)) u64 = @ptrCast(ptr); + dest_ptr.* = swapped; +} + +/// Large data copy with alignment hints for better optimization +/// Useful for copying strings, binary data >= 64 bytes +inline fn memcpyLarge(dest: []u8, src: []const u8) void { + std.debug.assert(dest.len >= src.len); + + const len = src.len; + + // For very large copies (>= 64 bytes), use SIMD-optimized copy + if (len >= 64) { + memcpySIMD(dest[0..len], src); + } else { + // For smaller sizes, standard memcpy is sufficient + @memcpy(dest[0..len], src); + } +} + +/// Check if byte swap is needed at compile time +inline fn needsByteSwap() bool { + return comptime (native_endian != big_endian); +} + +/// SIMD-accelerated byte swap for u32 (4 bytes) +/// Uses vector operations when available for better throughput +inline fn byteSwapU32SIMD(val: u32) u32 { + if (!needsByteSwap()) { + return val; + } + + const chunk_size = comptime detectSIMDChunkSize(); + + // Use SIMD if available (SSE2+ or NEON) + if (chunk_size >= 16) { + // Zig's @byteSwap is optimized to use BSWAP on x86 or REV on ARM + return @byteSwap(val); + } else { + // Scalar fallback (still efficient) + return @byteSwap(val); + } +} + +/// SIMD-accelerated byte swap for u64 (8 bytes) +inline fn byteSwapU64SIMD(val: u64) u64 { + if (!needsByteSwap()) { + return val; + } + + const chunk_size = comptime detectSIMDChunkSize(); + + if (chunk_size >= 16) { + return @byteSwap(val); + } else { + return @byteSwap(val); + } +} + +/// Fast integer write with optimized byte order conversion +/// This replaces the manual std.mem.writeInt for better performance +inline fn writeU32Fast(buffer: *[4]u8, val: u32) void { + const swapped = byteSwapU32SIMD(val); + const bytes: *const [4]u8 = @ptrCast(&swapped); + buffer.* = bytes.*; +} + +/// Fast integer write for u64 +inline fn writeU64Fast(buffer: *[8]u8, val: u64) void { + const swapped = byteSwapU64SIMD(val); + const bytes: *const [8]u8 = @ptrCast(&swapped); + buffer.* = bytes.*; +} + +/// Fast integer read with optimized byte order conversion +inline fn readU32Fast(buffer: *const [4]u8) u32 { + const val: u32 = @bitCast(buffer.*); + return byteSwapU32SIMD(val); +} + +/// Fast integer read for u64 +inline fn readU64Fast(buffer: *const [8]u8) u64 { + const val: u64 = @bitCast(buffer.*); + return byteSwapU64SIMD(val); +} + +/// Batch convert u32 array to big-endian (optimized for array serialization) +/// This is useful when writing arrays of integers with known format +/// Returns the number of bytes written +/// Optimized with alignment-aware fast paths +pub fn batchU32ToBigEndian(values: []const u32, output: []u8) usize { + std.debug.assert(output.len >= values.len * 4); + + if (!needsByteSwap()) { + // Already big-endian, direct copy + @memcpy(output[0 .. values.len * 4], std.mem.sliceAsBytes(values)); + return values.len * 4; + } + + const chunk_size = comptime detectSIMDChunkSize(); + + // SIMD optimization for batch conversion + if (chunk_size >= 16) { + // Check if output is aligned for faster writes + const output_aligned = isAligned(output.ptr, @alignOf(u32)); + + // Process 4 u32s at a time (16 bytes = 128 bits) + const VecType = @Vector(4, u32); + var i: usize = 0; + + while (i + 4 <= values.len) : (i += 4) { + const vec: VecType = values[i..][0..4].*; + const swapped = @byteSwap(vec); + + const out_offset = i * 4; + + if (output_aligned and isAligned(output.ptr + out_offset, 16)) { + // Fast path: aligned write (can be faster on some CPUs) + const dest_ptr: *align(16) [16]u8 = @ptrCast(@alignCast(output[out_offset..].ptr)); + const swapped_bytes: *const [16]u8 = @ptrCast(&swapped); + dest_ptr.* = swapped_bytes.*; + } else { + // Standard path: unaligned write + const swapped_bytes: *const [16]u8 = @ptrCast(&swapped); + @memcpy(output[out_offset..][0..16], swapped_bytes); + } + } + + // Handle remaining elements + while (i < values.len) : (i += 1) { + var buffer: [4]u8 = undefined; + writeU32Fast(&buffer, values[i]); + @memcpy(output[i * 4 ..][0..4], &buffer); + } + + return values.len * 4; + } else { + // Scalar fallback + for (values, 0..) |val, i| { + var buffer: [4]u8 = undefined; + writeU32Fast(&buffer, val); + @memcpy(output[i * 4 ..][0..4], &buffer); + } + return values.len * 4; + } +} + +/// Batch convert u64 array to big-endian +/// Optimized with alignment-aware fast paths +pub fn batchU64ToBigEndian(values: []const u64, output: []u8) usize { + std.debug.assert(output.len >= values.len * 8); + + if (!needsByteSwap()) { + @memcpy(output[0 .. values.len * 8], std.mem.sliceAsBytes(values)); + return values.len * 8; + } + + const chunk_size = comptime detectSIMDChunkSize(); + + if (chunk_size >= 16) { + // Check if output is aligned for faster writes + const output_aligned = isAligned(output.ptr, @alignOf(u64)); + + // Process 2 u64s at a time (16 bytes) + const VecType = @Vector(2, u64); + var i: usize = 0; + + while (i + 2 <= values.len) : (i += 2) { + const vec: VecType = values[i..][0..2].*; + const swapped = @byteSwap(vec); + + const out_offset = i * 8; + + if (output_aligned and isAligned(output.ptr + out_offset, 16)) { + // Fast path: aligned write + const dest_ptr: *align(16) [16]u8 = @ptrCast(@alignCast(output[out_offset..].ptr)); + const swapped_bytes: *const [16]u8 = @ptrCast(&swapped); + dest_ptr.* = swapped_bytes.*; + } else { + // Standard path: unaligned write + const swapped_bytes: *const [16]u8 = @ptrCast(&swapped); + @memcpy(output[out_offset..][0..16], swapped_bytes); + } + } + + // Handle remaining element + if (i < values.len) { + var buffer: [8]u8 = undefined; + writeU64Fast(&buffer, values[i]); + @memcpy(output[i * 8 ..][0..8], &buffer); + } + + return values.len * 8; + } else { + for (values, 0..) |val, i| { + var buffer: [8]u8 = undefined; + writeU64Fast(&buffer, val); + @memcpy(output[i * 8 ..][0..8], &buffer); + } + return values.len * 8; + } +} + +// ========== End of Byte Order Conversion Optimizations ========== + +/// Helper to check if two Payloads are equal (deep equality) +/// Note: For performance, consider limiting the use of arrays/maps as keys +fn payloadEqual(a: Payload, b: Payload) bool { + return payloadEqualDepth(a, b, 0); +} + +/// Internal helper for deep equality checking with depth tracking +/// max_depth prevents infinite recursion for cyclic structures +fn payloadEqualDepth(a: Payload, b: Payload, depth: usize) bool { + // Prevent excessive recursion (e.g., deeply nested structures) + const MAX_DEPTH = 100; + if (depth > MAX_DEPTH) { + return false; + } + + // Compare by type first + if (@as(@typeInfo(@TypeOf(a)).@"union".tag_type.?, a) != @as(@typeInfo(@TypeOf(b)).@"union".tag_type.?, b)) { + return false; + } + + return switch (a) { + .nil => true, + .bool => |av| av == b.bool, + .int => |av| av == b.int, + .uint => |av| av == b.uint, + .float => |av| av == b.float, + .str => |av| stringEqualSIMD(av.value(), b.str.value()), + .bin => |av| binaryEqualSIMD(av.value(), b.bin.value()), + .timestamp => |av| av.seconds == b.timestamp.seconds and av.nanoseconds == b.timestamp.nanoseconds, + .ext => |av| av.type == b.ext.type and binaryEqualSIMD(av.data, b.ext.data), + + // Deep equality for arrays + .arr => |av| { + const bv = b.arr; + if (av.len != bv.len) return false; + for (av, bv) |a_item, b_item| { + if (!payloadEqualDepth(a_item, b_item, depth + 1)) { + return false; + } + } + return true; + }, + + // Deep equality for maps + .map => |av| { + const bv = b.map; + if (av.count() != bv.count()) return false; + + // Check that all entries in 'a' exist in 'b' with same values + var it = av.map.iterator(); + while (it.next()) |a_entry| { + // Look up the key in map b + if (bv.map.get(a_entry.key_ptr.*)) |b_value| { + if (!payloadEqualDepth(a_entry.value_ptr.*, b_value, depth + 1)) { + return false; // Key found but value differs + } + } else { + return false; // Key not found in b + } + } + return true; + }, + }; +} + +/// Deep clone a Payload (allocates new memory for dynamic types) +fn clonePayload(payload: Payload, allocator: Allocator) !Payload { + return switch (payload) { + .nil, .bool, .int, .uint, .float, .timestamp => payload, // Value types, no allocation needed + + .str => |s| try Payload.strToPayload(s.value(), allocator), + .bin => |b| try Payload.binToPayload(b.value(), allocator), + .ext => |e| try Payload.extToPayload(e.type, e.data, allocator), + + .arr => |arr| { + const new_arr = try allocator.alloc(Payload, arr.len); + var cloned_count: usize = 0; + errdefer { + // cleanup partial clones on error + for (new_arr[0..cloned_count]) |item| { + item.free(allocator); + } + allocator.free(new_arr); + } + for (arr, 0..) |item, i| { + new_arr[i] = try clonePayload(item, allocator); + cloned_count += 1; + } + return Payload{ .arr = new_arr }; + }, + + .map => |m| { + var new_map = Map.init(allocator); + errdefer (Payload{ .map = new_map }).free(allocator); + + // Clone all entries + var it = m.map.iterator(); + while (it.next()) |entry| { + const cloned_key = try clonePayload(entry.key_ptr.*, allocator); + errdefer cloned_key.free(allocator); + const cloned_value = try clonePayload(entry.value_ptr.*, allocator); + errdefer cloned_value.free(allocator); + + // Use putInternal to insert without additional cloning + try new_map.putInternal(cloned_key, cloned_value); + } + return Payload{ .map = new_map }; + }, + }; +} + +/// HashMap context for Payload keys +const PayloadHashContext = struct { + pub fn hash(_: PayloadHashContext, key: Payload) u64 { + return payloadHash(key); + } + + pub fn eql(_: PayloadHashContext, a: Payload, b: Payload) bool { + return payloadEqual(a, b); + } +}; + +/// Internal HashMap type alias for cleaner code +const PayloadHashMap = std.HashMap(Payload, Payload, PayloadHashContext, std.hash_map.default_max_load_percentage); + +/// Map type supporting any Payload as key +/// Now uses HashMap for O(1) average case lookups instead of O(n) linear search +pub const Map = struct { + map: PayloadHashMap, + allocator: Allocator, + + const Self = @This(); + + /// Iterator for Map entries + pub const Iterator = struct { + inner: PayloadHashMap.Iterator, + + pub const Entry = struct { + key_ptr: *const Payload, + value_ptr: *Payload, + }; + + pub fn next(self: *Iterator) ?Entry { + const entry = self.inner.next() orelse return null; + return Entry{ + .key_ptr = entry.key_ptr, + .value_ptr = entry.value_ptr, + }; + } + }; + + pub fn init(allocator: Allocator) Self { + return Self{ + .map = PayloadHashMap.init(allocator), + .allocator = allocator, + }; + } + + pub fn deinit(self: *Self) void { + self.map.deinit(); + } + + pub fn count(self: Self) usize { + return self.map.count(); + } + + /// Get value by Payload key + pub fn get(self: Self, key: Payload) ?Payload { + return self.map.get(key); + } + + /// Get pointer to value by Payload key + pub fn getPtr(self: Self, key: Payload) ?*Payload { + return self.map.getPtr(key); + } + + /// Get value by string key (for backward compatibility) + pub fn getByString(self: Self, key: []const u8) ?Payload { + const key_payload = Payload{ .str = Str.init(key) }; + return self.map.get(key_payload); + } + + /// Put or update a key-value pair (internal, no cloning) + /// Used by deserialization where keys are already allocated + fn putInternal(self: *Self, key: Payload, value: Payload) !void { + const gop = try self.map.getOrPut(key); + if (gop.found_existing) { + // Key already exists, free the old key and update value + gop.key_ptr.free(self.allocator); + gop.key_ptr.* = key; + gop.value_ptr.* = value; + } else { + // New entry, set key and value + gop.key_ptr.* = key; + gop.value_ptr.* = value; + } + } + + /// Put or update a key-value pair + /// Note: The key will be deep-cloned to ensure the map owns it + /// Optimization: Uses getOrPut to hash key only once instead of twice + pub fn put(self: *Self, key: Payload, value: Payload) !void { + // Use getOrPut to hash key only once (instead of getPtr + put) + const gop = try self.map.getOrPut(key); + if (gop.found_existing) { + // Key exists, just update the value without cloning + gop.value_ptr.* = value; + } else { + // Key doesn't exist, clone it and insert + const cloned_key = try clonePayload(key, self.allocator); + errdefer cloned_key.free(self.allocator); + gop.key_ptr.* = cloned_key; + gop.value_ptr.* = value; + } + } + + /// Put or update with string key (for backward compatibility) + /// This allocates memory for the key string + /// Optimization: Uses getOrPut to hash key only once instead of twice + pub fn putString(self: *Self, key: []const u8, value: Payload) !void { + const key_payload = Payload{ .str = Str.init(key) }; + + // Use getOrPut to hash key only once (instead of getPtr + put) + const gop = try self.map.getOrPut(key_payload); + if (gop.found_existing) { + // Key exists, just update the value + gop.value_ptr.* = value; + } else { + // Key doesn't exist, allocate and insert + const new_key = try self.allocator.alloc(u8, key.len); + errdefer self.allocator.free(new_key); + @memcpy(new_key, key); + + gop.key_ptr.* = Payload{ .str = Str.init(new_key) }; + gop.value_ptr.* = value; + } + } + + /// Get or create an entry, returning pointers to key and value + pub fn getOrPut(self: *Self, key: []const u8) !struct { found_existing: bool, key_ptr: *[]const u8, value_ptr: *Payload } { + const key_payload = Payload{ .str = Str.init(key) }; + + const gop = try self.map.getOrPut(key_payload); + if (gop.found_existing) { + // Entry exists, return pointers + const key_str_ptr: *[]const u8 = @constCast(&gop.key_ptr.str.str); + return .{ + .found_existing = true, + .key_ptr = key_str_ptr, + .value_ptr = gop.value_ptr, + }; + } else { + // New entry, allocate key and initialize + const new_key = try self.allocator.alloc(u8, key.len); + errdefer self.allocator.free(new_key); + @memcpy(new_key, key); + + gop.key_ptr.* = Payload{ .str = Str.init(new_key) }; + gop.value_ptr.* = Payload{ .nil = void{} }; + + const key_str_ptr: *[]const u8 = @constCast(&gop.key_ptr.str.str); + return .{ + .found_existing = false, + .key_ptr = key_str_ptr, + .value_ptr = gop.value_ptr, + }; + } + } + + /// Ensure capacity for at least the specified number of entries + pub fn ensureTotalCapacity(self: *Self, new_capacity: u32) !void { + try self.map.ensureTotalCapacity(new_capacity); + } + + /// Get an iterator over map entries + pub fn iterator(self: *const Self) Iterator { + return Iterator{ + .inner = self.map.iterator(), + }; + } +}; + +/// Entity to store msgpack +/// +/// Note: The payload and its subvalues must have the same allocator +pub const Payload = union(enum) { + /// the error for Payload + pub const Error = error{ + NotMap, + NotArray, + }; + + nil: void, + bool: bool, + int: i64, + uint: u64, + float: f64, + str: Str, + bin: Bin, + arr: []Payload, + map: Map, + ext: EXT, + timestamp: Timestamp, + + /// get array element + pub fn getArrElement(self: Payload, index: usize) !Payload { + if (self != .arr) { + return Error.NotArray; + } + return self.arr[index]; + } + + /// get array length + pub fn getArrLen(self: Payload) !usize { + if (self != .arr) { + return Error.NotArray; + } + return self.arr.len; + } + + /// get map's element by string key (backward compatible) + pub fn mapGet(self: Payload, key: []const u8) !?Payload { + if (self != .map) { + return Error.NotMap; + } + return self.map.getByString(key); + } + + /// get map's element by Payload key (supports any key type) + pub fn mapGetGeneric(self: Payload, key: Payload) !?Payload { + if (self != .map) { + return Error.NotMap; + } + return self.map.get(key); + } + + /// set array element + pub fn setArrElement(self: *Payload, index: usize, val: Payload) !void { + if (self.* != .arr) { + return Error.NotArray; + } + self.arr[index] = val; + } + + /// put a new element to map payload with string key (backward compatible) + pub fn mapPut(self: *Payload, key: []const u8, val: Payload) !void { + if (self.* != .map) { + return Error.NotMap; + } + try self.map.putString(key, val); + } + + /// put a new element to map payload with Payload key (supports any key type) + /// Note: The key Payload will be stored directly, caller is responsible for + /// managing key's memory if needed (e.g., for str/bin/ext types) + pub fn mapPutGeneric(self: *Payload, key: Payload, val: Payload) !void { + if (self.* != .map) { + return Error.NotMap; + } + try self.map.put(key, val); + } + + /// deep clone a Payload, allocating owned copies for dynamic data + pub fn deepClone(self: Payload, allocator: Allocator) !Payload { + return clonePayload(self, allocator); + } + + /// get a NIL payload + pub inline fn nilToPayload() Payload { + return Payload{ + .nil = void{}, + }; + } + + /// get a bool payload + pub inline fn boolToPayload(val: bool) Payload { + return Payload{ + .bool = val, + }; + } + + /// get a int payload + pub inline fn intToPayload(val: i64) Payload { + return Payload{ + .int = val, + }; + } + + /// get a uint payload + pub inline fn uintToPayload(val: u64) Payload { + return Payload{ + .uint = val, + }; + } + + /// get a float payload + pub inline fn floatToPayload(val: f64) Payload { + return Payload{ + .float = val, + }; + } + + /// get a str payload + pub fn strToPayload(val: []const u8, allocator: Allocator) !Payload { + // allocate memory + const new_str = try allocator.alloc(u8, val.len); + // copy the value + @memcpy(new_str, val); + return Payload{ + .str = Str.init(new_str), + }; + } + + /// get a bin payload + pub fn binToPayload(val: []const u8, allocator: Allocator) !Payload { + // allocate memory + const new_bin = try allocator.alloc(u8, val.len); + // copy the value + @memcpy(new_bin, val); + return Payload{ + .bin = Bin.init(new_bin), + }; + } + + /// get an array payload + pub fn arrPayload(len: usize, allocator: Allocator) !Payload { + const arr = try allocator.alloc(Payload, len); + // Initialize with nil to ensure safe memory state for free() + // Note: While this adds overhead, it prevents undefined behavior + // when arrays are partially filled or freed before full initialization + + // Optimization: Use pointer arithmetic for faster initialization + // This is significantly faster than a loop for large arrays + const nil_payload = Payload.nilToPayload(); + for (arr) |*item| { + item.* = nil_payload; + } + return Payload{ + .arr = arr, + }; + } + + /// get a map payload + pub fn mapPayload(allocator: Allocator) Payload { + return Payload{ + .map = Map.init(allocator), + }; + } + + /// get an ext payload + pub fn extToPayload(t: i8, data: []const u8, allocator: Allocator) !Payload { + // allocate memory + const new_data = try allocator.alloc(u8, data.len); + // copy the value + @memcpy(new_data, data); + return Payload{ + .ext = EXT.init(t, new_data), + }; + } + + /// get a timestamp payload + pub inline fn timestampToPayload(seconds: i64, nanoseconds: u32) Payload { + return Payload{ + .timestamp = Timestamp.new(seconds, nanoseconds), + }; + } + + /// get a timestamp payload from seconds only + pub inline fn timestampFromSeconds(seconds: i64) Payload { + return Payload{ + .timestamp = Timestamp.fromSeconds(seconds), + }; + } + + /// get a timestamp payload from nanoseconds since Unix epoch + pub inline fn timestampFromNanos(nanos: i128) Payload { + return Payload{ + .timestamp = Timestamp.fromNanos(nanos), + }; + } + + /// free all memory for this payload and sub payloads + /// the allocator is payload's allocator + /// This is an iterative implementation that avoids stack overflow from deep nesting + /// Optimization: Uses stack-allocated buffer for shallow structures to avoid heap allocation during free + pub fn free(self: Payload, allocator: Allocator) void { + // Use stack-allocated buffer for shallow structures (up to 256 items) + // This avoids heap allocation during memory cleanup for most common cases + const STACK_BUFFER_SIZE = 256; + var stack_buffer: [STACK_BUFFER_SIZE]Payload = undefined; + var stack_len: usize = 0; + + // Fallback to heap if we exceed stack buffer + var heap_stack: ?std.ArrayList(Payload) = null; + defer if (heap_stack) |*hs| { + if (current_zig.minor == 14) { + hs.deinit(); + } else { + hs.deinit(allocator); + } + }; + + // Helper to push to stack (tries stack first, falls back to heap) + const pushPayload = struct { + fn push( + buffer: []Payload, + len: *usize, + heap: *?std.ArrayList(Payload), + alloc: Allocator, + payload: Payload, + ) void { + if (heap.*) |*h| { + // Already using heap + if (current_zig.minor == 14) { + h.append(payload) catch {}; + } else { + h.append(alloc, payload) catch {}; + } + } else if (len.* < buffer.len) { + // Stack buffer has space + buffer[len.*] = payload; + len.* += 1; + } else { + // Stack buffer full, migrate to heap + var new_heap = if (current_zig.minor == 14) + std.ArrayList(Payload).init(alloc) + else + std.ArrayList(Payload).empty; + + // Copy existing items from stack buffer to heap + for (buffer[0..len.*]) |item| { + if (current_zig.minor == 14) { + new_heap.append(item) catch return; + } else { + new_heap.append(alloc, item) catch return; + } + } + // Add new item + if (current_zig.minor == 14) { + new_heap.append(payload) catch return; + } else { + new_heap.append(alloc, payload) catch return; + } + heap.* = new_heap; + len.* = 0; // Clear stack buffer + } + } + }.push; + + // Helper to pop from stack + const popPayload = struct { + fn pop( + buffer: []Payload, + len: *usize, + heap: *?std.ArrayList(Payload), + ) ?Payload { + if (heap.*) |*h| { + if (h.items.len > 0) { + return h.pop(); + } + } + if (len.* > 0) { + len.* -= 1; + return buffer[len.*]; + } + return null; + } + }.pop; + + // Start with self + pushPayload(&stack_buffer, &stack_len, &heap_stack, allocator, self); + + while (popPayload(&stack_buffer, &stack_len, &heap_stack)) |payload_item| { + switch (payload_item) { + .str => |s| allocator.free(s.value()), + .bin => |b| allocator.free(b.value()), + .ext => |e| allocator.free(e.data), + + .arr => |arr| { + defer allocator.free(arr); + // Push children to stack in reverse order + var i = arr.len; + while (i > 0) { + i -= 1; + pushPayload(&stack_buffer, &stack_len, &heap_stack, allocator, arr[i]); + } + }, + + .map => |map| { + var map_copy = map; + defer map_copy.deinit(); + // Push both keys and values to stack for recursive freeing + var it = map_copy.map.iterator(); + while (it.next()) |entry| { + pushPayload(&stack_buffer, &stack_len, &heap_stack, allocator, entry.key_ptr.*); + pushPayload(&stack_buffer, &stack_len, &heap_stack, allocator, entry.value_ptr.*); + } + }, + + else => {}, // nil, bool, int, uint, float, timestamp - no memory to free + } + } + } + + /// get an i64 value from payload + /// Tries to get i64 value, converting uint if it fits within i64 range. + /// This is a lenient conversion method. + pub fn getInt(self: Payload) !i64 { + return switch (self) { + .int => |val| val, + .uint => |val| { + if (val <= std.math.maxInt(i64)) { + return @intCast(val); + } + // Value exceeds i64 range + return MsgPackError.InvalidType; + }, + else => return MsgPackError.InvalidType, + }; + } + + /// get an u64 value from payload + /// Tries to get u64 value, converting positive int if possible. + /// This is a lenient conversion method. + pub fn getUint(self: Payload) !u64 { + return switch (self) { + .int => |val| { + if (val >= 0) { + return @intCast(val); + } + // Negative values cannot be converted to u64 + return MsgPackError.InvalidType; + }, + .uint => |val| val, + else => return MsgPackError.InvalidType, + }; + } + + /// Get i64 value without type conversion (strict mode). + /// Returns error if payload is not exactly an int type. + pub fn asInt(self: Payload) !i64 { + return switch (self) { + .int => |val| val, + else => MsgPackError.InvalidType, + }; + } + + /// Get u64 value without type conversion (strict mode). + /// Returns error if payload is not exactly a uint type. + pub fn asUint(self: Payload) !u64 { + return switch (self) { + .uint => |val| val, + else => MsgPackError.InvalidType, + }; + } + + /// Get f64 value without type conversion. + pub fn asFloat(self: Payload) !f64 { + return switch (self) { + .float => |val| val, + else => MsgPackError.InvalidType, + }; + } + + /// Get bool value. + pub fn asBool(self: Payload) !bool { + return switch (self) { + .bool => |val| val, + else => MsgPackError.InvalidType, + }; + } + + /// Get string slice. The string data is owned by the Payload. + pub fn asStr(self: Payload) ![]const u8 { + return switch (self) { + .str => |s| s.value(), + else => MsgPackError.InvalidType, + }; + } + + /// Get binary data slice. The data is owned by the Payload. + pub fn asBin(self: Payload) ![]u8 { + return switch (self) { + .bin => |b| b.value(), + else => MsgPackError.InvalidType, + }; + } + + /// Check if payload is nil. + pub inline fn isNil(self: Payload) bool { + return self == .nil; + } + + /// Check if payload is a number (int, uint, or float). + pub inline fn isNumber(self: Payload) bool { + return switch (self) { + .int, .uint, .float => true, + else => false, + }; + } + + /// Check if payload is an integer (int or uint). + pub inline fn isInteger(self: Payload) bool { + return switch (self) { + .int, .uint => true, + else => false, + }; + } +}; + +/// markers +const Markers = enum(u8) { + POSITIVE_FIXINT = 0x00, + FIXMAP = 0x80, + FIXARRAY = 0x90, + FIXSTR = 0xa0, + NIL = 0xc0, + FALSE = 0xc2, + TRUE = 0xc3, + BIN8 = 0xc4, + BIN16 = 0xc5, + BIN32 = 0xc6, + EXT8 = 0xc7, + EXT16 = 0xc8, + EXT32 = 0xc9, + FLOAT32 = 0xca, + FLOAT64 = 0xcb, + UINT8 = 0xcc, + UINT16 = 0xcd, + UINT32 = 0xce, + UINT64 = 0xcf, + INT8 = 0xd0, + INT16 = 0xd1, + INT32 = 0xd2, + INT64 = 0xd3, + FIXEXT1 = 0xd4, + FIXEXT2 = 0xd5, + FIXEXT4 = 0xd6, + FIXEXT8 = 0xd7, + FIXEXT16 = 0xd8, + STR8 = 0xd9, + STR16 = 0xda, + STR32 = 0xdb, + ARRAY16 = 0xdc, + ARRAY32 = 0xdd, + MAP16 = 0xde, + MAP32 = 0xdf, + NEGATIVE_FIXINT = 0xe0, +}; + +/// A collection of errors that may occur when reading the payload +pub const MsgPackError = error{ + StrDataLengthTooLong, + BinDataLengthTooLong, + ArrayLengthTooLong, + TupleLengthTooLong, + MapLengthTooLong, + InputValueTooLarge, + FixedValueWriting, + TypeMarkerReading, + TypeMarkerWriting, + DataReading, + DataWriting, + ExtTypeReading, + ExtTypeWriting, + ExtTypeLength, + InvalidType, + LengthReading, + LengthWriting, + Internal, + + // New safety errors for iterative parser + MaxDepthExceeded, // Nesting depth exceeded limit + ArrayTooLarge, // Array has too many elements + MapTooLarge, // Map has too many key-value pairs + StringTooLong, // String exceeds length limit + ExtDataTooLarge, // Extension data exceeds length limit +}; + +/// Create an instance of msgpack_pack with custom limits +pub fn PackWithLimits( + comptime WriteContext: type, + comptime ReadContext: type, + comptime WriteError: type, + comptime ReadError: type, + comptime writeFn: fn (context: WriteContext, bytes: []const u8) WriteError!usize, + comptime readFn: fn (context: ReadContext, arr: []u8) ReadError!usize, + comptime limits: ParseLimits, +) type { + return struct { + write_context: WriteContext, + read_context: ReadContext, + + const Self = @This(); + const parse_limits = limits; + + /// init + pub fn init(writeContext: WriteContext, readContext: ReadContext) Self { + return Self{ + .write_context = writeContext, + .read_context = readContext, + }; + } + + /// wrap for writeFn + fn writeTo(self: Self, bytes: []const u8) !usize { + return writeFn(self.write_context, bytes); + } + + /// write one byte + inline fn writeByte(self: Self, byte: u8) !void { + const bytes = [_]u8{byte}; + const len = try self.writeTo(&bytes); + if (len != 1) { + return MsgPackError.LengthWriting; + } + } + + /// write data + inline fn writeData(self: Self, data: []const u8) !void { + const len = try self.writeTo(data); + if (len != data.len) { + return MsgPackError.LengthWriting; + } + } + + /// Generic integer write helper + inline fn writeIntRaw(self: Self, comptime T: type, val: T) !void { + // Use optimized SIMD byte swap for common integer types + if (T == u32) { + var arr: [4]u8 = undefined; + writeU32Fast(&arr, val); + try self.writeData(&arr); + } else if (T == u64) { + var arr: [8]u8 = undefined; + writeU64Fast(&arr, val); + try self.writeData(&arr); + } else { + // Standard path for other types (u8, u16, i8, i16, i32, i64) + var arr: [@sizeOf(T)]u8 = undefined; + std.mem.writeInt(T, &arr, val, big_endian); + try self.writeData(&arr); + } + } + + /// Generic data write with length prefix + inline fn writeDataWithLength(self: Self, comptime LenType: type, data: []const u8) !void { + try self.writeIntRaw(LenType, @intCast(data.len)); + try self.writeData(data); + } + + /// Generic integer value write (without marker) + inline fn writeIntValue(self: Self, comptime T: type, val: T) !void { + if (T == u8 or T == i8) { + try self.writeByte(@bitCast(val)); + } else { + try self.writeIntRaw(T, val); + } + } + + /// Generic integer write with marker + inline fn writeIntWithMarker(self: Self, comptime T: type, marker: Markers, val: T) !void { + try self.writeTypeMarker(marker); + try self.writeIntValue(T, val); + } + + /// write type marker + inline fn writeTypeMarker(self: Self, comptime marker: Markers) !void { + switch (marker) { + .POSITIVE_FIXINT, .FIXMAP, .FIXARRAY, .FIXSTR, .NEGATIVE_FIXINT => { + const err_msg = comptimePrint("marker ({}) is wrong, the can not be write directly!", .{marker}); + @compileError(err_msg); + }, + else => {}, + } + try self.writeByte(@intFromEnum(marker)); + } + + /// write nil + fn writeNil(self: Self) !void { + try self.writeTypeMarker(Markers.NIL); + } + + /// write bool + fn writeBool(self: Self, val: bool) !void { + if (val) { + try self.writeTypeMarker(Markers.TRUE); + } else { + try self.writeTypeMarker(Markers.FALSE); + } + } + + /// write positive fix int + inline fn writePfixInt(self: Self, val: u8) !void { + if (val <= MAX_POSITIVE_FIXINT) { + try self.writeByte(val); + } else { + return MsgPackError.InputValueTooLarge; + } + } + + inline fn writeU8Value(self: Self, val: u8) !void { + try self.writeIntValue(u8, val); + } + + /// write u8 int + fn writeU8(self: Self, val: u8) !void { + try self.writeIntWithMarker(u8, .UINT8, val); + } + + inline fn writeU16Value(self: Self, val: u16) !void { + try self.writeIntValue(u16, val); + } + + /// write u16 int + fn writeU16(self: Self, val: u16) !void { + try self.writeIntWithMarker(u16, .UINT16, val); + } + + inline fn writeU32Value(self: Self, val: u32) !void { + try self.writeIntValue(u32, val); + } + + /// write u32 int + fn writeU32(self: Self, val: u32) !void { + try self.writeIntWithMarker(u32, .UINT32, val); + } + + inline fn writeU64Value(self: Self, val: u64) !void { + try self.writeIntValue(u64, val); + } + + /// write u64 int + fn writeU64(self: Self, val: u64) !void { + try self.writeIntWithMarker(u64, .UINT64, val); + } + + /// write negative fix int + inline fn writeNfixInt(self: Self, val: i8) !void { + if (val >= MIN_NEGATIVE_FIXINT and val <= -1) { + try self.writeByte(@bitCast(val)); + } else { + return MsgPackError.InputValueTooLarge; + } + } + + inline fn writeI8Value(self: Self, val: i8) !void { + try self.writeIntValue(i8, val); + } + + /// write i8 int + fn writeI8(self: Self, val: i8) !void { + try self.writeIntWithMarker(i8, .INT8, val); + } + + inline fn writeI16Value(self: Self, val: i16) !void { + try self.writeIntValue(i16, val); + } + + /// write i16 int + fn writeI16(self: Self, val: i16) !void { + try self.writeIntWithMarker(i16, .INT16, val); + } + + inline fn writeI32Value(self: Self, val: i32) !void { + try self.writeIntValue(i32, val); + } + + /// write i32 int + fn writeI32(self: Self, val: i32) !void { + try self.writeIntWithMarker(i32, .INT32, val); + } + + inline fn writeI64Value(self: Self, val: i64) !void { + try self.writeIntValue(i64, val); + } + + /// write i64 int + fn writeI64(self: Self, val: i64) !void { + try self.writeIntWithMarker(i64, .INT64, val); + } + + /// write uint + fn writeUint(self: Self, val: u64) !void { + if (val <= MAX_POSITIVE_FIXINT) { + try self.writePfixInt(@intCast(val)); + } else if (val <= MAX_UINT8) { + try self.writeU8(@intCast(val)); + } else if (val <= MAX_UINT16) { + try self.writeU16(@intCast(val)); + } else if (val <= MAX_UINT32) { + try self.writeU32(@intCast(val)); + } else { + try self.writeU64(val); + } + } + + /// write int + fn writeInt(self: Self, val: i64) !void { + if (val >= 0) { + try self.writeUint(@intCast(val)); + } else if (val >= MIN_NEGATIVE_FIXINT) { + try self.writeNfixInt(@intCast(val)); + } else if (val >= MIN_INT8) { + try self.writeI8(@intCast(val)); + } else if (val >= MIN_INT16) { + try self.writeI16(@intCast(val)); + } else if (val >= MIN_INT32) { + try self.writeI32(@intCast(val)); + } else { + try self.writeI64(val); + } + } + + inline fn writeF32Value(self: Self, val: f32) !void { + const int: u32 = @bitCast(val); + var buffer: [4]u8 = undefined; + writeU32Fast(&buffer, int); + try self.writeData(&buffer); + } + + /// write f32 + fn writeF32(self: Self, val: f32) !void { + try self.writeTypeMarker(.FLOAT32); + try self.writeF32Value(val); + } + + inline fn writeF64Value(self: Self, val: f64) !void { + const int: u64 = @bitCast(val); + var arr: [8]u8 = undefined; + std.mem.writeInt(u64, &arr, int, big_endian); + + try self.writeData(&arr); + } + + /// write f64 + fn writeF64(self: Self, val: f64) !void { + try self.writeTypeMarker(.FLOAT64); + try self.writeF64Value(val); + } + + /// write float + fn writeFloat(self: Self, val: f64) !void { + // A value should only be encoded as f32 if it can be + // represented exactly without loss of precision. + const val_f32: f32 = @floatCast(val); + if (val == @as(f64, val_f32)) { + try self.writeF32(val_f32); + } else { + try self.writeF64(val); + } + } + + inline fn writeFixStrValue(self: Self, str: []const u8) !void { + try self.writeData(str); + } + + /// write fix str + fn writeFixStr(self: Self, str: []const u8) !void { + const len = str.len; + if (len > MAX_FIXSTR_LEN) { + return MsgPackError.StrDataLengthTooLong; + } + const header: u8 = @intFromEnum(Markers.FIXSTR) + @as(u8, @intCast(len)); + try self.writeByte(header); + try self.writeFixStrValue(str); + } + + /// Generic string writer for different size formats + /// Reduces code duplication for STR8/16/32 + inline fn writeStrGeneric(self: Self, comptime LenType: type, comptime marker: Markers, str: []const u8) !void { + const max_len = std.math.maxInt(LenType); + if (str.len > max_len) { + return MsgPackError.StrDataLengthTooLong; + } + try self.writeTypeMarker(marker); + try self.writeDataWithLength(LenType, str); + } + + inline fn writeStr8Value(self: Self, str: []const u8) !void { + try self.writeDataWithLength(u8, str); + } + + fn writeStr8(self: Self, str: []const u8) !void { + try self.writeStrGeneric(u8, .STR8, str); + } + + inline fn writeStr16Value(self: Self, str: []const u8) !void { + try self.writeDataWithLength(u16, str); + } + + fn writeStr16(self: Self, str: []const u8) !void { + try self.writeStrGeneric(u16, .STR16, str); + } + + inline fn writeStr32Value(self: Self, str: []const u8) !void { + try self.writeDataWithLength(u32, str); + } + + fn writeStr32(self: Self, str: []const u8) !void { + try self.writeStrGeneric(u32, .STR32, str); + } + + /// write str + fn writeStr(self: Self, str: Str) !void { + const len = str.value().len; + if (len <= MAX_FIXSTR_LEN) { + try self.writeFixStr(str.value()); + } else if (len <= MAX_UINT8) { + try self.writeStr8(str.value()); + } else if (len <= MAX_UINT16) { + try self.writeStr16(str.value()); + } else { + try self.writeStr32(str.value()); + } + } + + /// Generic binary writer for different size formats + /// Reduces code duplication for BIN8/16/32 + inline fn writeBinGeneric(self: Self, comptime LenType: type, comptime marker: Markers, bin: []const u8) !void { + const max_len = std.math.maxInt(LenType); + if (bin.len > max_len) { + return MsgPackError.BinDataLengthTooLong; + } + try self.writeTypeMarker(marker); + try self.writeDataWithLength(LenType, bin); + } + + fn writeBin8(self: Self, bin: []const u8) !void { + try self.writeBinGeneric(u8, .BIN8, bin); + } + + fn writeBin16(self: Self, bin: []const u8) !void { + try self.writeBinGeneric(u16, .BIN16, bin); + } + + fn writeBin32(self: Self, bin: []const u8) !void { + try self.writeBinGeneric(u32, .BIN32, bin); + } + + /// write bin + fn writeBin(self: Self, bin: Bin) !void { + const len = bin.value().len; + if (len <= MAX_UINT8) { + try self.writeBin8(bin.value()); + } else if (len <= MAX_UINT16) { + try self.writeBin16(bin.value()); + } else { + try self.writeBin32(bin.value()); + } + } + + inline fn writeExtValue(self: Self, ext: EXT) !void { + try self.writeI8Value(ext.type); + try self.writeData(ext.data); + } + + /// Generic fixed-size extension writer + /// Reduces code duplication for FIXEXT1/2/4/8/16 + inline fn writeFixExtGeneric(self: Self, comptime expected_len: usize, comptime marker: Markers, ext: EXT) !void { + if (ext.data.len != expected_len) { + return MsgPackError.ExtTypeLength; + } + try self.writeTypeMarker(marker); + try self.writeExtValue(ext); + } + + fn writeFixExt1(self: Self, ext: EXT) !void { + try self.writeFixExtGeneric(FIXEXT1_LEN, .FIXEXT1, ext); + } + + fn writeFixExt2(self: Self, ext: EXT) !void { + try self.writeFixExtGeneric(FIXEXT2_LEN, .FIXEXT2, ext); + } + + fn writeFixExt4(self: Self, ext: EXT) !void { + try self.writeFixExtGeneric(FIXEXT4_LEN, .FIXEXT4, ext); + } + + fn writeFixExt8(self: Self, ext: EXT) !void { + try self.writeFixExtGeneric(FIXEXT8_LEN, .FIXEXT8, ext); + } + + fn writeFixExt16(self: Self, ext: EXT) !void { + try self.writeFixExtGeneric(FIXEXT16_LEN, .FIXEXT16, ext); + } + + /// Generic extension writer for variable-size formats + /// Reduces code duplication for EXT8/16/32 + inline fn writeExtGeneric(self: Self, comptime LenType: type, comptime marker: Markers, ext: EXT) !void { + const max_len = std.math.maxInt(LenType); + if (ext.data.len > max_len) { + return MsgPackError.ExtTypeLength; + } + try self.writeTypeMarker(marker); + + // Write length using appropriate size + const len_val: LenType = @intCast(ext.data.len); + try self.writeIntValue(LenType, len_val); + + try self.writeExtValue(ext); + } + + fn writeExt8(self: Self, ext: EXT) !void { + try self.writeExtGeneric(u8, .EXT8, ext); + } + + fn writeExt16(self: Self, ext: EXT) !void { + try self.writeExtGeneric(u16, .EXT16, ext); + } + + fn writeExt32(self: Self, ext: EXT) !void { + try self.writeExtGeneric(u32, .EXT32, ext); + } + + fn writeExt(self: Self, ext: EXT) !void { + const len = ext.data.len; + if (len == FIXEXT1_LEN) { + try self.writeFixExt1(ext); + } else if (len == FIXEXT2_LEN) { + try self.writeFixExt2(ext); + } else if (len == FIXEXT4_LEN) { + try self.writeFixExt4(ext); + } else if (len == FIXEXT8_LEN) { + try self.writeFixExt8(ext); + } else if (len == FIXEXT16_LEN) { + try self.writeFixExt16(ext); + } else if (len <= std.math.maxInt(u8)) { + try self.writeExt8(ext); + } else if (len <= std.math.maxInt(u16)) { + try self.writeExt16(ext); + } else if (len <= std.math.maxInt(u32)) { + try self.writeExt32(ext); + } else { + return MsgPackError.ExtTypeLength; + } + } + + /// write timestamp + fn writeTimestamp(self: Self, timestamp: Timestamp) !void { + // According to MessagePack spec, timestamp uses extension type -1 + + // timestamp 32 format: seconds fit in 32-bit unsigned int and nanoseconds is 0 + if (timestamp.nanoseconds == 0 and timestamp.seconds >= 0 and timestamp.seconds <= MAX_UINT32) { + var data: [TIMESTAMP32_DATA_LEN]u8 = undefined; + writeU32Fast(&data, @intCast(timestamp.seconds)); + const ext = EXT{ .type = TIMESTAMP_EXT_TYPE, .data = &data }; + try self.writeExt(ext); + return; + } + + // timestamp 64 format: seconds fit in 34-bit and nanoseconds <= 999999999 + if (timestamp.seconds >= 0 and (timestamp.seconds >> TIMESTAMP64_SECONDS_BITS) == 0 and timestamp.nanoseconds <= MAX_NANOSECONDS) { + const data64: u64 = (@as(u64, timestamp.nanoseconds) << TIMESTAMP64_SECONDS_BITS) | @as(u64, @intCast(timestamp.seconds)); + var data: [TIMESTAMP64_DATA_LEN]u8 = undefined; + writeU64Fast(&data, data64); + const ext = EXT{ .type = TIMESTAMP_EXT_TYPE, .data = &data }; + try self.writeExt(ext); + return; + } + + // timestamp 96 format: full range with signed 64-bit seconds and 32-bit nanoseconds + if (timestamp.nanoseconds <= MAX_NANOSECONDS) { + var data: [TIMESTAMP96_DATA_LEN]u8 = undefined; + writeU32Fast(data[0..4], timestamp.nanoseconds); + // For i64, use standard path (could add writeI64Fast if needed) + std.mem.writeInt(i64, data[4..12], timestamp.seconds, big_endian); + const ext = EXT{ .type = TIMESTAMP_EXT_TYPE, .data = &data }; + try self.writeExt(ext); + return; + } + + return MsgPackError.InvalidType; + } + + /// write payload + pub fn write(self: Self, payload: Payload) !void { + switch (payload) { + .nil => { + try self.writeNil(); + }, + .bool => |val| { + try self.writeBool(val); + }, + .int => |val| { + try self.writeInt(val); + }, + .uint => |val| { + try self.writeUint(val); + }, + .float => |val| { + try self.writeFloat(val); + }, + .str => |val| { + try self.writeStr(val); + }, + .bin => |val| { + try self.writeBin(val); + }, + .arr => |arr| { + const len = arr.len; + if (len <= MAX_FIXARRAY_LEN) { + const header: u8 = @intFromEnum(Markers.FIXARRAY) + @as(u8, @intCast(len)); + try self.writeU8Value(header); + } else if (len <= MAX_UINT16) { + try self.writeTypeMarker(.ARRAY16); + try self.writeU16Value(@as(u16, @intCast(len))); + } else if (len <= MAX_UINT32) { + try self.writeTypeMarker(.ARRAY32); + try self.writeU32Value(@as(u32, @intCast(len))); + } else { + return MsgPackError.ArrayLengthTooLong; + } + for (arr) |val| { + try self.write(val); + } + }, + .map => |map| { + const len = map.count(); + if (len <= MAX_FIXMAP_LEN) { + const header: u8 = @intFromEnum(Markers.FIXMAP) + @as(u8, @intCast(len)); + try self.writeU8Value(header); + } else if (len <= MAX_UINT16) { + try self.writeTypeMarker(.MAP16); + try self.writeU16Value(@intCast(len)); + } else if (len <= MAX_UINT32) { + try self.writeTypeMarker(.MAP32); + try self.writeU32Value(@intCast(len)); + } else { + return MsgPackError.MapLengthTooLong; + } + // Write key-value pairs, key can be any Payload type + var itera = map.iterator(); + while (itera.next()) |entry| { + try self.write(entry.key_ptr.*); + try self.write(entry.value_ptr.*); + } + }, + .ext => |ext| { + try self.writeExt(ext); + }, + .timestamp => |timestamp| { + try self.writeTimestamp(timestamp); + }, + } + } + + fn readFrom(self: Self, bytes: []u8) !usize { + return readFn(self.read_context, bytes); + } + + inline fn readByte(self: Self) !u8 { + var res = [1]u8{0}; + const len = try self.readFrom(&res); + + if (len != 1) { + return MsgPackError.LengthReading; + } + + return res[0]; + } + + inline fn readData(self: Self, allocator: Allocator, len: usize) ![]u8 { + const data = try allocator.alloc(u8, len); + errdefer allocator.free(data); + const data_len = try self.readFrom(data); + + if (data_len != len) { + return MsgPackError.LengthReading; + } + + return data; + } + + /// Generic integer read helper + inline fn readIntRaw(self: Self, comptime T: type) !T { + // Use optimized SIMD byte swap for common integer types + if (T == u32) { + var buffer: [4]u8 = undefined; + const len = try self.readFrom(&buffer); + if (len != 4) { + return MsgPackError.LengthReading; + } + return readU32Fast(&buffer); + } else if (T == u64) { + var buffer: [8]u8 = undefined; + const len = try self.readFrom(&buffer); + if (len != 8) { + return MsgPackError.LengthReading; + } + return readU64Fast(&buffer); + } else { + // Standard path for other types + var buffer: [@sizeOf(T)]u8 = undefined; + const len = try self.readFrom(&buffer); + if (len != @sizeOf(T)) { + return MsgPackError.LengthReading; + } + return std.mem.readInt(T, &buffer, big_endian); + } + } + + /// Generic integer value read + inline fn readTypedInt(self: Self, comptime T: type) !T { + if (T == u8) { + return self.readByte(); + } else if (T == i8) { + const val = try self.readByte(); + return @bitCast(val); + } else { + return self.readIntRaw(T); + } + } + + fn readTypeMarkerU8(self: Self) !u8 { + const val = try self.readByte(); + return val; + } + + /// Precomputed lookup table for marker byte to Markers enum conversion + /// This eliminates branch misprediction overhead from switch statements + const MARKER_LOOKUP_TABLE: [256]Markers = blk: { + var table: [256]Markers = undefined; + var i: usize = 0; + while (i < 256) : (i += 1) { + const byte: u8 = @intCast(i); + table[i] = switch (byte) { + 0x00...0x7f => .POSITIVE_FIXINT, + 0x80...0x8f => .FIXMAP, + 0x90...0x9f => .FIXARRAY, + 0xa0...0xbf => .FIXSTR, + 0xc0 => .NIL, + 0xc1 => .NIL, // Reserved byte, treat as NIL + 0xc2 => .FALSE, + 0xc3 => .TRUE, + 0xc4 => .BIN8, + 0xc5 => .BIN16, + 0xc6 => .BIN32, + 0xc7 => .EXT8, + 0xc8 => .EXT16, + 0xc9 => .EXT32, + 0xca => .FLOAT32, + 0xcb => .FLOAT64, + 0xcc => .UINT8, + 0xcd => .UINT16, + 0xce => .UINT32, + 0xcf => .UINT64, + 0xd0 => .INT8, + 0xd1 => .INT16, + 0xd2 => .INT32, + 0xd3 => .INT64, + 0xd4 => .FIXEXT1, + 0xd5 => .FIXEXT2, + 0xd6 => .FIXEXT4, + 0xd7 => .FIXEXT8, + 0xd8 => .FIXEXT16, + 0xd9 => .STR8, + 0xda => .STR16, + 0xdb => .STR32, + 0xdc => .ARRAY16, + 0xdd => .ARRAY32, + 0xde => .MAP16, + 0xdf => .MAP32, + 0xe0...0xff => .NEGATIVE_FIXINT, + }; + } + break :blk table; + }; + + /// Fast marker type lookup using precomputed table (O(1) with no branches) + inline fn markerU8To(_: Self, marker_u8: u8) Markers { + return MARKER_LOOKUP_TABLE[marker_u8]; + } + + fn readTypeMarker(self: Self) !Markers { + const val = try self.readTypeMarkerU8(); + return self.markerU8To(val); + } + + inline fn readBoolValue(_: Self, marker: Markers) !bool { + switch (marker) { + .TRUE => return true, + .FALSE => return false, + else => return MsgPackError.TypeMarkerReading, + } + } + + fn readBool(self: Self) !bool { + const marker = try self.readTypeMarker(); + return self.readBoolValue(marker); + } + + inline fn readFixintValue(_: Self, marker_u8: u8) i8 { + return @bitCast(marker_u8); + } + + inline fn readI8Value(self: Self) !i8 { + return self.readTypedInt(i8); + } + + inline fn readV8Value(self: Self) !u8 { + return self.readTypedInt(u8); + } + + inline fn readI16Value(self: Self) !i16 { + return self.readTypedInt(i16); + } + + inline fn readU16Value(self: Self) !u16 { + return self.readTypedInt(u16); + } + + inline fn readI32Value(self: Self) !i32 { + return self.readTypedInt(i32); + } + + inline fn readU32Value(self: Self) !u32 { + return self.readTypedInt(u32); + } + + inline fn readI64Value(self: Self) !i64 { + return self.readTypedInt(i64); + } + + inline fn readU64Value(self: Self) !u64 { + return self.readTypedInt(u64); + } + + fn readIntValue(self: Self, marker_u8: u8) !i64 { + const marker = self.markerU8To(marker_u8); + // Optimized branch order: handle most common cases first + // fixint and 8-bit integers are most common in typical data + switch (marker) { + .NEGATIVE_FIXINT, .POSITIVE_FIXINT => { + const val = self.readFixintValue(marker_u8); + return val; + }, + .INT8 => { + const val = try self.readI8Value(); + return val; + }, + .UINT8 => { + const val = try self.readV8Value(); + return val; + }, + .INT16 => { + const val = try self.readI16Value(); + return val; + }, + .UINT16 => { + const val = try self.readU16Value(); + return val; + }, + .INT32 => { + const val = try self.readI32Value(); + return val; + }, + .UINT32 => { + const val = try self.readU32Value(); + return val; + }, + .INT64 => { + return self.readI64Value(); + }, + .UINT64 => { + const val = try self.readU64Value(); + if (val <= std.math.maxInt(i64)) { + return @intCast(val); + } + return MsgPackError.InvalidType; + }, + else => return MsgPackError.TypeMarkerReading, + } + } + + fn readUintValue(self: Self, marker_u8: u8) !u64 { + const marker = self.markerU8To(marker_u8); + // Optimized branch order: handle most common cases first + // fixint and 8-bit integers are most common in typical data + switch (marker) { + .POSITIVE_FIXINT => { + return marker_u8; + }, + .UINT8 => { + const val = try self.readV8Value(); + return val; + }, + .INT8 => { + const val = try self.readI8Value(); + if (val >= 0) { + return @intCast(val); + } + return MsgPackError.InvalidType; + }, + .UINT16 => { + const val = try self.readU16Value(); + return val; + }, + .INT16 => { + const val = try self.readI16Value(); + if (val >= 0) { + return @intCast(val); + } + return MsgPackError.InvalidType; + }, + .UINT32 => { + const val = try self.readU32Value(); + return val; + }, + .INT32 => { + const val = try self.readI32Value(); + if (val >= 0) { + return @intCast(val); + } + return MsgPackError.InvalidType; + }, + .UINT64 => { + return self.readU64Value(); + }, + .INT64 => { + const val = try self.readI64Value(); + if (val >= 0) { + return @intCast(val); + } + return MsgPackError.InvalidType; + }, + else => return MsgPackError.TypeMarkerReading, + } + } + + inline fn readF32Value(self: Self) !f32 { + // Use optimized read for u32 + var buffer: [4]u8 = undefined; + const len = try self.readFrom(&buffer); + if (len != 4) { + return MsgPackError.LengthReading; + } + const val_int = readU32Fast(&buffer); + const val: f32 = @bitCast(val_int); + return val; + } + + inline fn readF64Value(self: Self) !f64 { + // Use optimized read for u64 + var buffer: [8]u8 = undefined; + const len = try self.readFrom(&buffer); + if (len != 8) { + return MsgPackError.LengthReading; + } + const val_int = readU64Fast(&buffer); + const val: f64 = @bitCast(val_int); + return val; + } + + fn readFloatValue(self: Self, marker: Markers) !f64 { + switch (marker) { + .FLOAT32 => { + const val = try self.readF32Value(); + return val; + }, + .FLOAT64 => { + return self.readF64Value(); + }, + else => return MsgPackError.TypeMarkerReading, + } + } + + fn readFixStrValue(self: Self, allocator: Allocator, marker_u8: u8) ![]const u8 { + const len: u8 = marker_u8 - @intFromEnum(Markers.FIXSTR); + const str = try self.readData(allocator, len); + + return str; + } + + /// Generic string reader for different size formats + /// Reduces code duplication for STR8/16/32 + inline fn readStrValueGeneric(self: Self, comptime LenType: type, allocator: Allocator) ![]const u8 { + const len = try self.readTypedInt(LenType); + return try self.readData(allocator, len); + } + + fn readStr8Value(self: Self, allocator: Allocator) ![]const u8 { + return try self.readStrValueGeneric(u8, allocator); + } + + fn readStr16Value(self: Self, allocator: Allocator) ![]const u8 { + return try self.readStrValueGeneric(u16, allocator); + } + + fn readStr32Value(self: Self, allocator: Allocator) ![]const u8 { + return try self.readStrValueGeneric(u32, allocator); + } + + fn readStrValue(self: Self, marker_u8: u8, allocator: Allocator) ![]const u8 { + const marker = self.markerU8To(marker_u8); + + switch (marker) { + .FIXSTR => { + return self.readFixStrValue(allocator, marker_u8); + }, + .STR8 => { + return self.readStr8Value(allocator); + }, + .STR16 => { + return self.readStr16Value(allocator); + }, + .STR32 => { + return self.readStr32Value(allocator); + }, + else => return MsgPackError.TypeMarkerReading, + } + } + + inline fn validateBinLength(len: usize) !void { + // Inline validation for hot path + if (len > parse_limits.max_bin_length) { + return MsgPackError.BinDataLengthTooLong; + } + } + + /// Generic binary data reader for different size formats + /// Reduces code duplication for BIN8/16/32 + inline fn readBinValueGeneric(self: Self, comptime LenType: type, allocator: Allocator) ![]u8 { + const len = try self.readTypedInt(LenType); + try validateBinLength(len); + return try self.readData(allocator, len); + } + + fn readBin8Value(self: Self, allocator: Allocator) ![]u8 { + return try self.readBinValueGeneric(u8, allocator); + } + + fn readBin16Value(self: Self, allocator: Allocator) ![]u8 { + return try self.readBinValueGeneric(u16, allocator); + } + + fn readBin32Value(self: Self, allocator: Allocator) ![]u8 { + return try self.readBinValueGeneric(u32, allocator); + } + + fn readBinValue(self: Self, marker: Markers, allocator: Allocator) ![]u8 { + switch (marker) { + .BIN8 => { + return self.readBin8Value(allocator); + }, + .BIN16 => { + return self.readBin16Value(allocator); + }, + .BIN32 => { + return self.readBin32Value(allocator); + }, + else => return MsgPackError.TypeMarkerReading, + } + } + + inline fn validateExtLength(len: usize) !void { + // Inline validation for hot path + if (len > parse_limits.max_ext_length) { + return MsgPackError.ExtDataTooLarge; + } + } + + inline fn readExtData(self: Self, allocator: Allocator, len: usize) !EXT { + try validateExtLength(len); + const ext_type = try self.readI8Value(); + const data = try self.readData(allocator, len); + return EXT{ + .type = ext_type, + .data = data, + }; + } + + /// Check if a marker can potentially be a timestamp + inline fn isTimestampCandidate(marker: Markers) bool { + return marker == .FIXEXT4 or marker == .FIXEXT8 or marker == .EXT8; + } + + /// Read timestamp 32-bit format (seconds only, 0 nanoseconds) + inline fn readTimestamp32(self: Self) !Timestamp { + const seconds = try self.readU32Value(); + return Timestamp.new(@intCast(seconds), 0); + } + + /// Read timestamp 64-bit format (34-bit seconds + 30-bit nanoseconds) + inline fn readTimestamp64(self: Self) !Timestamp { + const data64 = try self.readU64Value(); + const nanoseconds: u32 = @intCast(data64 >> TIMESTAMP64_SECONDS_BITS); + const seconds: i64 = @intCast(data64 & TIMESTAMP64_SECONDS_MASK); + return Timestamp.new(seconds, nanoseconds); + } + + /// Read timestamp 96-bit format (32-bit nanoseconds + 64-bit seconds) + inline fn readTimestamp96(self: Self) !Timestamp { + const nanoseconds = try self.readU32Value(); + const seconds = try self.readI64Value(); + return Timestamp.new(seconds, nanoseconds); + } + + /// Read non-timestamp EXT data + inline fn readRegularExt(self: Self, ext_type: i8, len: usize, allocator: Allocator) !Payload { + try validateExtLength(len); + const ext_data = try allocator.alloc(u8, len); + errdefer allocator.free(ext_data); + _ = try self.readFrom(ext_data); + return Payload{ .ext = EXT{ .type = ext_type, .data = ext_data } }; + } + + /// Get EXT data length from marker + inline fn getExtLength(marker: Markers) usize { + return switch (marker) { + .FIXEXT1 => FIXEXT1_LEN, + .FIXEXT2 => FIXEXT2_LEN, + .FIXEXT4 => FIXEXT4_LEN, + .FIXEXT8 => FIXEXT8_LEN, + .FIXEXT16 => FIXEXT16_LEN, + else => unreachable, + }; + } + + /// Read and validate EXT8 length for timestamp detection + fn readExt8Length(self: Self) !struct { len: usize, is_timestamp_candidate: bool } { + const len = try self.readV8Value(); + // Only timestamp 96 format uses 12 bytes in EXT8 + if (len != TIMESTAMP96_DATA_LEN) { + return .{ .len = len, .is_timestamp_candidate = false }; + } + return .{ .len = len, .is_timestamp_candidate = true }; + } + + /// Read timestamp payload based on marker + inline fn readTimestampPayload(self: Self, marker: Markers) !Payload { + const required_len: usize = switch (marker) { + .FIXEXT4 => FIXEXT4_LEN, + .FIXEXT8 => FIXEXT8_LEN, + .EXT8 => TIMESTAMP96_DATA_LEN, + else => unreachable, + }; + try validateExtLength(required_len); + const timestamp: Timestamp = switch (marker) { + .FIXEXT4 => try self.readTimestamp32(), + .FIXEXT8 => try self.readTimestamp64(), + .EXT8 => try self.readTimestamp96(), + else => unreachable, + }; + return Payload{ .timestamp = timestamp }; + } + + /// read ext value or timestamp if it's timestamp type (-1) + fn readExtValueOrTimestamp(self: Self, marker: Markers, allocator: Allocator) !Payload { + // Fast path: not a timestamp candidate + if (!isTimestampCandidate(marker)) { + const val = try self.readExtValue(marker, allocator); + return Payload{ .ext = val }; + } + + // Handle EXT8 special case (need to read length first) + if (marker == .EXT8) { + const len_info = try self.readExt8Length(); + + // If not timestamp length, read as regular EXT + if (!len_info.is_timestamp_candidate) { + const ext_type = try self.readI8Value(); + return try self.readRegularExt(ext_type, len_info.len, allocator); + } + } + + // Read extension type to determine if it's a timestamp + const ext_type = try self.readI8Value(); + + // Timestamp type: read timestamp data + if (ext_type == TIMESTAMP_EXT_TYPE) { + return try self.readTimestampPayload(marker); + } + + // Regular EXT: read remaining data + const actual_len = if (marker == .EXT8) TIMESTAMP96_DATA_LEN else getExtLength(marker); + return try self.readRegularExt(ext_type, actual_len, allocator); + } + + fn readExtValue(self: Self, marker: Markers, allocator: Allocator) !EXT { + switch (marker) { + .FIXEXT1 => { + return self.readExtData(allocator, FIXEXT1_LEN); + }, + .FIXEXT2 => { + return self.readExtData(allocator, FIXEXT2_LEN); + }, + .FIXEXT4 => { + return self.readExtData(allocator, FIXEXT4_LEN); + }, + .FIXEXT8 => { + return self.readExtData(allocator, FIXEXT8_LEN); + }, + .FIXEXT16 => { + return self.readExtData(allocator, FIXEXT16_LEN); + }, + .EXT8 => { + const len = try self.readV8Value(); + return self.readExtData(allocator, len); + }, + .EXT16 => { + const len = try self.readU16Value(); + return self.readExtData(allocator, len); + }, + .EXT32 => { + const len = try self.readU32Value(); + return self.readExtData(allocator, len); + }, + else => { + return MsgPackError.InvalidType; + }, + } + } + + // ========== Iterative Parser State Machine ========== + + /// Parse state for iterative parsing + const ParseState = struct { + container_type: enum { + array, // Parsing array elements + map_key, // Expecting map key (must be string) + map_value, // Expecting map value + }, + data: union(enum) { + array: ArrayState, + map: MapState, + }, + }; + + const ArrayState = struct { + items: []Payload, + current_index: usize, + total_length: usize, + }; + + const MapState = struct { + map: Map, + current_key: ?Payload, + remaining_pairs: usize, + }; + + /// Clean up parse stack on error + fn cleanupParseStack(stack: *std.ArrayList(ParseState), allocator: Allocator) void { + // Pop and free all states from the stack + while (stack.items.len > 0) { + const state = stack.pop() orelse break; + switch (state.data) { + .array => |arr_state| { + // Free already parsed elements + for (arr_state.items[0..arr_state.current_index]) |item| { + item.free(allocator); + } + // Free the array itself + allocator.free(arr_state.items); + }, + .map => |map_state| { + // Free current_key if it exists (orphaned key waiting for value) + if (map_state.current_key) |key| { + key.free(allocator); + } + // Free the map and its contents + var map_copy = map_state.map; + defer map_copy.deinit(); + var it = map_copy.map.iterator(); + while (it.next()) |entry| { + // Need to cast away const since we own the keys and need to free them + const key_ptr_mut: *Payload = @constCast(entry.key_ptr); + key_ptr_mut.free(allocator); + entry.value_ptr.free(allocator); + } + }, + } + } + } + + /// Generic container length reader + /// Reduces code duplication for array and map length reading + inline fn readContainerLength( + self: Self, + marker: Markers, + marker_u8: u8, + comptime fix_marker: Markers, + comptime marker_16: Markers, + comptime marker_32: Markers, + comptime base: u8, + ) !usize { + return switch (marker) { + fix_marker => marker_u8 - base, + marker_16 => try self.readU16Value(), + marker_32 => try self.readU32Value(), + else => MsgPackError.InvalidType, + }; + } + + /// Read array length based on marker + inline fn readArrayLength(self: Self, marker: Markers, marker_u8: u8) !usize { + return self.readContainerLength(marker, marker_u8, .FIXARRAY, .ARRAY16, .ARRAY32, FIXARRAY_BASE); + } + + /// Read map length based on marker + inline fn readMapLength(self: Self, marker: Markers, marker_u8: u8) !usize { + return self.readContainerLength(marker, marker_u8, .FIXMAP, .MAP16, .MAP32, FIXMAP_BASE); + } + + /// Helper to append to parse stack (handles Zig version differences) + inline fn appendToStack(stack: *std.ArrayList(ParseState), allocator: Allocator, item: ParseState) !void { + if (current_zig.minor == 14) { + try stack.append(item); + } else { + try stack.append(allocator, item); + } + } + + // ========== End of State Machine Helpers ========== + + /// Fast path for simple types that don't require heap allocation or complex state management + inline fn readSimpleTypeFast(self: Self, marker: Markers, marker_u8: u8) !?Payload { + return switch (marker) { + .NIL => Payload{ .nil = void{} }, + .TRUE => Payload{ .bool = true }, + .FALSE => Payload{ .bool = false }, + + .POSITIVE_FIXINT => Payload{ .uint = marker_u8 }, + .NEGATIVE_FIXINT => Payload{ .int = @as(i8, @bitCast(marker_u8)) }, + + .UINT8 => Payload{ .uint = try self.readV8Value() }, + .UINT16 => Payload{ .uint = try self.readU16Value() }, + .UINT32 => Payload{ .uint = try self.readU32Value() }, + .UINT64 => Payload{ .uint = try self.readU64Value() }, + + .INT8 => Payload{ .int = try self.readI8Value() }, + .INT16 => Payload{ .int = try self.readI16Value() }, + .INT32 => Payload{ .int = try self.readI32Value() }, + .INT64 => Payload{ .int = try self.readI64Value() }, + + .FLOAT32 => Payload{ .float = try self.readF32Value() }, + .FLOAT64 => Payload{ .float = try self.readF64Value() }, + + // Note: FIXEXT4/FIXEXT8 could be timestamps, but we need to read ext_type first + // Since we can't "unread" in the stream, we handle all EXT types in the complex path + // to avoid consuming bytes that need to be re-processed. + + else => null, // Not a simple type, needs complex handling + }; + } + + /// read a payload, please use payload.free to free the memory + /// This is an iterative implementation that avoids stack overflow from deep nesting + pub fn read(self: Self, allocator: Allocator) !Payload { + // Fast path optimization: handle simple types without state machine overhead + const first_marker_u8 = try self.readTypeMarkerU8(); + const first_marker = self.markerU8To(first_marker_u8); + + // Try fast path for simple types (no containers, no allocation needed) + if (try self.readSimpleTypeFast(first_marker, first_marker_u8)) |simple_payload| { + return simple_payload; + } + + // Complex types: use full iterative parser + return self.readComplex(allocator, first_marker, first_marker_u8); + } + + /// Internal iterative parser for complex types (arrays, maps, strings, etc.) + fn readComplex(self: Self, allocator: Allocator, first_marker: Markers, first_marker_u8: u8) !Payload { + // Explicit stack for iterative parsing (on heap) + var parse_stack = if (current_zig.minor == 14) + std.ArrayList(ParseState).init(allocator) + else + std.ArrayList(ParseState).empty; + defer if (current_zig.minor == 14) parse_stack.deinit() else parse_stack.deinit(allocator); + errdefer cleanupParseStack(&parse_stack, allocator); + + // Root payload to return + var root: ?Payload = null; + + // Start with the already-read first marker + var marker_u8 = first_marker_u8; + var marker = first_marker; + + // Main loop (replaces recursion) + // Process first marker directly, then read subsequent markers in loop + var is_first = true; + while (true) { + // Check depth limit + if (parse_stack.items.len >= parse_limits.max_depth) { + return MsgPackError.MaxDepthExceeded; + } + + // Read next type marker (skip on first iteration) + if (!is_first) { + marker_u8 = try self.readTypeMarkerU8(); + marker = self.markerU8To(marker_u8); + } + is_first = false; + + // Current payload being constructed + var current_payload: Payload = undefined; + var needs_parent_fill = true; + + switch (marker) { + // Simple types: construct directly + .NIL => { + current_payload = Payload{ .nil = void{} }; + }, + .TRUE, .FALSE => { + const val = try self.readBoolValue(marker); + current_payload = Payload{ .bool = val }; + }, + .POSITIVE_FIXINT, .UINT8, .UINT16, .UINT32, .UINT64 => { + const val = try self.readUintValue(marker_u8); + current_payload = Payload{ .uint = val }; + }, + .NEGATIVE_FIXINT, .INT8, .INT16, .INT32, .INT64 => { + const val = try self.readIntValue(marker_u8); + current_payload = Payload{ .int = val }; + }, + .FLOAT32, .FLOAT64 => { + const val = try self.readFloatValue(marker); + current_payload = Payload{ .float = val }; + }, + .FIXSTR, .STR8, .STR16, .STR32 => { + const val = try self.readStrValue(marker_u8, allocator); + + // Validate string length + if (val.len > parse_limits.max_string_length) { + allocator.free(val); + return MsgPackError.StringTooLong; + } + + current_payload = Payload{ .str = Str.init(val) }; + }, + .BIN8, .BIN16, .BIN32 => { + const val = try self.readBinValue(marker, allocator); + + // Validate binary length + if (val.len > parse_limits.max_bin_length) { + allocator.free(val); + return MsgPackError.BinDataLengthTooLong; + } + + current_payload = Payload{ .bin = Bin.init(val) }; + }, + + // Container types: push to stack and continue + .FIXARRAY, .ARRAY16, .ARRAY32 => { + const len = try self.readArrayLength(marker, marker_u8); + + // Validate array length + if (len > parse_limits.max_array_length) { + return MsgPackError.ArrayTooLarge; + } + + // Special case: empty array + if (len == 0) { + const arr = try allocator.alloc(Payload, 0); + current_payload = Payload{ .arr = arr }; + } else { + // Allocate array + const arr = try allocator.alloc(Payload, len); + errdefer allocator.free(arr); + + // Push to stack + try appendToStack(&parse_stack, allocator, .{ + .container_type = .array, + .data = .{ .array = .{ + .items = arr, + .current_index = 0, + .total_length = len, + } }, + }); + + needs_parent_fill = false; + continue; // Continue to read first element + } + }, + + .FIXMAP, .MAP16, .MAP32 => { + const len = try self.readMapLength(marker, marker_u8); + + // Validate map size + if (len > parse_limits.max_map_size) { + return MsgPackError.MapTooLarge; + } + + // Special case: empty map + if (len == 0) { + current_payload = Payload{ .map = Map.init(allocator) }; + } else { + // Initialize map + var map = Map.init(allocator); + var map_owned = false; + errdefer if (!map_owned) map.deinit(); + + const capacity = std.math.cast(u32, len) orelse { + return MsgPackError.MapTooLarge; + }; + try map.ensureTotalCapacity(capacity); + + // Push to stack + try appendToStack(&parse_stack, allocator, .{ + .container_type = .map_key, + .data = .{ .map = .{ + .map = map, + .current_key = null, + .remaining_pairs = len, + } }, + }); + map_owned = true; + + needs_parent_fill = false; + continue; // Continue to read first key + } + }, + + // Extension types + .FIXEXT1, .FIXEXT2, .FIXEXT4, .FIXEXT8, .FIXEXT16, .EXT8, .EXT16, .EXT32 => { + const ext_result = try self.readExtValueOrTimestamp(marker, allocator); + current_payload = ext_result; + }, + } + + // Fill parent container or set root + if (needs_parent_fill) { + // Add errdefer to clean up current_payload if parent fill fails + errdefer current_payload.free(allocator); + + if (parse_stack.items.len == 0) { + // No parent, this is the root + root = current_payload; + break; + } + + // Fill parent and check if complete + while (true) { + const parent = &parse_stack.items[parse_stack.items.len - 1]; + const finished = try fillParentContainer(parent, current_payload); + + if (!finished) { + // Parent needs more elements + break; + } + + // Parent container is complete, pop it + const completed_state = parse_stack.pop() orelse return MsgPackError.Internal; + const completed_payload = containerToPayload(completed_state); + + if (parse_stack.items.len == 0) { + // This was the root container + root = completed_payload; + break; + } + + // Continue with completed container as new current + current_payload = completed_payload; + } + + if (root != null) break; + } + } + + return root orelse MsgPackError.Internal; + } + + /// Fill parent container with child element + /// Returns true if parent container is complete + /// This is a hot path function, optimized for common cases + inline fn fillParentContainer( + parent: *ParseState, + child: Payload, + ) !bool { + switch (parent.container_type) { + .array => { + // Fast path: array insertion is just pointer assignment + var arr_state = &parent.data.array; + arr_state.items[arr_state.current_index] = child; + arr_state.current_index += 1; + return arr_state.current_index >= arr_state.total_length; + }, + + .map_key => { + // Key can be any Payload type (not just string) + parent.data.map.current_key = child; + parent.container_type = .map_value; + return false; // Still need to read value + }, + + .map_value => { + var map_state = &parent.data.map; + const key = map_state.current_key orelse return MsgPackError.Internal; + // Use putInternal to avoid cloning already-allocated deserialized keys + try map_state.map.putInternal(key, child); + map_state.current_key = null; + map_state.remaining_pairs -= 1; + + if (map_state.remaining_pairs == 0) { + return true; // Map complete + } + + parent.container_type = .map_key; + return false; // Continue reading next key + }, + } + } + + /// Convert completed ParseState to Payload + inline fn containerToPayload(state: ParseState) Payload { + return switch (state.data) { + .array => |arr_state| Payload{ .arr = arr_state.items }, + .map => |map_state| Payload{ .map = map_state.map }, + }; + } + }; +} + +/// Create an instance of msgpack_pack with default limits (backward compatible) +pub fn Pack( + comptime WriteContext: type, + comptime ReadContext: type, + comptime WriteError: type, + comptime ReadError: type, + comptime writeFn: fn (context: WriteContext, bytes: []const u8) WriteError!usize, + comptime readFn: fn (context: ReadContext, arr: []u8) ReadError!usize, +) type { + return PackWithLimits( + WriteContext, + ReadContext, + WriteError, + ReadError, + writeFn, + readFn, + DEFAULT_LIMITS, + ); +} + +// ============================================================================ +// std.io.Reader and std.io.Writer Support (Zig 0.15+) +// ============================================================================ + +/// Check if we're using Zig 0.15 or later with the new I/O system +const has_new_io = current_zig.minor >= 15; + +/// Wrapper context for std.io.Writer (Zig 0.15+) +const IoWriterContext = struct { + writer: if (has_new_io) *std.Io.Writer else void, + + fn write(self: IoWriterContext, bytes: []const u8) !usize { + if (!has_new_io) @compileError("std.Io.Writer requires Zig 0.15 or later"); + try self.writer.writeAll(bytes); + return bytes.len; + } +}; + +/// Wrapper context for std.io.Reader (Zig 0.15+) +const IoReaderContext = struct { + reader: if (has_new_io) *std.Io.Reader else void, + + fn read(self: IoReaderContext, buf: []u8) !usize { + if (!has_new_io) @compileError("std.Io.Reader requires Zig 0.15 or later"); + try self.reader.readSliceAll(buf); + return buf.len; + } +}; + +/// Type alias for the Pack type used with std.io.Reader/Writer +const IoPackType = if (has_new_io) Pack( + IoWriterContext, + IoReaderContext, + std.Io.Writer.Error, + std.Io.Reader.Error, + IoWriterContext.write, + IoReaderContext.read, +) else void; + +/// Packer that works with std.io.Reader and std.io.Writer interfaces (Zig 0.15+) +/// +/// This provides a convenient wrapper around the generic Pack type for working +/// with Zig's standard I/O interfaces. +/// +/// Example: +/// ```zig +/// const std = @import("std"); +/// const msgpack = @import("msgpack"); +/// +/// pub fn main() !void { +/// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +/// defer _ = gpa.deinit(); +/// const allocator = gpa.allocator(); +/// +/// // Create file for I/O +/// var file = try std.fs.cwd().createFile("data.msgpack", .{ .read = true }); +/// defer file.close(); +/// +/// // Create reader and writer with buffers +/// var reader_buf: [4096]u8 = undefined; +/// var reader = file.reader(&reader_buf); +/// var writer_buf: [4096]u8 = undefined; +/// var writer = file.writer(&writer_buf); +/// +/// // Create packer +/// var packer = try msgpack.PackerIO.init(&reader, &writer); +/// +/// // Serialize +/// var payload = msgpack.Payload.mapPayload(allocator); +/// defer payload.free(allocator); +/// try payload.mapPut("name", try msgpack.Payload.strToPayload("Alice", allocator)); +/// try packer.write(payload); +/// +/// // Flush and reset for reading +/// try writer.flush(); +/// try file.seekTo(0); +/// reader.seek = 0; +/// reader.end = 0; +/// +/// // Deserialize +/// const decoded = try packer.read(allocator); +/// defer decoded.free(allocator); +/// } +/// ``` +pub const PackerIO = if (has_new_io) struct { + packer: IoPackType, + write_ctx: IoWriterContext, + read_ctx: IoReaderContext, + + /// Initialize a PackerIO with std.io.Reader and std.io.Writer + /// + /// The Reader and Writer must remain valid for the lifetime of this PackerIO. + pub fn init( + reader: *std.Io.Reader, + writer: *std.Io.Writer, + ) PackerIO { + const write_ctx = IoWriterContext{ .writer = writer }; + const read_ctx = IoReaderContext{ .reader = reader }; + + return .{ + .packer = IoPackType.init(write_ctx, read_ctx), + .write_ctx = write_ctx, + .read_ctx = read_ctx, + }; + } + + /// Write a Payload to the writer + /// + /// The payload must be freed by the caller after writing. + pub fn write(self: *PackerIO, payload: Payload) !void { + return self.packer.write(payload); + } + + /// Read a Payload from the reader + /// + /// The returned payload must be freed by the caller using `payload.free(allocator)`. + /// Note: The read method uses an iterative parser that is safe for deeply nested data. + pub fn read(self: *PackerIO, allocator: Allocator) !Payload { + return self.packer.read(allocator); + } +} else struct { + pub fn init(_: anytype, _: anytype) @This() { + @compileError("PackerIO requires Zig 0.15 or later"); + } +}; + +/// Convenience function to create a PackerIO from std.io.Reader and std.io.Writer +/// +/// This is a shorthand for `PackerIO.init(reader, writer)`. +/// +/// Example: +/// ```zig +/// var reader_buf: [4096]u8 = undefined; +/// var reader = file.reader(&reader_buf); +/// var writer_buf: [4096]u8 = undefined; +/// var writer = file.writer(&writer_buf); +/// var packer = msgpack.packIO(&reader, &writer); +/// ``` +pub fn packIO( + reader: if (has_new_io) *std.Io.Reader else void, + writer: if (has_new_io) *std.Io.Writer else void, +) if (has_new_io) PackerIO else void { + if (!has_new_io) @compileError("packIO requires Zig 0.15 or later"); + return PackerIO.init(reader, writer); +} + +// Export compatibility layer for cross-version support +pub const compat = @import("compat.zig"); diff --git a/ipc-codegen/scripts/convert_schema.ts b/ipc-codegen/scripts/convert_schema.ts new file mode 100644 index 000000000000..0f0bed844be8 --- /dev/null +++ b/ipc-codegen/scripts/convert_schema.ts @@ -0,0 +1,173 @@ +// One-shot converter: old positional IPC schema -> friendly JSONC form. +// Usage: node convert_schema.ts [outFriendly.jsonc] +import { readFileSync, writeFileSync } from "fs"; +import { + SchemaVisitor, + friendlyToPositional, +} from "../src/schema_visitor.ts"; + +const PRIM: Record = { + bool: "bool", + int: "u32", + "unsigned int": "u32", + "unsigned short": "u16", + "unsigned long": "u64", + "unsigned long long": "u64", + "unsigned char": "u8", + double: "f64", + string: "string", + bin32: "bin32", +}; + +function convert(oldPath: string, prefix: string) { + const old = JSON.parse(readFileSync(oldPath, "utf-8")); + const aliases: Record = {}; + const types: Record> = {}; + + const typeToShorthand = (t: any): string => { + if (typeof t === "string") { + return PRIM[t] ?? t; // primitive or a named (struct) reference + } + if (Array.isArray(t)) { + const [kind, args] = t; + if (kind === "vector") { + const [el] = args; + if (el === "unsigned char") return "bytes"; + return typeToShorthand(el) + "[]"; + } + if (kind === "array") + return typeToShorthand(args[0]) + "[" + args[1] + "]"; + if (kind === "optional") return typeToShorthand(args[0]) + "?"; + if (kind === "shared_ptr") return typeToShorthand(args[0]); + if (kind === "alias") { + const [name, underlying] = args; + aliases[name] = underlying === "bin32" ? "bin32" : PRIM[underlying]; + if (!aliases[name]) + throw new Error(`alias ${name} underlying ${underlying}`); + return name; + } + throw new Error(`unknown type kind: ${kind}`); + } + if (t && typeof t === "object" && t.__typename) { + const tn = t.__typename as string; + if (!(tn in types)) { + types[tn] = {}; // reserve to break cycles + types[tn] = structFields(t); + } + return tn; + } + throw new Error(`cannot convert type ${JSON.stringify(t)}`); + }; + + const structFields = (struct: any): Record => { + const out: Record = {}; + for (const [k, v] of Object.entries(struct)) { + if (k === "__typename") continue; + out[k] = typeToShorthand(v); + } + return out; + }; + + const commandPairs = old.commands[1] as Array<[string, any]>; + const responsePairs = old.responses[1] as Array<[string, any]>; + const respByName = new Map(responsePairs); + + const errEntry = responsePairs.find(([n]) => n.endsWith("ErrorResponse"))!; + const error = structFields(errEntry[1]); + + const commands: Record = {}; + for (const [cmdName, cmdStruct] of commandPairs) { + const key = cmdName.startsWith(prefix) + ? cmdName.slice(prefix.length) + : cmdName; + const request = structFields(cmdStruct); + const respStruct = respByName.get(`${cmdName}Response`); + let response: any; + if (respStruct === undefined) { + throw new Error(`No response named ${cmdName}Response`); + } else if (typeof respStruct === "string") { + response = respStruct.startsWith(prefix) + ? respStruct.slice(prefix.length) + : respStruct; + } else { + response = structFields(respStruct); + } + commands[key] = { request, response }; + } + + return { service: prefix, aliases, types, error, commands }; +} + +// Deep structural equality of CompiledSchema, ignoring Map insertion order. +// Struct references are compared by NAME only: generators emit nested structs +// by name from the top-level `structs` map, so the embedded fields on a struct +// ref are irrelevant (and differ benignly between string-ref and inline forms). +function normalize(c: any): any { + const normType = (t: any): any => { + if (t == null || typeof t !== "object") return t; + if (t.kind === "struct") + return { kind: "struct", structName: t.struct?.name }; + if (t.element) + return { kind: t.kind, size: t.size, element: normType(t.element) }; + return { + kind: t.kind, + primitive: t.primitive, + originalName: t.originalName, + }; + }; + const norm = (s: any) => ({ + name: s.name, + fields: s.fields.map((f: any) => ({ + name: f.name, + type: normType(f.type), + })), + }); + return { + structs: Object.fromEntries( + [...c.structs.entries()].map(([k, v]: any) => [k, norm(v)]).sort(), + ), + responses: Object.fromEntries( + [...c.responses.entries()].map(([k, v]: any) => [k, norm(v)]).sort(), + ), + commands: [...c.commands] + .map((x: any) => ({ + name: x.name, + responseType: x.responseType, + fields: x.fields.map((f: any) => ({ + name: f.name, + type: normType(f.type), + })), + })) + .sort((a, b) => a.name.localeCompare(b.name)), + errorTypeName: c.errorTypeName, + }; +} + +const [oldPath, prefix, outPath] = process.argv.slice(2); +const friendly = convert(oldPath, prefix); +const friendlyText = JSON.stringify(friendly, null, 2); +if (outPath) writeFileSync(outPath, friendlyText + "\n"); + +const old = JSON.parse(readFileSync(oldPath, "utf-8")); +const oldCompiled = new SchemaVisitor().visit(old.commands, old.responses); +const { commands, responses } = friendlyToPositional(friendly); +const newCompiled = new SchemaVisitor().visit(commands, responses); + +const a = JSON.stringify(normalize(oldCompiled)); +const b = JSON.stringify(normalize(newCompiled)); +if (a === b) { + console.log( + `ROUND-TRIP OK service=${prefix} commands=${friendly.commands && Object.keys(friendly.commands).length} types=${Object.keys(friendly.types).length} aliases=${Object.keys(friendly.aliases).length}`, + ); +} else { + console.log(`ROUND-TRIP MISMATCH for ${prefix}`); + // show first divergence + for (let i = 0; i < Math.max(a.length, b.length); i++) { + if (a[i] !== b[i]) { + console.log("old:", a.slice(Math.max(0, i - 80), i + 80)); + console.log("new:", b.slice(Math.max(0, i - 80), i + 80)); + break; + } + } + process.exit(1); +} diff --git a/ipc-codegen/src/cpp_codegen.ts b/ipc-codegen/src/cpp_codegen.ts new file mode 100644 index 000000000000..0ca7ace75b1d --- /dev/null +++ b/ipc-codegen/src/cpp_codegen.ts @@ -0,0 +1,751 @@ +/** + * C++ IPC Client Code Generator + * + * Generates a C++ IPC client from a CompiledSchema. The generated client: + * - Connects to a server over Unix Domain Socket via ipc::IpcClient + * - Serializes each command to the [name, payload] msgpack framing keyed on + * MSGPACK_SCHEMA_NAME, sends, receives, deserializes + * - Has one method per command, returning the typed response + * + * Usage: + * const gen = new CppCodegen({ namespace: 'my_service', prefix: 'MyService' }); + * const header = gen.generateHeader(schema); + * const impl = gen.generateImpl(schema); + */ + +import type { CompiledSchema, Command } from "./schema_visitor.ts"; +import { + toPascalCase, + toSnakeCase, + toAliasName, + dedupeStructsByName, +} from "./naming.ts"; + +export interface CppCodegenOptions { + /** C++ namespace for generated code, e.g. 'my_service' */ + namespace: string; + /** Prefix for command/response types, e.g. 'MyService' */ + prefix: string; + /** Strip the prefix from method names, e.g. MyServiceGetInfo -> get_info */ + stripMethodPrefix?: boolean; + /** + * Override for the generated output directory include path. + */ + generatedIncludeDir?: string; + /** + * Sub-namespace for wire types (e.g. 'wire' → types in ns::wire). + * When set, standalone types are wrapped in this sub-namespace, + * and the server dispatch deserializes into wire types then converts to domain types. + */ + wireNamespace?: string; +} + +export class CppCodegen { + constructor(private opts: CppCodegenOptions) {} + + private primitiveType(type: import("./schema_visitor.ts").Type): string { + switch (type.primitive) { + case "bool": + return "bool"; + case "u8": + return "uint8_t"; + case "u16": + return "uint16_t"; + case "u32": + return "uint32_t"; + case "u64": + return "uint64_t"; + case "f64": + return "double"; + case "string": + return "std::string"; + case "bytes": + return "std::vector"; + case "bin32": + return "std::array"; + } + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + + /** Convert a command name to a C++ method name (snake_case) */ + private methodName(commandName: string): string { + // With stripMethodPrefix: "CdbGetContractInstance" -> "get_contract_instance" + const withoutPrefix = + this.opts.stripMethodPrefix && commandName.startsWith(this.opts.prefix) + ? commandName.slice(this.opts.prefix.length) + : commandName; + return toSnakeCase(withoutPrefix); + } + + /** Check if the response has fields (non-void return) */ + private hasResponseFields(command: Command, schema: CompiledSchema): boolean { + const resp = schema.responses.get(command.responseType); + return !!resp && resp.fields.length > 0; + } + + /** Generate the method signature using command struct types directly */ + private generateMethodSignature( + command: Command, + schema: CompiledSchema, + className?: string, + ): string { + const method = this.methodName(command.name); + const hasFields = this.hasResponseFields(command, schema); + // Wire types use top-level response names (BbFooResponse). + // Command types with nested Response use Cmd::Response. + const retType = hasFields + ? this.opts.wireNamespace + ? command.responseType + : `${command.name}::Response` + : "void"; + + // If the command has fields, take the whole command struct by value + const params = command.fields.length > 0 ? `${command.name} cmd` : ""; + + const prefix = className ? `${className}::` : ""; + + return `${retType} ${prefix}${method}(${params})`; + } + + /** Generate the header file */ + generateHeader(schema: CompiledSchema, schemaHash?: string): string { + const { namespace: ns, prefix } = this.opts; + const wireNs = this.opts.wireNamespace; + const className = `${prefix}IpcClient`; + + const methods = schema.commands + .map((cmd) => { + const sig = this.generateMethodSignature(cmd, schema); + return ` ${sig};`; + }) + .join("\n"); + + const hashConstant = schemaHash + ? `\n/** Schema version hash for compatibility checking */\nstatic constexpr const char SCHEMA_HASH[] = "${schemaHash}";\n` + : ""; + + // When wireNamespace is set, include wire types and bring them into scope + const wireInclude = wireNs + ? `#include "${this.generatedInclude(`${toSnakeCase(prefix)}_types.hpp`)}"\n` + : ""; + const wireUsing = wireNs ? `using namespace ${wireNs};\n` : ""; + + const typesInclude = this.generatedInclude( + `${toSnakeCase(prefix)}_types.hpp`, + ); + + return `// AUTOGENERATED FILE - DO NOT EDIT +#pragma once + +#include "${typesInclude}" +#include "ipc_runtime/ipc_client.hpp" + +#include +#include +#include + +namespace ${ns} { +${wireUsing}${hashConstant} +/** + * @brief Auto-generated IPC client. + * + * Each method sends a msgpack-serialized command to the server and returns + * the typed response. Transport (UDS or MPSC-SHM) is selected by the path + * suffix passed to the constructor: ".sock" → UDS, ".shm" → MPSC-SHM. All + * methods block until the response arrives. + */ +class ${className} { + public: + /** + * @param path Transport path (".sock" → UDS, ".shm" → MPSC-SHM). + * @param call_timeout_ns Per-call send/receive timeout in nanoseconds. + * 0 (the default) means wait indefinitely — commands like proving + * can legitimately take minutes. + */ + explicit ${className}(const std::string& path, uint64_t call_timeout_ns = 0); + ~${className}(); + + ${className}(const ${className}&) = delete; + ${className}& operator=(const ${className}&) = delete; + +${methods} + + private: + template + Resp send(Cmd&& cmd) const; + + mutable std::unique_ptr<::ipc::IpcClient> client_; + uint64_t call_timeout_ns_; +}; + +} // namespace ${ns} +`; + } + + /** Generate the implementation file — hand-rolled [name, payload] serialization */ + generateImpl(schema: CompiledSchema): string { + const { namespace: ns, prefix } = this.opts; + const className = `${prefix}IpcClient`; + const errorType = schema.errorTypeName; + + const methods = schema.commands + .map((cmd) => { + return this.generateMethodImpl(cmd, schema, className); + }) + .join("\n"); + + return `// AUTOGENERATED FILE - DO NOT EDIT + +#include "${this.headerIncludePath()}" + +// THROW/RETHROW satisfy msgpack-c builds with -fno-exceptions support. They +// must be defined before is included (transitively via +// ipc_codegen/msgpack_adaptor.hpp). Under BB_NO_EXCEPTIONS THROW aborts; +// the guard lets a consumer predefine its own variant. +#include "ipc_codegen/throw.hpp" +#include "ipc_codegen/msgpack_adaptor.hpp" + +#include +#include +#include + +// Client-side glue is exception-using and transport-using. Under WASM / +// -fno-exceptions consumers that don't need a transport-based client (e.g. +// in-process FFI callers) can skip the whole translation unit so we don't +// have to thread THROW through every site. +#ifndef BB_NO_EXCEPTIONS + +namespace ${ns} { + +${className}::${className}(const std::string& path, uint64_t call_timeout_ns) + : client_(::ipc::make_client(path)) + , call_timeout_ns_(call_timeout_ns) +{ + if (!client_) { + throw std::runtime_error("ipc::make_client: unrecognised path suffix (expected .sock or .shm): " + path); + } + if (!client_->connect()) { + throw std::runtime_error("ipc::IpcClient::connect() failed for " + path); + } +} + +${className}::~${className}() = default; + +template +Resp ${className}::send(Cmd&& cmd) const +{ + // Serialize as [[CommandName, {payload}]] + msgpack::sbuffer send_buffer; + msgpack::packer pk(send_buffer); + pk.pack_array(1); + pk.pack_array(2); + pk.pack(std::string(Cmd::MSGPACK_SCHEMA_NAME)); + pk.pack(std::forward(cmd)); + + // Send request, receive response. + if (!client_->send(send_buffer.data(), send_buffer.size(), call_timeout_ns_)) { + throw std::runtime_error("ipc::IpcClient::send failed"); + } + auto response_view = client_->receive(call_timeout_ns_); + if (response_view.empty()) { + throw std::runtime_error("ipc::IpcClient::receive failed or timed out"); + } + // Copy out before release() — for SHM this gives up zero-copy semantics + // but keeps the rest of the code simple. convert() below copies anyway. + std::vector response_bytes(response_view.begin(), response_view.end()); + client_->release(response_view.size()); + + // Parse response: [ResponseName, {payload}] + auto unpacked = msgpack::unpack( + reinterpret_cast(response_bytes.data()), response_bytes.size()); + auto obj = unpacked.get(); + + if (obj.type != msgpack::type::ARRAY || obj.via.array.size != 2 || + obj.via.array.ptr[0].type != msgpack::type::STR) { + throw std::runtime_error("Invalid response format from server"); + } + + std::string resp_name(obj.via.array.ptr[0].via.str.ptr, obj.via.array.ptr[0].via.str.size); + if (resp_name == "${errorType}") { + std::string message; + auto& payload = obj.via.array.ptr[1]; + // Extract message field from the error map + if (payload.type == msgpack::type::MAP) { + for (uint32_t i = 0; i < payload.via.map.size; ++i) { + auto& kv = payload.via.map.ptr[i]; + if (kv.key.type == msgpack::type::STR) { + std::string key(kv.key.via.str.ptr, kv.key.via.str.size); + if (key == "message" && kv.val.type == msgpack::type::STR) { + message = std::string(kv.val.via.str.ptr, kv.val.via.str.size); + } + } + } + } + throw std::runtime_error("Server error: " + message); + } + if (resp_name != Resp::MSGPACK_SCHEMA_NAME) { + throw std::runtime_error("Expected response '" + std::string(Resp::MSGPACK_SCHEMA_NAME) + + "' but got '" + resp_name + "'"); + } + + Resp result; + obj.via.array.ptr[1].convert(result); + return result; +} + +${methods} +} // namespace ${ns} + +#endif // BB_NO_EXCEPTIONS +`; + } + + /** Generate a single method implementation */ + private generateMethodImpl( + command: Command, + schema: CompiledSchema, + className: string, + ): string { + const sig = this.generateMethodSignature(command, schema, className); + const hasFields = this.hasResponseFields(command, schema); + const respType = this.opts.wireNamespace + ? command.responseType + : `${command.name}::Response`; + + const cmdExpr = + command.fields.length > 0 ? "std::move(cmd)" : `${command.name}{}`; + + if (!hasFields) { + return `${sig} +{ + send<${command.name}, ${respType}>(${cmdExpr}); +} +`; + } + + return `${sig} +{ + return send<${command.name}, ${respType}>(${cmdExpr}); +} +`; + } + + /** Get the generated/ directory include prefix. + * Returns either the explicit --cpp-include-dir value (e.g. "myservice/generated") + * or empty for callers that include generated files by their bare filename. */ + private generatedDir(): string { + if (this.opts.generatedIncludeDir) { + return this.opts.generatedIncludeDir; + } + return ""; + } + + /** Form an include path: `/` if dir is non-empty, else bare ``. */ + private generatedInclude(filename: string): string { + const dir = this.generatedDir(); + return dir ? `${dir}/${filename}` : filename; + } + + /** Compute the include path for the generated client header */ + private headerIncludePath(): string { + return this.generatedInclude( + `${toSnakeCase(this.opts.prefix)}_ipc_client.hpp`, + ); + } + + // ----------------------------------------------------------------------- + // Standalone types (no external project dependencies) + // ----------------------------------------------------------------------- + + /** Generate standalone C++ types with MSGPACK_DEFINE_MAP — no external project deps */ + generateStandaloneTypes(schema: CompiledSchema): string { + const { namespace: ns, prefix } = this.opts; + + const aliasTypes = new Map< + string, + { underlying: string; schemaName: string } + >(); + const collect = (type: import("./schema_visitor.ts").Type): void => { + if (type.kind === "primitive" && type.originalName) { + aliasTypes.set(toAliasName(type.originalName), { + underlying: this.primitiveType(type), + schemaName: type.originalName, + }); + } else if ( + type.kind === "vector" || + type.kind === "array" || + type.kind === "optional" + ) { + if (type.element) collect(type.element); + } + }; + for (const s of schema.structs.values()) { + for (const f of s.fields) collect(f.type); + } + for (const s of schema.responses.values()) { + for (const f of s.fields) collect(f.type); + } + const aliasDecls = [...aliasTypes.entries()] + .sort(([a], [b]) => a.localeCompare(b)) + .map(([name, { underlying, schemaName }]) => { + // bin32 aliases are nominal types (a fixed 32-byte value with a name), + // so they are distinct wrapper structs. Scalar aliases are transparent + // synonyms — consumers static_cast them to/from enums and integers — + // so they are plain `using`. + if (underlying === "std::array") { + return `struct ${name} : ::ipc::Bin32Alias<${name}> { + using ::ipc::Bin32Alias<${name}>::Bin32Alias; + static constexpr const char MSGPACK_SCHEMA_NAME[] = "${schemaName}"; +};`; + } + return `using ${name} = ${underlying};`; + }) + .join("\n"); + + // Map schema types to C++ types + const mapType = (type: import("./schema_visitor.ts").Type): string => { + switch (type.kind) { + case "primitive": + return type.originalName + ? toAliasName(type.originalName) + : this.primitiveType(type); + case "vector": + return `std::vector<${mapType(type.element!)}>`; + case "array": + return `std::array<${mapType(type.element!)}, ${type.size}>`; + case "optional": + return `std::optional<${mapType(type.element!)}>`; + case "struct": + return type.struct!.name; + } + throw new Error(`Unsupported type kind: ${type.kind}`); + }; + + const allStructs = dedupeStructsByName([ + ...schema.structs.values(), + ...schema.responses.values(), + ]); + const structs = allStructs + .map((s) => { + if (s.fields.length > 20) { + throw new Error( + `Struct '${s.name}' has ${s.fields.length} fields; IPC_CODEGEN_SERIALIZATION_FIELDS supports at most 20. ` + + `Split the struct or extend the macro in ipc_codegen/msgpack_adaptor.hpp.`, + ); + } + const fields = s.fields + .map((f) => ` ${mapType(f.type)} ${f.name};`) + .join("\n"); + const fieldNames = s.fields.map((f) => f.name).join(", "); + const schemaName = ` static constexpr const char MSGPACK_SCHEMA_NAME[] = "${s.name}";`; + const serialization = fieldNames + ? ` IPC_CODEGEN_SERIALIZATION_FIELDS(${fieldNames})` + : ` template void msgpack(_PackFn&& pack_fn) { pack_fn(); }`; + return `struct ${s.name} {\n${schemaName}\n${fields}\n${serialization}\n bool operator==(const ${s.name}&) const = default;\n};`; + }) + .join("\n\n"); + + return `// AUTOGENERATED FILE - DO NOT EDIT +// Standalone types for ${prefix} service. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +// Pull in THROW/RETHROW: \`throw\` natively, abort-on-throw under +// BB_NO_EXCEPTIONS (WASM). Must be in scope before so msgpack-c +// picks up the right variant. +#include "ipc_codegen/msgpack_include.hpp" + +// --------------------------------------------------------------------------- +// Self-contained serialization macro for generated wire types. +// Defines a msgpack() method that enumerates field name/value pairs. +// --------------------------------------------------------------------------- +#ifndef IPC_CODEGEN_SERIALIZATION_FIELDS +#define _SF_E1(x) #x, x +#define _SF_E2(x, ...) #x, x, _SF_E1(__VA_ARGS__) +#define _SF_E3(x, ...) #x, x, _SF_E2(__VA_ARGS__) +#define _SF_E4(x, ...) #x, x, _SF_E3(__VA_ARGS__) +#define _SF_E5(x, ...) #x, x, _SF_E4(__VA_ARGS__) +#define _SF_E6(x, ...) #x, x, _SF_E5(__VA_ARGS__) +#define _SF_E7(x, ...) #x, x, _SF_E6(__VA_ARGS__) +#define _SF_E8(x, ...) #x, x, _SF_E7(__VA_ARGS__) +#define _SF_E9(x, ...) #x, x, _SF_E8(__VA_ARGS__) +#define _SF_E10(x, ...) #x, x, _SF_E9(__VA_ARGS__) +#define _SF_E11(x, ...) #x, x, _SF_E10(__VA_ARGS__) +#define _SF_E12(x, ...) #x, x, _SF_E11(__VA_ARGS__) +#define _SF_E13(x, ...) #x, x, _SF_E12(__VA_ARGS__) +#define _SF_E14(x, ...) #x, x, _SF_E13(__VA_ARGS__) +#define _SF_E15(x, ...) #x, x, _SF_E14(__VA_ARGS__) +#define _SF_E16(x, ...) #x, x, _SF_E15(__VA_ARGS__) +#define _SF_E17(x, ...) #x, x, _SF_E16(__VA_ARGS__) +#define _SF_E18(x, ...) #x, x, _SF_E17(__VA_ARGS__) +#define _SF_E19(x, ...) #x, x, _SF_E18(__VA_ARGS__) +#define _SF_E20(x, ...) #x, x, _SF_E19(__VA_ARGS__) +#define _SF_CNT(_1,_2,_3,_4,_5,_6,_7,_8,_9,_10,_11,_12,_13,_14,_15,_16,_17,_18,_19,_20,N,...) N +#define _SF_NUM(...) _SF_CNT(__VA_ARGS__,20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1) +#define _SF_CAT(a, b) a##b +#define _SF_SEL(n) _SF_CAT(_SF_E, n) +#define _SF_NVP(...) _SF_SEL(_SF_NUM(__VA_ARGS__))(__VA_ARGS__) +#define IPC_CODEGEN_SERIALIZATION_FIELDS(...) \\ + template void msgpack(_PackFn pack_fn) { pack_fn(_SF_NVP(__VA_ARGS__)); } +#endif + +// --------------------------------------------------------------------------- +// Wire aliases for primitive schema aliases. bin32 aliases are nominal wrappers +// carrying their alias name as the MSGPACK_SCHEMA_NAME dispatch tag. +// --------------------------------------------------------------------------- + +#ifndef IPC_CODEGEN_BIN32_ALIAS_DEFINED +#define IPC_CODEGEN_BIN32_ALIAS_DEFINED +namespace ipc { +template struct Bin32Alias { + using IPC_CODEGEN_BIN32_ALIAS = void; + std::array value{}; + + Bin32Alias() = default; + Bin32Alias(const std::array& bytes) : value(bytes) {} + Bin32Alias(std::array&& bytes) : value(std::move(bytes)) {} + + uint8_t* data() { return value.data(); } + const uint8_t* data() const { return value.data(); } + constexpr std::size_t size() const { return 32; } + + uint8_t& operator[](std::size_t i) { return value[i]; } + const uint8_t& operator[](std::size_t i) const { return value[i]; } + + auto begin() { return value.begin(); } + auto end() { return value.end(); } + auto begin() const { return value.begin(); } + auto end() const { return value.end(); } + + operator std::array&() { return value; } + operator const std::array&() const { return value; } + + void msgpack_pack(auto& packer) const + { + packer.pack_bin(static_cast(value.size())); + packer.pack_bin_body(reinterpret_cast(value.data()), static_cast(value.size())); + } + + void msgpack_unpack(auto object) + { + if constexpr (requires { object.template as>(); }) { + value = object.template as>(); + } else { + value = static_cast>(object); + } + } + + bool operator==(const Bin32Alias&) const = default; +}; +} // namespace ipc +#endif + +namespace ${ns} { + +${aliasDecls} + +${this.opts.wireNamespace ? `namespace ${this.opts.wireNamespace} {` : ""} + +${structs} + +${this.opts.wireNamespace ? `} // namespace ${this.opts.wireNamespace}` : ""} +} // namespace ${ns} +`; + } + + // ----------------------------------------------------------------------- + // Server-side code generation (uses standalone ipc_server.hpp template) + // ----------------------------------------------------------------------- + + /** Generate the dispatch — header-only, template, no transport dependency. */ + generateDispatchHeader(schema: CompiledSchema): string { + const { namespace: ns, prefix } = this.opts; + const errorTypeName = schema.errorTypeName; + const typesHeader = `${toSnakeCase(prefix)}_types.hpp`; + // Handler declarations — template + const handlerDecls = schema.commands + .map((c) => { + const method = toSnakeCase( + c.name.startsWith(prefix) ? c.name.slice(prefix.length) : c.name, + ); + return `template\nwire::${c.responseType} handle_${method}(Ctx& ctx, wire::${c.name}&& cmd);`; + }) + .join("\n\n"); + + // Handler entries for dispatch map + const handlerEntries = schema.commands + .map((cmd) => { + const method = toSnakeCase( + cmd.name.startsWith(prefix) + ? cmd.name.slice(prefix.length) + : cmd.name, + ); + + const deserialize = + cmd.fields.length > 0 + ? `wire::${cmd.name} wire_cmd; payload.convert(wire_cmd);` + : `wire::${cmd.name} wire_cmd;`; + + return ` { "${cmd.name}", [](Ctx& ctx, [[maybe_unused]] const msgpack::object& payload) -> std::vector { + ${deserialize} + auto wire_resp = handle_${method}(ctx, std::move(wire_cmd)); + if constexpr (requires { ctx.error_message; }) { + if (!ctx.error_message.empty()) { + std::string msg = std::move(ctx.error_message); + ctx.error_message.clear(); + return detail::make_error(msg); + } + } + msgpack::sbuffer buf; + msgpack::packer pk(buf); + pk.pack_array(2); pk.pack(std::string("${cmd.responseType}")); pk.pack(wire_resp); + return std::vector(buf.data(), buf.data() + buf.size()); + } }`; + }) + .join(",\n"); + + return `// AUTOGENERATED FILE - DO NOT EDIT +// Header-only dispatch — template for service context. +#pragma once + +#include "${typesHeader}" +#include "ipc_codegen/msgpack_adaptor.hpp" + +// Pull in THROW/RETHROW — 'throw' natively, abort-on-throw under +// BB_NO_EXCEPTIONS (WASM). ipc_codegen/throw.hpp keeps definitions guarded +// with #ifndef THROW, so a parent project that predefines them wins. +#include "ipc_codegen/msgpack_include.hpp" + +#include +#include +#include +#include +#include + +namespace ${ns} { + +// Wire types are in the 'wire' sub-namespace (from ${typesHeader}) +// Handler declarations — implement these in your handler file. +// Template specializations must be visible before make_handler() is instantiated. + +${handlerDecls} + +// --------------------------------------------------------------------------- +// Dispatch — template on service context type +// --------------------------------------------------------------------------- + +namespace detail { + +inline std::vector make_error(const std::string& message) +{ + msgpack::sbuffer buf; + msgpack::packer pk(buf); + pk.pack_array(2); + pk.pack(std::string("${errorTypeName}")); + pk.pack_map(1); + pk.pack(std::string("message")); + pk.pack(message); + return std::vector(buf.data(), buf.data() + buf.size()); +} + +} // namespace detail + +// Dispatcher signature — independent of the chosen IPC server backend. +using DispatchHandler = std::function(const std::vector&)>; + +template +DispatchHandler make_${toSnakeCase(prefix)}_handler(Ctx& ctx) +{ + using HandlerFn = std::function(Ctx&, const msgpack::object&)>; + static const std::unordered_map table = { +${handlerEntries}, + }; + + return [&ctx](const std::vector& raw_request) -> std::vector { + auto unpacked = msgpack::unpack( + reinterpret_cast(raw_request.data()), raw_request.size()); + auto obj = unpacked.get(); + + if (obj.type != msgpack::type::ARRAY || obj.via.array.size != 1) { + return detail::make_error("malformed request: expected outer array of size 1"); + } + + auto& inner = obj.via.array.ptr[0]; + if (inner.type != msgpack::type::ARRAY || inner.via.array.size != 2 || + inner.via.array.ptr[0].type != msgpack::type::STR) { + return detail::make_error("malformed request: expected [CommandName, {payload}]"); + } + + std::string cmd_name(inner.via.array.ptr[0].via.str.ptr, inner.via.array.ptr[0].via.str.size); + auto& cmd_payload = inner.via.array.ptr[1]; + + auto it = table.find(cmd_name); + if (it == table.end()) { + return detail::make_error("unknown command: " + cmd_name); + } +#ifdef BB_NO_EXCEPTIONS + return it->second(ctx, cmd_payload); +#else + try { + return it->second(ctx, cmd_payload); + } catch (const std::exception& e) { + std::cerr << "Error processing " << cmd_name << ": " << e.what() << '\\n'; + return detail::make_error(e.what()); + } +#endif + }; +} + +} // namespace ${ns} +`; + } + + /** Generate native IPC server glue for a dispatch header. */ + generateServerHeader(): string { + const { namespace: ns, prefix } = this.opts; + const dispatchHeader = `${toSnakeCase(prefix)}_dispatch.hpp`; + + return `// AUTOGENERATED FILE - DO NOT EDIT +// Native IPC server glue for ${prefix}. +#pragma once + +#include "${dispatchHeader}" +#include "ipc_runtime/serve_helper.hpp" +#include "ipc_runtime/signal_handlers.hpp" + +#include +#include +#include +#include + +namespace ${ns} { + +template +void serve(const std::string& input_path, Ctx& ctx) +{ + auto server = ::ipc::make_server(input_path); + if (!server) { + throw std::runtime_error("ipc::make_server: unrecognised path suffix (expected .sock or .shm): " + input_path); + } + ::ipc::install_default_signal_handlers(*server); + if (!server->listen()) { + throw std::runtime_error("ipc::IpcServer::listen() failed for " + input_path); + } + auto handler = make_${toSnakeCase(prefix)}_handler(ctx); + server->run([&handler](int /*client_id*/, std::span raw) { + return handler(std::vector(raw.begin(), raw.end())); + }); +} + +} // namespace ${ns} +`; + } +} diff --git a/ipc-codegen/src/generate.ts b/ipc-codegen/src/generate.ts new file mode 100644 index 000000000000..79ef1a15e1b6 --- /dev/null +++ b/ipc-codegen/src/generate.ts @@ -0,0 +1,612 @@ +// CI trigger +/** + * IPC code generation CLI. + * + * Usage: + * generate.ts --schema --lang --out [flags] + * + * Required: + * --schema JSON schema file + * --lang Target language + * --out Output directory for always-regenerated code + * + * Run with no arguments for the full flag reference. + * + * Zero npm dependencies — runs with Node.js 22+ via --experimental-strip-types. + */ + +import { createHash } from "crypto"; +import { + readFileSync, + writeFileSync, + renameSync, + mkdirSync, + cpSync, + rmSync, +} from "fs"; +import { execSync } from "child_process"; +import { dirname, join, resolve } from "path"; +import { fileURLToPath } from "url"; +import { + SchemaVisitor, + friendlyToPositional, + isFriendlySchema, + stripJsonc, + type CompiledSchema, +} from "./schema_visitor.ts"; +import { TypeScriptCodegen } from "./typescript_codegen.ts"; +import { + defaultBinaryEnvVar, + TypeScriptPackageCodegen, +} from "./typescript_package_codegen.ts"; +import { RustCodegen } from "./rust_codegen.ts"; +import { ZigCodegen } from "./zig_codegen.ts"; +import { CppCodegen } from "./cpp_codegen.ts"; +import { toSnakeCase } from "./naming.ts"; + +// @ts-ignore +const __dirname = dirname(fileURLToPath(import.meta.url)); + +// --------------------------------------------------------------------------- +// Argument parsing +// --------------------------------------------------------------------------- + +interface Args { + schema: string; + lang: string; + out: string; + prefix: string; + server: boolean; + client: boolean; + packageDir: string; + packageName: string; + binaryName: string; + binaryEnvVar: string; + packageTransports: string; + packageIpcPathArgs: string; + ipcRuntimeDependency: string; + cppNamespace: string; + cppWireNamespace: string; + cppIncludeDir: string; + uds: boolean; + ffi: boolean; + curveConstants: string; + stripMethodPrefix: boolean; +} + +function usage(): never { + console.error(`Usage: generate.ts --schema --lang --out [flags] + +Required: + --schema JSON schema file + --lang Target language (ts, rust, zig, cpp) + --out Output directory + +Optional: + --server Generate server dispatch + --client Generate client + --package Generate a TS package shell around a spawned IPC service (ts only) + --package-name TS package name for --package + --binary-name Native service binary name for --package + --binary-env-var Env var overriding the binary path for --package + --package-transports Comma-separated transports for --package (uds,shm) + --package-ipc-path-args + Comma-separated binary args for IPC path; use {path} + --ipc-runtime-dependency + package.json dependency spec for @aztec/ipc-runtime + --prefix Type prefix (auto-detected when >= 2 commands share one) + --strip-method-prefix Strip the prefix from generated method names in all + languages (e.g. BbCircuitProve -> circuitProve) + --uds Copy UDS backend templates (rust, zig only) + --ffi Copy in-process FFI backend templates (rust, zig only) + --cpp-namespace C++ namespace (e.g. my::ns) + --cpp-wire-namespace Wire types sub-namespace (default: wire) + --cpp-include-dir Include path for generated dir (e.g. myservice/generated) + --curve-constants Generate TS curve constants from JSON at `); + process.exit(1); +} + +function parseArgs(argv: string[]): Args { + const args: Args = { + schema: "", + lang: "", + out: "", + prefix: "", + server: false, + client: false, + packageDir: "", + packageName: "", + binaryName: "", + binaryEnvVar: "", + packageTransports: "uds", + packageIpcPathArgs: "--socket,{path}", + ipcRuntimeDependency: "@aztec/ipc-runtime", + cppNamespace: "", + cppWireNamespace: "wire", + cppIncludeDir: "", + uds: false, + ffi: false, + curveConstants: "", + stripMethodPrefix: false, + }; + + for (let i = 0; i < argv.length; i++) { + const flag = argv[i]; + const takeValue = (): string => { + const value = argv[++i]; + if (value === undefined || value.startsWith("--")) { + console.error(`Flag ${flag} requires a value`); + process.exit(1); + } + return value; + }; + switch (flag) { + case "--schema": + args.schema = takeValue(); + break; + case "--lang": + args.lang = takeValue(); + break; + case "--out": + args.out = takeValue(); + break; + case "--prefix": + args.prefix = takeValue(); + break; + case "--server": + args.server = true; + break; + case "--client": + args.client = true; + break; + case "--package": + args.packageDir = takeValue(); + break; + case "--package-name": + args.packageName = takeValue(); + break; + case "--binary-name": + args.binaryName = takeValue(); + break; + case "--binary-env-var": + args.binaryEnvVar = takeValue(); + break; + case "--package-transports": + args.packageTransports = takeValue(); + break; + case "--package-ipc-path-args": + args.packageIpcPathArgs = takeValue(); + break; + case "--ipc-runtime-dependency": + args.ipcRuntimeDependency = takeValue(); + break; + case "--cpp-namespace": + args.cppNamespace = takeValue(); + break; + case "--cpp-wire-namespace": + args.cppWireNamespace = takeValue(); + break; + case "--cpp-include-dir": + args.cppIncludeDir = takeValue(); + break; + case "--uds": + args.uds = true; + break; + case "--ffi": + args.ffi = true; + break; + case "--curve-constants": + args.curveConstants = takeValue(); + break; + case "--strip-method-prefix": + args.stripMethodPrefix = true; + break; + default: + console.error(`Unknown flag: ${flag}`); + process.exit(1); + } + } + + if (!args.schema || !args.lang || !args.out) { + usage(); + } + if (args.packageDir && args.lang !== "ts") { + console.error(`--package is only supported for --lang ts`); + process.exit(1); + } + if ((args.uds || args.ffi) && args.lang !== "rust" && args.lang !== "zig") { + console.error( + `--uds/--ffi copy backend templates and only apply to rust and zig; ` + + `ts and cpp consume transports from ipc-runtime directly`, + ); + process.exit(1); + } + + return args; +} + +// --------------------------------------------------------------------------- +// Schema loading +// --------------------------------------------------------------------------- + +function computeSchemaHash(schemaJson: string): string { + return createHash("sha256").update(schemaJson).digest("hex"); +} + +function loadSchema(schemaPath: string): { + compiled: CompiledSchema; + schemaHash: string; + service?: string; +} { + const rawJson = readFileSync(schemaPath, "utf-8").trim(); + const parsed = JSON.parse(stripJsonc(rawJson)); + let commandsUnion: any; + let responsesUnion: any; + let service: string | undefined; + if (isFriendlySchema(parsed)) { + ({ + commands: commandsUnion, + responses: responsesUnion, + service, + } = friendlyToPositional(parsed)); + } else { + commandsUnion = parsed.commands; + responsesUnion = parsed.responses; + } + const visitor = new SchemaVisitor(); + const compiled = visitor.visit(commandsUnion, responsesUnion); + const schemaHash = computeSchemaHash(rawJson); + return { compiled, schemaHash, service }; +} + +/** Detect common prefix from command names (e.g. WsdbGetTreeInfo, WsdbCreateFork → Wsdb) */ +function detectPrefix(compiled: CompiledSchema): string { + const names = compiled.commands.map((c) => c.name); + // With a single command the longest common prefix is the entire name and + // stripping it would erase the method name; require an explicit --prefix. + if (names.length < 2) return ""; + let prefix = names[0]; + for (const name of names.slice(1)) { + while (prefix && !name.startsWith(prefix)) { + prefix = prefix.slice(0, -1); + } + } + const words = prefix.match(/[A-Z][a-z]*/g) || []; + let result = ""; + for (const word of words) { + const candidate = result + word; + if (names.every((n) => n.startsWith(candidate))) { + result = candidate; + } else { + break; + } + } + return result; +} + +// --------------------------------------------------------------------------- +// Template copying +// --------------------------------------------------------------------------- + +function copyTemplate(lang: string, filename: string, outDir: string) { + const templatePath = join(__dirname, "..", "templates", lang, filename); + const destPath = join(outDir, filename); + // Atomic write — see writeFile() above for the race this guards. + const tmpPath = `${destPath}.${process.pid}.tmp`; + writeFileSync(tmpPath, readFileSync(templatePath, "utf-8")); + renameSync(tmpPath, destPath); + console.log(` ${destPath} (template)`); +} + +function copyTemplateDir(lang: string, dirname: string, outDir: string) { + const templatePath = join(__dirname, "..", "templates", lang, dirname); + const destPath = join(outDir, dirname); + rmSync(destPath, { recursive: true, force: true }); + cpSync(templatePath, destPath, { recursive: true }); + console.log(` ${destPath} (template)`); +} + +// --------------------------------------------------------------------------- +// C++ clang-format +// --------------------------------------------------------------------------- + +function formatCpp(files: string[]) { + if (files.length === 0) return; + try { + execSync(`clang-format-20 -i ${files.join(" ")}`, { stdio: "ignore" }); + } catch { + // clang-format-20 may not be available + } +} + +// --------------------------------------------------------------------------- +// Generation +// --------------------------------------------------------------------------- + +function generate(args: Args) { + const absSchema = resolve(args.schema); + const absOut = resolve(args.out); + mkdirSync(absOut, { recursive: true }); + + const { compiled, schemaHash, service } = loadSchema(absSchema); + // Friendly schemas fold the type prefix and method-prefix stripping into + // `service`: generated type names are `service + command`, method names are + // the bare command. Positional schemas keep the legacy --prefix/--strip flags. + const prefix = service || args.prefix || detectPrefix(compiled); + const stripMethodPrefix = service ? true : args.stripMethodPrefix; + + console.log( + `Schema: ${absSchema} (${compiled.commands.length} commands, prefix=${prefix})`, + ); + + function writeFile(name: string, content: string) { + const path = join(absOut, name); + mkdirSync(dirname(path), { recursive: true }); + // Atomic write: write to a sibling tempfile then rename. Multiple build + // trees can invoke this codegen concurrently against the same source-tree + // output dir; non-atomic writeFileSync can leave a half-written file + // visible to a parallel compiler include, showing up as embedded NUL bytes. + const tmpPath = `${path}.${process.pid}.tmp`; + writeFileSync(tmpPath, content); + renameSync(tmpPath, path); + console.log(` ${path}`); + return path; + } + + const cppFiles: string[] = []; + + switch (args.lang) { + case "ts": { + const gen = new TypeScriptCodegen({ + stripMethodPrefix: stripMethodPrefix ? prefix : undefined, + }); + writeFile("api_types.ts", gen.generateTypes(compiled, schemaHash)); + if (args.server) { + writeFile("server.ts", gen.generateServerApi(compiled)); + // No transport template copy — consumers import UdsIpcServer from + // '@aztec/ipc-runtime' (or hand a compatible byte-handler in). + } + if (args.client || args.packageDir) { + writeFile("async.ts", gen.generateAsyncApi(compiled)); + writeFile("sync.ts", gen.generateSyncApi(compiled)); + // No transport template copy — consumers import IpcClient from + // '@aztec/ipc-runtime' (or hand in a compatible byte backend). + } + if (args.curveConstants) { + generateCurveConstants(absOut, resolve(args.curveConstants)); + } + if (args.packageDir) { + const packageDir = resolve(args.packageDir); + const binaryName = + args.binaryName || toSnakeCase(prefix).replace(/_/g, "-"); + const packageName = + args.packageName || `${toSnakeCase(prefix).replace(/_/g, "-")}-ipc`; + const packageGen = new TypeScriptPackageCodegen({ + prefix, + packageName, + binaryName, + binaryEnvVar: args.binaryEnvVar || defaultBinaryEnvVar(binaryName), + ipcRuntimeDependency: args.ipcRuntimeDependency, + transports: args.packageTransports + .split(",") + .map((t) => t.trim()) + .filter(Boolean), + ipcPathArgs: args.packageIpcPathArgs + .split(",") + .map((arg) => arg.trim()) + .filter(Boolean), + }); + const writePackage = ( + name: string, + content: string, + opts?: { executable?: boolean }, + ) => { + const path = join(packageDir, name); + mkdirSync(dirname(path), { recursive: true }); + const tmpPath = `${path}.${process.pid}.tmp`; + writeFileSync(tmpPath, content); + renameSync(tmpPath, path); + if (opts?.executable) { + try { + execSync(`chmod +x ${path}`); + } catch {} + } + console.log(` ${path} (package)`); + }; + writePackage("package.json", packageGen.generatePackageJson()); + writePackage("tsconfig.json", packageGen.generateTsconfig()); + writePackage("README.md", packageGen.generateReadme()); + writePackage("src/index.ts", packageGen.generateIndex()); + writePackage("src/platform.ts", packageGen.generatePlatform()); + for (const manifest of packageGen.generateArchPackageManifests()) { + writePackage(manifest.path, manifest.content); + } + writePackage( + "scripts/prepare_arch_packages.sh", + packageGen.generatePrepareArchPackagesScript(), + { executable: true }, + ); + } + break; + } + case "rust": { + const gen = new RustCodegen({ + prefix, + stripMethodPrefix: stripMethodPrefix, + }); + writeFile( + `${toSnakeCase(prefix)}_types.rs`, + gen.generateTypes(compiled, schemaHash), + ); + if (args.server) { + writeFile( + `${toSnakeCase(prefix)}_server.rs`, + gen.generateServer(compiled), + ); + } + if (args.client) { + writeFile( + `${toSnakeCase(prefix)}_client.rs`, + gen.generateApi(compiled), + ); + } + // Backend templates (force-overwritten on regeneration). The `Backend` trait + // and `IpcError` type stay shared; ipc-runtime is consumed via the + // separate `ipc-runtime` crate. + if (args.uds || args.ffi) { + copyTemplate("rust", "backend.rs", absOut); + copyTemplate("rust", "error.rs", absOut); + } + if (args.ffi) { + copyTemplate("rust", "ffi_backend.rs", absOut); + } + break; + } + case "zig": { + const gen = new ZigCodegen({ + prefix, + clientName: `${prefix}Client`, + stripMethodPrefix: stripMethodPrefix, + }); + writeFile( + `${toSnakeCase(prefix)}_types.zig`, + gen.generateTypes(compiled, schemaHash), + ); + if (args.server) { + writeFile( + `${toSnakeCase(prefix)}_server.zig`, + gen.generateServer(compiled), + ); + // No transport template copy — consumers wire @import("ipc_runtime") + // (the Zig binding shipped from ipc-runtime/zig/) and use its + // Server.fromPath / listen / run loop directly. + } + if (args.client) { + writeFile( + `${toSnakeCase(prefix)}_client.zig`, + gen.generateClient(compiled), + ); + } + // Backend trait — keep so FFI consumers can plug in their own + // implementation. ipc_runtime.Client satisfies the same contract, + // so UDS/SHM consumers don't need a separate backend file. + if (args.uds || args.ffi) { + copyTemplate("zig", "backend.zig", absOut); + } + if (args.ffi) { + copyTemplate("zig", "ffi_backend.zig", absOut); + } + break; + } + case "cpp": { + const ns = args.cppNamespace || prefix.toLowerCase(); + const wireNs = args.cppWireNamespace; + const gen = new CppCodegen({ + namespace: ns, + prefix, + wireNamespace: wireNs, + generatedIncludeDir: args.cppIncludeDir, + stripMethodPrefix: stripMethodPrefix, + }); + + cppFiles.push( + writeFile( + `${toSnakeCase(prefix)}_types.hpp`, + gen.generateStandaloneTypes(compiled), + ), + ); + copyTemplateDir("cpp", "ipc_codegen", absOut); + if (args.server) { + cppFiles.push( + writeFile( + `${toSnakeCase(prefix)}_dispatch.hpp`, + gen.generateDispatchHeader(compiled), + ), + ); + cppFiles.push( + writeFile( + `${toSnakeCase(prefix)}_ipc_server.hpp`, + gen.generateServerHeader(), + ), + ); + } + if (args.client) { + cppFiles.push( + writeFile( + `${toSnakeCase(prefix)}_ipc_client.hpp`, + gen.generateHeader(compiled, schemaHash), + ), + ); + cppFiles.push( + writeFile( + `${toSnakeCase(prefix)}_ipc_client.cpp`, + gen.generateImpl(compiled), + ), + ); + } + + formatCpp(cppFiles); + break; + } + default: + console.error( + `Unknown language: ${args.lang}. Available: ts, rust, zig, cpp`, + ); + process.exit(1); + } + + console.log("Done."); +} + +// --------------------------------------------------------------------------- +// Curve constants +// --------------------------------------------------------------------------- + +function hexToBigInt(hex: string): bigint { + return BigInt("0x" + hex); +} + +function hexToByteList(hex: string): string { + const bytes: number[] = []; + for (let i = 0; i < hex.length; i += 2) + bytes.push(parseInt(hex.substring(i, i + 2), 16)); + return `new Uint8Array([${bytes.join(", ")}])`; +} + +function serializeCoordinate(coord: string | string[]): string { + return Array.isArray(coord) + ? `[${coord.map((c) => hexToByteList(c)).join(", ")}]` + : hexToByteList(coord); +} + +function generateCurveConstants(outputDir: string, constantsPath: string) { + const constants = JSON.parse(readFileSync(constantsPath, "utf-8")); + const content = `// AUTOGENERATED FILE - DO NOT EDIT +export const BN254_FR_MODULUS = ${hexToBigInt(constants.bn254_fr_modulus)}n; +export const BN254_FQ_MODULUS = ${hexToBigInt(constants.bn254_fq_modulus)}n; +export const BN254_G1_GENERATOR = { x: ${serializeCoordinate(constants.bn254_g1_generator.x)}, y: ${serializeCoordinate(constants.bn254_g1_generator.y)} } as const; +export const BN254_G2_GENERATOR = { x: ${serializeCoordinate(constants.bn254_g2_generator.x)}, y: ${serializeCoordinate(constants.bn254_g2_generator.y)} } as const; +export const GRUMPKIN_FR_MODULUS = ${hexToBigInt(constants.grumpkin_fr_modulus)}n; +export const GRUMPKIN_FQ_MODULUS = ${hexToBigInt(constants.grumpkin_fq_modulus)}n; +export const GRUMPKIN_G1_GENERATOR = { x: ${serializeCoordinate(constants.grumpkin_g1_generator.x)}, y: ${serializeCoordinate(constants.grumpkin_g1_generator.y)} } as const; +export const SECP256K1_FR_MODULUS = ${hexToBigInt(constants.secp256k1_fr_modulus)}n; +export const SECP256K1_FQ_MODULUS = ${hexToBigInt(constants.secp256k1_fq_modulus)}n; +export const SECP256K1_G1_GENERATOR = { x: ${serializeCoordinate(constants.secp256k1_g1_generator.x)}, y: ${serializeCoordinate(constants.secp256k1_g1_generator.y)} } as const; +export const SECP256R1_FR_MODULUS = ${hexToBigInt(constants.secp256r1_fr_modulus)}n; +export const SECP256R1_FQ_MODULUS = ${hexToBigInt(constants.secp256r1_fq_modulus)}n; +export const SECP256R1_G1_GENERATOR = { x: ${serializeCoordinate(constants.secp256r1_g1_generator.x)}, y: ${serializeCoordinate(constants.secp256r1_g1_generator.y)} } as const; +`; + mkdirSync(outputDir, { recursive: true }); + const path = join(outputDir, "curve_constants.ts"); + const tmpPath = `${path}.${process.pid}.tmp`; + writeFileSync(tmpPath, content); + renameSync(tmpPath, path); + console.log(` ${path}`); +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +const args = parseArgs(process.argv.slice(2)); +generate(args); diff --git a/ipc-codegen/src/naming.ts b/ipc-codegen/src/naming.ts new file mode 100644 index 000000000000..5c2c249726b8 --- /dev/null +++ b/ipc-codegen/src/naming.ts @@ -0,0 +1,75 @@ +/** + * Shared naming utilities for code generators + */ + +/** + * Convert camelCase or PascalCase to snake_case + * @example toSnakeCase("Blake2s") -> "blake2s" + * @example toSnakeCase("poseidonHash") -> "poseidon_hash" + */ +export function toSnakeCase(name: string): string { + return name + .replace(/([A-Z])/g, "_$1") + .toLowerCase() + .replace(/^_/, ""); +} + +/** + * Convert snake_case or camelCase to PascalCase. Interior capitals are + * preserved so the mapping does not destroy information. + * @example toPascalCase("blake2s") -> "Blake2s" + * @example toPascalCase("poseidon_hash") -> "PoseidonHash" + * @example toPascalCase("treeId") -> "TreeId" + */ +export function toPascalCase(name: string): string { + if (!name.includes("_")) { + return name.charAt(0).toUpperCase() + name.slice(1); + } + return name + .split("_") + .map((part) => part.charAt(0).toUpperCase() + part.slice(1)) + .join(""); +} + +/** + * Convert snake_case or PascalCase to camelCase. + * @example toCamelCase("tree_id") -> "treeId" + * @example toCamelCase("forkId") -> "forkId" + */ +export function toCamelCase(name: string): string { + // If no underscores, assume already camelCase (e.g. forkId, classId) + if (!name.includes("_")) { + return name.charAt(0).toLowerCase() + name.slice(1); + } + const pascal = toPascalCase(name); + return pascal.charAt(0).toLowerCase() + pascal.slice(1); +} + +/** + * Convert a schema alias name into its language type name. Strips a trailing + * `_t` (uint256_t -> Uint256) and PascalCases the rest, so `fr` -> `Fr`, + * `secp256k1_fr` -> `Secp256k1Fr`, `uint256_t` -> `Uint256`. + */ +export function toAliasName(name: string): string { + const trimmed = name.endsWith("_t") ? name.slice(0, -2) : name; + return toPascalCase(trimmed); +} + +/** + * Deduplicate structs by name, preserving first-seen order. A response can + * reference a type that was also discovered inline as a field (the schema + * dedups the second definition to a name string), so the structs and + * responses maps can hold the same type — it must only be emitted once. + */ +export function dedupeStructsByName( + items: T[], +): T[] { + const seen = new Set(); + const out: T[] = []; + for (const item of items) { + if (seen.has(item.name)) continue; + seen.add(item.name); + out.push(item); + } + return out; +} diff --git a/ipc-codegen/src/rust_codegen.ts b/ipc-codegen/src/rust_codegen.ts new file mode 100644 index 000000000000..d0b34c2ce4b8 --- /dev/null +++ b/ipc-codegen/src/rust_codegen.ts @@ -0,0 +1,772 @@ +/** + * Rust Code Generator - String template based + * + * Philosophy: + * - String templates for file structure + * - Simple type mapping + * - Idiomatic Rust conventions + * - No complex abstraction + */ + +import type { CompiledSchema, Type, Struct, Field } from "./schema_visitor.ts"; +import { toSnakeCase, toPascalCase, toAliasName } from "./naming.ts"; + +export interface RustCodegenOptions { + /** Type prefix, e.g. 'Svc' (used for type/file naming) */ + prefix?: string; + /** Strip the prefix from method names, e.g. SvcGetInfo -> get_info */ + stripMethodPrefix?: boolean; + /** API struct name, e.g. 'SvcApi'. Defaults to 'IpcApi' */ + apiStructName?: string; + /** Import path for Backend trait. Defaults to 'crate::backend::Backend' */ + backendImport?: string; + /** Import path for error types. Defaults to 'crate::error::{IpcError, Result}' */ + errorImport?: string; + /** Import path for generated types. Defaults to 'crate::types_gen::*' */ + typesImport?: string; + /** Module doc comment for types file */ + typesDocComment?: string; + /** Module doc comment for api file */ + apiDocComment?: string; +} + +export class RustCodegen { + private errorTypeName: string = "ErrorResponse"; + private opts: Required; + + constructor(options?: RustCodegenOptions) { + const prefix = options?.prefix ?? ""; + const name = prefix || "Ipc"; + this.opts = { + prefix, + stripMethodPrefix: options?.stripMethodPrefix ?? false, + apiStructName: options?.apiStructName ?? `${name}Api`, + backendImport: options?.backendImport ?? "super::backend::Backend", + errorImport: options?.errorImport ?? `super::error::{IpcError, Result}`, + typesImport: + options?.typesImport ?? + `super::${toSnakeCase(prefix || "ipc")}_types::*`, + typesDocComment: + options?.typesDocComment ?? `Generated types for ${name} IPC protocol`, + apiDocComment: options?.apiDocComment ?? `${name} IPC client API`, + }; + } + + private primitiveType(type: Type): string { + switch (type.primitive) { + case "bool": + return "bool"; + case "u8": + return "u8"; + case "u16": + return "u16"; + case "u32": + return "u32"; + case "u64": + return "u64"; + case "f64": + return "f64"; + case "string": + return "String"; + case "bytes": + return "Vec"; + case "bin32": + return "Bin32"; + } + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + + // Type mapping: Schema type -> Rust type + private mapType(type: Type): string { + switch (type.kind) { + case "primitive": + return type.originalName + ? toAliasName(type.originalName) + : this.primitiveType(type); + + case "vector": + return `Vec<${this.mapType(type.element!)}>`; + + case "array": + const elemType = this.mapType(type.element!); + // Large arrays become Vec for ergonomics + return type.size! > 32 + ? `Vec<${elemType}>` + : `[${elemType}; ${type.size}]`; + + case "optional": + return `Option<${this.mapType(type.element!)}>`; + + case "struct": + // Convert struct names to PascalCase for Rust conventions + return toPascalCase(type.struct!.name); + } + + throw new Error(`Unsupported type kind: ${type.kind}`); + } + + // Check if field needs serde(with = "serde_bytes") + private needsSerdeBytes(type: Type): boolean { + return type.kind === "primitive" && type.primitive === "bytes"; + } + + // Check if field needs serde(with = "serde_vec_bytes") + private needsSerdeVecBytes(type: Type): boolean { + return type.kind === "vector" && this.needsSerdeBytes(type.element!); + } + + // Check if field needs serde(with = "serde_bytes_array") - for [Vec; N]. + // Only applies up to the size-32 cutoff in mapType; larger arrays become + // Vec and take the serde_vec_bytes path. + private needsSerdeBytesArray(type: Type): boolean { + return ( + type.kind === "array" && + type.size! <= 32 && + this.needsSerdeBytes(type.element!) + ); + } + + // Check if field needs serde(with = "serde_vec_bytes") via the large-array + // fallback ([bytes; N>32] maps to Vec>). + private needsSerdeLargeBytesArray(type: Type): boolean { + return ( + type.kind === "array" && + type.size! > 32 && + this.needsSerdeBytes(type.element!) + ); + } + + // Check if field needs serde(with = "serde_opt_bytes") + private needsSerdeOptBytes(type: Type): boolean { + return type.kind === "optional" && this.needsSerdeBytes(type.element!); + } + + // Generate struct field + private generateField(field: Field): string { + const rustName = toSnakeCase(field.name); + const rustType = this.mapType(field.type); + let attrs = ""; + + // Add serde rename if needed + if (field.name !== rustName) { + attrs += ` #[serde(rename = "${field.name}")]\n`; + } + + // Add serde bytes handling + if (this.needsSerdeBytesArray(field.type)) { + attrs += ` #[serde(with = "serde_bytes_array")]\n`; + } else if ( + this.needsSerdeVecBytes(field.type) || + this.needsSerdeLargeBytesArray(field.type) + ) { + attrs += ` #[serde(with = "serde_vec_bytes")]\n`; + } else if (this.needsSerdeOptBytes(field.type)) { + attrs += ` #[serde(with = "serde_opt_bytes")]\n`; + } else if (this.needsSerdeBytes(field.type)) { + attrs += ` #[serde(with = "serde_bytes")]\n`; + } + + return `${attrs} pub ${rustName}: ${rustType},`; + } + + // Generate a struct definition + private generateStruct(struct: Struct, isCommand: boolean): string { + const rustName = toPascalCase(struct.name); + const fields = struct.fields.map((f) => this.generateField(f)).join("\n"); + + // Add serde rename if struct name changed + const serdeRename = + struct.name !== rustName ? `\n#[serde(rename = "${struct.name}")]` : ""; + + // Generate constructor for commands + const constructor = isCommand + ? this.generateConstructor(struct, rustName) + : ""; + + return `/// ${struct.name} +#[derive(Debug, Clone, Serialize, Deserialize)]${serdeRename} +pub struct ${rustName} { +${fields} +}${constructor}`; + } + + // Generate constructor for command structs + private generateConstructor(struct: Struct, rustName: string): string { + const params = struct.fields + .map((f) => `${toSnakeCase(f.name)}: ${this.mapType(f.type)}`) + .join(", "); + + const fieldInits = struct.fields + .map((f) => ` ${toSnakeCase(f.name)},`) + .join("\n"); + + return ` + +impl ${rustName} { + pub fn new(${params}) -> Self { + Self { +${fieldInits} + } + } +}`; + } + + // Generate Command enum + private generateCommandEnum(schema: CompiledSchema): string { + const names = schema.commands.map((c) => c.name); + const variants = names + .map((name) => { + const rustName = toPascalCase(name); + return ` ${rustName}(${rustName}),`; + }) + .join("\n"); + + const serializeCases = names + .map((name) => { + const rustName = toPascalCase(name); + return ` Command::${rustName}(data) => { + tuple.serialize_element("${name}")?; + tuple.serialize_element(data)?; + }`; + }) + .join("\n"); + + const deserializeCases = names + .map((name) => { + const rustName = toPascalCase(name); + return ` "${name}" => { + let data = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + Ok(Command::${rustName}(data)) + }`; + }) + .join("\n"); + + const variantNames = names.map((name) => `"${name}"`).join(", "); + + return `/// Command enum - wraps all possible commands +#[derive(Debug, Clone)] +pub enum Command { +${variants} +} + +impl Serialize for Command { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + use serde::ser::SerializeTuple; + let mut tuple = serializer.serialize_tuple(2)?; + match self { +${serializeCases} + } + tuple.end() + } +} + +impl<'de> Deserialize<'de> for Command { + fn deserialize(deserializer: D) -> Result + where D: serde::Deserializer<'de> { + use serde::de::{SeqAccess, Visitor}; + struct CommandVisitor; + + impl<'de> Visitor<'de> for CommandVisitor { + type Value = Command; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 2-element array [name, payload]") + } + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let name: String = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + match name.as_str() { +${deserializeCases} + _ => Err(serde::de::Error::unknown_variant(&name, &[${variantNames}])), + } + } + } + deserializer.deserialize_tuple(2, CommandVisitor) + } +}`; + } + + // Generate Response enum + private generateResponseEnum(schema: CompiledSchema): string { + // Include all response types from commands plus ErrorResponse if it exists + const commandResponseTypes = Array.from( + new Set(schema.commands.map((c) => c.responseType)), + ); + const errorName = schema.errorTypeName; + const responseTypes = schema.responses.has(errorName) + ? [...commandResponseTypes, errorName] + : commandResponseTypes; + const variants = responseTypes + .map((name) => { + const rustName = toPascalCase(name); + return ` ${rustName}(${rustName}),`; + }) + .join("\n"); + + const serializeCases = responseTypes + .map((name) => { + const rustName = toPascalCase(name); + return ` Response::${rustName}(data) => { + tuple.serialize_element("${name}")?; + tuple.serialize_element(data)?; + }`; + }) + .join("\n"); + + const deserializeCases = responseTypes + .map((name) => { + const rustName = toPascalCase(name); + return ` "${name}" => { + let data = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + Ok(Response::${rustName}(data)) + }`; + }) + .join("\n"); + + const variantNames = responseTypes.map((name) => `"${name}"`).join(", "); + + return `/// Response enum - wraps all possible responses +#[derive(Debug, Clone)] +pub enum Response { +${variants} +} + +impl Serialize for Response { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + use serde::ser::SerializeTuple; + let mut tuple = serializer.serialize_tuple(2)?; + match self { +${serializeCases} + } + tuple.end() + } +} + +impl<'de> Deserialize<'de> for Response { + fn deserialize(deserializer: D) -> Result + where D: serde::Deserializer<'de> { + use serde::de::{SeqAccess, Visitor}; + struct ResponseVisitor; + + impl<'de> Visitor<'de> for ResponseVisitor { + type Value = Response; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 2-element array [name, payload]") + } + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let name: String = seq.next_element()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + match name.as_str() { +${deserializeCases} + _ => Err(serde::de::Error::unknown_variant(&name, &[${variantNames}])), + } + } + } + deserializer.deserialize_tuple(2, ResponseVisitor) + } +}`; + } + + // Generate serde helper modules + private generateSerdeHelpers(): string { + return `mod serde_bytes { + use serde::{Deserialize, Deserializer, Serializer}; + pub fn serialize(bytes: &Vec, serializer: S) -> Result + where S: Serializer { serializer.serialize_bytes(bytes) } + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where D: Deserializer<'de> { >::deserialize(deserializer) } +} + +mod serde_vec_bytes { + use serde::{Deserialize, Deserializer, Serializer, Serialize}; + use serde::ser::SerializeSeq; + use serde::de::{SeqAccess, Visitor}; + + #[derive(Serialize, Deserialize)] + struct BytesWrapper(#[serde(with = "super::serde_bytes")] Vec); + + pub fn serialize(vec: &Vec>, serializer: S) -> Result + where S: Serializer { + let mut seq = serializer.serialize_seq(Some(vec.len()))?; + for bytes in vec { + seq.serialize_element(&BytesWrapper(bytes.clone()))?; + } + seq.end() + } + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where D: Deserializer<'de> { + struct VecVecU8Visitor; + impl<'de> Visitor<'de> for VecVecU8Visitor { + type Value = Vec>; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence of byte arrays") + } + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let mut vec = Vec::new(); + while let Some(wrapper) = seq.next_element::()? { + vec.push(wrapper.0); + } + Ok(vec) + } + } + deserializer.deserialize_seq(VecVecU8Visitor) + } +} + +mod serde_bytes_array { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + use serde::ser::SerializeTuple; + use serde::de::{SeqAccess, Visitor}; + + #[derive(Serialize, Deserialize)] + struct BytesWrapper(#[serde(with = "super::serde_bytes")] Vec); + + pub fn serialize(arr: &[Vec; N], serializer: S) -> Result + where S: Serializer { + let mut tup = serializer.serialize_tuple(N)?; + for bytes in arr { + tup.serialize_element(&BytesWrapper(bytes.clone()))?; + } + tup.end() + } + pub fn deserialize<'de, D, const N: usize>(deserializer: D) -> Result<[Vec; N], D::Error> + where D: Deserializer<'de> { + struct ArrayVisitor; + impl<'de, const N: usize> Visitor<'de> for ArrayVisitor { + type Value = [Vec; N]; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "an array of {N} byte arrays") + } + fn visit_seq(self, mut seq: A) -> Result + where A: SeqAccess<'de> { + let mut arr: [Vec; N] = std::array::from_fn(|_| Vec::new()); + for (i, item) in arr.iter_mut().enumerate() { + *item = seq.next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(i, &self))?.0; + } + Ok(arr) + } + } + deserializer.deserialize_tuple(N, ArrayVisitor::) + } +} + +mod serde_opt_bytes { + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + #[derive(Serialize, Deserialize)] + struct BytesWrapper(#[serde(with = "super::serde_bytes")] Vec); + + pub fn serialize(opt: &Option>, serializer: S) -> Result + where S: Serializer { + match opt { + Some(bytes) => serializer.serialize_some(&BytesWrapper(bytes.clone())), + None => serializer.serialize_none(), + } + } + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where D: Deserializer<'de> { + Ok(Option::::deserialize(deserializer)?.map(|w| w.0)) + } +}`; + } + + // Generate types file + generateTypes(schema: CompiledSchema, schemaHash?: string): string { + this.errorTypeName = schema.errorTypeName; + // Command structs get a generated `new()` constructor; response/shared + // structs do not. + const commandNames = new Set(schema.commands.map((c) => c.name)); + + const aliasTypes = new Map(); + const collect = (type: Type): void => { + if (type.kind === "primitive" && type.originalName) { + aliasTypes.set( + toAliasName(type.originalName), + type.primitive === "bin32" ? "Bin32" : this.primitiveType(type), + ); + } else if ( + type.kind === "vector" || + type.kind === "array" || + type.kind === "optional" + ) { + if (type.element) collect(type.element); + } + }; + for (const s of schema.structs.values()) { + for (const f of s.fields) collect(f.type); + } + for (const s of schema.responses.values()) { + for (const f of s.fields) collect(f.type); + } + const aliasDecls = [...aliasTypes.entries()] + .sort(([a], [b]) => a.localeCompare(b)) + .map(([name, underlying]) => `pub type ${name} = ${underlying};`) + .join("\n"); + + // Generate all structs (commands first, then responses) + const commandStructs = Array.from(schema.structs.values()) + .map((s) => this.generateStruct(s, commandNames.has(s.name))) + .join("\n\n"); + + // A response can reference a type also discovered inline as a field + // (registered in structs); emit it only once, from commandStructs. + const responseStructs = Array.from(schema.responses.values()) + .filter((s) => !schema.structs.has(s.name)) + .map((s) => this.generateStruct(s, false)) + .join("\n\n"); + + const hashLine = schemaHash + ? `\n/// Schema version hash for compatibility checking\npub const SCHEMA_HASH: &str = "${schemaHash}";\n` + : ""; + + return `//! AUTOGENERATED - DO NOT EDIT +//! ${this.opts.typesDocComment} + +use serde::{Deserialize, Serialize}; +${hashLine} +/// 32 raw bytes encoded as msgpack bin32. Primitive schema aliases below are +/// zero-cost pub type declarations over either this newtype or a scalar. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Bin32(pub [u8; 32]); + +impl Bin32 { + pub fn from_bytes(bytes: [u8; 32]) -> Self { Self(bytes) } + pub fn to_bytes(&self) -> &[u8; 32] { &self.0 } + pub fn as_slice(&self) -> &[u8] { &self.0 } +} + +impl Serialize for Bin32 { + fn serialize(&self, serializer: S) -> Result + where S: serde::Serializer { + serializer.serialize_bytes(&self.0) + } +} + +impl<'de> Deserialize<'de> for Bin32 { + fn deserialize(deserializer: D) -> Result + where D: serde::Deserializer<'de> { + let bytes: Vec = >::deserialize(deserializer)?; + let arr: [u8; 32] = bytes.try_into() + .map_err(|v: Vec| serde::de::Error::invalid_length(v.len(), &"32 bytes"))?; + Ok(Bin32(arr)) + } +} + +${aliasDecls} + +${this.generateSerdeHelpers()} + +${commandStructs} + +${responseStructs} + +${this.generateCommandEnum(schema)} + +${this.generateResponseEnum(schema)} +`; + } + + /** Convert a command name to a Rust method name (snake_case) */ + private methodName(commandName: string): string { + const withoutPrefix = + this.opts.stripMethodPrefix && + this.opts.prefix && + commandName.startsWith(this.opts.prefix) + ? commandName.slice(this.opts.prefix.length) + : commandName; + return toSnakeCase(withoutPrefix); + } + + // Generate API method + private generateApiMethod(command: { + name: string; + fields: Field[]; + responseType: string; + }): string { + const methodName = this.methodName(command.name); + const cmdRustName = toPascalCase(command.name); + const respRustName = toPascalCase(command.responseType); + + const params = command.fields + .map((f) => { + const rustType = this.mapType(f.type); + // Only convert simple Vec to &[u8], not nested types + const apiType = rustType === "Vec" ? "&[u8]" : rustType; + return `${toSnakeCase(f.name)}: ${apiType}`; + }) + .join(", "); + + const paramConversions = command.fields + .map((f) => { + const name = toSnakeCase(f.name); + const rustType = this.mapType(f.type); + // Only convert slices back to Vec + if (rustType === "Vec") { + return `${name}.to_vec()`; + } + return name; + }) + .join(", "); + + // Extract error type name from the error import (e.g., 'IpcError' from 'crate::error::{IpcError, Result}') + const errorType = + this.opts.errorImport.match(/\{(\w+),/)?.[1] ?? "IpcError"; + + return ` /// Execute ${command.name} + pub fn ${methodName}(&mut self, ${params}) -> Result<${respRustName}> { + let cmd = Command::${cmdRustName}(${cmdRustName}::new(${paramConversions})); + match self.execute(cmd)? { + Response::${respRustName}(resp) => Ok(resp), + Response::${toPascalCase(this.errorTypeName)}(err) => Err(${errorType}::Backend( + err.message + )), + _ => Err(${errorType}::InvalidResponse( + "Expected ${command.responseType}".to_string() + )), + } + }`; + } + + // Generate API file + generateApi(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const { + apiStructName, + backendImport, + errorImport, + typesImport, + apiDocComment, + } = this.opts; + + const apiMethods = schema.commands + .map((c) => this.generateApiMethod(c)) + .join("\n\n"); + + const errorType = errorImport.match(/\{(\w+),/)?.[1] ?? "IpcError"; + + return `//! AUTOGENERATED - DO NOT EDIT +//! ${apiDocComment} + +use ${backendImport}; +use ${errorImport}; +use ${typesImport}; + +/// ${apiDocComment} +pub struct ${apiStructName} { + backend: B, +} + +impl ${apiStructName} { + /// Create API with custom backend + pub fn new(backend: B) -> Self { + Self { backend } + } + + fn execute(&mut self, command: Command) -> Result { + let input_buffer = rmp_serde::to_vec_named(&vec![command]) + .map_err(|e| ${errorType}::Serialization(e.to_string()))?; + + let output_buffer = self.backend.call(&input_buffer)?; + + let response: Response = rmp_serde::from_slice(&output_buffer) + .map_err(|e| ${errorType}::Deserialization(e.to_string()))?; + + Ok(response) + } + +${apiMethods} + /// Destroy backend without shutdown command + pub fn destroy(&mut self) -> Result<()> { + self.backend.destroy() + } +} +`; + } + + // ----------------------------------------------------------------------- + // Server-side code generation + // ----------------------------------------------------------------------- + + /** Generate a Handler trait and serve() function */ + generateServer(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const { prefix, errorImport, typesImport } = this.opts; + const errorRespType = toPascalCase(this.errorTypeName); + + const traitMethods = schema.commands + .map((c) => { + const methodName = this.methodName(c.name); + const cmdRustName = toPascalCase(c.name); + const respRustName = toPascalCase(c.responseType); + return ` fn ${methodName}(&mut self, cmd: ${cmdRustName}) -> Result<${respRustName}>;`; + }) + .join("\n"); + + const dispatchArms = schema.commands + .map((c) => { + const methodName = this.methodName(c.name); + const cmdRustName = toPascalCase(c.name); + const respRustName = toPascalCase(c.responseType); + return ` Command::${cmdRustName}(cmd) => { + match handler.${methodName}(cmd) { + Ok(resp) => Response::${respRustName}(resp), + Err(e) => Response::${errorRespType}(${errorRespType} { message: e.to_string() }), + } + }`; + }) + .join("\n"); + + return `//! AUTOGENERATED - DO NOT EDIT +//! Server-side dispatch for ${prefix || "service"} IPC protocol + +// The import set is shared with the generated client; dispatch only needs +// Result, so tolerate the unused error type. +#[allow(unused_imports)] +use ${errorImport}; +use ${typesImport}; + +/// Handler trait — implement this to serve ${prefix || "service"} commands. +pub trait Handler { +${traitMethods} +} + +/// Dispatch a single command to the handler and return the response. +pub fn dispatch(handler: &mut dyn Handler, command: Command) -> Result { + let response = match command { +${dispatchArms} + }; + Ok(response) +} + +/// Decode a framed request, dispatch it, and encode the framed response. +/// All failures (malformed framing included) produce the schema error +/// variant, so transports can use this directly as their request handler. +pub fn handle_request(handler: &mut dyn Handler, request_bytes: &[u8]) -> Vec { + let response = match rmp_serde::from_slice::>(request_bytes) { + Err(e) => Response::${errorRespType}(${errorRespType} { + message: format!("malformed request: {e}"), + }), + Ok(commands) => match commands.into_iter().next() { + None => Response::${errorRespType}(${errorRespType} { + message: "malformed request: empty command array".to_string(), + }), + Some(command) => match dispatch(handler, command) { + Ok(resp) => resp, + Err(e) => Response::${errorRespType}(${errorRespType} { + message: e.to_string(), + }), + }, + }, + }; + rmp_serde::to_vec_named(&response).unwrap_or_default() +} +`; + } +} diff --git a/ipc-codegen/src/schema_visitor.ts b/ipc-codegen/src/schema_visitor.ts new file mode 100644 index 000000000000..8e98ea9d5424 --- /dev/null +++ b/ipc-codegen/src/schema_visitor.ts @@ -0,0 +1,678 @@ +/** + * Schema Visitor - Minimal abstraction over raw msgpack schema + * + * Philosophy: + * - Keep raw schema structure + * - Resolve type references into a graph + * - No normalization - languages handle their own conventions + * - Output is "compiled schema" with resolved types + */ + +import { toSnakeCase, toCamelCase } from "./naming.ts"; + +export type PrimitiveType = + | "bool" + | "u8" + | "u16" + | "u32" + | "u64" + | "f64" + | "string" + | "bytes" + | "bin32"; + +export interface Type { + kind: "primitive" | "vector" | "array" | "optional" | "struct"; + primitive?: PrimitiveType; + element?: Type; // For vector, array, optional + size?: number; // For array + struct?: Struct; // For struct types + originalName?: string; // Alias name from schema, when present. +} + +export interface Field { + name: string; + type: Type; +} + +export interface Struct { + name: string; + fields: Field[]; +} + +export interface Command { + name: string; + fields: Field[]; + responseType: string; +} + +export interface CompiledSchema { + // All unique struct types discovered + structs: Map; + + // Command -> Response mappings + commands: Command[]; + + // Response types + responses: Map; + + // Error response type name (e.g. 'WsdbErrorResponse'). Always present: + // schema validation rejects schemas without an error variant. + errorTypeName: string; +} + +/** + * Words that are keywords (or otherwise unusable as plain identifiers) in at + * least one target language. Field names whose snake_case or camelCase + * projection lands here would generate broken code. + */ +const RESERVED_WORDS = new Set([ + // Rust + "type", + "fn", + "match", + "impl", + "trait", + "mod", + "use", + "ref", + "self", + "super", + "crate", + "move", + "dyn", + "async", + "await", + "loop", + "where", + // Zig + "error", + "var", + "comptime", + "defer", + "errdefer", + "test", + "union", + "undefined", + "unreachable", + "orelse", + "and", + "or", + // C++ + "namespace", + "int", + "char", + "short", + "long", + "float", + "double", + "signed", + "unsigned", + "register", + "template", + "typename", + "operator", + "virtual", + "inline", + "friend", + "mutable", + "explicit", + "export", + "this", + "delete", + // JS/TS + "new", + "class", + "function", + "extends", + "instanceof", + "typeof", + "in", + "void", + "with", + "yield", + "let", + // Shared / common + "const", + "static", + "struct", + "enum", + "if", + "else", + "for", + "while", + "do", + "switch", + "case", + "default", + "break", + "continue", + "return", + "true", + "false", + "null", + "bool", + "throw", + "try", + "catch", + "public", + "private", + "protected", +]); + +function validateNamedUnionShape(schema: any, label: string): void { + if ( + !Array.isArray(schema) || + schema[0] !== "named_union" || + !Array.isArray(schema[1]) || + !schema[1].every( + (entry: any) => Array.isArray(entry) && typeof entry[0] === "string", + ) + ) { + throw new Error( + `Schema '${label}' is not in ["named_union", [[name, schema], ...]] form`, + ); + } +} + +/** + * SchemaVisitor - Walks raw msgpack schema and resolves references + */ +export class SchemaVisitor { + private structs = new Map(); + private responses = new Map(); + + visit(commandsSchema: any, responsesSchema: any): CompiledSchema { + // Reset state + this.structs.clear(); + this.responses.clear(); + + const commands: Command[] = []; + + // Schema format: ["named_union", [[name, schema], ...]] + validateNamedUnionShape(commandsSchema, "commands"); + validateNamedUnionShape(responsesSchema, "responses"); + const commandPairs = commandsSchema[1] as Array<[string, any]>; + const responsePairs = responsesSchema[1] as Array<[string, any]>; + + // First, visit all response types (including ErrorResponse). A string + // schema is a reference to a struct defined earlier in the document — + // schema reflection dedups repeated definitions to name strings (e.g. a + // response type that also appears inline as a field of an earlier + // response). It must resolve; a dangling reference means the generators + // would emit a type nothing defines. + for (const [respName, respSchema] of responsePairs) { + if (typeof respSchema === "string") { + const resolved = + this.structs.get(respSchema) ?? this.responses.get(respSchema); + if (!resolved) { + throw new Error( + `Response '${respName}' references '${respSchema}', which is not defined earlier in the schema`, + ); + } + this.responses.set(respName, resolved); + continue; + } + const respStruct = this.visitStruct(respName, respSchema); + this.responses.set(respName, respStruct); + } + + // Find the error response type name (e.g. 'WsdbErrorResponse') + const errorResponses = responsePairs.filter(([name]: [string, any]) => + name.endsWith("ErrorResponse"), + ); + if (errorResponses.length === 0) { + throw new Error( + "Schema has no error response: the responses union must contain a variant named '*ErrorResponse'", + ); + } + const errorTypeName = errorResponses[0][0]; + const errorStruct = this.responses.get(errorTypeName)!; + if ( + errorStruct.fields.length !== 1 || + errorStruct.fields[0].name !== "message" || + errorStruct.fields[0].type.primitive !== "string" + ) { + throw new Error( + `Error response '${errorTypeName}' must have exactly one field 'message: string'`, + ); + } + + // Commands pair with non-error responses by position (the schema is + // reflected from C++ unions declared in matching order, and a command + // may deliberately reuse another command's response shape, so names + // alone cannot pair them). Two guards close the silent-mispair hole the + // old unchecked indexing had: the counts must match exactly, and when a + // response named 'Response' exists it must be the one at the + // command's position — anything else means the unions are misordered. + const normalResponses = responsePairs.filter( + ([name]: [string, any]) => !name.endsWith("ErrorResponse"), + ); + if (normalResponses.length !== commandPairs.length) { + throw new Error( + `Schema has ${commandPairs.length} commands but ${normalResponses.length} non-error responses`, + ); + } + const normalResponseNames = new Set( + normalResponses.map(([name]: [string, any]) => name), + ); + const seenCommands = new Set(); + for (let i = 0; i < commandPairs.length; i++) { + const [cmdName, cmdSchema] = commandPairs[i]; + if (seenCommands.has(cmdName)) { + throw new Error(`Duplicate command name: ${cmdName}`); + } + seenCommands.add(cmdName); + + const [respName] = normalResponses[i]; + const conventionalName = `${cmdName}Response`; + if ( + respName !== conventionalName && + normalResponseNames.has(conventionalName) + ) { + throw new Error( + `Command '${cmdName}' pairs with '${respName}' by position, but a response named '${conventionalName}' exists elsewhere — the unions are misordered`, + ); + } + + // Discover command structure + const cmdStruct = this.visitStruct(cmdName, cmdSchema); + this.structs.set(cmdName, cmdStruct); + + // Create command mapping + commands.push({ + name: cmdName, + fields: cmdStruct.fields, + responseType: respName, + }); + } + + const compiled = { + structs: this.structs, + commands, + responses: this.responses, + errorTypeName, + }; + this.validateStructReferences(compiled); + this.validateIdentifiers(compiled); + return compiled; + } + + private visitStruct(name: string, schema: any): Struct { + const fields: Field[] = []; + + // Schema is an object with __typename and fields + for (const [key, value] of Object.entries(schema)) { + if (key === "__typename") continue; + + fields.push({ + name: key, + type: this.visitType(value), + }); + } + + return { name, fields }; + } + + private visitType(schema: any): Type { + // Primitive string type + if (typeof schema === "string") { + return this.resolvePrimitive(schema); + } + + // Array type descriptor: ['vector', [elementType]] + if (Array.isArray(schema)) { + const [kind, args] = schema; + + switch (kind) { + case "vector": { + const [elemType] = args as [any]; + // Special case: vector = bytes + if (elemType === "unsigned char") { + return { kind: "primitive", primitive: "bytes" }; + } + return { + kind: "vector", + element: this.visitType(elemType), + }; + } + + case "array": { + const [elemType, size] = args as [any, number]; + return { + kind: "array", + element: this.visitType(elemType), + size, + }; + } + + case "optional": { + const [elemType] = args as [any]; + return { + kind: "optional", + element: this.visitType(elemType), + }; + } + + case "shared_ptr": { + // Dereference shared_ptr - just use inner type + const [innerType] = args as [any]; + return this.visitType(innerType); + } + + case "alias": { + // Aliases carry [aliasName, underlyingKind]. The underlying kind is + // usually a primitive schema string. We preserve the alias name so + // generators can emit named zero-cost aliases over primitive wire + // shapes. + const [aliasName, underlying] = args as [string, string]; + if (underlying === "bin32") { + return { + kind: "primitive", + primitive: "bin32", + originalName: aliasName, + }; + } + return { + ...this.resolvePrimitive(underlying), + originalName: aliasName, + }; + } + + default: + throw new Error(`Unknown type kind: ${kind}`); + } + } + + // Inline struct definition + if (typeof schema === "object" && schema.__typename) { + const structName = schema.__typename as string; + // Check if already visited + if (!this.structs.has(structName)) { + const struct = this.visitStruct(structName, schema); + this.structs.set(structName, struct); + } + return { + kind: "struct", + struct: this.structs.get(structName)!, + }; + } + + throw new Error(`Cannot resolve type: ${JSON.stringify(schema)}`); + } + + private resolvePrimitive(name: string): Type { + const primitiveMap: Record = { + bool: "bool", + int: "u32", + "unsigned int": "u32", + "unsigned short": "u16", + "unsigned long": "u64", + "unsigned long long": "u64", + "unsigned char": "u8", + double: "f64", + string: "string", + bin32: "bin32", + }; + + const primitive = primitiveMap[name]; + if (primitive) { + return { kind: "primitive", primitive }; + } + + const knownStruct = this.structs.get(name); + if (knownStruct) { + return { kind: "struct", struct: knownStruct }; + } + + // Unknown primitive - treat as a forward struct reference. + return { + kind: "struct", + struct: { name, fields: [] }, + }; + } + + /** + * Reject field names that produce broken or colliding identifiers in any + * target language. Field names are emitted as snake_case (Rust/Zig/C++) + * and camelCase (TS), so both projections are checked. + */ + private validateIdentifiers(schema: CompiledSchema): void { + const allStructs = [ + ...schema.structs.values(), + ...schema.responses.values(), + ]; + for (const struct of allStructs) { + const snakeSeen = new Map(); + const camelSeen = new Map(); + for (const field of struct.fields) { + const snake = toSnakeCase(field.name); + const camel = toCamelCase(field.name); + if (RESERVED_WORDS.has(snake) || RESERVED_WORDS.has(camel)) { + throw new Error( + `Field '${struct.name}.${field.name}' maps to a reserved word in a target language`, + ); + } + const snakeClash = snakeSeen.get(snake); + if (snakeClash !== undefined && snakeClash !== field.name) { + throw new Error( + `Fields '${struct.name}.${snakeClash}' and '${struct.name}.${field.name}' both map to '${snake}'`, + ); + } + snakeSeen.set(snake, field.name); + const camelClash = camelSeen.get(camel); + if (camelClash !== undefined && camelClash !== field.name) { + throw new Error( + `Fields '${struct.name}.${camelClash}' and '${struct.name}.${field.name}' both map to '${camel}'`, + ); + } + camelSeen.set(camel, field.name); + } + } + } + + private validateStructReferences(schema: CompiledSchema): void { + const knownNames = new Set([ + ...schema.structs.keys(), + ...schema.responses.keys(), + ]); + const visitType = (type: Type): void => { + if ( + type.kind === "struct" && + type.struct && + !knownNames.has(type.struct.name) + ) { + throw new Error(`Unknown struct reference: ${type.struct.name}`); + } + if ( + (type.kind === "vector" || + type.kind === "array" || + type.kind === "optional") && + type.element + ) { + visitType(type.element); + } + }; + + for (const struct of schema.structs.values()) { + for (const field of struct.fields) { + visitType(field.type); + } + } + for (const struct of schema.responses.values()) { + for (const field of struct.fields) { + visitType(field.type); + } + } + } +} + +// --------------------------------------------------------------------------- +// Friendly (human-authored) schema front-end +// +// The committed schemas are hand-edited. The friendly format is a single +// object per service with shorthand string type references; this front-end +// lowers it to the positional ["named_union", ...] form that SchemaVisitor +// already consumes, so the generators are untouched and the produced +// CompiledSchema is identical to the equivalent positional schema. +// --------------------------------------------------------------------------- + +/** Strip line and block comments from JSONC, preserving string contents. */ +export function stripJsonc(text: string): string { + let out = ""; + let inStr = false; + let strCh = ""; + for (let i = 0; i < text.length; i++) { + const c = text[i]; + const n = text[i + 1]; + if (inStr) { + out += c; + if (c === "\\") { + out += n ?? ""; + i++; + } else if (c === strCh) { + inStr = false; + } + continue; + } + if (c === '"' || c === "'") { + inStr = true; + strCh = c; + out += c; + continue; + } + if (c === "/" && n === "/") { + while (i < text.length && text[i] !== "\n") i++; + continue; + } + if (c === "/" && n === "*") { + i += 2; + while (i < text.length && !(text[i] === "*" && text[i + 1] === "/")) i++; + i++; // skip the closing '/' + continue; + } + out += c; + } + return out; +} + +/** A parsed friendly schema is recognised by its top-level `service` key. */ +export function isFriendlySchema(parsed: any): boolean { + return ( + parsed != null && + typeof parsed === "object" && + !Array.isArray(parsed) && + typeof parsed.service === "string" + ); +} + +// u32 etc. -> the positional primitive spellings resolvePrimitive() accepts. +const PRIMITIVE_SHORTHAND: Record = { + bool: "bool", + u8: "unsigned char", + u16: "unsigned short", + u32: "unsigned int", + u64: "unsigned long", + f64: "double", + string: "string", + bin32: "bin32", +}; + +/** + * Lower a friendly schema object into the positional `{ commands, responses }` + * named_union pair, plus the service name (used as the type prefix). Type names + * are `service + commandKey`; method names are derived by the generators via + * the service prefix. Named `types` are inlined at every reference — visit() + * dedups them by `__typename`, so the resulting CompiledSchema matches the + * positional form exactly. + */ +export function friendlyToPositional(parsed: any): { + commands: any; + responses: any; + service: string; +} { + const service: string = parsed.service; + if (!service) { + throw new Error("Friendly schema requires a non-empty string 'service'"); + } + const aliases: Record = parsed.aliases ?? {}; + const types: Record = parsed.types ?? {}; + // An alias is either a nominal byte type (bin32) or a transparent scalar + // synonym over a primitive (e.g. MerkleTreeId = u32). The underlying is given + // in shorthand; map it to the positional spelling visitType() consumes. + const aliasUnderlying = (name: string): string => { + const u = aliases[name]; + if (u === "bin32") return "bin32"; + const prim = PRIMITIVE_SHORTHAND[u]; + if (!prim) { + throw new Error( + `Alias '${name}' underlying '${u}' is not 'bin32' or a primitive`, + ); + } + return prim; + }; + for (const [name] of Object.entries(aliases)) { + aliasUnderlying(name); // validate up front + } + + const parseTypeRef = (ref: string): any => { + const s = ref.trim(); + if (s.endsWith("?")) { + return ["optional", [parseTypeRef(s.slice(0, -1))]]; + } + if (s.endsWith("]")) { + const lb = s.lastIndexOf("["); + if (lb < 0) throw new Error(`Malformed type reference '${ref}'`); + const inner = s.slice(0, lb); + const n = s.slice(lb + 1, -1).trim(); + if (n === "") { + return ["vector", [parseTypeRef(inner)]]; + } + const size = Number(n); + if (!Number.isInteger(size) || size <= 0) { + throw new Error(`Bad fixed-array size in type reference '${ref}'`); + } + return ["array", [parseTypeRef(inner), size]]; + } + if (s === "bytes") return ["vector", ["unsigned char"]]; + if (s in PRIMITIVE_SHORTHAND) return PRIMITIVE_SHORTHAND[s]; + if (s in aliases) return ["alias", [s, aliasUnderlying(s)]]; + if (s in types) return inlineType(s); + throw new Error( + `Unknown type reference '${ref}' (not a primitive, declared alias, or declared type)`, + ); + }; + + const structBody = (fieldObj: Record, typename: string) => { + const body: any = { __typename: typename }; + for (const [fname, fref] of Object.entries(fieldObj ?? {})) { + body[fname] = parseTypeRef(fref); + } + return body; + }; + + const inlineType = (typeName: string): any => + structBody(types[typeName], typeName); + + const commands: any = ["named_union", []]; + const responses: any = ["named_union", []]; + + for (const [key, def] of Object.entries(parsed.commands ?? {})) { + const cmdName = service + key; + commands[1].push([cmdName, structBody(def.request, cmdName)]); + + if (typeof def.response === "string") { + // Reuse another command's response shape: dedup string-ref form. + const respName = service + def.response; + responses[1].push([respName, respName]); + } else { + const respName = `${cmdName}Response`; + responses[1].push([respName, structBody(def.response, respName)]); + } + } + + const errorName = `${service}ErrorResponse`; + responses[1].push([errorName, structBody(parsed.error, errorName)]); + + return { commands, responses, service }; +} diff --git a/ipc-codegen/src/typescript_codegen.ts b/ipc-codegen/src/typescript_codegen.ts new file mode 100644 index 000000000000..cf57283df55f --- /dev/null +++ b/ipc-codegen/src/typescript_codegen.ts @@ -0,0 +1,693 @@ +/** + * TypeScript Code Generator - String template based + * + * Philosophy: + * - String templates for file structure + * - Simple type mapping + * - Idiomatic TypeScript conventions + * - No complex abstraction + */ + +import type { + CompiledSchema, + Type, + Struct, + Field, + Command, +} from "./schema_visitor.ts"; +import { + toPascalCase, + toSnakeCase, + toCamelCase, + toAliasName, + dedupeStructsByName, +} from "./naming.ts"; + +export class TypeScriptCodegen { + private errorTypeName: string = "ErrorResponse"; + /** Prefix to strip from command names when generating method names (e.g. "Bb" -> BbCircuitProve becomes circuitProve) */ + private methodPrefix: string = ""; + + constructor(options?: { stripMethodPrefix?: string }) { + if (options?.stripMethodPrefix) { + this.methodPrefix = options.stripMethodPrefix; + } + } + + /** Strip the method prefix and convert to camelCase for API method names */ + private toMethodName(commandName: string): string { + let name = commandName; + if (this.methodPrefix && name.startsWith(this.methodPrefix)) { + name = name.slice(this.methodPrefix.length); + } + return toCamelCase(name); + } + + private primitiveType(type: Type): string { + switch (type.primitive) { + case "bool": + return "boolean"; + case "u8": + case "u16": + case "u32": + case "u64": + case "f64": + return "number"; + case "string": + return "string"; + case "bytes": + case "bin32": + return "Uint8Array"; + } + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + + private isU8Array(type: Type): boolean { + return ( + type.kind === "array" && + type.element?.kind === "primitive" && + type.element.primitive === "u8" + ); + } + + // Type mapping: Schema type -> TypeScript type + private mapType(type: Type): string { + switch (type.kind) { + case "primitive": + return type.originalName + ? toAliasName(type.originalName) + : this.primitiveType(type); + + case "vector": { + const inner = this.mapType(type.element!); + // Wrap union types in parens to avoid precedence issues: (Foo | undefined)[] + return type.element!.kind === "optional" + ? `(${inner})[]` + : `${inner}[]`; + } + + case "array": { + if (this.isU8Array(type)) { + return "Uint8Array"; + } + const inner = this.mapType(type.element!); + return type.element!.kind === "optional" + ? `(${inner})[]` + : `${inner}[]`; + } + + case "optional": + return `${this.mapType(type.element!)} | null`; + + case "struct": + return toPascalCase(type.struct!.name); + } + + throw new Error(`Unsupported type kind: ${type.kind}`); + } + + // Type mapping for msgpack interfaces (uses Msgpack* prefix for structs) + private mapMsgpackType(type: Type): string { + switch (type.kind) { + case "primitive": + // u64 crosses the wire as bigint beyond 32 bits (see toWireU64). + return type.primitive === "u64" + ? "number | bigint" + : this.primitiveType(type); + + case "vector": { + const inner = this.mapMsgpackType(type.element!); + // Parenthesize union element types: number | bigint[] != (number | bigint)[] + return inner.includes("|") ? `(${inner})[]` : `${inner}[]`; + } + + case "array": { + if (this.isU8Array(type)) { + return "Uint8Array"; + } + const inner = this.mapMsgpackType(type.element!); + return inner.includes("|") ? `(${inner})[]` : `${inner}[]`; + } + + case "optional": + return `${this.mapMsgpackType(type.element!)} | null`; + + case "struct": + return `Msgpack${toPascalCase(type.struct!.name)}`; + } + + throw new Error(`Unsupported msgpack type kind: ${type.kind}`); + } + + // Check if type needs conversion (has nested structs) + private needsConversion(type: Type): boolean { + switch (type.kind) { + case "primitive": + return false; + case "vector": + case "array": + case "optional": + return this.needsConversion(type.element!); + case "struct": + return true; + } + return false; + } + + // Generate field + private generateField(field: Field): string { + const tsName = toCamelCase(field.name); + const tsType = this.mapType(field.type); + return ` ${tsName}: ${tsType};`; + } + + // Generate msgpack field (original names, uses Msgpack* types for structs) + private generateMsgpackField(field: Field): string { + const tsType = this.mapMsgpackType(field.type); + return ` ${field.name}: ${tsType};`; + } + + // Generate public interface + private generateInterface(struct: Struct): string { + const tsName = toPascalCase(struct.name); + const fields = struct.fields.map((f) => this.generateField(f)).join("\n"); + + return `export interface ${tsName} { +${fields} +}`; + } + + // Generate msgpack interface (internal) + private generateMsgpackInterface(struct: Struct): string { + const tsName = toPascalCase(struct.name); + const fields = struct.fields + .map((f) => this.generateMsgpackField(f)) + .join("\n"); + + return `interface Msgpack${tsName} { +${fields} +}`; + } + + // Generate to* conversion function + private generateToFunction(struct: Struct): string { + const tsName = toPascalCase(struct.name); + + if (struct.fields.length === 0) { + return `function to${tsName}(o: Msgpack${tsName}): ${tsName} { + return {}; +}`; + } + + const checks = struct.fields + .filter((f) => f.type.kind !== "optional") + .map( + (f) => + ` if (o.${f.name} === undefined) { throw new Error("Expected ${f.name} in ${tsName} deserialization"); }`, + ) + .join("\n"); + + const conversions = struct.fields + .map((f) => { + const tsFieldName = toCamelCase(f.name); + const converter = this.generateToConverter(f.type, `o.${f.name}`); + return ` ${tsFieldName}: ${converter},`; + }) + .join("\n"); + + return `function to${tsName}(o: Msgpack${tsName}): ${tsName} { +${checks}; + return { +${conversions} + }; +}`; + } + + // Generate from* conversion function + private generateFromFunction(struct: Struct): string { + const tsName = toPascalCase(struct.name); + + if (struct.fields.length === 0) { + return `function from${tsName}(o: ${tsName}): Msgpack${tsName} { + return {}; +}`; + } + + const checks = struct.fields + .filter((f) => f.type.kind !== "optional") + .map((f) => { + const tsFieldName = toCamelCase(f.name); + return ` if (o.${tsFieldName} === undefined) { throw new Error("Expected ${tsFieldName} in ${tsName} serialization"); }`; + }) + .join("\n"); + + const conversions = struct.fields + .map((f) => { + const tsFieldName = toCamelCase(f.name); + const converter = this.generateFromConverter( + f.type, + `o.${tsFieldName}`, + ); + return ` ${f.name}: ${converter},`; + }) + .join("\n"); + + return `function from${tsName}(o: ${tsName}): Msgpack${tsName} { +${checks}; + return { +${conversions} + }; +}`; + } + + /** + * Generate a conversion expression for a field in either direction. + * Primitives that can be silently mis-encoded get runtime guards: + * u64 (precision loss past 2^53 until a bigint migration) and bin32 + * (length must be exactly 32 — other languages enforce this). + * Optionals normalize undefined to null so omitted fields are valid. + */ + private generateConverter( + dir: "to" | "from", + type: Type, + value: string, + ): string { + switch (type.kind) { + case "primitive": + if (type.primitive === "u64") { + return dir === "from" + ? `toWireU64(${value}, ${JSON.stringify(value)})` + : `assertU64(${value}, ${JSON.stringify(value)})`; + } + if (type.primitive === "bin32") { + return `assertBin32(${value}, ${JSON.stringify(value)})`; + } + return value; + case "vector": + case "array": { + if (this.isU8Array(type)) { + return value; + } + const elem = this.generateConverter(dir, type.element!, "v"); + return elem === "v" ? value : `${value}.map((v: any) => ${elem})`; + } + case "optional": { + const inner = this.generateConverter(dir, type.element!, value); + return inner === value + ? `${value} ?? null` + : `${value} != null ? ${inner} : null`; + } + case "struct": + return `${dir}${toPascalCase(type.struct!.name)}(${value})`; + } + return value; + } + + private generateToConverter(type: Type, value: string): string { + return this.generateConverter("to", type, value); + } + + private generateFromConverter(type: Type, value: string): string { + return this.generateConverter("from", type, value); + } + + // Generate types file (api_types.ts) + generateTypes(schema: CompiledSchema, schemaHash?: string): string { + const allStructs = dedupeStructsByName([ + ...schema.structs.values(), + ...schema.responses.values(), + ]); + + const aliasTypes = new Map(); + const collectAliases = (type: Type): void => { + if (type.kind === "primitive" && type.originalName) { + aliasTypes.set( + toAliasName(type.originalName), + this.primitiveType(type), + ); + } else if ( + type.kind === "vector" || + type.kind === "array" || + type.kind === "optional" + ) { + if (type.element) collectAliases(type.element); + } + }; + for (const s of allStructs) { + for (const f of s.fields) collectAliases(f.type); + } + const aliasDecls = [...aliasTypes.entries()] + .sort(([a], [b]) => a.localeCompare(b)) + .map(([name, underlying]) => `export type ${name} = ${underlying};`) + .join("\n"); + + // Public interfaces + const publicInterfaces = allStructs + .map((s) => this.generateInterface(s)) + .join("\n\n"); + + // Msgpack interfaces + const msgpackInterfaces = allStructs + .map((s) => this.generateMsgpackInterface(s)) + .join("\n\n"); + + // Conversion functions + const toFunctions = allStructs + .map((s) => "export " + this.generateToFunction(s)) + .join("\n\n"); + + const fromFunctions = allStructs + .map((s) => "export " + this.generateFromFunction(s)) + .join("\n\n"); + + const asyncApiMethods = schema.commands + .map( + (c) => + ` ${this.toMethodName(c.name)}(command: ${toPascalCase(c.name)}): Promise<${toPascalCase(c.responseType)}>;`, + ) + .join("\n"); + const syncApiMethods = schema.commands + .map( + (c) => + ` ${this.toMethodName(c.name)}(command: ${toPascalCase(c.name)}): ${toPascalCase(c.responseType)};`, + ) + .join("\n"); + + const hashLine = schemaHash + ? `\n/** Schema version hash for compatibility checking */\nexport const SCHEMA_HASH = '${schemaHash}';\n` + : ""; + + return `// AUTOGENERATED FILE - DO NOT EDIT +${hashLine} +// Runtime guards for wire types that JS cannot represent natively. +// TODO: migrate u64 fields to bigint end-to-end and drop these. +// +// Decode: msgpackr returns uint64/int64 wire values as bigint once they +// exceed 32 bits; values must fit in the JS safe integer range. +function assertU64(value: number | bigint, ctx: string): number { + if (typeof value === "bigint") { + if (value < 0n || value > BigInt(Number.MAX_SAFE_INTEGER)) { + throw new Error(\`\${ctx}: u64 value \${value} is outside JS safe integer range\`); + } + return Number(value); + } + if (!Number.isSafeInteger(value) || value < 0) { + throw new Error(\`\${ctx}: u64 value \${value} is outside JS safe integer range\`); + } + return value; +} + +// Encode: msgpackr encodes JS numbers above 2^32 as float64, which strict +// u64 decoders reject; route them through bigint so the wire type stays uint. +function toWireU64(value: number | bigint, ctx: string): number | bigint { + const checked = assertU64(value, ctx); + return checked > 0xffffffff ? BigInt(checked) : checked; +} + +function assertBin32(value: Uint8Array, ctx: string): Uint8Array { + if (value.length !== 32) { + throw new Error(\`\${ctx}: expected 32 bytes, got \${value.length}\`); + } + return value; +} + +// Type aliases for primitive types +${aliasDecls} + +// Public interfaces (exported) +${publicInterfaces} + +// Private Msgpack interfaces (not exported) +${msgpackInterfaces} + +// Conversion functions (exported) +${toFunctions} + +${fromFunctions} + +// Base API interfaces +export interface AsyncApiBase { +${asyncApiMethods} + destroy(): Promise; +} + +export interface SyncApiBase { +${syncApiMethods} + destroy(): void; +} +`; + } + + // Generate API method + private generateAsyncApiMethod(command: Command): string { + const methodName = this.toMethodName(command.name); + const cmdType = toPascalCase(command.name); + const respType = toPascalCase(command.responseType); + + return ` ${methodName}(command: ${cmdType}): Promise<${respType}> { + const msgpackCommand = from${cmdType}(command); + return msgpackCall(this.backend, [["${command.name}", msgpackCommand]]).then(([variantName, result]: [string, any]) => { + if (variantName === '${this.errorTypeName}') { + throw this.createError(result.message || 'Unknown error from server'); + } + if (variantName !== '${command.responseType}') { + throw new Error(\`Expected variant name '${command.responseType}' but got '\${variantName}'\`); + } + return to${respType}(result); + }); + }`; + } + + private generateSyncApiMethod(command: Command): string { + const methodName = this.toMethodName(command.name); + const cmdType = toPascalCase(command.name); + const respType = toPascalCase(command.responseType); + + return ` ${methodName}(command: ${cmdType}): ${respType} { + const msgpackCommand = from${cmdType}(command); + const [variantName, result] = msgpackCall(this.backend, [["${command.name}", msgpackCommand]]); + if (variantName === '${this.errorTypeName}') { + throw this.createError(result.message || 'Unknown error from server'); + } + if (variantName !== '${command.responseType}') { + throw new Error(\`Expected variant name '${command.responseType}' but got '\${variantName}'\`); + } + return to${respType}(result); + }`; + } + + // Generate async API file + generateAsyncApi(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const imports = this.generateApiImports(schema, "AsyncApiBase"); + const methods = schema.commands + .map((c) => this.generateAsyncApiMethod(c)) + .join("\n\n"); + + return `// AUTOGENERATED FILE - DO NOT EDIT + +import { Decoder, Encoder } from 'msgpackr'; +${imports} + +export interface IpcClientAsync { + call(input: Uint8Array): Promise; + destroy(): Promise; +} + +export type IpcErrorFactory = (message: string) => Error; + +async function msgpackCall(backend: IpcClientAsync, input: any[]) { + const inputBuffer = new Encoder({ useRecords: false, variableMapSize: true }).pack(input); + const encodedResult = await backend.call(inputBuffer); + return new Decoder({ useRecords: false }).unpack(encodedResult); +} + +export class AsyncApi implements AsyncApiBase { + constructor( + protected backend: IpcClientAsync, + protected createError: IpcErrorFactory = message => new Error(message), + ) {} + +${methods} + + destroy(): Promise { + return this.backend.destroy(); + } +} +`; + } + + // Generate sync API file + generateSyncApi(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const imports = this.generateApiImports(schema, "SyncApiBase"); + const methods = schema.commands + .map((c) => this.generateSyncApiMethod(c)) + .join("\n\n"); + + return `// AUTOGENERATED FILE - DO NOT EDIT + +import { Decoder, Encoder } from 'msgpackr'; +${imports} + +export interface IpcClientSync { + call(input: Uint8Array): Uint8Array; + destroy(): void; +} + +export type IpcErrorFactory = (message: string) => Error; + +function msgpackCall(backend: IpcClientSync, input: any[]) { + const inputBuffer = new Encoder({ useRecords: false, variableMapSize: true }).pack(input); + const encodedResult = backend.call(inputBuffer); + return new Decoder({ useRecords: false }).unpack(encodedResult); +} + +export class SyncApi implements SyncApiBase { + constructor( + protected backend: IpcClientSync, + protected createError: IpcErrorFactory = message => new Error(message), + ) {} + +${methods} + + destroy(): void { + this.backend.destroy(); + } +} +`; + } + + // Generate import statement for API files + private generateApiImports( + schema: CompiledSchema, + baseInterface: string, + ): string { + const types = new Set(); + + // Add command types and their conversion functions + for (const cmd of schema.commands) { + const cmdType = toPascalCase(cmd.name); + const respType = toPascalCase(cmd.responseType); + types.add(cmdType); + types.add(respType); + types.add(`from${cmdType}`); + types.add(`to${respType}`); + } + + types.add(baseInterface); + + const sortedTypes = Array.from(types).sort(); + return `import { ${sortedTypes.join(", ")} } from './api_types.js';`; + } + + // ----------------------------------------------------------------------- + // Server-side code generation + // ----------------------------------------------------------------------- + + /** Generate a server handler interface and dispatch function */ + generateServerApi(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const errorType = toPascalCase(this.errorTypeName); + + // Generate handler interface + const handlerMethods = schema.commands + .map((c) => { + const methodName = this.toMethodName(c.name); + const cmdType = toPascalCase(c.name); + const respType = toPascalCase(c.responseType); + return ` ${methodName}(command: ${cmdType}): Promise<${respType}>;`; + }) + .join("\n"); + + // Generate dispatch switch cases + const dispatchCases = schema.commands + .map((c) => { + const methodName = this.toMethodName(c.name); + const cmdType = toPascalCase(c.name); + const respType = toPascalCase(c.responseType); + return ` case '${c.name}': { + const cmd = to${cmdType}(payload); + const result = await handler.${methodName}(cmd); + return ['${c.responseType}', from${respType}(result)]; + }`; + }) + .join("\n"); + + // Collect imports + const importTypes = new Set(); + for (const cmd of schema.commands) { + const cmdType = toPascalCase(cmd.name); + const respType = toPascalCase(cmd.responseType); + importTypes.add(cmdType); + importTypes.add(respType); + importTypes.add(`to${cmdType}`); + importTypes.add(`from${respType}`); + } + const sortedImports = Array.from(importTypes).sort(); + + return `// AUTOGENERATED FILE - DO NOT EDIT +// Server-side dispatch for IPC protocol + +import { Decoder, Encoder } from 'msgpackr'; +import { ${sortedImports.join(", ")} } from './api_types.js'; + +/** Handler interface — implement this to serve commands. */ +export interface Handler { +${handlerMethods} +} + +/** + * Dispatch a [commandName, payload] pair to the handler. + * Returns [responseName, responsePayload] for serialization. + * Handler failures are wrapped into the schema error variant. + */ +export async function dispatch( + handler: Handler, + commandName: string, + payload: any, +): Promise<[string, any]> { + try { + switch (commandName) { +${dispatchCases} + default: + return ['${this.errorTypeName}', { message: \`Unknown command: \${commandName}\` }]; + } + } catch (err: any) { + return ['${this.errorTypeName}', { message: err?.message ?? String(err) }]; + } +} + +const requestDecoder = new Decoder({ useRecords: false }); +const responseEncoder = new Encoder({ useRecords: false, variableMapSize: true }); + +/** + * Decode a framed request, dispatch it, and encode the framed response. + * All failures (malformed framing included) produce a decodable error + * variant rather than a throw, so transports can use this directly as + * their request handler. + */ +export async function handleRequest( + handler: Handler, + requestBytes: Uint8Array, +): Promise { + let commandName: string; + let payload: any; + try { + const request = requestDecoder.unpack(requestBytes) as [[string, any]]; + [[commandName, payload]] = request; + if (typeof commandName !== 'string') { + throw new Error('expected [name, payload] request framing'); + } + } catch (err: any) { + return responseEncoder.pack([ + '${this.errorTypeName}', + { message: \`Malformed request: \${err?.message ?? String(err)}\` }, + ]); + } + const [respName, respPayload] = await dispatch(handler, commandName, payload ?? {}); + return responseEncoder.pack([respName, respPayload]); +} +`; + } +} diff --git a/ipc-codegen/src/typescript_package_codegen.ts b/ipc-codegen/src/typescript_package_codegen.ts new file mode 100644 index 000000000000..18fa48b1c1f1 --- /dev/null +++ b/ipc-codegen/src/typescript_package_codegen.ts @@ -0,0 +1,501 @@ +import { toSnakeCase } from "./naming.ts"; + +export interface TypeScriptPackageOptions { + prefix: string; + packageName: string; + binaryName: string; + binaryEnvVar: string; + ipcRuntimeDependency: string; + ipcPathArgs: string[]; + transports: string[]; +} + +function className(prefix: string): string { + return `${prefix}Service`; +} + +function transportType(prefix: string): string { + return `${prefix}Transport`; +} + +function optionsType(prefix: string): string { + return `${prefix}ServiceOptions`; +} + +function binaryFinderName(prefix: string): string { + return `find${prefix}Binary`; +} + +function envName(binaryName: string): string { + return `${binaryName.replace(/[^a-zA-Z0-9]+/g, "_").toUpperCase()}_PATH`; +} + +function packageStem(packageName: string): string { + return packageName.startsWith("@") + ? packageName.split("/")[1]! + : packageName; +} + +function archPackageNames(packageName: string): Record { + return { + "linux-x64": `${packageName}-linux-x64`, + "darwin-x64": `${packageName}-darwin-x64`, + "linux-arm64": `${packageName}-linux-arm64`, + "darwin-arm64": `${packageName}-darwin-arm64`, + }; +} + +const ARCH_PACKAGES = [ + { buildDir: "amd64-linux", suffix: "linux-x64", os: "linux", cpu: "x64" }, + { buildDir: "arm64-linux", suffix: "linux-arm64", os: "linux", cpu: "arm64" }, + { buildDir: "amd64-macos", suffix: "darwin-x64", os: "darwin", cpu: "x64" }, + { buildDir: "arm64-macos", suffix: "darwin-arm64", os: "darwin", cpu: "arm64" }, +] as const; + +export function defaultBinaryEnvVar(binaryName: string): string { + return envName(binaryName); +} + +export class TypeScriptPackageCodegen { + constructor(private opts: TypeScriptPackageOptions) {} + + generatePackageJson(): string { + const archPackages = archPackageNames(this.opts.packageName); + const scripts: Record = { + clean: "rm -rf dest .tsbuildinfo", + build: "tsc -p tsconfig.json", + prepare_arch_packages: "./scripts/prepare_arch_packages.sh", + }; + + const pkg = { + name: this.opts.packageName, + version: "0.1.0", + type: "module", + exports: { + ".": { + types: "./dest/index.d.ts", + default: "./dest/index.js", + }, + }, + files: ["dest/", "README.md"], + scripts, + dependencies: { + "@aztec/ipc-runtime": this.opts.ipcRuntimeDependency, + msgpackr: "^1.11.2", + tslib: "^2.4.0", + }, + optionalDependencies: Object.fromEntries( + Object.values(archPackages).map((packageName) => [packageName, "0.1.0"]), + ), + devDependencies: { + "@types/node": "^22.15.17", + typescript: "^5.3.3", + }, + }; + return JSON.stringify(pkg, null, 2) + "\n"; + } + + generateArchPackageJson(suffix: string, os: string, cpu: string): string { + const pkg = { + name: `${this.opts.packageName}-${suffix}`, + version: "0.1.0", + description: `Native binary for ${this.opts.packageName} (${suffix})`, + license: "MIT", + os: [os], + cpu: [cpu], + files: [this.opts.binaryName], + preferUnplugged: true, + }; + return JSON.stringify(pkg, null, 2) + "\n"; + } + + generateArchPackageManifests(): Array<{ path: string; content: string }> { + const stem = packageStem(this.opts.packageName); + return ARCH_PACKAGES.map(({ suffix, os, cpu }) => ({ + path: `packages/${stem}-${suffix}/package.json`, + content: this.generateArchPackageJson(suffix, os, cpu), + })); + } + + generateTsconfig(): string { + return JSON.stringify( + { + compilerOptions: { + target: "ES2022", + module: "NodeNext", + moduleResolution: "NodeNext", + declaration: true, + declarationMap: true, + composite: true, + outDir: "dest", + rootDir: "src", + tsBuildInfoFile: ".tsbuildinfo", + strict: true, + esModuleInterop: true, + skipLibCheck: true, + forceConsistentCasingInFileNames: true, + }, + include: ["src/**/*.ts"], + }, + null, + 2, + ) + "\n"; + } + + generateIndex(): string { + const prefix = this.opts.prefix; + const serviceClass = className(prefix); + const serviceOptions = optionsType(prefix); + const serviceTransport = transportType(prefix); + const findBinary = binaryFinderName(prefix); + const supportsShm = this.opts.transports.includes("shm"); + const transports = this.opts.transports.map((t) => `'${t}'`).join(" | "); + const ipcPathArgs = JSON.stringify(this.opts.ipcPathArgs); + const defaultTransport = this.opts.transports.includes("uds") + ? "uds" + : this.opts.transports[0]!; + + return `import { spawn, type ChildProcess } from 'node:child_process'; +import { existsSync, unlinkSync } from 'node:fs'; +import { tmpdir } from 'node:os'; +import { join } from 'node:path'; +import { threadId } from 'node:worker_threads'; +import { + ${supportsShm ? "createNapiShmAsyncClient,\n " : ""}UdsIpcClient, + type IpcClientAsync, +} from '@aztec/ipc-runtime'; +import { AsyncApi, type IpcErrorFactory } from './generated/async.js'; +import { ${findBinary} } from './platform.js'; + +export * from './generated/api_types.js'; +export { AsyncApi } from './generated/async.js'; +export { SyncApi } from './generated/sync.js'; + +export type ${serviceTransport} = ${transports}; + +export interface ${serviceOptions} { + binaryPath?: string; + transport?: ${serviceTransport}; + logger?: (msg: string) => void; + connectTimeoutMs?: number; + env?: NodeJS.ProcessEnv; + extraArgs?: string[]; + createError?: IpcErrorFactory; +${supportsShm ? " napiPath?: string;\n clientId?: number;\n" : ""}} + +let instanceCounter = 0; +const DEFAULT_CONNECT_TIMEOUT_MS = 30_000; + +class SpawnedBackend implements IpcClientAsync { + private constructor( + private child: ChildProcess, + private client: IpcClientAsync, + private ipcPath: string, + private transport: ${serviceTransport}, + private exitPromise: Promise, + ) {} + + static async spawn(options: ${serviceOptions} = {}): Promise { + const binaryPath = ${findBinary}(options.binaryPath); + if (!binaryPath) { + throw new Error('${this.opts.binaryName} binary not found'); + } + + const transport = options.transport ?? '${defaultTransport}'; + const instanceId = '${toSnakeCase(prefix)}-' + process.pid + '-' + threadId + '-' + instanceCounter++; + const ipcPath = transport === 'shm' + ? instanceId + '.shm' + : join(tmpdir(), instanceId + '.sock'); + + if (transport === 'uds' && existsSync(ipcPath)) { + unlinkSync(ipcPath); + } + + const ipcPathArgs = ${ipcPathArgs}.map((arg: string) => arg === '{path}' ? ipcPath : arg); + const child = spawn(binaryPath, [...ipcPathArgs, ...(options.extraArgs ?? [])], { + stdio: ['ignore', options.logger ? 'pipe' : 'ignore', options.logger ? 'pipe' : 'ignore'], + env: { ...process.env, ...(options.env ?? {}) }, + }); + + if (options.logger) { + child.stdout?.on('data', (data: Buffer) => options.logger?.('[${this.opts.binaryName} stdout] ' + data.toString().trimEnd())); + child.stderr?.on('data', (data: Buffer) => options.logger?.('[${this.opts.binaryName} stderr] ' + data.toString().trimEnd())); + } + + const exitPromise = new Promise(resolve => { + child.on('exit', () => resolve()); + }); + + const childReadyFailure = new Promise((_, reject) => { + child.once('error', reject); + child.once('exit', (code, signal) => { + reject( + new Error('${this.opts.binaryName} exited before IPC connection was ready (code=' + code + ', signal=' + signal + ')'), + ); + }); + }); + + const client = await Promise.race([connectClient(child, ipcPath, transport, options), childReadyFailure]); + return new SpawnedBackend(child, client, ipcPath, transport, exitPromise); + } + + getIpcPath(): string { + return this.ipcPath; + } + + call(input: Uint8Array): Promise { + return this.client.call(input); + } + + async destroy(): Promise { + await this.client.destroy(); + if (this.child.exitCode === null) { + this.child.kill('SIGTERM'); + } + await this.exitPromise; + this.child.stdout?.destroy(); + this.child.stderr?.destroy(); + this.child.removeAllListeners(); + cleanupIpcPath(this.ipcPath, this.transport); + } +} + +async function connectClient( + child: ChildProcess, + ipcPath: string, + transport: ${serviceTransport}, + options: ${serviceOptions}, +): Promise { + const timeoutMs = options.connectTimeoutMs ?? DEFAULT_CONNECT_TIMEOUT_MS; + const deadline = Date.now() + timeoutMs; + let lastError: unknown; + + while (Date.now() <= deadline) { + if (child.exitCode !== null) { + throw new Error('${this.opts.binaryName} exited before IPC connection was ready'); + } + try { + if (transport === 'uds') { + return await UdsIpcClient.connect(ipcPath, { connectTimeoutMs: Math.max(1, deadline - Date.now()) }); + } +${supportsShm ? ` if (transport === 'shm') { + return createNapiShmAsyncClient(ipcPath.replace(/\\.shm$/, ''), { + clientId: options.clientId ?? 0, + customAddonPath: options.napiPath, + }); + } +` : ""} throw new Error('Unsupported transport: ' + transport); + } catch (err) { + lastError = err; + await new Promise(resolve => setTimeout(resolve, 50)); + } + } + + throw new Error('Timed out connecting to ${this.opts.binaryName}: ' + (lastError instanceof Error ? lastError.message : String(lastError))); +} + +function cleanupIpcPath(ipcPath: string, transport: ${serviceTransport}) { + try { + if (transport === 'uds' && existsSync(ipcPath)) { + unlinkSync(ipcPath); + } + if (transport === 'shm') { + const shmName = ipcPath.replace(/\\.shm$/, ''); + for (const suffix of ['_request', '_response']) { + const shmPath = '/dev/shm/' + shmName + suffix; + if (existsSync(shmPath)) { + unlinkSync(shmPath); + } + } + } + } catch {} +} + +export class ${serviceClass} extends AsyncApi { + private constructor(private spawnedBackend: SpawnedBackend, createError?: IpcErrorFactory) { + super(spawnedBackend, createError); + } + + static async spawn(options: ${serviceOptions} = {}): Promise<${serviceClass}> { + const backend = await SpawnedBackend.spawn(options); + return new ${serviceClass}(backend, options.createError); + } + + getIpcPath(): string { + return this.spawnedBackend.getIpcPath(); + } +} +`; + } + + generatePlatform(): string { + const packageName = this.opts.packageName; + const findBinary = binaryFinderName(this.opts.prefix); + const envVar = this.opts.binaryEnvVar; + const stem = packageStem(packageName); + const archPackages = archPackageNames(packageName); + + return `import { createRequire } from 'node:module'; +import * as fs from 'node:fs'; +import * as path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +export type Platform = 'x86_64-linux' | 'x86_64-darwin' | 'aarch64-linux' | 'aarch64-darwin'; + +const PLATFORM_TO_PACKAGE: Record = { + 'x86_64-linux': '${archPackages["linux-x64"]}', + 'x86_64-darwin': '${archPackages["darwin-x64"]}', + 'aarch64-linux': '${archPackages["linux-arm64"]}', + 'aarch64-darwin': '${archPackages["darwin-arm64"]}', +}; + +function currentDir(): string { + return path.dirname(fileURLToPath(import.meta.url)); +} + +function detectPlatform(): Platform | null { + if (process.arch === 'x64' && process.platform === 'linux') return 'x86_64-linux'; + if (process.arch === 'x64' && process.platform === 'darwin') return 'x86_64-darwin'; + if (process.arch === 'arm64' && process.platform === 'linux') return 'aarch64-linux'; + if (process.arch === 'arm64' && process.platform === 'darwin') return 'aarch64-darwin'; + return null; +} + +function findArchPackageDir(platform: Platform): string | null { + const packageName = PLATFORM_TO_PACKAGE[platform]; + try { + const require = createRequire(import.meta.url); + return path.dirname(require.resolve(packageName + '/package.json')); + } catch { + const siblingPackageDir = path.join(currentDir(), '..', 'packages', packageName.split('/').pop()!); + return fs.existsSync(path.join(siblingPackageDir, 'package.json')) ? siblingPackageDir : null; + } +} + +export function ${findBinary}(customPath?: string): string | null { + if (customPath) { + return fs.existsSync(customPath) ? path.resolve(customPath) : null; + } + + const envPath = process.env.${envVar}; + if (envPath) { + return fs.existsSync(envPath) ? path.resolve(envPath) : null; + } + + const platform = detectPlatform(); + if (!platform) { + return null; + } + + const archDir = findArchPackageDir(platform); + if (archDir) { + const candidate = path.join(archDir, '${this.opts.binaryName}'); + if (fs.existsSync(candidate)) { + return candidate; + } + } + + return null; +} + +export const ARCH_PACKAGE_STEM = '${stem}'; +`; + } + + generatePrepareArchPackagesScript(): string { + return `#!/usr/bin/env bash +set -euo pipefail + +cd "$(dirname "$0")/.." + +declare -A PLATFORMS=( +${ARCH_PACKAGES.map(({ buildDir, suffix, os, cpu }) => ` ["${buildDir}"]="${suffix} ${os} ${cpu}"`).join("\n")} +) + +version=$(node -p "require('./package.json').version") + +declare -A BINARIES=() +for arg in "$@"; do + case "$arg" in + *=*) + key="\${arg%%=*}" + value="\${arg#*=}" + BINARIES["$key"]="$value" + ;; + *) + echo "Usage: npm run prepare_arch_packages -- [= ...]" >&2 + echo "Platforms: linux-x64, linux-arm64, darwin-x64, darwin-arm64" >&2 + exit 1 + ;; + esac +done + +for build_dir in "\${!PLATFORMS[@]}"; do + read -r suffix os cpu <<< "\${PLATFORMS[$build_dir]}" + pkg_name="${this.opts.packageName}-\${suffix}" + out_dir="packages/${packageStem(this.opts.packageName)}-\${suffix}" + binary_path="\${BINARIES[$suffix]:-\${BINARIES[$build_dir]:-}}" + + if [ -z "$binary_path" ]; then + binary_path="build/\${build_dir}/${this.opts.binaryName}" + fi + + if [ ! -f "$binary_path" ]; then + echo "Skipping \${pkg_name}: no binary at \${binary_path}" + continue + fi + + rm -rf "\${out_dir}" + mkdir -p "\${out_dir}" + cp "$binary_path" "\${out_dir}/${this.opts.binaryName}" + chmod +x "\${out_dir}/${this.opts.binaryName}" 2>/dev/null || true + + cat > "\${out_dir}/package.json" < get_leaf */ + stripMethodPrefix?: boolean; + /** Client struct name (e.g., 'WsdbClient') */ + clientName?: string; +} + +export class ZigCodegen { + private errorTypeName: string = "ErrorResponse"; + private opts: Required; + + constructor(options?: ZigCodegenOptions) { + this.opts = { + prefix: options?.prefix ?? "", + clientName: options?.clientName ?? "Client", + stripMethodPrefix: options?.stripMethodPrefix ?? false, + }; + } + + private primitiveType(type: Type): string { + switch (type.primitive) { + case "bool": + return "bool"; + case "u8": + return "u8"; + case "u16": + return "u16"; + case "u32": + return "u32"; + case "u64": + return "u64"; + case "f64": + return "f64"; + case "string": + return "[]const u8"; + case "bytes": + return "[]const u8"; + case "bin32": + return "[32]u8"; + } + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + + /** Map schema type to Zig type */ + private mapType(type: Type): string { + switch (type.kind) { + case "primitive": + return type.originalName + ? toAliasName(type.originalName) + : this.primitiveType(type); + case "vector": + return `[]const ${this.mapType(type.element!)}`; + case "array": + return `[${type.size}]${this.mapType(type.element!)}`; + case "optional": + return `?${this.mapType(type.element!)}`; + case "struct": + return toPascalCase(type.struct!.name); + } + throw new Error(`Unsupported type kind: ${type.kind}`); + } + + /** Generate a Zig field-to-payload conversion expression */ + private fieldToPayload( + fieldExpr: string, + type: import("./schema_visitor.ts").Type, + ): string { + switch (type.kind) { + case "primitive": + switch (type.primitive) { + case "bool": + return `Payload{ .bool = ${fieldExpr} }`; + case "u8": + case "u16": + case "u32": + case "u64": + return `Payload{ .uint = @intCast(${fieldExpr}) }`; + case "f64": + return `Payload{ .float = ${fieldExpr} }`; + case "string": + return `try Payload.strToPayload(${fieldExpr}, allocator)`; + case "bytes": + return `try Payload.binToPayload(${fieldExpr}, allocator)`; + case "bin32": + return `try Payload.binToPayload(&${fieldExpr}, allocator)`; + default: + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + case "optional": + return `if (${fieldExpr}) |v| ${this.fieldToPayload("v", type.element!)} else Payload{ .nil = {} }`; + case "vector": { + // For vectors, build an array payload + return `blk: { + var arr = try Payload.arrPayload(${fieldExpr}.len, allocator); + for (${fieldExpr}, 0..) |item, i| { + try arr.setArrElement(i, ${this.fieldToPayload("item", type.element!)}); + } + break :blk arr; + }`; + } + case "struct": + return `try ${fieldExpr}.toPayload(allocator)`; + case "array": + return `blk: { + var arr = try Payload.arrPayload(${fieldExpr}.len, allocator); + for (${fieldExpr}, 0..) |item, i| { + try arr.setArrElement(i, ${this.fieldToPayload("item", type.element!)}); + } + break :blk arr; + }`; + default: + throw new Error(`Unsupported type kind: ${type.kind}`); + } + } + + /** Generate a Zig payload-to-field conversion expression */ + private fieldFromPayload( + payloadExpr: string, + type: import("./schema_visitor.ts").Type, + ): string { + switch (type.kind) { + case "primitive": + switch (type.primitive) { + case "bool": + return `try ${payloadExpr}.asBool()`; + case "u8": + return `try payloadCastUint(u8, ${payloadExpr})`; + case "u16": + return `try payloadCastUint(u16, ${payloadExpr})`; + case "u32": + return `try payloadCastUint(u32, ${payloadExpr})`; + case "u64": + return `try payloadCastUint(u64, ${payloadExpr})`; + case "f64": + return `try ${payloadExpr}.asFloat()`; + case "string": + return `try ${payloadExpr}.asStr()`; + case "bytes": + return `${payloadExpr}.bin.value()`; + case "bin32": + return `${payloadExpr}.bin.value()[0..32].*`; + default: + throw new Error(`Unsupported primitive type: ${type.primitive}`); + } + case "vector": { + const elemConv = this.fieldFromPayload("elem", type.element!); + return `blk: { + const arr_len = try ${payloadExpr}.getArrLen(); + var result = try std.heap.page_allocator.alloc(${this.mapType(type.element!)}, arr_len); + for (0..arr_len) |i| { + const elem = try ${payloadExpr}.getArrElement(i); + result[i] = ${elemConv}; + } + break :blk result; + }`; + } + case "optional": + return `if (${payloadExpr} == .nil) null else ${this.fieldFromPayload(payloadExpr, type.element!)}`; + case "struct": + return `try ${toPascalCase(type.struct!.name)}.fromPayload(${payloadExpr})`; + case "array": { + const elemConv = this.fieldFromPayload("elem", type.element!); + return `blk: { + var result: ${this.mapType(type)} = undefined; + for (0..${type.size}) |i| { + const elem = try ${payloadExpr}.getArrElement(i); + result[i] = ${elemConv}; + } + break :blk result; + }`; + } + default: + throw new Error(`Unsupported type kind: ${type.kind}`); + } + } + + /** Generate a Zig struct definition with toPayload/fromPayload methods */ + private generateStruct(struct: Struct): string { + const zigName = toPascalCase(struct.name); + const fields = struct.fields + .map((f) => { + const zigFieldName = toSnakeCase(f.name); + const zigType = this.mapType(f.type); + return ` ${zigFieldName}: ${zigType},`; + }) + .join("\n"); + + const hasFields = struct.fields.length > 0; + + // toPayload method + const toPayloadFields = struct.fields + .map((f) => { + const zigFieldName = toSnakeCase(f.name); + return ` try map.mapPut("${f.name}", ${this.fieldToPayload(`self.${zigFieldName}`, f.type)});`; + }) + .join("\n"); + + // fromPayload method + const fromPayloadFields = struct.fields + .map((f) => { + const zigFieldName = toSnakeCase(f.name); + return ` .${zigFieldName} = ${this.fieldFromPayload(`(try payload.mapGet("${f.name}")).?`, f.type)},`; + }) + .join("\n"); + + // Empty structs: suppress unused parameter warnings + if (!hasFields) { + return `/// ${struct.name} +pub const ${zigName} = struct { + + pub fn toPayload(_: ${zigName}, allocator: std.mem.Allocator) !Payload { + return Payload.mapPayload(allocator); + } + + pub fn fromPayload(_: Payload) !${zigName} { + return ${zigName}{}; + } +};`; + } + + return `/// ${struct.name} +pub const ${zigName} = struct { +${fields} + + pub fn toPayload(self: ${zigName}, allocator: std.mem.Allocator) !Payload { + var map = Payload.mapPayload(allocator); +${toPayloadFields} + return map; + } + + pub fn fromPayload(payload: Payload) !${zigName} { + return ${zigName}{ +${fromPayloadFields} + }; + } +};`; + } + + /** Generate the Command tagged union */ + private generateCommandUnion(schema: CompiledSchema): string { + const variants = schema.commands + .map((c) => { + const zigName = toPascalCase(c.name); + return ` ${toSnakeCase(c.name)}: ${zigName},`; + }) + .join("\n"); + + const nameMap = schema.commands + .map((c) => { + return ` .${toSnakeCase(c.name)} => "${c.name}",`; + }) + .join("\n"); + + return `/// Tagged union of all commands +pub const Command = union(enum) { +${variants} + + pub fn schemaName(self: Command) []const u8 { + return switch (self) { +${nameMap} + }; + } +};`; + } + + /** Generate the Response tagged union */ + private generateResponseUnion(schema: CompiledSchema): string { + const commandResponseTypes = Array.from( + new Set(schema.commands.map((c) => c.responseType)), + ); + const errorName = schema.errorTypeName; + const responseTypes = schema.responses.has(errorName) + ? [...commandResponseTypes, errorName] + : commandResponseTypes; + + const variants = responseTypes + .map((name) => { + const zigName = toPascalCase(name); + return ` ${toSnakeCase(name)}: ${zigName},`; + }) + .join("\n"); + + return `/// Tagged union of all responses +pub const Response = union(enum) { +${variants} +};`; + } + + /** Generate the types file */ + generateTypes(schema: CompiledSchema, schemaHash?: string): string { + this.errorTypeName = schema.errorTypeName; + + const allStructs = dedupeStructsByName([ + ...schema.structs.values(), + ...schema.responses.values(), + ]); + + const aliasTypes = new Map(); + const collect = (type: Type): void => { + if (type.kind === "primitive" && type.originalName) { + aliasTypes.set( + toAliasName(type.originalName), + type.primitive === "bin32" ? "[32]u8" : this.primitiveType(type), + ); + } else if ( + type.kind === "vector" || + type.kind === "array" || + type.kind === "optional" + ) { + if (type.element) collect(type.element); + } + }; + for (const s of schema.structs.values()) { + for (const f of s.fields) collect(f.type); + } + for (const s of schema.responses.values()) { + for (const f of s.fields) collect(f.type); + } + const aliasDecls = [...aliasTypes.entries()] + .sort(([a], [b]) => a.localeCompare(b)) + .map(([name, underlying]) => `pub const ${name} = ${underlying};`) + .join("\n"); + + const structDefs = allStructs + .map((s) => this.generateStruct(s)) + .join("\n\n"); + + const hashLine = schemaHash + ? `\n/// Schema version hash for compatibility checking\npub const SCHEMA_HASH = "${schemaHash}";\n` + : ""; + + return `//! AUTOGENERATED - DO NOT EDIT +//! Generated from IPC msgpack schema +//! +//! Each struct has toPayload() and fromPayload() methods that convert +//! to/from zig-msgpack Payload objects for serialization. + +const std = @import("std"); +const msgpack = @import("msgpack"); +const Payload = msgpack.Payload; +const PackerIO = msgpack.PackerIO; +${hashLine} +/// Decode an unsigned wire integer with range checking. Accepts both msgpack +/// uint and non-negative int encodings: some encoders (e.g. msgpackr via +/// bigint) emit positive values with the signed int64 (0xd3) format. +pub fn payloadCastUint(comptime T: type, payload: Payload) !T { + const wide: u64 = switch (payload) { + .uint => |v| v, + .int => |v| if (v >= 0) @as(u64, @intCast(v)) else return error.InvalidType, + else => return error.InvalidType, + }; + return std.math.cast(T, wide) orelse error.InvalidType; +} + +// --------------------------------------------------------------------------- +// Primitive schema aliases. Bin32 aliases use [32]u8 and are encoded as +// msgpack bin32 by fieldToPayload / fieldFromPayload; scalar aliases use +// their scalar wire type directly. +// --------------------------------------------------------------------------- + +${aliasDecls} + +// --------------------------------------------------------------------------- +// Type definitions +// --------------------------------------------------------------------------- + +${structDefs} + +// --------------------------------------------------------------------------- +// Command / Response unions +// --------------------------------------------------------------------------- + +${this.generateCommandUnion(schema)} + +${this.generateResponseUnion(schema)} +`; + } + + /** Convert a command name to a Zig method name (snake_case) */ + private methodName(commandName: string): string { + const withoutPrefix = + this.opts.stripMethodPrefix && + this.opts.prefix && + commandName.startsWith(this.opts.prefix) + ? commandName.slice(this.opts.prefix.length) + : commandName; + return toSnakeCase(withoutPrefix); + } + + /** Generate the client wrapper — typed methods parameterized on backend type */ + generateClient(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const { prefix } = this.opts; + const errorRespName = toPascalCase(this.errorTypeName); + const typesFile = `${toSnakeCase(prefix)}_types.zig`; + + const methods = schema.commands + .map((c) => { + const methodName = this.methodName(c.name); + const zigCmdName = toPascalCase(c.name); + const zigRespName = toPascalCase(c.responseType); + return ` pub fn ${methodName}(self: *Self, cmd: types.${zigCmdName}) !types.${zigRespName} { + const request_bytes = try Self.encode("${c.name}", try cmd.toPayload(alloc)); + defer alloc.free(request_bytes); + const response_bytes = try self.backend.call(request_bytes); + defer alloc.free(response_bytes); + const resp_name, const resp_payload = try Self.decode(response_bytes); + if (std.mem.eql(u8, resp_name, "${this.errorTypeName}")) { + self.last_server_error = extractErrorMessage(resp_payload); + return error.ServerError; + } + return try types.${zigRespName}.fromPayload(resp_payload); + }`; + }) + .join("\n\n"); + + return `//! AUTOGENERATED - DO NOT EDIT +//! ${prefix} client — typed methods parameterized on a backend type. +//! +//! The backend must satisfy: call(self, request: []const u8) ![]u8 and destroy(self) void. +//! See backend.zig for the interface contract. +//! Implementations: ipc_runtime.Client, FfiBackend (ffi_backend.zig). + +const std = @import("std"); +const msgpack = @import("msgpack"); +const Payload = msgpack.Payload; +const types = @import("${typesFile}"); +const backend_mod = @import("backend.zig"); + +const alloc = std.heap.page_allocator; + +pub fn Client(comptime BackendType: type) type { + comptime backend_mod.assertBackend(BackendType); + + return struct { + const Self = @This(); + backend: *BackendType, + /// Message from the most recent server error response. Zig errors + /// carry no payload, so error.ServerError callers read this for the + /// server's diagnostic. Valid until the next call on this client. + last_server_error: ?[]const u8 = null, + + pub fn init(backend: *BackendType) Self { + return .{ .backend = backend }; + } + + pub fn destroy(self: *Self) void { + self.backend.destroy(); + } + +${methods} + + // --- internal helpers --- + + fn encode(cmd_name: []const u8, cmd_fields: Payload) ![]u8 { + var inner = try Payload.arrPayload(2, alloc); + try inner.setArrElement(0, try Payload.strToPayload(cmd_name, alloc)); + try inner.setArrElement(1, cmd_fields); + var outer = try Payload.arrPayload(1, alloc); + try outer.setArrElement(0, inner); + + var allocating_writer = std.Io.Writer.Allocating.init(alloc); + var packer = msgpack.PackerIO.init(undefined, &allocating_writer.writer); + try packer.write(outer); + return try allocating_writer.toOwnedSlice(); + } + + fn decode(response_bytes: []const u8) !struct { []const u8, Payload } { + var reader = std.Io.Reader.fixed(response_bytes); + var unpacker = msgpack.PackerIO.init(&reader, undefined); + const resp = try unpacker.read(alloc); + const resp_len = try resp.getArrLen(); + if (resp_len != 2) return error.InvalidResponse; + const name = try (try resp.getArrElement(0)).asStr(); + const payload = try resp.getArrElement(1); + return .{ name, payload }; + } + }; +} + +fn extractErrorMessage(payload: Payload) ?[]const u8 { + const msg = (payload.mapGet("message") catch return null) orelse return null; + return msg.asStr() catch null; +} +`; + } + + /** Generate the server wrapper — typed dispatch parameterized on a handler type */ + generateServer(schema: CompiledSchema): string { + this.errorTypeName = schema.errorTypeName; + const { prefix } = this.opts; + const typesFile = `${toSnakeCase(prefix)}_types.zig`; + + const handlerMethodNames = schema.commands.map((c) => + this.methodName(c.name), + ); + + // Dispatch cases: match command name → deserialize → call handler → serialize response + const dispatchCases = schema.commands + .map((c) => { + const methodName = this.methodName(c.name); + const zigCmdName = toPascalCase(c.name); + return ` if (std.mem.eql(u8, cmd_name, "${c.name}")) { + const cmd = types.${zigCmdName}.fromPayload(cmd_fields) catch |err| return makeErrorFmt("decode of ${c.name} failed: {s}", .{@errorName(err)}); + const resp = self.handler.${methodName}(cmd) catch |err| return self.handlerError("${c.name}", err); + const resp_payload = resp.toPayload(alloc) catch |err| return makeErrorFmt("encode of ${c.responseType} failed: {s}", .{@errorName(err)}); + return .{ .resp_name = "${c.responseType}", .resp_payload = resp_payload }; + }`; + }) + .join("\n"); + + return `//! AUTOGENERATED - DO NOT EDIT +//! ${prefix} IPC server — typed dispatch parameterized on a handler type. +//! +//! The handler is any type with one method per command: +//! pub fn ${handlerMethodNames[0] ?? "command"}(self: *@This(), cmd: types.${toPascalCase(schema.commands[0]?.name ?? "Command")}) !types.${toPascalCase(schema.commands[0]?.responseType ?? "Response")} +//! Handler failures are wrapped into the schema error variant. +//! +//! Wire it into a transport, e.g. @import("ipc_runtime"): +//! +//! var dispatcher = Dispatcher(MyHandler).init(&handler); +//! var server = try ipc_runtime.Server.fromPath(path); +//! try server.listen(); +//! server.run(*Dispatcher(MyHandler), &dispatcher, Dispatcher(MyHandler).handleRequest); + +const std = @import("std"); +const msgpack = @import("msgpack"); +const Payload = msgpack.Payload; +const types = @import("${typesFile}"); + +const alloc = std.heap.page_allocator; + +/// Result of dispatching one command. +pub const DispatchResult = struct { resp_name: []const u8, resp_payload: Payload }; + +/// Comptime check that HandlerType has every command handler method. +pub fn assertHandler(comptime HandlerType: type) void { +${handlerMethodNames + .map( + (m) => ` if (!@hasDecl(HandlerType, "${m}")) { + @compileError(@typeName(HandlerType) ++ " is missing handler method '${m}'"); + }`, + ) + .join("\n")} +} + +pub fn Dispatcher(comptime HandlerType: type) type { + comptime assertHandler(HandlerType); + + return struct { + const Self = @This(); + handler: *HandlerType, + // Per-request response scratch; freed on the next call. The transport + // contract requires the returned slice to stay valid until then. + resp_scratch: ?[]u8 = null, + + pub fn init(handler: *HandlerType) Self { + return .{ .handler = handler }; + } + + /// Typed dispatch of a decoded [name, payload] command. + pub fn dispatch(self: *Self, cmd_name: []const u8, cmd_fields: Payload) DispatchResult { +${dispatchCases} + + return makeErrorFmt("unknown command: {s}", .{cmd_name}); + } + + /// Transport entry point: decode framed request bytes, dispatch, and + /// encode framed response bytes. All failures (malformed framing + /// included) produce the schema error variant. + pub fn handleRequest(self: *Self, client_id: i32, request_bytes: []const u8) []u8 { + _ = client_id; + if (self.resp_scratch) |prev| alloc.free(prev); + self.resp_scratch = null; + + const parsed = parseRequest(request_bytes) catch |err| { + return self.encodeResponse(makeErrorFmt("malformed request: {s}", .{@errorName(err)})); + }; + return self.encodeResponse(self.dispatch(parsed.cmd_name, parsed.cmd_fields)); + } + + /// Build the error variant for a failed handler call. Zig errors + /// carry no payload, so a handler can stash a rich diagnostic in an + /// optional \`error_message: ?[]const u8\` field on itself before + /// returning; otherwise the error name is used. + fn handlerError(self: *Self, command_name: []const u8, err: anyerror) DispatchResult { + if (comptime @hasField(HandlerType, "error_message")) { + if (self.handler.error_message) |message| { + self.handler.error_message = null; + return makeErrorFmt("{s}", .{message}); + } + } + return makeErrorFmt("{s} failed: {s}", .{ command_name, @errorName(err) }); + } + + fn encodeResponse(self: *Self, result: DispatchResult) []u8 { + const bytes = encodeNamed(result.resp_name, result.resp_payload) catch + encodeNamed("${this.errorTypeName}", makeErrorFmt("response encode failed", .{}).resp_payload) catch + @panic("cannot encode error response"); + self.resp_scratch = bytes; + return bytes; + } + }; +} + +const ParsedRequest = struct { cmd_name: []const u8, cmd_fields: Payload }; + +fn parseRequest(request_bytes: []const u8) !ParsedRequest { + var reader = std.Io.Reader.fixed(request_bytes); + var unpacker = msgpack.PackerIO.init(&reader, undefined); + const request = try unpacker.read(alloc); + if (try request.getArrLen() != 1) return error.BadOuterArray; + const inner = try request.getArrElement(0); + if (try inner.getArrLen() != 2) return error.BadInnerArray; + const cmd_name = try (try inner.getArrElement(0)).asStr(); + const cmd_fields = try inner.getArrElement(1); + return .{ .cmd_name = cmd_name, .cmd_fields = cmd_fields }; +} + +fn encodeNamed(name: []const u8, payload: Payload) ![]u8 { + var resp = try Payload.arrPayload(2, alloc); + try resp.setArrElement(0, try Payload.strToPayload(name, alloc)); + try resp.setArrElement(1, payload); + var allocating_writer = std.Io.Writer.Allocating.init(alloc); + var packer = msgpack.PackerIO.init(undefined, &allocating_writer.writer); + try packer.write(resp); + return try allocating_writer.toOwnedSlice(); +} + +fn makeErrorFmt(comptime fmt: []const u8, args: anytype) DispatchResult { + const message = std.fmt.allocPrint(alloc, fmt, args) catch "error"; + var err_map = Payload.mapPayload(alloc); + err_map.mapPut("message", Payload.strToPayload(message, alloc) catch Payload{ .nil = {} }) catch {}; + return .{ .resp_name = "${this.errorTypeName}", .resp_payload = err_map }; +} +`; + } +} diff --git a/ipc-codegen/templates/cpp/ipc_codegen/msgpack_adaptor.hpp b/ipc-codegen/templates/cpp/ipc_codegen/msgpack_adaptor.hpp new file mode 100644 index 000000000000..0883871d6b14 --- /dev/null +++ b/ipc-codegen/templates/cpp/ipc_codegen/msgpack_adaptor.hpp @@ -0,0 +1,186 @@ +#pragma once +// +// Struct-map msgpack adaptor: pack/unpack types that declare their fields via +// the SERIALIZATION_FIELDS macro into a JSON-like map. Codegen-emitted C++ +// clients and servers include this support header when serializing wire types. +// +// Only depends on msgpack-c. Standalone-buildable. +// +// Some consumers build msgpack-c with THROW/RETHROW hooks; throw.hpp defines +// guarded defaults so a parent project's predefinition wins. +#include "throw.hpp" + +#include +#include +#include +#include +#include "msgpack_include.hpp" +#include +#include + +#ifndef IPC_CODEGEN_MSGPACK_CONCEPTS_DEFINED +#define IPC_CODEGEN_MSGPACK_CONCEPTS_DEFINED + +namespace msgpack_concepts { + +struct DoNothing { + void operator()(auto...) {} +}; + +template +concept HasMsgPack = requires(T t, DoNothing nop) { t.msgpack(nop); }; + +template +concept MsgpackConstructible = requires(T object, Args... args) { T{args...}; }; + +} // namespace msgpack_concepts + +#endif + +#ifndef IPC_CODEGEN_MSGPACK_BIN32_ALIAS_CONCEPT_DEFINED +#define IPC_CODEGEN_MSGPACK_BIN32_ALIAS_CONCEPT_DEFINED + +namespace msgpack_concepts { + +template +concept Bin32Alias = + requires(T t, const T ct) { + typename T::IPC_CODEGEN_BIN32_ALIAS; + { t.data() } -> std::same_as; + { ct.data() } -> std::same_as; + { ct.size() } -> std::convertible_to; + }; + +} // namespace msgpack_concepts + +#endif + +#ifndef IPC_CODEGEN_MSGPACK_DROP_KEYS_DEFINED +#define IPC_CODEGEN_MSGPACK_DROP_KEYS_DEFINED + +namespace msgpack { + +// SERIALIZATION_FIELDS' msgpack() callback receives args interleaved as +// (key0, val0, key1, val1, …). drop_keys strips the keys so we can check +// that the type is constructible from the values. +template +auto drop_keys_impl(Tuple &&tuple, std::index_sequence) { + return std::tie(std::get(std::forward(tuple))...); +} + +template auto drop_keys(std::tuple &&tuple) { + static_assert(sizeof...(Args) % 2 == 0, + "Tuple must contain an even number of elements"); + return drop_keys_impl(tuple, std::make_index_sequence{}); +} + +} // namespace msgpack + +#endif + +#ifndef IPC_CODEGEN_MSGPACK_STRUCT_MAP_ADAPTOR_DEFINED +#define IPC_CODEGEN_MSGPACK_STRUCT_MAP_ADAPTOR_DEFINED + +namespace msgpack::adaptor { + +template <> struct pack> { + template + packer &operator()(msgpack::packer &o, + std::array const &v) const { + o.pack_bin(static_cast(v.size())); + o.pack_bin_body(reinterpret_cast(v.data()), + static_cast(v.size())); + return o; + } +}; + +template <> struct convert> { + msgpack::object const &operator()(msgpack::object const &o, + std::array &v) const { + if (o.type != msgpack::type::BIN || o.via.bin.size != v.size()) { + THROW msgpack::type_error(); + } + std::memcpy(v.data(), o.via.bin.ptr, v.size()); + return o; + } +}; + +template struct pack { + template + packer &operator()(msgpack::packer &o, T const &v) const { + o.pack_bin(static_cast(v.size())); + o.pack_bin_body(reinterpret_cast(v.data()), + static_cast(v.size())); + return o; + } +}; + +template struct convert { + msgpack::object const &operator()(msgpack::object const &o, T &v) const { + if (o.type != msgpack::type::BIN || o.via.bin.size != v.size()) { + THROW msgpack::type_error(); + } + std::memcpy(v.data(), o.via.bin.ptr, v.size()); + return o; + } +}; + +// reads structs with msgpack() method from a JSON-like dictionary +template struct convert { + msgpack::object const &operator()(msgpack::object const &o, T &v) const { + static_assert(std::is_default_constructible_v, + "SERIALIZATION_FIELDS requires default-constructible types " + "(used during unpacking)"); + v.msgpack([&](auto &...args) { + auto static_checker = [&](auto &...value_args) { + static_assert( + msgpack_concepts::MsgpackConstructible, + "SERIALIZATION_FIELDS requires a constructor that can take the " + "types listed in " + "SERIALIZATION_FIELDS. " + "Type or arg count mismatch, or member initializer constructor not " + "available."); + }; + // Call static checker to ensure we have a constructor that takes all + // fields - unless we opt-out. + if constexpr (!requires { typename T::MSGPACK_NO_STATIC_CHECK; }) { + std::apply(static_checker, drop_keys(std::tie(args...))); + } + msgpack::type::define_map{args...}.msgpack_unpack(o); + }); + return o; + } +}; + +// converts structs with msgpack() method to a JSON-like dictionary +template struct pack { + template + packer &operator()(msgpack::packer &o, T const &v) const { + static_assert(std::is_default_constructible_v, + "SERIALIZATION_FIELDS requires default-constructible types " + "(used during unpacking)"); + const_cast(v).msgpack([&](auto &...args) { + auto static_checker = [&](auto &...value_args) { + static_assert( + msgpack_concepts::MsgpackConstructible, + "T requires a constructor that can take the fields listed in " + "SERIALIZATION_FIELDS (T will be " + "in template parameters in the compiler stack trace)" + "Check the SERIALIZATION_FIELDS macro usage in T for " + "incompleteness or wrong order. " + "Alternatively, a matching member initializer constructor might " + "not be available for T " + "and should be defined."); + }; + if constexpr (!requires { typename T::MSGPACK_NO_STATIC_CHECK; }) { + std::apply(static_checker, drop_keys(std::tie(args...))); + } + msgpack::type::define_map{args...}.msgpack_pack(o); + }); + return o; + } +}; + +} // namespace msgpack::adaptor + +#endif diff --git a/ipc-codegen/templates/cpp/ipc_codegen/msgpack_include.hpp b/ipc-codegen/templates/cpp/ipc_codegen/msgpack_include.hpp new file mode 100644 index 000000000000..1ac19615fb55 --- /dev/null +++ b/ipc-codegen/templates/cpp/ipc_codegen/msgpack_include.hpp @@ -0,0 +1,24 @@ +#pragma once +/** + * @file msgpack_include.hpp + * @brief The one sanctioned way to include from generated code. + * + * Under BB_NO_EXCEPTIONS (-fno-exceptions, e.g. WASM) msgpack-c's raw + * try/catch blocks do not compile, so they are rewritten to always-taken / + * dead branches for the duration of this include only. THROW (see throw.hpp) + * aborts in that mode, so the catch bodies are unreachable. Defining macros + * named after keywords is ill-formed if any standard-library header is + * preprocessed while they are active — hence the tight scope and #undef. + */ + +#include "throw.hpp" + +#ifdef BB_NO_EXCEPTIONS +#define try if (true) +#define catch(...) if (false) +#include +#undef try +#undef catch +#else +#include +#endif diff --git a/ipc-codegen/templates/cpp/ipc_codegen/throw.hpp b/ipc-codegen/templates/cpp/ipc_codegen/throw.hpp new file mode 100644 index 000000000000..478065e32e03 --- /dev/null +++ b/ipc-codegen/templates/cpp/ipc_codegen/throw.hpp @@ -0,0 +1,44 @@ +#pragma once +/** + * @file throw.hpp + * @brief THROW / RETHROW macros for code that compiles in both + * exception-enabled and -fno-exceptions modes. + * + * - Native (default): `THROW x` is equivalent to `throw x` — full exception + * semantics. `RETHROW` is bare `throw`. + * - WASM / -fno-exceptions (`BB_NO_EXCEPTIONS` defined): `THROW x` evaluates + * `x` once (so its constructor still runs, matching observable behaviour) + * and then aborts. `RETHROW` is a bare `std::abort()`. + * + * Use through codegen output that needs to compile in both modes. msgpack-c + * itself uses raw try/catch; include it via msgpack_include.hpp, which scopes + * a try/catch rewrite to that include only (defining keyword macros globally + * is ill-formed once standard headers follow). + * + * The macros are defined inside an `#ifndef THROW` guard so callers that + * predefine their own THROW/RETHROW can do so before this header is reached + * and we yield to whichever variant the parent project wants. + */ + +#ifndef THROW + +#ifdef BB_NO_EXCEPTIONS +#include + +namespace ipc::detail { +struct AbortOnThrow { + template + [[noreturn]] void operator<<(const T & /*ignored*/) const noexcept { + std::abort(); + } +}; +} // namespace ipc::detail + +#define THROW ::ipc::detail::AbortOnThrow() << +#define RETHROW std::abort() +#else +#define THROW throw +#define RETHROW throw +#endif // BB_NO_EXCEPTIONS + +#endif // THROW diff --git a/ipc-codegen/templates/rust/backend.rs b/ipc-codegen/templates/rust/backend.rs new file mode 100644 index 000000000000..3ca7e95ff84a --- /dev/null +++ b/ipc-codegen/templates/rust/backend.rs @@ -0,0 +1,63 @@ +//! Backend trait for msgpack communication +//! +//! This module defines a simple, pluggable interface for byte backends. +//! Users can easily implement custom backends (FFI, WASM, IPC, etc.). + +use super::error::Result; + +/// Simple interface for msgpack backend implementations. +/// +/// Implement this trait to create a custom backend for a generated client. +/// The backend handles msgpack-encoded command/response communication. +/// +/// # Example +/// +/// ```ignore +/// struct MyCustomBackend { +/// // your FFI handle, connection, etc. +/// } +/// +/// impl Backend for MyCustomBackend { +/// fn call(&mut self, input: &[u8]) -> Result> { +/// // Send input to your backend +/// // Return the response +/// } +/// +/// fn destroy(&mut self) -> Result<()> { +/// // Clean up resources +/// Ok(()) +/// } +/// } +/// ``` +pub trait Backend { + /// Execute a msgpack command and return the msgpack response. + /// + /// # Arguments + /// * `input` - Msgpack-encoded command + /// + /// # Returns + /// Msgpack-encoded response + fn call(&mut self, input: &[u8]) -> Result>; + + /// Clean up resources and shutdown the backend. + fn destroy(&mut self) -> Result<()>; +} + +// Bridge impl so ipc_runtime::IpcClient (UDS / MPSC-SHM transport) plugs +// directly into any generated Api as the Backend. Gated behind the +// consumer crate's `ipc-runtime` feature so FFI-only consumers don't need +// the ipc-runtime dependency at all: +// +// [features] +// default = ["ipc-runtime"] +// ipc-runtime = ["dep:ipc-runtime"] +#[cfg(feature = "ipc-runtime")] +impl Backend for ipc_runtime::IpcClient { + fn call(&mut self, input: &[u8]) -> Result> { + ipc_runtime::IpcClient::call(self, input) + .map_err(|e| super::error::IpcError::Backend(e.to_string())) + } + fn destroy(&mut self) -> Result<()> { + Ok(()) + } +} diff --git a/ipc-codegen/templates/rust/error.rs b/ipc-codegen/templates/rust/error.rs new file mode 100644 index 000000000000..4a1357f5e3b0 --- /dev/null +++ b/ipc-codegen/templates/rust/error.rs @@ -0,0 +1,32 @@ +//! Error types for generated IPC clients + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum IpcError { + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Deserialization error: {0}")] + Deserialization(String), + + #[error("Backend error: {0}")] + Backend(String), + + #[error("IPC error: {0}")] + Ipc(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("Invalid response: {0}")] + InvalidResponse(String), + + #[error("Connection error: {0}")] + Connection(String), + + #[error("WASM error: {0}")] + Wasm(String), +} + +pub type Result = std::result::Result; diff --git a/ipc-codegen/templates/rust/ffi_backend.rs b/ipc-codegen/templates/rust/ffi_backend.rs new file mode 100644 index 000000000000..5137ea99c6fa --- /dev/null +++ b/ipc-codegen/templates/rust/ffi_backend.rs @@ -0,0 +1,128 @@ +//! FFI backend scaffold for direct library linking. +//! +//! Calls a C symbol with msgpack bytes — no IPC overhead. Link against a +//! native library that exports `ipc_ffi_entry`, and add the appropriate +//! `-L` / `-l` directives to your `build.rs`. +//! +//! # Requirements +//! +//! 1. A native library exporting an extern-C function with this signature: +//! ```text +//! void ipc_ffi_entry( +//! const uint8_t* input, size_t input_len, +//! uint8_t** output_out, size_t* output_len_out); +//! ``` +//! `*output_out` must be a `malloc`'d buffer the caller is responsible for freeing. +//! 2. Library search path configured (via `.cargo/config.toml`, `RUSTFLAGS`, or +//! `cargo:rustc-link-search` in `build.rs`). +//! +//! # Example +//! +//! ```ignore +//! use my_service_client::{ServiceApi, FfiBackend}; +//! +//! let backend = FfiBackend::new()?; +//! let mut api = ServiceApi::new(backend); +//! let response = api.some_command(args)?; +//! ``` + +use super::backend::Backend; +use super::error::{IpcError, Result}; +use std::ptr; + +extern "C" { + /// Execute a msgpack-encoded command and return msgpack-encoded response. + /// + /// # Safety + /// - `input_in` must point to valid memory of `input_len_in` bytes + /// - `output_out` and `output_len_out` must be valid pointers + /// - Caller must free `*output_out` using `libc::free` + fn ipc_ffi_entry( + input_in: *const u8, + input_len_in: usize, + output_out: *mut *mut u8, + output_len_out: *mut usize, + ); +} + +/// FFI backend that calls a native library directly via its C ABI. +/// +/// Most performant backend (no process spawn, no IPC overhead) but requires +/// linking against the native library at build time. +/// +/// # Thread Safety +/// +/// This backend is **not** thread-safe by default. Each thread should have +/// its own `FfiBackend` instance, or access should be synchronized externally. +pub struct FfiBackend { + _initialized: bool, +} + +impl FfiBackend { + /// Create a new FFI backend. + pub fn new() -> Result { + Ok(Self { _initialized: true }) + } +} + +impl Backend for FfiBackend { + fn call(&mut self, input: &[u8]) -> Result> { + let mut output_ptr: *mut u8 = ptr::null_mut(); + let mut output_len: usize = 0; + + // SAFETY: + // - input.as_ptr() is valid for input.len() bytes + // - output_ptr and output_len are valid stack pointers + // - the FFI entrypoint allocates output using malloc, which we free below + unsafe { + ipc_ffi_entry( + input.as_ptr(), + input.len(), + &mut output_ptr, + &mut output_len, + ); + } + + if output_ptr.is_null() { + return Err(IpcError::Backend( + "FFI entry returned null pointer".to_string(), + )); + } + + if output_len == 0 { + unsafe { + libc::free(output_ptr as *mut libc::c_void); + } + return Err(IpcError::Backend( + "FFI entry returned empty response".to_string(), + )); + } + + // SAFETY: output_ptr is valid for output_len bytes, allocated by malloc + let output = unsafe { std::slice::from_raw_parts(output_ptr, output_len).to_vec() }; + + // SAFETY: output_ptr was allocated by the FFI entrypoint using malloc + unsafe { + libc::free(output_ptr as *mut libc::c_void); + } + + Ok(output) + } + + fn destroy(&mut self) -> Result<()> { + self._initialized = false; + Ok(()) + } +} + +impl Drop for FfiBackend { + fn drop(&mut self) { + let _ = self.destroy(); + } +} + +impl Default for FfiBackend { + fn default() -> Self { + Self::new().expect("Failed to initialize FfiBackend") + } +} diff --git a/ipc-codegen/templates/zig/backend.zig b/ipc-codegen/templates/zig/backend.zig new file mode 100644 index 000000000000..bd26281f8095 --- /dev/null +++ b/ipc-codegen/templates/zig/backend.zig @@ -0,0 +1,27 @@ +/// Backend abstraction — comptime interface for transport. +/// +/// A valid backend type must provide: +/// fn call(self: *T, request: []const u8) ![]u8 +/// fn destroy(self: *T) void +/// +/// Implementations: +/// ipc_runtime.Client — UDS / MPSC-SHM transport from ipc-runtime/zig +/// FfiBackend (ffi_backend.zig) — Direct C FFI linking +/// +/// Usage with the generated client: +/// const Client = @import("my_service_client.zig").Client; +/// const ipc_runtime = @import("ipc_runtime"); +/// var backend = try ipc_runtime.Client.fromPath(allocator, "/tmp/my-service.sock"); +/// var client = Client(ipc_runtime.Client){ .backend = &backend }; + +/// Compile-time check that a type satisfies the backend interface. +pub fn assertBackend(comptime T: type) void { + // Must have: fn call(self: *T, request: []const u8) ![]u8 + if (!@hasDecl(T, "call")) { + @compileError("Backend type " ++ @typeName(T) ++ " missing 'call' method"); + } + // Must have: fn destroy(self: *T) void + if (!@hasDecl(T, "destroy")) { + @compileError("Backend type " ++ @typeName(T) ++ " missing 'destroy' method"); + } +} diff --git a/ipc-codegen/templates/zig/ffi_backend.zig b/ipc-codegen/templates/zig/ffi_backend.zig new file mode 100644 index 000000000000..93e16eb915eb --- /dev/null +++ b/ipc-codegen/templates/zig/ffi_backend.zig @@ -0,0 +1,34 @@ +/// FFI backend scaffold for direct library linking. +/// +/// Calls a C symbol with msgpack bytes — no IPC overhead. Link against a +/// native library that exports `ipc_ffi_entry`, and adjust the link +/// configuration in your build.zig to pull that library in. +/// +/// Satisfies the backend interface: call(request) -> response, destroy(). +const std = @import("std"); + +extern fn ipc_ffi_entry(input: [*]const u8, input_len: usize, output: *[*]u8, output_len: *usize) void; + +/// Allocator contract: callers free returned slices with this allocator +/// (the generated client uses std.heap.page_allocator), so the malloc'd FFI +/// buffer is copied into it and freed with the C allocator here — freeing a +/// malloc'd pointer with a Zig allocator is undefined behaviour. +const alloc = std.heap.page_allocator; + +pub const FfiBackend = struct { + /// Send a msgpack command and receive the response via FFI. + pub fn call(self: *FfiBackend, request: []const u8) ![]u8 { + _ = self; + var out_ptr: [*]u8 = undefined; + var out_len: usize = 0; + ipc_ffi_entry(request.ptr, request.len, &out_ptr, &out_len); + defer std.c.free(out_ptr); + const response = try alloc.alloc(u8, out_len); + @memcpy(response, out_ptr[0..out_len]); + return response; + } + + pub fn destroy(self: *FfiBackend) void { + _ = self; + } +}; diff --git a/ipc-codegen/test/schema_visitor.test.ts b/ipc-codegen/test/schema_visitor.test.ts new file mode 100644 index 000000000000..90cfd19e9db6 --- /dev/null +++ b/ipc-codegen/test/schema_visitor.test.ts @@ -0,0 +1,189 @@ +/** + * Schema validation tests. Run with: + * node --experimental-strip-types --no-warnings test/schema_visitor.test.ts + * Exits non-zero on failure. + */ +import { + SchemaVisitor, + stripJsonc, + friendlyToPositional, +} from "../src/schema_visitor.ts"; +import * as fs from "node:fs"; +import * as path from "node:path"; + +let failures = 0; + +function expectThrows(label: string, fn: () => void, messagePart: string) { + try { + fn(); + console.error(`FAIL: ${label} did not throw`); + failures++; + } catch (e: any) { + if (!e.message.includes(messagePart)) { + console.error( + `FAIL: ${label} threw wrong error: ${e.message} (expected to include '${messagePart}')`, + ); + failures++; + } else { + console.log(`ok: ${label}`); + } + } +} + +function expectOk(label: string, fn: () => void) { + try { + fn(); + console.log(`ok: ${label}`); + } catch (e: any) { + console.error(`FAIL: ${label} threw: ${e.message}`); + failures++; + } +} + +const errResp = ["FooErrorResponse", { message: "string" }]; + +expectOk("echo schema is valid", () => { + const schemaPath = path.join( + import.meta.dirname, + "../echo_example/schema/schema.jsonc", + ); + const parsed = JSON.parse(stripJsonc(fs.readFileSync(schemaPath, "utf8"))); + const { commands, responses } = friendlyToPositional(parsed); + new SchemaVisitor().visit(commands, responses); +}); + +expectThrows( + "missing error response", + () => + new SchemaVisitor().visit( + ["named_union", [["FooBar", { x: "unsigned int" }]]], + ["named_union", [["FooBarResponse", { y: "unsigned int" }]]], + ), + "no error response", +); + +expectThrows( + "duplicate command", + () => + new SchemaVisitor().visit( + [ + "named_union", + [ + ["FooA", {}], + ["FooA", {}], + ], + ], + ["named_union", [["FooAResponse", {}], ["FooAResponse", {}], errResp]], + ), + "Duplicate command name", +); + +expectOk("response reuse by position is allowed", () => + new SchemaVisitor().visit( + ["named_union", [["FooBar", {}]]], + ["named_union", [["FooSharedResponse", {}], errResp]], + ), +); + +expectThrows( + "misordered unions", + () => + new SchemaVisitor().visit( + [ + "named_union", + [ + ["FooA", {}], + ["FooB", {}], + ], + ], + ["named_union", [["FooBResponse", {}], ["FooAResponse", {}], errResp]], + ), + "misordered", +); + +expectOk("string response reference resolves to earlier inline struct", () => + new SchemaVisitor().visit( + [ + "named_union", + [ + ["FooMake", {}], + ["FooGet", {}], + ], + ], + [ + "named_union", + [ + [ + "FooMakeResponse", + { + __typename: "FooMakeResponse", + item: { __typename: "FooGetResponse", x: "unsigned int" }, + }, + ], + ["FooGetResponse", "FooGetResponse"], + errResp, + ], + ], + ), +); + +expectThrows( + "dangling string response reference", + () => + new SchemaVisitor().visit( + ["named_union", [["FooBar", {}]]], + ["named_union", [["FooBarResponse", "NeverDefined"], errResp]], + ), + "not defined earlier", +); + +expectThrows( + "bad error struct shape", + () => + new SchemaVisitor().visit( + ["named_union", [["FooBar", {}]]], + [ + "named_union", + [ + ["FooBarResponse", {}], + ["FooErrorResponse", { msg: "string" }], + ], + ], + ), + "exactly one field 'message: string'", +); + +expectThrows( + "reserved word field", + () => + new SchemaVisitor().visit( + ["named_union", [["FooBar", { type: "unsigned int" }]]], + ["named_union", [["FooBarResponse", {}], errResp]], + ), + "reserved word", +); + +expectThrows( + "colliding field projections", + () => + new SchemaVisitor().visit( + [ + "named_union", + [["FooBar", { forkId: "unsigned int", fork_id: "unsigned int" }]], + ], + ["named_union", [["FooBarResponse", {}], errResp]], + ), + "both map to", +); + +expectThrows( + "bad top-level shape", + () => new SchemaVisitor().visit({ commands: [] }, ["named_union", []]), + "named_union", +); + +if (failures > 0) { + console.error(`${failures} test(s) failed`); + process.exit(1); +} +console.log("schema_visitor tests passed"); diff --git a/ipc-runtime/.clang-format b/ipc-runtime/.clang-format new file mode 100644 index 000000000000..14a6b3bac238 --- /dev/null +++ b/ipc-runtime/.clang-format @@ -0,0 +1,26 @@ +PointerAlignment: Left +ColumnLimit: 120 +IndentWidth: 4 +BinPackArguments: false +BinPackParameters: false +Cpp11BracedListStyle: false +AlwaysBreakAfterReturnType: None +AlwaysBreakAfterDefinitionReturnType: None +PenaltyReturnTypeOnItsOwnLine: 1000000 +BreakConstructorInitializers: BeforeComma +BreakBeforeBraces: Custom +BraceWrapping: + AfterClass: false + AfterEnum: false + AfterFunction: true + AfterNamespace: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +AllowShortFunctionsOnASingleLine : Inline +SortIncludes: true diff --git a/ipc-runtime/.rebuild_patterns b/ipc-runtime/.rebuild_patterns new file mode 100644 index 000000000000..4dbc9784e676 --- /dev/null +++ b/ipc-runtime/.rebuild_patterns @@ -0,0 +1,13 @@ +^ipc-runtime/cpp/ipc_runtime/.*\.(cpp|hpp)$ +^ipc-runtime/cpp/napi/.*\.(cpp|hpp)$ +^ipc-runtime/cpp/CMakeLists\.txt$ +^ipc-runtime/cpp/CMakePresets\.json$ +^ipc-runtime/cpp/scripts/.*$ +^ipc-runtime/bootstrap\.sh$ +^ipc-runtime/ts/src/.*$ +^ipc-runtime/ts/package\.json$ +^ipc-runtime/ts/tsconfig\.json$ +^ipc-runtime/ts/scripts/.*$ +^ipc-runtime/rust/(src/.*|build\.rs|Cargo\.toml)$ +^ipc-runtime/zig/(src/.*|build\.zig)$ +^ipc-runtime/scripts/.*$ diff --git a/ipc-runtime/README.md b/ipc-runtime/README.md new file mode 100644 index 000000000000..15118999154d --- /dev/null +++ b/ipc-runtime/README.md @@ -0,0 +1,283 @@ +# ipc-runtime + +UDS + MPSC shared-memory transport library for IPC services. + +ipc-runtime is the byte-moving layer underneath +[`/ipc-codegen`](../ipc-codegen). It exposes a small `IpcServer` / +`IpcClient` API that picks the right transport from the path you hand it — +`.sock` → Unix-domain socket, `.shm` → MPSC shared memory — so per-service +code never branches on transport. + +The same C++ implementation is reused from Rust, TypeScript (Node.js) and +Zig via a tiny C ABI. Wire types travel as opaque byte arrays; the per- +language codegen output knows how to (de)serialise them. + +## Quick start + +```sh +cd ipc-runtime +./bootstrap.sh # build C++ static lib + tests (default) +./bootstrap.sh test # run C++ ipc_runtime_tests +``` + +Per-language bindings build standalone: + +```sh +# Rust crate +cd rust && cargo build + +# TypeScript package (publishes @aztec/ipc-runtime via file: link) +cd ts && yarn install --immutable && yarn build + +# Zig binding (compiles the C++ sources itself; no prebuilt archive) +cd zig && zig build test +``` + +The Rust and Zig bindings each compile the C++ sources themselves (via the +`cc` crate / `zig build`), so there's no separately-built archive to ship +between them and ipc-runtime/cpp. + +## Layout + +``` +ipc-runtime/ + bootstrap.sh # build / test (C++ only) + cpp/ + ipc_runtime/ + constants.hpp # shared limits/defaults (mirrored in ts/src/types.ts) + ipc_client.{hpp,cpp} # abstract IpcClient interface + factories + ipc_server.{hpp,cpp} # abstract IpcServer interface + factories + run() loop + socket_client.{hpp,cpp} # UDS client implementation + socket_server.{hpp,cpp} # UDS server implementation (epoll / kqueue) + shm_client.hpp # single-client (SPSC) SHM client + shm_server.hpp # single-client (SPSC) SHM server + mpsc_shm_client.hpp # multi-client SHM client (one slot per client) + mpsc_shm_server.hpp # multi-client SHM server + shm_common.hpp # length-prefix framing over the rings + shm/ # lock-free SPSC/MPSC ring buffer primitives + serve_helper.{hpp,cpp} # ipc::make_server / make_client (path-suffix dispatch) + signal_handlers.{hpp,cpp} # ipc::install_default_signal_handlers + c_abi.{h,cpp} # C ABI exported to Rust / Zig / NAPI + CMakeLists.txt + rust/ + src/lib.rs # safe Rust wrapper over c_abi.h + build.rs # invokes cc to compile cpp/ sources + Cargo.toml + ts/ + src/ + index.ts # re-exports of the surface below + uds_client.ts # UdsIpcClient (Node net.Socket) + uds_server.ts # UdsIpcServer + shm_client.ts # NapiShmSyncClient / NapiShmAsyncClient (NAPI bridge) + types.ts # IpcClientSync / IpcClientAsync interfaces + package.json + zig/ + src/main.zig # Zig binding (Server.fromPath / Client.fromPath) + src/smoke.zig # in-process smoke test + build.zig +``` + +The shared C ABI in `cpp/ipc_runtime/c_abi.h` is the single contract every +non-C++ binding implements. Adding a new language binding is "wrap that +header"; there is no separate cross-language IPC framing to learn. + +## Transport selection + +`ipc::make_server(path)` and `ipc::make_client(path)` pick the transport +from the path's suffix: + +| Path suffix | Transport | Used by | +|-------------|----------------------------|--------------------------------------| +| `*.sock` | Unix-domain socket | Async local clients, CLI tools | +| `*.shm` | MPSC shared-memory rings | Low-latency native clients | + +```cpp +#include "ipc_runtime/serve_helper.hpp" +#include "ipc_runtime/signal_handlers.hpp" + +auto server = ipc::make_server(input_path); // .sock or .shm +ipc::install_default_signal_handlers(*server); // SIGINT/SIGTERM → clean exit +server->listen(); +server->run(make_service_handler(request_ctx)); // codegen-emitted dispatcher +``` + +UDS is the default; SHM is a sub-microsecond hot path for latency-sensitive +local clients (`shm/README.md` covers the ring-buffer internals). + +The suffix helpers use MPSC-SHM for `.shm` because it supports multiple +client slots. If a service only ever needs one producer/client, the lower-level +`IpcServer::create_shm` / `IpcClient::create_shm` factories are also supported; +they use one request ring and one response ring directly. They are not selected +by path suffix to keep `.shm` behavior stable for multi-client services. + +## API surface + +### C++ (`cpp/ipc_runtime/`) + +```cpp +namespace ipc { + +class IpcClient { +public: + static std::unique_ptr create_socket(const std::string& socket_path); + static std::unique_ptr create_shm(const std::string& base_name); + static std::unique_ptr create_mpsc_shm(const std::string& base_name, + std::size_t client_id); + + virtual bool connect() = 0; + virtual bool send(const void* data, size_t len, uint64_t timeout_ns) = 0; + virtual std::span receive(uint64_t timeout_ns) = 0; + virtual void release(size_t message_size) = 0; + virtual void close() = 0; +}; + +class IpcServer { +public: + using Handler = std::function(int client_id, std::span)>; + + static std::unique_ptr create_socket(const std::string& path, int max_clients); + static std::unique_ptr create_shm(const std::string& base_name, + std::size_t request_ring_size = DEFAULT_RING_SIZE, + std::size_t response_ring_size = DEFAULT_RING_SIZE); + static std::unique_ptr create_mpsc_shm(const std::string& base_name, + std::size_t max_clients, + std::size_t request_ring_size = DEFAULT_RING_SIZE, + std::size_t response_ring_size = DEFAULT_RING_SIZE); + + virtual bool listen() = 0; + virtual int wait_for_data(uint64_t timeout_ns) = 0; + virtual std::span receive(int client_id) = 0; + virtual void release(int client_id, size_t message_size) = 0; + virtual bool send(int client_id, const void* data, size_t len) = 0; + virtual void close() = 0; + virtual void request_shutdown(); // NOT signal-safe (wakes waiters) + void request_shutdown_from_signal() noexcept; // signal-safe variant + virtual void run(const Handler& handler); // event loop +}; + +std::unique_ptr make_server(const std::string& path, const ServerOptions& = {}); +std::unique_ptr make_client(const std::string& path, std::size_t shm_client_id = 0); + +void install_default_signal_handlers(IpcServer& server); + +} // namespace ipc +``` + +`receive` returns a zero-copy span (into the SHM ring or the socket's +internal buffer); the caller must follow it with `release(span.size())` +before the next `receive` on the same client. The `run()` event loop owns +that pattern so handlers just deal in whole messages. + +Every request gets exactly one response: a handler returning a +zero-length vector sends a zero-length response frame (which clients see as +a valid empty reply, in every language binding). To exit the loop cleanly, +call `request_shutdown()`; `install_default_signal_handlers` wires +SIGINT/SIGTERM to the signal-safe `request_shutdown_from_signal()` so RAII +destructors run normally. + +### Rust (`rust/`, crate `ipc-runtime`) + +```rust +let mut server = ipc_runtime::Server::from_path("/tmp/svc.sock")?; +server.listen()?; +server.install_default_signal_handlers(); +server.run(|client_id, request| handle(client_id, request)); + +let mut client = ipc_runtime::Client::from_path("/tmp/svc.sock")?; +let response: Vec = client.call(&request_bytes)?; +``` + +`Server::from_path` / `Client::from_path` mirror the C++ helpers (suffix +dispatch). The Rust `Client::call(&[u8]) -> Result>` packages the +send + receive + release sequence into a single safe operation. For +multi-slot MPSC-SHM, use `Client::from_path_with_id(path, client_id)`. The +crate compiles the C++ sources via `build.rs` so there's no separate +linker hook for downstream Cargo users. + +### TypeScript (`ts/`, published as `@aztec/ipc-runtime`) + +Two transport-specific clients: + +| Class | Transport | Sync / Async | +|----------------------|----------------------------|--------------------------------------------------------------| +| `UdsIpcClient` | Node `net.Socket` | async only | +| `NapiShmSyncClient` | MPSC-SHM via NAPI bridge | sync | +| `NapiShmAsyncClient` | MPSC-SHM via NAPI bridge | async (C++ poll thread + ThreadSafeFunction bridge) | + +`UdsIpcServer` is provided for in-process tests; production servers are in +C++. + +### Zig (`zig/`) + +`Server.fromPath(path)` / `Client.fromPath(path)` over the same C ABI; the +Zig `build.zig` compiles the C++ sources directly with the bundled clang + +libc++, so there's no archive shim. + +## Shared constants + +`cpp/ipc_runtime/constants.hpp` (mirrored in `ts/src/types.ts`) is the single +definition of the transport limits and defaults: + +| Constant | Value | Meaning | +|----------|-------|---------| +| `MAX_FRAME_SIZE` | 256 MiB | Max length prefix accepted on receive; larger frames close the connection / fail the ring instead of allocating. | +| `CONNECT_RETRY_BUDGET_MS` | 5000 | Total client connect retry budget (all transports). | +| `DEFAULT_RING_SIZE` | 4 MiB | SHM ring size per direction per client. | +| `SOCKET_BACKLOG` | 10 | Default UDS listen backlog. | +| `DEFAULT_CALL_TIMEOUT_NS` | 0 | Per-call timeout for client send/receive; 0 = infinite. (`IpcServer::wait_for_data(0)` is the documented exception: non-blocking poll.) | + +## Wire framing + +Both transports use a 4-byte little-endian length prefix in front of every +message: + +``` +┌───────────────────────┬────────────────────────┐ +│ Length (uint32 le) │ Payload (Length bytes) │ +└───────────────────────┴────────────────────────┘ +``` + +Framing is handled inside `IpcServer::receive` / `IpcClient::recv`; callers +deal in whole messages. The codegen's `Command` / `Response` NamedUnion +sits inside that payload — see `ipc-codegen/SCHEMA_SPEC.md`. + +## Performance characteristics + +| Transport | Round-trip latency | Throughput, 1 client | Notes | +|---------------|---------------------|----------------------|--------------------------------------| +| UDS | 6–15 µs | ~150K msgs/s | One syscall per `send`/`recv` | +| MPSC-SHM (hot)| 0.3–1 µs | ~1M msgs/s | Lock-free; adaptive spin + futex | +| MPSC-SHM (cold)| 3–6 µs | n/a | First message after idle ring | + +The `cpp/ipc_runtime/grind_ipc.sh` script and `ipc_runtime_tests` stress the +SHM implementation; benchmark harnesses can reuse the same runtime APIs. + +## Threading model + +- **UDS server**: single-threaded event loop with `epoll`. Concurrent + clients are interleaved; handlers run on the loop thread, so heavy work + should be offloaded. +- **MPSC-SHM server**: single consumer pulling from the per-client request + ring. Clients write lock-free in parallel; the server is the sole + reader. +- **UDS client**: each `IpcClient` is single-threaded — share between + threads via your own synchronisation. +- **MPSC-SHM client**: lock-free producer. Multiple clients can hammer the + request ring concurrently. + +## Limitations + +- **POSIX-only**: Linux and macOS are the supported platforms (futex on + Linux, `os_sync_wait_on_address` on macOS; epoll/kqueue for sockets). + Other platforms fail the build with an explicit `#error`. +- **SHM** capacity is fixed at server-create time. Clean shutdown unlinks + the request and response shared-memory objects automatically when + `IpcServer` destructs; fatal signals best-effort unlink them when + `install_default_signal_handlers` is in place. +- **UDS** has the usual `ulimit` for file descriptors and one syscall per + send/recv. Buffer copies on send are unavoidable. + +For deep dives: + +- `cpp/ipc_runtime/shm/README.md` — lock-free ring-buffer architecture. +- `ipc-codegen/SCHEMA_SPEC.md` — wire-format details consumed by callers. diff --git a/ipc-runtime/bootstrap.sh b/ipc-runtime/bootstrap.sh new file mode 100755 index 000000000000..bd3fe351a3ed --- /dev/null +++ b/ipc-runtime/bootstrap.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash +# ipc-runtime — UDS + MPSC-SHM transport library for IPC services. +# +# Standalone-buildable: cmake + a C++20 compiler + POSIX is all that's +# required. No repo-local deps, no msgpack dep, no tracing or database +# machinery. gtest is fetched via FetchContent the first time tests are +# built (cached locally between runs in cpp/build/_deps/). +# +# Cross-compile via standard CMake toolchain knobs, e.g. with zig: +# CXX=zig-c++ \ +# CXXFLAGS="-target aarch64-linux-gnu" \ +# ./bootstrap.sh +# +# Tests are skipped automatically when cross-compiling. + +source $(git rev-parse --show-toplevel)/ci3/source_bootstrap + +hash=$(cache_content_hash .rebuild_patterns) + +BUILD_DIR=${BUILD_DIR:-cpp/build} + +function build { + echo_header "ipc-runtime build" + cmake -B "$BUILD_DIR" -S cpp -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE:-Release} + cmake --build "$BUILD_DIR" --target ipc_runtime ipc_runtime_tests + + # Native Node addon — host-arch build. Cross-arch builds go through + # `build_cross ` against the cpp/CMakePresets.json presets. + if [ "${SKIP_NAPI:-0}" -ne 1 ]; then + cmake --build "$BUILD_DIR" --target ipc_runtime_napi + local target_dir="ts/build/$(arch)-$(os)" + mkdir -p "$target_dir" + cp "$BUILD_DIR"/lib/ipc_runtime_napi.node "$target_dir/" + echo "Copied NAPI addon → $target_dir/ipc_runtime_napi.node" + fi + + # Build the TS package so file/portal-link consumers find dest/ populated + # before they typecheck. + if [ "${SKIP_TS_BUILD:-0}" -ne 1 ]; then + (cd ts && yarn install --immutable && yarn build) + fi +} + +function test_cmds { + echo "$hash:CPUS=1:TIMEOUT=120s ipc-runtime/cpp/build/ipc_runtime_tests" + echo "$hash:CPUS=4:TIMEOUT=300s ipc-runtime/scripts/run_rust_tests.sh" + echo "$hash:CPUS=1:TIMEOUT=120s ipc-runtime/scripts/run_ts_tests.sh" +} + +function test { + echo_header "ipc-runtime test" + build + test_cmds | filter_test_cmds | parallelize +} + +function clean { + rm -rf "$BUILD_DIR" +} + +function build_cross { + local arch=$1 + echo_header "ipc-runtime build_cross $arch" + (cd cpp && cmake --preset "$arch" && cmake --build --preset "$arch") +} + +function cross_copy { + ./ts/scripts/copy_cross.sh "$@" +} + +function release { + cross_copy + cd ts + retry "deploy_npm ${REF_NAME#v}" +} + +case "$cmd" in + "") + build + ;; + "hash") + echo "$hash" + ;; + *) + default_cmd_handler "$@" + ;; +esac diff --git a/ipc-runtime/cpp/.gitignore b/ipc-runtime/cpp/.gitignore new file mode 100644 index 000000000000..a5309e6b906e --- /dev/null +++ b/ipc-runtime/cpp/.gitignore @@ -0,0 +1 @@ +build*/ diff --git a/ipc-runtime/cpp/CMakeLists.txt b/ipc-runtime/cpp/CMakeLists.txt new file mode 100644 index 000000000000..2c2383fa7773 --- /dev/null +++ b/ipc-runtime/cpp/CMakeLists.txt @@ -0,0 +1,100 @@ +# ipc-runtime — UDS + MPSC-SHM transport for IPC services. +# +# Standalone CMake project with no external dependencies beyond a C++20 +# compiler + POSIX. Works two ways: +# +# 1. As a top-level standalone build: +# cd ipc-runtime/cpp && cmake -B build && cmake --build build +# +# 2. As a sub-project via `add_subdirectory()` from a larger CMake tree. +# In that case the parent's compile options +# (-Wall, -Werror, etc.) and gtest target (if any) are inherited. + +cmake_minimum_required(VERSION 3.20) + +# Only declare the project when this CMakeLists is the top-level entry point. +# Skipping project() under add_subdirectory keeps the parent's project name +# and avoids re-running platform detection. +if(NOT DEFINED PROJECT_NAME) + project(ipc_runtime CXX) + set(CMAKE_CXX_STANDARD 20) + set(CMAKE_CXX_STANDARD_REQUIRED ON) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + set(IPC_RUNTIME_IS_TOP_LEVEL ON) +else() + set(IPC_RUNTIME_IS_TOP_LEVEL OFF) +endif() + +# Tests and the NAPI addon are ipc-runtime's own concerns. Consumers that +# add_subdirectory us only want the `ipc_runtime` library target; default +# both off in that case so they don't have to know these options exist. +option(IPC_RUNTIME_BUILD_TESTS "Build ipc-runtime unit tests" ${IPC_RUNTIME_IS_TOP_LEVEL}) +option(IPC_RUNTIME_BUILD_NAPI "Build the Node native addon (ipc_runtime_napi.node) for the TS package" ${IPC_RUNTIME_IS_TOP_LEVEL}) + +# ---- Library --------------------------------------------------------------- + +# Under WASM (emscripten / no-POSIX targets) the transport sources don't +# compile, but downstream no-POSIX consumers still need the headers (the +# IpcServer/IpcClient declarations referenced from generated server code; +# the codegen template helpers themselves now live in ipc-codegen). +# Expose an INTERFACE target with just the include path in that case. +if(WASM) + add_library(ipc_runtime INTERFACE) + target_include_directories(ipc_runtime INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) +else() + add_library(ipc_runtime STATIC + ipc_runtime/c_abi.cpp + ipc_runtime/ipc_client.cpp + ipc_runtime/ipc_server.cpp + ipc_runtime/serve_helper.cpp + ipc_runtime/signal_handlers.cpp + ipc_runtime/socket_client.cpp + ipc_runtime/socket_server.cpp + ipc_runtime/shm/mpsc_shm.cpp + ipc_runtime/shm/spsc_shm.cpp + ) + target_compile_features(ipc_runtime PUBLIC cxx_std_20) + set_target_properties(ipc_runtime PROPERTIES POSITION_INDEPENDENT_CODE ON) + + # `#include "ipc_runtime/..."` resolves from this directory for both + # in-tree code and downstream consumers. + target_include_directories(ipc_runtime PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + + # pthread is the only system dep beyond libc (used by the SHM client thread + # and parent-death kqueue watcher). + find_package(Threads REQUIRED) + target_link_libraries(ipc_runtime PUBLIC Threads::Threads) +endif() + +# ---- Tests ----------------------------------------------------------------- + +if(IPC_RUNTIME_BUILD_TESTS AND NOT WASM) + # Reuse a parent project's GTest target if one is already in scope. + # Otherwise fetch a pinned release. + if(NOT TARGET GTest::gtest_main) + include(FetchContent) + FetchContent_Declare( + GTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.13.0 + GIT_SHALLOW TRUE + FIND_PACKAGE_ARGS + ) + set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(GTest) + endif() + + add_executable(ipc_runtime_tests ipc_runtime/shm.test.cpp ipc_runtime/socket.test.cpp) + target_link_libraries(ipc_runtime_tests PRIVATE ipc_runtime GTest::gtest_main) + + include(GoogleTest) + if(NOT CMAKE_CROSSCOMPILING) + gtest_discover_tests(ipc_runtime_tests DISCOVERY_TIMEOUT 30) + endif() +endif() + +# ---- Node native addon (consumed by @aztec/ipc-runtime) -------------------- + +if(IPC_RUNTIME_BUILD_NAPI AND NOT WASM) + add_subdirectory(napi) +endif() diff --git a/ipc-runtime/cpp/CMakePresets.json b/ipc-runtime/cpp/CMakePresets.json new file mode 100644 index 000000000000..f965bc41bd8f --- /dev/null +++ b/ipc-runtime/cpp/CMakePresets.json @@ -0,0 +1,106 @@ +{ + "version": 5, + "cmakeMinimumRequired": { + "major": 3, + "minor": 20, + "patch": 0 + }, + "configurePresets": [ + { + "name": "cross-base", + "displayName": "Zig (base)", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "environment": { + "CC": "zig cc", + "CXX": "zig c++", + "CFLAGS": "-g0 -fvisibility=hidden", + "CXXFLAGS": "-g0 -fvisibility=hidden" + }, + "cacheVariables": { + "CMAKE_AR": "${sourceDir}/scripts/zig-ar.sh", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_RANLIB": "${sourceDir}/scripts/zig-ranlib.sh", + "ENABLE_PIC": "ON", + "IPC_RUNTIME_BUILD_TESTS": "OFF" + } + }, + { + "name": "amd64-linux", + "inherits": "cross-base", + "environment": { + "CC": "zig cc -target x86_64-linux-gnu.2.35", + "CXX": "zig c++ -target x86_64-linux-gnu.2.35", + "LDFLAGS": "-s" + }, + "cacheVariables": { + "CMAKE_SYSTEM_NAME": "Linux" + } + }, + { + "name": "arm64-linux", + "inherits": "cross-base", + "environment": { + "CC": "zig cc -target aarch64-linux-gnu.2.35", + "CXX": "zig c++ -target aarch64-linux-gnu.2.35", + "LDFLAGS": "-s" + }, + "cacheVariables": { + "CMAKE_SYSTEM_NAME": "Linux" + } + }, + { + "name": "amd64-macos", + "inherits": "cross-base", + "environment": { + "CC": "zig cc -target x86_64-macos -mcpu=baseline", + "CXX": "zig c++ -target x86_64-macos -mcpu=baseline", + "LDFLAGS": "-s" + }, + "cacheVariables": { + "CMAKE_SYSTEM_NAME": "Darwin", + "CMAKE_SYSTEM_PROCESSOR": "x86_64" + } + }, + { + "name": "arm64-macos", + "inherits": "cross-base", + "environment": { + "CC": "zig cc -target aarch64-macos -mcpu=apple_a14", + "CXX": "zig c++ -target aarch64-macos -mcpu=apple_a14", + "LDFLAGS": "-s" + }, + "cacheVariables": { + "CMAKE_SYSTEM_NAME": "Darwin", + "CMAKE_SYSTEM_PROCESSOR": "aarch64" + } + } + ], + "buildPresets": [ + { + "name": "amd64-linux", + "configurePreset": "amd64-linux", + "inheritConfigureEnvironment": true, + "targets": ["ipc_runtime_napi"] + }, + { + "name": "arm64-linux", + "configurePreset": "arm64-linux", + "inheritConfigureEnvironment": true, + "targets": ["ipc_runtime_napi"] + }, + { + "name": "amd64-macos", + "configurePreset": "amd64-macos", + "inheritConfigureEnvironment": true, + "targets": ["ipc_runtime_napi"] + }, + { + "name": "arm64-macos", + "configurePreset": "arm64-macos", + "inheritConfigureEnvironment": true, + "targets": ["ipc_runtime_napi"] + } + ] +} diff --git a/ipc-runtime/cpp/ipc_runtime/c_abi.cpp b/ipc-runtime/cpp/ipc_runtime/c_abi.cpp new file mode 100644 index 000000000000..d02d8db8c1f7 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/c_abi.cpp @@ -0,0 +1,258 @@ +#include "ipc_runtime/c_abi.h" + +#include "ipc_runtime/ipc_client.hpp" +#include "ipc_runtime/ipc_server.hpp" +#include "ipc_runtime/serve_helper.hpp" +#include "ipc_runtime/signal_handlers.hpp" + +#include +#include +#include +#include +#include +#include + +// Opaque structs that wrap the C++ unique_ptrs. Keeping them as distinct types +// (rather than typedefs to the C++ classes) means the C ABI is a true name-only +// surface — no C++ types leak into the header. + +struct ipc_server { + std::unique_ptr impl; +}; + +struct ipc_client { + std::unique_ptr impl; +}; + +namespace { + +inline ipc_server* wrap_server(std::unique_ptr s) +{ + if (!s) { + return nullptr; + } + auto* w = new ipc_server; + w->impl = std::move(s); + return w; +} + +inline ipc_client* wrap_client(std::unique_ptr c) +{ + if (!c) { + return nullptr; + } + auto* w = new ipc_client; + w->impl = std::move(c); + return w; +} + +} // namespace + +extern "C" { + +// -------- Options ---------------------------------------------------------- + +void ipc_server_options_default(ipc_server_options_t* opts) +{ + if (!opts) { + return; + } + ipc::ServerOptions defaults; + opts->max_shm_clients = defaults.max_shm_clients; + opts->shm_request_ring_size = defaults.shm_request_ring_size; + opts->shm_response_ring_size = defaults.shm_response_ring_size; + opts->socket_backlog = defaults.socket_backlog; +} + +// -------- Server ----------------------------------------------------------- + +ipc_server_t* ipc_make_server(const char* path, const ipc_server_options_t* opts) +{ + if (!path) { + return nullptr; + } + ipc::ServerOptions cpp_opts; + if (opts) { + cpp_opts.max_shm_clients = opts->max_shm_clients; + cpp_opts.shm_request_ring_size = opts->shm_request_ring_size; + cpp_opts.shm_response_ring_size = opts->shm_response_ring_size; + cpp_opts.socket_backlog = opts->socket_backlog; + } + return wrap_server(ipc::make_server(path, cpp_opts)); +} + +ipc_server_t* ipc_server_create_socket(const char* path, int max_clients) +{ + if (!path) { + return nullptr; + } + return wrap_server(ipc::IpcServer::create_socket(path, max_clients)); +} + +ipc_server_t* ipc_server_create_mpsc_shm(const char* base_name, + size_t max_clients, + size_t request_ring_size, + size_t response_ring_size) +{ + if (!base_name) { + return nullptr; + } + return wrap_server(ipc::IpcServer::create_mpsc_shm(base_name, max_clients, request_ring_size, response_ring_size)); +} + +void ipc_server_destroy(ipc_server_t* server) +{ + delete server; +} + +bool ipc_server_listen(ipc_server_t* server) +{ + return server && server->impl ? server->impl->listen() : false; +} + +void ipc_server_close(ipc_server_t* server) +{ + if (server && server->impl) { + server->impl->close(); + } +} + +void ipc_server_request_shutdown(ipc_server_t* server) +{ + if (server && server->impl) { + server->impl->request_shutdown(); + } +} + +int ipc_server_wait_for_data(ipc_server_t* server, uint64_t timeout_ns) +{ + return server && server->impl ? server->impl->wait_for_data(timeout_ns) : -1; +} + +ipc_status_t ipc_server_receive(ipc_server_t* server, int client_id, const uint8_t** out, size_t* out_len) +{ + if (!server || !server->impl || !out || !out_len) { + return IPC_ERR_RECV; + } + auto view = server->impl->receive(client_id); + // data() == nullptr is error/timeout; a non-null empty view is a valid + // zero-length message. + if (view.data() == nullptr) { + *out = nullptr; + *out_len = 0; + return IPC_ERR_RECV; + } + *out = view.data(); + *out_len = view.size(); + return IPC_OK; +} + +void ipc_server_release(ipc_server_t* server, int client_id, size_t msg_size) +{ + if (server && server->impl) { + server->impl->release(client_id, msg_size); + } +} + +bool ipc_server_send(ipc_server_t* server, int client_id, const uint8_t* data, size_t len) +{ + return server && server->impl ? server->impl->send(client_id, data, len) : false; +} + +void ipc_server_run(ipc_server_t* server, ipc_server_handler_fn handler, void* ctx) +{ + if (!server || !server->impl || !handler) { + return; + } + server->impl->run([handler, ctx](int client_id, std::span raw) -> std::vector { + uint8_t* resp_ptr = nullptr; + size_t resp_len = 0; + handler(client_id, raw.data(), raw.size(), &resp_ptr, &resp_len, ctx); + if (!resp_ptr || resp_len == 0) { + return {}; + } + return std::vector(resp_ptr, resp_ptr + resp_len); + }); +} + +void ipc_install_default_signal_handlers(ipc_server_t* server) +{ + if (server && server->impl) { + ipc::install_default_signal_handlers(*server->impl); + } +} + +// -------- Client ----------------------------------------------------------- + +ipc_client_t* ipc_make_client(const char* path, size_t shm_client_id) +{ + if (!path) { + return nullptr; + } + return wrap_client(ipc::make_client(path, shm_client_id)); +} + +ipc_client_t* ipc_client_create_socket(const char* socket_path) +{ + if (!socket_path) { + return nullptr; + } + return wrap_client(ipc::IpcClient::create_socket(socket_path)); +} + +ipc_client_t* ipc_client_create_mpsc_shm(const char* base_name, size_t client_id) +{ + if (!base_name) { + return nullptr; + } + return wrap_client(ipc::IpcClient::create_mpsc_shm(base_name, client_id)); +} + +void ipc_client_destroy(ipc_client_t* client) +{ + delete client; +} + +bool ipc_client_connect(ipc_client_t* client) +{ + return client && client->impl ? client->impl->connect() : false; +} + +void ipc_client_close(ipc_client_t* client) +{ + if (client && client->impl) { + client->impl->close(); + } +} + +bool ipc_client_send(ipc_client_t* client, const uint8_t* data, size_t len, uint64_t timeout_ns) +{ + return client && client->impl ? client->impl->send(data, len, timeout_ns) : false; +} + +ipc_status_t ipc_client_receive(ipc_client_t* client, uint64_t timeout_ns, const uint8_t** out, size_t* out_len) +{ + if (!client || !client->impl || !out || !out_len) { + return IPC_ERR_RECV; + } + auto view = client->impl->receive(timeout_ns); + // data() == nullptr is error/timeout; a non-null empty view is a valid + // zero-length response (IPC_OK with *out_len == 0). + if (view.data() == nullptr) { + *out = nullptr; + *out_len = 0; + return IPC_ERR_RECV; + } + *out = view.data(); + *out_len = view.size(); + return IPC_OK; +} + +void ipc_client_release(ipc_client_t* client, size_t msg_size) +{ + if (client && client->impl) { + client->impl->release(msg_size); + } +} + +} // extern "C" diff --git a/ipc-runtime/cpp/ipc_runtime/c_abi.h b/ipc-runtime/cpp/ipc_runtime/c_abi.h new file mode 100644 index 000000000000..ef8641497b85 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/c_abi.h @@ -0,0 +1,143 @@ +#ifndef IPC_RUNTIME_C_ABI_H +#define IPC_RUNTIME_C_ABI_H + +/* + * Plain-C ABI for ipc-runtime. + * + * Non-C++ consumers (Rust, Zig, ...) bind to this header to use the same + * UDS + MPSC-SHM transport that C++ services use. Opaque handles wrap + * the C++ IpcServer / IpcClient objects; functions return status codes + * instead of exceptions; std::span / std::function become explicit + * (ptr, len) pairs and free function pointers. + * + * Lifetimes: + * - `ipc_server_t*` and `ipc_client_t*` are owned. Pass to ipc_server_destroy + * / ipc_client_destroy when done; until then the pointer is non-null. + * - Bytes returned by ipc_server_receive / ipc_client_receive remain valid + * until the matching `_release` call (or transport tear-down). Caller + * must NOT free them. + * - In ipc_server_run, the response buffer the handler writes to *resp_out + * is owned by the handler; the runtime copies it before send and does + * NOT free it. Either return a buffer the runtime can memcpy from and + * forget, or manage with a static buffer. + * + * Threading: all functions are blocking; one caller per handle. + * ipc_client_connect retries on the calling thread (no internal threads); + * the only thread the runtime spawns is the macOS parent-death watcher + * installed by ipc_install_default_signal_handlers. + */ + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* --- Status codes ------------------------------------------------------ */ + +/* Only the codes the runtime actually produces. IPC_ERR_RECV covers + * receive timeout and disconnect; a successful zero-length response is + * IPC_OK with *out_len == 0. */ +typedef enum { IPC_OK = 0, IPC_ERR_RECV = -5 } ipc_status_t; + +/* --- Options ----------------------------------------------------------- */ + +typedef struct { + size_t max_shm_clients; /* default: 2 */ + size_t shm_request_ring_size; /* default: 4 MiB */ + size_t shm_response_ring_size; /* default: 4 MiB */ + int socket_backlog; /* default: 1 */ +} ipc_server_options_t; + +/* Populate `opts` with the same defaults ipc::ServerOptions{} provides. */ +void ipc_server_options_default(ipc_server_options_t *opts); + +/* --- Server ------------------------------------------------------------ */ + +typedef struct ipc_server ipc_server_t; + +/* Pick UDS vs MPSC-SHM by suffix. Returns NULL if suffix unrecognised. */ +ipc_server_t *ipc_make_server(const char *path, + const ipc_server_options_t *opts); + +ipc_server_t *ipc_server_create_socket(const char *path, int max_clients); +ipc_server_t *ipc_server_create_mpsc_shm(const char *base_name, + size_t max_clients, + size_t request_ring_size, + size_t response_ring_size); + +void ipc_server_destroy(ipc_server_t *server); + +bool ipc_server_listen(ipc_server_t *server); +void ipc_server_close(ipc_server_t *server); +void ipc_server_request_shutdown(ipc_server_t *server); + +/* Returns client_id ≥ 0, or -1 on timeout/error. */ +int ipc_server_wait_for_data(ipc_server_t *server, uint64_t timeout_ns); + +/* On success: *out / *out_len reference an internal buffer valid until + * ipc_server_release(). Returns IPC_OK or a negative status. */ +ipc_status_t ipc_server_receive(ipc_server_t *server, int client_id, + const uint8_t **out, size_t *out_len); + +void ipc_server_release(ipc_server_t *server, int client_id, size_t msg_size); + +bool ipc_server_send(ipc_server_t *server, int client_id, const uint8_t *data, + size_t len); + +/* Convenience event loop. The handler is called for each incoming message; + * it writes the response into a buffer it owns and stores the pointer + + * length in *resp_out / *resp_len_out. The runtime copies the response + * into its send path and does not free the handler's buffer — the handler + * is responsible (e.g. via a thread-local arena). + * + * Every request gets exactly one response; leaving *resp_out unset (or + * setting *resp_len_out = 0) sends a zero-length response frame. To exit + * the loop, call ipc_server_request_shutdown() from inside the handler. + */ +typedef void (*ipc_server_handler_fn)(int client_id, const uint8_t *req, + size_t req_len, uint8_t **resp_out, + size_t *resp_len_out, void *ctx); + +void ipc_server_run(ipc_server_t *server, ipc_server_handler_fn handler, + void *ctx); + +/* Install SIGTERM/SIGINT graceful-shutdown + SIGBUS/SIGSEGV close + + * parent-death monitoring (prctl on linux, kqueue NOTE_EXIT on macOS) wired to + * `server`. */ +void ipc_install_default_signal_handlers(ipc_server_t *server); + +/* --- Client ------------------------------------------------------------ */ + +typedef struct ipc_client ipc_client_t; + +ipc_client_t *ipc_make_client(const char *path, size_t shm_client_id); + +ipc_client_t *ipc_client_create_socket(const char *socket_path); +ipc_client_t *ipc_client_create_mpsc_shm(const char *base_name, + size_t client_id); + +void ipc_client_destroy(ipc_client_t *client); + +bool ipc_client_connect(ipc_client_t *client); +void ipc_client_close(ipc_client_t *client); + +bool ipc_client_send(ipc_client_t *client, const uint8_t *data, size_t len, + uint64_t timeout_ns); + +/* On success: IPC_OK with *out / *out_len referencing an internal buffer + * valid until ipc_client_release(). A zero-length response is IPC_OK with + * *out_len == 0 (still followed by ipc_client_release(0)). IPC_ERR_RECV + * means timeout or disconnect. timeout_ns == 0 means infinite. */ +ipc_status_t ipc_client_receive(ipc_client_t *client, uint64_t timeout_ns, + const uint8_t **out, size_t *out_len); + +void ipc_client_release(ipc_client_t *client, size_t msg_size); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* IPC_RUNTIME_C_ABI_H */ diff --git a/ipc-runtime/cpp/ipc_runtime/constants.hpp b/ipc-runtime/cpp/ipc_runtime/constants.hpp new file mode 100644 index 000000000000..ddc054a04f16 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/constants.hpp @@ -0,0 +1,58 @@ +#pragma once +/** + * @file constants.hpp + * @brief Shared transport constants for ipc-runtime. + * + * Single definition for the limits and defaults shared across the UDS / + * SPSC-SHM / MPSC-SHM transports and their language bindings, so they stay + * consistent. Mirrored for TypeScript in ts/src/types.ts — keep the two in sync. + */ + +#include +#include + +namespace ipc { + +/** + * Maximum length-prefix value accepted on receive, across all transports and + * languages. A frame claiming more than this is treated as corruption: the + * connection is closed (sockets) or the ring is declared corrupt (SHM), + * instead of allocating/awaiting the claimed size. + */ +inline constexpr uint32_t MAX_FRAME_SIZE = 256U * 1024 * 1024; // 256 MiB + +/** + * Total budget for connect() retry loops, covering the window where the + * server process is still starting up. Shared by UDS and SHM clients. + */ +inline constexpr uint64_t CONNECT_RETRY_BUDGET_MS = 5000; +/** Delay between connect() attempts within the retry budget. */ +inline constexpr uint64_t CONNECT_RETRY_DELAY_MS = 10; + +/** Default ring size for SHM transports (per direction, per client). */ +inline constexpr size_t DEFAULT_RING_SIZE = 4 * 1024 * 1024; // 4 MiB + +/** Default listen backlog for UDS servers. */ +inline constexpr int SOCKET_BACKLOG = 10; + +/** + * Default per-call timeout for client send/receive: 0 = infinite. + * + * Timeout semantics, unified across transports: + * - IpcClient::send / IpcClient::receive: 0 = block indefinitely. + * - IpcServer::wait_for_data: 0 = non-blocking poll (documented exception). + * - SHM ring primitives (claim/peek/wait_for_*): 0 = immediate check; the + * client/server layers translate 0 → infinite before reaching the rings. + */ +inline constexpr uint64_t DEFAULT_CALL_TIMEOUT_NS = 0; + +/** Internal representation of an infinite timeout at the ring layer. */ +inline constexpr uint64_t TIMEOUT_INFINITE_NS = UINT64_MAX; + +/** Translate the public "0 = infinite" convention to the ring layer's. */ +inline constexpr uint64_t normalize_call_timeout(uint64_t timeout_ns) +{ + return timeout_ns == 0 ? TIMEOUT_INFINITE_NS : timeout_ns; +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/grind_ipc.sh b/ipc-runtime/cpp/ipc_runtime/grind_ipc.sh new file mode 100755 index 000000000000..5426cb701e50 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/grind_ipc.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Stress-grind the SHM ring test in parallel. Usage: grind_ipc.sh [jobs] +source $(git rev-parse --show-toplevel)/ci3/source + +cd "$(dirname "$0")" + +trap 'clean' EXIT + +function clean { + rm -f /dev/shm/shm_wrap_* +} + +jobs=${1:-128} +if [ $# -gt 0 ]; then + shift +fi + +clean +# Copy so a rebuild mid-grind doesn't swap the binary under running jobs. +cp ../build/ipc_runtime_tests ../build/ipc_runtime_tests_live +while true; do + echo "dump_fail '$@ timeout 30s ../build/ipc_runtime_tests_live --gtest_filter=ShmTest.SingleClientSmallRingHighVolume &> >(add_timestamps && date)' >/dev/null" +done | parallel -j$jobs --halt now,fail=1 diff --git a/ipc-runtime/cpp/ipc_runtime/ipc_client.cpp b/ipc-runtime/cpp/ipc_runtime/ipc_client.cpp new file mode 100644 index 000000000000..87d78583fc1f --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/ipc_client.cpp @@ -0,0 +1,26 @@ +#include "ipc_runtime/ipc_client.hpp" +#include "ipc_runtime/mpsc_shm_client.hpp" +#include "ipc_runtime/shm_client.hpp" +#include "ipc_runtime/socket_client.hpp" +#include +#include +#include + +namespace ipc { + +std::unique_ptr IpcClient::create_socket(const std::string& socket_path) +{ + return std::make_unique(socket_path); +} + +std::unique_ptr IpcClient::create_shm(const std::string& base_name) +{ + return std::make_unique(base_name); +} + +std::unique_ptr IpcClient::create_mpsc_shm(const std::string& base_name, size_t client_id) +{ + return std::make_unique(base_name, client_id); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/ipc_client.hpp b/ipc-runtime/cpp/ipc_runtime/ipc_client.hpp new file mode 100644 index 000000000000..5f31a6051dc1 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/ipc_client.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief Abstract interface for IPC client + * + * Provides a unified interface for connecting to IPC servers and exchanging messages. + * Implementations handle transport-specific details (Unix domain sockets, shared memory, etc). + */ +class IpcClient { + public: + IpcClient() = default; + virtual ~IpcClient() = default; + + // Abstract interface - no copy or move + IpcClient(const IpcClient&) = delete; + IpcClient& operator=(const IpcClient&) = delete; + IpcClient(IpcClient&&) = delete; + IpcClient& operator=(IpcClient&&) = delete; + + /** + * @brief Connect to the server + * @return true if connection successful, false otherwise + */ + virtual bool connect() = 0; + + /** + * @brief Send a message to the server + * @param data Pointer to message data + * @param len Length of message in bytes + * @param timeout_ns Timeout in nanoseconds (0 = infinite) + * @return true if sent successfully, false on error or timeout + */ + virtual bool send(const void* data, size_t len, uint64_t timeout_ns) = 0; + + /** + * @brief Receive a message from the server (zero-copy for shared memory) + * @param timeout_ns Timeout in nanoseconds (0 = infinite) + * @return Span of message data. data() == nullptr means error/timeout; + * a non-null span of size 0 is a valid zero-length message. + * + * The span remains valid until release() is called or the next recv(). + * For shared memory: direct view into ring buffer (true zero-copy) + * For sockets: view into internal buffer (eliminates one copy) + * + * Must be followed by release() to consume the message. + */ + virtual std::span receive(uint64_t timeout_ns) = 0; + + /** + * @brief Wake any thread blocked in receive()/send() (for shutdown). + * + * Default no-op; SHM transports wake futex waiters on their rings. + */ + virtual void wakeup() {} + + /** + * @brief Release the previously received message + * @param message_size Size of the message being released (from span.size()) + * + * Must be called after recv() to consume the message and free resources. + * For shared memory: releases space in the ring buffer + * For sockets: no-op (message already consumed during recv) + */ + virtual void release(size_t message_size) = 0; + + /** + * @brief Close the connection + */ + virtual void close() = 0; + + // Factory methods. + static std::unique_ptr create_socket(const std::string& socket_path); + // Single-client SHM: one request ring and one response ring. Use this + // directly when the service only needs one producer/client. + static std::unique_ptr create_shm(const std::string& base_name); + // Multi-producer SHM: one request ring per client slot and one response + // ring per client slot. This is what make_client("*.shm") selects. + static std::unique_ptr create_mpsc_shm(const std::string& base_name, size_t client_id); +}; + +/** + * @brief Construct an IpcClient based on the input path's suffix. + * + * Recognised suffixes: + * - "*.sock" → IpcClient::create_socket(path) + * - "*.shm" → IpcClient::create_mpsc_shm(, client_id) + * + * Returns nullptr if the suffix is not recognised. `shm_client_id` is only + * consulted for the SHM path; for MPSC-SHM, each connecting client picks a + * distinct slot (0..max_clients-1). + * + * @param input_path Path passed by the caller (often a CLI flag). + * @param shm_client_id Client slot to claim in MPSC-SHM mode. Ignored for UDS. + */ +std::unique_ptr make_client(const std::string& input_path, std::size_t shm_client_id = 0); + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/ipc_server.cpp b/ipc-runtime/cpp/ipc_runtime/ipc_server.cpp new file mode 100644 index 000000000000..b57760967e2c --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/ipc_server.cpp @@ -0,0 +1,31 @@ +#include "ipc_runtime/ipc_server.hpp" +#include "ipc_runtime/mpsc_shm_server.hpp" +#include "ipc_runtime/shm_server.hpp" +#include "ipc_runtime/socket_server.hpp" +#include +#include +#include + +namespace ipc { + +std::unique_ptr IpcServer::create_socket(const std::string& socket_path, int max_clients) +{ + return std::make_unique(socket_path, max_clients); +} + +std::unique_ptr IpcServer::create_shm(const std::string& base_name, + size_t request_ring_size, + size_t response_ring_size) +{ + return std::make_unique(base_name, request_ring_size, response_ring_size); +} + +std::unique_ptr IpcServer::create_mpsc_shm(const std::string& base_name, + size_t max_clients, + size_t request_ring_size, + size_t response_ring_size) +{ + return std::make_unique(base_name, max_clients, request_ring_size, response_ring_size); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/ipc_server.hpp b/ipc-runtime/cpp/ipc_runtime/ipc_server.hpp new file mode 100644 index 000000000000..166559f583e3 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/ipc_server.hpp @@ -0,0 +1,212 @@ +#pragma once + +#include "ipc_runtime/constants.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief Abstract interface for IPC server + * + * Provides a unified interface for accepting client connections and exchanging + * messages. Implementations handle transport-specific details (Unix domain + * sockets, shared memory, etc). + */ +class IpcServer { + public: + IpcServer() = default; + virtual ~IpcServer() = default; + + // Abstract interface - no copy or move + IpcServer(const IpcServer&) = delete; + IpcServer& operator=(const IpcServer&) = delete; + IpcServer(IpcServer&&) = delete; + IpcServer& operator=(IpcServer&&) = delete; + + /** + * @brief Start listening for client connections + * @return true if successful, false otherwise + */ + virtual bool listen() = 0; + + /** + * @brief Wait for data from any connected client + * + * @param timeout_ns Maximum time to wait in nanoseconds (0 = non-blocking + * poll) + * @return Client ID that has data available, or -1 on timeout/error + */ + virtual int wait_for_data(uint64_t timeout_ns) = 0; + + /** + * @brief Receive next message from a specific client + * + * Blocks until a complete message is available. Returns a span pointing to + * the message data. For shared memory, this is a zero-copy view directly into + * the ring buffer. For sockets, this is a view into an internal buffer. + * + * The message remains valid until release() is called with the message size. + * + * @param client_id Client to receive from + * @return Span of message data (empty only on error/disconnect) + */ + virtual std::span receive(int client_id) = 0; + + /** + * @brief Release/consume the previously received message + * + * Must be called after receive() to advance to the next message. + * For shared memory, this releases space in the ring buffer. + * For sockets, this is a no-op (message already consumed during receive). + * + * @param client_id Client whose message to release + * @param message_size Size of the message being released (from span.size()) + */ + virtual void release(int client_id, size_t message_size) = 0; + + /** + * @brief Send a message to a specific client + * @param client_id Client to send to + * @param data Pointer to message data + * @param len Length of message in bytes + * @return true if sent successfully, false on error + */ + virtual bool send(int client_id, const void* data, size_t len) = 0; + + /** + * @brief Close the server and all client connections + */ + virtual void close() = 0; + + /** + * @brief Request graceful shutdown. + * + * Sets shutdown flag and wakes all blocked threads. After this returns, the + * run() loop will exit on its next iteration. Call close() afterward to clean + * up resources. + */ + virtual void request_shutdown() + { + shutdown_requested_.store(true, std::memory_order_release); + wakeup_all(); + } + + void request_shutdown_from_signal() noexcept { shutdown_requested_.store(true, std::memory_order_release); } + + /** + * @brief Filesystem artifacts to remove if the process dies abnormally. + * + * Used by install_default_signal_handlers() to cache (at install time) the + * socket paths / shared-memory names a fatal-signal handler should + * best-effort unlink. Empty by default. + */ + struct CleanupPaths { + std::vector unlink_paths; // removed via ::unlink() + std::vector shm_unlink_names; // removed via ::shm_unlink() + }; + virtual CleanupPaths cleanup_paths() const { return {}; } + + /** + * @brief High-level request handler function type + * + * Takes client_id and request data, returns response data. + * Every request gets exactly one response; an empty vector is sent as a + * zero-length response frame. + */ + using Handler = std::function(int client_id, std::span request)>; + + /** + * @brief Accept pending client connections without blocking (optional for + * some transports) + * @return Client ID if successful, -1 if no pending connection or error + * + * Note: Some transports (like shared memory) may not need explicit accept + * calls. + */ + virtual int accept() { return -1; } + + /** + * @brief Run server event loop with handler + * + * Continuously waits for client requests and invokes handler. + * Handler is responsible for deserializing request, processing, and + * serializing response. This is a convenience method that encapsulates the + * typical server loop. + * + * Uses peek/release pattern: + * - peek() returns a span (zero-copy for SHM, internal buffer for sockets) + * - handler processes the request + * - release() explicitly consumes the message + * + * This design ensures no messages are lost and enables zero-copy for shared + * memory. + * + * @param handler Function to process requests and generate responses + */ + virtual void run(const Handler& handler) + { + while (!shutdown_requested_.load(std::memory_order_acquire)) { + // Try to accept new clients (non-blocking for socket servers) + accept(); + + int client_id = wait_for_data(100000000); // 100ms timeout + if (client_id < 0) { + // Timeout or error - check shutdown flag on next iteration + continue; + } + + // Receive message (blocks until complete message available, zero-copy for + // SHM). A null data() means error/timeout; a non-null empty span is a + // valid zero-length request. + auto request = receive(client_id); + if (request.data() == nullptr) { + continue; + } + + // Always send the response frame — a zero-length response is still a + // response, and skipping it would deadlock the waiting client. + auto response = handler(client_id, request); + send(client_id, response.data(), response.size()); + + // Explicitly release/consume the message. + release(client_id, request.size()); + } + } + + // Factory methods. + static std::unique_ptr create_socket(const std::string& socket_path, int max_clients); + // Single-client SHM: one request ring and one response ring. Use this + // directly when the service only needs one producer/client. + static std::unique_ptr create_shm(const std::string& base_name, + size_t request_ring_size = DEFAULT_RING_SIZE, + size_t response_ring_size = DEFAULT_RING_SIZE); + // Multi-producer SHM: one request ring per client slot and one response + // ring per client slot. This is what make_server("*.shm") selects. + static std::unique_ptr create_mpsc_shm(const std::string& base_name, + size_t max_clients, + size_t request_ring_size = DEFAULT_RING_SIZE, + size_t response_ring_size = DEFAULT_RING_SIZE); + + protected: + std::atomic shutdown_requested_{ false }; + + /** + * @brief Wake all blocked threads (for graceful shutdown) + * + * Wakes any threads blocked in wait_for_data() or other blocking operations. + * Used by signal handlers to trigger graceful shutdown without waiting for + * timeouts. + */ + virtual void wakeup_all() {}; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/mpsc_shm_client.hpp b/ipc-runtime/cpp/ipc_runtime/mpsc_shm_client.hpp new file mode 100644 index 000000000000..039741b1d18c --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/mpsc_shm_client.hpp @@ -0,0 +1,131 @@ +#pragma once + +#include "constants.hpp" +#include "ipc_client.hpp" +#include "shm/mpsc_shm.hpp" +#include "shm/spsc_shm.hpp" +#include "shm_common.hpp" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC client for multi-client shared memory server + * + * Uses MpscProducer for sending requests and a dedicated SPSC ring for + * receiving responses. Each client is assigned a unique client_id. + */ +class MpscShmClient : public IpcClient { + public: + MpscShmClient(std::string base_name, size_t client_id) + : base_name_(std::move(base_name)) + , client_id_(client_id) + {} + + ~MpscShmClient() override = default; + + // Non-copyable, non-movable + MpscShmClient(const MpscShmClient&) = delete; + MpscShmClient& operator=(const MpscShmClient&) = delete; + MpscShmClient(MpscShmClient&&) = delete; + MpscShmClient& operator=(MpscShmClient&&) = delete; + + bool connect() override + { + if (producer_.has_value()) { + return true; // Already connected + } + + constexpr size_t max_attempts = CONNECT_RETRY_BUDGET_MS / CONNECT_RETRY_DELAY_MS; + constexpr auto retry_delay = std::chrono::milliseconds(CONNECT_RETRY_DELAY_MS); + + for (size_t attempt = 0; attempt < max_attempts; ++attempt) { + try { + // Connect as producer to the MPSC request system + producer_ = MpscProducer::connect(base_name_ + "_req", client_id_); + + // Connect to our dedicated SPSC response ring + std::string resp_name = base_name_ + "_resp_" + std::to_string(client_id_); + response_ring_ = SpscShm::connect(resp_name); + + return true; + } catch (...) { + producer_.reset(); + response_ring_.reset(); + if (attempt + 1 == max_attempts) { + return false; + } + std::this_thread::sleep_for(retry_delay); + } + } + + return false; + } + + bool send(const void* data, size_t len, uint64_t timeout_ns) override + { + if (!producer_.has_value()) { + return false; + } + + // Claim space for length prefix + data + size_t total_size = sizeof(uint32_t) + len; + void* buf = producer_->claim(total_size, normalize_call_timeout(timeout_ns)); + if (buf == nullptr) { + return false; + } + + // Write length prefix + data + auto len_u32 = static_cast(len); + std::memcpy(buf, &len_u32, sizeof(uint32_t)); + std::memcpy(static_cast(buf) + sizeof(uint32_t), data, len); + + // Publish (rings doorbell to wake server) + producer_->publish(total_size); + return true; + } + + std::span receive(uint64_t timeout_ns) override + { + if (!response_ring_.has_value()) { + return {}; + } + return ring_receive_msg(response_ring_.value(), normalize_call_timeout(timeout_ns)); + } + + void release(size_t message_size) override + { + if (!response_ring_.has_value()) { + return; + } + response_ring_->release(sizeof(uint32_t) + message_size); + } + + void close() override + { + producer_.reset(); + response_ring_.reset(); + } + + void wakeup() override + { + if (response_ring_.has_value()) { + response_ring_->wakeup_all(); + } + } + + private: + std::string base_name_; + size_t client_id_; + std::optional producer_; + std::optional response_ring_; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/mpsc_shm_server.hpp b/ipc-runtime/cpp/ipc_runtime/mpsc_shm_server.hpp new file mode 100644 index 000000000000..1c8b22b7f21b --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/mpsc_shm_server.hpp @@ -0,0 +1,170 @@ +#pragma once + +#include "ipc_server.hpp" +#include "shm/mpsc_shm.hpp" +#include "shm/spsc_shm.hpp" +#include "shm_common.hpp" +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC server implementation using shared memory with multi-client support + * + * Uses MPSC (multi-producer single-consumer) for requests and per-client SPSC + * rings for responses. Supports up to max_clients concurrent clients. + * + * Shared memory layout: + * - Request: MPSC consumer with one SPSC ring per client (client writes, server reads) + * - Response: Separate SPSC ring per client (server writes, client reads) + */ +class MpscShmServer : public IpcServer { + public: + MpscShmServer(std::string base_name, + size_t max_clients, + size_t request_ring_size = DEFAULT_RING_SIZE, + size_t response_ring_size = DEFAULT_RING_SIZE) + : base_name_(std::move(base_name)) + , max_clients_(max_clients) + , request_ring_size_(request_ring_size) + , response_ring_size_(response_ring_size) + {} + + ~MpscShmServer() override { close(); } + + // Non-copyable, non-movable + MpscShmServer(const MpscShmServer&) = delete; + MpscShmServer& operator=(const MpscShmServer&) = delete; + MpscShmServer(MpscShmServer&&) = delete; + MpscShmServer& operator=(MpscShmServer&&) = delete; + + bool listen() override + { + if (request_consumer_.has_value()) { + return true; // Already listening + } + + // Clean up any leftover shared memory + MpscConsumer::unlink(base_name_ + "_req", max_clients_); + for (size_t i = 0; i < max_clients_; i++) { + SpscShm::unlink(base_name_ + "_resp_" + std::to_string(i)); + } + + try { + // Create MPSC consumer for requests (one ring per client) + request_consumer_ = MpscConsumer::create(base_name_ + "_req", max_clients_, request_ring_size_); + + // Create per-client SPSC response rings + response_rings_.reserve(max_clients_); + for (size_t i = 0; i < max_clients_; i++) { + std::string resp_name = base_name_ + "_resp_" + std::to_string(i); + response_rings_.push_back(SpscShm::create(resp_name, response_ring_size_)); + } + + return true; + } catch (...) { + close(); + return false; + } + } + + int wait_for_data(uint64_t timeout_ns) override + { + if (!request_consumer_.has_value()) { + return -1; + } + // MpscConsumer::wait_for_data returns ring index = client_id + return request_consumer_->wait_for_data(timeout_ns); + } + + std::span receive(int client_id) override + { + if (!request_consumer_.has_value() || client_id < 0 || static_cast(client_id) >= max_clients_) { + return {}; + } + // Peek on the specific client's request ring via MpscConsumer + void* len_ptr = request_consumer_->peek(static_cast(client_id), sizeof(uint32_t), 100000000); + if (len_ptr == nullptr) { + return {}; + } + uint32_t msg_len = 0; + std::memcpy(&msg_len, len_ptr, sizeof(uint32_t)); + + // A prefix larger than the send side could legally publish means the + // ring is corrupt — error out instead of waiting for it. + if (msg_len > MAX_FRAME_SIZE || msg_len > request_ring_size_ / 2 - sizeof(uint32_t)) { + throw std::runtime_error("MpscShmServer::receive: corrupt length prefix (" + std::to_string(msg_len) + + " bytes exceeds ring/frame limits)"); + } + + void* msg_ptr = request_consumer_->peek(static_cast(client_id), sizeof(uint32_t) + msg_len, 100000000); + if (msg_ptr == nullptr) { + return {}; + } + return std::span(static_cast(msg_ptr) + sizeof(uint32_t), msg_len); + } + + void release(int client_id, size_t message_size) override + { + if (!request_consumer_.has_value() || client_id < 0 || static_cast(client_id) >= max_clients_) { + return; + } + request_consumer_->release(static_cast(client_id), sizeof(uint32_t) + message_size); + } + + bool send(int client_id, const void* data, size_t len) override + { + if (client_id < 0 || static_cast(client_id) >= response_rings_.size()) { + return false; + } + return ring_send_msg(response_rings_[static_cast(client_id)], data, len, 100000000); + } + + void close() override + { + request_consumer_.reset(); + response_rings_.clear(); + + // Clean up shared memory + MpscConsumer::unlink(base_name_ + "_req", max_clients_); + for (size_t i = 0; i < max_clients_; i++) { + SpscShm::unlink(base_name_ + "_resp_" + std::to_string(i)); + } + } + + void wakeup_all() override + { + if (request_consumer_.has_value()) { + request_consumer_->wakeup_all(); + } + for (auto& ring : response_rings_) { + ring.wakeup_all(); + } + } + + CleanupPaths cleanup_paths() const override + { + CleanupPaths paths; + paths.shm_unlink_names.push_back(base_name_ + "_req_doorbell"); + for (size_t i = 0; i < max_clients_; i++) { + paths.shm_unlink_names.push_back(base_name_ + "_req_ring_" + std::to_string(i)); + paths.shm_unlink_names.push_back(base_name_ + "_resp_" + std::to_string(i)); + } + return paths; + } + + private: + std::string base_name_; + size_t max_clients_; + size_t request_ring_size_; + size_t response_ring_size_; + std::optional request_consumer_; + std::vector response_rings_; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/serve_helper.cpp b/ipc-runtime/cpp/ipc_runtime/serve_helper.cpp new file mode 100644 index 000000000000..4943e50ba609 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/serve_helper.cpp @@ -0,0 +1,43 @@ +#include "ipc_runtime/serve_helper.hpp" + +#include +#include + +namespace ipc { + +namespace { + +bool ends_with(const std::string& s, const std::string& suffix) +{ + return s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +} // namespace + +std::unique_ptr make_server(const std::string& input_path, const ServerOptions& opts) +{ + if (ends_with(input_path, ".sock")) { + return IpcServer::create_socket(input_path, opts.socket_backlog); + } + if (ends_with(input_path, ".shm")) { + // SHM mode uses the base name (suffix stripped) as the shared-memory key. + std::string base_name = input_path.substr(0, input_path.size() - 4); + return IpcServer::create_mpsc_shm( + base_name, opts.max_shm_clients, opts.shm_request_ring_size, opts.shm_response_ring_size); + } + return nullptr; +} + +std::unique_ptr make_client(const std::string& input_path, std::size_t shm_client_id) +{ + if (ends_with(input_path, ".sock")) { + return IpcClient::create_socket(input_path); + } + if (ends_with(input_path, ".shm")) { + std::string base_name = input_path.substr(0, input_path.size() - 4); + return IpcClient::create_mpsc_shm(base_name, shm_client_id); + } + return nullptr; +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/serve_helper.hpp b/ipc-runtime/cpp/ipc_runtime/serve_helper.hpp new file mode 100644 index 000000000000..8c5d690bf708 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/serve_helper.hpp @@ -0,0 +1,54 @@ +#pragma once +/** + * @file serve_helper.hpp + * @brief Factory helpers for instantiating IpcServer from a path string. + * + * The make_server() helper picks the right transport (Unix domain socket + * vs. MPSC shared-memory) based on the input path's suffix: + * ".sock" → UDS, ".shm" → MPSC-SHM. + * This keeps per-service main() code free of transport-selection logic. + */ + +#include "ipc_runtime/constants.hpp" +#include "ipc_runtime/ipc_client.hpp" +#include "ipc_runtime/ipc_server.hpp" + +#include +#include +#include + +namespace ipc { + +/// Options for make_server(). +struct ServerOptions { + /// Maximum concurrent SHM clients (only used when .shm path is chosen). + /// Default 2: enough for a primary client plus one auxiliary native client. + std::size_t max_shm_clients = 2; + /// SHM request ring size (per-client → server). + std::size_t shm_request_ring_size = DEFAULT_RING_SIZE; + /// SHM response ring size (server → per-client). + std::size_t shm_response_ring_size = DEFAULT_RING_SIZE; + /// Listen backlog for UDS mode. + int socket_backlog = SOCKET_BACKLOG; +}; + +/** + * @brief Construct an IpcServer based on the input path's suffix. + * + * Recognised suffixes: + * - "*.sock" → IpcServer::create_socket(path, opts.socket_backlog) + * - "*.shm" → IpcServer::create_mpsc_shm(, opts.max_shm_clients, + * opts.shm_request_ring_size, + * opts.shm_response_ring_size) + * The single-client SHM factories (`IpcServer::create_shm` / + * `IpcClient::create_shm`) are intentionally not selected by suffix; call + * them directly when a service does not need multiple producers. + * + * Returns nullptr if the suffix is not recognised. + * + * @param input_path Path passed by the caller (often a CLI flag). + * @param opts SHM and socket tuning knobs. + */ +std::unique_ptr make_server(const std::string& input_path, const ServerOptions& opts = {}); + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm.test.cpp b/ipc-runtime/cpp/ipc_runtime/shm.test.cpp new file mode 100644 index 000000000000..08116147488a --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm.test.cpp @@ -0,0 +1,375 @@ +#include "ipc_runtime/ipc_client.hpp" +#include "ipc_runtime/ipc_server.hpp" +#include "ipc_runtime/shm/spsc_shm.hpp" +#include "ipc_runtime/shm_client.hpp" +#include "ipc_runtime/shm_common.hpp" +#include "ipc_runtime/shm_server.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ipc; + +namespace { + +/** + * You can really stress test this with grind_ipc.sh + */ +TEST(ShmTest, SingleClientSmallRingHighVolume) +{ + constexpr size_t RING_SIZE = 2UL * 1024; + constexpr size_t NUM_ITERATIONS = 10000000; + // Sizing ensures that no matter that state of the internal ring buffer, we + // can't deadlock. + constexpr size_t MAX_MSG_SIZE = (RING_SIZE / 2) - 4; + + // Use short name for macOS compatibility (31-char limit) + std::string wrap_test_shm = "shm_wrap_" + std::to_string(getpid()); + auto server = IpcServer::create_shm(wrap_test_shm, RING_SIZE, RING_SIZE); + ASSERT_TRUE(server->listen()) << "Wrap test server failed to listen"; + + std::atomic server_running{ true }; + std::atomic corruptions{ 0 }; + + // Echo server with validation + std::thread server_thread([&]() { + size_t iter = 0; + while (server_running.load(std::memory_order_acquire)) { + server->accept(); + + int client_id = server->wait_for_data(10000000); // 10ms + if (client_id < 0) { + continue; + } + + auto request_buf = server->receive(client_id); + // std::cerr << "Server received " << request.size() << " bytes" << '\n'; + + if (request_buf.empty()) { + continue; + } + + // Take a copy of the request so we can release. + std::vector request(request_buf.begin(), request_buf.end()); + server->release(client_id, request.size()); + + // Validate pattern: first byte should be XOR with offsets + // Check a few bytes to detect corruption without slowing down too much + if (request.size() > 0) { + uint8_t first = request[0]; + for (size_t i = 0; i < std::min(request.size(), size_t(16)); i++) { + uint8_t expected = static_cast((first ^ i) & 0xFF); + if (request[i] != expected) { + corruptions.fetch_add(1); + std::cerr << "Pattern mismatch at offset " << i << ": expected=" << (int)expected + << " actual=" << (int)request[i] << '\n'; + break; + } + } + } + + // Retry send until success. + while (!server->send(client_id, request.data(), request.size())) { + // Timeout - retry (response ring might be full) + std::cerr << iter << " Server send size " << request.size() << " timeout, retrying..." << '\n'; + dynamic_cast(server.get())->debug_dump(); + } + // std::cerr << "Server sent response of " << request.size() << " bytes" + // << '\n'; + iter++; + } + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(300)); + + auto client = IpcClient::create_shm(wrap_test_shm); + ASSERT_TRUE(client->connect()); + + // Random message sizes. + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution size_dist(1, MAX_MSG_SIZE); + + // Store sizes for each iteration so receiver knows what to expect + std::vector iteration_sizes(NUM_ITERATIONS); + for (size_t i = 0; i < NUM_ITERATIONS; i++) { + iteration_sizes[i] = size_dist(gen); + // iteration_sizes[i] = MAX_MSG_SIZE - 1; + } + + // Sender thread: continuously send requests + std::thread sender_thread([&]() { + std::vector send_buffer(MAX_MSG_SIZE); + + for (size_t iter = 0; iter < NUM_ITERATIONS; iter++) { + size_t size = iteration_sizes[iter]; + // std::cerr << "Client: Iteration " << iter << ": sending " << size << " + // bytes" << '\n'; + + // Fill buffer with iteration-specific pattern + // First byte is iteration number (mod 256), rest is XOR pattern with + // offset + uint8_t iter_byte = static_cast(iter & 0xFF); + for (size_t i = 0; i < size; i++) { + send_buffer[i] = static_cast((iter_byte ^ i) & 0xFF); + } + + // Retry send until success - timeouts are expected under extreme load + while (!client->send(send_buffer.data(), size, 100000000)) { + // Timeout - retry (ring might be full, server might be slow) + std::cerr << iter << " Client send size " << size << " timeout, retrying..." << '\n'; + dynamic_cast(client.get())->debug_dump(); + } + } + }); + + // Receiver thread: continuously receive and validate responses + std::thread receiver_thread([&]() { + for (size_t iter = 0; iter < NUM_ITERATIONS; iter++) { + size_t expected_size = iteration_sizes[iter]; + + // Retry recv until success - timeouts are expected under extreme load + std::span response; + while ((response = client->receive(100000000)).empty()) { + std::cerr << iter << " Client receive timeout, retrying..." << '\n'; + // Timeout - retry + } + // std::cerr << "Client received response of " << response.size() << " + // bytes" << '\n'; + + ASSERT_EQ(response.size(), expected_size) << "Size mismatch at iteration " << iter; + + // Validate entire response - check iteration byte and pattern + uint8_t iter_byte = static_cast(iter & 0xFF); + if (response.size() > 0) { + ASSERT_EQ(response[0], iter_byte) << "Iteration byte mismatch at iteration " << iter; + for (size_t i = 0; i < response.size(); i++) { + uint8_t expected = static_cast((iter_byte ^ i) & 0xFF); + if (response[i] != expected) { + FAIL() << "Data corruption at iteration " << iter << " offset " << i + << ": expected=" << (int)expected << " actual=" << (int)response[i]; + } + } + } + + client->release(response.size()); + } + }); + + sender_thread.join(); + receiver_thread.join(); + + client->close(); + + server_running.store(false); + server->request_shutdown(); + server_thread.join(); + server->close(); + + EXPECT_EQ(corruptions.load(), 0) << "Corruptions detected in single-threaded wrap test"; +} + +/** + * A handler returning an empty vector must still produce a (zero-length) + * response frame, otherwise the client deadlocks waiting for it. + */ +TEST(ShmTest, ZeroLengthResponseRoundTrip) +{ + constexpr size_t RING_SIZE = 4UL * 1024; + std::string base_name = "shm_zlen_" + std::to_string(getpid()); + auto server = IpcServer::create_shm(base_name, RING_SIZE, RING_SIZE); + ASSERT_TRUE(server->listen()); + + std::thread server_thread( + [&] { server->run([](int, std::span) { return std::vector{}; }); }); + + auto client = IpcClient::create_shm(base_name); + ASSERT_TRUE(client->connect()); + + uint8_t byte = 42; + ASSERT_TRUE(client->send(&byte, 1, 1'000'000'000ULL)); + + // Bounded retry loop: data() == nullptr means timeout, a non-null span of + // size 0 is a successful zero-length response. + std::span resp; + for (int i = 0; i < 50 && resp.data() == nullptr; i++) { + resp = client->receive(100'000'000ULL); + } + EXPECT_NE(resp.data(), nullptr) << "zero-length response should be success, not timeout"; + EXPECT_EQ(resp.size(), 0U); + client->release(resp.size()); + + client->close(); + server->request_shutdown(); + server_thread.join(); + server->close(); +} + +/** + * A timeout ≥ ~4.295s must be honored at full width: the ring API takes a + * uint64 ns timeout, so it must not narrow to uint32 and wrap (e.g. 4.5s → + * ~205ms). Verify a 4.5s wait_for_data survives past the 32-bit-ns wrap point + * and sees data published at the 2s mark. + */ +TEST(ShmTest, TimeoutDoesNotWrapAbove4Seconds) +{ + constexpr size_t RING_SIZE = 4UL * 1024; + std::string base_name = "shm_tmo_" + std::to_string(getpid()); + auto server = IpcServer::create_shm(base_name, RING_SIZE, RING_SIZE); + ASSERT_TRUE(server->listen()); + + auto client = IpcClient::create_shm(base_name); + ASSERT_TRUE(client->connect()); + + std::thread sender([&] { + std::this_thread::sleep_for(std::chrono::seconds(2)); + uint8_t byte = 1; + client->send(&byte, 1, 1'000'000'000ULL); + }); + + auto start = std::chrono::steady_clock::now(); + int client_id = server->wait_for_data(4'500'000'000ULL); // 4.5s — would wrap to ~205ms as uint32 + auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); + + EXPECT_EQ(client_id, 0) << "wait_for_data should see the data sent at the 2s mark"; + EXPECT_GE(elapsed.count(), 1500) << "returned before the 2s publish — timeout wrapped"; + EXPECT_LT(elapsed.count(), 4400); + + sender.join(); + auto request = server->receive(0); + if (!request.empty()) { + server->release(0, request.size()); + } + client->close(); + server->close(); +} + +/** + * A corrupt length prefix in the ring (larger than any message the send side + * could legally publish) must error out rather than waiting forever for the + * bytes to arrive. + */ +TEST(ShmTest, RingRejectsCorruptLengthPrefix) +{ + constexpr size_t RING_SIZE = 4UL * 1024; + std::string ring_name = "shm_corrupt_" + std::to_string(getpid()); + SpscShm::unlink(ring_name); + auto producer = SpscShm::create(ring_name, RING_SIZE); + auto consumer = SpscShm::connect(ring_name); + + // Forge a frame whose length prefix claims far more than capacity/2. + void* buf = producer.claim(8, 1'000'000'000ULL); + ASSERT_NE(buf, nullptr); + uint32_t bogus_len = 0xFFFFFF00; + std::memcpy(buf, &bogus_len, sizeof(bogus_len)); + producer.publish(8); + + EXPECT_THROW((void)ring_receive_msg(consumer, 10'000'000ULL), std::runtime_error); + + SpscShm::unlink(ring_name); +} + +/** + * Sanity check for the MPSC (multi-producer single-consumer) SHM transport: two + * clients concurrently send distinct payloads and each receives back its own + * echoed response. This is the load-bearing property MPSC adds over SPSC — + * multiple producers must not mix up responses or block each other. + */ +TEST(ShmTest, MpscEchoTwoClients) +{ + constexpr size_t NUM_CLIENTS = 2; + constexpr size_t NUM_MESSAGES = 200; + constexpr size_t MSG_SIZE = 64; + constexpr size_t RING_SIZE = 4UL * 1024; + + std::string base_name = "shm_mpsc_" + std::to_string(getpid()); + auto server = IpcServer::create_mpsc_shm(base_name, NUM_CLIENTS, RING_SIZE, RING_SIZE); + ASSERT_TRUE(server->listen()) << "MPSC server failed to listen"; + + std::atomic server_running{ true }; + + // Echo server: poll for any client with data, echo it back to that client. + std::thread server_thread([&]() { + while (server_running.load(std::memory_order_acquire)) { + server->accept(); + int client_id = server->wait_for_data(1000000); // 1ms + if (client_id < 0) { + continue; + } + auto request_buf = server->receive(client_id); + if (request_buf.empty()) { + continue; + } + std::vector request(request_buf.begin(), request_buf.end()); + server->release(client_id, request.size()); + while (!server->send(client_id, request.data(), request.size())) { + // Retry if the client's response ring is full. + } + } + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + auto run_client = [&](size_t client_id) { + auto client = IpcClient::create_mpsc_shm(base_name, client_id); + ASSERT_TRUE(client->connect()) << "Client " << client_id << " failed to connect"; + + for (size_t iter = 0; iter < NUM_MESSAGES; iter++) { + std::vector payload(MSG_SIZE); + // First byte tags the client; remaining bytes encode (client_id, iter, + // offset). + payload[0] = static_cast(client_id); + for (size_t i = 1; i < MSG_SIZE; i++) { + payload[i] = static_cast((client_id ^ iter ^ i) & 0xFF); + } + + while (!client->send(payload.data(), payload.size(), 100000000)) { + // Retry on send timeout. + } + + std::span response; + while ((response = client->receive(100000000)).empty()) { + // Retry on receive timeout. + } + + ASSERT_EQ(response.size(), MSG_SIZE) << "client " << client_id << " iter " << iter; + // The crucial MPSC invariant: client sees its own payload back, not + // another client's. + ASSERT_EQ(response[0], static_cast(client_id)) + << "client " << client_id << " got cross-client response at iter " << iter; + for (size_t i = 1; i < MSG_SIZE; i++) { + uint8_t expected = static_cast((client_id ^ iter ^ i) & 0xFF); + ASSERT_EQ(response[i], expected) << "client " << client_id << " iter " << iter << " offset " << i; + } + client->release(response.size()); + } + client->close(); + }; + + std::vector client_threads; + client_threads.reserve(NUM_CLIENTS); + for (size_t id = 0; id < NUM_CLIENTS; id++) { + client_threads.emplace_back(run_client, id); + } + for (auto& t : client_threads) { + t.join(); + } + + server_running.store(false); + server->request_shutdown(); + server_thread.join(); + server->close(); +} + +} // namespace diff --git a/ipc-runtime/cpp/ipc_runtime/shm/README.md b/ipc-runtime/cpp/ipc_runtime/shm/README.md new file mode 100644 index 000000000000..45a4bba41231 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/README.md @@ -0,0 +1,371 @@ +# Lock-Free Shared Memory Ring Buffers (C++) + +Ultra-low-latency shared-memory ring buffers for inter-process communication using modern C++. Built on POSIX `shm_open` + `mmap` with lock-free atomics and efficient futex-based blocking (Linux futex; `os_sync_wait_on_address` on macOS). + +## Features + +- **Zero-copy IPC** between processes via MAP_SHARED +- **Lock-free**: No mutexes, no syscalls in hot path +- **Adaptive blocking**: Brief spin, then futex sleep for power efficiency +- **Single-Producer Single-Consumer (SPSC)**: Lock-free ring buffer building block +- **Multi-Producer Single-Consumer (MPSC)**: Compositional layer using SPSC + doorbell +- **Modern C++**: RAII, move semantics, factory methods +- **Cache-optimized**: Careful alignment to avoid false sharing + +## Performance + +| Operation | Latency | Notes | +|------------------------------------|------------------|--------------------------------| +| SPSC roundtrip (hot) | 0.3–1 µs | No contention, busy loop | +| SPSC roundtrip (cold) | 3–6 µs | After futex wakeup | +| MPSC roundtrip (3 producers, hot) | ~40 µs | 3-way contention | +| Pipe/socket (for comparison) | 6–15 µs | Requires syscalls | + +*Measured on AMD Ryzen 9 5950X, Ubuntu 24.04, small messages (<1KB)* + +## Architecture + +### SPSC (Single-Producer Single-Consumer) + +``` +┌──────────────────────────────────────────────────┐ +│ SpscCtrl (control block) │ +│ ┌────────────────────────────────────────────┐ │ +│ │ head + wrap_head │ │ +│ │ (producer-owned, cacheline-aligned) │ │ +│ │ tail │ │ +│ │ (consumer-owned, cacheline-aligned) │ │ +│ │ capacity, mask (immutable) │ │ +│ └────────────────────────────────────────────┘ │ +│ │ +│ Data buffer (power-of-2 size) │ +│ ┌────────────────────────────────────────────┐ │ +│ │ [producer writes here] [consumer reads] │ │ +│ └────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────┘ +``` + +**Key characteristics:** +- **Lock-free**: Producer and consumer never block each other +- **Cache-friendly**: head/tail separated by cache line to avoid false sharing +- **Variable-length messages**: Automatic padding when wrapping around ring +- **Efficient blocking**: Spin briefly, then futex sleep/wake + +### MPSC (Multi-Producer Single-Consumer) + +``` +┌─────────────────────────────────────────────────┐ +│ MPSC System (N producers) │ +│ │ +│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ +│ │ Producer │ │ Producer │ │ Producer │ │ +│ │ 0 │ │ 1 │ │ 2 │ │ +│ └─────┬────┘ └─────┬────┘ └─────┬────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ +│ │ SPSC │ │ SPSC │ │ SPSC │ │ +│ │ Ring 0 │ │ Ring 1 │ │ Ring 2 │ │ +│ └────┬────┘ └────┬────┘ └────┬────┘ │ +│ │ │ │ │ +│ └─────────────┼─────────────┘ │ +│ │ │ +│ ┌─────▼──────┐ │ +│ │ Doorbell │◄─────────────────┤ +│ │ Futex │ (wake on data) │ +│ └─────┬──────┘ │ +│ │ │ +│ ┌─────▼──────┐ │ +│ │ Consumer │ │ +│ │ (polls all │ │ +│ │ rings) │ │ +│ └────────────┘ │ +└─────────────────────────────────────────────────┘ +``` + +**Key characteristics:** +- Each producer gets dedicated SPSC ring (no contention between producers) +- Consumer polls all rings in round-robin fashion +- Shared doorbell futex: producers ring on empty→non-empty transition +- Per-producer backpressure (full ring blocks only that producer) + +## API Overview + +All timeouts are in nanoseconds. At this layer `timeout_ns == 0` means an +immediate (non-blocking) check; the higher-level `IpcClient`/`IpcServer` +wrappers translate their public "0 = infinite" convention before calling in. + +### SpscShm Class + +```cpp +namespace ipc { + +class SpscShm { +public: + // Factory methods + static SpscShm create(const std::string& name, size_t min_capacity); + static SpscShm connect(const std::string& name); + static bool unlink(const std::string& name); + + // Move-only (RAII) + SpscShm(SpscShm&& other) noexcept; + SpscShm& operator=(SpscShm&& other) noexcept; + ~SpscShm(); + + // Introspection + uint64_t available() const; // bytes ready to read + uint64_t capacity() const; + + // Producer API: claim/publish must be paired, with sizes exactly + // matching the consumer's peek/release pair. + void* claim(size_t want, uint64_t timeout_ns); // nullptr on timeout + void publish(size_t n); + + // Consumer API + void* peek(size_t want, uint64_t timeout_ns); // nullptr on timeout + void release(size_t n); + + // Blocking wait (adaptive spin, then futex) + bool wait_for_data(size_t need, uint64_t timeout_ns); + bool wait_for_space(size_t need, uint64_t timeout_ns); + + // Wake all blocked waiters (graceful shutdown) + void wakeup_all(); +}; + +} // namespace ipc +``` + +The wrap decision is stateless and derived purely from the requested size, +so every `claim(n)`/`publish(n)` by the producer must be matched by a +`peek(n)`/`release(n)` of the same `n` by the consumer (see the header +comment in `spsc_shm.hpp`). + +### MpscConsumer / MpscProducer Classes + +```cpp +namespace ipc { + +class MpscConsumer { +public: + // Factory + static MpscConsumer create(const std::string& name, + size_t num_producers, + size_t ring_capacity); + static bool unlink(const std::string& name, size_t num_producers); + + // Move-only (RAII) + MpscConsumer(MpscConsumer&& other) noexcept; + ~MpscConsumer(); + + // Consumer API + int wait_for_data(uint64_t timeout_ns); // ring index with data, or -1 + void* peek(size_t ring_idx, size_t want, uint64_t timeout_ns); + void release(size_t ring_idx, size_t n); + void wakeup_all(); +}; + +class MpscProducer { +public: + // Factory + static MpscProducer connect(const std::string& name, size_t producer_id); + + // Move-only (RAII) + MpscProducer(MpscProducer&& other) noexcept; + ~MpscProducer(); + + // Producer API + void* claim(size_t want, uint64_t timeout_ns); + void publish(size_t n); // rings the doorbell +}; + +} // namespace ipc +``` + +## Usage Examples + +These use the message framing helpers from `../shm_common.hpp` +(`ring_send_msg` / `ring_receive_msg`), which add a 4-byte length prefix and +take care of the matched claim/peek sizing. + +**Producer process:** +```cpp +#include "ipc_runtime/shm_common.hpp" +#include + +int main() { + // Create ring buffer (1 MiB capacity); consumer connects by name. + auto tx = ipc::SpscShm::create("/demo_ring", 1 << 20); + + std::string msg = "hello from producer"; + while (true) { + // Blocks up to 1s for ring space; false on timeout. + ipc::ring_send_msg(tx, msg.data(), msg.size(), 1'000'000'000); + } +} +``` + +**Consumer process:** +```cpp +#include "ipc_runtime/shm_common.hpp" +#include + +int main() { + auto rx = ipc::SpscShm::connect("/demo_ring"); + + while (true) { + // Blocks up to 1s for a whole message; empty data() on timeout. + auto msg = ipc::ring_receive_msg(rx, 1'000'000'000); + if (msg.data() == nullptr) { + continue; // timeout + } + std::cout << "Received: " << std::string(msg.begin(), msg.end()) << "\n"; + rx.release(4 + msg.size()); // length prefix + payload + } +} +``` + +**Cleanup:** +```cpp +// When done (from either process) +ipc::SpscShm::unlink("/demo_ring"); +``` + +For multi-producer setups, prefer the higher-level `MpscShmServer` / +`MpscShmClient` (`../mpsc_shm_server.hpp`, `../mpsc_shm_client.hpp`), which +wire `MpscConsumer`/`MpscProducer` together with per-client response rings +and the same framing. + +## Implementation Details + +### Memory Layout + +The shared memory region contains: +1. **SpscCtrl** (control block, 256 bytes) + - Atomic head/tail counters (cache-line aligned) + - Futex sequencers for sleep/wake + - Capacity and mask (immutable) +2. **Data buffer** (power-of-2 size, follows control block) + +Total size: `sizeof(SpscCtrl) + capacity` + +### Padding and Wrapping + +When a message would wrap around the ring boundary, automatic padding is inserted: + +``` +┌────────────────────────────────────────────────┐ +│ [msg1] [msg2] [...............] [padding] │ +│ ^ │ +│ └─ wrap point │ +└────────────────────────────────────────────────┘ + ^ + └─ next message starts at beginning +``` + +The consumer's `peek()` automatically skips padding, so callers never see it. + +### Futex-Based Blocking + +Instead of busy-waiting forever: +1. **Producer**: Spins briefly checking for space, then sleeps on the `tail` futex (armed against the current tail value) +2. **Consumer**: Spins briefly checking for data, then sleeps on the `head` futex (armed against the current head value) +3. **Wakeup**: The other side calls `futex_wake` unconditionally after publishing/releasing. `futex_wait` re-checks the armed value atomically under the futex bucket lock, so a publish/release that races the sleep returns `EAGAIN` instead of sleeping. + +This provides: +- Low latency when active (spin catches transitions) +- Low power when idle (futex sleep) +- No thundering herd (one waker, one sleeper) + +> The wake is intentionally unconditional — do **not** gate it on a +> `consumer_blocked`/`producer_blocked` flag to skip the syscall when no one is +> waiting. That would be a cross-process Dekker handshake between the flag and +> the head/tail word, which races and can drop the wake, stranding a waiter on +> already-published data. A `futex_wake` with no waiter is a cheap no-op, so the +> unconditional wake costs effectively nothing on the idle-consumer path. + +### MPSC Doorbell + +The doorbell is a simple futex counter in shared memory: + +```cpp +struct alignas(64) MpscDoorbell { + // Producer-written (incremented in publish()) + alignas(64) std::atomic seq; + // (+ cache-line padding) +}; +``` + +**Protocol:** +1. Producer publishes data to its SPSC ring +2. Producer increments the doorbell seq and calls `futex_wake` unconditionally +3. Consumer wakes up, polls all rings in round-robin +4. Consumer sleeps on the doorbell seq only when all rings are empty + +This ensures the consumer wakes promptly when any producer has data, while minimizing futex overhead when rings stay populated. + +## Performance Tuning + +### Spin Time + +The `spin_ns` parameter controls busy-wait duration before sleeping: + +- **Low latency**: Use longer spin (e.g., 100 µs) to avoid futex overhead +- **Power efficiency**: Use shorter spin (e.g., 1 µs) to sleep sooner +- **Recommended**: 10-20 µs balances latency and power + +### Ring Size + +- Must be **power of two** +- Larger rings reduce wrapping overhead but use more memory +- Recommended: 1 MB (1 << 20) for most use cases +- Small messages (<1 KB): Can use smaller rings (256 KB) +- Large messages (>100 KB): Use larger rings (4-16 MB) + +### Number of Producers (MPSC) + +- More producers → more ring poll overhead for consumer +- Recommended: ≤8 producers for best performance +- Beyond that, consider multiple MPSC systems or alternative architecture + +## Thread Safety + +### SPSC +- **One producer thread**, **one consumer thread** +- No internal synchronization needed (lock-free by design) +- Cannot share producer or consumer role across threads + +### MPSC +- **Multiple producer threads** (one per producer instance) +- **One consumer thread** +- Each producer is independent (no contention) +- Consumer must be single-threaded + +## Limitations + +1. **Platform**: Linux and macOS (futex / `os_sync_wait_on_address`); other platforms fail the build +2. **Capacity**: Must be power of two +3. **Fixed size**: Cannot resize after creation +4. **No security**: All processes with access can read/write shared memory +5. **Manual cleanup**: Must call `unlink()` to remove `/dev/shm` objects + +## Comparison with Other IPC Mechanisms + +| Mechanism | Latency | Throughput | Complexity | Use Case | +|--------------------|------------|------------|------------|-------------------------| +| Pipe | 6-15 µs | 150K/s | Low | Simple IPC | +| Unix Socket | 6-15 µs | 150K/s | Low | Network-like API | +| SPSC Ring | 0.3-1 µs | 1M/s | Medium | Ultra-low latency | +| MPSC Ring | ~3 µs | 700K/s | Medium | Multiple producers | +| POSIX MQ | 10-20 µs | 100K/s | Medium | Message queue semantics | + +## See Also + +- Parent IPC module: [`../README.md`](../README.md) +- Tests: [`../shm.test.cpp`](../shm.test.cpp) +- Benchmarks: build a harness against the `ipc_runtime` CMake target. +- Higher-level wrappers: `ShmClient` / `ShmServer` in [`../shm_client.hpp`](../shm_client.hpp) + +## License + +See repository root for license details. diff --git a/ipc-runtime/cpp/ipc_runtime/shm/futex.hpp b/ipc-runtime/cpp/ipc_runtime/shm/futex.hpp new file mode 100644 index 000000000000..a253f5f03376 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/futex.hpp @@ -0,0 +1,90 @@ +/** + * @file futex.hpp + * @brief Cross-platform futex-like synchronization primitives + * + * Provides unified wait/wake operations for cross-process synchronization: + * - macOS: Uses os_sync_wait_on_address_with_timeout / os_sync_wake_by_address_any + * - Linux: Uses futex syscalls + */ +#pragma once + +#include +#include + +#ifdef __APPLE__ +// Darwin's os_sync API (available since macOS 10.12 / iOS 10) +// Forward declarations to avoid header dependency +extern "C" { +int os_sync_wait_on_address_with_timeout( + void* addr, uint64_t value, size_t size, uint32_t flags, uint32_t clockid, uint64_t timeout_ns); +int os_sync_wake_by_address_any(void* addr, size_t size, uint32_t flags); +} +#define OS_SYNC_WAIT_ON_ADDRESS_SHARED 1u +#define OS_SYNC_WAKE_BY_ADDRESS_SHARED 1u +#define OS_CLOCK_MACH_ABSOLUTE_TIME 32u +#elif defined(__linux__) +// Linux futex +#include +#include +#include +#include +#else +#error "ipc-runtime supports Linux and macOS only" +#endif + +namespace ipc { + +/** + * @brief Atomic compare-and-wait operation with timeout + * + * Blocks if the value at addr equals expect, but only for up to timeout_ns nanoseconds. + * Works across process boundaries. + * + * @param addr Pointer to 32-bit value to wait on + * @param expect Expected value - blocks if *addr == expect + * @param timeout_ns Maximum time to wait in nanoseconds (0 = return immediately if value matches) + * @return 0 on wake, -1 on error (check errno for ETIMEDOUT on timeout) + */ +inline int futex_wait_timeout(volatile uint32_t* addr, uint32_t expect, uint64_t timeout_ns) +{ +#ifdef __APPLE__ + // macOS: Use os_sync_wait_on_address_with_timeout with SHARED flag for cross-process + // Uses MACH_ABSOLUTE_TIME clock (monotonic, measures time since boot) + return os_sync_wait_on_address_with_timeout(const_cast(addr), + static_cast(expect), + sizeof(uint32_t), + OS_SYNC_WAIT_ON_ADDRESS_SHARED, + OS_CLOCK_MACH_ABSOLUTE_TIME, + timeout_ns); +#else + // Linux futex with timeout + struct timespec timeout = { .tv_sec = static_cast(timeout_ns / 1000000000ULL), + .tv_nsec = static_cast(timeout_ns % 1000000000ULL) }; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + return static_cast(syscall(SYS_futex, addr, FUTEX_WAIT, expect, &timeout, nullptr, 0)); +#endif +} + +/** + * @brief Wake waiters blocked on an address + * + * Wakes up to n waiters blocked on addr. Works across process boundaries. + * + * @param addr Pointer to 32-bit value to wake on + * @param n Number of waiters to wake (1 for single, INT_MAX for all) + * @return Number of waiters woken, or -1 on error + */ +inline int futex_wake(volatile uint32_t* addr, int n) +{ +#ifdef __APPLE__ + // macOS: Use os_sync_wake_by_address with SHARED flag for cross-process + (void)n; + return os_sync_wake_by_address_any(const_cast(addr), sizeof(uint32_t), OS_SYNC_WAKE_BY_ADDRESS_SHARED); +#else + // Linux futex + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + return static_cast(syscall(SYS_futex, addr, FUTEX_WAKE, n, nullptr, nullptr, 0)); +#endif +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.cpp b/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.cpp new file mode 100644 index 000000000000..731a399c0f98 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.cpp @@ -0,0 +1,387 @@ +#include "mpsc_shm.hpp" +#include "futex.hpp" +#include "utilities.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +// ----- MpscConsumer Implementation ----- + +MpscConsumer::MpscConsumer(std::vector&& rings, int doorbell_fd, size_t doorbell_len, MpscDoorbell* doorbell) + : rings_(std::move(rings)) + , doorbell_fd_(doorbell_fd) + , doorbell_len_(doorbell_len) + , doorbell_(doorbell) +{} + +MpscConsumer::MpscConsumer(MpscConsumer&& other) noexcept + : rings_(std::move(other.rings_)) + , doorbell_fd_(other.doorbell_fd_) + , doorbell_len_(other.doorbell_len_) + , doorbell_(other.doorbell_) + , last_served_(other.last_served_) +{ + other.doorbell_fd_ = -1; + other.doorbell_len_ = 0; + other.doorbell_ = nullptr; + other.last_served_ = 0; +} + +MpscConsumer& MpscConsumer::operator=(MpscConsumer&& other) noexcept +{ + if (this != &other) { + // Clean up current resources + if (doorbell_ != nullptr) { + munmap(doorbell_, doorbell_len_); + } + if (doorbell_fd_ >= 0) { + ::close(doorbell_fd_); + } + + // Move from other + rings_ = std::move(other.rings_); + doorbell_fd_ = other.doorbell_fd_; + doorbell_len_ = other.doorbell_len_; + doorbell_ = other.doorbell_; + last_served_ = other.last_served_; + + // Clear other + other.doorbell_fd_ = -1; + other.doorbell_len_ = 0; + other.doorbell_ = nullptr; + other.last_served_ = 0; + } + return *this; +} + +MpscConsumer::~MpscConsumer() +{ + if (doorbell_ != nullptr) { + munmap(doorbell_, doorbell_len_); + } + if (doorbell_fd_ >= 0) { + ::close(doorbell_fd_); + } +} + +MpscConsumer MpscConsumer::create(const std::string& name, size_t num_producers, size_t ring_capacity) +{ + if (name.empty() || num_producers == 0) { + throw std::runtime_error("MpscConsumer::create: invalid arguments"); + } + + // Create doorbell shared memory + std::string doorbell_name = name + "_doorbell"; + size_t doorbell_len = sizeof(MpscDoorbell); + + int doorbell_fd = shm_open(doorbell_name.c_str(), O_RDWR | O_CREAT | O_EXCL, 0600); + if (doorbell_fd < 0) { + throw std::runtime_error("MpscConsumer::create: shm_open doorbell failed: " + + std::string(std::strerror(errno))); + } + + if (ftruncate(doorbell_fd, static_cast(doorbell_len)) != 0) { + int e = errno; + ::close(doorbell_fd); + shm_unlink(doorbell_name.c_str()); + throw std::runtime_error("MpscConsumer::create: ftruncate doorbell failed: " + std::string(std::strerror(e))); + } + + auto* doorbell = + static_cast(mmap(nullptr, doorbell_len, PROT_READ | PROT_WRITE, MAP_SHARED, doorbell_fd, 0)); + if (doorbell == MAP_FAILED) { + int e = errno; + ::close(doorbell_fd); + shm_unlink(doorbell_name.c_str()); + throw std::runtime_error("MpscConsumer::create: mmap doorbell failed: " + std::string(std::strerror(e))); + } + + // Initialize doorbell (use placement new to avoid memset on non-trivial type) + new (doorbell) MpscDoorbell{}; + doorbell->seq.store(0, std::memory_order_release); + + // Create all SPSC rings + std::vector rings; + rings.reserve(num_producers); + + try { + for (size_t i = 0; i < num_producers; i++) { + std::string ring_name = name + "_ring_" + std::to_string(i); + rings.push_back(SpscShm::create(ring_name, ring_capacity)); + } + } catch (...) { + // Cleanup on failure + for (size_t i = 0; i < rings.size(); i++) { + std::string ring_name = name + "_ring_" + std::to_string(i); + SpscShm::unlink(ring_name); + } + munmap(doorbell, doorbell_len); + ::close(doorbell_fd); + shm_unlink(doorbell_name.c_str()); + throw; + } + + return MpscConsumer(std::move(rings), doorbell_fd, doorbell_len, doorbell); +} + +bool MpscConsumer::unlink(const std::string& name, size_t num_producers) +{ + std::string doorbell_name = name + "_doorbell"; + shm_unlink(doorbell_name.c_str()); + + for (size_t i = 0; i < num_producers; i++) { + std::string ring_name = name + "_ring_" + std::to_string(i); + SpscShm::unlink(ring_name); + } + + return true; +} + +int MpscConsumer::wait_for_data(uint64_t timeout_ns) +{ + size_t num_rings = rings_.size(); + + // Phase 1: Quick poll - check if data already available + for (size_t i = 0; i < num_rings; i++) { + size_t idx = (last_served_ + 1 + i) % num_rings; + if (rings_[idx].available() > 0) { + last_served_ = idx; + previous_had_data_ = true; // Found data - enable spinning on next call + return static_cast(idx); + } + } + + // Adaptive spinning: only spin if previous call found data + constexpr uint64_t SPIN_NS = 100000; // 100us + uint64_t spin_duration; + uint64_t remaining_timeout; + + if (previous_had_data_) { + // Previous call found data - do full spin (optimistic) + spin_duration = (timeout_ns < SPIN_NS) ? timeout_ns : SPIN_NS; + remaining_timeout = (timeout_ns > SPIN_NS) ? (timeout_ns - SPIN_NS) : 0; + } else { + // Previous call timed out - skip spinning (idle channel) + spin_duration = 0; + remaining_timeout = timeout_ns; + } + + // Phase 2: Spin phase (only if previous call found data) + if (spin_duration > 0) { + uint64_t start = mono_ns_now(); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-do-while) + do { + for (size_t i = 0; i < num_rings; i++) { + size_t idx = (last_served_ + 1 + i) % num_rings; + if (rings_[idx].available() > 0) { + last_served_ = idx; + previous_had_data_ = true; // Found data during spin + return static_cast(idx); + } + } + IPC_PAUSE(); + } while ((mono_ns_now() - start) < spin_duration); + + // Check after spin + for (size_t i = 0; i < num_rings; i++) { + size_t idx = (last_served_ + 1 + i) % num_rings; + if (rings_[idx].available() > 0) { + last_served_ = idx; + previous_had_data_ = true; // Found data after spin + return static_cast(idx); + } + } + } + + // No more time or didn't spin - check if we can block + if (remaining_timeout == 0) { + previous_had_data_ = false; // Timeout - disable spinning on next call + return -1; + } + + // About to block. Capture the doorbell seq and arm the futex against it. + // Producers bump seq on every publish and wake unconditionally (see + // MpscProducer::publish), and futex_wait re-checks *seq == seq atomically, so + // a publish that lands before we sleep returns EAGAIN. The arm-value plus the + // unconditional wake are the whole protocol — no "blocked" flag. + uint32_t seq = doorbell_->seq.load(std::memory_order_acquire); + + // Final check before blocking + for (size_t i = 0; i < num_rings; i++) { + size_t idx = (last_served_ + 1 + i) % num_rings; + if (rings_[idx].available() > 0) { + last_served_ = idx; + previous_had_data_ = true; // Found data before blocking + return static_cast(idx); + } + } + + futex_wait_timeout(reinterpret_cast(&doorbell_->seq), seq, remaining_timeout); + + // After waking, poll again + for (size_t i = 0; i < num_rings; i++) { + size_t idx = (last_served_ + 1 + i) % num_rings; + if (rings_[idx].available() > 0) { + last_served_ = idx; + previous_had_data_ = true; // Found data after waking + return static_cast(idx); + } + } + + previous_had_data_ = false; // Timeout - disable spinning on next call + return -1; // No data available (timeout) +} + +void* MpscConsumer::peek(size_t ring_idx, size_t want, uint64_t timeout_ns) +{ + if (ring_idx >= rings_.size()) { + return nullptr; + } + return rings_[ring_idx].peek(want, timeout_ns); +} + +void MpscConsumer::release(size_t ring_idx, size_t n) +{ + if (ring_idx < rings_.size()) { + rings_[ring_idx].release(n); + } +} + +void MpscConsumer::wakeup_all() +{ + // Wake consumer blocked on doorbell + futex_wake(reinterpret_cast(&doorbell_->seq), INT_MAX); + + // Wake all producers blocked on their rings + for (auto& ring : rings_) { + ring.wakeup_all(); + } +} + +// ----- MpscProducer Implementation ----- + +MpscProducer::MpscProducer( + SpscShm&& ring, int doorbell_fd, size_t doorbell_len, MpscDoorbell* doorbell, size_t producer_id) + : ring_(std::move(ring)) + , doorbell_fd_(doorbell_fd) + , doorbell_len_(doorbell_len) + , doorbell_(doorbell) + , producer_id_(producer_id) +{} + +MpscProducer::MpscProducer(MpscProducer&& other) noexcept + : ring_(std::move(other.ring_)) + , doorbell_fd_(other.doorbell_fd_) + , doorbell_len_(other.doorbell_len_) + , doorbell_(other.doorbell_) + , producer_id_(other.producer_id_) +{ + other.doorbell_fd_ = -1; + other.doorbell_len_ = 0; + other.doorbell_ = nullptr; + other.producer_id_ = 0; +} + +MpscProducer& MpscProducer::operator=(MpscProducer&& other) noexcept +{ + if (this != &other) { + // Clean up current resources + if (doorbell_ != nullptr) { + munmap(doorbell_, doorbell_len_); + } + if (doorbell_fd_ >= 0) { + ::close(doorbell_fd_); + } + + // Move from other + ring_ = std::move(other.ring_); + doorbell_fd_ = other.doorbell_fd_; + doorbell_len_ = other.doorbell_len_; + doorbell_ = other.doorbell_; + producer_id_ = other.producer_id_; + + // Clear other + other.doorbell_fd_ = -1; + other.doorbell_len_ = 0; + other.doorbell_ = nullptr; + other.producer_id_ = 0; + } + return *this; +} + +MpscProducer::~MpscProducer() +{ + if (doorbell_ != nullptr) { + munmap(doorbell_, doorbell_len_); + } + if (doorbell_fd_ >= 0) { + ::close(doorbell_fd_); + } +} + +MpscProducer MpscProducer::connect(const std::string& name, size_t producer_id) +{ + if (name.empty()) { + throw std::runtime_error("MpscProducer::connect: empty name"); + } + + // Connect to doorbell + std::string doorbell_name = name + "_doorbell"; + size_t doorbell_len = sizeof(MpscDoorbell); + + int doorbell_fd = shm_open(doorbell_name.c_str(), O_RDWR, 0600); + if (doorbell_fd < 0) { + throw std::runtime_error("MpscProducer::connect: shm_open doorbell failed: " + + std::string(std::strerror(errno))); + } + + auto* doorbell = + static_cast(mmap(nullptr, doorbell_len, PROT_READ | PROT_WRITE, MAP_SHARED, doorbell_fd, 0)); + if (doorbell == MAP_FAILED) { + int e = errno; + ::close(doorbell_fd); + throw std::runtime_error("MpscProducer::connect: mmap doorbell failed: " + std::string(std::strerror(e))); + } + + // Connect to assigned ring + std::string ring_name = name + "_ring_" + std::to_string(producer_id); + try { + SpscShm ring = SpscShm::connect(ring_name); + return MpscProducer(std::move(ring), doorbell_fd, doorbell_len, doorbell, producer_id); + } catch (...) { + munmap(doorbell, doorbell_len); + ::close(doorbell_fd); + throw; + } +} + +void* MpscProducer::claim(size_t want, uint64_t timeout_ns) +{ + return ring_.claim(want, timeout_ns); +} + +void MpscProducer::publish(size_t n) +{ + // Publish to ring first + ring_.publish(n); + + // Ring doorbell to wake the consumer. Bump seq (release) so a consumer + // mid-block sees the value change and its futex_wait returns immediately, + // then wake unconditionally — never gated on a "consumer blocked" flag (see + // SpscShm::publish for why that handshake is unsafe across processes). + doorbell_->seq.fetch_add(1, std::memory_order_release); + futex_wake(reinterpret_cast(&doorbell_->seq), 1); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.hpp b/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.hpp new file mode 100644 index 000000000000..23aa3cffc621 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/mpsc_shm.hpp @@ -0,0 +1,152 @@ +/** + * @file mpsc_shm.hpp + * @brief Multi-Producer Single-Consumer via SPSC rings + doorbell futex + * + * Coordinates multiple producers using individual SPSC rings and a shared doorbell. + */ + +#pragma once + +#include "spsc_shm.hpp" +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief Shared doorbell for waking consumer + * + * Producers ring this when publishing data to wake the sleeping consumer. + * Carefully aligned to avoid false sharing between producer and consumer. + */ +struct alignas(64) MpscDoorbell { + // Producer-written (written by producers in publish()). Consumers wait on + // this seq with a futex; producers bump it and wake unconditionally. + alignas(64) std::atomic seq; + std::array _pad0; +}; + +/** + * @brief Multi-producer single-consumer - consumer side + * + * Manages multiple SPSC rings (one per producer) and waits on a shared doorbell. + */ +class MpscConsumer { + public: + /** + * @brief Create MPSC consumer + * @param name Base name for shared memory objects + * @param num_producers Number of producer rings to create + * @param ring_capacity Capacity for each SPSC ring + * @throws std::runtime_error if creation fails + */ + static MpscConsumer create(const std::string& name, size_t num_producers, size_t ring_capacity); + + /** + * @brief Unlink all shared memory for this MPSC system + * @param name Base name + * @param num_producers Number of producers + * @return true if all unlinks successful + */ + static bool unlink(const std::string& name, size_t num_producers); + + // Move-only + MpscConsumer(MpscConsumer&& other) noexcept; + MpscConsumer& operator=(MpscConsumer&& other) noexcept; + MpscConsumer(const MpscConsumer&) = delete; + MpscConsumer& operator=(const MpscConsumer&) = delete; + + ~MpscConsumer(); + + /** + * @brief Wait for data on any ring + * @param timeout_ns Total timeout in nanoseconds (spins 10ms, then futex waits for remainder) + * @return Ring index with data, or -1 on timeout + */ + int wait_for_data(uint64_t timeout_ns); + + /** + * @brief Peek data from specific ring + * @param ring_idx Ring index + * @param want Minimum bytes required + * @param timeout_ns Timeout in nanoseconds + * @return Pointer to data, or nullptr on timeout + */ + void* peek(size_t ring_idx, size_t want, uint64_t timeout_ns); + + /** + * @brief Release data from specific ring + * @param ring_idx Ring index + * @param n Bytes to release + */ + void release(size_t ring_idx, size_t n); + + /** + * @brief Wake all blocked threads (for graceful shutdown) + * Wakes consumer blocked on doorbell and all producers blocked on their rings + */ + void wakeup_all(); + + private: + MpscConsumer(std::vector&& rings, int doorbell_fd, size_t doorbell_len, MpscDoorbell* doorbell); + + std::vector rings_; + int doorbell_fd_ = -1; + size_t doorbell_len_ = 0; + MpscDoorbell* doorbell_ = nullptr; + size_t last_served_ = 0; // Round-robin fairness + bool previous_had_data_ = false; // Adaptive spinning: only spin if previous call found data +}; + +/** + * @brief Multi-producer single-consumer - producer side + * + * Connects to one SPSC ring and rings the shared doorbell when publishing. + */ +class MpscProducer { + public: + /** + * @brief Connect to MPSC system as a producer + * @param name Base name for shared memory objects + * @param producer_id Producer ID (determines which ring to use) + * @throws std::runtime_error if connection fails + */ + static MpscProducer connect(const std::string& name, size_t producer_id); + + // Move-only + MpscProducer(MpscProducer&& other) noexcept; + MpscProducer& operator=(MpscProducer&& other) noexcept; + MpscProducer(const MpscProducer&) = delete; + MpscProducer& operator=(const MpscProducer&) = delete; + + ~MpscProducer(); + + /** + * @brief Claim space in producer's ring + * @param want Bytes wanted + * @param timeout_ns Timeout in nanoseconds + * @return Pointer to buffer, or nullptr on timeout + */ + void* claim(size_t want, uint64_t timeout_ns); + + /** + * @brief Publish data to producer's ring (rings doorbell) + * @param n Bytes to publish + */ + void publish(size_t n); + + private: + MpscProducer(SpscShm&& ring, int doorbell_fd, size_t doorbell_len, MpscDoorbell* doorbell, size_t producer_id); + + SpscShm ring_; + int doorbell_fd_ = -1; + size_t doorbell_len_ = 0; + MpscDoorbell* doorbell_ = nullptr; + size_t producer_id_ = 0; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.cpp b/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.cpp new file mode 100644 index 000000000000..ccb3c1134f66 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.cpp @@ -0,0 +1,543 @@ +#include "spsc_shm.hpp" +#include "futex.hpp" +#include "utilities.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +namespace { + +inline uint64_t pow2_ceil_u64(uint64_t x) +{ + if (x < 2) { + return 2; + } + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x |= x >> 32; + return x + 1; +} + +} // anonymous namespace + +// ----- SpscShm Implementation ----- + +SpscShm::SpscShm(int fd, size_t map_len, SpscCtrl* ctrl, uint8_t* buf) + : fd_(fd) + , map_len_(map_len) + , ctrl_(ctrl) + , buf_(buf) +{} + +SpscShm::SpscShm(SpscShm&& other) noexcept + : fd_(other.fd_) + , map_len_(other.map_len_) + , ctrl_(other.ctrl_) + , buf_(other.buf_) +{ + other.fd_ = -1; + other.map_len_ = 0; + other.ctrl_ = nullptr; + other.buf_ = nullptr; +} + +SpscShm& SpscShm::operator=(SpscShm&& other) noexcept +{ + if (this != &other) { + // Clean up current resources + if (ctrl_ != nullptr) { + munmap(ctrl_, map_len_); + } + if (fd_ >= 0) { + ::close(fd_); + } + + // Move from other + fd_ = other.fd_; + map_len_ = other.map_len_; + ctrl_ = other.ctrl_; + buf_ = other.buf_; + + // Clear other + other.fd_ = -1; + other.map_len_ = 0; + other.ctrl_ = nullptr; + other.buf_ = nullptr; + } + return *this; +} + +SpscShm::~SpscShm() +{ + if (ctrl_ != nullptr) { + munmap(ctrl_, map_len_); + } + if (fd_ >= 0) { + ::close(fd_); + } +} + +SpscShm SpscShm::create(const std::string& name, size_t min_capacity) +{ + if (name.empty()) { + throw std::runtime_error("SpscShm::create: empty name"); + } + + size_t cap = pow2_ceil_u64(min_capacity); + size_t map_len = sizeof(SpscCtrl) + cap; + + int fd = shm_open(name.c_str(), O_RDWR | O_CREAT | O_EXCL, 0600); + if (fd < 0) { + std::string error_msg = "SpscShm::create: shm_open failed for '" + name + "': " + std::strerror(errno); + if (errno == ENOSPC || errno == ENOMEM) { + error_msg += " (likely /dev/shm is full - check df -h /dev/shm)"; + } + throw std::runtime_error(error_msg); + } + + if (ftruncate(fd, static_cast(map_len)) != 0) { + int e = errno; + std::string error_msg = "SpscShm::create: ftruncate failed for '" + name + + "' (size=" + std::to_string(map_len) + "): " + std::strerror(e); + if (e == ENOSPC || e == ENOMEM) { + error_msg += " (likely /dev/shm is full - check df -h /dev/shm)"; + } + ::close(fd); + shm_unlink(name.c_str()); + throw std::runtime_error(error_msg); + } + + void* mem = mmap(nullptr, map_len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (mem == MAP_FAILED) { + int e = errno; + std::string error_msg = "SpscShm::create: mmap failed for '" + name + "' (size=" + std::to_string(map_len) + + "): " + std::strerror(e); + if (e == ENOSPC || e == ENOMEM) { + error_msg += " (likely /dev/shm is full - check df -h /dev/shm)"; + } + ::close(fd); + shm_unlink(name.c_str()); + throw std::runtime_error(error_msg); + } + + std::memset(mem, 0, map_len); + auto* ctrl = static_cast(mem); + + // Initialize non-atomic fields first + ctrl->capacity = cap; + ctrl->mask = cap - 1; + ctrl->wrap_head = UINT64_MAX; + + // Initialize atomics with release ordering to ensure capacity/mask/wrap_head are visible + ctrl->head.store(0ULL, std::memory_order_release); + ctrl->tail.store(0ULL, std::memory_order_release); + + auto* buf = reinterpret_cast(ctrl + 1); + return SpscShm(fd, map_len, ctrl, buf); +} + +SpscShm SpscShm::connect(const std::string& name) +{ + if (name.empty()) { + throw std::runtime_error("SpscShm::connect: empty name"); + } + + int fd = shm_open(name.c_str(), O_RDWR, 0600); + if (fd < 0) { + throw std::runtime_error("SpscShm::connect: shm_open failed: " + std::string(std::strerror(errno))); + } + + struct stat st; + if (fstat(fd, &st) != 0) { + int e = errno; + ::close(fd); + throw std::runtime_error("SpscShm::connect: fstat failed: " + std::string(std::strerror(e))); + } + size_t map_len = static_cast(st.st_size); + + void* mem = mmap(nullptr, map_len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (mem == MAP_FAILED) { + int e = errno; + ::close(fd); + throw std::runtime_error("SpscShm::connect: mmap failed: " + std::string(std::strerror(e))); + } + + auto* ctrl = static_cast(mem); + auto* buf = reinterpret_cast(ctrl + 1); + + // Ensure initialization is visible before use (pairs with release in create) + (void)ctrl->head.load(std::memory_order_acquire); + + return SpscShm(fd, map_len, ctrl, buf); +} + +bool SpscShm::unlink(const std::string& name) +{ + return shm_unlink(name.c_str()) == 0; +} + +uint64_t SpscShm::available() const +{ + uint64_t head = ctrl_->head.load(std::memory_order_acquire); + uint64_t tail = ctrl_->tail.load(std::memory_order_acquire); + return head - tail; +} + +void* SpscShm::claim(size_t want, uint64_t timeout_ns) +{ + // Wait for contiguous space to be available + if (!wait_for_space(want, timeout_ns)) { + return nullptr; // Timeout + } + + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + uint64_t head = ctrl_->head.load(std::memory_order_relaxed); + uint64_t pos = head & mask; + uint64_t till_end = cap - pos; + + // Check if it fits contiguously without wrapping + if (want <= till_end) { + // Fits contiguously - no wrap + return buf_ + pos; + } + + // Needs to wrap + return buf_; // Return pointer to beginning of ring +} + +void SpscShm::publish(size_t n) +{ + uint64_t head = ctrl_->head.load(std::memory_order_relaxed); + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + uint64_t pos = head & mask; + uint64_t till_end = cap - pos; + + // Detect if we published wrapped data + // If at current head position we can't fit n bytes, it must have wrapped + uint64_t total_advance = n; + if (n > till_end) { + // We wrote at the beginning after wrapping - skip padding and our data + total_advance += till_end; + ctrl_->wrap_head = head; + } + + // Advance head with release so the consumer's acquire load of head + // synchronizes with the data and wrap_head writes above. + ctrl_->head.store(head + total_advance, std::memory_order_release); + + // Wake any blocked consumer unconditionally: futex_wake with no waiter is a + // cheap no-op, and the consumer's head value-check keeps the fast path + // syscall-free. Do not gate the wake on a "consumer blocked" flag — that + // cross-process flag/head handshake races and can drop the wake, stranding + // the consumer asleep on already-published data. + futex_wake(reinterpret_cast(&ctrl_->head), 1); +} + +void* SpscShm::peek(size_t want, uint64_t timeout_ns) +{ + // Wait for contiguous data to be available + if (!wait_for_data(want, timeout_ns)) { + return nullptr; // Timeout + } + + // Read head with acquire to synchronize wrap_head + ctrl_->head.load(std::memory_order_acquire); + + uint64_t tail = ctrl_->tail.load(std::memory_order_relaxed); + + // Check if we're at the position where a message wrapped + // If tail == wrap_head, the message starts at position 0 + if (tail == ctrl_->wrap_head) { + return buf_; + } + + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + uint64_t pos = tail & mask; + [[maybe_unused]] uint64_t till_end = cap - pos; + + // At this point wait_for_data() has guaranteed contiguity from tail + // (or we would have wrapped via wrap_head), so want must fit here. + assert(want <= till_end); + + // Data fits contiguously at current position + return buf_ + pos; +} + +void SpscShm::release(size_t n) +{ + uint64_t tail = ctrl_->tail.load(std::memory_order_relaxed); + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + uint64_t pos = tail & mask; + uint64_t till_end = cap - pos; + + uint64_t total_release = 0; + if (tail == ctrl_->wrap_head) { + // We're releasing data from a wrapped message - skip padding + total_release = till_end + n; + } else { + assert(n <= till_end); + // Normal case: data was contiguous + total_release = n; + } + + uint64_t new_tail = tail + total_release; + ctrl_->tail.store(new_tail, std::memory_order_release); + + // Wake any producer blocked on a full ring, unconditionally — see publish() + // for why the wake is never gated on a "producer blocked" flag. + futex_wake(reinterpret_cast(&ctrl_->tail), 1); +} + +bool SpscShm::wait_for_data(size_t need, uint64_t timeout_ns) +{ + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + + // Check if we need contiguous data that would wrap. + auto check_available = [this, cap, mask, need]() -> bool { + uint64_t head = ctrl_->head.load(std::memory_order_acquire); + uint64_t tail = ctrl_->tail.load(std::memory_order_relaxed); + uint64_t avail = head - tail; + + if (avail < need) { + return false; // Not enough total data + } + + // Check if data is contiguous + uint64_t pos = tail & mask; + uint64_t till_end = cap - pos; + + if (need <= till_end) { + return true; // Fits contiguously + } + + // Would wrap - need padding + actual data available + return avail >= (till_end + need); + }; + + if (check_available()) { + previous_had_data_ = true; // Found data - enable spinning on next call + return true; + } + + // Adaptive spinning: only spin if previous call found data + constexpr uint64_t SPIN_NS = 100000; // 100us + uint64_t spin_duration; + uint64_t remaining_timeout; + + if (previous_had_data_) { + // Previous call found data - do full spin (optimistic) + spin_duration = (timeout_ns < SPIN_NS) ? timeout_ns : SPIN_NS; + remaining_timeout = (timeout_ns > SPIN_NS) ? (timeout_ns - SPIN_NS) : 0; + } else { + // Previous call timed out - skip spinning (idle channel) + spin_duration = 0; + remaining_timeout = timeout_ns; + } + + // Spin phase (only if previous call found data) + if (spin_duration > 0) { + uint64_t start = mono_ns_now(); + constexpr uint32_t TIME_CHECK_INTERVAL = 256; // Check time every 256 iterations + uint32_t iterations = 0; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-do-while) + do { + if (check_available()) { + previous_had_data_ = true; // Found data during spin + return true; + } + IPC_PAUSE(); + + // Only check time periodically to avoid syscall overhead + iterations++; + if (iterations >= TIME_CHECK_INTERVAL) { + if ((mono_ns_now() - start) >= spin_duration) { + break; + } + iterations = 0; + } + } while (true); + + // Check after spin + if (check_available()) { + previous_had_data_ = true; // Found data after spin + return true; + } + } + + // No more time or didn't spin - check if we can block + if (remaining_timeout == 0) { + previous_had_data_ = false; // Timeout - disable spinning on next call + return false; + } + + // About to block. Capture the head value we're waiting to change and arm the + // futex against it. The producer wakes unconditionally on every publish (see + // publish()), and futex_wait re-checks *head == head_now atomically under the + // futex bucket lock, so a publish that lands between this load and the + // syscall returns EAGAIN immediately rather than sleeping. No separate + // "blocked" flag is involved — the arm-value and the producer's wake are the + // whole protocol. + uint32_t head_now = static_cast(ctrl_->head.load(std::memory_order_acquire)); + + if (check_available()) { + previous_had_data_ = true; // Found data before blocking + return true; + } + + futex_wait_timeout(reinterpret_cast(&ctrl_->head), head_now, remaining_timeout); + + bool result = check_available(); + previous_had_data_ = result; // Update flag based on final result + return result; +} + +bool SpscShm::wait_for_space(size_t need, uint64_t timeout_ns) +{ + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + + // Check if we need contiguous space that would wrap + auto check_space = [this, cap, mask, need]() -> bool { + uint64_t head = ctrl_->head.load(std::memory_order_relaxed); + uint64_t tail = ctrl_->tail.load(std::memory_order_acquire); + uint64_t freeb = cap - (head - tail); + + // std::cerr << "Checking space: head=" << head << " tail=" << tail << " free=" << freeb << " need=" << need + // << "\n"; + if (freeb < need) { + return false; // Not enough total free space + } + + // Check if space is contiguous + uint64_t pos = head & mask; + uint64_t till_end = cap - pos; + + if (need <= till_end) { + return true; // Fits contiguously + } + + // Would wrap - just check if we have enough total space + // If we have till_end + need bytes free, the ring buffer invariant + // guarantees the beginning is available for writing + return freeb >= (till_end + need); + }; + + if (check_space()) { + previous_had_space_ = true; // Found space - enable spinning on next call + return true; + } + + // Adaptive spinning: only spin if previous call found space + constexpr uint64_t SPIN_NS = 100000; // 100us + uint64_t spin_duration = 0; + uint64_t remaining_timeout = timeout_ns; + + if (previous_had_space_) { + // Previous call found space - do full spin (optimistic) + spin_duration = (timeout_ns < SPIN_NS) ? timeout_ns : SPIN_NS; + remaining_timeout = (timeout_ns > SPIN_NS) ? (timeout_ns - SPIN_NS) : 0; + } + + // Spin phase (only if previous call found space) + if (spin_duration > 0) { + uint64_t start = mono_ns_now(); + constexpr uint32_t TIME_CHECK_INTERVAL = 256; // Check time every 256 iterations + uint32_t iterations = 0; + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-do-while) + do { + if (check_space()) { + previous_had_space_ = true; // Found space during spin + return true; + } + IPC_PAUSE(); + + // Only check time periodically to avoid syscall overhead + iterations++; + if (iterations >= TIME_CHECK_INTERVAL) { + if ((mono_ns_now() - start) >= spin_duration) { + break; + } + iterations = 0; + } + } while (true); + + // Check after spin + if (check_space()) { + previous_had_space_ = true; // Found space after spin + return true; + } + } + + // No more time or didn't spin - check if we can block + if (remaining_timeout == 0) { + previous_had_space_ = false; // Timeout - disable spinning on next call + return false; + } + + // About to block. Arm the futex against the current tail; release() wakes + // unconditionally and futex_wait re-checks *tail == tail_now atomically, so a + // release between this load and the syscall returns EAGAIN rather than + // sleeping. Mirrors wait_for_data — the arm-value plus the unconditional wake + // are the whole protocol. + uint32_t tail_now = static_cast(ctrl_->tail.load(std::memory_order_acquire)); + + if (check_space()) { + previous_had_space_ = true; // Found space before blocking + return true; + } + + futex_wait_timeout(reinterpret_cast(&ctrl_->tail), tail_now, remaining_timeout); + + bool result = check_space(); + previous_had_space_ = result; // Update flag based on final result + return result; +} + +void SpscShm::wakeup_all() +{ + futex_wake(reinterpret_cast(&ctrl_->head), INT_MAX); + futex_wake(reinterpret_cast(&ctrl_->tail), INT_MAX); +} + +void SpscShm::debug_dump(const char* prefix) const +{ + uint64_t head = ctrl_->head.load(std::memory_order_acquire); + uint64_t tail = ctrl_->tail.load(std::memory_order_acquire); + uint64_t cap = ctrl_->capacity; + uint64_t mask = ctrl_->mask; + uint64_t wrap_head = ctrl_->wrap_head; + + uint64_t head_pos = head & mask; + uint64_t tail_pos = tail & mask; + uint64_t used = head - tail; + uint64_t free = cap - used; + + std::cerr << "[" << prefix << "] head=" << head << " tail=" << tail << " | head_pos=" << head_pos + << " tail_pos=" << tail_pos << " | used=" << used << " free=" << free << " cap=" << cap + << " | wrap_head=" << (wrap_head == UINT64_MAX ? "NONE" : std::to_string(wrap_head)) << '\n'; +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.hpp b/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.hpp new file mode 100644 index 000000000000..886ec912e701 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/spsc_shm.hpp @@ -0,0 +1,178 @@ +/** + * @file spsc_shm.hpp + * @brief Single-producer/single-consumer shared-memory ring buffer (Linux, x86-64 optimized) + * + * - Zero-copy between processes via MAP_SHARED + * - One producer, one consumer. No locks. Hot path has no syscalls + * - Adaptive spin, then futex sleep/wake on empty/full transitions + * - Variable-length message framing + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace ipc { + +constexpr size_t SPSC_CACHELINE = 64; + +/** + * @brief Control structure for SPSC ring buffer + * + * Carefully aligned to avoid false sharing between producer and consumer. + */ +struct alignas(SPSC_CACHELINE) SpscCtrl { + // Producer-owned (written by producer, read by consumer) + alignas(SPSC_CACHELINE) std::atomic head; // bytes written + uint64_t wrap_head; // Head value when last message wrapped (UINT64_MAX = no wrap), synchronized by head + std::array _pad0; + + // Consumer-owned (written by consumer, read by producer) + alignas(SPSC_CACHELINE) std::atomic tail; // bytes consumed + std::array _pad1; + + // Immutable capacity information + alignas(SPSC_CACHELINE) uint64_t capacity; // power of two + alignas(SPSC_CACHELINE) uint64_t mask; // capacity - 1 + + // uint8_t buffer[capacity] follows in memory... +}; + +static_assert(alignof(SpscCtrl) == SPSC_CACHELINE, "SpscCtrl alignment"); +static_assert(sizeof(SpscCtrl) % SPSC_CACHELINE == 0, "SpscCtrl size multiple of cache line"); + +/** + * @brief Lock-free single-producer single-consumer shared memory ring buffer + * + * Provides zero-copy message passing between processes using shared memory. + * Uses futex for efficient blocking when empty/full. + * + * CRITICAL USAGE REQUIREMENT: + * Each claim(n)/publish(n) pair by the producer MUST be perfectly matched by a corresponding + * peek(n)/release(n) pair by the consumer, with the EXACT same sizes. + * + * This is because the wrapping logic is completely stateless - it decides whether to wrap based + * solely on whether the requested size fits in the remaining space before the end of the buffer. + * If the producer and consumer use different sizes, they will make inconsistent wrap decisions + * and data corruption will occur. + * + * CORRECT usage example (framed messages): + * Producer: Consumer: + * claim(4), publish(4) <---> peek(4), release(4) // length prefix + * claim(msg_len), publish(msg_len) <---> peek(msg_len), release(msg_len) // message data + */ +class SpscShm { + public: + /** + * @brief Create a new SPSC ring buffer + * @param name Shared memory object name (without /dev/shm prefix) + * @param min_capacity Minimum capacity (rounded up to power of 2) + * @throws std::runtime_error if creation fails + */ + static SpscShm create(const std::string& name, size_t min_capacity); + + /** + * @brief Connect to existing SPSC ring buffer + * @param name Shared memory object name + * @throws std::runtime_error if connection fails + */ + static SpscShm connect(const std::string& name); + + /** + * @brief Unlink shared memory object (cleanup after close) + * @param name Shared memory object name + * @return true if successful, false otherwise + */ + static bool unlink(const std::string& name); + + // Move-only (no copy) + SpscShm(SpscShm&& other) noexcept; + SpscShm& operator=(SpscShm&& other) noexcept; + SpscShm(const SpscShm&) = delete; + SpscShm& operator=(const SpscShm&) = delete; + + ~SpscShm(); + + // Introspection + uint64_t available() const; // bytes ready to read + + uint64_t capacity() const { return ctrl_->capacity; } + + /** + * Producer API: claim() and publish() must be used in pairs + * + * @brief Claim contiguous space in the ring buffer (blocks until available) + * @param want Number of bytes to claim + * @param timeout_ns Timeout in nanoseconds + * @return Pointer to claimed space, or nullptr on timeout + * + * IMPORTANT: The size passed to claim(want) must exactly match the size passed to the + * corresponding peek(want) call by the consumer. Otherwise wrap decisions will be inconsistent. + */ + void* claim(size_t want, uint64_t timeout_ns); + + /** + * @brief Publish n bytes previously claimed + * @param n Number of bytes to publish (must match what was claimed) + * + * IMPORTANT: The size passed to publish(n) must exactly match the size passed to the + * corresponding release(n) call by the consumer. Otherwise wrap decisions will be inconsistent. + */ + void publish(size_t n); + + /** + * Consumer API: peek() and release() must be used in pairs + * + * @brief Peek contiguous readable region (blocks until available) + * @param want Number of bytes to peek + * @param timeout_ns Timeout in nanoseconds + * @return Pointer to readable data, or nullptr on timeout + * + * IMPORTANT: The size passed to peek(want) must exactly match the size passed to the + * corresponding claim(want) call by the producer. Otherwise wrap decisions will be inconsistent. + */ + void* peek(size_t want, uint64_t timeout_ns); + + /** + * @brief Release n bytes previously peeked + * @param n Number of bytes to release (must match what was peeked) + * + * IMPORTANT: The size passed to release(n) must exactly match the size passed to the + * corresponding publish(n) call by the producer. Otherwise wrap decisions will be inconsistent. + */ + void release(size_t n); + + /** + * @brief Wake all blocked threads (for graceful shutdown) + * + * Wakes both producers blocked on space and consumers blocked on data. + * Used for graceful shutdown of the communication channel. + */ + void wakeup_all(); + + bool wait_for_data(size_t need, uint64_t timeout_ns); + bool wait_for_space(size_t need, uint64_t timeout_ns); + + /** + * @brief Dump internal ring buffer state for debugging + * @param prefix Prefix string for the debug output (e.g., "Client REQ" or "Server RESP") + */ + void debug_dump(const char* prefix) const; + + private: + // Private constructor for create/connect factories + SpscShm(int fd, size_t map_len, SpscCtrl* ctrl, uint8_t* buf); + + int fd_ = -1; + size_t map_len_ = 0; + SpscCtrl* ctrl_ = nullptr; + uint8_t* buf_ = nullptr; + bool previous_had_data_ = false; // Adaptive spinning: consumer only spins if previous call found data + bool previous_had_space_ = false; // Adaptive spinning: producer only spins if previous call found space +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm/utilities.hpp b/ipc-runtime/cpp/ipc_runtime/shm/utilities.hpp new file mode 100644 index 000000000000..a6fdb03cac97 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm/utilities.hpp @@ -0,0 +1,42 @@ +/** + * @file utilities.hpp + * @brief Common utilities for IPC shared memory implementation + * + * Provides timing and CPU pause utilities for spin-wait loops. + */ +#pragma once + +#include +#include // NOLINT(modernize-deprecated-headers) - need POSIX clock_gettime/CLOCK_MONOTONIC + +#if defined(__x86_64__) +#include +#define IPC_PAUSE() _mm_pause() +#elif defined(__aarch64__) +#define IPC_PAUSE() asm volatile("yield") +#else +#define IPC_PAUSE() \ + do { \ + } while (0) +#endif + +namespace ipc { + +/** + * @brief Get current monotonic time in nanoseconds + * + * Uses CLOCK_MONOTONIC which is suitable for measuring elapsed time + * and not affected by system clock adjustments. + * + * @return Current monotonic time in nanoseconds, or 0 on error + */ +inline uint64_t mono_ns_now() +{ + struct timespec ts; + if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { + return 0; + } + return (static_cast(ts.tv_sec) * 1000000000ULL) + static_cast(ts.tv_nsec); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm_client.hpp b/ipc-runtime/cpp/ipc_runtime/shm_client.hpp new file mode 100644 index 000000000000..79722cfe1b49 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm_client.hpp @@ -0,0 +1,133 @@ +#pragma once + +#include "constants.hpp" +#include "ipc_client.hpp" +#include "shm/spsc_shm.hpp" +#include "shm_common.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC client implementation using shared memory + * + * Uses SPSC (single-producer single-consumer) for both requests and responses. + * Simple 1:1 client-server communication. + */ +class ShmClient : public IpcClient { + public: + explicit ShmClient(std::string base_name) + : base_name_(std::move(base_name)) + {} + + ~ShmClient() override = default; + + // Non-copyable, non-movable (owns shared memory resources) + ShmClient(const ShmClient&) = delete; + ShmClient& operator=(const ShmClient&) = delete; + ShmClient(ShmClient&&) = delete; + ShmClient& operator=(ShmClient&&) = delete; + + bool connect() override + { + if (request_ring_.has_value()) { + return true; // Already connected + } + + // Retry within the shared connect budget — the server process may + // still be starting up. + constexpr size_t max_attempts = CONNECT_RETRY_BUDGET_MS / CONNECT_RETRY_DELAY_MS; + constexpr auto retry_delay = std::chrono::milliseconds(CONNECT_RETRY_DELAY_MS); + + for (size_t attempt = 0; attempt < max_attempts; ++attempt) { + try { + // Connect to request ring (client writes, server reads) + std::string req_name = base_name_ + "_request"; + request_ring_ = SpscShm::connect(req_name); + + // Connect to response ring (server writes, client reads) + std::string resp_name = base_name_ + "_response"; + response_ring_ = SpscShm::connect(resp_name); + + return true; + } catch (...) { + request_ring_.reset(); + response_ring_.reset(); + if (attempt + 1 == max_attempts) { + return false; + } + std::this_thread::sleep_for(retry_delay); + } + } + + return false; + } + + bool send(const void* data, size_t len, uint64_t timeout_ns) override + { + if (!request_ring_.has_value()) { + return false; + } + return ring_send_msg(request_ring_.value(), data, len, normalize_call_timeout(timeout_ns)); + } + + std::span receive(uint64_t timeout_ns) override + { + if (!response_ring_.has_value()) { + return {}; + } + return ring_receive_msg(response_ring_.value(), normalize_call_timeout(timeout_ns)); + } + + void release(size_t message_size) override + { + if (!response_ring_.has_value()) { + return; + } + response_ring_->release(sizeof(uint32_t) + message_size); + } + + void close() override + { + request_ring_.reset(); + response_ring_.reset(); + } + + void wakeup() override + { + if (request_ring_.has_value()) { + request_ring_->wakeup_all(); + } + if (response_ring_.has_value()) { + response_ring_->wakeup_all(); + } + } + + void debug_dump() const + { + if (request_ring_.has_value()) { + request_ring_->debug_dump("Client REQ"); + } + if (response_ring_.has_value()) { + response_ring_->debug_dump("Client RESP"); + } + } + + private: + std::string base_name_; + std::optional request_ring_; // Client writes to this + std::optional response_ring_; // Client reads from this +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm_common.hpp b/ipc-runtime/cpp/ipc_runtime/shm_common.hpp new file mode 100644 index 000000000000..ccc273210bd9 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm_common.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include "ipc_runtime/constants.hpp" +#include "ipc_runtime/shm/spsc_shm.hpp" +#include +#include +#include +#include +#include + +namespace ipc { + +inline bool ring_send_msg(SpscShm& ring, const void* data, size_t len, uint64_t timeout_ns) +{ + // Prevent sending messages larger than half the ring buffer capacity. + // This simplifies wrap-around logic. + if (len > ring.capacity() / 2 - 4) { + throw std::runtime_error("ring_send_msg: message too large for ring " + "buffer, must be <= half capacity minus 4 bytes"); + } + + // Atomic send: claim space for entire message (length + data) + size_t total_size = 4 + len; + void* buf = ring.claim(total_size, timeout_ns); + if (buf == nullptr) { + return false; // Timeout or no space - nothing published yet (atomic + // failure) + } + + // Write length prefix and message data together + auto len_u32 = static_cast(len); + std::memcpy(buf, &len_u32, 4); + std::memcpy(static_cast(buf) + 4, data, len); + + // Publish entire message atomically + ring.publish(total_size); + + return true; +} + +inline std::span ring_receive_msg(SpscShm& ring, uint64_t timeout_ns) +{ + // Peek the length prefix (4 bytes) + void* len_ptr = ring.peek(4, timeout_ns); + if (len_ptr == nullptr) { + return {}; // Timeout + } + + // Read message length + uint32_t msg_len = 0; + std::memcpy(&msg_len, len_ptr, 4); + + // Validate before waiting on the claimed size: the send side can never + // legally publish more than capacity/2 - 4 bytes, so a larger prefix + // means the ring is corrupt. Waiting would only ever time out. + if (msg_len > MAX_FRAME_SIZE || msg_len > ring.capacity() / 2 - 4) { + throw std::runtime_error("ring_receive_msg: corrupt length prefix (" + std::to_string(msg_len) + + " bytes exceeds ring/frame limits)"); + } + + // Now peek the message data + void* msg_ptr = ring.peek(4 + msg_len, timeout_ns); + if (msg_ptr == nullptr) { + return {}; // Timeout + } + + // Return span directly into ring buffer (zero-copy!) + return std::span(static_cast(msg_ptr) + 4, msg_len); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/shm_server.hpp b/ipc-runtime/cpp/ipc_runtime/shm_server.hpp new file mode 100644 index 000000000000..3df6851e3bef --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/shm_server.hpp @@ -0,0 +1,156 @@ +#pragma once + +#include "ipc_server.hpp" +#include "shm/spsc_shm.hpp" +#include "shm_common.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC server implementation using shared memory + * + * Uses SPSC (single-producer single-consumer) for both requests and responses. + * Simple 1:1 client-server communication. + */ +class ShmServer : public IpcServer { + public: + ShmServer(std::string base_name, + size_t request_ring_size = DEFAULT_RING_SIZE, + size_t response_ring_size = DEFAULT_RING_SIZE) + : base_name_(std::move(base_name)) + , request_ring_size_(request_ring_size) + , response_ring_size_(response_ring_size) + {} + + ~ShmServer() override { close(); } + + // Non-copyable, non-movable (owns shared memory resources) + ShmServer(const ShmServer&) = delete; + ShmServer& operator=(const ShmServer&) = delete; + ShmServer(ShmServer&&) = delete; + ShmServer& operator=(ShmServer&&) = delete; + + bool listen() override + { + if (request_ring_.has_value()) { + return true; // Already listening + } + + // Clean up any leftover shared memory + std::string req_name = base_name_ + "_request"; + std::string resp_name = base_name_ + "_response"; + SpscShm::unlink(req_name); + SpscShm::unlink(resp_name); + + try { + // Create SPSC ring for requests (client writes, server reads) + request_ring_ = SpscShm::create(req_name, request_ring_size_); + + // Create SPSC ring for responses (server writes, client reads) + response_ring_ = SpscShm::create(resp_name, response_ring_size_); + + return true; + } catch (...) { + close(); // Cleanup on failure + return false; + } + } + + int wait_for_data(uint64_t timeout_ns) override + { + assert(request_ring_); + if (!request_ring_.has_value()) { + return -1; + } + + // Wait for data in request ring, return client ID 0 (always single client) + if (request_ring_->wait_for_data(sizeof(uint32_t), timeout_ns)) { + return 0; // Single client, always ID 0 + } + return -1; // Timeout + } + + std::span receive([[maybe_unused]] int client_id) override + { + if (!request_ring_.has_value()) { + return {}; + } + // TODO: Plumb timeout. + return ring_receive_msg(request_ring_.value(), 100000000); // 100ms timeout + } + + void release([[maybe_unused]] int client_id, size_t message_size) override + { + if (!request_ring_.has_value()) { + return; + } + request_ring_->release(sizeof(uint32_t) + message_size); + } + + bool send([[maybe_unused]] int client_id, const void* data, size_t len) override + { + if (!response_ring_.has_value()) { + return false; + } + return ring_send_msg(response_ring_.value(), data, len, 100000000); + } + + void close() override + { + // Close rings + request_ring_.reset(); + response_ring_.reset(); + + // Clean up shared memory + std::string req_name = base_name_ + "_request"; + std::string resp_name = base_name_ + "_response"; + SpscShm::unlink(req_name); + SpscShm::unlink(resp_name); + } + + void wakeup_all() override + { + // Wake any threads blocked in wait/peek/claim + if (request_ring_.has_value()) { + request_ring_->wakeup_all(); + } + if (response_ring_.has_value()) { + response_ring_->wakeup_all(); + } + } + + CleanupPaths cleanup_paths() const override + { + return CleanupPaths{ .unlink_paths = {}, + .shm_unlink_names = { base_name_ + "_request", base_name_ + "_response" } }; + } + + void debug_dump() const + { + if (request_ring_.has_value()) { + request_ring_->debug_dump("Server REQ"); + } + if (response_ring_.has_value()) { + response_ring_->debug_dump("Server RESP"); + } + } + + private: + std::string base_name_; + size_t request_ring_size_; + size_t response_ring_size_; + std::optional request_ring_; // Server reads from this + std::optional response_ring_; // Server writes to this +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/signal_handlers.cpp b/ipc-runtime/cpp/ipc_runtime/signal_handlers.cpp new file mode 100644 index 000000000000..52cf6aecd650 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/signal_handlers.cpp @@ -0,0 +1,143 @@ +#include "ipc_runtime/signal_handlers.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#if defined(__linux__) +#include +#elif defined(__APPLE__) +#include +#include +#else +#error "ipc-runtime supports Linux and macOS only" +#endif + +namespace ipc { + +namespace { + +// File-scope pointer used by signal handlers. Atomic so handler execution +// (which may interrupt main() at any point) observes a consistent value. +std::atomic g_signal_server{ nullptr }; + +// Filesystem artifacts to remove on a fatal signal, cached as fixed-size +// C strings at install time. The fatal handler may only use +// async-signal-safe calls, so no std::string / allocation there. +constexpr size_t MAX_CLEANUP_PATHS = 32; +constexpr size_t MAX_CLEANUP_PATH_LEN = 256; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) +char g_unlink_paths[MAX_CLEANUP_PATHS][MAX_CLEANUP_PATH_LEN]; +size_t g_num_unlink_paths = 0; +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) +char g_shm_unlink_names[MAX_CLEANUP_PATHS][MAX_CLEANUP_PATH_LEN]; +size_t g_num_shm_unlink_names = 0; + +void cache_cleanup_paths(const IpcServer& server) +{ + auto paths = server.cleanup_paths(); + g_num_unlink_paths = 0; + for (const auto& p : paths.unlink_paths) { + if (g_num_unlink_paths >= MAX_CLEANUP_PATHS || p.size() >= MAX_CLEANUP_PATH_LEN) { + continue; + } + std::strncpy(g_unlink_paths[g_num_unlink_paths], p.c_str(), MAX_CLEANUP_PATH_LEN - 1); + g_unlink_paths[g_num_unlink_paths][MAX_CLEANUP_PATH_LEN - 1] = '\0'; + g_num_unlink_paths++; + } + g_num_shm_unlink_names = 0; + for (const auto& p : paths.shm_unlink_names) { + if (g_num_shm_unlink_names >= MAX_CLEANUP_PATHS || p.size() >= MAX_CLEANUP_PATH_LEN) { + continue; + } + std::strncpy(g_shm_unlink_names[g_num_shm_unlink_names], p.c_str(), MAX_CLEANUP_PATH_LEN - 1); + g_shm_unlink_names[g_num_shm_unlink_names][MAX_CLEANUP_PATH_LEN - 1] = '\0'; + g_num_shm_unlink_names++; + } +} + +void write_stderr_signal_safe(const char* message, size_t len) +{ + ssize_t written = ::write(STDERR_FILENO, message, len); + (void)written; +} + +void graceful_shutdown_handler([[maybe_unused]] int signal) +{ + constexpr char message[] = "\nReceived shutdown signal\n"; + write_stderr_signal_safe(message, sizeof(message) - 1); + if (auto* s = g_signal_server.load(std::memory_order_acquire); s != nullptr) { + s->request_shutdown_from_signal(); + } +} + +void fatal_error_handler(int signal) +{ + constexpr char message[] = "\nFatal IPC runtime signal\n"; + write_stderr_signal_safe(message, sizeof(message) - 1); + // Best-effort removal of the socket file / SHM segments so a crashed + // server doesn't leak them. unlink() is async-signal-safe; shm_unlink() + // is a thin syscall wrapper and safe in practice (and we are about to + // _Exit anyway). + for (size_t i = 0; i < g_num_unlink_paths; i++) { + ::unlink(g_unlink_paths[i]); + } + for (size_t i = 0; i < g_num_shm_unlink_names; i++) { + ::shm_unlink(g_shm_unlink_names[i]); + } + std::_Exit(128 + signal); +} + +void setup_parent_death_monitoring() +{ +#if defined(__linux__) + if (prctl(PR_SET_PDEATHSIG, SIGTERM) == -1) { + std::cerr << "Warning: Could not set parent death signal" << '\n'; + } +#elif defined(__APPLE__) + pid_t parent_pid = getppid(); + std::thread([parent_pid]() { + int kq = kqueue(); + if (kq == -1) { + std::cerr << "Warning: Could not create kqueue for parent monitoring" << '\n'; + return; + } + struct kevent change; + EV_SET(&change, parent_pid, EVFILT_PROC, EV_ADD | EV_ENABLE, NOTE_EXIT, 0, nullptr); + if (kevent(kq, &change, 1, nullptr, 0, nullptr) == -1) { + std::cerr << "Warning: Could not monitor parent process" << '\n'; + close(kq); + return; + } + struct kevent event; + kevent(kq, nullptr, 0, &event, 1, nullptr); + std::cerr << "Parent process exited, shutting down..." << '\n'; + close(kq); + // Request graceful shutdown so the server's run() loop exits and its + // destructor unlinks the socket/SHM files. (std::exit here would skip + // the stack-allocated server's destructor and leak them.) + if (auto* s = g_signal_server.load(std::memory_order_acquire); s != nullptr) { + s->request_shutdown_from_signal(); + } + }).detach(); +#endif +} + +} // namespace + +void install_default_signal_handlers(IpcServer& server) +{ + g_signal_server.store(&server, std::memory_order_release); + cache_cleanup_paths(server); + (void)std::signal(SIGTERM, graceful_shutdown_handler); + (void)std::signal(SIGINT, graceful_shutdown_handler); + (void)std::signal(SIGBUS, fatal_error_handler); + (void)std::signal(SIGSEGV, fatal_error_handler); + setup_parent_death_monitoring(); +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/signal_handlers.hpp b/ipc-runtime/cpp/ipc_runtime/signal_handlers.hpp new file mode 100644 index 000000000000..83f0442d78a3 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/signal_handlers.hpp @@ -0,0 +1,34 @@ +#pragma once +/** + * @file signal_handlers.hpp + * @brief Default lifecycle signal handlers for IPC servers. + * + * Wires: + * - SIGTERM / SIGINT → IpcServer::request_shutdown_from_signal() + * (graceful drain; the run() loop exits on its next poll iteration) + * - SIGBUS / SIGSEGV → best-effort unlink of the server's socket/SHM + * files (cached at install time) + _Exit(128 + sig) + * - Parent-process death watch via prctl(PR_SET_PDEATHSIG) on Linux + * and a kqueue NOTE_EXIT watcher on macOS — so spawn-and-forget + * services die with their parent rather than turning into orphans. + * + * The reference is stored in a file-scope static, so this is a singleton: + * exactly one IpcServer can be "registered" for the process. Calling + * install_default_signal_handlers() a second time replaces the previous + * registration. + */ + +#include "ipc_runtime/ipc_server.hpp" + +namespace ipc { + +/** + * @brief Install default lifecycle signal handlers + parent-death monitor. + * + * @param server Server instance the handlers control. Must outlive the + * handlers (i.e. live until normal exit). Re-calling + * replaces the registered server. + */ +void install_default_signal_handlers(IpcServer& server); + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/socket.test.cpp b/ipc-runtime/cpp/ipc_runtime/socket.test.cpp new file mode 100644 index 000000000000..ddd81379d488 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/socket.test.cpp @@ -0,0 +1,156 @@ +#include "ipc_runtime/ipc_client.hpp" +#include "ipc_runtime/ipc_server.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ipc; + +namespace { + +std::string test_socket_path(const char* tag) +{ + return "/tmp/ipc_socket_test_" + std::string(tag) + "_" + std::to_string(getpid()) + ".sock"; +} + +TEST(SocketTest, EchoRoundTrip) +{ + std::string path = test_socket_path("echo"); + auto server = IpcServer::create_socket(path, 2); + ASSERT_TRUE(server->listen()); + + std::thread server_thread([&] { + server->run([](int, std::span req) { return std::vector(req.begin(), req.end()); }); + }); + + auto client = IpcClient::create_socket(path); + ASSERT_TRUE(client->connect()); + + std::vector payload = { 1, 2, 3, 4, 5 }; + ASSERT_TRUE(client->send(payload.data(), payload.size(), 1'000'000'000ULL)); + auto resp = client->receive(5'000'000'000ULL); + ASSERT_EQ(resp.size(), payload.size()); + EXPECT_EQ(std::memcmp(resp.data(), payload.data(), payload.size()), 0); + client->release(resp.size()); + + client->close(); + server->request_shutdown(); + server_thread.join(); + server->close(); +} + +// A handler returning an empty vector must still produce a (zero-length) +// response frame, otherwise the client deadlocks waiting for it. +TEST(SocketTest, ZeroLengthResponseRoundTrip) +{ + std::string path = test_socket_path("zlen"); + auto server = IpcServer::create_socket(path, 2); + ASSERT_TRUE(server->listen()); + + std::thread server_thread( + [&] { server->run([](int, std::span) { return std::vector{}; }); }); + + auto client = IpcClient::create_socket(path); + ASSERT_TRUE(client->connect()); + + uint8_t byte = 42; + ASSERT_TRUE(client->send(&byte, 1, 1'000'000'000ULL)); + auto resp = client->receive(5'000'000'000ULL); + EXPECT_NE(resp.data(), nullptr) << "zero-length response should be success, not timeout"; + EXPECT_EQ(resp.size(), 0U); + client->release(resp.size()); + + client->close(); + server->request_shutdown(); + server_thread.join(); + server->close(); +} + +// A corrupt/malicious length prefix must cause the server to drop the +// connection, not allocate the claimed amount. +TEST(SocketTest, ServerRejectsOversizedLengthPrefix) +{ + std::string path = test_socket_path("oversize_srv"); + auto server = IpcServer::create_socket(path, 2); + ASSERT_TRUE(server->listen()); + + std::thread server_thread([&] { + server->run([](int, std::span req) { return std::vector(req.begin(), req.end()); }); + }); + + // Raw client so we can write a bogus frame. + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_GE(fd, 0); + struct sockaddr_un addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + std::strncpy(addr.sun_path, path.c_str(), sizeof(addr.sun_path) - 1); + ASSERT_EQ(::connect(fd, reinterpret_cast(&addr), sizeof(addr)), 0); + + uint32_t bogus_len = 0x7FFFFFFF; // ~2 GiB, way over MAX_FRAME_SIZE + ASSERT_EQ(::send(fd, &bogus_len, sizeof(bogus_len), 0), static_cast(sizeof(bogus_len))); + + // Server should close the connection. recv with a timeout so a buggy + // server (waiting for 2 GiB of payload) fails the test instead of hanging. + struct timeval tv = { .tv_sec = 5, .tv_usec = 0 }; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)); + uint8_t buf[4]; + ssize_t n = ::recv(fd, buf, sizeof(buf), 0); + EXPECT_EQ(n, 0) << "server should have closed the connection on oversized frame"; + ::close(fd); + + server->request_shutdown(); + server_thread.join(); + server->close(); +} + +// Same on the client side: a bogus length prefix from the server must be +// rejected (connection closed), not trusted as an allocation size. +TEST(SocketTest, ClientRejectsOversizedLengthPrefix) +{ + std::string path = test_socket_path("oversize_cli"); + + int listen_fd = socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_GE(listen_fd, 0); + ::unlink(path.c_str()); + struct sockaddr_un addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + std::strncpy(addr.sun_path, path.c_str(), sizeof(addr.sun_path) - 1); + ASSERT_EQ(bind(listen_fd, reinterpret_cast(&addr), sizeof(addr)), 0); + ASSERT_EQ(::listen(listen_fd, 1), 0); + + std::thread fake_server([&] { + int conn_fd = ::accept(listen_fd, nullptr, nullptr); + if (conn_fd < 0) { + return; + } + uint32_t bogus_len = 0x7FFFFFFF; + ::send(conn_fd, &bogus_len, sizeof(bogus_len), 0); + // Leave the connection open: a buggy client would block waiting for + // ~2 GiB of payload (bounded by its receive timeout). + uint8_t buf[1]; + ::recv(conn_fd, buf, sizeof(buf), 0); // returns when client closes + ::close(conn_fd); + }); + + auto client = IpcClient::create_socket(path); + ASSERT_TRUE(client->connect()); + auto resp = client->receive(2'000'000'000ULL); + EXPECT_EQ(resp.data(), nullptr) << "oversized frame must be an error"; + + client->close(); + fake_server.join(); + ::close(listen_fd); + ::unlink(path.c_str()); +} + +} // namespace diff --git a/ipc-runtime/cpp/ipc_runtime/socket_client.cpp b/ipc-runtime/cpp/ipc_runtime/socket_client.cpp new file mode 100644 index 000000000000..c99cd8c8e9e6 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/socket_client.cpp @@ -0,0 +1,211 @@ +#include "ipc_runtime/socket_client.hpp" +#include "ipc_runtime/constants.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ipc { + +SocketClient::SocketClient(std::string socket_path) + : socket_path_(std::move(socket_path)) +{} + +SocketClient::~SocketClient() +{ + close_internal(); +} + +bool SocketClient::connect() +{ + if (fd_ >= 0) { + return true; // Already connected + } + + constexpr size_t max_attempts = CONNECT_RETRY_BUDGET_MS / CONNECT_RETRY_DELAY_MS; + constexpr auto retry_delay = std::chrono::milliseconds(CONNECT_RETRY_DELAY_MS); + + for (size_t attempt = 0; attempt < max_attempts; ++attempt) { + // Create socket + fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd_ < 0) { + return false; + } + + // Connect to server + struct sockaddr_un addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1); + + if (::connect(fd_, reinterpret_cast(&addr), sizeof(addr)) == 0) { + applied_recv_timeout_ns_ = 0; + applied_send_timeout_ns_ = 0; + return true; + } + + ::close(fd_); + fd_ = -1; + if (attempt + 1 == max_attempts) { + return false; + } + std::this_thread::sleep_for(retry_delay); + } + + return false; +} + +bool SocketClient::apply_timeout(int option, uint64_t& applied_ns, uint64_t timeout_ns) +{ + if (applied_ns == timeout_ns) { + return true; + } + // timeout_ns == 0 → {0, 0} which means "no timeout" (infinite) for + // SO_RCVTIMEO / SO_SNDTIMEO. + struct timeval tv; + tv.tv_sec = static_cast(timeout_ns / 1000000000ULL); + tv.tv_usec = static_cast((timeout_ns % 1000000000ULL) / 1000ULL); + if (setsockopt(fd_, SOL_SOCKET, option, &tv, sizeof(tv)) != 0) { + return false; + } + applied_ns = timeout_ns; + return true; +} + +int SocketClient::send_exact(const void* buf, size_t len, bool& partial) +{ + size_t total_sent = 0; + while (total_sent < len) { + ssize_t n = ::send(fd_, static_cast(buf) + total_sent, len - total_sent, 0); + if (n < 0) { + if (errno == EINTR) { + continue; // Interrupted, retry + } + partial = total_sent > 0; + return -1; // Timeout (EAGAIN/EWOULDBLOCK) or hard error + } + total_sent += static_cast(n); + } + return 1; +} + +int SocketClient::recv_exact(void* buf, size_t len, bool& partial) +{ + size_t total_read = 0; + while (total_read < len) { + ssize_t n = ::recv(fd_, static_cast(buf) + total_read, len - total_read, 0); + if (n < 0) { + if (errno == EINTR) { + continue; // Interrupted, retry + } + partial = total_read > 0; + return -1; // Timeout (EAGAIN/EWOULDBLOCK) or hard error + } + if (n == 0) { + partial = total_read > 0; + return 0; // Server disconnected + } + total_read += static_cast(n); + } + return 1; +} + +bool SocketClient::send(const void* data, size_t len, uint64_t timeout_ns) +{ + if (fd_ < 0) { + errno = EINVAL; + return false; + } + if (len > MAX_FRAME_SIZE) { + errno = EMSGSIZE; + return false; + } + + apply_timeout(SO_SNDTIMEO, applied_send_timeout_ns_, timeout_ns); + + // Send length prefix (4 bytes, little-endian), then message data, + // looping on partial writes. + auto msg_len = static_cast(len); + bool partial = false; + if (send_exact(&msg_len, sizeof(msg_len), partial) != 1 || send_exact(data, len, partial) != 1) { + if (partial) { + // Part of the frame is on the wire — the stream is desynced and + // unusable. Close rather than silently corrupting later frames. + close_internal(); + } + return false; + } + return true; +} + +std::span SocketClient::receive(uint64_t timeout_ns) +{ + if (fd_ < 0) { + return {}; + } + + apply_timeout(SO_RCVTIMEO, applied_recv_timeout_ns_, timeout_ns); + + // Read length prefix (4 bytes) + uint32_t msg_len = 0; + bool partial = false; + if (recv_exact(&msg_len, sizeof(msg_len), partial) != 1) { + if (partial) { + // Mid-frame failure — stream desynced. + close_internal(); + } + return {}; + } + + // A corrupt/malicious prefix must not drive the allocation below. + if (msg_len > MAX_FRAME_SIZE) { + close_internal(); + return {}; + } + + // Ensure buffer is large enough. Keep at least one byte so data() is + // non-null for zero-length messages (null data() signals failure). + if (recv_buffer_.size() < msg_len || recv_buffer_.empty()) { + recv_buffer_.resize(std::max(msg_len, 1)); + } + + // Read message data into internal buffer + if (recv_exact(recv_buffer_.data(), msg_len, partial) != 1) { + // Prefix consumed but payload incomplete — stream desynced. + close_internal(); + return {}; + } + + // Return span into internal buffer + return std::span(recv_buffer_.data(), msg_len); +} + +void SocketClient::release(size_t /*message_size*/) +{ + // No-op for sockets - data is already consumed from kernel buffer during + // recv() +} + +void SocketClient::close() +{ + close_internal(); +} + +void SocketClient::close_internal() +{ + if (fd_ >= 0) { + ::close(fd_); + fd_ = -1; + } +} + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/socket_client.hpp b/ipc-runtime/cpp/ipc_runtime/socket_client.hpp new file mode 100644 index 000000000000..8d495adf3092 --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/socket_client.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "ipc_runtime/ipc_client.hpp" +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC client implementation using Unix domain sockets + * + * Direct implementation with no wrapper layer - manages socket connection + * directly. Send/receive timeouts are honored via SO_SNDTIMEO / SO_RCVTIMEO + * (timeout_ns == 0 means infinite). + */ +class SocketClient : public IpcClient { + public: + explicit SocketClient(std::string socket_path); + ~SocketClient() override; + + // Non-copyable, non-movable (owns file descriptor) + SocketClient(const SocketClient&) = delete; + SocketClient& operator=(const SocketClient&) = delete; + SocketClient(SocketClient&&) = delete; + SocketClient& operator=(SocketClient&&) = delete; + + bool connect() override; + bool send(const void* data, size_t len, uint64_t timeout_ns) override; + std::span receive(uint64_t timeout_ns) override; + void release(size_t message_size) override; + void close() override; + + private: + void close_internal(); + bool apply_timeout(int option, uint64_t& applied_ns, uint64_t timeout_ns); + // Returns 1 on success, 0 on orderly EOF before any byte, -1 on + // error/timeout. `partial` is set when the stream is desynced (some but + // not all bytes transferred). + int recv_exact(void* buf, size_t len, bool& partial); + int send_exact(const void* buf, size_t len, bool& partial); + + std::string socket_path_; + int fd_ = -1; + std::vector recv_buffer_; // Internal buffer for socket recv + // Last timeouts applied via setsockopt; avoids a syscall per call. + uint64_t applied_recv_timeout_ns_ = 0; + uint64_t applied_send_timeout_ns_ = 0; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/socket_server.cpp b/ipc-runtime/cpp/ipc_runtime/socket_server.cpp new file mode 100644 index 000000000000..d26cfd6c3d1e --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/socket_server.cpp @@ -0,0 +1,606 @@ +#include "ipc_runtime/socket_server.hpp" +#include "ipc_runtime/constants.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Platform-specific event notification includes +#ifdef __APPLE__ +#include // kqueue on macOS/BSD +#elif defined(__linux__) +#include // epoll on Linux +#else +#error "ipc-runtime supports Linux and macOS only" +#endif + +namespace ipc { + +SocketServer::SocketServer(std::string socket_path, int initial_max_clients) + : socket_path_(std::move(socket_path)) + , initial_max_clients_(initial_max_clients) +{ + const size_t reserve_size = initial_max_clients > 0 ? static_cast(initial_max_clients) : 10; + client_fds_.reserve(reserve_size); + recv_buffers_.reserve(reserve_size); +} + +SocketServer::~SocketServer() +{ + close_internal(); +} + +void SocketServer::close() +{ + close_internal(); +} + +void SocketServer::close_internal() +{ + // Close all client connections + for (int fd : client_fds_) { + if (fd >= 0) { + ::close(fd); + } + } + client_fds_.clear(); + fd_to_client_id_.clear(); + num_clients_ = 0; + + if (fd_ >= 0) { + ::close(fd_); + fd_ = -1; + } + + if (listen_fd_ >= 0) { + ::close(listen_fd_); + listen_fd_ = -1; + } + + // Clean up socket file + ::unlink(socket_path_.c_str()); +} + +int SocketServer::find_free_slot() +{ + // Look for existing free slot + for (size_t i = 0; i < client_fds_.size(); i++) { + if (client_fds_[i] == -1) { + return static_cast(i); + } + } + + // No free slot found, allocate new one at end + return static_cast(client_fds_.size()); +} + +bool SocketServer::send(int client_id, const void* data, size_t len) +{ + if (client_id < 0 || static_cast(client_id) >= client_fds_.size() || + client_fds_[static_cast(client_id)] < 0) { + errno = EINVAL; + return false; + } + + if (len > MAX_FRAME_SIZE) { + errno = EMSGSIZE; + return false; + } + + int fd = client_fds_[static_cast(client_id)]; + + // Send length prefix (4 bytes) then message data, looping on partial + // writes — a short write after the prefix would permanently desync the + // stream for this connection. + auto msg_len = static_cast(len); + const uint8_t* parts[2] = { reinterpret_cast(&msg_len), static_cast(data) }; + size_t part_lens[2] = { sizeof(msg_len), len }; + for (int part = 0; part < 2; part++) { + size_t total_sent = 0; + while (total_sent < part_lens[part]) { + ssize_t n = ::send(fd, parts[part] + total_sent, part_lens[part] - total_sent, 0); + if (n < 0) { + if (errno == EINTR) { + continue; // Interrupted, retry + } + if (part > 0 || total_sent > 0) { + // Frame partially on the wire — stream desynced. + disconnect_client(client_id); + } + return false; + } + total_sent += static_cast(n); + } + } + return true; +} + +void SocketServer::release(int client_id, size_t message_size) +{ + // No-op for sockets - message already consumed from kernel buffer during receive() + (void)client_id; + (void)message_size; +} + +std::span SocketServer::receive(int client_id) +{ + if (client_id < 0 || static_cast(client_id) >= client_fds_.size() || + client_fds_[static_cast(client_id)] < 0) { + return {}; + } + + int fd = client_fds_[static_cast(client_id)]; + const auto client_idx = static_cast(client_id); + + // Ensure buffers are sized for this client + if (client_idx >= recv_buffers_.size()) { + recv_buffers_.resize(client_idx + 1); + } + + // Read length prefix (4 bytes) - must loop until all bytes received (MSG_WAITALL unreliable on macOS) + uint32_t msg_len = 0; + size_t total_read = 0; + while (total_read < sizeof(msg_len)) { + ssize_t n = ::recv(fd, reinterpret_cast(&msg_len) + total_read, sizeof(msg_len) - total_read, 0); + if (n < 0) { + if (errno == EINTR) { + continue; // Interrupted, retry + } + return {}; + } + if (n == 0) { + // Client disconnected + disconnect_client(client_id); + return {}; + } + total_read += static_cast(n); + } + + // A corrupt/malicious prefix must not drive the allocation below. + if (msg_len > MAX_FRAME_SIZE) { + disconnect_client(client_id); + return {}; + } + + // Resize buffer if needed to fit length prefix + message + size_t total_size = sizeof(uint32_t) + msg_len; + if (recv_buffers_[client_idx].size() < total_size) { + recv_buffers_[client_idx].resize(total_size); + } + + // Store length prefix in buffer + std::memcpy(recv_buffers_[client_idx].data(), &msg_len, sizeof(uint32_t)); + + // Read message data - must loop until all bytes received (MSG_WAITALL unreliable on macOS) + total_read = 0; + while (total_read < msg_len) { + ssize_t n = + ::recv(fd, recv_buffers_[client_idx].data() + sizeof(uint32_t) + total_read, msg_len - total_read, 0); + if (n < 0) { + if (errno == EINTR) { + continue; // Interrupted, retry + } + disconnect_client(client_id); + return {}; + } + if (n == 0) { + // Client disconnected mid-message + disconnect_client(client_id); + return {}; + } + total_read += static_cast(n); + } + + return std::span(recv_buffers_[client_idx].data() + sizeof(uint32_t), msg_len); +} + +#ifdef __APPLE__ +// ============================================================================ +// macOS Implementation (kqueue, blocking sockets, simple accept) +// ============================================================================ + +bool SocketServer::listen() +{ + if (listen_fd_ >= 0) { + return true; // Already listening + } + + // Remove any existing socket file + ::unlink(socket_path_.c_str()); + + // Create socket + listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd_ < 0) { + return false; + } + + // Set non-blocking mode (required for accept-until-EAGAIN pattern) + int flags = fcntl(listen_fd_, F_GETFL, 0); + if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + // Bind to path + struct sockaddr_un addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1); + + if (bind(listen_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + // Restrict socket to owner only, matching the 0600 mode used for SHM transport + ::chmod(socket_path_.c_str(), 0600); + + // Listen with backlog + int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : SOCKET_BACKLOG; + if (::listen(listen_fd_, backlog) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + // Create kqueue instance + fd_ = kqueue(); + if (fd_ < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + // Add listen socket to kqueue + struct kevent ev; + EV_SET(&ev, listen_fd_, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr); + if (kevent(fd_, &ev, 1, nullptr, 0, nullptr) < 0) { + ::close(fd_); + fd_ = -1; + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + return true; +} + +int SocketServer::accept() +{ + if (listen_fd_ < 0) { + errno = EINVAL; + return -1; + } + + // Accept all pending connections (loop until EAGAIN) + // Non-blocking socket ensures this returns immediately + int last_client_id = -1; + + while (true) { + int client_fd = ::accept(listen_fd_, nullptr, nullptr); + + if (client_fd < 0) { + // Check if this is expected (no more connections) or a real error + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // No more pending connections - expected, break + break; + } + // Real error - but if we already accepted some, return success + if (last_client_id >= 0) { + break; + } + // No connections accepted and got real error + return -1; + } + + // Set client socket to BLOCKING mode (inherited non-blocking from listen socket) + // This avoids busy-waiting in recv() - we only recv after kqueue signals data ready + int flags = fcntl(client_fd, F_GETFL, 0); + if (flags >= 0) { + fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK); + } + + // Find free slot (or allocate new one) + int client_id = find_free_slot(); + + // Store client fd + const auto client_id_unsigned = static_cast(client_id); + if (client_id_unsigned >= client_fds_.size()) { + client_fds_.resize(client_id_unsigned + 1, -1); + } + client_fds_[static_cast(client_id)] = client_fd; + fd_to_client_id_[client_fd] = client_id; + num_clients_++; + + // Add client to kqueue + struct kevent kev; + EV_SET(&kev, client_fd, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, nullptr); + if (kevent(fd_, &kev, 1, nullptr, 0, nullptr) < 0) { + disconnect_client(client_id); + // Continue trying to accept other pending connections + continue; + } + + last_client_id = client_id; + } + + return last_client_id; +} + +int SocketServer::wait_for_data(uint64_t timeout_ns) +{ + if (fd_ < 0) { + errno = EINVAL; + return -1; + } + + struct kevent ev; + struct timespec timeout; + struct timespec* timeout_ptr = nullptr; + + if (timeout_ns > 0) { + timeout.tv_sec = static_cast(timeout_ns / 1000000000ULL); + timeout.tv_nsec = static_cast(timeout_ns % 1000000000ULL); + timeout_ptr = &timeout; + } else if (timeout_ns == 0) { + timeout.tv_sec = 0; + timeout.tv_nsec = 0; + timeout_ptr = &timeout; + } + + int n = kevent(fd_, nullptr, 0, &ev, 1, timeout_ptr); + if (n <= 0) { + return -1; + } + + int ready_fd = static_cast(ev.ident); + + // Check if it's listen socket (new connection) or client data + if (ready_fd == listen_fd_) { + errno = EAGAIN; // Signal caller to call accept + return -1; + } + + // Find which client + auto it = fd_to_client_id_.find(ready_fd); + if (it == fd_to_client_id_.end()) { + errno = ENOENT; + return -1; + } + + return it->second; +} + +void SocketServer::disconnect_client(int client_id) +{ + if (client_id < 0 || static_cast(client_id) >= client_fds_.size()) { + return; + } + + int fd = client_fds_[static_cast(client_id)]; + if (fd >= 0) { + // For kqueue, we don't need explicit deletion - closing the fd removes it automatically + // But we can explicitly remove it for clarity + struct kevent ev; + EV_SET(&ev, fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + kevent(fd_, &ev, 1, nullptr, 0, nullptr); + + ::close(fd); + fd_to_client_id_.erase(fd); + client_fds_[static_cast(client_id)] = -1; + num_clients_--; + } +} + +#else + +// ============================================================================ +// Linux Implementation (epoll, non-blocking sockets, accept-until-EAGAIN) +// ============================================================================ + +bool SocketServer::listen() +{ + if (listen_fd_ >= 0) { + return true; // Already listening + } + + // Remove any existing socket file + ::unlink(socket_path_.c_str()); + + // Create socket + listen_fd_ = socket(AF_UNIX, SOCK_STREAM, 0); + if (listen_fd_ < 0) { + return false; + } + + // Set non-blocking mode (required for accept-until-EAGAIN pattern) + int flags = fcntl(listen_fd_, F_GETFL, 0); + if (flags < 0 || fcntl(listen_fd_, F_SETFL, flags | O_NONBLOCK) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + // Bind to path + struct sockaddr_un addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + std::strncpy(addr.sun_path, socket_path_.c_str(), sizeof(addr.sun_path) - 1); + + if (bind(listen_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + return false; + } + + // Restrict socket to owner only, matching the 0600 mode used for SHM transport + ::chmod(socket_path_.c_str(), 0600); + + // Listen with backlog + int backlog = initial_max_clients_ > 0 ? initial_max_clients_ : SOCKET_BACKLOG; + if (::listen(listen_fd_, backlog) < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + // Create epoll instance + fd_ = epoll_create1(0); + if (fd_ < 0) { + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + // Add listen socket to epoll + struct epoll_event ev; + ev.events = EPOLLIN; + ev.data.fd = listen_fd_; + if (epoll_ctl(fd_, EPOLL_CTL_ADD, listen_fd_, &ev) < 0) { + ::close(fd_); + fd_ = -1; + ::close(listen_fd_); + listen_fd_ = -1; + ::unlink(socket_path_.c_str()); + return false; + } + + return true; +} + +int SocketServer::accept() +{ + if (listen_fd_ < 0) { + errno = EINVAL; + return -1; + } + + // Accept all pending connections (loop until EAGAIN) + // Non-blocking socket ensures this returns immediately + int last_client_id = -1; + + while (true) { + int client_fd = ::accept(listen_fd_, nullptr, nullptr); + + if (client_fd < 0) { + // Check if this is expected (no more connections) or a real error + if (errno == EAGAIN || errno == EWOULDBLOCK) { + // No more pending connections - expected, break + break; + } + // Real error - but if we already accepted some, return success + if (last_client_id >= 0) { + break; + } + // No connections accepted and got real error + return -1; + } + + // Set client socket to BLOCKING mode (inherited non-blocking from listen socket) + // This avoids busy-waiting in recv() - we only recv after epoll signals data ready + int flags = fcntl(client_fd, F_GETFL, 0); + if (flags >= 0) { + fcntl(client_fd, F_SETFL, flags & ~O_NONBLOCK); + } + + // Find free slot (or allocate new one) + int client_id = find_free_slot(); + + // Store client fd + const auto client_id_unsigned = static_cast(client_id); + if (client_id_unsigned >= client_fds_.size()) { + client_fds_.resize(client_id_unsigned + 1, -1); + } + client_fds_[static_cast(client_id)] = client_fd; + fd_to_client_id_[client_fd] = client_id; + num_clients_++; + + // Add client to epoll + struct epoll_event client_ev; + client_ev.events = EPOLLIN; + client_ev.data.fd = client_fd; + if (epoll_ctl(fd_, EPOLL_CTL_ADD, client_fd, &client_ev) < 0) { + disconnect_client(client_id); + // Continue trying to accept other pending connections + continue; + } + + last_client_id = client_id; + } + + return last_client_id; +} + +int SocketServer::wait_for_data(uint64_t timeout_ns) +{ + if (fd_ < 0) { + errno = EINVAL; + return -1; + } + + struct epoll_event ev; + // 0 = non-blocking poll (matches the interface doc and the kqueue + // branch). Sub-millisecond timeouts round up to 1ms; large timeouts + // clamp to INT_MAX ms. + int timeout_ms = 0; + if (timeout_ns > 0) { + uint64_t ms = std::max(1, timeout_ns / 1000000ULL); + timeout_ms = static_cast(std::min(ms, INT_MAX)); + } + int n = epoll_wait(fd_, &ev, 1, timeout_ms); + if (n <= 0) { + return -1; + } + + // Check if it's listen socket (new connection) or client data + if (ev.data.fd == listen_fd_) { + errno = EAGAIN; // Signal caller to call accept + return -1; + } + + // Find which client + auto it = fd_to_client_id_.find(ev.data.fd); + if (it == fd_to_client_id_.end()) { + errno = ENOENT; + return -1; + } + + return it->second; +} + +void SocketServer::disconnect_client(int client_id) +{ + if (client_id < 0 || static_cast(client_id) >= client_fds_.size()) { + return; + } + + int fd = client_fds_[static_cast(client_id)]; + if (fd >= 0) { + epoll_ctl(fd_, EPOLL_CTL_DEL, fd, nullptr); + ::close(fd); + fd_to_client_id_.erase(fd); + client_fds_[static_cast(client_id)] = -1; + num_clients_--; + } +} + +#endif + +} // namespace ipc diff --git a/ipc-runtime/cpp/ipc_runtime/socket_server.hpp b/ipc-runtime/cpp/ipc_runtime/socket_server.hpp new file mode 100644 index 000000000000..5d116a60286f --- /dev/null +++ b/ipc-runtime/cpp/ipc_runtime/socket_server.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "ipc_runtime/ipc_server.hpp" +#include +#include +#include +#include +#include +#include + +namespace ipc { + +/** + * @brief IPC server implementation using Unix domain sockets + * + * Platform-specific implementation: + * - Linux: uses epoll for efficient multi-client handling + * - macOS: uses kqueue for efficient multi-client handling + * Dynamic client capacity with no artificial limits. + */ +class SocketServer : public IpcServer { + public: + SocketServer(std::string socket_path, int initial_max_clients); + ~SocketServer() override; + + // Non-copyable, non-movable (owns file descriptors) + SocketServer(const SocketServer&) = delete; + SocketServer& operator=(const SocketServer&) = delete; + SocketServer(SocketServer&&) = delete; + SocketServer& operator=(SocketServer&&) = delete; + + bool listen() override; + int accept() override; + int wait_for_data(uint64_t timeout_ns) override; + std::span receive(int client_id) override; + void release(int client_id, size_t message_size) override; + bool send(int client_id, const void* data, size_t len) override; + void close() override; + + CleanupPaths cleanup_paths() const override + { + return CleanupPaths{ .unlink_paths = { socket_path_ }, .shm_unlink_names = {} }; + } + + private: + void close_internal(); + void disconnect_client(int client_id); + int find_free_slot(); + + std::string socket_path_; + int initial_max_clients_; + int listen_fd_ = -1; + int fd_ = -1; // kqueue or epoll fd + std::vector client_fds_; // client_id -> fd + std::unordered_map fd_to_client_id_; // fd -> client_id (for fast lookup) + std::vector> recv_buffers_; // client_id -> recv buffer + int num_clients_ = 0; +}; + +} // namespace ipc diff --git a/ipc-runtime/cpp/napi/.gitignore b/ipc-runtime/cpp/napi/.gitignore new file mode 100644 index 000000000000..3805b61d94b4 --- /dev/null +++ b/ipc-runtime/cpp/napi/.gitignore @@ -0,0 +1,4 @@ +.pnp.cjs +.pnp.loader.mjs +.yarn/ +node_modules/ diff --git a/ipc-runtime/cpp/napi/.yarnrc.yml b/ipc-runtime/cpp/napi/.yarnrc.yml new file mode 100644 index 000000000000..1d8b263ddcf5 --- /dev/null +++ b/ipc-runtime/cpp/napi/.yarnrc.yml @@ -0,0 +1,5 @@ +nodeLinker: node-modules + +npmMinimalAgeGate: 7d + +npmPreapprovedPackages: [] diff --git a/ipc-runtime/cpp/napi/CMakeLists.txt b/ipc-runtime/cpp/napi/CMakeLists.txt new file mode 100644 index 000000000000..504a0e53becf --- /dev/null +++ b/ipc-runtime/cpp/napi/CMakeLists.txt @@ -0,0 +1,57 @@ +# Node.js native addon for ipc-runtime's SHM-IPC clients. +# +# Builds a SHARED library `ipc_runtime_napi.node` exporting MsgpackClient / +# MsgpackClientAsync. The header search paths for `node-addon-api` / `napi.h` +# come from the local `node-addon-api` + `node-api-headers` npm packages. +# +# Only built when targeting native (not WASM, not fuzzing). +if(WASM OR FUZZING) + return() +endif() + +# see https://nodejs.org/dist/latest/docs/api/n-api.html#node-api-version-matrix +add_definitions(-DNAPI_VERSION=9) + +execute_process( + COMMAND yarn --immutable + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} +) + +execute_process( + COMMAND node -p "require('node-addon-api').include" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE NODE_ADDON_API_DIR +) + +execute_process( + COMMAND node -p "require('node-api-headers').include_dir" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE NODE_API_HEADERS_DIR +) + +string(REGEX REPLACE "[\r\n\"]" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR}) +string(REGEX REPLACE "[\r\n\"]" "" NODE_API_HEADERS_DIR ${NODE_API_HEADERS_DIR}) + +add_library(ipc_runtime_napi SHARED + init.cpp + msgpack_client_wrapper.cpp + msgpack_client_async.cpp +) +# Pin the output location so the addon lands at /lib/ipc_runtime_napi.node +# regardless of build mode (standalone, cross-preset, or add_subdirectory'd). +set_target_properties(ipc_runtime_napi PROPERTIES + PREFIX "" + SUFFIX ".node" + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +# Mark node-addon-api / node-api-headers as SYSTEM includes so vendored +# headers don't trip parent-project warning flags. +target_include_directories(ipc_runtime_napi SYSTEM PRIVATE + ${NODE_API_HEADERS_DIR} + ${NODE_ADDON_API_DIR} +) +target_link_libraries(ipc_runtime_napi PRIVATE ipc_runtime) + +# On macOS, Node N-API symbols are provided by the host runtime, not at link time. +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_link_options(ipc_runtime_napi PRIVATE "-undefined" "dynamic_lookup") +endif() diff --git a/ipc-runtime/cpp/napi/init.cpp b/ipc-runtime/cpp/napi/init.cpp new file mode 100644 index 000000000000..a7b6d13f261f --- /dev/null +++ b/ipc-runtime/cpp/napi/init.cpp @@ -0,0 +1,16 @@ +#include "msgpack_client_async.hpp" +#include "msgpack_client_wrapper.hpp" +#include "napi.h" + +// Node addon entry point for ipc-runtime's SHM-IPC client bindings. +// Exports only the transport-agnostic msgpack clients; service-specific +// bindings live in their own addons and consume @aztec/ipc-runtime separately. +static Napi::Object Init(Napi::Env env, Napi::Object exports) +{ + exports.Set(Napi::String::New(env, "MsgpackClient"), ipc::napi::MsgpackClientWrapper::get_class(env)); + exports.Set(Napi::String::New(env, "MsgpackClientAsync"), ipc::napi::MsgpackClientAsync::get_class(env)); + return exports; +} + +// NOLINTNEXTLINE +NODE_API_MODULE(ipc_runtime_napi, Init) diff --git a/ipc-runtime/cpp/napi/msgpack_client_async.cpp b/ipc-runtime/cpp/napi/msgpack_client_async.cpp new file mode 100644 index 000000000000..3c2a031bd6b2 --- /dev/null +++ b/ipc-runtime/cpp/napi/msgpack_client_async.cpp @@ -0,0 +1,196 @@ +#include "msgpack_client_async.hpp" + +#include "ipc_runtime/ipc_client.hpp" +#include "napi.h" + +#include +#include +#include +#include + +namespace ipc::napi { + +MsgpackClientAsync::MsgpackClientAsync(const Napi::CallbackInfo& info) + : ObjectWrap(info) +{ + Napi::Env env = info.Env(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "First argument must be a string (shared memory name)"); + } + std::string shm_name = info[0].As(); + + std::size_t client_id = 0; + if (info.Length() >= 2 && info[1].IsNumber()) { + client_id = static_cast(info[1].As().Uint32Value()); + } + + // MPSC-SHM client — matches ipc::make_server's default transport. + client_ = ipc::IpcClient::create_mpsc_shm(shm_name, client_id); + + if (!client_->connect()) { + throw Napi::Error::New(env, "Failed to connect to shared memory server"); + } +} + +MsgpackClientAsync::~MsgpackClientAsync() +{ + close_internal(); +} + +Napi::Value MsgpackClientAsync::setResponseCallback(const Napi::CallbackInfo& info) +{ + Napi::Env env = info.Env(); + + if (info.Length() < 1 || !info[0].IsFunction()) { + throw Napi::TypeError::New(env, "First argument must be a function"); + } + + // Store callback for lazy TSFN creation in acquire(). + js_callback_ = Napi::Persistent(info[0].As()); + + // Start the response poller. Joined by close(). + poll_thread_ = std::thread(&MsgpackClientAsync::poll_responses, this); + + return env.Undefined(); +} + +void MsgpackClientAsync::poll_responses() +{ + constexpr uint64_t TIMEOUT_NS = 1'000'000'000; // 1s + + while (!shutdown_.load(std::memory_order_acquire)) { + std::span response = client_->receive(TIMEOUT_NS); + // data() == nullptr means timeout; a non-null empty span is a valid + // zero-length response and must be delivered. + if (response.data() == nullptr) { + continue; // timeout — keep polling (and re-check shutdown) + } + + // Copy out — span is invalidated by release(). + auto* response_data = new std::vector(response.begin(), response.end()); + client_->release(response.size()); + + std::lock_guard lock(tsfn_mutex_); + // Only call into the TSFN while a reference is held (ref_count_ > 0). + // Calling a released TSFN — or the default-constructed one before the + // first acquire(), e.g. for a stale response already in the ring — is + // undefined behaviour. + if (ref_count_ == 0) { + delete response_data; + continue; + } + auto status = tsfn_.NonBlockingCall( + response_data, [](Napi::Env env, Napi::Function js_callback, std::vector* data) { + auto js_buffer = Napi::Buffer::Copy(env, data->data(), data->size()); + js_callback.Call({ js_buffer }); + delete data; + }); + if (status != napi_ok) { + // Failed to queue — likely process exiting. Drop the response. + delete response_data; + } + } +} + +Napi::Value MsgpackClientAsync::call(const Napi::CallbackInfo& info) +{ + Napi::Env env = info.Env(); + + if (info.Length() < 1 || !info[0].IsBuffer()) { + throw Napi::TypeError::New(env, "First argument must be a Buffer"); + } + if (shutdown_.load(std::memory_order_acquire)) { + throw Napi::Error::New(env, "Client is closed"); + } + + auto input_buffer = info[0].As>(); + const uint8_t* input_data = input_buffer.Data(); + size_t input_len = input_buffer.Length(); + + // Single non-blocking attempt: claim() treats timeout 1 as an immediate + // check (0 would be normalized to infinite). TS owns the promise queue. + if (!client_->send(input_data, input_len, 1)) { + throw Napi::Error::New(env, "Failed to send request, ring buffer full. Make it bigger?"); + } + + return env.Undefined(); +} + +Napi::Value MsgpackClientAsync::acquire(const Napi::CallbackInfo& info) +{ + Napi::Env env = info.Env(); + std::lock_guard lock(tsfn_mutex_); + + if (ref_count_ == 0) { + // Lazily create TSFN on 0 → 1. + tsfn_ = Napi::ThreadSafeFunction::New(env, + js_callback_.Value(), + "IpcRuntimeShmResponseCallback", + /*max_queue_size*/ 0, + /*initial_thread_count*/ 1); + } + ref_count_++; + return env.Undefined(); +} + +Napi::Value MsgpackClientAsync::release(const Napi::CallbackInfo& info) +{ + std::lock_guard lock(tsfn_mutex_); + if (ref_count_ == 0) { + return info.Env().Undefined(); // Unbalanced release — ignore + } + ref_count_--; + if (ref_count_ == 0) { + tsfn_.Release(); // 1 → 0 + } + return info.Env().Undefined(); +} + +Napi::Value MsgpackClientAsync::close(const Napi::CallbackInfo& info) +{ + close_internal(); + return info.Env().Undefined(); +} + +void MsgpackClientAsync::close_internal() +{ + if (shutdown_.exchange(true, std::memory_order_acq_rel)) { + return; // Already closed + } + // Wake the poll thread out of a blocking receive, then join it. + if (client_) { + client_->wakeup(); + } + if (poll_thread_.joinable()) { + poll_thread_.join(); + } + { + std::lock_guard lock(tsfn_mutex_); + // Release the TSFN reference held on behalf of in-flight calls — + // otherwise the libuv loop stays referenced and Node never exits. + if (ref_count_ > 0) { + ref_count_ = 0; + tsfn_.Release(); + } + } + if (client_) { + client_->close(); + } +} + +Napi::Function MsgpackClientAsync::get_class(Napi::Env env) +{ + return DefineClass( + env, + "MsgpackClientAsync", + { + MsgpackClientAsync::InstanceMethod("setResponseCallback", &MsgpackClientAsync::setResponseCallback), + MsgpackClientAsync::InstanceMethod("call", &MsgpackClientAsync::call), + MsgpackClientAsync::InstanceMethod("acquire", &MsgpackClientAsync::acquire), + MsgpackClientAsync::InstanceMethod("release", &MsgpackClientAsync::release), + MsgpackClientAsync::InstanceMethod("close", &MsgpackClientAsync::close), + }); +} + +} // namespace ipc::napi diff --git a/ipc-runtime/cpp/napi/msgpack_client_async.hpp b/ipc-runtime/cpp/napi/msgpack_client_async.hpp new file mode 100644 index 000000000000..efd11ef43a0f --- /dev/null +++ b/ipc-runtime/cpp/napi/msgpack_client_async.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "ipc_runtime/ipc_client.hpp" +#include "napi.h" +#include +#include +#include +#include + +namespace ipc::napi { + +/** + * @brief Asynchronous NAPI wrapper for msgpack calls over shared-memory IPC. + * + * Provides an asynchronous, pipelined `call(Buffer)` to JavaScript. Multiple + * requests can be in flight simultaneously; responses are matched in FIFO + * order by the TypeScript wrapper. + * + * Architecture: + * - TypeScript: owns the promise queue + matches requests to responses + * - C++ main thread: writes requests to the SHM request ring + * - C++ poll thread: polls the response ring; invokes the JS callback via + * ThreadSafeFunction + * + * TS owns the queue (single-threaded JS makes that natural), so we don't need + * a C++-side mutex/queue. + */ +class MsgpackClientAsync : public Napi::ObjectWrap { + public: + MsgpackClientAsync(const Napi::CallbackInfo& info); + ~MsgpackClientAsync() override; + + /// info[0]: JS Function invoked once per response from the poll thread. + Napi::Value setResponseCallback(const Napi::CallbackInfo& info); + + /// info[0]: Buffer containing the msgpack request. Returns undefined. + Napi::Value call(const Napi::CallbackInfo& info); + + /// Acquire / release a ThreadSafeFunction reference that keeps the + /// libuv loop alive while requests are in flight. + Napi::Value acquire(const Napi::CallbackInfo& info); + Napi::Value release(const Napi::CallbackInfo& info); + + /// Stop the poll thread, release any held TSFN reference and close the + /// underlying client. Idempotent. The TS wrapper calls this from + /// destroy(); also run by the destructor as a safety net. + Napi::Value close(const Napi::CallbackInfo& info); + + static Napi::Function get_class(Napi::Env env); + + private: + /// Background loop: blocks on the response ring, invokes the JS callback + /// per message via tsfn_. Joined by close(). + void poll_responses(); + void close_internal(); + + std::unique_ptr client_; + std::thread poll_thread_; + std::atomic shutdown_{ false }; + + std::mutex tsfn_mutex_; + Napi::FunctionReference js_callback_; + Napi::ThreadSafeFunction tsfn_; + int ref_count_ = 0; +}; + +} // namespace ipc::napi diff --git a/ipc-runtime/cpp/napi/msgpack_client_wrapper.cpp b/ipc-runtime/cpp/napi/msgpack_client_wrapper.cpp new file mode 100644 index 000000000000..9c004fd2895f --- /dev/null +++ b/ipc-runtime/cpp/napi/msgpack_client_wrapper.cpp @@ -0,0 +1,116 @@ +#include "msgpack_client_wrapper.hpp" + +#include "ipc_runtime/ipc_client.hpp" +#include "napi.h" + +#include +#include +#include +#include + +namespace ipc::napi { + +MsgpackClientWrapper::MsgpackClientWrapper(const Napi::CallbackInfo& info) + : ObjectWrap(info) +{ + Napi::Env env = info.Env(); + + if (info.Length() < 1 || !info[0].IsString()) { + throw Napi::TypeError::New(env, "First argument must be a string (shared memory name)"); + } + std::string shm_name = info[0].As(); + + // Optional second arg: MPSC client slot id (defaults to 0). + std::size_t client_id = 0; + if (info.Length() >= 2 && info[1].IsNumber()) { + client_id = static_cast(info[1].As().Uint32Value()); + } + + // MPSC-SHM client — matches ipc::make_server which uses MPSC by default, + // so the same shm_name can host multiple clients (each with a distinct slot). + client_ = ipc::IpcClient::create_mpsc_shm(shm_name, client_id); + + if (!client_->connect()) { + throw Napi::Error::New(env, "Failed to connect to shared memory server"); + } + + connected_ = true; +} + +MsgpackClientWrapper::~MsgpackClientWrapper() +{ + if (client_ && connected_) { + client_->close(); + } +} + +Napi::Value MsgpackClientWrapper::call(const Napi::CallbackInfo& info) +{ + Napi::Env env = info.Env(); + + if (!connected_) { + throw Napi::Error::New(env, "Client is not connected"); + } + + if (info.Length() < 1 || !info[0].IsBuffer()) { + throw Napi::TypeError::New(env, "First argument must be a Buffer"); + } + + auto input_buffer = info[0].As>(); + const uint8_t* input_data = input_buffer.Data(); + size_t input_len = input_buffer.Length(); + + // Retry on backpressure, but with an overall deadline: this is a blocking + // call on the Node main thread, and a dead/wedged server must surface as + // an error rather than hanging the process forever. + constexpr uint64_t TIMEOUT_NS = 1'000'000'000; // 1s per attempt + constexpr uint64_t CALL_DEADLINE_NS = 60'000'000'000; // 60s overall + + auto now_ns = [] { + return static_cast( + std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()) + .count()); + }; + const uint64_t start_ns = now_ns(); + + while (!client_->send(input_data, input_len, TIMEOUT_NS)) { + // request ring full, consumer behind — retry until the deadline + if (now_ns() - start_ns > CALL_DEADLINE_NS) { + throw Napi::Error::New(env, "IPC call timed out sending request (server unresponsive?)"); + } + } + + // data() == nullptr means timeout; a non-null empty span is a valid + // zero-length response. + std::span response; + while ((response = client_->receive(TIMEOUT_NS)).data() == nullptr) { + if (now_ns() - start_ns > CALL_DEADLINE_NS) { + throw Napi::Error::New(env, "IPC call timed out waiting for response (server unresponsive?)"); + } + } + + auto js_buffer = Napi::Buffer::Copy(env, response.data(), response.size()); + client_->release(response.size()); + return js_buffer; +} + +Napi::Value MsgpackClientWrapper::close(const Napi::CallbackInfo& info) +{ + if (client_ && connected_) { + client_->close(); + connected_ = false; + } + return info.Env().Undefined(); +} + +Napi::Function MsgpackClientWrapper::get_class(Napi::Env env) +{ + return DefineClass(env, + "MsgpackClient", + { + MsgpackClientWrapper::InstanceMethod("call", &MsgpackClientWrapper::call), + MsgpackClientWrapper::InstanceMethod("close", &MsgpackClientWrapper::close), + }); +} + +} // namespace ipc::napi diff --git a/ipc-runtime/cpp/napi/msgpack_client_wrapper.hpp b/ipc-runtime/cpp/napi/msgpack_client_wrapper.hpp new file mode 100644 index 000000000000..e095e2de697c --- /dev/null +++ b/ipc-runtime/cpp/napi/msgpack_client_wrapper.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "ipc_runtime/ipc_client.hpp" +#include "napi.h" +#include + +namespace ipc::napi { + +/** + * @brief NAPI wrapper for synchronous msgpack calls over shared-memory IPC. + * + * Wraps an ipc::IpcClient (SHM transport) and exposes a blocking + * `call(Buffer) -> Buffer` to JavaScript. One round-trip per `call`. + */ +class MsgpackClientWrapper : public Napi::ObjectWrap { + public: + MsgpackClientWrapper(const Napi::CallbackInfo& info); + ~MsgpackClientWrapper(); + + Napi::Value call(const Napi::CallbackInfo& info); + Napi::Value close(const Napi::CallbackInfo& info); + + static Napi::Function get_class(Napi::Env env); + + private: + std::unique_ptr client_; + bool connected_ = false; +}; + +} // namespace ipc::napi diff --git a/ipc-runtime/cpp/napi/package.json b/ipc-runtime/cpp/napi/package.json new file mode 100644 index 000000000000..7f2fef56f2bc --- /dev/null +++ b/ipc-runtime/cpp/napi/package.json @@ -0,0 +1,16 @@ +{ + "name": "ipc-runtime-napi", + "private": true, + "version": "0.0.0", + "packageManager": "yarn@4.13.0", + "description": "Node addon source for @aztec/ipc-runtime's SHM-IPC client. Built by CMake (see CMakeLists.txt); these npm deps only supply node-addon-api / node-api headers.", + "dependencies": { + "node-addon-api": "^8.0.0", + "node-api-headers": "^1.1.0" + }, + "binary": { + "napi_versions": [ + 9 + ] + } +} diff --git a/ipc-runtime/cpp/napi/yarn.lock b/ipc-runtime/cpp/napi/yarn.lock new file mode 100644 index 000000000000..6d5334edc1d5 --- /dev/null +++ b/ipc-runtime/cpp/napi/yarn.lock @@ -0,0 +1,212 @@ +# This file is generated by running "yarn install" inside your project. +# Manual changes might be lost - proceed with caution! + +__metadata: + version: 8 + cacheKey: 10c0 + +"@isaacs/fs-minipass@npm:^4.0.0": + version: 4.0.1 + resolution: "@isaacs/fs-minipass@npm:4.0.1" + dependencies: + minipass: "npm:^7.0.4" + checksum: 10c0/c25b6dc1598790d5b55c0947a9b7d111cfa92594db5296c3b907e2f533c033666f692a3939eadac17b1c7c40d362d0b0635dc874cbfe3e70db7c2b07cc97a5d2 + languageName: node + linkType: hard + +"abbrev@npm:^4.0.0": + version: 4.0.0 + resolution: "abbrev@npm:4.0.0" + checksum: 10c0/b4cc16935235e80702fc90192e349e32f8ef0ed151ef506aa78c81a7c455ec18375c4125414b99f84b2e055199d66383e787675f0bcd87da7a4dbd59f9eac1d5 + languageName: node + linkType: hard + +"chownr@npm:^3.0.0": + version: 3.0.0 + resolution: "chownr@npm:3.0.0" + checksum: 10c0/43925b87700f7e3893296c8e9c56cc58f926411cce3a6e5898136daaf08f08b9a8eb76d37d3267e707d0dcc17aed2e2ebdf5848c0c3ce95cf910a919935c1b10 + languageName: node + linkType: hard + +"env-paths@npm:^2.2.0": + version: 2.2.1 + resolution: "env-paths@npm:2.2.1" + checksum: 10c0/285325677bf00e30845e330eec32894f5105529db97496ee3f598478e50f008c5352a41a30e5e72ec9de8a542b5a570b85699cd63bd2bc646dbcb9f311d83bc4 + languageName: node + linkType: hard + +"exponential-backoff@npm:^3.1.1": + version: 3.1.3 + resolution: "exponential-backoff@npm:3.1.3" + checksum: 10c0/77e3ae682b7b1f4972f563c6dbcd2b0d54ac679e62d5d32f3e5085feba20483cf28bd505543f520e287a56d4d55a28d7874299941faf637e779a1aa5994d1267 + languageName: node + linkType: hard + +"fdir@npm:^6.5.0": + version: 6.5.0 + resolution: "fdir@npm:6.5.0" + peerDependencies: + picomatch: ^3 || ^4 + peerDependenciesMeta: + picomatch: + optional: true + checksum: 10c0/e345083c4306b3aed6cb8ec551e26c36bab5c511e99ea4576a16750ddc8d3240e63826cc624f5ae17ad4dc82e68a253213b60d556c11bfad064b7607847ed07f + languageName: node + linkType: hard + +"graceful-fs@npm:^4.2.6": + version: 4.2.11 + resolution: "graceful-fs@npm:4.2.11" + checksum: 10c0/386d011a553e02bc594ac2ca0bd6d9e4c22d7fa8cfbfc448a6d148c59ea881b092db9dbe3547ae4b88e55f1b01f7c4a2ecc53b310c042793e63aa44cf6c257f2 + languageName: node + linkType: hard + +"ipc-runtime-napi@workspace:.": + version: 0.0.0-use.local + resolution: "ipc-runtime-napi@workspace:." + dependencies: + node-addon-api: "npm:^8.0.0" + node-api-headers: "npm:^1.1.0" + languageName: unknown + linkType: soft + +"isexe@npm:^4.0.0": + version: 4.0.0 + resolution: "isexe@npm:4.0.0" + checksum: 10c0/5884815115bceac452877659a9c7726382531592f43dc29e5d48b7c4100661aed54018cb90bd36cb2eaeba521092570769167acbb95c18d39afdccbcca06c5ce + languageName: node + linkType: hard + +"minipass@npm:^7.0.4, minipass@npm:^7.1.2": + version: 7.1.3 + resolution: "minipass@npm:7.1.3" + checksum: 10c0/539da88daca16533211ea5a9ee98dc62ff5742f531f54640dd34429e621955e91cc280a91a776026264b7f9f6735947629f920944e9c1558369e8bf22eb33fbb + languageName: node + linkType: hard + +"minizlib@npm:^3.1.0": + version: 3.1.0 + resolution: "minizlib@npm:3.1.0" + dependencies: + minipass: "npm:^7.1.2" + checksum: 10c0/5aad75ab0090b8266069c9aabe582c021ae53eb33c6c691054a13a45db3b4f91a7fb1bd79151e6b4e9e9a86727b522527c0a06ec7d45206b745d54cd3097bcec + languageName: node + linkType: hard + +"node-addon-api@npm:^8.0.0": + version: 8.7.0 + resolution: "node-addon-api@npm:8.7.0" + dependencies: + node-gyp: "npm:latest" + checksum: 10c0/31a03b00f6b0753ab08360952fdf80a1abb619dcf8125fa1ab07e3a414da050963440c3a86c77a0334c0be7a71acb5e242dc468b79201ee6151c7b943afe946d + languageName: node + linkType: hard + +"node-api-headers@npm:^1.1.0": + version: 1.8.0 + resolution: "node-api-headers@npm:1.8.0" + checksum: 10c0/544f1d55756dcd7536b4c61b78fd2c2553d87a228dfa35729540466966e73b551edf9611785c8ada706e7ef4737b3ed55a3f735678a8f989a7bc9a45f296049c + languageName: node + linkType: hard + +"node-gyp@npm:latest": + version: 12.3.0 + resolution: "node-gyp@npm:12.3.0" + dependencies: + env-paths: "npm:^2.2.0" + exponential-backoff: "npm:^3.1.1" + graceful-fs: "npm:^4.2.6" + nopt: "npm:^9.0.0" + proc-log: "npm:^6.0.0" + semver: "npm:^7.3.5" + tar: "npm:^7.5.4" + tinyglobby: "npm:^0.2.12" + undici: "npm:^6.25.0" + which: "npm:^6.0.0" + bin: + node-gyp: bin/node-gyp.js + checksum: 10c0/9d9032b405cbe42f72a105259d9eb679376470c102df4a2dbaa51e07d59bf741dcffb85897087ea9d8318b9cabb824a8978af51508ae142f0239ae1e6a3c2329 + languageName: node + linkType: hard + +"nopt@npm:^9.0.0": + version: 9.0.0 + resolution: "nopt@npm:9.0.0" + dependencies: + abbrev: "npm:^4.0.0" + bin: + nopt: bin/nopt.js + checksum: 10c0/1822eb6f9b020ef6f7a7516d7b64a8036e09666ea55ac40416c36e4b2b343122c3cff0e2f085675f53de1d2db99a2a89a60ccea1d120bcd6a5347bf6ceb4a7fd + languageName: node + linkType: hard + +"picomatch@npm:^4.0.4": + version: 4.0.4 + resolution: "picomatch@npm:4.0.4" + checksum: 10c0/e2c6023372cc7b5764719a5ffb9da0f8e781212fa7ca4bd0562db929df8e117460f00dff3cb7509dacfc06b86de924b247f504d0ce1806a37fac4633081466b0 + languageName: node + linkType: hard + +"proc-log@npm:^6.0.0": + version: 6.1.0 + resolution: "proc-log@npm:6.1.0" + checksum: 10c0/4f178d4062733ead9d71a9b1ab24ebcecdfe2250916a5b1555f04fe2eda972a0ec76fbaa8df1ad9c02707add6749219d118a4fc46dc56bdfe4dde4b47d80bb82 + languageName: node + linkType: hard + +"semver@npm:^7.3.5": + version: 7.8.0 + resolution: "semver@npm:7.8.0" + bin: + semver: bin/semver.js + checksum: 10c0/8f096ca9b80ffd47b308d03f9ce8c873e27e2983f36023c559cdc92c51e8433fc23ebbfe57ec9623fc155636a6961ee989501099841ae4bb1babc8d2b3f048cd + languageName: node + linkType: hard + +"tar@npm:^7.5.4": + version: 7.5.15 + resolution: "tar@npm:7.5.15" + dependencies: + "@isaacs/fs-minipass": "npm:^4.0.0" + chownr: "npm:^3.0.0" + minipass: "npm:^7.1.2" + minizlib: "npm:^3.1.0" + yallist: "npm:^5.0.0" + checksum: 10c0/8f039edb1d12fdd7df6c6f9877d125afe9f3da3f5f9317df326fdd090d48793d6998cede1506a1471f3e3a250db270a89dace28005eb5e99c5a9132d704ac956 + languageName: node + linkType: hard + +"tinyglobby@npm:^0.2.12": + version: 0.2.16 + resolution: "tinyglobby@npm:0.2.16" + dependencies: + fdir: "npm:^6.5.0" + picomatch: "npm:^4.0.4" + checksum: 10c0/f2e09fd93dd95c41e522113b686ff6f7c13020962f8698a864a257f3d7737599afc47722b7ab726e12f8a813f779906187911ff8ee6701ede65072671a7e934b + languageName: node + linkType: hard + +"undici@npm:^6.25.0": + version: 6.25.0 + resolution: "undici@npm:6.25.0" + checksum: 10c0/2597cc6689bdb02c210c557b1f85febbfda65becae6e6fc1061508e2f33734d25207f81cd8af56ada9956329eb3a7bd7431e87dcfeceba20ee87059b57dcf985 + languageName: node + linkType: hard + +"which@npm:^6.0.0": + version: 6.0.1 + resolution: "which@npm:6.0.1" + dependencies: + isexe: "npm:^4.0.0" + bin: + node-which: bin/which.js + checksum: 10c0/7e710e54ea36d2d6183bee2f9caa27a3b47b9baf8dee55a199b736fcf85eab3b9df7556fca3d02b50af7f3dfba5ea3a45644189836df06267df457e354da66d5 + languageName: node + linkType: hard + +"yallist@npm:^5.0.0": + version: 5.0.0 + resolution: "yallist@npm:5.0.0" + checksum: 10c0/a499c81ce6d4a1d260d4ea0f6d49ab4da09681e32c3f0472dee16667ed69d01dae63a3b81745a24bd78476ec4fcf856114cb4896ace738e01da34b2c42235416 + languageName: node + linkType: hard diff --git a/ipc-runtime/cpp/scripts/zig-ar.sh b/ipc-runtime/cpp/scripts/zig-ar.sh new file mode 100755 index 000000000000..f5dbcae843cf --- /dev/null +++ b/ipc-runtime/cpp/scripts/zig-ar.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +exec zig ar "$@" diff --git a/ipc-runtime/cpp/scripts/zig-ranlib.sh b/ipc-runtime/cpp/scripts/zig-ranlib.sh new file mode 100755 index 000000000000..ee4f1852f25f --- /dev/null +++ b/ipc-runtime/cpp/scripts/zig-ranlib.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +exec zig ranlib "$@" diff --git a/ipc-runtime/rust/.gitignore b/ipc-runtime/rust/.gitignore new file mode 100644 index 000000000000..2c96eb1b6517 --- /dev/null +++ b/ipc-runtime/rust/.gitignore @@ -0,0 +1,2 @@ +target/ +Cargo.lock diff --git a/ipc-runtime/rust/Cargo.toml b/ipc-runtime/rust/Cargo.toml new file mode 100644 index 000000000000..0e5a437f2fc9 --- /dev/null +++ b/ipc-runtime/rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "ipc-runtime" +version = "0.1.0" +edition = "2021" +description = "Safe Rust bindings to the ipc-runtime C ABI (UDS + MPSC-SHM transport)." +license = "Apache-2.0" + +[lib] +name = "ipc_runtime" +path = "src/lib.rs" + +# We compile the C++ runtime sources ourselves via the `cc` crate. There's +# no prebuilt libipc_runtime.a to find — cargo's toolchain picks the C++ +# stdlib (libstdc++ on most linux toolchains; libc++ on macOS), and the +# resulting archive is internally consistent with whatever the final cargo +# binary links. +[build-dependencies] +cc = "1.0" diff --git a/ipc-runtime/rust/build.rs b/ipc-runtime/rust/build.rs new file mode 100644 index 000000000000..c4cd25a5b302 --- /dev/null +++ b/ipc-runtime/rust/build.rs @@ -0,0 +1,47 @@ +// Build script: compile the ipc-runtime C++ sources directly via cc. +// +// We deliberately don't link a prebuilt libipc_runtime.a — each consumer +// compiles the same .cpp sources with its own toolchain so the resulting +// archive is internally consistent with whatever C++ stdlib the consumer's +// final binary links. For Rust on linux that's typically libstdc++ via +// system clang; macOS gets libc++ via Apple clang. Either way, no external +// IPC_RUNTIME_LIB_DIR dependency. + +use std::path::PathBuf; + +fn main() { + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let cpp_dir = crate_dir.join("../cpp"); + let src_dir = cpp_dir.join("ipc_runtime"); + + let sources = [ + "c_abi.cpp", + "ipc_client.cpp", + "ipc_server.cpp", + "serve_helper.cpp", + "signal_handlers.cpp", + "socket_client.cpp", + "socket_server.cpp", + "shm/mpsc_shm.cpp", + "shm/spsc_shm.cpp", + ]; + + let mut build = cc::Build::new(); + build + .cpp(true) + .std("c++20") + .flag_if_supported("-fPIC") + .include(&cpp_dir); + + for src in sources { + let path = src_dir.join(src); + build.file(&path); + println!("cargo:rerun-if-changed={}", path.display()); + } + println!("cargo:rerun-if-changed=build.rs"); + + build.compile("ipc_runtime"); + + // pthread comes via libc; the cc crate already wires the C++ stdlib link. + println!("cargo:rustc-link-lib=pthread"); +} diff --git a/ipc-runtime/rust/src/lib.rs b/ipc-runtime/rust/src/lib.rs new file mode 100644 index 000000000000..b83ba3cc64b8 --- /dev/null +++ b/ipc-runtime/rust/src/lib.rs @@ -0,0 +1,380 @@ +//! Safe Rust bindings to ipc-runtime — UDS + MPSC-SHM transport. +//! +//! Mirrors the C++ API: `IpcServer` and `IpcClient` types pick the right +//! transport based on the input path's suffix (`.sock` → UDS, +//! `.shm` → MPSC-SHM). Use the `from_path` constructors and the rest of +//! the API is the same across transports. +//! +//! See ipc-runtime/cpp/ipc_runtime/c_abi.h for the underlying C ABI. + +#![allow(non_camel_case_types)] + +use std::ffi::{c_void, CString}; +use std::os::raw::{c_char, c_int}; +use std::ptr::NonNull; + +// --------------------------------------------------------------------------- +// extern "C" declarations (mirror of c_abi.h) +// --------------------------------------------------------------------------- + +mod sys { + use super::*; + + #[repr(C)] + #[derive(Clone, Copy, PartialEq, Eq, Debug)] + pub struct ipc_status_t(pub i32); + + pub const IPC_OK: ipc_status_t = ipc_status_t(0); + + #[repr(C)] + pub struct ipc_server_options_t { + pub max_shm_clients: usize, + pub shm_request_ring_size: usize, + pub shm_response_ring_size: usize, + pub socket_backlog: c_int, + } + + pub enum ipc_server {} + pub enum ipc_client {} + + pub type ipc_server_handler_fn = unsafe extern "C" fn( + client_id: c_int, + req: *const u8, + req_len: usize, + resp_out: *mut *mut u8, + resp_len_out: *mut usize, + ctx: *mut c_void, + ); + + extern "C" { + + pub fn ipc_make_server( + path: *const c_char, + opts: *const ipc_server_options_t, + ) -> *mut ipc_server; + pub fn ipc_server_destroy(server: *mut ipc_server); + pub fn ipc_server_listen(server: *mut ipc_server) -> bool; + pub fn ipc_server_close(server: *mut ipc_server); + pub fn ipc_server_request_shutdown(server: *mut ipc_server); + pub fn ipc_server_run( + server: *mut ipc_server, + handler: ipc_server_handler_fn, + ctx: *mut c_void, + ); + pub fn ipc_install_default_signal_handlers(server: *mut ipc_server); + + pub fn ipc_make_client(path: *const c_char, shm_client_id: usize) -> *mut ipc_client; + pub fn ipc_client_destroy(client: *mut ipc_client); + pub fn ipc_client_connect(client: *mut ipc_client) -> bool; + pub fn ipc_client_close(client: *mut ipc_client); + pub fn ipc_client_send( + client: *mut ipc_client, + data: *const u8, + len: usize, + timeout_ns: u64, + ) -> bool; + pub fn ipc_client_receive( + client: *mut ipc_client, + timeout_ns: u64, + out: *mut *const u8, + out_len: *mut usize, + ) -> ipc_status_t; + pub fn ipc_client_release(client: *mut ipc_client, msg_size: usize); + } +} + +// --------------------------------------------------------------------------- +// Errors +// --------------------------------------------------------------------------- + +#[derive(Debug)] +pub enum Error { + InvalidPath(String), + Connect(String), + Listen(String), + Send, + Receive, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::InvalidPath(p) => { + write!( + f, + "ipc-runtime: invalid path (must end in .sock or .shm): {}", + p + ) + } + Error::Connect(p) => write!(f, "ipc-runtime: connect failed for {}", p), + Error::Listen(p) => write!(f, "ipc-runtime: listen failed for {}", p), + Error::Send => write!(f, "ipc-runtime: send failed"), + Error::Receive => write!(f, "ipc-runtime: receive failed"), + } + } +} + +impl std::error::Error for Error {} + +pub type Result = std::result::Result; + +/// 0 = infinite, matching the C ABI's unified timeout semantics. `call` +/// is documented as blocking until the reply arrives. +const DEFAULT_CALL_TIMEOUT_NS: u64 = 0; + +// --------------------------------------------------------------------------- +// IpcServer +// --------------------------------------------------------------------------- + +/// Server handle. Drop closes + releases the underlying C++ object. +pub struct IpcServer { + inner: NonNull, +} + +unsafe impl Send for IpcServer {} + +impl IpcServer { + /// Construct a server from a path. ".sock" → UDS, ".shm" → MPSC-SHM. + pub fn from_path(path: &str) -> Result { + let c_path = CString::new(path).map_err(|_| Error::InvalidPath(path.to_string()))?; + let raw = unsafe { sys::ipc_make_server(c_path.as_ptr(), std::ptr::null()) }; + NonNull::new(raw) + .map(|inner| IpcServer { inner }) + .ok_or_else(|| Error::InvalidPath(path.to_string())) + } + + pub fn listen(&mut self) -> Result<()> { + if unsafe { sys::ipc_server_listen(self.inner.as_ptr()) } { + Ok(()) + } else { + Err(Error::Listen("listen() returned false".to_string())) + } + } + + pub fn request_shutdown(&self) { + unsafe { sys::ipc_server_request_shutdown(self.inner.as_ptr()) }; + } + + /// Install default lifecycle signal handlers (SIGTERM/SIGINT graceful + /// shutdown, SIGBUS/SIGSEGV close+exit, parent-death watch). + pub fn install_default_signal_handlers(&self) { + unsafe { sys::ipc_install_default_signal_handlers(self.inner.as_ptr()) }; + } + + /// Run the event loop. The handler is called for each incoming request + /// with the client id and request bytes; it returns the response bytes. + /// Blocks until shutdown is requested. + pub fn run(&mut self, mut handler: F) + where + F: FnMut(i32, &[u8]) -> Vec, + { + // We pass a fat closure as `*mut c_void`; the shim re-casts it back + // and invokes the closure. The handler's response lives in + // `Ctx::scratch`, which stays valid while the runtime copies it into + // its send path (the runtime never retains the pointer past send()) + // and is dropped/overwritten on the next request. + + struct Ctx<'a> { + handler: &'a mut dyn FnMut(i32, &[u8]) -> Vec, + scratch: Vec, + } + + let handler_obj: &mut dyn FnMut(i32, &[u8]) -> Vec = &mut handler; + let mut ctx = Ctx { + handler: handler_obj, + scratch: Vec::new(), + }; + + unsafe extern "C" fn shim( + client_id: c_int, + req: *const u8, + req_len: usize, + resp_out: *mut *mut u8, + resp_len_out: *mut usize, + ctx_raw: *mut c_void, + ) { + let ctx = &mut *(ctx_raw as *mut Ctx<'_>); + let req_slice = if req_len == 0 { + &[] + } else { + std::slice::from_raw_parts(req, req_len) + }; + let response = (ctx.handler)(client_id, req_slice); + ctx.scratch = response; + *resp_out = ctx.scratch.as_mut_ptr(); + *resp_len_out = ctx.scratch.len(); + } + + unsafe { + sys::ipc_server_run( + self.inner.as_ptr(), + shim, + &mut ctx as *mut Ctx<'_> as *mut c_void, + ); + } + } +} + +impl Drop for IpcServer { + fn drop(&mut self) { + unsafe { + sys::ipc_server_close(self.inner.as_ptr()); + sys::ipc_server_destroy(self.inner.as_ptr()); + } + } +} + +// --------------------------------------------------------------------------- +// IpcClient +// --------------------------------------------------------------------------- + +/// Client handle. Drop closes + releases the underlying C++ object. +pub struct IpcClient { + inner: NonNull, +} + +unsafe impl Send for IpcClient {} + +impl IpcClient { + /// Construct a client and connect. ".sock" → UDS, ".shm" → MPSC-SHM + /// (with `shm_client_id` slot). + pub fn from_path(path: &str) -> Result { + Self::from_path_with_id(path, 0) + } + + pub fn from_path_with_id(path: &str, shm_client_id: usize) -> Result { + let c_path = CString::new(path).map_err(|_| Error::InvalidPath(path.to_string()))?; + let raw = unsafe { sys::ipc_make_client(c_path.as_ptr(), shm_client_id) }; + let inner = NonNull::new(raw).ok_or_else(|| Error::InvalidPath(path.to_string()))?; + let client = IpcClient { inner }; + if !unsafe { sys::ipc_client_connect(client.inner.as_ptr()) } { + return Err(Error::Connect(path.to_string())); + } + Ok(client) + } + + /// Synchronous request/response. Sends `req`, blocks until a reply + /// arrives, copies it out, releases the runtime's buffer. A zero-length + /// reply is `Ok(vec![])`, not an error. + pub fn call(&mut self, req: &[u8]) -> Result> { + if !unsafe { + sys::ipc_client_send( + self.inner.as_ptr(), + req.as_ptr(), + req.len(), + DEFAULT_CALL_TIMEOUT_NS, + ) + } { + return Err(Error::Send); + } + let mut out: *const u8 = std::ptr::null(); + let mut out_len: usize = 0; + let status = unsafe { + sys::ipc_client_receive( + self.inner.as_ptr(), + DEFAULT_CALL_TIMEOUT_NS, + &mut out, + &mut out_len, + ) + }; + if status != sys::IPC_OK { + return Err(Error::Receive); + } + // IPC_OK with out_len == 0 is a valid zero-length response; the + // release must still run (it consumes the frame header for SHM). + let response = if out_len == 0 { + Vec::new() + } else { + unsafe { std::slice::from_raw_parts(out, out_len) }.to_vec() + }; + unsafe { sys::ipc_client_release(self.inner.as_ptr(), out_len) }; + Ok(response) + } +} + +impl Drop for IpcClient { + fn drop(&mut self) { + unsafe { + sys::ipc_client_close(self.inner.as_ptr()); + sys::ipc_client_destroy(self.inner.as_ptr()); + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + /// Spawn `server.run(echo)` on a thread and return a raw handle usable to + /// request shutdown from the test thread (run() holds &mut self, so the + /// safe `request_shutdown(&self)` cannot be called concurrently). + fn spawn_echo_server(path: &str) -> (std::thread::JoinHandle<()>, usize) { + let mut server = IpcServer::from_path(path).expect("make server"); + server.listen().expect("listen"); + let raw = server.inner.as_ptr() as usize; + let handle = std::thread::spawn(move || { + server.run(|_client_id, req| { + if req == b"empty" { + Vec::new() + } else { + req.to_vec() + } + }); + }); + (handle, raw) + } + + fn shutdown_server(raw: usize, handle: std::thread::JoinHandle<()>) { + unsafe { sys::ipc_server_request_shutdown(raw as *mut sys::ipc_server) }; + handle.join().expect("server thread"); + } + + #[test] + fn connect_refused_is_err_not_hang() { + let path = format!("/tmp/ipc_rust_test_refused_{}.sock", std::process::id()); + let start = std::time::Instant::now(); + let result = IpcClient::from_path(&path); + assert!(result.is_err(), "connect to absent server must fail"); + // The connect budget is 5s; anything wildly beyond means a hang. + assert!(start.elapsed() < std::time::Duration::from_secs(30)); + } + + #[test] + fn uds_echo_and_zero_length_response() { + let path = format!("/tmp/ipc_rust_test_uds_{}.sock", std::process::id()); + let _ = std::fs::remove_file(&path); + let (handle, raw) = spawn_echo_server(&path); + + let mut client = IpcClient::from_path(&path).expect("connect"); + let resp = client.call(b"hello").expect("echo call"); + assert_eq!(resp, b"hello"); + + // A handler returning an empty Vec must surface as Ok(empty), not Err. + let resp = client.call(b"empty").expect("zero-length call"); + assert!(resp.is_empty()); + + drop(client); + shutdown_server(raw, handle); + let _ = std::fs::remove_file(&path); + } + + #[test] + fn shm_echo_and_zero_length_response() { + let path = format!("/ipc_rust_test_shm_{}.shm", std::process::id()); + let (handle, raw) = spawn_echo_server(&path); + + let mut client = IpcClient::from_path(&path).expect("connect"); + let resp = client.call(b"hello shm").expect("echo call"); + assert_eq!(resp, b"hello shm"); + + let resp = client.call(b"empty").expect("zero-length call"); + assert!(resp.is_empty()); + + drop(client); + shutdown_server(raw, handle); + } +} diff --git a/ipc-runtime/scripts/run_rust_tests.sh b/ipc-runtime/scripts/run_rust_tests.sh new file mode 100755 index 000000000000..a402b971d1ee --- /dev/null +++ b/ipc-runtime/scripts/run_rust_tests.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +# Run the Rust binding tests (compiles the C++ sources via build.rs). +source $(git rev-parse --show-toplevel)/ci3/source +cd $(dirname $0)/../rust +cargo test diff --git a/ipc-runtime/scripts/run_ts_tests.sh b/ipc-runtime/scripts/run_ts_tests.sh new file mode 100755 index 000000000000..edabce646c92 --- /dev/null +++ b/ipc-runtime/scripts/run_ts_tests.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +# Run the TS package tests (UDS transport, in-process server + client). +source $(git rev-parse --show-toplevel)/ci3/source +cd $(dirname $0)/../ts +yarn install --immutable +yarn test diff --git a/ipc-runtime/ts/.gitignore b/ipc-runtime/ts/.gitignore new file mode 100644 index 000000000000..60fcbaf682f5 --- /dev/null +++ b/ipc-runtime/ts/.gitignore @@ -0,0 +1,4 @@ +dest/ +node_modules/ +.yarn/ +build/ diff --git a/ipc-runtime/ts/.yarnrc.yml b/ipc-runtime/ts/.yarnrc.yml new file mode 100644 index 000000000000..3186f3f0795a --- /dev/null +++ b/ipc-runtime/ts/.yarnrc.yml @@ -0,0 +1 @@ +nodeLinker: node-modules diff --git a/ipc-runtime/ts/package.json b/ipc-runtime/ts/package.json new file mode 100644 index 000000000000..d88889861802 --- /dev/null +++ b/ipc-runtime/ts/package.json @@ -0,0 +1,28 @@ +{ + "name": "@aztec/ipc-runtime", + "packageManager": "yarn@4.13.0", + "version": "0.1.0", + "type": "module", + "main": "dest/index.js", + "types": "dest/index.d.ts", + "exports": { + ".": { + "types": "./dest/index.d.ts", + "import": "./dest/index.js" + } + }, + "scripts": { + "build": "tsc -p tsconfig.json", + "clean": "rm -rf dest", + "test": "tsc -p tsconfig.json && node --test dest/uds.test.js" + }, + "files": [ + "dest", + "src", + "build" + ], + "devDependencies": { + "@types/node": "^22", + "typescript": "^5.6.3" + } +} diff --git a/ipc-runtime/ts/scripts/copy_cross.sh b/ipc-runtime/ts/scripts/copy_cross.sh new file mode 100755 index 000000000000..c9ad49d79b1d --- /dev/null +++ b/ipc-runtime/ts/scripts/copy_cross.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Copies cross-compiled ipc_runtime_napi.node into the per-platform build/ +# layout consumed by ipc-runtime/ts/src/native_loader.ts. +# +# Inputs come from ipc-runtime/cpp/build-/lib/, populated by +# `ipc-runtime/bootstrap.sh build_cross ` (which uses +# cpp/CMakePresets.json's amd64-linux / arm64-linux / amd64-macos / +# arm64-macos presets). +set -e +NO_CD=1 source $(git rev-parse --show-toplevel)/ci3/source + +cd $(dirname $0)/.. + +if [ -n "${1:-}" ]; then + arch="$1" + mkdir -p ./build/$arch + cp ../cpp/build-$arch/lib/ipc_runtime_napi.node ./build/$arch/ +elif semver check "${REF_NAME:-}" && [[ "$(arch)" == "amd64" ]]; then + # Release build on amd64-linux: gather all four cross-compiled targets. + # The native amd64-linux addon is already in place from + # ipc-runtime/bootstrap.sh's native build step. + for arch in arm64-linux amd64-macos arm64-macos; do + mkdir -p ./build/$arch + cp ../cpp/build-$arch/lib/ipc_runtime_napi.node ./build/$arch/ + done + + llvm-strip-20 ./build/*/* + + # Re-sign macOS Mach-O binaries after stripping (stripping invalidates + # the ad-hoc code signature). + for arch in amd64-macos arm64-macos; do + for f in ./build/$arch/*; do + ldid -S "$f" + done + done +else + echo "copy_cross.sh: no arch arg and not a release build — nothing to do." +fi diff --git a/ipc-runtime/ts/src/index.ts b/ipc-runtime/ts/src/index.ts new file mode 100644 index 000000000000..0cb8b14a94c4 --- /dev/null +++ b/ipc-runtime/ts/src/index.ts @@ -0,0 +1,23 @@ +export type { IpcClientAsync, IpcClientSync } from "./types.js"; +export { + MAX_FRAME_SIZE, + CONNECT_RETRY_BUDGET_MS, + DEFAULT_RING_SIZE, + SOCKET_BACKLOG, + DEFAULT_CALL_TIMEOUT_NS, +} from "./types.js"; +export { UdsIpcClient, type UdsIpcClientConnectOptions } from "./uds_client.js"; +export { UdsIpcServer, type IpcServerHandler } from "./uds_server.js"; +export { + NapiShmSyncClient, + NapiShmAsyncClient, + createNapiShmSyncClient, + createNapiShmAsyncClient, + type NapiMsgpackClientSync, + type NapiMsgpackClientAsync, +} from "./shm_client.js"; +export { + findIpcRuntimeNapi, + loadIpcRuntimeNapi, + type Platform, +} from "./native_loader.js"; diff --git a/ipc-runtime/ts/src/native_loader.ts b/ipc-runtime/ts/src/native_loader.ts new file mode 100644 index 000000000000..cb2b3c96633c --- /dev/null +++ b/ipc-runtime/ts/src/native_loader.ts @@ -0,0 +1,97 @@ +// Locate the prebuilt ipc_runtime_napi.node addon shipped with this package. +// +// The addon is built by `ipc-runtime/bootstrap.sh` (CMake target +// `ipc_runtime_napi`) and copied into `build/-/` next to this +// package's `package.json`. Resolution walks up from this file's URL to the +// first `package.json` adjacent to a `build/` directory — that's the +// package root in both `file:`-linked and published consumption. + +import { createRequire } from "node:module"; +import * as fs from "node:fs"; +import * as path from "node:path"; +import { fileURLToPath } from "node:url"; + +export type Platform = + | "x86_64-linux" + | "x86_64-darwin" + | "aarch64-linux" + | "aarch64-darwin"; + +const PLATFORM_TO_BUILD_DIR: Record = { + "x86_64-linux": "amd64-linux", + "x86_64-darwin": "amd64-macos", + "aarch64-linux": "arm64-linux", + "aarch64-darwin": "arm64-macos", +}; + +function detectPlatform(): Platform | null { + const arch = process.arch; + const platform = process.platform; + if (arch === "x64" && platform === "linux") return "x86_64-linux"; + if (arch === "x64" && platform === "darwin") return "x86_64-darwin"; + if (arch === "arm64" && platform === "linux") return "aarch64-linux"; + if (arch === "arm64" && platform === "darwin") return "aarch64-darwin"; + return null; +} + +function findPackageRoot(): string | null { + // `import.meta.url` after tsc compile points at the .js file under + // /dest/...; climb until we find package.json with a sibling build/. + let currentDir = path.dirname(fileURLToPath(import.meta.url)); + const root = path.parse(currentDir).root; + while (currentDir !== root) { + const packageJsonPath = path.join(currentDir, "package.json"); + if (fs.existsSync(packageJsonPath)) { + const buildDir = path.join(currentDir, "build"); + if (fs.existsSync(buildDir)) { + return currentDir; + } + } + currentDir = path.dirname(currentDir); + } + return null; +} + +/** + * Resolve the path of `ipc_runtime_napi.node` for the current platform. + * Returns null if either the platform is unsupported or the artifact is + * absent (typical reason: ipc-runtime/bootstrap.sh hasn't run yet). + */ +export function findIpcRuntimeNapi(customPath?: string): string | null { + if (customPath) { + return fs.existsSync(customPath) ? path.resolve(customPath) : null; + } + const platform = detectPlatform(); + if (!platform) return null; + const packageRoot = findPackageRoot(); + if (!packageRoot) return null; + const buildDir = PLATFORM_TO_BUILD_DIR[platform]; + const candidate = path.join( + packageRoot, + "build", + buildDir, + "ipc_runtime_napi.node", + ); + return fs.existsSync(candidate) ? candidate : null; +} + +/** + * Load `ipc_runtime_napi.node` and return its native exports + * (`MsgpackClient`, `MsgpackClientAsync`). Throws a descriptive error when + * the addon cannot be located or fails to dlopen. + */ +export function loadIpcRuntimeNapi(customPath?: string): { + MsgpackClient: new (shmName: string, clientId?: number) => any; + MsgpackClientAsync: new (shmName: string, clientId?: number) => any; +} { + const addonPath = findIpcRuntimeNapi(customPath); + if (!addonPath) { + throw new Error( + "Could not locate ipc_runtime_napi.node. Build with `ipc-runtime/bootstrap.sh` " + + "or set the optional `customPath` argument to point at a prebuilt addon.", + ); + } + // createRequire so this works in both ESM and CJS callers. + const require = createRequire(import.meta.url); + return require(addonPath); +} diff --git a/ipc-runtime/ts/src/shm_client.ts b/ipc-runtime/ts/src/shm_client.ts new file mode 100644 index 000000000000..6345a3ff21ff --- /dev/null +++ b/ipc-runtime/ts/src/shm_client.ts @@ -0,0 +1,172 @@ +import { loadIpcRuntimeNapi } from "./native_loader.js"; +import { IpcClientAsync, IpcClientSync } from "./types.js"; + +/** + * Minimum surface a NAPI msgpack client must expose. Satisfied by the + * `MsgpackClient` / `MsgpackClientAsync` classes exported from this + * package's own `ipc_runtime_napi.node` addon (see ipc-runtime/cpp/napi/), + * which wraps the C++ ipc::IpcClient. + * + * The interface is exposed for tests / consumers that want to inject a + * mock or alternative implementation; the standard production path is the + * `createNapiShm{Sync,Async}Client` factories below, which load the + * prebuilt addon shipped in this package's `build/-/` directory. + * + * Note on the async contract: `MsgpackClientAsync.call` is *fire and + * forget*. Responses arrive via `setResponseCallback` in FIFO order on a + * background-thread → main-thread bridge (Napi::ThreadSafeFunction). + * The TS wrapper below owns the request queue and matches responses. + */ +export interface NapiMsgpackClientSync { + call(input: Buffer): Buffer; + close(): void; +} + +export interface NapiMsgpackClientAsync { + setResponseCallback(cb: (response: Buffer) => void): void; + call(input: Buffer): void; + acquire(): void; + release(): void; + /** Stop the native poll thread, release any held TSFN ref, close the client. */ + close(): void; +} + +/** Wraps a sync NAPI msgpack client behind the IpcClientSync interface. */ +export class NapiShmSyncClient implements IpcClientSync { + constructor(private inner: NapiMsgpackClientSync) {} + + call(input: Uint8Array): Uint8Array { + const buf = Buffer.isBuffer(input) + ? input + : Buffer.from(input.buffer, input.byteOffset, input.byteLength); + const resp = this.inner.call(buf); + return new Uint8Array(resp.buffer, resp.byteOffset, resp.byteLength); + } + + destroy(): void { + this.inner.close(); + } +} + +interface PendingCallback { + resolve: (data: Uint8Array) => void; + reject: (error: Error) => void; +} + +/** + * Wraps the fire-and-forget async NAPI msgpack client behind the + * `IpcClientAsync` interface. Owns a FIFO queue of pending calls; the C++ + * background polling thread invokes `setResponseCallback` once per + * response, and this wrapper matches it to the next queued caller. + * + * `acquire` / `release` are reference-count hooks the NAPI exposes so the + * libuv loop is kept alive only while requests are outstanding — without + * them a `node script.js` would never exit naturally. + */ +export class NapiShmAsyncClient implements IpcClientAsync { + private readonly pending: PendingCallback[] = []; + private destroyed = false; + + constructor(private inner: NapiMsgpackClientAsync) { + this.inner.setResponseCallback((response: Buffer) => { + if (this.destroyed) { + // Late response delivered after destroy(); the native close already + // balanced the TSFN reference. + return; + } + const cb = this.pending.shift(); + if (cb) { + cb.resolve(new Uint8Array(response)); + if (this.pending.length === 0) { + this.inner.release(); + } + } else { + // Protocol desync — every response should match a pending call. + // Don't release: no acquire was taken for an orphan response. + console.warn( + "NapiShmAsyncClient: dropping response with no pending caller", + ); + } + }); + } + + call(input: Uint8Array): Promise { + if (this.destroyed) { + return Promise.reject( + new Error("NapiShmAsyncClient: call() after destroy()"), + ); + } + const buf = Buffer.isBuffer(input) + ? input + : Buffer.from(input.buffer, input.byteOffset, input.byteLength); + return new Promise((resolve, reject) => { + if (this.pending.length === 0) { + this.inner.acquire(); + } + this.pending.push({ resolve, reject }); + try { + this.inner.call(buf); + } catch (err: any) { + // Send failed — unwind the queue entry we just added. + this.pending.pop(); + if (this.pending.length === 0) { + this.inner.release(); + } + reject( + err instanceof Error + ? err + : new Error(`SHM async call failed: ${String(err)}`), + ); + } + }); + } + + async destroy(): Promise { + if (this.destroyed) { + return; + } + this.destroyed = true; + // Reject anything still in flight. + while (this.pending.length > 0) { + const cb = this.pending.shift(); + cb?.reject(new Error("ipc-runtime SHM client destroyed before response")); + } + // Stops the native poll thread and releases the TSFN reference taken + // when the queue went 0 → 1 — without this, Node never exits when + // destroyed with calls in flight. + this.inner.close(); + } +} + +export interface CreateNapiShmOptions { + /** MPSC client slot id (default 0). Distinct clients on the same shmName must use distinct slots. */ + clientId?: number; + /** Override addon path lookup. Rarely needed; useful for tests / unusual deployments. */ + customAddonPath?: string; +} + +/** + * Factories that load the bundled `ipc_runtime_napi.node` addon and + * construct an MPSC-SHM client wrapped behind the `IpcClient*` interface. + * Matches the transport used by `ipc::make_server` on the C++ side, so any + * server started via that helper accepts these clients directly. + */ +export function createNapiShmSyncClient( + shmName: string, + options: CreateNapiShmOptions = {}, +): NapiShmSyncClient { + const napi = loadIpcRuntimeNapi(options.customAddonPath); + return new NapiShmSyncClient( + new napi.MsgpackClient(shmName, options.clientId ?? 0), + ); +} + +export function createNapiShmAsyncClient( + shmName: string, + options: CreateNapiShmOptions = {}, +): NapiShmAsyncClient { + const napi = loadIpcRuntimeNapi(options.customAddonPath); + return new NapiShmAsyncClient( + new napi.MsgpackClientAsync(shmName, options.clientId ?? 0), + ); +} diff --git a/ipc-runtime/ts/src/types.ts b/ipc-runtime/ts/src/types.ts new file mode 100644 index 000000000000..de67744cd419 --- /dev/null +++ b/ipc-runtime/ts/src/types.ts @@ -0,0 +1,38 @@ +/** + * Minimal byte-in / byte-out interface that the ipc-codegen-emitted + * Api types consume. Both UDS and SHM transports satisfy this. + */ +export interface IpcClientAsync { + call(input: Uint8Array): Promise; + destroy(): Promise; +} + +export interface IpcClientSync { + call(input: Uint8Array): Uint8Array; + destroy(): void; +} + +// Shared transport constants, mirroring cpp/ipc_runtime/constants.hpp — +// keep the two in sync. + +/** + * Maximum length-prefix value accepted on receive. A frame claiming more + * than this is treated as corruption and the connection is closed instead + * of buffering the claimed size. + */ +export const MAX_FRAME_SIZE = 256 * 1024 * 1024; // 256 MiB + +/** + * Total budget (ms) for connect() retry loops, covering the window where + * the server process is still starting up. + */ +export const CONNECT_RETRY_BUDGET_MS = 5000; + +/** Default ring size for SHM transports (per direction, per client). */ +export const DEFAULT_RING_SIZE = 4 * 1024 * 1024; // 4 MiB + +/** Default listen backlog for UDS servers. */ +export const SOCKET_BACKLOG = 10; + +/** Default per-call timeout: 0 = infinite (matches the C++ client APIs). */ +export const DEFAULT_CALL_TIMEOUT_NS = 0; diff --git a/ipc-runtime/ts/src/uds.test.ts b/ipc-runtime/ts/src/uds.test.ts new file mode 100644 index 000000000000..07d36a1e1036 --- /dev/null +++ b/ipc-runtime/ts/src/uds.test.ts @@ -0,0 +1,151 @@ +// In-process UDS transport tests: UdsIpcServer + UdsIpcClient round-trips, +// zero-length responses, disconnect handling and oversized-frame rejection. +// Run via `yarn test` (node --test against the compiled dest/ output). + +import { test } from "node:test"; +import * as assert from "node:assert/strict"; +import * as net from "node:net"; +import * as fs from "node:fs"; +import * as os from "node:os"; +import * as path from "node:path"; +import { UdsIpcClient } from "./uds_client.js"; +import { UdsIpcServer } from "./uds_server.js"; + +function tmpSocketPath(tag: string): string { + return path.join(os.tmpdir(), `ipc_ts_test_${tag}_${process.pid}.sock`); +} + +test("echo round-trip", async () => { + const socketPath = tmpSocketPath("echo"); + const server = await UdsIpcServer.listen(socketPath, (_id, req) => req); + const client = await UdsIpcClient.connect(socketPath); + try { + const payload = new Uint8Array([1, 2, 3, 4, 5]); + const resp = await client.call(payload); + assert.deepEqual(resp, payload); + + // Pipelined calls resolve FIFO. + const [a, b] = await Promise.all([ + client.call(new Uint8Array([7])), + client.call(new Uint8Array([8, 9])), + ]); + assert.deepEqual(a, new Uint8Array([7])); + assert.deepEqual(b, new Uint8Array([8, 9])); + } finally { + await client.destroy(); + await server.close(); + } + assert.equal(fs.existsSync(socketPath), false, "socket unlinked on close"); +}); + +test("socket file is chmod 0600", async () => { + const socketPath = tmpSocketPath("chmod"); + const server = await UdsIpcServer.listen(socketPath, (_id, req) => req); + try { + const mode = fs.statSync(socketPath).mode & 0o777; + assert.equal(mode, 0o600); + } finally { + await server.close(); + } +}); + +test("zero-length response resolves (not a hang/error)", async () => { + const socketPath = tmpSocketPath("zlen"); + const server = await UdsIpcServer.listen(socketPath, () => new Uint8Array(0)); + const client = await UdsIpcClient.connect(socketPath); + try { + const resp = await client.call(new Uint8Array([42])); + assert.equal(resp.length, 0); + } finally { + await client.destroy(); + await server.close(); + } +}); + +test("disconnect rejects pending calls and fails fast afterwards", async () => { + const socketPath = tmpSocketPath("disc"); + // Raw server that accepts, reads, then kills the connection without + // responding. + const rawServer = net.createServer((conn) => { + conn.once("data", () => conn.destroy()); + }); + await new Promise((resolve) => + rawServer.listen(socketPath, () => resolve()), + ); + + const client = await UdsIpcClient.connect(socketPath); + try { + await assert.rejects(client.call(new Uint8Array([1]))); + // Socket is dead — further calls fail fast instead of queueing. + await assert.rejects(client.call(new Uint8Array([2])), /closed/); + } finally { + await client.destroy(); + rawServer.close(); + fs.rmSync(socketPath, { force: true }); + } +}); + +test("client rejects oversized frame from server", async () => { + const socketPath = tmpSocketPath("oversize_cli"); + // Raw server that answers any request with a corrupt 0xFFFFFFFF length + // prefix. + const rawServer = net.createServer((conn) => { + conn.once("data", () => { + const bogus = Buffer.allocUnsafe(4); + bogus.writeUInt32LE(0xffffffff, 0); + conn.write(bogus); + }); + }); + await new Promise((resolve) => + rawServer.listen(socketPath, () => resolve()), + ); + + const client = await UdsIpcClient.connect(socketPath); + try { + await assert.rejects(client.call(new Uint8Array([1])), /oversized frame/); + } finally { + await client.destroy(); + rawServer.close(); + fs.rmSync(socketPath, { force: true }); + } +}); + +test("server drops connection on oversized frame", async () => { + const socketPath = tmpSocketPath("oversize_srv"); + const server = await UdsIpcServer.listen(socketPath, (_id, req) => req); + const conn = net.createConnection(socketPath); + try { + await new Promise((resolve, reject) => { + conn.once("connect", () => resolve()); + conn.once("error", reject); + }); + const bogus = Buffer.allocUnsafe(4); + bogus.writeUInt32LE(0xffffffff, 0); + conn.write(bogus); + await new Promise((resolve, reject) => { + const timer = setTimeout( + () => reject(new Error("server did not close the connection")), + 5000, + ); + conn.once("close", () => { + clearTimeout(timer); + resolve(); + }); + conn.once("error", () => { + /* RST is fine — close follows */ + }); + }); + } finally { + conn.destroy(); + await server.close(); + } +}); + +test("connect times out against a bound-but-unresponsive path", async () => { + const socketPath = tmpSocketPath("noaccept"); + fs.rmSync(socketPath, { force: true }); + await assert.rejects( + UdsIpcClient.connect(socketPath, { connectTimeoutMs: 300 }), + /timed out/, + ); +}); diff --git a/ipc-runtime/ts/src/uds_client.ts b/ipc-runtime/ts/src/uds_client.ts new file mode 100644 index 000000000000..df53266f161e --- /dev/null +++ b/ipc-runtime/ts/src/uds_client.ts @@ -0,0 +1,202 @@ +import * as net from "node:net"; +import { + IpcClientAsync, + CONNECT_RETRY_BUDGET_MS, + MAX_FRAME_SIZE, +} from "./types.js"; + +interface PendingCall { + resolve: (resp: Uint8Array) => void; + reject: (err: Error) => void; +} + +export interface UdsIpcClientConnectOptions { + /** Mark the socket as unref'd so it doesn't keep the Node event loop alive when idle. */ + unref?: boolean; + /** + * Retry budget (ms) for the initial connect when the server has bound the + * path but not yet called listen(). Set to 0 to fail immediately on + * ECONNREFUSED. Default CONNECT_RETRY_BUDGET_MS (5000). + */ + connectTimeoutMs?: number; +} + +/** + * Async IPC client over a Unix Domain Socket. Wire format matches the C++ + * ipc::IpcServer/IpcClient socket transport: 4-byte little-endian length + * prefix followed by `length` bytes of msgpack payload, per direction. + * + * Supports pipelining: multiple concurrent `call()` invocations are queued + * FIFO and matched with responses in order. Pipelining keeps the server-side + * socket window full and matches the native client behaviour. + */ +export class UdsIpcClient implements IpcClientAsync { + private buffer: Buffer = Buffer.alloc(0); + private pending: PendingCall[] = []; + private destroyed = false; + /** Set once the socket has errored/closed; new calls fail fast. */ + private closed = false; + + private constructor(private conn: net.Socket) { + conn.on("data", (chunk) => this.onData(chunk)); + conn.on("error", (err) => this.failAll(err)); + conn.on("close", () => this.failAll(new Error("socket closed"))); + } + + static async connect( + socketPath: string, + opts?: UdsIpcClientConnectOptions, + ): Promise { + const conn = await connectWithRetry( + socketPath, + opts?.connectTimeoutMs ?? CONNECT_RETRY_BUDGET_MS, + ); + conn.setNoDelay(true); + if (opts?.unref) conn.unref(); + return new UdsIpcClient(conn); + } + + /** Number of in-flight calls awaiting a response. */ + get inflight(): number { + return this.pending.length; + } + + /** Underlying socket — exposed for ref/unref control (event-loop tuning). */ + get socket(): net.Socket { + return this.conn; + } + + async call(input: Uint8Array): Promise { + if (this.destroyed) { + throw new Error("UdsIpcClient: call() after destroy()"); + } + if (this.closed) { + throw new Error("UdsIpcClient: call() on a closed/errored socket"); + } + return new Promise((resolve, reject) => { + this.pending.push({ resolve, reject }); + const lenBuf = Buffer.allocUnsafe(4); + lenBuf.writeUInt32LE(input.length, 0); + this.conn.write(lenBuf); + this.conn.write(input); + }); + } + + async destroy(): Promise { + this.destroyed = true; + this.conn.removeAllListeners(); + this.conn.destroy(); + this.failAll(new Error("UdsIpcClient destroyed")); + } + + private onData(chunk: Buffer): void { + this.buffer = + this.buffer.length === 0 + ? Buffer.from(chunk) + : Buffer.concat([this.buffer, chunk]); + while (this.buffer.length >= 4) { + const len = this.buffer.readUInt32LE(0); + if (len > MAX_FRAME_SIZE) { + // Corrupt/malicious frame — close instead of buffering up to the + // claimed size. + this.conn.destroy(); + this.failAll( + new Error( + `UdsIpcClient: oversized frame (${len} bytes exceeds MAX_FRAME_SIZE)`, + ), + ); + return; + } + if (this.buffer.length < 4 + len) return; + const payload = this.buffer.subarray(4, 4 + len); + this.buffer = this.buffer.subarray(4 + len); + const next = this.pending.shift(); + if (next) { + next.resolve(new Uint8Array(payload)); + } else { + // Protocol desync — every response should match a pending call. + console.warn( + `UdsIpcClient: dropping ${len}-byte response with no pending caller`, + ); + } + } + } + + private failAll(err: Error): void { + this.closed = true; + const pending = this.pending; + this.pending = []; + for (const p of pending) p.reject(err); + } +} + +/** + * Connect to `socketPath`, retrying on ECONNREFUSED until `timeoutMs` + * elapses. ECONNREFUSED happens in the narrow window between the server's + * bind() and listen(); other errors fail immediately. Each attempt is also + * capped at the remaining budget, so a bound-but-never-accepting server + * cannot hang the connect past the deadline. + */ +async function connectWithRetry( + socketPath: string, + timeoutMs: number, +): Promise { + const deadline = Date.now() + timeoutMs; + let attempt = 0; + let lastErr: Error | undefined; + while (true) { + try { + const remainingMs = Math.max(1, deadline - Date.now()); + return await attemptConnect(socketPath, remainingMs); + } catch (err) { + lastErr = err as Error; + const code = (err as NodeJS.ErrnoException).code; + if ( + code !== "ECONNREFUSED" && + code !== "ENOENT" && + code !== "ETIMEDOUT" + ) { + throw new Error(`UdsIpcClient: connect failed: ${lastErr.message}`); + } + if (Date.now() >= deadline) { + throw new Error(`UdsIpcClient: connect timed out: ${lastErr.message}`); + } + const delay = Math.min(50, 5 * 2 ** attempt++); + await new Promise((resolve) => setTimeout(resolve, delay)); + } + } +} + +function attemptConnect( + socketPath: string, + timeoutMs: number, +): Promise { + return new Promise((resolve, reject) => { + const conn = net.createConnection(socketPath); + const cleanup = () => { + conn.removeListener("connect", onConnect); + conn.removeListener("error", onError); + clearTimeout(timer); + }; + const onError = (err: Error) => { + cleanup(); + conn.destroy(); + reject(err); + }; + const onConnect = () => { + cleanup(); + resolve(conn); + }; + const timer = setTimeout(() => { + cleanup(); + conn.destroy(); + const err: NodeJS.ErrnoException = new Error( + `connect attempt timed out after ${timeoutMs}ms`, + ); + err.code = "ETIMEDOUT"; + reject(err); + }, timeoutMs); + conn.once("connect", onConnect); + conn.once("error", onError); + }); +} diff --git a/ipc-runtime/ts/src/uds_server.ts b/ipc-runtime/ts/src/uds_server.ts new file mode 100644 index 000000000000..d63598afd64b --- /dev/null +++ b/ipc-runtime/ts/src/uds_server.ts @@ -0,0 +1,140 @@ +import * as net from "node:net"; +import * as fs from "node:fs"; +import { MAX_FRAME_SIZE } from "./types.js"; + +/** + * Handler signature mirrors the C++ ipc::IpcServer::Handler: receive raw + * bytes, return raw bytes. msgpack decode/encode and command dispatch are + * the caller's responsibility (or the codegen's, when a generated dispatcher + * is wired in). + */ +export type IpcServerHandler = ( + clientId: number, + request: Uint8Array, +) => Promise | Uint8Array; + +/** + * UDS server with the same 4-byte-LE-length-prefix wire as UdsIpcClient and + * the C++ ipc::IpcServer socket transport. Accepts multiple concurrent + * connections; handler invocations are serialised per-connection. + * + * Signal handling is the caller's responsibility (unlike the C++ server's + * install_default_signal_handlers); the socket file is unlinked on close() + * and best-effort on process exit. + */ +export class UdsIpcServer { + private server: net.Server; + private nextClientId = 0; + private readonly unlinkOnExit = () => { + try { + fs.unlinkSync(this.socketPath); + } catch { + /* may already be gone */ + } + }; + + private constructor( + server: net.Server, + private socketPath: string, + ) { + this.server = server; + } + + static async listen( + socketPath: string, + handler: IpcServerHandler, + ): Promise { + try { + fs.unlinkSync(socketPath); + } catch { + /* socket file may not exist; ignore */ + } + + const server = net.createServer(); + const instance = new UdsIpcServer(server, socketPath); + server.on("connection", (conn) => instance.handleConnection(conn, handler)); + + await new Promise((resolve, reject) => { + const onError = (err: Error) => { + server.removeListener("listening", onListening); + reject(err); + }; + const onListening = () => { + server.removeListener("error", onError); + resolve(); + }; + server.once("error", onError); + server.once("listening", onListening); + server.listen(socketPath); + }); + + // Restrict the socket to the owner, matching the C++ server (and the + // 0600 mode used for SHM segments). + fs.chmodSync(socketPath, 0o600); + + // Best-effort cleanup if the process exits without close(). + process.on("exit", instance.unlinkOnExit); + + return instance; + } + + async close(): Promise { + await new Promise((resolve) => this.server.close(() => resolve())); + process.removeListener("exit", this.unlinkOnExit); + try { + fs.unlinkSync(this.socketPath); + } catch { + /* may already be gone */ + } + } + + private handleConnection(conn: net.Socket, handler: IpcServerHandler): void { + const clientId = this.nextClientId++; + let buffer = Buffer.alloc(0); + let chain: Promise = Promise.resolve(); + + conn.on("data", (chunk: Buffer) => { + buffer = + buffer.length === 0 + ? Buffer.from(chunk) + : Buffer.concat([buffer, chunk]); + while (buffer.length >= 4) { + const len = buffer.readUInt32LE(0); + if (len > MAX_FRAME_SIZE) { + // Corrupt/malicious frame — drop the connection instead of + // buffering up to the claimed size. + conn.destroy( + new Error( + `UdsIpcServer: oversized frame (${len} bytes exceeds MAX_FRAME_SIZE)`, + ), + ); + return; + } + if (buffer.length < 4 + len) break; + const payload = new Uint8Array(buffer.subarray(4, 4 + len)); + buffer = buffer.subarray(4 + len); + + const prev = chain; + chain = (async () => { + await prev; + try { + const resp = await handler(clientId, payload); + const lenBuf = Buffer.allocUnsafe(4); + lenBuf.writeUInt32LE(resp.length, 0); + conn.write(lenBuf); + conn.write(resp); + } catch (err) { + conn.destroy(err as Error); + } + })(); + void chain.catch(() => { + /* errors already handled by destroying the connection */ + }); + } + }); + + conn.on("error", () => { + /* swallowed — clients reconnect */ + }); + } +} diff --git a/ipc-runtime/ts/tsconfig.json b/ipc-runtime/ts/tsconfig.json new file mode 100644 index 000000000000..9ab2e0c930d3 --- /dev/null +++ b/ipc-runtime/ts/tsconfig.json @@ -0,0 +1,16 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ESNext", + "moduleResolution": "Bundler", + "strict": true, + "esModuleInterop": true, + "skipLibCheck": true, + "declaration": true, + "outDir": "dest", + "rootDir": "src", + "lib": ["ES2022"], + "types": ["node"] + }, + "include": ["src"] +} diff --git a/ipc-runtime/ts/yarn.lock b/ipc-runtime/ts/yarn.lock new file mode 100644 index 000000000000..db7180039bcb --- /dev/null +++ b/ipc-runtime/ts/yarn.lock @@ -0,0 +1,51 @@ +# This file is generated by running "yarn install" inside your project. +# Manual changes might be lost - proceed with caution! + +__metadata: + version: 8 + cacheKey: 10c0 + +"@aztec/ipc-runtime@workspace:.": + version: 0.0.0-use.local + resolution: "@aztec/ipc-runtime@workspace:." + dependencies: + "@types/node": "npm:^22" + typescript: "npm:^5.6.3" + languageName: unknown + linkType: soft + +"@types/node@npm:^22": + version: 22.19.19 + resolution: "@types/node@npm:22.19.19" + dependencies: + undici-types: "npm:~6.21.0" + checksum: 10c0/402e0f088c94cabda3cd721546bd8e4e75e098e0b342f6e03b90ca1e19c28986f9650112c64fcfd09fc8cebc0f8b20291a513153e90489331cf666e1e5503e16 + languageName: node + linkType: hard + +"typescript@npm:^5.6.3": + version: 5.9.3 + resolution: "typescript@npm:5.9.3" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/6bd7552ce39f97e711db5aa048f6f9995b53f1c52f7d8667c1abdc1700c68a76a308f579cd309ce6b53646deb4e9a1be7c813a93baaf0a28ccd536a30270e1c5 + languageName: node + linkType: hard + +"typescript@patch:typescript@npm%3A^5.6.3#optional!builtin": + version: 5.9.3 + resolution: "typescript@patch:typescript@npm%3A5.9.3#optional!builtin::version=5.9.3&hash=5786d5" + bin: + tsc: bin/tsc + tsserver: bin/tsserver + checksum: 10c0/ad09fdf7a756814dce65bc60c1657b40d44451346858eea230e10f2e95a289d9183b6e32e5c11e95acc0ccc214b4f36289dcad4bf1886b0adb84d711d336a430 + languageName: node + linkType: hard + +"undici-types@npm:~6.21.0": + version: 6.21.0 + resolution: "undici-types@npm:6.21.0" + checksum: 10c0/c01ed51829b10aa72fc3ce64b747f8e74ae9b60eafa19a7b46ef624403508a54c526ffab06a14a26b3120d055e1104d7abe7c9017e83ced038ea5cf52f8d5e04 + languageName: node + linkType: hard diff --git a/ipc-runtime/zig/.gitignore b/ipc-runtime/zig/.gitignore new file mode 100644 index 000000000000..3389c86c9946 --- /dev/null +++ b/ipc-runtime/zig/.gitignore @@ -0,0 +1,2 @@ +.zig-cache/ +zig-out/ diff --git a/ipc-runtime/zig/build.zig b/ipc-runtime/zig/build.zig new file mode 100644 index 000000000000..e93af3f31444 --- /dev/null +++ b/ipc-runtime/zig/build.zig @@ -0,0 +1,71 @@ +const std = @import("std"); + +/// Build the ipc-runtime C++ sources into a static archive that Zig owns. +/// We compile the same .cpp files other consumers do, but with Zig's +/// bundled clang + libc++. The archive lives in Zig's build cache and is +/// internally consistent with whatever libc++ the final Zig binary links. +/// No prebuilt artifact, no IPC_RUNTIME_LIB_DIR. +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const cpp_root = b.path("../cpp"); + + // Build the runtime sources into a static library Zig owns. Same .cpp + // files other consumers compile, but here through Zig's bundled clang + + // libc++ — internally consistent with whatever the final Zig binary + // links. + const runtime_mod = b.createModule(.{ + .target = target, + .optimize = optimize, + .link_libc = true, + .link_libcpp = true, + }); + runtime_mod.addIncludePath(cpp_root); + runtime_mod.addCSourceFiles(.{ + .root = cpp_root, + .files = &.{ + "ipc_runtime/c_abi.cpp", + "ipc_runtime/ipc_client.cpp", + "ipc_runtime/ipc_server.cpp", + "ipc_runtime/serve_helper.cpp", + "ipc_runtime/signal_handlers.cpp", + "ipc_runtime/socket_client.cpp", + "ipc_runtime/socket_server.cpp", + "ipc_runtime/shm/mpsc_shm.cpp", + "ipc_runtime/shm/spsc_shm.cpp", + }, + .flags = &.{ "-std=c++20", "-fPIC" }, + }); + const runtime = b.addLibrary(.{ + .name = "ipc_runtime", + .linkage = .static, + .root_module = runtime_mod, + }); + + // Module others can @import("ipc_runtime") from their build.zig. + const mod = b.addModule("ipc_runtime", .{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = optimize, + .link_libc = true, + .link_libcpp = true, + }); + mod.addIncludePath(cpp_root); + mod.linkLibrary(runtime); + + // Smoke executable so `zig build` produces something verifiable. + const smoke = b.addExecutable(.{ + .name = "ipc_runtime_smoke", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/smoke.zig"), + .target = target, + .optimize = optimize, + .link_libc = true, + .link_libcpp = true, + }), + }); + smoke.root_module.addImport("ipc_runtime", mod); + smoke.linkLibrary(runtime); + b.installArtifact(smoke); +} diff --git a/ipc-runtime/zig/build.zig.zon b/ipc-runtime/zig/build.zig.zon new file mode 100644 index 000000000000..6834cc1fcc15 --- /dev/null +++ b/ipc-runtime/zig/build.zig.zon @@ -0,0 +1,11 @@ +.{ + .name = .ipc_runtime, + .version = "0.1.0", + .fingerprint = 0x177513c7abc06ef4, + .minimum_zig_version = "0.14.0", + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + }, +} diff --git a/ipc-runtime/zig/src/main.zig b/ipc-runtime/zig/src/main.zig new file mode 100644 index 000000000000..52b954331fd3 --- /dev/null +++ b/ipc-runtime/zig/src/main.zig @@ -0,0 +1,144 @@ +//! Zig binding to ipc-runtime — UDS + MPSC-SHM transport. +//! +//! `Server.fromPath(path)` / `Client.fromPath(path)` pick UDS vs MPSC-SHM +//! by the path suffix (`.sock` → UDS, `.shm` → SHM). Same call/listen/run +//! methods across transports. See ipc-runtime/cpp/ipc_runtime/c_abi.h for +//! the underlying C ABI. +//! +//! Zig's build.zig compiles the C++ sources directly with the bundled +//! clang + libc++, so there's no prebuilt-archive dependency. + +const std = @import("std"); + +const c = @cImport({ + @cInclude("ipc_runtime/c_abi.h"); +}); + +/// 0 = infinite, matching the C ABI's unified timeout semantics. `call` +/// blocks until the reply arrives. +const default_call_timeout_ns: u64 = 0; + +pub const Error = error{ + InvalidPath, + Connect, + Listen, + Send, + Receive, +}; + +/// Server handle. `deinit` releases the underlying C++ object. +pub const Server = struct { + handle: *c.ipc_server, + + pub fn fromPath(path: [:0]const u8) Error!Server { + const raw = c.ipc_make_server(path.ptr, null); + if (raw == null) return Error.InvalidPath; + return .{ .handle = raw.? }; + } + + pub fn deinit(self: *Server) void { + c.ipc_server_close(self.handle); + c.ipc_server_destroy(self.handle); + } + + pub fn listen(self: *Server) Error!void { + if (!c.ipc_server_listen(self.handle)) return Error.Listen; + } + + pub fn requestShutdown(self: *Server) void { + c.ipc_server_request_shutdown(self.handle); + } + + pub fn installDefaultSignalHandlers(self: *Server) void { + c.ipc_install_default_signal_handlers(self.handle); + } + + /// Run the event loop. `handler` is invoked per request; its return slice + /// must remain valid until the next call (a per-context arena works well). + pub fn run( + self: *Server, + comptime Ctx: type, + ctx: Ctx, + handler: *const fn (ctx: Ctx, client_id: i32, req: []const u8) []u8, + ) void { + const Bridge = struct { + ctx: Ctx, + handler: *const fn (ctx: Ctx, client_id: i32, req: []const u8) []u8, + + fn shim( + client_id: c_int, + req: [*c]const u8, + req_len: usize, + resp_out: [*c][*c]u8, + resp_len_out: [*c]usize, + ctx_raw: ?*anyopaque, + ) callconv(.c) void { + const bridge: *@This() = @ptrCast(@alignCast(ctx_raw.?)); + const req_slice = if (req_len == 0) &[_]u8{} else req[0..req_len]; + const resp = bridge.handler(bridge.ctx, @intCast(client_id), req_slice); + resp_out[0] = @constCast(resp.ptr); + resp_len_out[0] = resp.len; + } + }; + var bridge = Bridge{ .ctx = ctx, .handler = handler }; + c.ipc_server_run(self.handle, Bridge.shim, &bridge); + } +}; + +/// Client handle. `deinit` releases the underlying C++ object. +pub const Client = struct { + handle: *c.ipc_client, + allocator: std.mem.Allocator, + + /// Open a client connection. `.sock` → UDS, `.shm` → MPSC-SHM (slot 0; + /// use `fromPathWithId` for a different slot). + pub fn fromPath(allocator: std.mem.Allocator, path: [:0]const u8) Error!Client { + return fromPathWithId(allocator, path, 0); + } + + pub fn fromPathWithId(allocator: std.mem.Allocator, path: [:0]const u8, shm_client_id: usize) Error!Client { + const raw = c.ipc_make_client(path.ptr, shm_client_id); + if (raw == null) return Error.InvalidPath; + const client = Client{ .handle = raw.?, .allocator = allocator }; + if (!c.ipc_client_connect(client.handle)) { + c.ipc_client_destroy(client.handle); + return Error.Connect; + } + return client; + } + + pub fn deinit(self: *Client) void { + c.ipc_client_close(self.handle); + c.ipc_client_destroy(self.handle); + } + + /// Alias for deinit() so Client satisfies the ipc-codegen Backend + /// contract (which expects `destroy(self: *T) void`). Generated typed + /// clients call `backend.destroy()` at end-of-life. + pub fn destroy(self: *Client) void { + self.deinit(); + } + + /// Synchronous request/response. Returns an owned slice (free with the + /// allocator passed at construction). A zero-length reply is a valid + /// empty slice, not an error. + pub fn call(self: *Client, request: []const u8) ![]u8 { + if (!c.ipc_client_send(self.handle, request.ptr, request.len, default_call_timeout_ns)) { + return Error.Send; + } + var out_ptr: [*c]const u8 = null; + var out_len: usize = 0; + const status = c.ipc_client_receive(self.handle, default_call_timeout_ns, &out_ptr, &out_len); + if (status != c.IPC_OK) { + return Error.Receive; + } + // IPC_OK with out_len == 0 is a valid zero-length response; release + // must still run (it consumes the frame header for SHM). + const copied = try self.allocator.alloc(u8, out_len); + if (out_len > 0) { + @memcpy(copied, out_ptr[0..out_len]); + } + c.ipc_client_release(self.handle, out_len); + return copied; + } +}; diff --git a/ipc-runtime/zig/src/smoke.zig b/ipc-runtime/zig/src/smoke.zig new file mode 100644 index 000000000000..95e5b7deecd1 --- /dev/null +++ b/ipc-runtime/zig/src/smoke.zig @@ -0,0 +1,47 @@ +//! Smoke test: spawn UDS server thread, client connects + round-trips one msg. +const std = @import("std"); +const ipc = @import("ipc_runtime"); + +const SocketPath = "/tmp/ipc_runtime_zig_smoke.sock"; + +fn serverThread(arg: usize) void { + _ = arg; + std.fs.cwd().deleteFile(SocketPath) catch {}; + var srv = ipc.Server.fromPath(SocketPath) catch unreachable; + defer srv.deinit(); + srv.listen() catch unreachable; + + const Ctx = struct { + scratch: []u8, + }; + var scratch: [16]u8 = undefined; + var ctx_struct = Ctx{ .scratch = &scratch }; + + srv.run(*Ctx, &ctx_struct, struct { + fn h(ctx: *Ctx, _: i32, req: []const u8) []u8 { + const n = @min(req.len, ctx.scratch.len); + for (0..n) |i| ctx.scratch[i] = req[n - 1 - i]; + return ctx.scratch[0..n]; + } + }.h); +} + +pub fn main() !void { + const thread = try std.Thread.spawn(.{}, serverThread, .{@as(usize, 0)}); + _ = thread; + std.Thread.sleep(100 * std.time.ns_per_ms); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + defer _ = gpa.deinit(); + var client = try ipc.Client.fromPath(gpa.allocator(), SocketPath); + defer client.deinit(); + + const response = try client.call("hello"); + defer gpa.allocator().free(response); + std.debug.print("client got: {s}\n", .{response}); + if (!std.mem.eql(u8, response, "olleh")) { + std.debug.print("mismatch\n", .{}); + std.process.exit(1); + } + std.process.exit(0); +} diff --git a/yarn-project/bootstrap.sh b/yarn-project/bootstrap.sh index 6c84b8b7a6d3..0a2b88c4eb67 100755 --- a/yarn-project/bootstrap.sh +++ b/yarn-project/bootstrap.sh @@ -298,7 +298,10 @@ case "$cmd" in git clean -fdx ;; "clean-lite") - files=$(git ls-files --ignored --others --exclude-standard | grep -vE '(node_modules/|^\.yarn/)' || true) + # Preserve gitignored fixture dirs that are populated by sibling builds and + # consumed concurrently by parallel test commands. Wiping them mid-test + # yanks files out from under readers (see chonk_inputs.sh download path). + files=$(git ls-files --ignored --others --exclude-standard | grep -vE '(node_modules/|^\.yarn/|^end-to-end/example-app-ivc-inputs-out/|^end-to-end/ultrahonk-bench-inputs/|^end-to-end/dumped-avm-circuit-inputs/)' || true) if [ -n "$files" ]; then echo "$files" | xargs rm -rf fi