diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index 9f985667a..33b5b61ef 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -92,10 +92,11 @@ train: checkpoint_path: "./runs" # # # Rendering # # # - render: false # Determines whether to render the environment (note: will slow down training) + render: true # Determines whether to render the environment (note: will slow down training) + render_backend: "array" # Options: matplotlib, array render_3d: true # Render simulator state in 3d or 2d render_interval: 1 # Render every k iterations - render_k_scenarios: 10 # Number of scenarios to render + render_k_scenarios: 2 # Number of scenarios to render render_format: "mp4" # Options: gif, mp4 render_fps: 15 # Frames per second zoom_radius: 50 diff --git a/gpudrive/env/env_puffer.py b/gpudrive/env/env_puffer.py index 811971cf6..ffe5e8fbf 100644 --- a/gpudrive/env/env_puffer.py +++ b/gpudrive/env/env_puffer.py @@ -8,12 +8,10 @@ from gpudrive.env.config import EnvConfig, RenderConfig from gpudrive.env.env_torch import GPUDriveTorchEnv -from gpudrive.datatypes.observation import ( - LocalEgoState, -) from gpudrive.visualize.utils import img_from_fig from gpudrive.env.dataset import SceneDataLoader +from gpudrive.visualize.utils import color_onehot_segmentation_map from pufferlib.environment import PufferEnv from gpudrive import GPU_DRIVE_DATA_DIR @@ -66,6 +64,7 @@ def __init__( render_format="mp4", render_fps=15, zoom_radius=50, + render_backend="array", buf=None, **kwargs, ): @@ -97,6 +96,7 @@ def __init__( self.render_format = render_format self.render_fps = render_fps self.zoom_radius = zoom_radius + self.render_backend = render_backend # VBD self.vbd_model_path = vbd_model_path @@ -422,19 +422,32 @@ def render_env(self): np.where(np.array(list(self.rendering_in_progress.values())))[0] ) time_steps = list(self.episode_lengths[envs_to_render, 0]) - + if len(envs_to_render) > 0: - sim_state_figures = self.env.vis.plot_simulator_state( - env_indices=envs_to_render, - time_steps=time_steps, - zoom_radius=self.zoom_radius, - ) + if self.render_backend == "matplotlib": + # Render bird's eye view using matplotlib + sim_state_figures = self.env.vis.plot_simulator_state( + env_indices=envs_to_render, + time_steps=time_steps, + zoom_radius=self.zoom_radius, + ) - for idx, render_env_idx in enumerate(envs_to_render): + for idx, render_env_idx in enumerate(envs_to_render): + self.frames[render_env_idx].append( + img_from_fig(sim_state_figures[idx]) + ) + else: + # Render bird's eye view using raster scan algorithm + bev = self.env.sim.bev_observation_tensor().to_torch() + # Convert the BEV observation to a colored segmentation map + # If the tensor is one-hot encoded segmentation data + colored_bev = color_onehot_segmentation_map(bev, 'cpu') + agent_idx = 0 + self.frames[render_env_idx].append( - img_from_fig(sim_state_figures[idx]) + colored_bev[agent_idx, :] ) - + def resample_scenario_batch(self): """Sample and set new batch of WOMD scenarios.""" diff --git a/gpudrive/visualize/utils.py b/gpudrive/visualize/utils.py index 094c1d030..c4e2d8c26 100644 --- a/gpudrive/visualize/utils.py +++ b/gpudrive/visualize/utils.py @@ -14,6 +14,36 @@ from gpudrive.visualize.color import ROAD_GRAPH_COLORS, ROAD_GRAPH_TYPE_NAMES +def color_onehot_segmentation_map(onehot_map, device): + """ + Args: + onehot_map: torch.Tensor of shape [num_classes, H, W], dtype=bool or float (one-hot) + device: torch device + Returns: + colored_image: [H, W, 3] uint8 image + """ + color_mapping = torch.tensor([ + [0, 0, 0], # None + [125, 125, 125], # RoadEdge + [120, 120, 120], # RoadLine + [230, 230, 230], # RoadLane + [200, 200, 200], # CrossWalk + [217, 166, 33], # SpeedBump + [255, 0, 0], # StopSign + [0, 255, 255], # Vehicle + [0, 255, 0], # Pedestrian + [128, 0, 128], # Cyclist + [192, 192, 192], # Padding + ], dtype=torch.uint8, device=device) + + # Convert one-hot to class indices: [H, W] + class_map = onehot_map.argmax(dim=0) + + # Map to color image: [H, W, 3] + colored_image = color_mapping[class_map] + + return colored_image + def img_from_fig(fig: matplotlib.figure.Figure) -> np.ndarray: """Returns a [H, W, 3] uint8 np image from fig.canvas.tostring_rgb().""" # Adjusted margins to better accommodate 3D plots