diff --git a/examples/wuji/pyproject.toml b/examples/wuji/pyproject.toml index 1ba06643..05c37407 100644 --- a/examples/wuji/pyproject.toml +++ b/examples/wuji/pyproject.toml @@ -6,6 +6,16 @@ readme = "README.md" requires-python = ">=3.12" dependencies = ["genelab"] +[project.optional-dependencies] +# Deploy core (real2sim + policy control). Pure-software pieces tested headlessly. +deploy = ["onnxruntime", "pyzmq", "scipy"] +# Vision observer (Hikvision camera -> ZMQ). Hardware-side, not exercised in CI. +# Also needs the Hikvision MVS SDK (system install, not pip) — see deploy/README.md. +deploy-vision = ["opencv-contrib-python", "pupil-apriltags", "pyyaml"] +# Real Wuji hand SDK (compiled wheel; physical hardware). Imported lazily, kept out +# of `deploy` so the headless core stays binary-free. Pinned to match wuji-mjlab. +deploy-hand = ["wujihandpy==1.5.1"] + [project.entry-points."genelab.extensions"] genelab_wuji = "genelab_wuji.tasks:register" diff --git a/examples/wuji/src/genelab_wuji/deploy/README.md b/examples/wuji/src/genelab_wuji/deploy/README.md new file mode 100644 index 00000000..038614d5 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/README.md @@ -0,0 +1,108 @@ +# Wuji-hand deploy (Genesis-native) + +A Genesis-native port of the `wuji-mjlab/deploy/reorient` pipeline. Two deliverables: + +1. **real2sim** — reproduce the real cube's pose inside the Genesis sim. +2. **policy deploy** — run an exported ONNX policy to control the (real or mock) hand. + +The pieces are decoupled via ZMQ (localhost): + +``` + cube_world_observer ──cube pose (5555)──▶ play_real (controls the hand) + (Hikvision camera) │ toreal_viewer (mirrors cube in sim) + └──────────────▶ + toreal_viewer ──goal (5556)──▶ play_real +``` + +## Architecture + +| Module | Responsibility | Tested | +|---|---|---| +| `frame_transform.py` | wxyz quat math + `cube_cam_to_tag` (camera→wrist-tag lift) | ✅ | +| `real2sim.py` | `tag_pose_in_world`, `cube_pose_in_tag_to_world` (sim reproduction) | ✅ | +| `zmq_bridge.py` | cube/goal pub-sub + xyzw↔wxyz + last-valid cache | ✅ | +| `obs.py` | `DeployObsBuilder` (207-dim policy obs + 3-step history) | ✅ | +| `action.py` | `ActionProcessor` (offset + clamp + EMA + warmup) | ✅ | +| `onnx_policy.py` | `ONNXPolicy` (GeneLab metadata format) | ✅ | +| `hand_driver.py` | `HandDriverBase` / `MockHandDriver` / `WujiHandDriver` | ✅ (mock) | +| `controller.py` | `DeployController` (closed-loop step) | ✅ | +| `camera_config.py` | Hikvision intrinsics/ROI/capture from `config/camera.yaml` | glue (hardware) | +| `cube_geom.py` | cube_tags JSON resolution (`config/cube_tags.json`) | glue | +| `scripts/hand_utils.py` | `check` (read-only bridge test) / `home` (3s ramp to grasp pose) | glue (hardware) | +| `scripts/calib_check.py` | static calib viewer: live hand (encoders) + cube vs. digital twin | glue (hardware) | +| `scripts/play_real.py` | deploy control loop + goal modes + success monitor + Genesis mirror (real/mock) | glue | +| `scripts/toreal_viewer.py` | real2sim Genesis viewer | glue | +| `scripts/cube_world_observer.py` | Hikvision camera → ArUco board + SO3 Kalman → ZMQ cube pose | glue (hardware) | + +The pure-software core is numpy-only and runs headlessly (no Genesis, no hardware), +so all frame/obs/action/policy logic is unit-tested in `tests/test_examples_wuji_deploy_*.py`. + +### Key conventions + +- **Quaternions**: wxyz everywhere internally; the cube wire format is scipy xyzw and + is converted at the ZMQ boundary (`cube_pose_from_msg` / `cube_msg_from_pose`). +- **Tag frame**: the observer reports the cube already in the wrist-AprilTag frame — + the exact frame the policy was trained on — so the deploy obs needs **no forward + kinematics**. +- **6D goal error**: matches the GeneLab training encoding (`matrix_to_rotation_6d`, + first two matrix rows), pinned against the real training math in the tests. +- **Joint order**: `finger1_joint1..4, finger2...` (= `wujihandpy` (5,4) row-major), so + no remap between policy and hardware. + +## Install + +```bash +uv pip install -e 'examples/wuji[deploy]' # core (real2sim + control) +uv pip install -e 'examples/wuji[deploy,deploy-vision]' # + camera observer +uv pip install -e 'examples/wuji[deploy,deploy-hand]' # + real Wuji hand SDK (wujihandpy) +``` + +The cube observer also needs the **Hikvision MVS SDK** (system install, not pip — same +as wuji-mjlab). Install from (default `/opt/MVS`) and source +its environment before running the observer: + +```bash +export MVCAM_COMMON_RUNENV=/opt/MVS/lib +export LD_LIBRARY_PATH=/opt/MVS/lib/64:/opt/MVS/lib/32:$LD_LIBRARY_PATH +# (or: source /opt/MVS/bin/set_env_path.sh /opt/MVS) +# If MvImport lives elsewhere: export MVS_PYTHON_PATH=/path/to/dir/containing/MvImport +``` + +## Run + +```bash +# 0) export a trained policy to ONNX +genelab export Genelab-Reorient-Wuji-Hand-v0 PATH/model.pt --format onnx --out policy.onnx + +# 1) smoke-test the control loop, no hardware, no ZMQ, no viewer +python -m genelab_wuji.deploy.scripts.play_real --ckpt policy.onnx --mock --no-zmq --no-viewer --steps 100 + +# 1.5) bring up the real hand bridge (needs wujihandpy): check first, then home +python -m genelab_wuji.deploy.scripts.hand_utils check # READ-ONLY: connection + encoder sanity +python -m genelab_wuji.deploy.scripts.hand_utils home # 3s ease-in-out ramp to the grasp pose + +# 2) vision: detect the cube and publish its tag-frame pose on ZMQ:5555 (needs MVS env) +python -m genelab_wuji.deploy.scripts.cube_world_observer --preview # terminal A +python -m genelab_wuji.deploy.scripts.toreal_viewer # terminal B (real2sim mirror) + +# 2.5) calibration check: home the hand, render live hand + observed cube in the twin +python -m genelab_wuji.deploy.scripts.calib_check # (needs the observer running) + +# 3) drive the real hand from the live observer feed (Genesis mirror viewer on by default, +# showing the live hand + observed cube + goal; pass --no-viewer for headless). +# goal modes: --goal-mode random (uniform-SO3, resampled on success) | +# fixed --goal-quat w,x,y,z | external (goal from toreal_viewer ZMQ) +python -m genelab_wuji.deploy.scripts.play_real --ckpt policy.onnx --real --goal-mode random +``` + +`play_real` mirrors the live hand (encoders) + observed cube + goal in a Genesis viewer +by default (`--no-viewer` to disable). It reuses the same kinematic, physics-free refresh +as `calib_check`, so the mirror just reflects reality. The control core itself is numpy-only +and runs headlessly under `--no-viewer`. + +The cube observer is a faithful port of the production wuji-mjlab pipeline (Hikvision MVS +capture, multi-face ArUco board fusion, SO3 Kalman + position low-pass + corner EMA, world +auto-sampling, fast ROI, OpenCV preview). It publishes the cube pose in the wrist-tag frame +in the exact same ZMQ schema GeneLab's `CubeReceiver` consumes. Tuning lives in +`config/observer.yaml`; camera intrinsics/ROI in `config/camera.yaml`; cube tag layout in +`config/cube_tags.json`. For a non-Hikvision camera, swap the MVS capture in `run()`. diff --git a/examples/wuji/src/genelab_wuji/deploy/__init__.py b/examples/wuji/src/genelab_wuji/deploy/__init__.py new file mode 100644 index 00000000..17e9973b --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/__init__.py @@ -0,0 +1,8 @@ +"""Genesis-native deployment for the Wuji-hand reorientation policy. + +Ports the deployment pipeline from ``wuji-mjlab/deploy/reorient`` onto the GeneLab +(Genesis) stack: real2sim cube-pose reproduction in sim and ONNX-policy control of +the hand. The pure-numpy core (frame transforms, ZMQ bridge, obs assembly, ONNX +wrapper, hand-driver abstraction) is simulator- and hardware-agnostic so it runs +and tests headlessly; Genesis viewers and real hardware live in ``deploy.scripts``. +""" diff --git a/examples/wuji/src/genelab_wuji/deploy/action.py b/examples/wuji/src/genelab_wuji/deploy/action.py new file mode 100644 index 00000000..56fc748e --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/action.py @@ -0,0 +1,52 @@ +"""Action post-processing for deploy (numpy port of JointPositionOffsetEMAAction). + +The policy emits raw actions ~[-1, 1]; the joint target is +``default + action_scale * clamp(action)``, clamped to joint limits, EMA-smoothed +against the previous target, and held at the default pose for ``warmup_steps`` after +each reset. Training-only terms (encoder_bias, action noise) are dropped. +""" + +from __future__ import annotations + +import numpy as np + + +class ActionProcessor: + """Turn a raw policy action into a smoothed, limit-clamped joint target.""" + + def __init__( + self, + default_joint_pos: np.ndarray, + action_scale: float = 0.5, + ema_alpha: float = 0.5, + warmup_steps: int = 8, + joint_pos_limits: tuple[np.ndarray, np.ndarray] | None = None, + ) -> None: + self._default = np.asarray(default_joint_pos, dtype=float) + self.action_scale = action_scale + self.ema_alpha = ema_alpha + self.warmup_steps = warmup_steps + if joint_pos_limits is None: + self._lo = np.full_like(self._default, -np.inf) + self._hi = np.full_like(self._default, np.inf) + else: + self._lo = np.asarray(joint_pos_limits[0], dtype=float) + self._hi = np.asarray(joint_pos_limits[1], dtype=float) + self._prev_target = self._default.copy() + self._step = 0 + + def reset(self) -> None: + """Reset the EMA state and warmup counter (call on episode boundaries).""" + self._prev_target = self._default.copy() + self._step = 0 + + def process(self, action: np.ndarray) -> np.ndarray: + """Return the joint target for this control step (JOINT_NAMES_20 order).""" + action = np.clip(np.asarray(action, dtype=float), -1.0, 1.0) + raw_target = self._default + self.action_scale * action + raw_target = np.clip(raw_target, self._lo, self._hi) + smoothed = self.ema_alpha * raw_target + (1.0 - self.ema_alpha) * self._prev_target + target = self._default.copy() if self._step < self.warmup_steps else smoothed + self._prev_target = target.copy() + self._step += 1 + return target diff --git a/examples/wuji/src/genelab_wuji/deploy/camera_config.py b/examples/wuji/src/genelab_wuji/deploy/camera_config.py new file mode 100644 index 00000000..bd7448a5 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/camera_config.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Wuji Technology Co., Ltd. +# Ported into GeneLab from wuji-mjlab deploy/reorient/lib/camera_config.py +"""Camera Configuration Loader. + +Centralized camera parameters from config/camera.yaml. +Provides functions for loading camera intrinsics, distortion, and ROI settings. + +Example: + >>> from camera_config import get_camera_matrix, get_dist_coeffs + >>> K = get_camera_matrix() + >>> dist = get_dist_coeffs() +""" + +from __future__ import annotations + +import os +from typing import Any + +import numpy as np +import yaml + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CONFIG_FILE = os.path.join(SCRIPT_DIR, "config", "camera.yaml") + +def load_camera_config(config_file: str | None = None) -> dict[str, Any]: + """Load camera configuration from YAML file. + + Args: + config_file: Path to configuration file. If None, uses default. + + Returns: + Camera configuration dictionary. + + Raises: + FileNotFoundError: If configuration file doesn't exist. + """ + if config_file is None: + config_file = CONFIG_FILE + + if not os.path.exists(config_file): + raise FileNotFoundError(f"Camera config not found: {config_file}") + + with open(config_file, 'r') as f: + cfg = yaml.safe_load(f) + + return cfg + + +def get_camera_matrix(cfg: dict[str, Any] | None = None) -> np.ndarray: + """Get camera intrinsic matrix K, adjusted for ROI offset. + + When ROI is set, cx and cy are shifted by the ROI offset so that the + intrinsics remain valid for the cropped image. + + Args: + cfg: Camera configuration dict. If None, loads from file. + + Returns: + 3x3 camera intrinsic matrix. + """ + if cfg is None: + cfg = load_camera_config() + + intr = cfg['intrinsics'] + roi = cfg['roi'] + K = np.array([ + [intr['fx'], 0, intr['cx'] - roi['offset_x']], + [0, intr['fy'], intr['cy'] - roi['offset_y']], + [0, 0, 1] + ], dtype=np.float64) + return K + + +def get_dist_coeffs(cfg: dict[str, Any] | None = None) -> np.ndarray: + """Get distortion coefficients. + + Args: + cfg: Camera configuration dict. If None, loads from file. + + Returns: + Distortion coefficients array [k1, k2, p1, p2, k3]. + """ + if cfg is None: + cfg = load_camera_config() + + dist = cfg['distortion'] + return np.array([ + dist['k1'], dist['k2'], dist['p1'], dist['p2'], dist['k3'] + ], dtype=np.float64) + + +def get_roi(cfg: dict[str, Any] | None = None) -> tuple[int, int, int, int]: + """Get ROI parameters. + + Args: + cfg: Camera configuration dict. If None, loads from file. + + Returns: + Tuple of (offset_x, offset_y, width, height). + """ + if cfg is None: + cfg = load_camera_config() + + roi = cfg['roi'] + return roi['offset_x'], roi['offset_y'], roi['width'], roi['height'] + + +def get_capture_settings(cfg: dict[str, Any] | None = None) -> dict[str, Any]: + """Get camera capture settings. + + Args: + cfg: Camera configuration dict. If None, loads from file. + + Returns: + Dictionary with exposure_time, gain, and frame_rate. + """ + if cfg is None: + cfg = load_camera_config() + + cap = cfg['capture'] + return { + 'exposure_time': cap['exposure_time'], + 'gain': cap['gain'], + 'frame_rate': cap.get('frame_rate', 0), + } + + +def setup_camera_roi(cam: Any, cfg: dict[str, Any] | None = None) -> tuple[int, int]: + """Setup camera ROI from config. + + Args: + cam: MvCamera instance. + cfg: Camera configuration dict. If None, loads from file. + + Returns: + Tuple of (width, height) of the configured ROI. + """ + offset_x, offset_y, width, height = get_roi(cfg) + cam.MV_CC_SetIntValueEx("OffsetX", offset_x) + cam.MV_CC_SetIntValueEx("OffsetY", offset_y) + cam.MV_CC_SetIntValueEx("Width", width) + cam.MV_CC_SetIntValueEx("Height", height) + print(f"Camera ROI: {width}x{height} @ ({offset_x}, {offset_y})") + return width, height + + +def setup_camera_capture(cam: Any, cfg: dict[str, Any] | None = None) -> None: + """Setup camera capture settings from config. + + Args: + cam: MvCamera instance. + cfg: Camera configuration dict. If None, loads from file. + """ + settings = get_capture_settings(cfg) + cam.MV_CC_SetFloatValue("ExposureTime", settings['exposure_time']) + cam.MV_CC_SetFloatValue("Gain", settings['gain']) + # Frame rate: enable explicit control and set target + frame_rate = settings.get('frame_rate', 0) + if frame_rate and frame_rate > 0: + ret1 = cam.MV_CC_SetBoolValue("AcquisitionFrameRateEnable", True) + ret2 = cam.MV_CC_SetFloatValue("AcquisitionFrameRate", float(frame_rate)) + # Read back actual resulting frame rate + from ctypes import c_float, byref + actual_fps = c_float(0) + ret3 = cam.MV_CC_GetFloatValue("ResultingFrameRate", actual_fps) + if ret3 == 0: + actual_str = f", actual={actual_fps.value:.1f}Hz" + else: + actual_str = ", actual=unknown" + print(f"Camera capture: exposure={settings['exposure_time']}us, gain={settings['gain']}, " + f"frame_rate={frame_rate}Hz (enable_ret=0x{ret1:X}, set_ret=0x{ret2:X}{actual_str})") + else: + print(f"Camera capture: exposure={settings['exposure_time']}us, gain={settings['gain']}, frame_rate=default") + + +if __name__ == "__main__": + print("=" * 50) + print("Camera Config Test") + print("=" * 50) + + cfg = load_camera_config() + print("\nCamera Config loaded:") + print(f" ROI: {get_roi(cfg)}") + print(f" K:\n{get_camera_matrix(cfg)}") + print(f" Dist: {get_dist_coeffs(cfg)}") + print(f" Capture: {get_capture_settings(cfg)}") diff --git a/examples/wuji/src/genelab_wuji/deploy/config.py b/examples/wuji/src/genelab_wuji/deploy/config.py new file mode 100644 index 00000000..6f8dddc1 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/config.py @@ -0,0 +1,45 @@ +"""Shared deploy constants: joint ordering, remap, and home pose. + +Two DIFFERENT joint orders are in play and MUST be remapped between (this was the +real-hand 0%-success bug): + +* **Encoder / hardware order** = ``JOINT_NAMES_20`` = ``REORIENT_JOINT_POS`` keys = + ``wujihandpy``'s (5, 4) row-major flatten: **finger-major** (finger1_joint1..4, + finger2_joint1..4, ...). This is what ``read_encoders`` / ``write_target`` speak. +* **Policy / Genesis articulation order** = ``POLICY_JOINT_NAMES``: **joint-major** + (finger1..5_joint1, then finger1..5_joint2, ...). Genesis orders the articulation + this way regardless of the MJCF element order, so the trained policy's obs and + action are joint-major. + +``DeployController`` remaps encoder->policy on read and policy->encoder on write via +``ENC_TO_POLICY`` / its inverse. ``tests/test_examples_wuji_deploy_joint_order.py`` +pins ``POLICY_JOINT_NAMES`` against the actual built env so the constant can't drift. +""" + +from __future__ import annotations + +import numpy as np + +from genelab_wuji.reorient.constants import REORIENT_JOINT_POS + +JOINT_NAMES_20: tuple[str, ...] = tuple(REORIENT_JOINT_POS) +"""The 20 hand joint names in encoder / hardware order (finger-major).""" + +N_JOINTS: int = len(JOINT_NAMES_20) +"""Hand DOF count (20 = 5 fingers x 4 joints).""" + +ENC_TO_POLICY: tuple[int, ...] = (0, 4, 8, 12, 16, 1, 5, 9, 13, 17, 2, 6, 10, 14, 18, 3, 7, 11, 15, 19) +"""``policy_order[i] = encoder_order[ENC_TO_POLICY[i]]`` (encoder = ``JOINT_NAMES_20``).""" + +POLICY_JOINT_NAMES: tuple[str, ...] = tuple(JOINT_NAMES_20[i] for i in ENC_TO_POLICY) +"""Hand joint names in policy / Genesis articulation order (joint-major).""" + + +def default_joint_pos() -> np.ndarray: + """Home grasp keyframe as a ``(20,)`` array, in ``JOINT_NAMES_20`` (encoder) order.""" + return np.array([REORIENT_JOINT_POS[name] for name in JOINT_NAMES_20], dtype=float) + + +def default_joint_pos_policy() -> np.ndarray: + """Home grasp keyframe ``(20,)`` in policy / articulation order (joint-major).""" + return default_joint_pos()[list(ENC_TO_POLICY)] diff --git a/examples/wuji/src/genelab_wuji/deploy/config/camera.yaml b/examples/wuji/src/genelab_wuji/deploy/config/camera.yaml new file mode 100644 index 00000000..07912cc7 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/config/camera.yaml @@ -0,0 +1,28 @@ +sensor: + full_width: 1280 + full_height: 1024 +roi: + offset_x: 0 + offset_y: 0 + width: 1280 + height: 1024 +fast_roi: + offset_x: 472 + offset_y: 216 + width: 552 + height: 776 +intrinsics: + fx: 1694.09 + fy: 1692.69 + cx: 644.7 + cy: 477.7 +distortion: + k1: -0.071205 + k2: 0.129295 + p1: 0.000104 + p2: 0.0001 + k3: -0.128663 +capture: + exposure_time: 5000 + gain: 10.0 + frame_rate: 90 diff --git a/examples/wuji/src/genelab_wuji/deploy/config/cube_tags.json b/examples/wuji/src/genelab_wuji/deploy/config/cube_tags.json new file mode 100644 index 00000000..e0e7579e --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/config/cube_tags.json @@ -0,0 +1,156 @@ +{ + "description": "Cube ArUco tag configuration", + "cube_size": 0.054, + "tag_size": 0.013, + "tag_center_offset": 0.018, + "face_rotations": { + "TOP": 0, + "BOTTOM": 180, + "FRONT": 90, + "BACK": 90, + "LEFT": 90, + "RIGHT": 270 + }, + "faces_config": { + "TOP": { + "0": "T", + "2": "R", + "3": "B", + "1": "L" + }, + "BOTTOM": { + "11": "T", + "9": "R", + "8": "B", + "10": "L" + }, + "FRONT": { + "22": "T", + "23": "R", + "21": "B", + "20": "L" + }, + "BACK": { + "18": "T", + "19": "R", + "17": "B", + "16": "L" + }, + "LEFT": { + "15": "R", + "14": "T", + "13": "B", + "12": "L" + }, + "RIGHT": { + "5": "T", + "4": "R", + "7": "L", + "6": "B" + } + }, + "face_axes": { + "TOP": { + "center": [ + 0, + 0, + 1 + ], + "u": [ + 1, + 0, + 0 + ], + "v": [ + 0, + 1, + 0 + ] + }, + "BOTTOM": { + "center": [ + 0, + 0, + -1 + ], + "u": [ + 1, + 0, + 0 + ], + "v": [ + 0, + -1, + 0 + ] + }, + "FRONT": { + "center": [ + 0, + -1, + 0 + ], + "u": [ + 1, + 0, + 0 + ], + "v": [ + 0, + 0, + 1 + ] + }, + "BACK": { + "center": [ + 0, + 1, + 0 + ], + "u": [ + -1, + 0, + 0 + ], + "v": [ + 0, + 0, + 1 + ] + }, + "LEFT": { + "center": [ + -1, + 0, + 0 + ], + "u": [ + 0, + -1, + 0 + ], + "v": [ + 0, + 0, + 1 + ] + }, + "RIGHT": { + "center": [ + 1, + 0, + 0 + ], + "u": [ + 0, + 1, + 0 + ], + "v": [ + 0, + 0, + 1 + ] + } + } +} diff --git a/examples/wuji/src/genelab_wuji/deploy/config/observer.yaml b/examples/wuji/src/genelab_wuji/deploy/config/observer.yaml new file mode 100644 index 00000000..1f2df69d --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/config/observer.yaml @@ -0,0 +1,34 @@ +# Cube World Observer Configuration +# Adjust these parameters to tune tracking behavior +# Note: ZMQ port is configured in control.yaml (zmq.cube_port) + +# SO3 Kalman Filter for rotation +rotation_filter: + process_noise: 0.5 # Higher = more agile, allows faster rotation changes + measurement_noise: 0.1 # Lower = trust PnP measurement more + +# Low-pass filter for position +position_filter: + alpha: 0.8 # Higher = faster tracking (0-1, 1=no filter) + +# Presets (uncomment to use): +# +# Agile (fast response, more noise): +# process_noise: 0.5 +# measurement_noise: 0.1 +# alpha: 0.8 +# +# Smooth (stable, slower response): +# process_noise: 0.01 +# measurement_noise: 2.0 +# alpha: 0.2 + +# PnP solver +pnp: + reproj_threshold: 6.0 # px; mean reprojection error after PnP refinement; treated as failure above this threshold + +# Image preprocessing +preprocess: + enable_clahe: true # if disabled, use min-channel grayscale directly without contrast enhancement + clahe_clip: 2.0 + clahe_tile: [8, 8] diff --git a/examples/wuji/src/genelab_wuji/deploy/controller.py b/examples/wuji/src/genelab_wuji/deploy/controller.py new file mode 100644 index 00000000..1d880072 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/controller.py @@ -0,0 +1,123 @@ +"""Closed-loop deploy controller (hardware/sim-agnostic). + +Wires the deploy pieces into one control step: + + read encoders -> build policy obs (cube/goal from the observer feed) + -> ONNX policy -> EMA action -> write joint target to the hand + +It depends only on small protocols (a hand driver, a cube source, a goal source, +a callable policy), so it runs headlessly with mocks in tests and with the real +hand + ZMQ + Genesis viewer in ``scripts/play_real.py``. +""" + +from __future__ import annotations + +from typing import Any, Protocol + +import numpy as np + +from genelab_wuji.deploy.action import ActionProcessor +from genelab_wuji.deploy.config import N_JOINTS +from genelab_wuji.deploy.obs import DeployObsBuilder + + +class _Policy(Protocol): + def __call__(self, obs: np.ndarray) -> np.ndarray: ... + + +class _CubeSource(Protocol): + def latest(self) -> tuple[np.ndarray, np.ndarray]: ... + + +class _GoalSource(Protocol): + def latest(self) -> np.ndarray: ... + + +class _Driver(Protocol): + def home(self) -> None: ... + def write_target(self, qpos: np.ndarray) -> None: ... + def read_encoders(self) -> np.ndarray: ... + + +class DeployController: + """Run the policy in closed loop against a hand driver and observer feeds.""" + + def __init__( + self, + policy: _Policy, + driver: _Driver, + cube_source: _CubeSource, + goal_source: _GoalSource, + *, + default_joint_pos: np.ndarray, + control_dt: float = 0.05, + action_scale: float = 0.5, + ema_alpha: float = 0.5, + warmup_steps: int = 8, + joint_pos_limits: tuple[np.ndarray, np.ndarray] | None = None, + enc_to_policy: np.ndarray | None = None, + ) -> None: + self.policy = policy + self.driver = driver + self.cube_source = cube_source + self.goal_source = goal_source + self.control_dt = control_dt + # Joint-order remap between the driver (encoder/hardware order) and the policy + # (Genesis articulation order). ``None`` = identity. ``default_joint_pos`` must be + # in the SAME order the policy uses (policy order when a remap is given). + self._enc_to_policy = None if enc_to_policy is None else np.asarray(enc_to_policy) + self._policy_to_enc = ( + None if self._enc_to_policy is None else np.argsort(self._enc_to_policy) + ) + self._default = np.asarray(default_joint_pos, dtype=float) + self._obs = DeployObsBuilder(self._default) + self._action_proc = ActionProcessor( + self._default, + action_scale=action_scale, + ema_alpha=ema_alpha, + warmup_steps=warmup_steps, + joint_pos_limits=joint_pos_limits, + ) + self._last_action = np.zeros(N_JOINTS) + self._prev_joint_pos = self._default.copy() + + def _to_policy(self, v: np.ndarray) -> np.ndarray: + """Reorder an encoder/hardware-order vector into policy order.""" + return v if self._enc_to_policy is None else v[self._enc_to_policy] + + def _to_hardware(self, v: np.ndarray) -> np.ndarray: + """Reorder a policy-order vector into encoder/hardware order.""" + return v if self._policy_to_enc is None else v[self._policy_to_enc] + + def reset(self) -> None: + """Home the hand and clear obs/action/velocity state.""" + self.driver.home() + self._obs.reset() + self._action_proc.reset() + self._last_action = np.zeros(N_JOINTS) + self._prev_joint_pos = self._to_policy(self.driver.read_encoders()) + + def step(self) -> dict[str, Any]: + """Run one control step; return ``{action, target, obs, joint_pos}``.""" + encoder = self.driver.read_encoders() # hardware (encoder) order + joint_pos = self._to_policy(encoder) # policy order + joint_vel = (joint_pos - self._prev_joint_pos) / self.control_dt + self._prev_joint_pos = joint_pos + + cube_pos_tag, cube_quat_tag = self.cube_source.latest() + goal_quat_tag = self.goal_source.latest() + + obs = self._obs.compute( + joint_pos=joint_pos, + joint_vel=joint_vel, + cube_pos_tag=cube_pos_tag, + cube_quat_tag=cube_quat_tag, + goal_quat_tag=goal_quat_tag, + last_action=self._last_action, + ) + action = np.asarray(self.policy(obs), dtype=float) # policy order + target = self._action_proc.process(action) # policy order + self.driver.write_target(self._to_hardware(target)) # back to hardware order + self._last_action = action + # ``joint_pos`` returned in encoder/hardware order (for the viewer's name-based remap). + return {"action": action, "target": target, "obs": obs, "joint_pos": encoder} diff --git a/examples/wuji/src/genelab_wuji/deploy/cube_geom.py b/examples/wuji/src/genelab_wuji/deploy/cube_geom.py new file mode 100644 index 00000000..cf0b5f91 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/cube_geom.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Wuji Technology Co., Ltd. +# Ported into GeneLab from wuji-mjlab deploy/reorient/lib/cube_geom.py +"""Shared helpers for cube geometry: cube_tags JSON resolution + runtime scaling. + +Used by both the vision pipeline (cube_world_observer.py — needs cube/tag sizes +for AprilTag/ArUco PnP) and the sim/benchmark pipeline (play_real.py — needs to +patch the MuJoCo cube body so visualization and physics match the real cube). +""" +from __future__ import annotations + +import json +import os +from typing import Any + +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +DEFAULT_CUBE_CONFIG_FILE = os.path.join(ROOT_DIR, "config", "cube_tags.json") + + +def resolve_cube_config_path(arg: str | None) -> str: + """Resolve a --cube CLI argument to an absolute cube_tags JSON path. + + Accepts: + - None / "" / "default" -> ``config/cube_tags.json`` (the 54mm baseline). + - An existing path (absolute or relative to cwd) -> used as-is. + - A short size suffix like ``"36"`` / ``"40_5"`` -> ``config/cube_tags.json``. + + Raises FileNotFoundError if the resolved path does not exist. + """ + if not arg or arg == "default": + return DEFAULT_CUBE_CONFIG_FILE + if os.path.exists(arg): + return os.path.abspath(arg) + candidate = os.path.join(ROOT_DIR, "config", f"cube_tags{arg}.json") + if os.path.exists(candidate): + return candidate + raise FileNotFoundError( + f"--cube={arg!r}: not a path nor a known size suffix. Tried {candidate!r}." + ) + + +def load_cube_config(path: str) -> dict[str, Any]: + """Load a cube_tags JSON file.""" + with open(path) as f: + return json.load(f) diff --git a/examples/wuji/src/genelab_wuji/deploy/frame_transform.py b/examples/wuji/src/genelab_wuji/deploy/frame_transform.py new file mode 100644 index 00000000..aa00a9dd --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/frame_transform.py @@ -0,0 +1,70 @@ +"""Frame transforms for the deploy pipeline (numpy, wxyz quaternion convention). + +Two concerns live here: + +* **Quaternion math** (``quat_apply`` / ``quat_mul`` / ``quat_conjugate``) in the + wxyz Hamilton convention used across the MuJoCo / mjlab / GeneLab stack. +* **real2sim lift** (``cube_cam_to_tag``): the vision pipeline reports the cube + pose and the wrist-AprilTag pose both in the *camera* frame; ``cube_cam_to_tag`` + expresses the cube in the *tag* frame, which is the frame the policy observes. + +All rotations are 3x3 matrices whose columns are the source-frame axes written in +the target frame (``R_a_b`` maps a vector in frame ``a`` into frame ``b``). +""" + +from __future__ import annotations + +import numpy as np + + +def quat_apply(quat_wxyz: np.ndarray, vec: np.ndarray) -> np.ndarray: + """Rotate ``vec`` by the unit quaternion ``quat_wxyz`` (Hamilton, wxyz).""" + w, x, y, z = quat_wxyz + R = np.array([ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [ 2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [ 2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ]) + return R @ vec + + +def quat_mul(q1: np.ndarray, q2: np.ndarray) -> np.ndarray: + """Hamilton product of two wxyz quaternions: ``q1 ∘ q2``.""" + w1, x1, y1, z1 = q1 + w2, x2, y2, z2 = q2 + return np.array([ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + ]) + + +def quat_conjugate(quat_wxyz: np.ndarray) -> np.ndarray: + """Conjugate (== inverse for a unit quat) of a wxyz quaternion.""" + w, x, y, z = quat_wxyz + return np.array([w, -x, -y, -z]) + + +def cube_cam_to_tag( + R_tag_cam: np.ndarray, + t_tag_cam: np.ndarray, + R_cube_cam: np.ndarray, + t_cube_cam: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Express the cube pose in the wrist-tag frame. + + Args: + R_tag_cam: ``(3, 3)`` tag rotation in the camera frame. + t_tag_cam: ``(3,)`` tag origin in the camera frame. + R_cube_cam: ``(3, 3)`` cube rotation in the camera frame. + t_cube_cam: ``(3,)`` cube center in the camera frame. + + Returns: + ``(R_cube_tag, t_cube_tag)`` — the cube pose in the tag frame. + """ + R_cam_tag = R_tag_cam.T + t_cam_tag = -R_cam_tag @ t_tag_cam + R_cube_tag = R_cam_tag @ R_cube_cam + t_cube_tag = R_cam_tag @ t_cube_cam + t_cam_tag + return R_cube_tag, t_cube_tag diff --git a/examples/wuji/src/genelab_wuji/deploy/hand_driver.py b/examples/wuji/src/genelab_wuji/deploy/hand_driver.py new file mode 100644 index 00000000..9638c205 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/hand_driver.py @@ -0,0 +1,130 @@ +"""Hand-driver abstraction: hardware-agnostic interface + mock + real (wujihandpy). + +The control loop depends only on ``HandDriverBase``. ``MockHandDriver`` echoes +written targets so the full pipeline runs and tests headlessly; ``WujiHandDriver`` +talks to the real hand and is imported lazily so the dependency is optional. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import numpy as np + +from genelab_wuji.deploy.config import JOINT_NAMES_20, N_JOINTS, default_joint_pos + + +def _home_ramp(current: np.ndarray, target: np.ndarray, steps: int) -> np.ndarray: + """Ease-in-out (smoothstep) interpolation from ``current`` to ``target``. + + Returns ``(steps, 20)`` intermediate targets; the last row equals ``target`` + exactly (smoothstep ``3t²-2t³`` reaches 1 at ``t=1``). Pure/numpy so the ramp + math is testable without hardware. + """ + current = np.asarray(current, dtype=float) + target = np.asarray(target, dtype=float) + steps = max(1, int(steps)) + t = (np.arange(1, steps + 1, dtype=float) / steps)[:, None] # (steps, 1), ends at 1.0 + t_smooth = t * t * (3.0 - 2.0 * t) + return current[None, :] + t_smooth * (target - current)[None, :] + + +class HandDriverBase(ABC): + """Interface every hand backend implements (targets/encoders flattened to 20).""" + + @abstractmethod + def home(self, duration_s: float = 3.0) -> None: + """Drive the hand to the home grasp keyframe (ease-in-out ramp over ``duration_s``).""" + + @abstractmethod + def write_target(self, qpos: np.ndarray) -> None: + """Command a ``(20,)`` joint position target (JOINT_NAMES_20 order).""" + + @abstractmethod + def read_encoders(self) -> np.ndarray: + """Read the actual ``(20,)`` joint positions (JOINT_NAMES_20 order).""" + + def joint_names_in_encoder_order(self) -> tuple[str, ...]: + """Joint names matching ``read_encoders`` / ``write_target`` indexing.""" + return JOINT_NAMES_20 + + +class MockHandDriver(HandDriverBase): + """In-memory hand: ``read_encoders`` echoes the last ``write_target``. + + Starts at the home grasp pose so a fresh driver reads a sensible state. + """ + + def __init__(self) -> None: + self._state = default_joint_pos() + + def home(self, duration_s: float = 3.0) -> None: + # No hardware to ease; the ramp is a real-driver safety concern only. + self._state = default_joint_pos() + + def write_target(self, qpos: np.ndarray) -> None: + qpos = np.asarray(qpos, dtype=float) + if qpos.shape != (N_JOINTS,): + raise ValueError(f"qpos shape {qpos.shape}, expected ({N_JOINTS},)") + self._state = qpos.copy() + + def read_encoders(self) -> np.ndarray: + return self._state.copy() + + +class WujiHandDriver(HandDriverBase): + """Real Wuji hand via ``wujihandpy`` (imported lazily; untested in CI). + + The hardware exposes a (5, 4) array (5 fingers x 4 joints); we flatten to + (20,) at the boundary, which matches ``JOINT_NAMES_20`` row-major order. + Use as a context manager so joints are enabled on enter / disabled on exit. + """ + + def __init__(self, effort_limit_nm: float = 0.5) -> None: + import wujihandpy # noqa: F401 (fail loudly if the dep is missing) + + self._wujihandpy = wujihandpy + self.effort_limit_nm = effort_limit_nm + self._hand: Any = None + + def __enter__(self) -> "WujiHandDriver": + self._hand = self._wujihandpy.Hand() + self._hand.write_joint_effort_limit(self.effort_limit_nm) + self._hand.write_joint_enabled(True) + return self + + def __exit__(self, *exc: object) -> None: + if self._hand is not None: + self._hand.write_joint_enabled(False) + self._hand = None + + def home(self, duration_s: float = 3.0) -> None: + """Smoothly ramp from the current pose to the home grasp keyframe. + + Ease-in-out interpolation at 50 Hz over ``duration_s`` so the hand eases + in rather than snapping (a single instant write can jerk the joints). + ``duration_s <= 0`` does one immediate write. + """ + import time + + target = default_joint_pos() + if duration_s <= 0: + self.write_target(target) + return + steps = max(1, int(duration_s * 50.0)) # 50 Hz smoothing + dt = duration_s / steps + for frame in _home_ramp(self.read_encoders(), target, steps): + self.write_target(frame) + time.sleep(dt) + + def write_target(self, qpos: np.ndarray) -> None: + assert self._hand is not None, "enter the WujiHandDriver context first" + qpos = np.asarray(qpos, dtype=float) + if qpos.shape != (N_JOINTS,): + raise ValueError(f"qpos shape {qpos.shape}, expected ({N_JOINTS},)") + self._hand.write_joint_target_position(qpos.reshape(5, 4)) + + def read_encoders(self) -> np.ndarray: + assert self._hand is not None, "enter the WujiHandDriver context first" + return np.asarray(self._hand.read_joint_actual_position(), dtype=float).reshape(N_JOINTS) diff --git a/examples/wuji/src/genelab_wuji/deploy/obs.py b/examples/wuji/src/genelab_wuji/deploy/obs.py new file mode 100644 index 00000000..7bed480b --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/obs.py @@ -0,0 +1,97 @@ +"""Assemble the policy observation vector for deploy (pure numpy). + +Reproduces the GeneLab training policy obs group so an exported ONNX policy +receives exactly what it was trained on. No forward kinematics is required: joint +state comes from the hand encoders, cube/goal poses arrive from the observer +already in the wrist-tag frame, and the last action is tracked by the action term. + +Term order, per-term history, and the 6D goal-error encoding mirror +``genelab_wuji.reorient.mdp.observations`` and ``...mdp._math``. +""" + +from __future__ import annotations + +import numpy as np + +from genelab_wuji.deploy.frame_transform import quat_apply, quat_conjugate, quat_mul + + +def _matrix_from_quat(quat_wxyz: np.ndarray) -> np.ndarray: + """Rotation matrix from a wxyz quaternion (columns are the rotated basis).""" + return np.stack( + [ + quat_apply(quat_wxyz, np.array([1.0, 0.0, 0.0])), + quat_apply(quat_wxyz, np.array([0.0, 1.0, 0.0])), + quat_apply(quat_wxyz, np.array([0.0, 0.0, 1.0])), + ], + axis=1, + ) + + +def goal_rot_err_6d(cube_quat_tag: np.ndarray, goal_quat_tag: np.ndarray) -> np.ndarray: + """6D rotation error (first two matrix rows) of cube-to-goal, tag frame. + + Mirrors ``genelab_wuji.reorient.mdp.observations.goal_rot_err_6d``: + ``matrix_to_rotation_6d(matrix_from_quat(cube_quat ∘ goal_quat*))``. + """ + err_quat = quat_mul(cube_quat_tag, quat_conjugate(goal_quat_tag)) + rot = _matrix_from_quat(err_quat) + return rot[:2, :].reshape(6) + + +class DeployObsBuilder: + """Build the 207-dim policy obs with per-term 3-step history. + + Each term keeps a ``(history_len, dim)`` buffer, term-major oldest->newest. + ``reset()`` clears the buffers; the first ``compute`` after a reset backfills + every history slot with the current frame (matches the training CircularBuffer). + """ + + # (name, dim) in the order the GeneLab policy obs group concatenates them. + _TERMS: tuple[tuple[str, int], ...] = ( + ("joint_pos_rel", 20), + ("joint_vel_rel", 20), + ("cube_pos_in_tag", 3), + ("goal_rot_err_6d", 6), + ("last_action", 20), + ) + + def __init__(self, default_joint_pos: np.ndarray, history_len: int = 3) -> None: + self.default_joint_pos = np.asarray(default_joint_pos, dtype=float) + self.history_len = history_len + self._buffers: dict[str, np.ndarray] = {} + + def reset(self) -> None: + """Clear history so the next ``compute`` backfills.""" + self._buffers = {} + + def _push(self, name: str, value: np.ndarray) -> np.ndarray: + """Append ``value`` to the term buffer (backfill on first frame); flatten.""" + value = np.asarray(value, dtype=float) + buf = self._buffers.get(name) + if buf is None or buf.shape != (self.history_len, value.shape[0]): + buf = np.repeat(value[None, :], self.history_len, axis=0) + else: + buf = np.concatenate([buf[1:], value[None, :]], axis=0) + self._buffers[name] = buf + return buf.reshape(-1) + + def compute( + self, + joint_pos: np.ndarray, + joint_vel: np.ndarray, + cube_pos_tag: np.ndarray, + cube_quat_tag: np.ndarray, + goal_quat_tag: np.ndarray, + last_action: np.ndarray, + ) -> np.ndarray: + """Return the flat policy obs vector for this control step.""" + frame = { + "joint_pos_rel": np.asarray(joint_pos, dtype=float) - self.default_joint_pos, + "joint_vel_rel": np.asarray(joint_vel, dtype=float), + "cube_pos_in_tag": np.asarray(cube_pos_tag, dtype=float), + "goal_rot_err_6d": goal_rot_err_6d(cube_quat_tag, goal_quat_tag), + "last_action": np.asarray(last_action, dtype=float), + } + blocks = [self._push(name, frame[name]) for name, _dim in self._TERMS] + return np.concatenate(blocks).astype(np.float32) diff --git a/examples/wuji/src/genelab_wuji/deploy/onnx_policy.py b/examples/wuji/src/genelab_wuji/deploy/onnx_policy.py new file mode 100644 index 00000000..04915354 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/onnx_policy.py @@ -0,0 +1,80 @@ +"""Thin ONNX policy wrapper for deploy (ported from wuji-mjlab, GeneLab metadata). + +Obs assembly and action post-processing live elsewhere (``DeployObsBuilder`` / +the action term), so this is just: load the session, introspect dims, optionally +read the sibling ``.metadata.json`` GeneLab's exporter writes, and run a +single forward pass. Normalization is baked into the graph, so no normalizer is +needed at deploy time. +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import onnxruntime as ort + + +class ONNXPolicy: + """Single-step ONNX policy. + + Args: + onnx_path: Path to the ``policy.onnx`` exported by ``genelab export``. + metadata_path: Optional metadata sidecar. When ``None``, looks for + ``.metadata.json`` then ``/metadata.json``. + """ + + def __init__( + self, + onnx_path: str | Path, + metadata_path: Optional[str | Path] = None, + ) -> None: + onnx_path = str(onnx_path) + if not os.path.exists(onnx_path): + raise FileNotFoundError(f"ONNX not found: {onnx_path}") + + self.onnx_path: str = onnx_path + self.session = ort.InferenceSession( + onnx_path, providers=["CPUExecutionProvider"] + ) + inp = self.session.get_inputs()[0] + out = self.session.get_outputs()[0] + self.input_name: str = inp.name + self.output_name: str = out.name + # Shapes are (batch, N); axis 0 may be a dynamic symbol — take the last dim. + self.input_dim: int = int(inp.shape[-1]) + self.action_dim: int = int(out.shape[-1]) + self.metadata: dict[str, Any] = self._load_metadata(onnx_path, metadata_path) + + @staticmethod + def _load_metadata( + onnx_path: str, metadata_path: Optional[str | Path] + ) -> dict[str, Any]: + candidates = ( + [str(metadata_path)] + if metadata_path is not None + else [ + onnx_path + ".metadata.json", + os.path.join(os.path.dirname(onnx_path), "metadata.json"), + ] + ) + for candidate in candidates: + if os.path.exists(candidate): + with open(candidate) as f: + return json.load(f) + return {} + + def __call__(self, obs: np.ndarray) -> np.ndarray: + """Single forward pass; accepts ``(input_dim,)`` or ``(1, input_dim)``.""" + if obs.ndim == 1: + obs = obs[None, :] + if obs.shape != (1, self.input_dim): + raise ValueError( + f"obs shape {obs.shape}, expected (1, {self.input_dim})" + ) + obs = obs.astype(np.float32, copy=False) + result = self.session.run([self.output_name], {self.input_name: obs})[0] + return result.squeeze(0) diff --git a/examples/wuji/src/genelab_wuji/deploy/real2sim.py b/examples/wuji/src/genelab_wuji/deploy/real2sim.py new file mode 100644 index 00000000..8c410da9 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/real2sim.py @@ -0,0 +1,54 @@ +"""Reproduce the real cube inside the Genesis sim from its tag-frame pose. + +The vision pipeline reports the cube pose in the wrist-AprilTag frame (the frame +the policy observes). To visualize it in sim we need the tag's pose in sim-world +coordinates, then lift the cube through it. For the fixed-base hand the tag world +pose is constant and derived from the palm pose via the ``TAG_IN_PALM`` rigid +offset (mirrors ``genelab_wuji.reorient.mdp.observations._tag_pose``). + +``cube_pose_in_tag_to_world`` (viewer) and ``cube_pose_world_to_tag`` (obs) are +exact inverses so what the policy sees and what the viewer draws agree. +""" + +from __future__ import annotations + +import numpy as np + +from genelab_wuji.deploy.frame_transform import quat_apply, quat_conjugate, quat_mul +from genelab_wuji.reorient.constants import TAG_IN_PALM_POS, TAG_IN_PALM_QUAT_WXYZ + + +def tag_pose_in_world( + palm_pos_w: np.ndarray, palm_quat_w: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """World pose of the wrist tag given the palm world pose.""" + tag_in_palm_pos = np.asarray(TAG_IN_PALM_POS, dtype=float) + tag_in_palm_quat = np.asarray(TAG_IN_PALM_QUAT_WXYZ, dtype=float) + tag_pos_w = palm_pos_w + quat_apply(palm_quat_w, tag_in_palm_pos) + tag_quat_w = quat_mul(palm_quat_w, tag_in_palm_quat) + return tag_pos_w, tag_quat_w + + +def cube_pose_in_tag_to_world( + tag_pos_w: np.ndarray, + tag_quat_w: np.ndarray, + cube_pos_tag: np.ndarray, + cube_quat_tag: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Lift a tag-frame cube pose into sim-world coordinates (for the viewer).""" + cube_pos_w = tag_pos_w + quat_apply(tag_quat_w, cube_pos_tag) + cube_quat_w = quat_mul(tag_quat_w, cube_quat_tag) + return cube_pos_w, cube_quat_w + + +def cube_pose_world_to_tag( + tag_pos_w: np.ndarray, + tag_quat_w: np.ndarray, + cube_pos_w: np.ndarray, + cube_quat_w: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Express a world cube pose in the tag frame (inverse of the lift above).""" + tag_quat_inv = quat_conjugate(tag_quat_w) + cube_pos_tag = quat_apply(tag_quat_inv, cube_pos_w - tag_pos_w) + cube_quat_tag = quat_mul(tag_quat_inv, cube_quat_w) + return cube_pos_tag, cube_quat_tag diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/__init__.py b/examples/wuji/src/genelab_wuji/deploy/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/_env.py b/examples/wuji/src/genelab_wuji/deploy/scripts/_env.py new file mode 100644 index 00000000..65c4eac4 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/_env.py @@ -0,0 +1,104 @@ +"""Shared helpers to build the Genesis reorient scene for deploy visualization. + +Both ``toreal_viewer`` (real2sim) and ``play_real`` (control mirror) need the same +play-mode reorient env plus a way to (a) read the wrist-tag world pose and (b) set +the cube / hand pose each frame. Heavy imports (genesis, torch, the env) are kept +inside the functions so importing this module is cheap and headless-safe. +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np + + +def build_reorient_env(num_envs: int | None = None) -> Any: + """Build the play-mode reorient env (hand + cube) with auto-reset disabled. + + ``num_envs`` overrides the cfg's play default (4); deploy viewers only ever use + env 0, so single-process tools (e.g. ``calib_check``) pass ``1`` to avoid + rendering parallel copies. + """ + from genelab.envs.manager_based_rl_env import ManagerBasedRlEnv + from genelab_wuji.reorient.env_cfg import wuji_hand_reorient_env_cfg + + cfg = wuji_hand_reorient_env_cfg(play=True) + cfg.auto_reset = False # we drive poses by hand; never teleport on "done" + if num_envs is not None: + cfg.simulation.num_envs = num_envs + env = ManagerBasedRlEnv(cfg) + env.reset() + return env + + +def tag_world_pose(env: Any) -> tuple[np.ndarray, np.ndarray]: + """Wrist-tag pose in sim-world coordinates (numpy ``(3,)`` / ``(4,)``). + + Reuses the task's own ``_tag_pose`` so the viewer frame matches the obs frame. + """ + from genelab_wuji.reorient.mdp.observations import _tag_pose + + tag_pos, tag_quat = _tag_pose(env) # torch (B, 3) / (B, 4) + return tag_pos[0].detach().cpu().numpy(), tag_quat[0].detach().cpu().numpy() + + +def set_cube_pose(env: Any, pos_w: np.ndarray, quat_w: np.ndarray) -> None: + """Kinematically place the sim cube at the given world pose (zero velocity).""" + import torch + + handle = env.scene["object"].gs_handle + device = env.device + pos = torch.tensor(pos_w, dtype=torch.float, device=device).unsqueeze(0) + quat = torch.tensor(quat_w, dtype=torch.float, device=device).unsqueeze(0) + zeros = torch.zeros(1, 3, device=device) + for setter, value in (("set_pos", pos), ("set_quat", quat), ("set_vel", zeros), ("set_ang", zeros)): + fn = getattr(handle, setter, None) + if fn is not None: + fn(value) + + +def set_goal_marker(env: Any, quat_w: np.ndarray) -> None: + """Orient the play-mode ``goal_marker`` entity to the target (viewer only). + + The marker sits at a fixed display pose (``GOAL_MARKER_POS``); we only rewrite + its orientation so play_real's viewer shows the current goal. No-op if the scene + has no goal marker (e.g. a non-play scene). + """ + import torch + + try: + handle = env.scene["goal_marker"].gs_handle + except (KeyError, AttributeError): + return + quat = torch.tensor(quat_w, dtype=torch.float, device=env.device).unsqueeze(0) + set_quat = getattr(handle, "set_quat", None) + if set_quat is not None: + set_quat(quat) + + +def set_hand_joints(env: Any, qpos_encoder_order: np.ndarray) -> None: + """Kinematically render hand encoder readings on the sim robot (calib viewer). + + ``qpos_encoder_order`` is the ``(20,)`` joint vector in ``JOINT_NAMES_20`` / + encoder order; it is reordered (by name) into the articulation's ``joint_names`` + order and written as joint state (zero velocity). Teleport-per-tick: the next + write re-syncs, so a single physics step can't drift the rendered pose. + """ + import torch + + from genelab_wuji.deploy.config import JOINT_NAMES_20 + + robot = env.scene["robot"] + enc_index = {name: i for i, name in enumerate(JOINT_NAMES_20)} + missing = [n for n in robot.joint_names if n not in enc_index] + if missing: + raise ValueError(f"robot joints absent from encoder order JOINT_NAMES_20: {missing}") + perm = [enc_index[n] for n in robot.joint_names] # joint_names[i] = encoder[perm[i]] + reordered = np.asarray(qpos_encoder_order, dtype=float)[perm] + + device = env.device + jp = torch.tensor(reordered, dtype=torch.float, device=device).unsqueeze(0) + jv = torch.zeros_like(jp) + env_ids = torch.arange(env.num_envs, device=device) + robot.write_joint_state(jp, jv, env_ids) diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/calib_check.py b/examples/wuji/src/genelab_wuji/deploy/scripts/calib_check.py new file mode 100644 index 00000000..ae0b3cfd --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/calib_check.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Static calibration check for cube pose vs. hand pose (Genesis-native port). + +Ported from wuji-mjlab deploy/reorient/tools/calib_check.py — adapted to the +Genesis reorient scene and GeneLab's deploy helpers. + +What this does: + 1. Initialises the hand driver and ramps the hand to the reorient home pose. + 2. Opens the Genesis viewer showing the digital twin (hand + cube). + 3. Loops at ~20 Hz reading joint encoders + cube pose over ZMQ, rendering the + live hand pose and the observed cube each tick. + +Use this to eyeball whether the ArUco-based cube pose estimate (anchored to the +wrist AprilTag world frame) matches the physical cube. There is no policy and no +control beyond the initial homing ramp — the hand stays at home. + +Run: + python -m genelab_wuji.deploy.scripts.cube_world_observer & # publisher + python -m genelab_wuji.deploy.scripts.calib_check # this tool + python -m genelab_wuji.deploy.scripts.calib_check --mock # no hardware + +Press Ctrl+C or close the viewer window to exit. Needs a GPU + display (Genesis +viewer); the transform math itself is covered by the headless deploy tests. +""" + +from __future__ import annotations + +import argparse +import time +from typing import Any + +from genelab_wuji.deploy.hand_driver import HandDriverBase, MockHandDriver +from genelab_wuji.deploy.real2sim import cube_pose_in_tag_to_world +from genelab_wuji.deploy.scripts._env import ( + build_reorient_env, + set_cube_pose, + set_hand_joints, + tag_world_pose, +) +from genelab_wuji.deploy.zmq_bridge import DEFAULT_CUBE_PORT, CubeReceiver + + +def _loop(env: Any, drv: HandDriverBase, cube_recv: CubeReceiver | None, rate_hz: float) -> None: + """Render the live hand + observed cube until the viewer closes / Ctrl+C.""" + tag_pos_w, tag_quat_w = tag_world_pose(env) # fixed-base hand -> constant + dt = 1.0 / max(rate_hz, 1.0) + try: + while not env.viewer_closed: + set_hand_joints(env, drv.read_encoders()) + if cube_recv is not None: + cube_pos_tag, cube_quat_tag = cube_recv.latest() + cube_pos_w, cube_quat_w = cube_pose_in_tag_to_world( + tag_pos_w, tag_quat_w, cube_pos_tag, cube_quat_tag + ) + set_cube_pose(env, cube_pos_w, cube_quat_w) + # FK-only refresh: render the teleported hand/cube as-set, no physics + # integration (so the cube isn't pulled down by gravity each frame). + env.scene.refresh_visualizer() + time.sleep(dt) + except KeyboardInterrupt: + print("\n[calib-check] interrupted by user.") + + +def main() -> int: + parser = argparse.ArgumentParser(description=(__doc__ or "").splitlines()[0]) + parser.add_argument("--cube-port", type=int, default=DEFAULT_CUBE_PORT) + parser.add_argument("--host", default="localhost") + parser.add_argument( + "--no-cube-zmq", + action="store_true", + help="skip CubeReceiver — render only the hand (diagnose publisher issues).", + ) + parser.add_argument("--rate-hz", type=float, default=20.0, help="viewer refresh rate") + parser.add_argument( + "--effort-limit", type=float, default=0.5, help="per-joint Nm for the real hand" + ) + parser.add_argument( + "--mock", action="store_true", help="use the mock hand (no hardware; renders home pose)" + ) + args = parser.parse_args() + + cube_recv: CubeReceiver | None = None + if not args.no_cube_zmq: + cube_recv = CubeReceiver(port=args.cube_port, host=args.host) + print(f"[calib-check] CubeReceiver listening on tcp://{args.host}:{args.cube_port}") + else: + print("[calib-check] cube ZMQ disabled (--no-cube-zmq)") + + env = build_reorient_env(num_envs=1) # single viewer; no parallel envs + try: + if args.mock: + drv: HandDriverBase = MockHandDriver() + drv.home() + print("[calib-check] mock hand at home — viewer up. Ctrl+C / close window to exit.") + _loop(env, drv, cube_recv, args.rate_hz) + else: + from genelab_wuji.deploy.hand_driver import WujiHandDriver + + with WujiHandDriver(effort_limit_nm=args.effort_limit) as drv: + print("[calib-check] homing hand ...") + drv.home(duration_s=3.0) + print( + "[calib-check] homed — viewer up. Compare the rendered cube against the " + "real cube. Ctrl+C / close window to exit." + ) + _loop(env, drv, cube_recv, args.rate_hz) + finally: + if cube_recv is not None: + cube_recv.close() + env.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/cube_world_observer.py b/examples/wuji/src/genelab_wuji/deploy/scripts/cube_world_observer.py new file mode 100644 index 00000000..18b0e6c0 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/cube_world_observer.py @@ -0,0 +1,1332 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2026 Wuji Technology Co., Ltd. +""" +Cube World Observer + +Detects cube pose relative to world coordinate system defined by AprilTag. +Publishes pose via ZMQ for use by play_real.py. + +Ported into GeneLab from wuji-mjlab deploy/reorient/scripts/cube_world_observer.py. + +Features: +- World frame defined by AprilTag ID 0 +- Cube detection using ArUco 4x4 tags with dominant face strategy +- SO3 Kalman Filter for rotation smoothing +- ZMQ publishing on port 5555 + +Usage: + python -m genelab_wuji.deploy.scripts.cube_world_observer --preview # With visualization + python -m genelab_wuji.deploy.scripts.cube_world_observer # Headless mode + +On startup, the world coordinate system is auto-sampled (100 frames by default), +then a fixed world frame is used. Press 'w' to resample the world frame. +""" +import sys +import os +import time +import json +import yaml +import numpy as np +import cv2 +from ctypes import * +from scipy.spatial.transform import Rotation +from scipy.linalg import inv + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # deploy/scripts +ROOT_DIR = os.path.dirname(SCRIPT_DIR) # deploy +# MvImport: Hikvision MVS SDK Python bindings. +# System-level dependency (NOT vendored in this repo). Default install path is +# /opt/MVS; override with MVS_PYTHON_PATH env var if installed elsewhere. +_mvs_python_path = os.environ.get("MVS_PYTHON_PATH", "/opt/MVS/Samples/64/Python") +if not os.path.isdir(os.path.join(_mvs_python_path, "MvImport")): + raise RuntimeError( + f"MvImport not found at {_mvs_python_path}/MvImport. " + "Install Hikvision MVS SDK (https://www.hikrobotics.com) or set " + "MVS_PYTHON_PATH env var to the dir containing MvImport/." + ) +sys.path.insert(0, _mvs_python_path) + +os.environ["OPENCV_LOG_LEVEL"] = "SILENT" + +from MvImport.MvCameraControl_class import * +from genelab_wuji.deploy.camera_config import ( + load_camera_config, get_camera_matrix, get_dist_coeffs, + setup_camera_roi, setup_camera_capture +) +from genelab_wuji.deploy.zmq_bridge import DEFAULT_CUBE_PORT, CubePublisher + +try: + from pupil_apriltags import Detector as AprilTagDetector +except ImportError: + print("ERROR: pupil_apriltags not installed. Run: pip install pupil-apriltags") + sys.exit(1) + +import zmq + +# Load camera config +_cam_cfg = load_camera_config() +K = get_camera_matrix(_cam_cfg) +DIST_COEFFS = get_dist_coeffs(_cam_cfg) + +# Config files +OBSERVER_CONFIG_FILE = os.path.join(os.path.dirname(SCRIPT_DIR), "config", "observer.yaml") # deploy/config/observer.yaml + +from genelab_wuji.deploy.cube_geom import ( + resolve_cube_config_path, + DEFAULT_CUBE_CONFIG_FILE as CUBE_CONFIG_FILE, +) + +# World origin = AprilTag ID 0 on the wrist. +WORLD_TAG_ID = 0 +WORLD_TAG_SIZE = 0.048 # 48mm +WORLD_SAMPLE_FRAMES = 100 # Number of frames to sample for world frame averaging + +# Optional world-frame correction; None = use AprilTag frame as-is. +# AprilTag detector (X-right, Y-down, Z-into-tag) -> MuJoCo wrist tag (right-handed) +# Pure handedness flip: same X, flipped Y and Z (printed tag X aligns with MuJoCo wrist tag X). +WORLD_FRAME_CORRECTION = np.array([ + [ 1.0, 0.0, 0.0], + [ 0.0, -1.0, 0.0], + [ 0.0, 0.0, -1.0], +]) +# WORLD_FRAME_CORRECTION = "+x +z -y" # Example: remap axes +# WORLD_FRAME_CORRECTION = np.array([[1,0,0], [0,0,1], [0,-1,0]]) # Same as above + + +def parse_axis_remap(remap_str): + """Parse axis remapping string to rotation matrix. + + Args: + remap_str: String like "+x +z -y" specifying how AprilTag axes map to new world axes + Format: "new_X new_Y new_Z" where each is ±x, ±y, or ±z + + Returns: + 3x3 rotation matrix R such that new_point = R @ apriltag_point + """ + axis_map = { + '+x': np.array([1, 0, 0]), + '-x': np.array([-1, 0, 0]), + '+y': np.array([0, 1, 0]), + '-y': np.array([0, -1, 0]), + '+z': np.array([0, 0, 1]), + '-z': np.array([0, 0, -1]), + } + + parts = remap_str.lower().split() + if len(parts) != 3: + raise ValueError(f"Axis remap must have 3 parts, got: {remap_str}") + + # Build rotation matrix: columns are where AprilTag axes go + # But we want rows to be where new axes come from + R = np.zeros((3, 3)) + for i, part in enumerate(parts): + if part not in axis_map: + raise ValueError(f"Invalid axis: {part}. Use +x,-x,+y,-y,+z,-z") + R[i, :] = axis_map[part] + + # Verify it's a valid rotation (det = +1 for right-handed) + det = np.linalg.det(R) + if not np.isclose(abs(det), 1.0): + raise ValueError(f"Invalid axis remap: axes not orthogonal (det={det:.3f})") + if det < 0: + raise ValueError(f"Invalid axis remap: forms left-handed system (det={det:.3f}). " + "Hint: flip one axis sign to make it right-handed.") + + return R + +# No silent default — pass --cube to override config/cube_tags.json. + +# Cube frame correction rotation matrix +# Corrects the difference between ArUco board coordinate system and MuJoCo mesh coordinate system +# Format: same as WORLD_FRAME_CORRECTION (matrix or axis remap string) +# Set to None if ArUco board axes match MuJoCo mesh axes +CUBE_FRAME_CORRECTION = None # None = no correction +# Example: if ArUco X,Y,Z maps to MuJoCo Y,Z,X: CUBE_FRAME_CORRECTION = "+y +z +x" + +# Face colors (matching MuJoCo dex_cube) +FACE_COLORS = { + 'TOP': ('Cyan', (255, 255, 0)), # BGR + 'BOTTOM': ('Blue', (255, 0, 0)), + 'FRONT': ('Red', (0, 0, 255)), + 'BACK': ('White', (255, 255, 255)), + 'LEFT': ('Green', (0, 255, 0)), + 'RIGHT': ('Yellow', (0, 255, 255)), +} + + +def load_observer_config(): + """Load observer configuration from YAML file.""" + defaults = { + 'rotation_filter': { + 'process_noise': 0.1, + 'measurement_noise': 0.3, + }, + 'position_filter': { + 'alpha': 0.6, + }, + 'pnp': { + 'reproj_threshold': 6.0, + }, + 'preprocess': { + 'enable_clahe': True, + 'clahe_clip': 2.0, + 'clahe_tile': [8, 8], + }, + } + + if os.path.exists(OBSERVER_CONFIG_FILE): + try: + with open(OBSERVER_CONFIG_FILE, 'r') as f: + cfg = yaml.safe_load(f) + # Merge with defaults + for key in defaults: + if key in cfg: + defaults[key].update(cfg[key]) + print(f"Loaded observer config from {OBSERVER_CONFIG_FILE}") + except Exception as e: + print(f"Warning: Failed to load observer config: {e}, using defaults") + + return defaults + + +class SO3KalmanFilter: + """SO(3) rotation Kalman filter in tangent space.""" + + def __init__(self, process_noise=0.01, measurement_noise=0.1): + self.state = np.zeros(3) + self.covariance = np.eye(3) * 0.1 + self.Q = np.eye(3) * process_noise + self.R_noise = np.eye(3) * measurement_noise + self.is_initialized = False + self.reference_rot = np.eye(3) + self.filtered_rot = np.eye(3) + + def _rotation_to_axis_angle(self, R): + """Convert rotation matrix to axis-angle (more stable than logm near 180°).""" + # Use Rodrigues formula for stability + rvec, _ = cv2.Rodrigues(R) + return rvec.flatten() + + def _axis_angle_to_rotation(self, rvec): + """Convert axis-angle to rotation matrix.""" + R, _ = cv2.Rodrigues(rvec.reshape(3, 1)) + return R + + def update(self, rotation_matrix): + if not self.is_initialized: + self.reference_rot = rotation_matrix.copy() + self.filtered_rot = rotation_matrix.copy() + self.is_initialized = True + return rotation_matrix + + # Compute relative rotation + R_relative = rotation_matrix @ self.reference_rot.T + + # Use axis-angle instead of logm for stability near 180° + z_local = self._rotation_to_axis_angle(R_relative) + + # Kalman update + self.covariance = self.covariance + self.Q + S = self.covariance + self.R_noise + K_gain = self.covariance @ inv(S) + self.state = self.state + K_gain @ (z_local - self.state) + self.covariance = (np.eye(3) - K_gain) @ self.covariance + + # Convert back to rotation matrix + R_filtered_local = self._axis_angle_to_rotation(self.state) + R_filtered_global = R_filtered_local @ self.reference_rot + + # Re-center reference when state gets too large + if np.linalg.norm(self.state) > 1.5: + self.reference_rot = R_filtered_global.copy() + self.state = np.zeros(3) + + self.filtered_rot = R_filtered_global + return self.filtered_rot + + def reset(self): + self.state = np.zeros(3) + self.covariance = np.eye(3) * 0.1 + self.is_initialized = False + self.reference_rot = np.eye(3) + + +class VectorLowPassFilter: + """Simple low-pass filter for position.""" + + def __init__(self, alpha=0.3): + self.alpha = alpha + self.filtered_val = None + + def update(self, val): + if self.filtered_val is None: + self.filtered_val = val.copy() + return self.filtered_val + self.filtered_val = self.alpha * val + (1 - self.alpha) * self.filtered_val + return self.filtered_val + + def reset(self): + self.filtered_val = None + + +# Corner EMA filter alpha (1.0 = no smoothing, acts as per-ID state cache for reset) +CORNER_FILTER_ALPHA = 1.0 + + +class CornerEMAFilter: + """Per-marker-ID corner EMA filter. + + Maintains a dict {marker_id: (4,2) filtered corners}. + Each frame, for every detected marker, the 4 corner positions are + exponentially-smoothed with the previous frame's value. + Markers not seen for >max_age frames are evicted. + """ + + def __init__(self, alpha=CORNER_FILTER_ALPHA, max_age=5): + self.alpha = alpha + self.max_age = max_age + self._state = {} + self._age = {} + + def update(self, corners, ids): + """Filter corners in-place and return (filtered_corners, ids).""" + if ids is None or len(ids) == 0: + for mid in list(self._age): + self._age[mid] += 1 + if self._age[mid] > self.max_age: + del self._state[mid] + del self._age[mid] + return corners, ids + + seen = set() + filtered = [] + for i, mid in enumerate(ids.flatten()): + mid = int(mid) + seen.add(mid) + pts = corners[i].reshape(4, 2).astype(np.float32) + if mid in self._state: + pts = self.alpha * pts + (1 - self.alpha) * self._state[mid] + self._state[mid] = pts.copy() + self._age[mid] = 0 + filtered.append(pts.reshape(1, 4, 2).astype(np.float32)) + + for mid in list(self._age): + if mid not in seen: + self._age[mid] += 1 + if self._age[mid] > self.max_age: + del self._state[mid] + del self._age[mid] + + return filtered, ids + + def reset(self): + self._state.clear() + self._age.clear() + + +# --- Buffer backlog detection constants --- +BACKLOG_LATENCY_S = 30.0e-3 # 30ms; headless grab ≈ 20ms (waiting for camera frame) +BACKLOG_COUNT = 5 # consecutive slow grabs before flush +BACKLOG_MAX_FLUSH = 20 # safety cap on flush loop + + +class CubeWorldObserver: + """Detects cube pose relative to world coordinate system defined by AprilTag.""" + + def __init__(self, visualize=False, zmq_port=5555, + process_noise=0.01, measurement_noise=1.0, alpha=0.3, + world_sample_frames=WORLD_SAMPLE_FRAMES, + cube_config_path: str | None = None): + self.visualize = visualize + self._cube_config_path = cube_config_path or CUBE_CONFIG_FILE + + # Initialize camera + print("Initializing camera...") + MvCamera.MV_CC_Initialize() + deviceList = MV_CC_DEVICE_INFO_LIST() + MvCamera.MV_CC_EnumDevices(MV_GIGE_DEVICE | MV_USB_DEVICE, deviceList) + + if deviceList.nDeviceNum == 0: + raise RuntimeError("No camera found!") + + self.cam = MvCamera() + stDevice = cast(deviceList.pDeviceInfo[0], POINTER(MV_CC_DEVICE_INFO)).contents + self.cam.MV_CC_CreateHandle(stDevice) + self.cam.MV_CC_OpenDevice(MV_ACCESS_Exclusive, 0) + self.cam.MV_CC_SetEnumValue("TriggerMode", MV_TRIGGER_MODE_OFF) + self.cam.MV_CC_SetEnumValue("PixelFormat", PixelType_Gvsp_BayerGB8) + setup_camera_capture(self.cam, _cam_cfg) + setup_camera_roi(self.cam, _cam_cfg) + self.cam.MV_CC_StartGrabbing() + print("Camera ready!") + + # AprilTag detector for world frame + self.apriltag_detector = AprilTagDetector( + families="tag36h11", nthreads=4, quad_decimate=1.0, + quad_sigma=0.0, decode_sharpening=0.25, + ) + + # ArUco detector for cube + self.aruco_dict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_50) + # Support both old and new OpenCV API + if hasattr(cv2.aruco, 'DetectorParameters_create'): + # Old API (OpenCV < 4.7) + self.aruco_params = cv2.aruco.DetectorParameters_create() + self.aruco_params.cornerRefinementMethod = cv2.aruco.CORNER_REFINE_SUBPIX + self.aruco_detector = None # Use old-style detectMarkers + else: + # New API (OpenCV >= 4.7) + aruco_params = cv2.aruco.DetectorParameters() + aruco_params.cornerRefinementMethod = cv2.aruco.CORNER_REFINE_SUBPIX + self.aruco_detector = cv2.aruco.ArucoDetector(self.aruco_dict, aruco_params) + + # Load cube config and build board + self._load_config() + self._build_aruco_board() + + # Filters + self.filter_R = SO3KalmanFilter(process_noise=process_noise, measurement_noise=measurement_noise) + self.filter_t = VectorLowPassFilter(alpha=alpha) + + # ZMQ publisher + self.publisher = CubePublisher(port=zmq_port) + print(f"ZMQ publisher on port {zmq_port}") + + # State + self.stOutFrame = MV_FRAME_OUT() + self.world_pose = None + self.frame_count = 0 + self.last_print_time = 0 + self.last_frame_count = 0 + self._display_fps = 0.0 + self.filt_R = np.eye(3) + self.filt_t = np.zeros(3) + self.prev_quat = None # For quaternion sign continuity + self._R_cube_world = None # Cube rotation in world frame (for visualization) + self._t_cube_world = None # Cube position in world frame (for visualization) + self._dominant_face = None # Currently visible face + + # World frame sampling state + self._world_samples_R = [] # Collected rotation samples + self._world_samples_t = [] # Collected translation samples + self._world_fixed = False # Whether world frame is fixed + self._world_sample_target = world_sample_frames # Target sample count + + # --- New observer state for IPPE+ITERATIVE migration --- + _cfg = load_observer_config() + self._reproj_threshold = float(_cfg['pnp']['reproj_threshold']) + + self._enable_clahe = bool(_cfg['preprocess']['enable_clahe']) + _clip = float(_cfg['preprocess']['clahe_clip']) + _tile = tuple(int(x) for x in _cfg['preprocess']['clahe_tile']) + self._clahe = cv2.createCLAHE(clipLimit=_clip, tileGridSize=_tile) if self._enable_clahe else None + + self.corner_filter = CornerEMAFilter(alpha=CORNER_FILTER_ALPHA) + + # IPPE disambiguation state + self._ippe_locked_idx = 0 + self._lost_frames = 0 + self._prev_dominant_face = None + self._active_faces = set() + self._reproj_err = 0.0 + + # Backlog detection state + self._grab_slow_count = 0 + + def _load_config(self): + """Load cube configuration from file. + + Raises: + FileNotFoundError: if the cube_tags*.json path does not exist — + cube geometry must be specified explicitly (no silent defaults). + KeyError: if the JSON is missing required cube_size/tag_size keys. + """ + self._tag_map = None + self._face_axes_cfg = None + self._face_rotations = {'TOP': 0, 'BOTTOM': 0, 'FRONT': 0, 'BACK': 0, 'LEFT': 0, 'RIGHT': 0} + + if not os.path.exists(self._cube_config_path): + raise FileNotFoundError( + f"cube tags JSON not found: {self._cube_config_path}. " + "Specify with --cube (e.g. --cube 36 / --cube 40_5 / default 54mm)." + ) + + try: + with open(self._cube_config_path, 'r') as f: + cfg = json.load(f) + # Required: cube_size + tag_size + tag_center_offset. No silent defaults. + try: + self._cube_size = float(cfg['cube_size']) + self._tag_size = float(cfg['tag_size']) + self._tag_offset = float(cfg['tag_center_offset']) + except KeyError as e: + raise KeyError( + f"{self._cube_config_path} is missing required key {e}; " + "cube_size, tag_size and tag_center_offset are not allowed to be defaulted." + ) from None + faces_cfg = cfg.get('faces_config', {}) + self._tag_map = {face: {int(k): v for k, v in tags.items()} for face, tags in faces_cfg.items()} + self._face_axes_cfg = cfg.get('face_axes', None) + for face, rot in cfg.get('face_rotations', {}).items(): + self._face_rotations[face] = rot + print(f"Loaded cube config from {self._cube_config_path}") + print(f" cube_size={self._cube_size*1000:.1f}mm " + f"tag_size={self._tag_size*1000:.2f}mm " + f"tag_center_offset={self._tag_offset*1000:.2f}mm") + except json.JSONDecodeError as e: + print(f"Warning: Failed to parse config JSON: {e}") + + # Build tag to face mapping + self._tag_to_face = {} + if self._tag_map: + for face, tags in self._tag_map.items(): + for tid in tags.keys(): + self._tag_to_face[tid] = face + + def _build_aruco_board(self): + """Build ArUco Board with all cube tags' 3D positions.""" + half = self._cube_size / 2 + ht = self._tag_size / 2 + off = self._tag_offset + + def rotate_corners(corners, rotation): + n = (rotation // 90) % 4 + if n == 0: return corners + elif n == 1: return np.array([corners[3], corners[0], corners[1], corners[2]]) + elif n == 2: return np.array([corners[2], corners[3], corners[0], corners[1]]) + else: return np.array([corners[1], corners[2], corners[3], corners[0]]) + + def face_tags(face_center, u_axis, v_axis, rotation=0): + tags = {} + for pos, center in [('T', face_center + off * v_axis), ('B', face_center - off * v_axis), + ('L', face_center - off * u_axis), ('R', face_center + off * u_axis)]: + corners = np.array([ + center - ht * u_axis + ht * v_axis, center + ht * u_axis + ht * v_axis, + center + ht * u_axis - ht * v_axis, center - ht * u_axis - ht * v_axis, + ], dtype=np.float32) + tags[pos] = rotate_corners(corners, rotation) + return tags + + if self._face_axes_cfg: + faces = {name: (np.array(axes['center'], dtype=np.float64) * half, + np.array(axes['u'], dtype=np.float64), + np.array(axes['v'], dtype=np.float64)) + for name, axes in self._face_axes_cfg.items()} + else: + X, Y, Z = np.array([1,0,0]), np.array([0,1,0]), np.array([0,0,1]) + faces = {'TOP': (half*Z, X, Y), 'BOTTOM': (-half*Z, X, -Y), 'FRONT': (-half*Y, X, Z), + 'BACK': (half*Y, -X, Z), 'LEFT': (-half*X, -Y, Z), 'RIGHT': (half*X, Y, Z)} + + tag_map = self._tag_map or { + 'TOP': {0:'L',1:'B',2:'T',3:'R'}, 'BOTTOM': {8:'R',9:'T',10:'B',11:'L'}, + 'FRONT': {16:'R',17:'T',18:'B',19:'L'}, 'BACK': {20:'B',21:'R',22:'L',23:'T'}, + 'LEFT': {4:'R',5:'T',6:'B',7:'L'}, 'RIGHT': {12:'B',13:'R',14:'L',15:'T'}, + } + + board_corners, board_ids = [], [] + for face_name, (center, u, v) in faces.items(): + tags = face_tags(center, u, v, self._face_rotations.get(face_name, 0)) + for tid, pos in tag_map[face_name].items(): + board_corners.append(tags[pos]) + board_ids.append([tid]) + + sorted_idx = np.argsort([b[0] for b in board_ids]) + board_corners = [board_corners[i] for i in sorted_idx] + board_ids = np.array([board_ids[i] for i in sorted_idx], dtype=np.int32) + + # Support both old and new OpenCV API + if hasattr(cv2.aruco, 'Board_create'): + # Old API (OpenCV < 4.7) + self.cube_board = cv2.aruco.Board_create(board_corners, self.aruco_dict, board_ids) + else: + # New API (OpenCV >= 4.7) + self.cube_board = cv2.aruco.Board(board_corners, self.aruco_dict, board_ids) + print(f"ArUco Board: {len(board_ids)} tags") + + def _match_image_points(self, corners, ids): + """Match detected corners/ids to board - compatibility wrapper for old/new API.""" + if hasattr(self.cube_board, 'matchImagePoints'): + # New API (OpenCV >= 4.7) + return self.cube_board.matchImagePoints(corners, ids) + else: + # Old API (OpenCV < 4.7) - manually match + obj_pts = [] + img_pts = [] + if ids is None or len(ids) == 0: + return None, None + + board_ids_flat = self.cube_board.ids.flatten() + for i, marker_id in enumerate(ids.flatten()): + # Find this marker in the board + board_idx = np.where(board_ids_flat == marker_id)[0] + if len(board_idx) > 0: + board_idx = board_idx[0] + # Add all 4 corners of this marker + obj_pts.extend(self.cube_board.objPoints[board_idx]) + img_pts.extend(corners[i][0]) + + if len(obj_pts) == 0: + return None, None + + return np.array(obj_pts, dtype=np.float32), np.array(img_pts, dtype=np.float32) + + def detect_world_tag(self, gray): + """Detect world AprilTag and return its pose in camera frame.""" + results = self.apriltag_detector.detect( + gray, estimate_tag_pose=True, + camera_params=(K[0, 0], K[1, 1], K[0, 2], K[1, 2]), + tag_size=WORLD_TAG_SIZE + ) + for r in results: + if r.tag_id == WORLD_TAG_ID: + return r.pose_R, r.pose_t.flatten(), r.corners + return None, None, None + + def _average_rotations(self, rotations): + """Average multiple rotation matrices using quaternion averaging.""" + if len(rotations) == 0: + return np.eye(3) + + # Convert to quaternions + quats = [] + for R in rotations: + rot = Rotation.from_matrix(R) + q = rot.as_quat() # (x, y, z, w) + quats.append(q) + + quats = np.array(quats) + + # Ensure quaternion sign consistency (all pointing same hemisphere) + for i in range(1, len(quats)): + if np.dot(quats[i], quats[0]) < 0: + quats[i] = -quats[i] + + # Average quaternions and normalize + avg_quat = quats.mean(axis=0) + avg_quat /= np.linalg.norm(avg_quat) + + return Rotation.from_quat(avg_quat).as_matrix() + + def start_world_sampling(self): + """Start/restart world frame sampling.""" + self._world_samples_R = [] + self._world_samples_t = [] + self._world_fixed = False + self.world_pose = None + print(f"\n[World Sampling] Starting... (collecting {self._world_sample_target} frames)") + + def _finalize_world_frame(self): + """Finalize world frame from collected samples.""" + if len(self._world_samples_R) < 10: + print(f"[World Sampling] Failed: only {len(self._world_samples_R)} samples collected") + return False + + # Average rotation matrices + avg_R = self._average_rotations(self._world_samples_R) + + # Average translations + avg_t = np.mean(self._world_samples_t, axis=0) + + # Apply world frame correction if specified + if WORLD_FRAME_CORRECTION is not None: + # Parse string format or use matrix directly + if isinstance(WORLD_FRAME_CORRECTION, str): + correction_R = parse_axis_remap(WORLD_FRAME_CORRECTION) + print(f"[World Sampling] Axis remap: {WORLD_FRAME_CORRECTION}") + else: + correction_R = np.array(WORLD_FRAME_CORRECTION) + + # R_corrected transforms points from corrected world frame to camera frame + # If R_apriltag transforms from AprilTag frame to camera frame, + # and correction_R transforms from AprilTag frame to new world frame, + # then: R_new_to_cam = R_apriltag @ correction_R.T + avg_R = avg_R @ correction_R.T + print(f"[World Sampling] Applied world frame correction (det={np.linalg.det(correction_R):.1f})") + + self.world_pose = (avg_R, avg_t) + self._world_fixed = True + + print(f"[World Sampling] Complete! Averaged {len(self._world_samples_R)} samples") + print(f"[World Sampling] World frame is now FIXED. Press 'w' to resample.") + + # Switch to hardware fast ROI (headless only; preview keeps full frame) + if not self.visualize: + self._switch_to_fast_roi() + + return True + + def _switch_to_fast_roi(self): + """Switch camera to hardware fast_roi for high-speed cube tracking (headless).""" + global K + fast_roi = _cam_cfg.get('fast_roi') + if fast_roi is None: + return + cur_roi = _cam_cfg['roi'] + if (fast_roi['width'] == cur_roi['width'] + and fast_roi['height'] == cur_roi['height'] + and fast_roi['offset_x'] == cur_roi['offset_x'] + and fast_roi['offset_y'] == cur_roi['offset_y']): + return # already at fast ROI + + print(f"[Fast ROI] Switching to {fast_roi['width']}x{fast_roi['height']} " + f"@ ({fast_roi['offset_x']}, {fast_roi['offset_y']}) ...") + self.cam.MV_CC_StopGrabbing() + self.cam.MV_CC_SetIntValueEx("OffsetX", 0) + self.cam.MV_CC_SetIntValueEx("OffsetY", 0) + self.cam.MV_CC_SetIntValueEx("Width", fast_roi['width']) + self.cam.MV_CC_SetIntValueEx("Height", fast_roi['height']) + self.cam.MV_CC_SetIntValueEx("OffsetX", fast_roi['offset_x']) + self.cam.MV_CC_SetIntValueEx("OffsetY", fast_roi['offset_y']) + self.cam.MV_CC_StartGrabbing() + + # Update global K for new ROI offset + intr = _cam_cfg['intrinsics'] + K[0, 2] = intr['cx'] - fast_roi['offset_x'] + K[1, 2] = intr['cy'] - fast_roi['offset_y'] + print(f"[Fast ROI] Active. K updated: cx={K[0,2]:.1f}, cy={K[1,2]:.1f}") + + def detect_cube_pose(self, corners, ids): + """Detect cube pose via IPPE + ITERATIVE hybrid with dominant-face strategy. + + Pipeline: + 1. Dominant-face selection with hysteresis. + 2. IPPE (coplanar analytical) returns two candidate solutions. + 3. Disambiguate with locked-index hysteresis: switch only if other + solution is strictly better on BOTH reproj and geodesic distance. + 4. ITERATIVE refinement using the chosen IPPE solution as guess. + 5. Reprojection-error gate; filter reset on reacquire. + """ + if ids is None or len(ids) == 0: + self._dominant_face = None + self._lost_frames += 1 + return None, None, 0 + + # --- Count markers per face --- + face_counts = {} + for tid in ids.flatten(): + if int(tid) in self._tag_to_face: + face = self._tag_to_face[int(tid)] + face_counts[face] = face_counts.get(face, 0) + 1 + + if not face_counts: + self._dominant_face = None + self._lost_frames += 1 + return None, None, 0 + + # --- Dominant face with hysteresis --- + best_face = max(face_counts, key=face_counts.get) + if (self._prev_dominant_face is not None + and self._prev_dominant_face in face_counts + and face_counts.get(self._prev_dominant_face, 0) >= face_counts[best_face]): + best_face = self._prev_dominant_face + self._dominant_face = best_face + self._prev_dominant_face = best_face + self._active_faces = {best_face} + + valid_indices = [i for i, tid in enumerate(ids.flatten()) + if int(tid) in self._tag_to_face and self._tag_to_face[int(tid)] == best_face] + if valid_indices: + corners = [corners[i] for i in valid_indices] + ids = ids[valid_indices] + + obj_pts, img_pts = self._match_image_points(corners, ids) + if obj_pts is None or len(obj_pts) < 4: + self._lost_frames += 1 + return None, None, 0 + + # --- Step 1: IPPE returns both coplanar solutions (sol 0 has lower reproj) --- + n_sol, rvecs_ippe, tvecs_ippe, reproj_errors = cv2.solvePnPGeneric( + obj_pts, img_pts, K, DIST_COEFFS, flags=cv2.SOLVEPNP_IPPE) + + if n_sol == 0: + self._lost_frames += 1 + return None, None, 0 + + # --- Step 2: Disambiguate IPPE solutions --- + # Lock onto current pick; switch only if other is clearly better + # on BOTH reproj (<80%) AND geodesic distance (<33%). + if n_sol == 1: + best_idx = 0 + elif not self.filter_R.is_initialized or self._lost_frames > 0: + best_idx = 0 + else: + R_prev = self.filt_R + dists = [] + for i in range(n_sol): + R_i, _ = cv2.Rodrigues(rvecs_ippe[i]) + diff = cv2.Rodrigues(R_prev.T @ R_i)[0] + dists.append(np.linalg.norm(diff)) + + locked = self._ippe_locked_idx + other = 1 - locked + re_locked = reproj_errors[locked].item() + re_other = reproj_errors[other].item() + + if (re_other < re_locked * 0.8) and (dists[other] < dists[locked] * 0.33): + best_idx = other + else: + best_idx = locked + + self._ippe_locked_idx = best_idx + pick_rvec, pick_tvec = rvecs_ippe[best_idx], tvecs_ippe[best_idx] + + # --- Step 3: ITERATIVE refinement with IPPE pick as initial guess --- + success, rvec, tvec = cv2.solvePnP( + obj_pts, img_pts, K, DIST_COEFFS, + rvec=pick_rvec.copy(), tvec=pick_tvec.copy(), + useExtrinsicGuess=True, + flags=cv2.SOLVEPNP_ITERATIVE, + ) + + if not success: + self._lost_frames += 1 + return None, None, 0 + + # --- Step 4: Reprojection-error gate --- + reproj_pts, _ = cv2.projectPoints(obj_pts, rvec, tvec, K, DIST_COEFFS) + reproj_err = float(np.mean(np.linalg.norm( + img_pts.reshape(-1, 2) - reproj_pts.reshape(-1, 2), axis=1))) + self._reproj_err = reproj_err + if reproj_err > self._reproj_threshold: + self._lost_frames += 1 + return None, None, 0 + + # --- Step 5: Reset filters on reacquire so stale state doesn't contaminate --- + if self._lost_frames > 0: + self.corner_filter.reset() + self.filter_R.reset() + self.filter_t.reset() + self.prev_quat = None + + self._lost_frames = 0 + + # --- Update filters --- + R, _ = cv2.Rodrigues(rvec) + self.filt_R = self.filter_R.update(R) + self.filt_t = self.filter_t.update(tvec.flatten()) + + return self.filt_R, self.filt_t, len(ids) + + def transform_to_world_frame(self, R_cube_cam, t_cube_cam): + """Transform cube pose from camera frame to world frame.""" + if self.world_pose is None: + return None, None + R_world_cam, t_world_cam = self.world_pose + R_cam_world = R_world_cam.T + t_cam_world = -R_cam_world @ t_world_cam + R_cube_world = R_cam_world @ R_cube_cam + t_cube_world = R_cam_world @ t_cube_cam + t_cam_world + return R_cube_world, t_cube_world + + def _draw_world_axes(self, img, axis_length=0.03, line_width=4): + """Draw world coordinate axes with RGB colors for XYZ.""" + if self.world_pose is None: + return + + R_world, t_world = self.world_pose + + # Project origin and axis endpoints to image + origin = t_world.reshape(3, 1) + x_end = origin + axis_length * R_world[:, 0:1] + y_end = origin + axis_length * R_world[:, 1:2] + z_end = origin + axis_length * R_world[:, 2:3] + + # Project to 2D + origin_2d, _ = cv2.projectPoints(origin.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + x_2d, _ = cv2.projectPoints(x_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + y_2d, _ = cv2.projectPoints(y_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + z_2d, _ = cv2.projectPoints(z_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + + origin_pt = tuple(origin_2d[0, 0].astype(int)) + x_pt = tuple(x_2d[0, 0].astype(int)) + y_pt = tuple(y_2d[0, 0].astype(int)) + z_pt = tuple(z_2d[0, 0].astype(int)) + + # Draw axes: X=Red, Y=Green, Z=Blue (BGR format) + cv2.arrowedLine(img, origin_pt, x_pt, (0, 0, 255), line_width, tipLength=0.3) # X - Red + cv2.arrowedLine(img, origin_pt, y_pt, (0, 255, 0), line_width, tipLength=0.3) # Y - Green + cv2.arrowedLine(img, origin_pt, z_pt, (255, 0, 0), line_width, tipLength=0.3) # Z - Blue + + # Draw axis labels at arrow tips + cv2.putText(img, "+X", x_pt, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) + cv2.putText(img, "+Y", y_pt, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + cv2.putText(img, "+Z", z_pt, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2) + + # Draw origin marker + cv2.circle(img, origin_pt, 5, (255, 255, 255), -1) + + def _draw_cube_axes_in_world(self, img, R_cube_world, t_cube_world, axis_length=0.025, line_width=2): + """Draw cube coordinate axes in world frame (transformed to camera for display).""" + if self.world_pose is None: + return + + R_world_cam, t_world_cam = self.world_pose + + # self.world_pose = (R_world_cam, t_world_cam) where: + # - R_world_cam: rotation matrix whose columns are world axes in camera frame + # - t_world_cam: world origin position in camera frame + + # For a point P_world in world coordinates, its camera coordinates: + # P_cam = R_world_cam @ P_world + t_world_cam (if R_world_cam rotates world->cam) + # But actually from the code, R_world_cam columns are world axes in camera frame + # So R_world_cam @ P_world gives P_cam (without translation consideration for rotation) + + # Cube position in camera frame + t_cube_cam = R_world_cam @ t_cube_world + t_world_cam + + # Cube axes in camera frame + # R_cube_world columns are cube axes in world frame + # R_world_cam @ (cube axis in world) = cube axis in camera + R_cube_cam = R_world_cam @ R_cube_world + + # Draw axes at cube position + origin = t_cube_cam.reshape(3, 1) + x_end = origin + axis_length * R_cube_cam[:, 0:1] + y_end = origin + axis_length * R_cube_cam[:, 1:2] + z_end = origin + axis_length * R_cube_cam[:, 2:3] + + # Project to 2D + origin_2d, _ = cv2.projectPoints(origin.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + x_2d, _ = cv2.projectPoints(x_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + y_2d, _ = cv2.projectPoints(y_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + z_2d, _ = cv2.projectPoints(z_end.T, np.zeros(3), np.zeros(3), K, DIST_COEFFS) + + origin_pt = tuple(origin_2d[0, 0].astype(int)) + x_pt = tuple(x_2d[0, 0].astype(int)) + y_pt = tuple(y_2d[0, 0].astype(int)) + z_pt = tuple(z_2d[0, 0].astype(int)) + + # Draw with lighter colors to distinguish from world axes + cv2.arrowedLine(img, origin_pt, x_pt, (100, 100, 255), line_width, tipLength=0.3) # X - Light Red + cv2.arrowedLine(img, origin_pt, y_pt, (100, 255, 100), line_width, tipLength=0.3) # Y - Light Green + cv2.arrowedLine(img, origin_pt, z_pt, (255, 100, 100), line_width, tipLength=0.3) # Z - Light Blue + + def run(self): + """Main detection loop.""" + print("\nCube World Observer running...") + print(" Press 'q' to quit, 'r' to reset filters, 'w' to resample world frame\n") + + # Start world frame sampling on startup + self.start_world_sampling() + + while True: + _tA = time.perf_counter() + ret = self.cam.MV_CC_GetImageBuffer(self.stOutFrame, 100) + if ret != 0: + continue + grab_dt = time.perf_counter() - _tA + + # --- Buffer backlog detection & recovery --- + if grab_dt > BACKLOG_LATENCY_S: + self._grab_slow_count += 1 + if self._grab_slow_count >= BACKLOG_COUNT: + # Flush: free current frame, then drain queued frames (bounded) + self.cam.MV_CC_FreeImageBuffer(self.stOutFrame) + flushed = 0 + while flushed < BACKLOG_MAX_FLUSH: + r = self.cam.MV_CC_GetImageBuffer(self.stOutFrame, 1) + if r != 0: + break + self.cam.MV_CC_FreeImageBuffer(self.stOutFrame) + flushed += 1 + # Re-grab a fresh frame + ret = self.cam.MV_CC_GetImageBuffer(self.stOutFrame, 100) + if ret != 0: + self._grab_slow_count = 0 + continue + print(f"[FLUSH] buffer backlog detected (grab={grab_dt*1000:.1f}ms), " + f"drained {flushed} stale frames") + self._grab_slow_count = 0 + else: + self._grab_slow_count = 0 + + self.frame_count += 1 + nH = self.stOutFrame.stFrameInfo.nHeight + nW = self.stOutFrame.stFrameInfo.nWidth + data = string_at(self.stOutFrame.pBufAddr, self.stOutFrame.stFrameInfo.nFrameLen) + bayer = np.frombuffer(data, dtype=np.uint8).reshape(nH, nW) + + # Demosaic to BGR (always, used both for min-channel gray and visualization) + bgr = cv2.cvtColor(bayer, cv2.COLOR_BayerGB2BGR) + + # Min-channel: white→255, any color→≈0; robust for dark ArUco on white cube + gray_min = np.minimum(np.minimum(bgr[:, :, 0], bgr[:, :, 1]), bgr[:, :, 2]) + + # Optional CLAHE contrast enhancement (full-frame, before any ROI crop) + if self._clahe is not None: + gray = self._clahe.apply(gray_min) + else: + gray = gray_min + + color = bgr if self.visualize else None + + # Detect world AprilTag (skip when world frame is already fixed) + if not self._world_fixed: + R_world, t_world, world_corners = self.detect_world_tag(gray) + world_detected = R_world is not None + # Sampling mode: collect samples + if world_detected: + self._world_samples_R.append(R_world) + self._world_samples_t.append(t_world) + if len(self._world_samples_R) >= self._world_sample_target: + self._finalize_world_frame() + else: + world_detected = True + world_corners = None + + # Detect cube ArUco tags + # In preview mode with world fixed, use software ROI crop for speed + fast_roi = _cam_cfg.get('fast_roi') + _use_sw_roi = (self.visualize and self._world_fixed + and fast_roi is not None) + if _use_sw_roi: + rx, ry = fast_roi['offset_x'], fast_roi['offset_y'] + rw, rh = fast_roi['width'], fast_roi['height'] + gray_roi = gray[ry:ry+rh, rx:rx+rw] + else: + gray_roi = gray + rx, ry = 0, 0 + + if self.aruco_detector is None: + corners, ids, _ = cv2.aruco.detectMarkers(gray_roi, self.aruco_dict, parameters=self.aruco_params) + else: + corners, ids, _ = self.aruco_detector.detectMarkers(gray_roi) + + # Map corners back to full-frame coordinates + if _use_sw_roi and ids is not None and len(ids) > 0: + corners = [c + np.array([[[rx, ry]]], dtype=c.dtype) for c in corners] + cube_quat_world = None + cube_pos_world = None + n_tags = 0 + + if ids is not None and len(ids) > 0: + mask = (ids.flatten() >= 0) & (ids.flatten() <= 23) + if mask.any(): + corners = [corners[i] for i in range(len(corners)) if mask[i]] + ids = ids[mask] + else: + corners, ids = [], None + + # Corner-level EMA filter (state cache + reset hook before PnP) + if ids is not None and len(ids) > 0: + corners, ids = self.corner_filter.update(corners, ids) + else: + self.corner_filter.update([], None) + + R_cube_cam = None + if ids is not None and len(ids) > 0: + R_cube_cam, t_cube_cam, n_tags = self.detect_cube_pose(corners, ids) + # _lost_frames managed inside detect_cube_pose on the success / fail branches + else: + # No cube markers this frame — still counts as a lost frame so the + # IPPE disambiguation reset logic sees a fresh reacquire next time. + self._lost_frames += 1 + + if R_cube_cam is not None and self.world_pose is not None: + R_cube_world, t_cube_world = self.transform_to_world_frame(R_cube_cam, t_cube_cam) + if R_cube_world is not None: + # Apply cube frame correction if specified + # This corrects the difference between ArUco board axes and MuJoCo mesh axes + if CUBE_FRAME_CORRECTION is not None: + if isinstance(CUBE_FRAME_CORRECTION, str): + cube_correction_R = parse_axis_remap(CUBE_FRAME_CORRECTION) + else: + cube_correction_R = np.array(CUBE_FRAME_CORRECTION) + # R_cube_world_corrected = R_cube_world @ cube_correction_R.T + R_cube_world = R_cube_world @ cube_correction_R.T + + # Frame: wrist-tag frame (observer-native). Consumers + # (RealHandEnv via CubeReceiver) treat this as tag-frame + # cube pose, fed directly to policy obs via the deploy-side + # override funcs in lib/real_hand_obs.py. + # Store for visualization + self._R_cube_world = R_cube_world + self._t_cube_world = t_cube_world + + # R_cube_world: transforms from cube frame TO world frame + # For MuJoCo mocap body, we need the quaternion that represents + # cube orientation in world frame (rotation from cube to world) + rot = Rotation.from_matrix(R_cube_world) + quat = rot.as_quat() # (x, y, z, w) + + # Quaternion sign continuity: q and -q represent same rotation + # Choose sign to minimize distance from previous quaternion + if self.prev_quat is not None: + if np.dot(quat, self.prev_quat) < 0: + quat = -quat + self.prev_quat = quat.copy() + + cube_quat_world = quat + cube_pos_world = t_cube_world + + # Publish via ZMQ. + # Frame: wrist-tag frame (observer-native, from transform_to_world_frame). + # Consumers (RealHandEnv via CubeReceiver) treat this as tag-frame + # cube pose, fed directly to policy obs via the deploy-side override + # funcs in lib/real_hand_obs.py — no mjworld round-trip. + # Quaternion order: (x, y, z, w) — scipy convention. CubeReceiver + # converts to MuJoCo (w, x, y, z) on receive (see zmq_bridge.py). + if cube_quat_world is not None: + q_xyzw = cube_quat_world # scipy xyzw, as wuji computed it + quat_wxyz = np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]]) + self.publisher.publish(cube_pos_world, quat_wxyz, world_fixed=self._world_fixed, cube_size=float(self._cube_size)) + + # Visualization + if self.visualize and color is not None: + # Draw world frame axes (RGB = XYZ) + if self._world_fixed and self.world_pose is not None: + self._draw_world_axes(color) + + if world_corners is not None: + cv2.polylines(color, [world_corners.astype(int)], True, (255, 0, 255), 3) + + # Draw detection ROI when active + if _use_sw_roi: + cv2.rectangle(color, (rx, ry), (rx + rw, ry + rh), (0, 200, 200), 2) + + if ids is not None and len(ids) > 0: + cv2.aruco.drawDetectedMarkers(color, corners, ids) + + if n_tags > 0 and self._R_cube_world is not None: + # Draw cube axes in world frame (lighter colors) + self._draw_cube_axes_in_world(color, self._R_cube_world, self._t_cube_world) + + # World frame status display + if not self._world_fixed: + # Sampling mode: show progress bar + n_samples = len(self._world_samples_R) + progress = n_samples / self._world_sample_target + bar_width = 200 + bar_height = 20 + cv2.rectangle(color, (10, 10), (10 + bar_width, 10 + bar_height), (50, 50, 50), -1) + cv2.rectangle(color, (10, 10), (10 + int(bar_width * progress), 10 + bar_height), (0, 255, 255), -1) + cv2.rectangle(color, (10, 10), (10 + bar_width, 10 + bar_height), (255, 255, 255), 1) + cv2.putText(color, f"World Sampling: {n_samples}/{self._world_sample_target}", + (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) + else: + # Fixed mode: show status + cv2.putText(color, "WORLD FIXED", (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) + if world_detected: + cv2.putText(color, "(tag visible)", (180, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (100, 255, 100), 1) + + cv2.putText(color, f"Tags: {n_tags}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + # Display dominant face with color + if self._dominant_face and self._dominant_face in FACE_COLORS: + face_name = self._dominant_face + color_name, face_bgr = FACE_COLORS[face_name] + # Draw color block + text + cv2.rectangle(color, (10, 75), (40, 105), face_bgr, -1) + cv2.rectangle(color, (10, 75), (40, 105), (255, 255, 255), 1) + cv2.putText(color, f"{face_name} ({color_name})", (50, 98), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, face_bgr, 2) + + if cube_quat_world is not None: + rpy = Rotation.from_quat(cube_quat_world).as_euler('xyz', degrees=True) + cv2.putText(color, f"RPY: ({rpy[0]:+.1f}, {rpy[1]:+.1f}, {rpy[2]:+.1f})", (10, 130), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) + + # FPS overlay (top-right) + fps_color = (0, 255, 0) if self._display_fps >= 20 else (0, 165, 255) + fps_text = f"FPS: {self._display_fps:.1f}" + (tw, th), _ = cv2.getTextSize(fps_text, cv2.FONT_HERSHEY_SIMPLEX, 0.8, 2) + cv2.putText(color, fps_text, (color.shape[1] - tw - 10, th + 10), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, fps_color, 2) + + # Key hints + cv2.putText(color, "q:quit r:reset w:resample world s:select ROI", (10, 755), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1) + + cv2.imshow('Cube World Observer', cv2.resize(color, (960, 768))) + + self.cam.MV_CC_FreeImageBuffer(self.stOutFrame) + + # Print status periodically + now = time.time() + if now - self.last_print_time > 2.0: + elapsed = now - self.last_print_time if self.last_print_time > 0 else 1.0 + fps = (self.frame_count - self.last_frame_count) / elapsed + self._display_fps = fps + self.last_frame_count = self.frame_count + if not self._world_fixed: + n_samples = len(self._world_samples_R) + print(f"[{self.frame_count:6d}] FPS: {fps:5.1f} | World Sampling: {n_samples}/{self._world_sample_target}") + elif cube_quat_world is not None: + rpy = Rotation.from_quat(cube_quat_world).as_euler('xyz', degrees=True) + # cube_quat_world is (x,y,z,w) from scipy + qx, qy, qz, qw = cube_quat_world + px, py, pz = cube_pos_world + print(f"[{self.frame_count:6d}] FPS: {fps:5.1f} | World: FIXED | Tags: {n_tags} | " + f"Pos: ({px:+.4f}, {py:+.4f}, {pz:+.4f}) | " + f"Quat(wxyz): ({qw:+.4f}, {qx:+.4f}, {qy:+.4f}, {qz:+.4f}) | " + f"Quat(xyzw): ({qx:+.4f}, {qy:+.4f}, {qz:+.4f}, {qw:+.4f}) | " + f"RPY: ({rpy[0]:+6.1f}, {rpy[1]:+6.1f}, {rpy[2]:+6.1f})") + else: + print(f"[{self.frame_count:6d}] FPS: {fps:5.1f} | World: FIXED | Cube: NOT DETECTED") + self.last_print_time = now + + if self.visualize: + key = cv2.pollKey() & 0xFF + if key == ord('q'): + break + elif key == ord('r'): + # Reset cube filters only (not world frame) + self.filter_R.reset() + self.filter_t.reset() + self.prev_quat = None + print("Cube filters reset!") + elif key == ord('w'): + # Resample world frame + self.start_world_sampling() + # Also reset cube filters since world frame changed + self.filter_R.reset() + self.filter_t.reset() + self.prev_quat = None + elif key == ord('s'): + self._select_and_save_fast_roi(bgr) + + def _select_and_save_fast_roi(self, current_frame): + """Open selectROI dialog, save selection to config/camera.yaml and apply live. + + Draws a rectangle on the full-resolution frame. Coordinates are in full + sensor coordinates (offset_x/offset_y from 0,0). Persists to camera.yaml. + """ + import os + + if not isinstance(current_frame, np.ndarray) or current_frame.ndim != 3: + print(f"[ROI] ERROR: expected BGR image, got {type(current_frame).__name__}") + return + + print("\n[ROI] Drag a rectangle on the frame. ENTER/SPACE to confirm, C to cancel.") + # Use a resized preview for selection (same size as main imshow) + display_size = (960, 768) + display = cv2.resize(current_frame, display_size) + scale_x = current_frame.shape[1] / display_size[0] + scale_y = current_frame.shape[0] / display_size[1] + + x, y, w, h = cv2.selectROI( + "Select ROI (ENTER/SPACE=confirm, C=cancel)", display, + showCrosshair=True, fromCenter=False, + ) + cv2.destroyWindow("Select ROI (ENTER/SPACE=confirm, C=cancel)") + + if w == 0 or h == 0: + print("[ROI] Selection cancelled.") + return + + # Scale back to full resolution + offset_x = int(round(x * scale_x)) + offset_y = int(round(y * scale_y)) + width = int(round(w * scale_x)) + height = int(round(h * scale_y)) + + # Some Hikvision cameras require width/height to be multiples of 4 or 8 + width = (width // 8) * 8 + height = (height // 8) * 8 + offset_x = (offset_x // 8) * 8 + offset_y = (offset_y // 8) * 8 + if width < 64 or height < 64: + print(f"[ROI] Selection too small ({width}x{height}), ignored.") + return + + print(f"[ROI] New fast_roi: offset=({offset_x},{offset_y}) size={width}x{height}") + + # Persist to config/camera.yaml via yaml load → modify → atomic write + yaml_path = os.path.join(ROOT_DIR, "config", "camera.yaml") + try: + with open(yaml_path, "r") as f: + cfg = yaml.safe_load(f) + cfg["fast_roi"] = { + "offset_x": offset_x, + "offset_y": offset_y, + "width": width, + "height": height, + } + tmp_path = yaml_path + ".tmp" + with open(tmp_path, "w") as f: + yaml.safe_dump(cfg, f, default_flow_style=False, sort_keys=False) + os.replace(tmp_path, yaml_path) + print(f"[ROI] Saved to {yaml_path}") + except Exception as exc: + print(f"[ROI] Failed to save: {exc}") + # Clean up stray tmp file if any + try: + os.remove(yaml_path + ".tmp") + except OSError: + pass + return + + # Apply live: update in-memory config so software ROI uses new values + _cam_cfg["fast_roi"] = { + "offset_x": offset_x, + "offset_y": offset_y, + "width": width, + "height": height, + } + print("[ROI] Applied. Next frames will use the new fast_roi.") + + def cleanup(self): + """Release resources.""" + if self.visualize: + cv2.destroyAllWindows() + self.cam.MV_CC_StopGrabbing() + self.cam.MV_CC_CloseDevice() + self.cam.MV_CC_DestroyHandle() + MvCamera.MV_CC_Finalize() + self.publisher.close() + print("Cleanup done.") + + +def main(): + import argparse + + # Load config from file first + cfg = load_observer_config() + + parser = argparse.ArgumentParser(description="Cube World Observer") + parser.add_argument('--preview', action='store_true', help="Show preview window") + parser.add_argument('--port', type=int, default=None, help="ZMQ port (override config)") + parser.add_argument('--process-noise', type=float, default=None, help="SO3 Kalman Q (override config)") + parser.add_argument('--measurement-noise', type=float, default=None, help="SO3 Kalman R (override config)") + parser.add_argument('--alpha', type=float, default=None, help="Position LP alpha (override config)") + parser.add_argument('--world-samples', type=int, default=WORLD_SAMPLE_FRAMES, + help=f"Number of frames to sample for world frame (default: {WORLD_SAMPLE_FRAMES})") + parser.add_argument('--cube', type=str, default=None, + help="Cube tags config: a size suffix (e.g. '36', '40_5') " + "resolving to config/cube_tags.json, or a literal " + "path. Default: config/cube_tags.json (54mm).") + args = parser.parse_args() + cube_config_path = resolve_cube_config_path(args.cube) + + # Use config values, CLI args override + # ZMQ port from unified config_loader (control.yaml) + port = args.port if args.port is not None else DEFAULT_CUBE_PORT + process_noise = args.process_noise if args.process_noise is not None else cfg['rotation_filter']['process_noise'] + measurement_noise = args.measurement_noise if args.measurement_noise is not None else cfg['rotation_filter']['measurement_noise'] + alpha = args.alpha if args.alpha is not None else cfg['position_filter']['alpha'] + + print(f"Filter params: process_noise={process_noise}, measurement_noise={measurement_noise}, alpha={alpha}") + print(f"World frame sampling: {args.world_samples} frames") + + observer = CubeWorldObserver( + visualize=args.preview, + zmq_port=port, + process_noise=process_noise, + measurement_noise=measurement_noise, + alpha=alpha, + world_sample_frames=args.world_samples, + cube_config_path=cube_config_path, + ) + try: + observer.run() + except KeyboardInterrupt: + print("\nInterrupted.") + finally: + observer.cleanup() + + +if __name__ == "__main__": + main() diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/hand_utils.py b/examples/wuji/src/genelab_wuji/deploy/scripts/hand_utils.py new file mode 100644 index 00000000..a33ab864 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/hand_utils.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +"""Wuji-hand hardware utilities — the first sim2real step (Genesis-native port). + +Single entry-point for low-level hand operations, mirroring wuji-mjlab's +``deploy/reorient/scripts/hand_utils.py``. Two subcommands: + + check READ-ONLY connection + encoder sanity check. Does not write any + targets; the hand stays where it is. Run this FIRST. + home Ramp all 20 joints to the reorient home grasp pose via a 3s ease-in-out + (``WujiHandDriver.home``). Enables joints on entry, disables on exit, + and reports tracking error. + +Usage: + python -m genelab_wuji.deploy.scripts.hand_utils check + python -m genelab_wuji.deploy.scripts.hand_utils home + +The home pose is ``REORIENT_JOINT_POS`` (via ``config.default_joint_pos``) — the +same grasp keyframe the policy and ``MockHandDriver`` start from, so after ``home`` +the real hand matches the sim reset state. Needs the optional ``wujihandpy`` hand +SDK; it is imported lazily so this module stays importable without hardware. +""" + +from __future__ import annotations + +import argparse +import sys + +import numpy as np + +from genelab_wuji.deploy.config import JOINT_NAMES_20, N_JOINTS, default_joint_pos + + +def _print_finger_major(label: str, qpos_rad: np.ndarray) -> None: + """Print a flat (20,) qpos as 5 fingers x 4 joints, in degrees.""" + print(label) + for i in range(5): + finger_deg = np.rad2deg(qpos_rad[4 * i : 4 * (i + 1)]) + print(f" finger{i + 1}: " + ", ".join(f"{v:+6.1f}°" for v in finger_deg)) + + +def cmd_home(_args: argparse.Namespace) -> int: + """Ramp the real hand to the home grasp pose over a 3s ease-in-out. + + Safety: + - ``WujiHandDriver`` context manager: enable joints on enter, disable on exit. + - ``home()`` does a 50 Hz smoothstep ramp from the current pose (no snap). + - Reads back the actual position and reports max / RMS error vs target. + + After this the hand is at ``default_joint_pos()`` — the pose the sim resets to. + """ + from genelab_wuji.deploy.hand_driver import WujiHandDriver + + home_qpos = default_joint_pos() + print("=" * 60) + print("Wuji hand HOME — 3s smooth ramp to the reorient grasp pose") + print("=" * 60) + _print_finger_major("\nTarget pose (default_joint_pos, degrees, finger-major):", home_qpos) + + with WujiHandDriver() as drv: + print("\nReading current pose...") + current = drv.read_encoders() + max_diff = float(np.abs(current - home_qpos).max()) + print(f" Max diff from target: {max_diff:.3f} rad ({np.rad2deg(max_diff):.1f}°)") + + print("\nRamping to home pose over 3s...") + drv.home(duration_s=3.0) + + print("\nReading actual after ramp...") + actual = drv.read_encoders() + err = np.abs(actual - home_qpos) + max_err = float(err.max()) + rms_err = float(np.sqrt(np.mean((actual - home_qpos) ** 2))) + print(f" Max err: {max_err:.3f} rad ({np.rad2deg(max_err):.2f}°)") + print(f" RMS err: {rms_err:.3f} rad ({np.rad2deg(rms_err):.2f}°)") + if max_err < np.deg2rad(2): + print(" ✓ Within 2° — home reached") + elif max_err < np.deg2rad(5): + print(" ⚠ Within 5° — hand may not be tracking perfectly") + else: + print(" ✗ Over 5° error — investigate") + + print("\n✓ hand_utils home complete; joints disabled.") + return 0 + + +def cmd_check(_args: argparse.Namespace) -> int: + """READ-ONLY connection + encoder sanity check. + + Connects to ``wujihandpy.Hand()``, reads joint limits + encoders, prints the + current pose against the expected joint order, and compares to the home pose. + Does NOT write any targets; the hand stays where it is. This is the safe first + step before any closed-loop deployment. + """ + try: + import wujihandpy + except ImportError: + print("ERROR: wujihandpy not installed (the optional hardware SDK).") + return 1 + + print("=" * 60) + print("Wuji hand connection check (READ-ONLY)") + print("=" * 60) + + print("\n[1/4] Connecting to wujihandpy.Hand()...") + try: + hand = wujihandpy.Hand() + except Exception as e: # noqa: BLE001 — surface any hardware error to the operator + print(f" FAIL: {e}") + print(" Check: USB connected? udev rules? hand powered?") + return 1 + print(" ✓ connected") + + print("\n[2/4] Reading joint limits (no enable / no write)...") + try: + upper = hand.read_joint_upper_limit() # (5, 4) + lower = hand.read_joint_lower_limit() + except Exception as e: # noqa: BLE001 + print(f" FAIL reading limits: {e}") + return 1 + print(f" shape: upper={upper.shape}, lower={lower.shape}") + print(f" upper range: [{upper.min():.2f}, {upper.max():.2f}] rad") + print(f" lower range: [{lower.min():.2f}, {lower.max():.2f}] rad") + + print("\n[3/4] Reading encoder actual position...") + try: + actual = hand.read_joint_actual_position() # (5, 4) + except Exception as e: # noqa: BLE001 + print(f" FAIL: {e}") + return 1 + flat = np.asarray(actual, dtype=float).flatten() + print(f" shape: {actual.shape} → flat ({flat.shape[0]},)") + print(f" range: [{flat.min():.3f}, {flat.max():.3f}] rad") + _print_finger_major(" values (degrees, finger-major):", flat) + + print("\n[4/4] Validating joint order + home offset...") + print(f" Expected joint order ({N_JOINTS} joints):") + for i, name in enumerate(JOINT_NAMES_20): + print(f" [{i:2d}] {name}: {np.rad2deg(flat[i]):+6.1f}°") + + expected_home = default_joint_pos() + diff = flat - expected_home + print("\n Current vs home pose (default_joint_pos):") + print( + f" Max abs diff: {np.abs(diff).max():.3f} rad " + f"({np.rad2deg(np.abs(diff).max()):.1f}°)" + ) + print(f" RMS diff: {np.sqrt(np.mean(diff**2)):.3f} rad") + print(" (A large diff just means the hand isn't at home yet — not an error.)") + + print("\n" + "=" * 60) + print("✓ All read operations succeeded. Hand bridge healthy.") + print("✓ Next: ramp to home with `python -m genelab_wuji.deploy.scripts.hand_utils home`") + print("=" * 60) + return 0 + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Wuji hand hardware utilities (home / check).") + sub = parser.add_subparsers(dest="cmd", required=True) + p_home = sub.add_parser("home", help="Ramp all 20 joints to the home grasp pose (writes targets).") + p_home.set_defaults(func=cmd_home) + p_check = sub.add_parser("check", help="READ-ONLY connection + encoder sanity check.") + p_check.set_defaults(func=cmd_check) + args = parser.parse_args(argv) + return int(args.func(args)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/play_real.py b/examples/wuji/src/genelab_wuji/deploy/scripts/play_real.py new file mode 100644 index 00000000..bed26f4f --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/play_real.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +"""Deploy an exported reorient policy to control the (real or mock) Wuji hand. + +Wires the tested deploy core into a closed loop: + + cube/goal (ZMQ from cube_world_observer / toreal_viewer) + -> DeployObsBuilder -> ONNX policy -> EMA action -> hand driver + -> success monitor (geodesic(cube, goal) < threshold, held) -> resample + +Goal modes (``--goal-mode``): + external Goal comes from the goal ZMQ feed (toreal_viewer mocap drag). [default] + fixed Hold a single goal quat (``--goal-quat w,x,y,z``). + random Uniform-SO(3) goal; resampled each time the cube achieves it. + +Defaults to ``--mock`` (no hardware, no ZMQ required) so the loop can be smoke-run +anywhere; pass ``--real`` to drive the hand via ``wujihandpy``. The control logic +is covered headlessly by ``tests/test_examples_wuji_deploy_controller.py``. + +Usage: + # Smoke run without hardware (mock hand, random goals): + python -m genelab_wuji.deploy.scripts.play_real --ckpt policy.onnx --goal-mode random --steps 200 + + # Real hand, random goals resampled on success: + python -m genelab_wuji.deploy.scripts.play_real --ckpt policy.onnx --real --goal-mode random + + # Real hand, goal driven by toreal_viewer over ZMQ: + python -m genelab_wuji.deploy.scripts.play_real --ckpt policy.onnx --real --goal-mode external +""" + +from __future__ import annotations + +import argparse +import time + +import numpy as np + +from genelab_wuji.deploy.config import ENC_TO_POLICY, default_joint_pos_policy +from genelab_wuji.deploy.controller import DeployController +from genelab_wuji.deploy.hand_driver import HandDriverBase, MockHandDriver +from genelab_wuji.deploy.onnx_policy import ONNXPolicy +from genelab_wuji.deploy.zmq_bridge import ( + DEFAULT_CUBE_PORT, + DEFAULT_GOAL_PORT, + CubeReceiver, + GoalReceiver, +) + + +def _quat_geodesic(q1_wxyz: np.ndarray, q2_wxyz: np.ndarray) -> float: + """Angle (rad) between two unit wxyz quaternions.""" + dot = abs(float(np.dot(q1_wxyz, q2_wxyz))) + return 2.0 * float(np.arccos(min(1.0, max(-1.0, dot)))) + + +def _random_unit_quat_wxyz() -> np.ndarray: + """Uniform random rotation over SO(3) (scipy; xyzw -> wxyz).""" + from scipy.spatial.transform import Rotation + + q_xyzw = Rotation.random().as_quat() + return np.array([q_xyzw[3], q_xyzw[0], q_xyzw[1], q_xyzw[2]], dtype=float) + + +def _parse_quat_wxyz(s: str) -> np.ndarray: + """argparse type: 'w,x,y,z' -> normalized wxyz quaternion.""" + parts = [float(x) for x in s.split(",")] + if len(parts) != 4: + raise argparse.ArgumentTypeError(f"--goal-quat expects 4 floats w,x,y,z, got {s!r}") + q = np.array(parts, dtype=float) + n = float(np.linalg.norm(q)) + if n < 1e-9: + raise argparse.ArgumentTypeError("--goal-quat has zero norm") + return q / n + + +class _GoalStub: + """Local goal source (drop-in for ``GoalReceiver``): ``latest()`` returns the + current target quat; ``set`` swaps it (used by fixed / random goal modes).""" + + def __init__(self, quat_wxyz: np.ndarray) -> None: + self._quat = np.asarray(quat_wxyz, dtype=float) + + def set(self, quat_wxyz: np.ndarray) -> None: + self._quat = np.asarray(quat_wxyz, dtype=float) + + def latest(self) -> np.ndarray: + return self._quat.copy() + + def close(self) -> None: # parity with GoalReceiver for the cleanup path + pass + + +def _make_driver(real: bool) -> HandDriverBase: + if not real: + return MockHandDriver() + from genelab_wuji.deploy.hand_driver import WujiHandDriver + + driver = WujiHandDriver() + driver.__enter__() # caller exits via the finally block in main() + return driver + + +def _make_goal_source(args: argparse.Namespace): + """Return the goal source for the chosen ``--goal-mode``.""" + if args.goal_mode == "fixed": + if args.goal_quat is None: + raise SystemExit("--goal-mode fixed requires --goal-quat w,x,y,z") + return _GoalStub(args.goal_quat) + if args.goal_mode == "random": + return _GoalStub(_random_unit_quat_wxyz()) + # external + if args.no_zmq: + return GoalReceiver(connect=False) + return GoalReceiver(port=args.goal_port) + + +class _SimMirror: + """Genesis digital-twin viewer for play_real: live hand + observed cube + goal. + + Renders kinematically each control step (no physics — see + ``InteractiveScene.refresh_visualizer``), so it just shows reality, never fights + it. All heavy imports (Genesis / the env) are deferred to construction so a + ``--no-viewer`` run stays numpy-only and headless-safe. + """ + + def __init__(self) -> None: + from genelab_wuji.deploy.frame_transform import quat_mul + from genelab_wuji.deploy.real2sim import cube_pose_in_tag_to_world + from genelab_wuji.deploy.scripts._env import ( + build_reorient_env, + set_cube_pose, + set_goal_marker, + set_hand_joints, + tag_world_pose, + ) + + self._to_world = cube_pose_in_tag_to_world + self._quat_mul = quat_mul + self._set_cube = set_cube_pose + self._set_goal = set_goal_marker + self._set_hand = set_hand_joints + self._env = build_reorient_env(num_envs=1) + self._tag_pos_w, self._tag_quat_w = tag_world_pose(self._env) + + @property + def closed(self) -> bool: + return bool(self._env.viewer_closed) + + def update( + self, + joint_pos: np.ndarray, + cube_pos_tag: np.ndarray, + cube_quat_tag: np.ndarray, + goal_quat_tag: np.ndarray, + ) -> None: + self._set_hand(self._env, joint_pos) + cube_pos_w, cube_quat_w = self._to_world( + self._tag_pos_w, self._tag_quat_w, cube_pos_tag, cube_quat_tag + ) + self._set_cube(self._env, cube_pos_w, cube_quat_w) + self._set_goal(self._env, self._quat_mul(self._tag_quat_w, goal_quat_tag)) + self._env.scene.refresh_visualizer() + + def close(self) -> None: + self._env.close() + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--ckpt", required=True, help="exported policy.onnx") + parser.add_argument("--metadata", default=None, help="policy metadata.json (auto-detected)") + parser.add_argument("--real", action="store_true", help="drive the real hand (wujihandpy)") + parser.add_argument("--mock", action="store_true", help="use the mock hand (default)") + parser.add_argument("--no-zmq", action="store_true", help="skip ZMQ; zeros cube / stub goal") + parser.add_argument("--cube-port", type=int, default=DEFAULT_CUBE_PORT) + parser.add_argument("--goal-port", type=int, default=DEFAULT_GOAL_PORT) + parser.add_argument( + "--goal-mode", choices=("external", "fixed", "random"), default="external" + ) + parser.add_argument( + "--goal-quat", type=_parse_quat_wxyz, default=None, help="fixed goal 'w,x,y,z'" + ) + parser.add_argument( + "--success-threshold", type=float, default=0.2, help="success geodesic err (rad)" + ) + parser.add_argument( + "--success-hold-sec", type=float, default=0.5, help="hold time under threshold for success" + ) + parser.add_argument("--control-dt", type=float, default=0.05, help="policy step period (s)") + parser.add_argument("--steps", type=int, default=0, help="stop after N steps (0 = forever)") + parser.add_argument( + "--viewer", + action=argparse.BooleanOptionalAction, + default=True, + help="mirror the live hand + cube + goal in a Genesis viewer " + "(default on; pass --no-viewer for headless / mock smoke runs)", + ) + args = parser.parse_args() + + if args.control_dt <= 0: + raise SystemExit("--control-dt must be > 0 (used for joint velocity + success timing)") + + policy = ONNXPolicy(args.ckpt, metadata_path=args.metadata) + driver = _make_driver(real=args.real) + cube = CubeReceiver(connect=False) if args.no_zmq else CubeReceiver(port=args.cube_port) + goal = _make_goal_source(args) + + controller = DeployController( + policy=policy, + driver=driver, + cube_source=cube, + goal_source=goal, + default_joint_pos=default_joint_pos_policy(), # policy (articulation) order + control_dt=args.control_dt, + enc_to_policy=np.asarray(ENC_TO_POLICY), # remap encoder<->policy joint order + ) + controller.reset() + mirror = _SimMirror() if args.viewer else None + + hold_steps = max(1, round(args.success_hold_sec / args.control_dt)) + print( + f"[play_real] obs_dim={policy.input_dim} action_dim={policy.action_dim} " + f"driver={type(driver).__name__} goal_mode={args.goal_mode} " + f"viewer={'on' if mirror else 'off'} " + f"success<{args.success_threshold:.2f}rad held {hold_steps} steps" + ) + + step = 0 + hold = 0 + successes = 0 + try: + while args.steps == 0 or step < args.steps: + t0 = time.time() + info = controller.step() + step += 1 + + cube_pos_tag, cube_quat_tag = cube.latest() + goal_quat_tag = goal.latest() + + # Success monitor: geodesic(cube, goal) below threshold, sustained. + err = _quat_geodesic(cube_quat_tag, goal_quat_tag) + hold = hold + 1 if err < args.success_threshold else 0 + if hold >= hold_steps: + successes += 1 + print(f"[play_real] ✓ success #{successes} (err {np.degrees(err):.1f}°)") + hold = 0 + if args.goal_mode == "random" and isinstance(goal, _GoalStub): + goal.set(_random_unit_quat_wxyz()) + goal_quat_tag = goal.latest() + print("[play_real] new random goal") + + if mirror is not None: + mirror.update(info["joint_pos"], cube_pos_tag, cube_quat_tag, goal_quat_tag) + if mirror.closed: + break + + sleep = args.control_dt - (time.time() - t0) + if sleep > 0: + time.sleep(sleep) + except KeyboardInterrupt: + pass + finally: + driver_exit = getattr(driver, "__exit__", None) + if args.real and driver_exit is not None: + driver_exit(None, None, None) + cube.close() + goal.close() + if mirror is not None: + mirror.close() + print(f"[play_real] ran {step} control steps, {successes} successes") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/wuji/src/genelab_wuji/deploy/scripts/toreal_viewer.py b/examples/wuji/src/genelab_wuji/deploy/scripts/toreal_viewer.py new file mode 100644 index 00000000..adf3a99e --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/scripts/toreal_viewer.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +"""real2sim viewer: reproduce the real cube's pose inside the Genesis sim. + +Subscribes to the cube-pose ZMQ feed (published by ``cube_world_observer``), lifts +each tag-frame pose into sim-world coordinates, and places the sim cube there every +frame so the Genesis viewer mirrors reality. Optionally publishes a goal orientation +(drag-free fixed goal) on the goal port for ``play_real``. + +Usage: + # Terminal 1: real camera -> ZMQ (needs hardware; see cube_world_observer.py) + python -m genelab_wuji.deploy.scripts.cube_world_observer + + # Terminal 2: mirror the cube in the Genesis sim + python -m genelab_wuji.deploy.scripts.toreal_viewer + +Run on a host with a GPU + display (Genesis viewer). The transform math itself is +covered by the headless deploy tests. +""" + +from __future__ import annotations + +import argparse +import time + +from genelab_wuji.deploy.real2sim import cube_pose_in_tag_to_world +from genelab_wuji.deploy.scripts._env import build_reorient_env, set_cube_pose, tag_world_pose +from genelab_wuji.deploy.zmq_bridge import DEFAULT_CUBE_PORT, CubeReceiver + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--cube-port", type=int, default=DEFAULT_CUBE_PORT) + parser.add_argument("--host", default="localhost") + parser.add_argument("--fps", type=float, default=60.0, help="viewer refresh rate") + args = parser.parse_args() + + env = build_reorient_env() + tag_pos_w, tag_quat_w = tag_world_pose(env) # fixed-base hand -> constant + cube = CubeReceiver(port=args.cube_port, host=args.host) + + print(f"[toreal_viewer] mirroring cube from tcp://{args.host}:{args.cube_port}") + dt = 1.0 / max(1e-3, args.fps) + try: + while not env.viewer_closed: + cube_pos_tag, cube_quat_tag = cube.latest() + cube_pos_w, cube_quat_w = cube_pose_in_tag_to_world( + tag_pos_w, tag_quat_w, cube_pos_tag, cube_quat_tag + ) + set_cube_pose(env, cube_pos_w, cube_quat_w) + env.scene.step(update_visualizer=True) + time.sleep(dt) + except KeyboardInterrupt: + pass + finally: + cube.close() + env.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/wuji/src/genelab_wuji/deploy/zmq_bridge.py b/examples/wuji/src/genelab_wuji/deploy/zmq_bridge.py new file mode 100644 index 00000000..4c384c22 --- /dev/null +++ b/examples/wuji/src/genelab_wuji/deploy/zmq_bridge.py @@ -0,0 +1,300 @@ +"""ZMQ pub/sub bridge between deploy processes (ported from wuji-mjlab). + +Topology (localhost IPC): + +* port 5555 — cube pose: ``cube_world_observer`` -> ``play_real`` / ``toreal_viewer`` +* port 5556 — goal orientation: ``toreal_viewer`` -> ``play_real`` + +The observer publishes orientation in scipy **xyzw** order; everything downstream +runs in mujoco **wxyz**. ``cube_pose_from_msg`` / ``goal_from_msg`` own that +conversion as pure functions so it is testable without sockets; the receiver +classes wrap them with a background thread and a last-valid cache. +""" + +from __future__ import annotations + +import json +import threading +import time +from typing import Any + +import numpy as np + +DEFAULT_CUBE_PORT: int = 5555 +"""Default ZMQ port for cube pose (cube_world_observer).""" + +DEFAULT_GOAL_PORT: int = 5556 +"""Default ZMQ port for goal orientation (toreal_viewer).""" + + +def cube_pose_from_msg( + msg: dict[str, Any], +) -> tuple[np.ndarray, np.ndarray, bool, float | None]: + """Parse a cube-pose message into ``(pos, quat_wxyz, world_fixed, cube_size)``. + + The wire orientation is scipy xyzw; the returned quaternion is mujoco wxyz. + ``cube_size`` is ``None`` when the publisher does not announce it. + """ + cube = msg["cube1"] + p = cube["position"] + q = cube["orientation"] # scipy xyzw on the wire + pos = np.array([p["x"], p["y"], p["z"]]) + quat_wxyz = np.array([q["w"], q["x"], q["y"], q["z"]]) + world_fixed = bool(msg.get("world_fixed", False)) + cube_size = msg.get("cube_size") + return pos, quat_wxyz, world_fixed, (float(cube_size) if cube_size is not None else None) + + +def cube_msg_from_pose( + pos: np.ndarray, + quat_wxyz: np.ndarray, + *, + world_fixed: bool, + cube_size: float | None = None, +) -> dict[str, Any]: + """Serialize a cube pose to the wire format (orientation as scipy xyzw). + + Inverse of ``cube_pose_from_msg``; used by the observer / any cube feeder. + """ + msg: dict[str, Any] = { + "world_fixed": bool(world_fixed), + "cube1": { + "position": {"x": float(pos[0]), "y": float(pos[1]), "z": float(pos[2])}, + "orientation": { + "x": float(quat_wxyz[1]), + "y": float(quat_wxyz[2]), + "z": float(quat_wxyz[3]), + "w": float(quat_wxyz[0]), + }, + }, + } + if cube_size is not None: + msg["cube_size"] = float(cube_size) + return msg + + +def goal_from_msg(msg: dict[str, Any]) -> np.ndarray: + """Parse a goal message into a wxyz quaternion (already wxyz on the wire).""" + q = msg["goal"]["orientation"] + return np.array([q["w"], q["x"], q["y"], q["z"]]) + + +_RECEIVER_POLL_INTERVAL: float = 0.001 +"""Polling interval for receiver background threads (seconds).""" + + +class CubeReceiver: + """Subscribe to cube pose (port 5555) and expose the latest valid pose. + + The background thread parses each message via ``cube_pose_from_msg`` and feeds + it to ``_update_from_msg``; ``latest()`` returns the last sample seen while + ``world_fixed`` was True (so a momentary loss of calibration keeps the last + good pose), defaulting to ``(zeros, identity)`` until the first valid sample. + + Pass ``connect=False`` to skip the socket/thread (used in tests and when a + process drives ``_update_from_msg`` itself). + """ + + def __init__( + self, + port: int = DEFAULT_CUBE_PORT, + host: str = "localhost", + *, + connect: bool = True, + ) -> None: + self.port = port + self._latest_pos = np.zeros(3) + self._latest_quat = np.array([1.0, 0.0, 0.0, 0.0]) + self._cached_valid_pos = np.zeros(3) + self._cached_valid_quat = np.array([1.0, 0.0, 0.0, 0.0]) + self.cube_count = 0 + self.world_fixed = False + self.cube_size: float | None = None + self._lock = threading.Lock() + self._running = connect + self._socket: Any = None + self._context: Any = None + self._thread: Any = None + if connect: + import zmq + + self._context = zmq.Context() + self._socket = self._context.socket(zmq.SUB) + self._socket.connect(f"tcp://{host}:{port}") + self._socket.setsockopt_string(zmq.SUBSCRIBE, "") + self._socket.setsockopt(zmq.RCVHWM, 1) + self._socket.setsockopt(zmq.CONFLATE, 1) + self._thread = threading.Thread(target=self._receiver_loop, daemon=True) + self._thread.start() + + def _update_from_msg(self, msg: dict[str, Any]) -> None: + """Ingest one parsed cube message (also the unit-test entry point).""" + pos, quat_wxyz, world_fixed, cube_size = cube_pose_from_msg(msg) + with self._lock: + self._latest_pos = pos + self._latest_quat = quat_wxyz + self.world_fixed = world_fixed + if cube_size is not None: + self.cube_size = cube_size + self.cube_count += 1 + + def latest(self) -> tuple[np.ndarray, np.ndarray]: + """Return ``(cube_pos, cube_quat_wxyz)``, caching the last valid sample.""" + with self._lock: + if self.cube_count > 0 and self.world_fixed: + self._cached_valid_pos = self._latest_pos.copy() + self._cached_valid_quat = self._latest_quat.copy() + return self._cached_valid_pos.copy(), self._cached_valid_quat.copy() + + def close(self) -> None: + self._running = False + if self._socket is not None: + self._socket.close() + if self._context is not None: + self._context.term() + + def _receiver_loop(self) -> None: + import zmq + + while self._running: + try: + data = self._socket.recv(zmq.NOBLOCK) + msg = json.loads(data.decode("utf-8")) + if "cube1" in msg: + self._update_from_msg(msg) + except zmq.Again: + pass + except (json.JSONDecodeError, KeyError): + pass + time.sleep(_RECEIVER_POLL_INTERVAL) + + +class GoalReceiver: + """Subscribe to goal orientation (port 5556) and expose the latest goal. + + ``latest()`` returns identity ``[1, 0, 0, 0]`` until a goal arrives. Pass + ``connect=False`` to skip the socket/thread (tests / self-driven feeds). + """ + + def __init__( + self, + port: int = DEFAULT_GOAL_PORT, + host: str = "localhost", + *, + connect: bool = True, + ) -> None: + self.port = port + self._latest_goal = np.array([1.0, 0.0, 0.0, 0.0]) + self.goal_count = 0 + self._lock = threading.Lock() + self._running = connect + self._socket: Any = None + self._context: Any = None + self._thread: Any = None + if connect: + import zmq + + self._context = zmq.Context() + self._socket = self._context.socket(zmq.SUB) + self._socket.connect(f"tcp://{host}:{port}") + self._socket.setsockopt_string(zmq.SUBSCRIBE, "") + self._socket.setsockopt(zmq.RCVHWM, 1) + self._socket.setsockopt(zmq.CONFLATE, 1) + self._thread = threading.Thread(target=self._receiver_loop, daemon=True) + self._thread.start() + + def _update_from_msg(self, msg: dict[str, Any]) -> None: + """Ingest one parsed goal message (also the unit-test entry point).""" + goal = goal_from_msg(msg) + with self._lock: + self._latest_goal = goal + self.goal_count += 1 + + def latest(self) -> np.ndarray: + """Return the latest goal quaternion (wxyz); identity until one arrives.""" + with self._lock: + return self._latest_goal.copy() + + def close(self) -> None: + self._running = False + if self._socket is not None: + self._socket.close() + if self._context is not None: + self._context.term() + + def _receiver_loop(self) -> None: + import zmq + + while self._running: + try: + data = self._socket.recv(zmq.NOBLOCK) + msg = json.loads(data.decode("utf-8")) + if "goal" in msg: + self._update_from_msg(msg) + except zmq.Again: + pass + except (json.JSONDecodeError, KeyError): + pass + time.sleep(_RECEIVER_POLL_INTERVAL) + + +class GoalPublisher: + """Publish goal orientation (wxyz) on port 5556 (toreal_viewer -> play_real).""" + + def __init__(self, port: int = DEFAULT_GOAL_PORT) -> None: + import zmq + + self.port = port + self._context = zmq.Context() + self._socket = self._context.socket(zmq.PUB) + self._socket.bind(f"tcp://*:{port}") + + def publish(self, goal_quat_wxyz: np.ndarray) -> None: + msg = { + "timestamp": time.time(), + "goal": { + "orientation": { + "w": float(goal_quat_wxyz[0]), + "x": float(goal_quat_wxyz[1]), + "y": float(goal_quat_wxyz[2]), + "z": float(goal_quat_wxyz[3]), + } + }, + } + self._socket.send_string(json.dumps(msg)) + + def close(self) -> None: + self._socket.close() + self._context.term() + + +class CubePublisher: + """Publish cube pose on port 5555 (cube_world_observer -> consumers).""" + + def __init__(self, port: int = DEFAULT_CUBE_PORT) -> None: + import zmq + + self.port = port + self._context = zmq.Context() + self._socket = self._context.socket(zmq.PUB) + self._socket.bind(f"tcp://*:{port}") + + def publish( + self, + pos: np.ndarray, + quat_wxyz: np.ndarray, + *, + world_fixed: bool, + cube_size: float | None = None, + ) -> None: + import zmq + + msg = cube_msg_from_pose( + pos, quat_wxyz, world_fixed=world_fixed, cube_size=cube_size + ) + msg["timestamp"] = time.time() + self._socket.send_string(json.dumps(msg), flags=zmq.NOBLOCK) + + def close(self) -> None: + self._socket.close() + self._context.term() diff --git a/examples/wuji/src/genelab_wuji/reorient/asset.py b/examples/wuji/src/genelab_wuji/reorient/asset.py index 1c1dd1bc..375a44ae 100644 --- a/examples/wuji/src/genelab_wuji/reorient/asset.py +++ b/examples/wuji/src/genelab_wuji/reorient/asset.py @@ -11,11 +11,14 @@ import xml.etree.ElementTree as ET from functools import lru_cache from pathlib import Path -from typing import Final from genelab.actuator import ImplicitPDActuatorCfg +from genelab.asset_zoo.wuji_hand import ( + WUJI_CUBE_SPEC as _CUBE, + WUJI_HAND_REORIENT_SPEC as _MESHES, +) from genelab.entity import ArticulationCfg -from genelab.utils.download import AssetSpec, fetch_asset +from genelab.utils.download import fetch_asset from genelab_wuji.reorient.constants import ( REORIENT_JOINT_POS, @@ -23,26 +26,12 @@ REORIENT_ROBOT_ROOT_ROT, ) +# Reorient meshes (``_MESHES``) and the viewer cube (``_CUBE``) are declared in the central +# asset zoo (genelab.asset_zoo.wuji_hand) so ``genelab asset list``/``download`` discover +# them; imported here as the example's handles. _PACKAGE_ROOT = Path(__file__).resolve().parent _MJCF_TEMPLATE = _PACKAGE_ROOT / "mjcf" / "right_mjlab.xml" -_MESHES: Final = AssetSpec( - name="wuji_hand_reorient", - url="https://raw.githubusercontent.com/KraHsu/genelab-assets/main/wuji_hand_reorient/wuji_hand_reorient.tar.gz", - md5="68bed6d8f0fe4adc81ac8aa7f62cdfbe", - filename="wuji_hand_reorient.tar.gz", - archive_member="wuji_hand_reorient/meshes/right/right_palm_link.STL", -) - -# 54 mm UV-textured cube (visible faces) for the viewer: the held object and the goal marker. -_CUBE: Final = AssetSpec( - name="wuji_cube", - url="https://raw.githubusercontent.com/KraHsu/genelab-assets/main/wuji_cube/wuji_cube.tar.gz", - md5="f77eff83a9ca8ade2966d5202b5d337d", - filename="wuji_cube.tar.gz", - archive_member="wuji_cube/dex_cube.obj", -) - def resolve_cube_mesh() -> str: """Download (once) and return the path to the 54 mm textured cube OBJ.""" diff --git a/examples/wuji/src/genelab_wuji/reorient/env_cfg.py b/examples/wuji/src/genelab_wuji/reorient/env_cfg.py index 1fec08e3..61379a6f 100644 --- a/examples/wuji/src/genelab_wuji/reorient/env_cfg.py +++ b/examples/wuji/src/genelab_wuji/reorient/env_cfg.py @@ -137,6 +137,11 @@ def wuji_hand_reorient_env_cfg(play: bool = False, num_envs: int = 8192) -> Mana substeps=1, vis=play, gpu=not play, + # Per-env model params so dof/link DR (PD gains, frictionloss, mass/inertia) + # actually applies per environment. Genesis defaults these off, which silently + # no-ops per-env dof DR; training-only (eval/play uses nominal params). + batch_dofs_info=not play, + batch_links_info=not play, ), scene=InteractiveSceneCfg( env_spacing=(0.75, 0.75), @@ -275,6 +280,18 @@ def _events_cfg(play: bool) -> dict[str, EventTermCfg]: } if play: return cfg + # Per-episode gravity-direction tilt — the Genesis-native equivalent of mjlab's + # hand-pitch DR (Genesis can't tilt a fixed-base hand per env, so we tilt gravity + # instead: same gravity-in-palm physics, hand stays fixed-base, tag frame unchanged). + # Makes the policy robust to a tilted hardware mount (e.g. the real ~10 deg down-tilt). + cfg["gravity_tilt"] = EventTermCfg( + func=mdp.dr.gravity_tilt, + mode="reset", + # 0.2 rad (~11 deg) cone covers the real ~10 deg mount tilt with margin. A larger + # full-azimuth cone (tried 0.4) was too hard — the cube rolled off before the + # policy could bootstrap a grasp (training stalled at goals ~0.3). + params={"max_tilt_rad": 0.2}, + ) cfg.update( { "robot_friction": EventTermCfg( @@ -318,6 +335,22 @@ def _events_cfg(play: bool) -> dict[str, EventTermCfg]: mode="startup", params={"asset_cfg": robot, "bias_range": (-0.01, 0.01)}, ), + # Per-env joint dry-friction (stiction): the MJCF has none, so the policy + # never learns to overcome the real hand's static friction and reorients too + # slowly on hardware. Real per-env DR now that batch_dofs_info is on (the + # earlier global-baseline attempt hurt; this samples per env/joint). + "dof_frictionloss": EventTermCfg( + func=mdp.dr.dof_frictionloss, + mode="startup", + params={"asset_cfg": robot, "friction_range": (0.0, 0.03)}, + ), + # Per-env joint armature (rotor inertia) scale — robustness to inertia + # calibration error (mjlab parity). Now per-env via batch_dofs_info. + "dof_armature": EventTermCfg( + func=mdp.dr.dof_armature, + mode="startup", + params={"asset_cfg": robot, "scale_range": (0.75, 1.3)}, + ), "object_disturbance": EventTermCfg( func=events.apply_velocity_disturbance, mode="interval", diff --git a/examples/wuji/src/genelab_wuji/reorient/scripts/sim2sim_mjlab.py b/examples/wuji/src/genelab_wuji/reorient/scripts/sim2sim_mjlab.py index b8c43b69..f24751be 100644 --- a/examples/wuji/src/genelab_wuji/reorient/scripts/sim2sim_mjlab.py +++ b/examples/wuji/src/genelab_wuji/reorient/scripts/sim2sim_mjlab.py @@ -44,10 +44,21 @@ _P_TO_M = [POLICY_JOINTS.index(m) for m in MJLAB_JOINTS] # policy-order array -> mjlab order -def run(policy_path: str, trials: int, seed: int, timeout: float = 14.0) -> dict[str, float]: +def run( + policy_path: str, trials: int, seed: int, timeout: float = 14.0, tilt_deg: float = 0.0 +) -> dict[str, float]: torch.manual_seed(seed) np.random.seed(seed) scene = build_reorient_scene(sim_dt=0.01, ctrl_dt=0.05, cube_edge_m=0.054) + # Emulate a tilted hardware mount by tilting gravity in the (level-hand) eval scene + # about +X. Same gravity-in-palm effect as pitching the hand; measures robustness to + # a mount tilt (e.g. the real ~10 deg) without moving the fixed-base hand. + if tilt_deg: + import math + + g0 = float(np.linalg.norm(scene.model.opt.gravity)) or 9.81 + rad = math.radians(tilt_deg) + scene.model.opt.gravity[:] = [0.0, g0 * math.sin(rad), -g0 * math.cos(rad)] policy = load_policy(policy_path) tag_pos, tag_quat = tag_frame() default_p = default_policy_joint_pos() @@ -150,10 +161,16 @@ def main() -> None: p.add_argument("--policy", required=True, help="rsl_rl checkpoint .pt") p.add_argument("--trials", type=int, default=100) p.add_argument("--seed", type=int, default=0) + p.add_argument( + "--gravity-tilt", + type=float, + default=0.0, + help="tilt gravity by this many degrees (emulates a tilted hardware mount)", + ) args = p.parse_args() - r = run(args.policy, args.trials, args.seed) + r = run(args.policy, args.trials, args.seed, tilt_deg=args.gravity_tilt) print( - f"sim2sim (wuji-mjlab env) over {r['trials']} trials: " + f"sim2sim (wuji-mjlab env, tilt={args.gravity_tilt:.0f}deg) over {r['trials']} trials: " f"success_rate={r['success_rate']:.2f} drop_rate={r['drop_rate']:.2f} " f"timeout_rate={r['timeout_rate']:.2f} mean_goal_reaches={r['mean_goal_reaches']:.2f}" ) diff --git a/examples/wuji/src/genelab_wuji/wuji_hand/assets.py b/examples/wuji/src/genelab_wuji/wuji_hand/assets.py index 2b2f19fa..4622096f 100644 --- a/examples/wuji/src/genelab_wuji/wuji_hand/assets.py +++ b/examples/wuji/src/genelab_wuji/wuji_hand/assets.py @@ -11,7 +11,8 @@ import numpy as np from numpy.typing import NDArray -from genelab.utils.download import AssetSpec, fetch_asset +from genelab.asset_zoo.wuji_hand import WUJI_HAND_SPEC as WUJI_HAND_DESCRIPTION +from genelab.utils.download import fetch_asset type FloatArray = NDArray[np.float32] @@ -20,16 +21,10 @@ DEFAULT_TRAJECTORY = PACKAGE_ROOT / "data" / "wave.npy" SIDES = ("left", "right") -# Full left+right hand description (MJCF + ~5 MB of STL meshes) hosted as a .tar.gz in the -# genelab-assets repo, so the source tree stays lean. ``archive_member`` points at the -# right-hand MJCF; the description directory is its grandparent in the extracted tree. -WUJI_HAND_DESCRIPTION = AssetSpec( - name="wuji_hand", - url="https://raw.githubusercontent.com/KraHsu/genelab-assets/main/wuji_hand/wuji_hand.tar.gz", - md5="46827dfc417773d469a75347a072e82e", - filename="wuji_hand.tar.gz", - archive_member="wuji_hand/description/mjcf/right.xml", -) +# Hand description spec (``wuji_hand``: MJCF + ~5 MB STL meshes) lives in the central asset +# zoo (genelab.asset_zoo.wuji_hand) so the asset CLI discovers it; aliased here for the +# example's existing call sites. ``archive_member`` points at the right-hand MJCF; the +# description directory is its grandparent in the extracted tree. def fetch_description_dir() -> Path: diff --git a/policy.onnx b/policy.onnx new file mode 100644 index 00000000..e6be5183 Binary files /dev/null and b/policy.onnx differ diff --git a/src/genelab/asset_zoo/wuji_hand.py b/src/genelab/asset_zoo/wuji_hand.py index 7cb66a40..5d310db9 100644 --- a/src/genelab/asset_zoo/wuji_hand.py +++ b/src/genelab/asset_zoo/wuji_hand.py @@ -1,4 +1,4 @@ -"""WUJI Hand asset zoo entry — 20-DoF dexterous hand (left + right). +"""WUJI asset zoo entries — 20-DoF dexterous hand (left + right) + reorient assets. Five fingers (``finger1``–``finger5``), each with four joints (``joint1``–``joint4``), for 20 actuated DoF per hand. The canonical hardware description is mirrored from @@ -9,6 +9,12 @@ A single implicit-PD actuator group spans all 20 joints with uniform nominal gains — soft position control sized to the hand's small fingers. Downstream manipulation tasks (e.g. in-hand reorientation) override these with hardware-calibrated per-joint gains. + +This module is the single source of truth for every WUJI ``AssetSpec`` — the base hand +plus the reorient-task meshes and viewer cube. They are declared at module level so the +asset zoo discovers them (``genelab asset list`` / ``asset download`` walk this package +for module-level ``AssetSpec`` instances); the bundled Wuji example imports them rather +than redeclaring its own. """ from typing import Final @@ -23,8 +29,37 @@ ) _MD5: Final = "46827dfc417773d469a75347a072e82e" +WUJI_HAND_SPEC: Final = AssetSpec( + name="wuji_hand", + url=_URL, + md5=_MD5, + filename="wuji_hand.tar.gz", + archive_member="wuji_hand/description/mjcf/right.xml", +) +"""20-DoF dexterous hand description (MJCF + meshes), left + right.""" + +WUJI_HAND_REORIENT_SPEC: Final = AssetSpec( + name="wuji_hand_reorient", + url="https://raw.githubusercontent.com/KraHsu/genelab-assets/main/wuji_hand_reorient/wuji_hand_reorient.tar.gz", + md5="68bed6d8f0fe4adc81ac8aa7f62cdfbe", + filename="wuji_hand_reorient.tar.gz", + archive_member="wuji_hand_reorient/meshes/right/right_palm_link.STL", +) +"""mjlab-tuned reorient right-hand meshes (~4 MB), paired with the in-tree ``right_mjlab.xml``.""" + +WUJI_CUBE_SPEC: Final = AssetSpec( + name="wuji_cube", + url="https://raw.githubusercontent.com/KraHsu/genelab-assets/main/wuji_cube/wuji_cube.tar.gz", + md5="f77eff83a9ca8ade2966d5202b5d337d", + filename="wuji_cube.tar.gz", + archive_member="wuji_cube/dex_cube.obj", +) +"""54 mm UV-textured cube (held object + goal marker) for the reorient viewer.""" + def _spec(side: str) -> AssetSpec: + if side == "right": + return WUJI_HAND_SPEC return AssetSpec( name="wuji_hand", url=_URL, diff --git a/src/genelab/configs.py b/src/genelab/configs.py index a51dd126..e9c3b8b3 100644 --- a/src/genelab/configs.py +++ b/src/genelab/configs.py @@ -93,6 +93,13 @@ class SimulationCfg: integrator: str | None = ( None # gs.integrator.: Euler / implicitfast / approximate_implicitfast ) + # Store per-DOF / per-link model params (kp/kv/frictionloss/damping/armature, mass/ + # inertia) PER-ENV so they can be domain-randomized per environment. Genesis defaults + # both to False (params shared across the batch), which silently no-ops per-env dof DR + # like ``mdp.dr.randomize_joint_stiffness_damping`` on implicit-PD actuators. Costs a + # little memory (per-env copies of the model arrays); enable for sim2real DR. + batch_dofs_info: bool | None = None # RigidOptions.batch_dofs_info + batch_links_info: bool | None = None # RigidOptions.batch_links_info def rigid_options_kwargs(self) -> dict[str, Any]: """Map the *set* rigid-solver fields to ``gs.options.RigidOptions`` kwargs. @@ -111,6 +118,8 @@ def rigid_options_kwargs(self) -> dict[str, Any]: "tolerance": self.solver_tolerance, "constraint_timeconst": self.constraint_timeconst, "integrator": self.integrator, + "batch_dofs_info": self.batch_dofs_info, + "batch_links_info": self.batch_links_info, } return {k: v for k, v in mapping.items() if v is not None} diff --git a/src/genelab/mdp/dr/__init__.py b/src/genelab/mdp/dr/__init__.py index 96c2b9e7..b6fdcd09 100644 --- a/src/genelab/mdp/dr/__init__.py +++ b/src/genelab/mdp/dr/__init__.py @@ -16,13 +16,22 @@ from genelab.mdp.dr.actuator import randomize_actuator_deadzone from genelab.mdp.dr.body import body_com_offset, body_mass_offset from genelab.mdp.dr.geom import geom_friction -from genelab.mdp.dr.joint import encoder_bias, randomize_joint_stiffness_damping +from genelab.mdp.dr.gravity import gravity_tilt +from genelab.mdp.dr.joint import ( + dof_armature, + dof_frictionloss, + encoder_bias, + randomize_joint_stiffness_damping, +) __all__ = [ "body_com_offset", "body_mass_offset", + "dof_armature", + "dof_frictionloss", "encoder_bias", "geom_friction", + "gravity_tilt", "randomize_actuator_deadzone", "randomize_joint_stiffness_damping", ] diff --git a/src/genelab/mdp/dr/gravity.py b/src/genelab/mdp/dr/gravity.py new file mode 100644 index 00000000..2e73ff84 --- /dev/null +++ b/src/genelab/mdp/dr/gravity.py @@ -0,0 +1,58 @@ +"""Per-environment gravity-direction domain randomization.""" + +import math +from typing import TYPE_CHECKING + +import torch + +from genelab.mdp.dr._common import normalise_env_ids + +if TYPE_CHECKING: + from genelab.contracts import EnvContext + + +def _rigid_solver(env: "EnvContext"): + """Best-effort access to the Genesis rigid solver (per-env gravity lives there).""" + scene = getattr(env, "scene", None) + gs_scene = getattr(scene, "_gs_scene", None) + sim = getattr(gs_scene, "sim", None) + return getattr(sim, "rigid_solver", None) + + +def gravity_tilt( + env: "EnvContext", + env_ids: torch.Tensor | None, + max_tilt_rad: float = 0.4, + magnitude: float = 9.81, +) -> None: + """Per-env gravity-direction DR: tilt gravity by a random polar angle in a random azimuth. + + The Genesis-native equivalent of randomizing a fixed-base hand's mount pitch (mjlab's + ``reset_root_state`` pitch DR). Genesis refuses per-env orientation on a fixed-base link + with geometry, but tilting **gravity** per env gives the SAME gravity-in-palm physics + while keeping the hand fixed-base and the wrist-tag world frame (``tag_w``) unchanged — + so the deploy obs pipeline needs no frame changes. Makes the policy robust to a tilted + hardware mount (e.g. a ~10 deg down-tilt). A full cone (random azimuth) covers the mount + tilt regardless of axis. Use ``mode="reset"`` (re-sample per episode, like mjlab). + + Per-env gravity requires no batch flag — gravity is already a per-env solver field + (``rigid_solver.set_gravity(..., envs_idx=...)``). Guarded for the fake-env test scaffold. + """ + env_ids = normalise_env_ids(env, env_ids) + if env_ids.numel() == 0: + return + setter = getattr(_rigid_solver(env), "set_gravity", None) + if setter is None: + return + n = int(env_ids.numel()) + theta = torch.empty(n, device=env.device).uniform_(0.0, max_tilt_rad) + phi = torch.empty(n, device=env.device).uniform_(0.0, 2.0 * math.pi) + horizontal = magnitude * torch.sin(theta) + g = torch.empty(n, 3, device=env.device) + g[:, 0] = horizontal * torch.cos(phi) + g[:, 1] = horizontal * torch.sin(phi) + g[:, 2] = -magnitude * torch.cos(theta) + try: + setter(g, envs_idx=env_ids) + except Exception: + pass diff --git a/src/genelab/mdp/dr/joint.py b/src/genelab/mdp/dr/joint.py index d43c411d..53268c7b 100644 --- a/src/genelab/mdp/dr/joint.py +++ b/src/genelab/mdp/dr/joint.py @@ -62,6 +62,87 @@ def randomize_joint_stiffness_damping( pass +def dof_frictionloss( + env: "EnvContext", + env_ids: torch.Tensor | None, + friction_range: tuple[float, float] = (0.0, 0.02), + asset_cfg: "SceneEntityCfg | None" = None, +) -> None: + """Per-env, per-joint absolute joint dry-friction (frictionloss, Nm) DR. + + The hand MJCF declares zero joint frictionloss, so a Genesis-trained policy never + learns to overcome the *real* hand's static friction — it under-drives the fingers and + reorients too slowly on hardware (the cube is held but the goal times out). Sampling an + absolute frictionloss per env/joint makes the policy robust to a range of real stiction. + + REQUIRES ``SimulationCfg.batch_dofs_info=True`` — otherwise Genesis stores frictionloss + shared across the batch (``set_dofs_frictionloss`` rejects a per-env ``(n_env, n_joint)`` + tensor) and this call silently no-ops. Written via ``set_dofs_frictionloss``; guarded + for the fake-env test scaffolding. + """ + env_ids = normalise_env_ids(env, env_ids) + if env_ids.numel() == 0: + return + n = int(env_ids.numel()) + handle = asset_handle(env, asset_cfg) + setter = getattr(handle, "set_dofs_frictionloss", None) or getattr( + handle, "set_dofs_friction", None + ) + if setter is None: + return + for actuator in asset_articulation(env, asset_cfg).actuators.values(): + n_joints = actuator.num_joints + if n_joints == 0: + continue + vals = torch.empty(n, n_joints, device=env.device).uniform_(*friction_range) + try: + setter(vals, dofs_idx_local=actuator.dof_ids, envs_idx=env_ids) + except Exception: + pass + + +def dof_armature( + env: "EnvContext", + env_ids: torch.Tensor | None, + scale_range: tuple[float, float] = (0.75, 1.3), + asset_cfg: "SceneEntityCfg | None" = None, +) -> None: + """Per-env, per-joint multiplicative DR on joint armature (rotor reflected inertia). + + Armature shapes the effective joint inertia / actuator response; randomizing it makes + the policy robust to inertia-calibration error (mjlab uses scale 0.75-1.3). The MJCF + declares non-zero armature, so a multiplicative scale is meaningful (unlike frictionloss). + + REQUIRES ``SimulationCfg.batch_dofs_info=True`` (else armature is shared across the batch + and per-env writes no-op). Reads the nominal via ``get_dofs_armature``; guarded for the + fake-env test scaffolding. + """ + env_ids = normalise_env_ids(env, env_ids) + if env_ids.numel() == 0: + return + n = int(env_ids.numel()) + handle = asset_handle(env, asset_cfg) + getter = getattr(handle, "get_dofs_armature", None) + setter = getattr(handle, "set_dofs_armature", None) + if setter is None or getter is None: + return + for actuator in asset_articulation(env, asset_cfg).actuators.values(): + n_joints = actuator.num_joints + if n_joints == 0: + continue + try: + nominal = getter(dofs_idx_local=actuator.dof_ids, envs_idx=env_ids) + nominal = nominal[0] if nominal.dim() == 2 else nominal + except Exception: + continue + mult = torch.empty(n, n_joints, device=env.device).uniform_(*scale_range) + vals = nominal.to(env.device).unsqueeze(0) * mult + try: + setter(vals, dofs_idx_local=actuator.dof_ids, envs_idx=env_ids) + except Exception: + pass + + def encoder_bias( env: "EnvContext", env_ids: torch.Tensor | None, diff --git a/src/genelab/scene/interactive_scene.py b/src/genelab/scene/interactive_scene.py index 06490885..db64a42b 100644 --- a/src/genelab/scene/interactive_scene.py +++ b/src/genelab/scene/interactive_scene.py @@ -774,6 +774,30 @@ def step(self, *, update_visualizer: bool = True) -> None: return raise + def refresh_visualizer(self) -> None: + """Refresh the viewer from the current kinematic state — no physics step. + + For kinematic viewers (deploy real2sim / calibration) that set entity state + each frame and want to render it exactly as set: this runs forward kinematics + + a viewer update but **no time integration**, so gravity / contacts never move + the teleported bodies (the Genesis analogue of MuJoCo's ``mj_forward``). A no-op + once the viewer is closed; sets :py:attr:`viewer_closed` if the user closes it. + """ + if self._viewer_closed: + return + vis = getattr(self._gs_scene, "visualizer", None) + if vis is None: + return + exc_cls = self._gs_exception_cls + try: + vis.update_visual_states() + vis.update() + except Exception as exc: + if exc_cls is not None and isinstance(exc, exc_cls) and str(exc) == "Viewer closed.": + self._viewer_closed = True + return + raise + def refresh_state(self) -> None: for art in self.articulations.values(): art.refresh() diff --git a/tests/test_examples_wuji_deploy_action.py b/tests/test_examples_wuji_deploy_action.py new file mode 100644 index 00000000..1c520234 --- /dev/null +++ b/tests/test_examples_wuji_deploy_action.py @@ -0,0 +1,54 @@ +"""Action post-processing for deploy (pure numpy). + +Mirrors ``JointPositionOffsetEMAAction`` from the reorient task: the policy emits +raw actions ~[-1, 1]; the joint target is ``default + scale * clamp(action)``, +clamped to joint limits, EMA-smoothed against the previous target, and held at the +default pose for a warmup window after each reset. (Deploy drops the training-only +``encoder_bias`` / ``action_noise`` terms.) +""" + +import numpy as np + +from genelab_wuji.deploy.action import ActionProcessor + +_N = 20 +_DEFAULT = np.linspace(0.1, 0.9, _N) + + +def _proc(**kw: object) -> ActionProcessor: + return ActionProcessor(default_joint_pos=_DEFAULT, **kw) # type: ignore[arg-type] + + +def test_warmup_holds_default_pose() -> None: + proc = _proc(action_scale=0.5, ema_alpha=0.5, warmup_steps=3) + proc.reset() + for _ in range(3): + target = proc.process(np.ones(_N)) # large action ignored during warmup + assert np.allclose(target, _DEFAULT) + + +def test_first_post_warmup_step_applies_scaled_offset_under_ema() -> None: + proc = _proc(action_scale=0.5, ema_alpha=0.5, warmup_steps=0) + proc.reset() + action = np.full(_N, 0.4) + target = proc.process(action) + # prev == default on the first step, so: + # smoothed = alpha*(default + scale*action) + (1-alpha)*default + # = default + alpha*scale*action + expected = _DEFAULT + 0.5 * 0.5 * action + assert np.allclose(target, expected) + + +def test_target_clamped_to_joint_limits() -> None: + lo = _DEFAULT - 0.05 + hi = _DEFAULT + 0.05 + proc = _proc( + action_scale=1.0, + ema_alpha=1.0, # no smoothing, so the raw clamp is directly observable + warmup_steps=0, + joint_pos_limits=(lo, hi), + ) + proc.reset() + target = proc.process(np.ones(_N)) # would push +1.0 past hi without clamping + assert np.all(target <= hi + 1e-9) + assert np.allclose(target, hi) diff --git a/tests/test_examples_wuji_deploy_controller.py b/tests/test_examples_wuji_deploy_controller.py new file mode 100644 index 00000000..aea9f4a2 --- /dev/null +++ b/tests/test_examples_wuji_deploy_controller.py @@ -0,0 +1,84 @@ +"""End-to-end deploy control loop, headless (mock hand + self-driven ZMQ sources). + +``DeployController`` ties the pieces together: read encoders -> build policy obs +(cube/goal from the observer feed) -> ONNX policy -> EMA action -> write target. +This is the "model deploy controls the hand" path; ``play_real.py`` wires the same +controller to the real hand, real ZMQ, and a Genesis viewer. +""" + +import json +from pathlib import Path + +import numpy as np +import pytest + +pytest.importorskip("onnxruntime") +pytest.importorskip("zmq") +torch = pytest.importorskip("torch") + +from genelab_wuji.deploy.config import default_joint_pos # noqa: E402 +from genelab_wuji.deploy.controller import DeployController # noqa: E402 +from genelab_wuji.deploy.hand_driver import MockHandDriver # noqa: E402 +from genelab_wuji.deploy.onnx_policy import ONNXPolicy # noqa: E402 +from genelab_wuji.deploy.zmq_bridge import CubeReceiver, GoalReceiver # noqa: E402 + +_OBS_DIM = 207 +_N = 20 + + +def _export_zero_policy(tmp_path: Path) -> Path: + """A policy that always outputs zeros (so post-warmup target == default).""" + lin = torch.nn.Linear(_OBS_DIM, _N) + with torch.no_grad(): + lin.weight.zero_() + lin.bias.zero_() + lin.eval() + onnx_path = tmp_path / "policy.onnx" + torch.onnx.export( + lin, + torch.zeros(1, _OBS_DIM), + str(onnx_path), + input_names=["obs"], + output_names=["actions"], + dynamic_axes={"obs": {0: "batch"}, "actions": {0: "batch"}}, + opset_version=17, + ) + (tmp_path / "policy.onnx.metadata.json").write_text( + json.dumps({"obs_dim": _OBS_DIM, "action_dim": _N}) + ) + return onnx_path + + +def _make_controller(tmp_path: Path, **kw) -> tuple[DeployController, MockHandDriver]: + policy = ONNXPolicy(_export_zero_policy(tmp_path)) + driver = MockHandDriver() + cube = CubeReceiver(connect=False) + goal = GoalReceiver(connect=False) + ctrl = DeployController( + policy=policy, + driver=driver, + cube_source=cube, + goal_source=goal, + default_joint_pos=default_joint_pos(), + **kw, + ) + return ctrl, driver + + +def test_step_returns_action_and_writes_target_to_hand(tmp_path: Path) -> None: + ctrl, driver = _make_controller(tmp_path, warmup_steps=0) + ctrl.reset() + info = ctrl.step() + assert info["action"].shape == (_N,) + assert info["target"].shape == (_N,) + # The hand received exactly the processed target. + assert np.allclose(driver.read_encoders(), info["target"]) + + +def test_zero_policy_post_warmup_drives_hand_to_default(tmp_path: Path) -> None: + # A zero action -> target == default joint pos once warmup has elapsed. + ctrl, driver = _make_controller(tmp_path, warmup_steps=0) + ctrl.reset() + for _ in range(5): + ctrl.step() + assert np.allclose(driver.read_encoders(), default_joint_pos(), atol=1e-5) diff --git a/tests/test_examples_wuji_deploy_frame_transform.py b/tests/test_examples_wuji_deploy_frame_transform.py new file mode 100644 index 00000000..f02deef7 --- /dev/null +++ b/tests/test_examples_wuji_deploy_frame_transform.py @@ -0,0 +1,34 @@ +"""wxyz quaternion helpers used by the deploy pipeline (pure numpy).""" + +import numpy as np + +from genelab_wuji.deploy.frame_transform import ( + quat_apply, + quat_conjugate, + quat_mul, +) + + +def _quat_z(angle: float) -> np.ndarray: + return np.array([np.cos(angle / 2), 0.0, 0.0, np.sin(angle / 2)]) + + +def test_quat_apply_rotates_x_to_y_about_z() -> None: + q = _quat_z(np.pi / 2) # +90deg about z + out = quat_apply(q, np.array([1.0, 0.0, 0.0])) + assert np.allclose(out, [0.0, 1.0, 0.0], atol=1e-9) + + +def test_quat_mul_composes_rotations() -> None: + q45 = _quat_z(np.pi / 4) + q90 = _quat_z(np.pi / 2) + composed = quat_mul(q45, q45) # 45 + 45 = 90 + out = quat_apply(composed, np.array([1.0, 0.0, 0.0])) + expected = quat_apply(q90, np.array([1.0, 0.0, 0.0])) + assert np.allclose(out, expected, atol=1e-9) + + +def test_quat_conjugate_inverts_rotation() -> None: + q = _quat_z(np.pi / 3) + identity = quat_mul(q, quat_conjugate(q)) + assert np.allclose(identity, [1.0, 0.0, 0.0, 0.0], atol=1e-9) diff --git a/tests/test_examples_wuji_deploy_hand_driver.py b/tests/test_examples_wuji_deploy_hand_driver.py new file mode 100644 index 00000000..ea049d57 --- /dev/null +++ b/tests/test_examples_wuji_deploy_hand_driver.py @@ -0,0 +1,65 @@ +"""Hand-driver abstraction for deploy: hardware-agnostic interface + mock. + +The control loop talks to ``HandDriverBase``; ``MockHandDriver`` lets the whole +pipeline run (and be tested) without the real ``wujihandpy`` hand. The encoder +joint order matches the GeneLab policy order (finger1_joint1..4, finger2...), +which is also ``wujihandpy``'s (5, 4) row-major flatten — so no remap is needed. +""" + +import numpy as np + +from genelab_wuji.deploy.config import default_joint_pos +from genelab_wuji.deploy.hand_driver import MockHandDriver, _home_ramp # pyright: ignore[reportPrivateUsage] +from genelab_wuji.reorient.constants import REORIENT_JOINT_POS + + +def test_mock_driver_echoes_written_target() -> None: + driver = MockHandDriver() + target = np.linspace(-0.3, 0.3, 20) + driver.write_target(target) + assert np.allclose(driver.read_encoders(), target) + + +def test_mock_driver_encoder_order_matches_policy_joint_order() -> None: + driver = MockHandDriver() + assert tuple(driver.joint_names_in_encoder_order()) == tuple(REORIENT_JOINT_POS) + + +def test_mock_driver_home_sets_grasp_keyframe() -> None: + driver = MockHandDriver() + driver.write_target(np.zeros(20)) + driver.home() + expected = np.array(list(REORIENT_JOINT_POS.values())) + assert np.allclose(driver.read_encoders(), expected) + + +def test_mock_driver_home_accepts_duration_arg() -> None: + # The 3s ramp is a real-driver concern; the mock ignores duration_s but must + # accept it so the shared HandDriverBase / DeployController.reset call works. + driver = MockHandDriver() + driver.write_target(np.zeros(20)) + driver.home(duration_s=3.0) + assert np.allclose(driver.read_encoders(), default_joint_pos()) + + +def test_home_ramp_is_monotone_smoothstep_ending_at_target() -> None: + current = np.zeros(20) + target = default_joint_pos() + ramp = _home_ramp(current, target, steps=150) # 3s @ 50 Hz + assert ramp.shape == (150, 20) + # Smoothstep: starts eased-in near current, ends exactly at target. + assert np.allclose(ramp[-1], target) + assert not np.allclose(ramp[0], target) + # Each joint moves monotonically from current toward target (no overshoot). + deltas = np.diff(ramp, axis=0) + signs = np.sign(target - current) + assert np.all(deltas * signs[None, :] >= -1e-9) + assert ramp.max() <= max(target.max(), current.max()) + 1e-9 + assert ramp.min() >= min(target.min(), current.min()) - 1e-9 + + +def test_home_ramp_single_step_hits_target() -> None: + target = default_joint_pos() + ramp = _home_ramp(np.zeros(20), target, steps=1) + assert ramp.shape == (1, 20) + assert np.allclose(ramp[0], target) diff --git a/tests/test_examples_wuji_deploy_joint_order.py b/tests/test_examples_wuji_deploy_joint_order.py new file mode 100644 index 00000000..b2f59513 --- /dev/null +++ b/tests/test_examples_wuji_deploy_joint_order.py @@ -0,0 +1,53 @@ +"""Pin the deploy joint-order remap against the real Genesis articulation order. + +The encoder / wujihandpy order (``JOINT_NAMES_20``) is finger-major; the policy / +Genesis articulation order (``POLICY_JOINT_NAMES``) is joint-major. ``DeployController`` +remaps between them. If they drift, the real hand gets scrambled joint obs + actions and +twitches without manipulating the cube (the real-hand 0%-success bug). +""" + +import numpy as np +import pytest + +from genelab_wuji.deploy.config import ( + ENC_TO_POLICY, + JOINT_NAMES_20, + POLICY_JOINT_NAMES, + default_joint_pos, + default_joint_pos_policy, +) + + +def test_enc_to_policy_is_a_valid_permutation() -> None: + assert sorted(ENC_TO_POLICY) == list(range(20)) + assert set(POLICY_JOINT_NAMES) == set(JOINT_NAMES_20) + # Policy order is joint-major: the first five are every finger's joint1. + assert POLICY_JOINT_NAMES[:5] == tuple(f"right_finger{f}_joint1" for f in range(1, 6)) + # Encoder order is finger-major: the first four are finger1's joints 1..4. + assert JOINT_NAMES_20[:4] == tuple(f"right_finger1_joint{j}" for j in range(1, 5)) + + +def test_default_policy_is_default_reordered() -> None: + d = default_joint_pos() + dp = default_joint_pos_policy() + assert np.allclose(dp, d[list(ENC_TO_POLICY)]) + # Round-trips back to encoder order via the inverse permutation. + assert np.allclose(dp[np.argsort(ENC_TO_POLICY)], d) + + +def test_policy_joint_order_matches_env() -> None: + """Drift guard: POLICY_JOINT_NAMES must equal the built env's articulation order.""" + pytest.importorskip("genesis") + from genelab_wuji.deploy.scripts._env import build_reorient_env + + env = None + try: + env = build_reorient_env(num_envs=1) + assert list(env.scene["robot"].joint_names) == list(POLICY_JOINT_NAMES) + except Exception as exc: # asset download / GPU / display unavailable in minimal CI + if env is None: + pytest.skip(f"reorient env unavailable: {exc}") + raise + finally: + if env is not None: + env.close() diff --git a/tests/test_examples_wuji_deploy_obs.py b/tests/test_examples_wuji_deploy_obs.py new file mode 100644 index 00000000..0d29e891 --- /dev/null +++ b/tests/test_examples_wuji_deploy_obs.py @@ -0,0 +1,114 @@ +"""Policy observation assembly for deploy (pure numpy, no simulator). + +The deployed obs needs no forward kinematics: joint state comes from encoders, +cube/goal poses come from the observer (already in the tag frame), and the last +action is tracked. ``DeployObsBuilder`` reproduces the GeneLab training policy obs +(term order, per-term 3-step history, 6D goal-error encoding) so an exported ONNX +policy receives exactly what it was trained on. + +Policy obs layout (matches ``genelab_wuji.reorient.env_cfg`` policy group): + joint_pos_rel_history 20 * 3 = 60 + joint_vel_rel_history 20 * 3 = 60 + cube_pos_in_tag_history 3 * 3 = 9 + goal_rot_err_6d_history 6 * 3 = 18 + last_action_history 20 * 3 = 60 + = 207 +""" + +import numpy as np +import pytest + +from genelab_wuji.deploy.obs import DeployObsBuilder, goal_rot_err_6d + +_N_JOINTS = 20 +_HIST = 3 +_OBS_DIM = 207 +_JP, _JV, _CUBE, _GOAL, _ACT = 60, 60, 9, 18, 60 + + +def _default_qpos() -> np.ndarray: + return np.linspace(0.1, 0.9, _N_JOINTS) + + +def test_obs_dim_is_207() -> None: + builder = DeployObsBuilder(default_joint_pos=_default_qpos(), history_len=_HIST) + builder.reset() + obs = builder.compute( + joint_pos=_default_qpos(), + joint_vel=np.zeros(_N_JOINTS), + cube_pos_tag=np.zeros(3), + cube_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + goal_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + last_action=np.zeros(_N_JOINTS), + ) + assert obs.shape == (_OBS_DIM,) + + +def test_first_frame_backfills_history() -> None: + # After reset, the first compute should fill all 3 history slots with the same + # frame (mirrors the training CircularBuffer backfill on reset). + builder = DeployObsBuilder(default_joint_pos=_default_qpos(), history_len=_HIST) + builder.reset() + + delta = 0.05 * np.ones(_N_JOINTS) + obs = builder.compute( + joint_pos=_default_qpos() + delta, + joint_vel=np.zeros(_N_JOINTS), + cube_pos_tag=np.zeros(3), + cube_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + goal_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + last_action=np.zeros(_N_JOINTS), + ) + + # joint_pos_rel block = (joint_pos - default) repeated across 3 frames. + jp_block = obs[:_JP].reshape(_HIST, _N_JOINTS) + assert np.allclose(jp_block[0], delta) + assert np.allclose(jp_block[1], delta) + assert np.allclose(jp_block[2], delta) + + +def test_history_rolls_oldest_to_newest() -> None: + builder = DeployObsBuilder(default_joint_pos=_default_qpos(), history_len=_HIST) + builder.reset() + + act_start = _JP + _JV + _CUBE + _GOAL # last_action block offset + + def step(action_val: float) -> np.ndarray: + obs = builder.compute( + joint_pos=_default_qpos(), + joint_vel=np.zeros(_N_JOINTS), + cube_pos_tag=np.zeros(3), + cube_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + goal_quat_tag=np.array([1.0, 0.0, 0.0, 0.0]), + last_action=action_val * np.ones(_N_JOINTS), + ) + return obs[act_start : act_start + _ACT].reshape(_HIST, _N_JOINTS)[:, 0] + + step(1.0) # backfill -> [1, 1, 1] + step(2.0) # roll -> [1, 1, 2] + frames = step(3.0) # roll -> [1, 2, 3] + assert np.allclose(frames, [1.0, 2.0, 3.0]) # oldest -> newest + + +def test_goal_rot_err_6d_matches_genelab_training_math() -> None: + # Pin the 6D encoding against the *actual* GeneLab training code path + # (genelab.utils.math + reorient.mdp._math), the policy was trained on this. + torch = pytest.importorskip("torch") + from genelab.utils.math import matrix_from_quat, quat_conjugate, quat_mul + from genelab_wuji.reorient.mdp._math import matrix_to_rotation_6d + + rng = np.random.default_rng(3) + for _ in range(8): + cube = rng.standard_normal(4) + cube /= np.linalg.norm(cube) + goal = rng.standard_normal(4) + goal /= np.linalg.norm(goal) + + ours = goal_rot_err_6d(cube, goal) + + c = torch.tensor(cube, dtype=torch.float).unsqueeze(0) + g = torch.tensor(goal, dtype=torch.float).unsqueeze(0) + err = quat_mul(c, quat_conjugate(g)) + ref = matrix_to_rotation_6d(matrix_from_quat(err)).squeeze(0).numpy() + + assert np.allclose(ours, ref, atol=1e-5), f"{ours} != {ref}" diff --git a/tests/test_examples_wuji_deploy_onnx_policy.py b/tests/test_examples_wuji_deploy_onnx_policy.py new file mode 100644 index 00000000..89fab7e6 --- /dev/null +++ b/tests/test_examples_wuji_deploy_onnx_policy.py @@ -0,0 +1,73 @@ +"""ONNX policy wrapper for deploy: load a GeneLab-exported policy and run it. + +GeneLab's exporter writes an ONNX with input ``obs`` / output ``actions`` (batch +axis dynamic) plus a sibling ``.metadata.json`` recording ``obs_dim`` and +``action_dim`` with normalization baked in. These tests build a tiny real ONNX so +the load / dim-introspection / single-forward path is exercised end to end. +""" + +import json +from pathlib import Path + +import numpy as np +import pytest + +pytest.importorskip("onnxruntime") +torch = pytest.importorskip("torch") + +from genelab_wuji.deploy.onnx_policy import ONNXPolicy # noqa: E402 + +_OBS_DIM = 207 +_ACTION_DIM = 20 + + +def _export_tiny_policy(tmp_path: Path, obs_dim: int, action_dim: int) -> Path: + """Export a 1-layer Linear policy to ONNX + sibling metadata.json (GeneLab fmt).""" + model = torch.nn.Linear(obs_dim, action_dim) + model.eval() + onnx_path = tmp_path / "policy.onnx" + torch.onnx.export( + model, + torch.zeros(1, obs_dim), + str(onnx_path), + input_names=["obs"], + output_names=["actions"], + dynamic_axes={"obs": {0: "batch"}, "actions": {0: "batch"}}, + opset_version=17, + ) + meta = { + "obs_dim": obs_dim, + "action_dim": action_dim, + "action_range": [-1.0, 1.0], + "normalization_baked": True, + } + (tmp_path / "policy.onnx.metadata.json").write_text(json.dumps(meta)) + return onnx_path + + +def test_loads_and_reports_dims(tmp_path: Path) -> None: + onnx_path = _export_tiny_policy(tmp_path, _OBS_DIM, _ACTION_DIM) + policy = ONNXPolicy(onnx_path) + assert policy.input_dim == _OBS_DIM + assert policy.action_dim == _ACTION_DIM + + +def test_single_forward_returns_action_vector(tmp_path: Path) -> None: + onnx_path = _export_tiny_policy(tmp_path, _OBS_DIM, _ACTION_DIM) + policy = ONNXPolicy(onnx_path) + action = policy(np.zeros(_OBS_DIM, dtype=np.float32)) + assert action.shape == (_ACTION_DIM,) + + +def test_wrong_obs_dim_raises(tmp_path: Path) -> None: + onnx_path = _export_tiny_policy(tmp_path, _OBS_DIM, _ACTION_DIM) + policy = ONNXPolicy(onnx_path) + with pytest.raises(ValueError, match="expected"): + policy(np.zeros(_OBS_DIM - 1, dtype=np.float32)) + + +def test_reads_metadata_sidecar(tmp_path: Path) -> None: + onnx_path = _export_tiny_policy(tmp_path, _OBS_DIM, _ACTION_DIM) + policy = ONNXPolicy(onnx_path) + assert policy.metadata["obs_dim"] == _OBS_DIM + assert policy.metadata["normalization_baked"] is True diff --git a/tests/test_examples_wuji_deploy_real2sim.py b/tests/test_examples_wuji_deploy_real2sim.py new file mode 100644 index 00000000..923906f0 --- /dev/null +++ b/tests/test_examples_wuji_deploy_real2sim.py @@ -0,0 +1,48 @@ +"""real2sim coordinate transforms for the Wuji-hand deploy pipeline. + +The vision pipeline detects the cube pose in the *camera* frame and the wrist +AprilTag pose in the *camera* frame. ``cube_cam_to_tag`` lifts the cube into the +wrist-tag frame — the exact frame the policy was trained on. These tests pin that +math (pure numpy, no simulator) so the real cube position reproduces correctly. +""" + +import numpy as np + +from genelab_wuji.deploy.frame_transform import cube_cam_to_tag + + +def _rand_rot(rng: np.random.Generator) -> np.ndarray: + R, _ = np.linalg.qr(rng.standard_normal((3, 3))) + if np.linalg.det(R) < 0: # keep it a proper rotation (det = +1) + R[:, 0] = -R[:, 0] + return R + + +def test_cube_sitting_on_tag_maps_to_tag_origin() -> None: + # When the cube pose in camera frame equals the tag pose in camera frame, the + # cube sits exactly at the tag origin: tag-frame position ~0, rotation ~identity. + rng = np.random.default_rng(0) + R_tag_cam, _ = np.linalg.qr(rng.standard_normal((3, 3))) # random rotation + t_tag_cam = rng.standard_normal(3) + + R_cube_tag, t_cube_tag = cube_cam_to_tag(R_tag_cam, t_tag_cam, R_tag_cam, t_tag_cam) + + assert np.allclose(t_cube_tag, np.zeros(3), atol=1e-9) + assert np.allclose(R_cube_tag, np.eye(3), atol=1e-9) + + +def test_known_cube_pose_in_tag_is_recovered_through_camera() -> None: + # Define a known cube pose in the tag frame, push it out to the camera frame, + # then lift it back: cube_cam_to_tag must recover the original tag-frame pose. + rng = np.random.default_rng(7) + R_tag_cam, t_tag_cam = _rand_rot(rng), rng.standard_normal(3) + R_cube_tag_true, t_cube_tag_true = _rand_rot(rng), np.array([0.01, -0.02, 0.03]) + + # Compose the cube pose into the camera frame. + R_cube_cam = R_tag_cam @ R_cube_tag_true + t_cube_cam = R_tag_cam @ t_cube_tag_true + t_tag_cam + + R_cube_tag, t_cube_tag = cube_cam_to_tag(R_tag_cam, t_tag_cam, R_cube_cam, t_cube_cam) + + assert np.allclose(t_cube_tag, t_cube_tag_true, atol=1e-9) + assert np.allclose(R_cube_tag, R_cube_tag_true, atol=1e-9) diff --git a/tests/test_examples_wuji_deploy_sim_viz.py b/tests/test_examples_wuji_deploy_sim_viz.py new file mode 100644 index 00000000..6abd1fa4 --- /dev/null +++ b/tests/test_examples_wuji_deploy_sim_viz.py @@ -0,0 +1,63 @@ +"""Reproduce the real cube in the Genesis sim world from its tag-frame pose. + +The observer reports the cube pose in the wrist-tag frame; to draw it in sim we +lift it back to sim-world coordinates given the tag's world pose. These tests pin +that lift as the exact inverse of the obs-side ``cube_*_world_to_tag`` transform, +so what the policy sees and what the viewer draws agree. +""" + +import numpy as np + +from genelab_wuji.deploy.real2sim import ( + cube_pose_in_tag_to_world, + cube_pose_world_to_tag, + tag_pose_in_world, +) + + +def _quat_z(angle: float) -> np.ndarray: + return np.array([np.cos(angle / 2), 0.0, 0.0, np.sin(angle / 2)]) + + +def test_cube_world_to_tag_and_back_round_trips() -> None: + tag_pos_w = np.array([0.02, 0.0, 0.55]) + tag_quat_w = _quat_z(np.pi / 5) + cube_pos_w = np.array([-0.05, 0.01, 0.56]) + cube_quat_w = _quat_z(-np.pi / 7) + + cube_pos_tag, cube_quat_tag = cube_pose_world_to_tag( + tag_pos_w, tag_quat_w, cube_pos_w, cube_quat_w + ) + back_pos, back_quat = cube_pose_in_tag_to_world( + tag_pos_w, tag_quat_w, cube_pos_tag, cube_quat_tag + ) + + assert np.allclose(back_pos, cube_pos_w, atol=1e-9) + # quaternion equality up to sign + assert np.allclose(back_quat, cube_quat_w, atol=1e-9) or np.allclose( + back_quat, -cube_quat_w, atol=1e-9 + ) + + +def test_cube_at_tag_origin_lifts_to_tag_world_position() -> None: + # A cube reported at the tag origin must render exactly at the tag's world pose. + tag_pos_w = np.array([0.1, -0.2, 0.5]) + tag_quat_w = _quat_z(0.3) + + cube_pos_w, cube_quat_w = cube_pose_in_tag_to_world( + tag_pos_w, tag_quat_w, np.zeros(3), np.array([1.0, 0.0, 0.0, 0.0]) + ) + assert np.allclose(cube_pos_w, tag_pos_w, atol=1e-9) + assert np.allclose(cube_quat_w, tag_quat_w, atol=1e-9) + + +def test_tag_pose_in_world_applies_tag_in_palm_offset() -> None: + # With the palm at the origin (identity), the tag world pose equals the + # constant TAG_IN_PALM transform from the reorient constants. + from genelab_wuji.reorient.constants import TAG_IN_PALM_POS, TAG_IN_PALM_QUAT_WXYZ + + tag_pos_w, tag_quat_w = tag_pose_in_world( + np.zeros(3), np.array([1.0, 0.0, 0.0, 0.0]) + ) + assert np.allclose(tag_pos_w, TAG_IN_PALM_POS, atol=1e-9) + assert np.allclose(tag_quat_w, TAG_IN_PALM_QUAT_WXYZ, atol=1e-9) diff --git a/tests/test_examples_wuji_deploy_zmq_bridge.py b/tests/test_examples_wuji_deploy_zmq_bridge.py new file mode 100644 index 00000000..701c4232 --- /dev/null +++ b/tests/test_examples_wuji_deploy_zmq_bridge.py @@ -0,0 +1,121 @@ +"""ZMQ bridge for the deploy pipeline: message parsing + latest()-cache semantics. + +The cube observer publishes orientation in scipy xyzw order; the deploy stack runs +on mujoco wxyz. These tests pin the conversion and the receiver's last-valid cache +(so a momentary loss of ``world_fixed`` keeps the last good pose). +""" + +from typing import Any + +import numpy as np +import pytest + +zmq = pytest.importorskip("zmq") + +from genelab_wuji.deploy.zmq_bridge import ( # noqa: E402 + CubeReceiver, + GoalReceiver, + cube_msg_from_pose, + cube_pose_from_msg, + goal_from_msg, +) + + +def test_cube_pose_from_msg_converts_xyzw_to_wxyz() -> None: + msg = { + "world_fixed": True, + "cube_size": 0.054, + "cube1": { + "position": {"x": 0.1, "y": -0.2, "z": 0.3}, + # scipy xyzw on the wire (w last) + "orientation": {"x": 0.0, "y": 0.0, "z": 0.0, "w": 1.0}, + }, + } + pos, quat_wxyz, world_fixed, cube_size = cube_pose_from_msg(msg) + + assert np.allclose(pos, [0.1, -0.2, 0.3]) + assert np.allclose(quat_wxyz, [1.0, 0.0, 0.0, 0.0]) # w first + assert world_fixed is True + assert cube_size == pytest.approx(0.054) + + +def test_cube_pose_from_msg_preserves_quat_component_mapping() -> None: + # Distinct components so a w<->x swap (or any mislabel) would be caught. + msg = { + "world_fixed": False, + "cube1": { + "position": {"x": 0.0, "y": 0.0, "z": 0.0}, + "orientation": {"x": 0.1, "y": 0.2, "z": 0.3, "w": 0.4}, + }, + } + _pos, quat_wxyz, world_fixed, cube_size = cube_pose_from_msg(msg) + + assert np.allclose(quat_wxyz, [0.4, 0.1, 0.2, 0.3]) # [w, x, y, z] + assert world_fixed is False + assert cube_size is None # absent in this message + + +def _cube_msg( + pos: list[float], quat_wxyz: list[float], world_fixed: bool +) -> dict[str, Any]: + w, x, y, z = quat_wxyz + return { + "world_fixed": world_fixed, + "cube1": { + "position": {"x": pos[0], "y": pos[1], "z": pos[2]}, + "orientation": {"x": x, "y": y, "z": z, "w": w}, + }, + } + + +def test_cube_receiver_latest_defaults_to_zeros_identity_before_valid() -> None: + # No socket: connect=False keeps it inert so we can drive it by hand. + recv = CubeReceiver(connect=False) + pos, quat = recv.latest() + assert np.allclose(pos, np.zeros(3)) + assert np.allclose(quat, [1.0, 0.0, 0.0, 0.0]) + + +def test_cube_receiver_caches_last_valid_through_world_unfixed() -> None: + recv = CubeReceiver(connect=False) + + good_pos = [0.1, 0.2, 0.3] + good_quat = [0.0, 1.0, 0.0, 0.0] # wxyz + recv._update_from_msg(_cube_msg(good_pos, good_quat, world_fixed=True)) + + pos, quat = recv.latest() + assert np.allclose(pos, good_pos) + assert np.allclose(quat, good_quat) + + # A later sample arrives while calibration is momentarily lost: keep the cache. + recv._update_from_msg(_cube_msg([9.0, 9.0, 9.0], [1.0, 0.0, 0.0, 0.0], world_fixed=False)) + pos2, quat2 = recv.latest() + assert np.allclose(pos2, good_pos) + assert np.allclose(quat2, good_quat) + + +def test_goal_from_msg_reads_wxyz_orientation() -> None: + msg = {"goal": {"orientation": {"w": 0.4, "x": 0.1, "y": 0.2, "z": 0.3}}} + assert np.allclose(goal_from_msg(msg), [0.4, 0.1, 0.2, 0.3]) + + +def test_goal_receiver_defaults_to_identity_then_tracks_latest() -> None: + recv = GoalReceiver(connect=False) + assert np.allclose(recv.latest(), [1.0, 0.0, 0.0, 0.0]) + + recv._update_from_msg({"goal": {"orientation": {"w": 0.0, "x": 0.0, "y": 1.0, "z": 0.0}}}) + assert np.allclose(recv.latest(), [0.0, 0.0, 1.0, 0.0]) + + +def test_cube_msg_round_trips_through_parser() -> None: + # The publisher serializes a wxyz pose to the wire (scipy xyzw); the parser + # converts back to wxyz. Round-trip must be exact so observer and consumer agree. + pos = np.array([0.12, -0.03, 0.55]) + quat_wxyz = np.array([0.4, 0.1, 0.2, 0.3]) + msg = cube_msg_from_pose(pos, quat_wxyz, world_fixed=True, cube_size=0.054) + + got_pos, got_quat, world_fixed, cube_size = cube_pose_from_msg(msg) + assert np.allclose(got_pos, pos) + assert np.allclose(got_quat, quat_wxyz) + assert world_fixed is True + assert cube_size == pytest.approx(0.054)