diff --git a/README.md b/README.md index aa3aae59..8c242647 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ MultiTask Suite This suite contains a collection of environments centered around dexterous manipulation. Standard [TCDM benchmarks](https://pregrasps.github.io/) are a part of this suite ## - ROBEL Suite (Coming soon) - This suite contains a collection of environments centered around real-world locomotion and manipulation. Standard [ROBEL benchmarks](http://roboticsbenchmarks.org/) are a part of this suite + This suite contains a collection of environments centered around real-world locomotion and manipulation. Standard [ROBEL benchmarks](https://sites.google.com/view/roboticsbenchmarks) are a part of this suite # Citation If you find `RoboHive` useful in your research, diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 1072b132..efc2d540 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -495,8 +495,8 @@ def get_input_seed(self): def _reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs): """ - Reset the environment - Default implemention provided. Override if env needs custom reset + Reset the environment (Default implemention provided). + Override if env needs custom reset. Carefully handle return type for gym/gymnasium compatibility """ qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel diff --git a/robohive/envs/hands/baoding_v1.py b/robohive/envs/hands/baoding_v1.py index 1b7cafec..1b6d97da 100644 --- a/robohive/envs/hands/baoding_v1.py +++ b/robohive/envs/hands/baoding_v1.py @@ -264,8 +264,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) - return obs + return super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) def create_goal_trajectory(self, time_step=.1, time_period=6): len_of_goals = 1000 # assumes that its greator than env horizon @@ -326,5 +325,4 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): class BaodingRandomEnvV1(BaodingFixedEnvV1): def reset(self, **kwargs): - obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) - return obs + return super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) \ No newline at end of file diff --git a/robohive/envs/hands/door_v1.py b/robohive/envs/hands/door_v1.py index 8cd9f5e1..4e756e9f 100644 --- a/robohive/envs/hands/door_v1.py +++ b/robohive/envs/hands/door_v1.py @@ -96,17 +96,13 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - self.sim.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2) self.sim.model.body_pos[self.door_bid, 1] = self.np_random.uniform(low=0.25, high=0.35) self.sim.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) def get_env_state(self): diff --git a/robohive/envs/hands/hammer_v1.py b/robohive/envs/hands/hammer_v1.py index 3ef29033..cbdccb92 100644 --- a/robohive/envs/hands/hammer_v1.py +++ b/robohive/envs/hands/hammer_v1.py @@ -120,15 +120,12 @@ def get_obs_dict(self, sim): return obs_dict - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - self.sim.model.body_pos[self.target_bid,2] = self.np_random.uniform(low=0.1, high=0.25) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) + def get_env_state(self): """ diff --git a/robohive/envs/hands/pen_v1.py b/robohive/envs/hands/pen_v1.py index 54539d34..f38c1c1d 100644 --- a/robohive/envs/hands/pen_v1.py +++ b/robohive/envs/hands/pen_v1.py @@ -108,19 +108,14 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - desired_orien = np.zeros(3) desired_orien[0] = self.np_random.uniform(low=-1, high=1) desired_orien[1] = self.np_random.uniform(low=-1, high=1) self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.sim.forward() - - return self.get_obs() + return super().reset(**kwargs) def get_env_state(self): diff --git a/robohive/envs/hands/relocate_v1.py b/robohive/envs/hands/relocate_v1.py index 88e273e2..be2d8ea0 100644 --- a/robohive/envs/hands/relocate_v1.py +++ b/robohive/envs/hands/relocate_v1.py @@ -136,20 +136,16 @@ def get_obs_dict(self, sim): return obs_dict - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - - self.sim.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15) self.sim.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3) self.sim.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2) self.sim.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2) self.sim.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) + def get_env_state(self): """ diff --git a/robohive/envs/multi_task/README.md b/robohive/envs/multi_task/README.md index 4e365916..dc0ee157 100644 --- a/robohive/envs/multi_task/README.md +++ b/robohive/envs/multi_task/README.md @@ -2,7 +2,7 @@ This suite is designed to study generalization in multi-task settings. RoboHive's multi-task suite builds from `FrankaKitchen` environements originally studied in the [Relay Policy Learning](https://relay-policy-learning.github.io/) project. # Franka Kitchen -`FrankaKitchen` domain offers a challenging set of manipulation problems in an unstructured environment with many possible tasks to perform. The original set consisted of a franka robot in a kitchen domain. Overtime, Franka Kitchen has became a popular choice of environments for studying multi-task generalization. Its widespread use has led to a few different publically available variations. To help navigate these changes, we name these variations and document its evolution of across various versions below - +`FrankaKitchen` domain offers a challenging set of manipulation problems in an unstructured environment with many possible tasks to perform. The original set consisted of a franka robot in a kitchen domain. Overtime, Franka Kitchen has became a popular choice of environments for studying multi-task generalization. Its widespread use has led to a few different publically available variations. To help navigate these changes, we name these variations and document its evolution across various versions below - ## Change log/ History diff --git a/robohive/envs/multi_task/multi_task_base_v1.py b/robohive/envs/multi_task/multi_task_base_v1.py index a18adda4..4122593e 100644 --- a/robohive/envs/multi_task/multi_task_base_v1.py +++ b/robohive/envs/multi_task/multi_task_base_v1.py @@ -200,9 +200,9 @@ def get_reward_dict(self, obs_dict): # Optional Keys ("obj_goal", -np.sum(goal_dist, axis=-1)), ("bonus", - 1.0*np.product(goal_dist < 5 * self.obj["dof_proximity"], axis=-1) + 1.0*np.prod(goal_dist < 5 * self.obj["dof_proximity"], axis=-1) # np.product(goal_dist < 0.75 * self.obj["dof_ranges"], axis=-1) - + 1.0*np.product(goal_dist < 1.67 * self.obj["dof_proximity"], axis=-1), + + 1.0*np.prod(goal_dist < 1.67 * self.obj["dof_proximity"], axis=-1), # + np.product(goal_dist < 0.25 * self.obj["dof_ranges"], axis=-1), ), ("pose", -np.sum(np.abs(obs_dict["pose_err"]), axis=-1)), @@ -272,4 +272,4 @@ def set_obj_goal(self, obj_goal=None, interact_site=None): elif type(interact_site) is str: # overwrite using name self.interact_sid = self.sim.model.site_name2id(interact_site) elif type(interact_site) is int: # overwrite using id - self.interact_sid = interact_site \ No newline at end of file + self.interact_sid = interact_site diff --git a/robohive/logger/grouped_datasets.py b/robohive/logger/grouped_datasets.py index 3450af9e..0ba24e65 100644 --- a/robohive/logger/grouped_datasets.py +++ b/robohive/logger/grouped_datasets.py @@ -19,7 +19,6 @@ # access pattern for pickle and h5 backbone post load isn't the same # - Should we get rid of pickle support and double down on h5? # - other way would to make the default container (trace.trace) h5 container instead of a dict -# Should we explicitely keep tract if the trace has been flattened/ stacked/ closed etc? class TraceType(enum.Enum): @@ -32,7 +31,7 @@ def get_type(input_type): """ A more robust way of getting trace type. Supports strings """ - if type(input_type) == str: + if isinstance(input_type, str): if input_type.lower() == "robohive": return TraceType.ROBOHIVE elif input_type.lower() == "roboset": @@ -49,22 +48,23 @@ def __init__(self, name): self.trace = self.root[name] self.index = 0 self.type = TraceType.ROBOHIVE + self.closed = False # False: Trace is open for edits. True: Trace can be analyzed but not edited. # Create a group in your logs def create_group(self, name): self.trace[name] = {} - # Directly add a full dataset to a given group + # Directly add a full dataset to a given group. If data appending is needed, use create_datum() instead def create_dataset(self, group_key, dataset_key, dataset_val): if group_key not in self.trace.keys(): self.create_group(name=group_key) - self.trace[group_key][dataset_key] = [dataset_val] + self.trace[group_key][dataset_key] = dataset_val # Remove dataset from an existing group(s) def remove_dataset(self, group_keys:list, dataset_key:str): - if type(group_keys)==str: + if isinstance(group_keys, str): if group_keys==":": group_keys = self.trace.keys() else: @@ -76,6 +76,13 @@ def remove_dataset(self, group_keys:list, dataset_key:str): del self.trace[group_key][dataset_key] + # Create the first datum of an existing group. Use append_datum() to append more elements + def create_datum(self, group_key, dataset_key, dataset_val): + if group_key not in self.trace.keys(): + self.create_group(name=group_key) + self.trace[group_key][dataset_key] = [dataset_val] + + # Append dataset datum to an existing group def append_datum(self, group_key, dataset_key, dataset_val): assert group_key in self.trace.keys(), "Group:{} does not exist".format(group_key) @@ -114,7 +121,21 @@ def set(self, group_key, dataset_key, dataset_ind=None, dataset_val=None): # verify if a data can be a part of an existing datasets def verify_type(self, dataset, data): dataset_type = type(dataset[0]) - assert type(data) == dataset_type, TypeError("Type mismatch while appending. Datum should be {}".format(dataset_type)) + assert isinstance(data, dataset_type), TypeError("Type mismatch while appending. Datum should be {}".format(dataset_type)) + + # check for array + if isinstance(data, np.ndarray): + assert data.shape == dataset[0].shape, ValueError(f"Data dimenstion({data.shape}) not compatible with dataset dimensions({dataset[0].shape})") + # check for list + if isinstance(data, list): + assert len(data) == len(dataset[0]), ValueError(f"Data dimenstion({len(data)}) not compatible with dataset dimensions({len(dataset[0])})") + # check for dictionary + if isinstance(data, dict): + flattened_data = flatten_dict(data) + flattened_dataset = flatten_dict(dataset[0]) + assert flattened_data.keys() == flattened_dataset.keys(), ValueError(f"Data keys {flattened_data.keys()} not compatible with dataset keys {flattened_dataset.keys()}") + for key in flattened_data: + assert np.array(flattened_data[key]).shape == np.array(flattened_dataset[key]).shape, ValueError(f"Data dimension for key '{key}' ({np.array(flattened_data[key]).shape}) not compatible with dataset dimensions ({np.array(flattened_dataset[key]).shape})") # Verify that all datasets in each groups are of same length. Helpful for time synced traces @@ -126,11 +147,14 @@ def verify_len(self): trace_len = len(self.trace[grp_k][key]) else: key_len = len(self.trace[grp_k][key]) - assert trace_len == key_len, ValueError("len({}[{}]={}, should be {}".format(grp_k, key, key_len, trace_len)) + assert trace_len == key_len, ValueError("Dataset length mismatch: len({}[{}]={}, should be {}".format(grp_k, key, key_len, trace_len)) # Very if trace is stacked and flattened. Useful for utilities like render, save etc def verify_stacked_flattened(self): + if self.closed: + True + for grp_k, grp_v in self.trace.items(): for dst_k, dst_v in grp_v.items(): # Check if stacked @@ -141,6 +165,120 @@ def verify_stacked_flattened(self): return False return True + # plot data + def plot(self, output_dir, output_format, groups:list, datasets:list, x_dataset:str='time'): + # Plot dataset traces using the groups and datasets keys list. T + # ARGUMENTS: + # output_dir: path for output + # output_format: pdf/png/None(for onscreen) + # groups: - list(Groups)_ng to plot: + # - ng = len(groups) == number of subplots + # - ":" to consider each group once + # - None entry in the list can will leave the subplot empty + # datasets: - list(list(datasets))_ng to plot([['left',], ['right', 'top']]), + # - ng = len(groups) == len(datasets) == number of subplots + # - ":" to plot each dataset once + # x_dataset: - dataset key to use as x-axis if available + # EXAMPLES + # 1. plot(..., groups=":", data=":") + # produces a plot with len(groups) subplots + # 2. plot(...,groups=['traj1', 'traj1', 'traj2'], data=[['qpos'], ['qpos','qvel'], ['qvel']]) + # produces a plot with three subplots + + if not self.closed: + prompt("Trace is still open for edits. Close the trace to enable plotting", type=Prompt.WARN) + return + + import matplotlib as mpl + # mpl.use('Agg') + import matplotlib.pyplot as plt + plt.rcParams.update({'font.size': 5}) + h_fig = plt.figure(self.name) + plt.clf() + + # Resolve groups + if isinstance(groups, str) and groups==":": + groups = list(self.trace.keys()) + elif isinstance(groups, str): + groups = [groups] + else: + assert isinstance(groups, list), TypeError(f"Expected a list of groups. Got {groups}") + + # number of subplots + n_subplot = len(groups) + + # Check for datasets + if isinstance(datasets, str) and datasets==":": + datasets = n_subplot*[":"] + elif isinstance(datasets, str): + datasets = [datasets] + else: + assert (isinstance(datasets, list)), TypeError(f"Dataset keys needs to be a list. Got {datasets}") + + # Check for group and datasets sizes + assert len(datasets)==n_subplot, ValueError(f"len(groups):{n_subplot} has to match len(datasets):{len(datasets)}") + # print(groups) + # print(datasets) + + # Run through all groups + for i_grp, grp_key in enumerate(groups): + + # Leave empty if requested + if grp_key is None: + continue + + # process group / subplot + assert isinstance(grp_key, str), TypeError(f"Dataset key needs to be a string. Got {grp_key}") + assert grp_key in self.trace.keys(), "Unknown group {}. Available groups {}".format(grp_key, self.trace.keys()) + grp_val = self.trace[grp_key] + + # print('selected group', grp_key) + + # Resolve datasets within existing group + if isinstance(datasets, str) and datasets==":": + i_grp_datasets = list(grp_val.keys()) + elif isinstance(datasets[i_grp], str) and datasets[i_grp]==":": + i_grp_datasets = list(grp_val.keys()) + else: + i_grp_datasets = datasets[i_grp] + + assert isinstance(i_grp_datasets, list) and isinstance(i_grp_datasets[0], str), TypeError(f"Unrecognized dataset input for group:{grp_key}. Expected ':', or a list from {grp_val.keys()}. Got: {i_grp_datasets}") + + # Run through all dataset requests within the group + for ds_key in i_grp_datasets: + assert ds_key in grp_val.keys(), f"Group: {grp_key} :> Unknown dataset {ds_key}. Available datasets {grp_val.keys()}" + ds_val = grp_val[ds_key] + + assert isinstance(ds_val, np.ndarray), ValueError(f"Dataset for plotting needs to be an array. Provided data:{ds_val}, type:{type(ds_val)}") + assert np.issubdtype(ds_val.dtype, np.number), ValueError(f"Dataset for plotting needs to of numerical dtype. Provided dtype: {ds_val.dtype}") + assert len(ds_val.shape)<3, ValueError(f"Plotting is only supported for 1D and 2D Dataset. Provided data dims: {ds_val.shape}") + + # print(f"g:{grp_key}/ d:{ds_key}") + h_axis = plt.subplot(n_subplot, 1, i_grp+1) + # h_axis.set_prop_cycle(None) + + if x_dataset in grp_val.keys(): + plt.plot(grp_val[x_dataset][:], ds_val, label=f"{ds_key}", marker='') + h_axis.set_xlabel(x_dataset) + elif x_dataset in self.trace.keys(): + plt.plot(self.trace[x_dataset], ds_val, label=f"{ds_key}", marker='') + h_axis.set_xlabel(x_dataset) + else: + plt.plot(ds_val, label=f"{grp_key}/{ds_key}", marker='*') + h_axis.set_title(grp_key) + h_axis.legend() + + + # show/save plot + if output_format is None: + plt.show() + return False + else: + file_name = os.path.join(output_dir, f"{self.name}_{grp_key}_{ds_key}_{output_format}".replace("/", "_")) + # plt.savefig(file_name) + print("saved ", file_name) + return h_fig + # Render frames/videos def render(self, output_dir, output_format, groups:list, datasets:list, input_fps:int=25): @@ -249,10 +387,10 @@ def items(self): return zip(self.trace.keys(), self) # return length - """ - returns the number of groups in the trace - """ def __len__(self) -> str: + """ + returns the number of groups in the trace + """ return len(self.trace.keys()) @@ -337,6 +475,8 @@ def close(self, if verify_length: self.verify_len() + self.closed = True + # Save def save(self, @@ -391,34 +531,109 @@ def load(trace_path, trace_type=TraceType.UNSET): trace.trace = file_data[trace.name] # load data trace.root = file_data # build root trace.trace_type=TraceType.get_type(trace_type) + trace.closed = True return trace +def test_trace_plot(): + trace = Trace("root_name") + + data1 = np.sin(np.arange(0,100)) + data2 = np.cos(np.arange(0,100)) + data3 = np.sin(np.arange(0,200))+np.cos(np.arange(0,200)) + time = 0.01*np.arange(0,200) + + trace.create_group("grp1") + trace.create_dataset(group_key="grp1", dataset_key="dst1", dataset_val=data1) + trace.create_dataset(group_key="grp1", dataset_key="dst2", dataset_val=data2) + + trace.create_group("grp2") + trace.create_dataset(group_key="grp2", dataset_key="time", dataset_val=time) + trace.create_dataset(group_key="grp2", dataset_key="dst3", dataset_val=data3) + trace.close() + + trace.plot(output_format='plot0.pdf', output_dir=".", groups=["grp1",], datasets=[["dst1",],], x_dataset="dst1") + trace.plot(output_format='plot1.pdf', output_dir=".", groups=":", datasets=":") + trace.plot(output_format='plot2.pdf', output_dir=".", groups=":", datasets=[":", ":"]) + trace.plot(output_format='plot3.pdf', output_dir=".", groups=":", datasets=[["dst2",], ":"]) + + # Catch issues plotting string array + try: + trace = Trace("string") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array(["v1", "v2", "v3"])) + trace.close() + trace.plot(output_format=None, output_dir=".", groups=["grp1",], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot array of strings: {e}", type=Prompt.WARN) + + # Catch issues plotting list(strings) + try: + trace = Trace("string") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=["v1", "v2", "v3"]) + trace.close() + trace.plot(output_format='plot4.pdf', output_dir=".", groups=["grp1",], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot list of strings: {e}", type=Prompt.WARN) + + + # plot complex dicts + trace = Trace("root_dict") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":1, "two":2.0, "three":"3"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":11, "two":22.0, "three":"33"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":111, "two":222.0, "three":"333"}) + trace.close() + trace.plot(output_format='plot5.pdf', output_dir=".", groups=["grp1"], datasets=[["dst_k1/one","dst_k1/two"],]) + trace.plot(output_format='plot6.pdf', output_dir=".", groups=["grp1","grp1"], datasets=[["dst_k1/one"],["dst_k1/two"]]) + trace.plot(output_format='plot7.pdf', output_dir=".", groups=["grp1","grp1"], datasets=[["dst_k1/one"],["dst_k1/one","dst_k1/two",]]) + trace.plot(output_format='plot8.pdf', output_dir=".", groups=[None, "grp1"], datasets=[None, ["dst_k1/one","dst_k1/two"]]) + # catch trying to plot strings + try: + trace.plot(output_format=None, output_dir=".", groups=["grp1"], datasets=[":"]) + except Exception as e: + prompt(f"EXPECTED: Caught exception while trying to plot dict with strings: {e}", type=Prompt.WARN) + + + # Catch trying to plot >2D array + trace = Trace("root_3darray") + trace.create_group("grp1") + trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val=np.ones([4, 2, 4])) + trace.close() + try: + trace.plot(output_format='plot9.pdf', output_dir=".", groups=["grp1"], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"EXPECTED: Caught expected exception during plotting >2D dataset: {e}", type=Prompt.WARN) + + # Test trace def test_trace(): trace = Trace("Root_name") # Create a group: append and verify trace.create_group("grp1") - trace.create_dataset(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v11") - trace.create_dataset(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v2") + trace.create_datum(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v2") trace.append_datum(group_key="grp1", dataset_key="dst_k2", dataset_val="dst_v22") trace.verify_len() # Add another group trace.create_group("grp2") - trace.create_dataset(group_key="grp2", dataset_key="dst_k3", dataset_val={"dst_v3":[3]}) - trace.create_dataset(group_key="grp2", dataset_key="dst_k4", dataset_val={"dst_v4":[4]}) + trace.create_datum(group_key="grp2", dataset_key="dst_k3", dataset_val={"dst_v3":[3]}) + trace.create_datum(group_key="grp2", dataset_key="dst_k4", dataset_val={"dst_v4":[4]}) print(trace) # get set methods datum = "dst_v111" trace.set('grp1','dst_k1', 0, datum) assert datum == trace.get('grp1','dst_k1', 0), "Get-Set error" - datum = {"dst_v33":[33]} + datum = {"dst_v4":[0]} trace.set('grp2','dst_k4', 0, datum) assert datum == trace.get('grp2','dst_k4', 0), "Get-Set error" + try: + datum = {"dst_diff_name":[33]} + trace.set('grp2','dst_k4', 0, datum) + except Exception as e: + prompt(f"Caught expected exception trying to insert an inconsistent datum: {e}", type=Prompt.WARN) # save-load methods trace.save(trace_name='test_trace.pickle', verify_length=True) @@ -432,10 +647,77 @@ def test_trace(): print("PKL trace") print(pkl_trace) -if __name__ == '__main__': - test_trace() +def test_trace_append(): + # Create a group: append str + trace = Trace("string") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v1") + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v11") + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val="dst_v111") + # print(trace) + trace.close() + print(trace) + # Create a group: append list(string) + trace = Trace("list(string)") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v1","dst_v2"]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v11","dst_v22"]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=["dst_v111","dst_v222"]) + # print(trace) + trace.close() + print(trace) + # Create a group: append list(float) + trace = Trace("list(float)") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[1, 2]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[11, 22]) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=[111, 222]) + # print(trace) + trace.close(i_res=np.int16) + print(trace) + # Create a group: append dict + trace = Trace("dict") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":1, "two":2.0, "three":"3"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":11, "two":22.0, "three":"33"}) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val={"one":111, "two":222.0, "three":"333"}) + # print(trace) + trace.close() + print(trace) + # Create a group: append ndarray + trace = Trace("ndarray") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([1, 2])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([11, 22])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([111, 222])) + print(trace) + trace.close(i_res=np.int16) + print(trace) + + # Create a group: append ndarray + trace = Trace("ndarray_stack") + trace.create_group("grp1") + trace.create_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.ones([4, 2])) + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.zeros([4, 2])) + try: + trace.append_datum(group_key="grp1", dataset_key="dst_k1", dataset_val=np.array([11, 22])) + except Exception as e: + prompt(f"Caught expected exception during append_datum: {e}", type=Prompt.WARN) + trace.close(i_res=np.int16) + assert trace['grp1']['dst_k1'].shape==(2, 4 ,2), ValueError("Check ndarray concatenation") + try: + trace.plot(output_format=None, output_dir=".", groups=["grp1"], datasets=[["dst_k1",],]) + except Exception as e: + prompt(f"Caught expected exception during plotting >2D dataset: {e}", type=Prompt.WARN) + + + +if __name__ == '__main__': + test_trace() + test_trace_append() + test_trace_plot() diff --git a/robohive/robot/README.md b/robohive/robot/README.md index 7dfa2036..18ac300c 100644 --- a/robohive/robot/README.md +++ b/robohive/robot/README.md @@ -21,7 +21,7 @@ RoboHive provides a unified [base class](hardware_base.py) for all hardware inte RoboHive uses [Polymetis](https://facebookresearch.github.io/fairo/polymetis/) for interfacing with Franka Arms. Please follow the install instructions provided by the Polymetis authors. 2. [Dynamixel](http://www.dynamixel.com/) -RoboHive supports all dynamixel based robots such as [ROBEL](http://roboticsbenchmarks.org/). Please follow driver install instructions [here](https://github.com/vikashplus/dynamixel). +RoboHive supports all dynamixel based robots such as [ROBEL](https://sites.google.com/view/roboticsbenchmarks). Please follow driver install instructions [here](https://github.com/vikashplus/dynamixel). 3. [OptiTrack](https://optitrack.com/) RoboHive supports optitrack motion tracking system. Please follow install instructions [here](https://github.com/vikashplus/OptiTrack) diff --git a/robohive/simhive/franka_sim b/robohive/simhive/franka_sim index 82aaf3be..2e6542c3 160000 --- a/robohive/simhive/franka_sim +++ b/robohive/simhive/franka_sim @@ -1 +1 @@ -Subproject commit 82aaf3bebfa29e00133a6eebc7684e793c668fc1 +Subproject commit 2e6542c36114e92803adfc3016f2c69cd9a8ee8c diff --git a/robohive/tests/test_envs.py b/robohive/tests/test_envs.py index 78137191..cb01fcae 100644 --- a/robohive/tests/test_envs.py +++ b/robohive/tests/test_envs.py @@ -8,6 +8,7 @@ import unittest from robohive.utils import gym +from robohive.utils.implement_for import implement_for import numpy as np import pickle import copy @@ -66,7 +67,8 @@ def check_env(self, environment_id, input_seed): rwd_dict1 = env1.get_reward_dict(obs_dict1) assert len(rwd_dict1) > 0 # reset env - env1.reset() + reset_data = env1.reset() + self.check_reset(reset_data) # serialize / deserialize env ------------ env2w = pickle.loads(pickle.dumps(env1w)) @@ -102,6 +104,20 @@ def check_env(self, environment_id, input_seed): del(env1) del(env2) + + @implement_for("gym", None, "0.26") + def check_reset(self, reset_data): + assert isinstance(reset_data, np.ndarray), "Reset should return the observation vector" + + @implement_for("gym", "0.26", None) + def check_reset(self, reset_data): + assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2" + assert isinstance(reset_data[1], dict), "second element returned should be a dict" + @implement_for("gymnasium") + def check_reset(self, reset_data): + assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2" + assert isinstance(reset_data[1], dict), "second element returned should be a dict" + def check_old_envs(self, module_name, env_names, lite=False, seed=1234): print("\nTesting module:: ", module_name) for env_name in env_names: diff --git a/robohive/tests/test_logger.py b/robohive/tests/test_logger.py index d07e0010..7cbf5787 100644 --- a/robohive/tests/test_logger.py +++ b/robohive/tests/test_logger.py @@ -2,18 +2,42 @@ import click import click.testing -from robohive.logger.grouped_datasets import test_trace +from robohive.logger.grouped_datasets import test_trace, test_trace_append, test_trace_plot from robohive.logger.examine_logs import examine_logs from robohive.utils.examine_env import main as examine_env import os import re +import glob class TestTrace(unittest.TestCase): - def teast_trace(self): + def test_trace(self): # Call your function and test its output/assertions print("Testing Trace Basics") test_trace() + def test_trace_append(self): + # Call your function and test its output/assertions + print("Testing Trace complex appends") + test_trace_append() + + def test_trace_plot(self): + # Call your function and test its output/assertions + print("Testing Trace plotting") + test_trace_plot() + # Define the pattern for the files you want to delete + pattern = "./*plot*.pdf" + + # Use glob to find all files matching the pattern + files_to_delete = glob.glob(pattern) + + # Iterate over the list of files and delete each one + for file_path in files_to_delete: + try: + os.remove(file_path) + print(f"Deleted: {file_path}") + except Exception as e: + print(f"Error deleting file {file_path}: {e}") + class TestExamineTrace(unittest.TestCase): def test_logs_playback(self): diff --git a/robohive/tutorials/3_get_obs_proprio_extero.ipynb b/robohive/tutorials/3_get_obs_proprio_extero.ipynb index d39e8436..ad9614f1 100644 --- a/robohive/tutorials/3_get_obs_proprio_extero.ipynb +++ b/robohive/tutorials/3_get_obs_proprio_extero.ipynb @@ -135,6 +135,60 @@ "print(f\"visual_dict = {env.visual_dict.keys()}\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. convert depth into a point cloud\n", + "\n", + "If you need a point cloud observation, you can convert the depth image into the point cloud in the following way.\n", + "\n", + "First, use env.get_visuals to obtain the depth image in the obs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "obs = env.get_visuals(visual_keys=['rgb:vil_camera:224x224:2d', 'd:vil_camera:224x224:2d'])\n", + "color = obs['rgb:vil_camera:224x224:2d']\n", + "depth = obs['d:vil_camera:224x224:2d'][0]\n", + "plt.imshow(color)\n", + "plt.show()\n", + "plt.imshow(depth)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we convert the depth image into a point cloud based on camera intrinsics and extrinsics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from robohive.utils.pointcloud_utils import get_point_cloud, visualize_point_cloud_from_nparray\n", + "pcd = get_point_cloud(env, obs['d:vil_camera:224x224:2d'][0], 'vil_camera')\n", + "\n", + "# # Interactive point cloud visualization with open3d (requires open3d installation)\n", + "# visualize_point_cloud_from_nparray(pcd, obs['rgb:vil_camera:224x224:2d'], vis_coordinate=True)\n", + "\n", + "# Visualize the point cloud with matplotlib (if you are too lazy to install open3d)\n", + "import matplotlib.pyplot as plt\n", + "fig = plt.figure()\n", + "ax = fig.add_subplot(projection='3d')\n", + "ax.scatter(pcd[:,0], pcd[:,1], pcd[:,2], c=obs['rgb:vil_camera:224x224:2d'].reshape(-1, 3)/256)\n", + "plt.show()\n" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -202,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.10.13" }, "orig_nbformat": 4 }, diff --git a/robohive/tutorials/render_cams.py b/robohive/tutorials/render_cams.py index c5ef5df9..551978d9 100644 --- a/robohive/tutorials/render_cams.py +++ b/robohive/tutorials/render_cams.py @@ -9,29 +9,18 @@ DESC = ''' Helper script to render images offscreen and save using a mujoco model.\n USAGE:\n - $ python render_cams.py --model_path franka_sim.xml --cam_names top_cam --cam_names left_cam --cam_names right_cam \n + $ python render_cams.py --model_path <../model.xml> --cam_names --cam_names \n EXAMPLE:\n - $ python utils/render_cams.py -m envs/fm/assets/franka_microwave.xml -c top_cam -c left_cam -c right_cam + $ python robohive/tutorials/render_cams.py -m robohive/envs/arms/franka/assets/franka_reach_v0.xml -c left_cam -c top_cam + $ python robohive/tutorials/render_cams.py -m robohive/simhive/robel_sim/dkitty/kitty-v2.1.xml -c "A:trackingY" -c "A:trackingZ" + ''' -from mujoco_py import load_model_from_path, MjSim +import mujoco + from PIL import Image -import numpy as np import click - -def render_camera_offscreen(cameras:list, width:int=640, height:int=480, device_id:int=0, sim=None): - """ - Render images(widthxheight) from a list_of_cameras on the specified device_id. - """ - imgs = np.zeros((len(cameras), height, width, 3), dtype=np.uint8) - for ind, cam in enumerate(cameras) : - img = sim.render(width=width, height=height, mode='offscreen', camera_name=cam, device_id=device_id) - img = img[::-1, :, : ] # Image has to be flipped - imgs[ind, :, :, :] = img - return imgs - - @click.command(help=DESC) @click.option('-m', '--model_path', required=True, type=str, help='model file') @click.option('-c', '--cam_names', required=True, multiple=True, help=('Camera names for rendering')) @@ -40,14 +29,23 @@ def render_camera_offscreen(cameras:list, width:int=640, height:int=480, device_ @click.option('-d', '--device_id', type=int, default=0, help='device id for rendering') def main(model_path, cam_names, width, height, device_id): - # render images - model = load_model_from_path(model_path) - sim = MjSim(model) - imgs = render_camera_offscreen(cameras=cam_names, width=width, height=height, device_id=device_id, sim=sim) + + # prepare model, data, scene + mj_model = mujoco.MjModel.from_xml_path(model_path) + mj_data = mujoco.MjData(mj_model) + mujoco.mj_forward(mj_model, mj_data) + + # prepare the renderer + renderer = mujoco.Renderer(mj_model, height=height, width=width) # save images for i, cam in enumerate(cam_names): - image = Image.fromarray(imgs[i]) + # update the scene + renderer.update_scene(mj_data, camera=cam) + # render the rgb_array + rgb_arr = renderer.render() + # save the image + image = Image.fromarray(rgb_arr) image.save(cam+".jpeg") print("saved "+cam+".jpeg") diff --git a/robohive/utils/pointcloud_utils.py b/robohive/utils/pointcloud_utils.py new file mode 100644 index 00000000..7ada4c33 --- /dev/null +++ b/robohive/utils/pointcloud_utils.py @@ -0,0 +1,108 @@ +import numpy as np +from robohive.utils.quat_math import mulQuat, euler2quat, quat2mat + + +# ------ MuJoCo specific functions ------ + +def get_point_cloud(env, depth, camera_name): + # Make sure the depth values are in meters. If the depth comes + # from robohive, it is already in meters. If it directly comes + # from mujoco, you need to use the convert_depth function below. + # Output is flattened. Each row is a point in 3D space. + fovy = env.sim.model.cam_fovy[env.sim.model.camera_name2id(camera_name)] + K = get_intrinsics(fovy, depth.shape[0], depth.shape[1]) + pc = depth2xyz(depth, K) + pc = pc.reshape(-1, 3) + + transform = get_extrinsics(env, camera_name=camera_name) + new_pc = np.ones((pc.shape[0], 4)) + new_pc[:, :3] = pc + new_pc = (transform @ new_pc.transpose()).transpose() + return new_pc[:, :-1] + + +def convert_depth(env, depth): + # Convert raw depth values into meters + # Check this as well: https://github.com/deepmind/dm_control/blob/master/dm_control/mujoco/engine.py#L734 + extent = env.sim.model.stat.extent + near = env.sim.model.vis.map.znear * extent + far = env.sim.model.vis.map.zfar * extent + depth_m = depth * 2 - 1 + depth_m = (2 * near * far) / (far + near - depth_m * (far - near)) + return depth_m + + +def get_extrinsics(env, camera_name): + # Transformation from camera frame to world frame + cam_id = env.sim.model.camera_name2id(camera_name) + cam_pos = env.sim.model.cam_pos[cam_id] + cam_quat = env.sim.model.cam_quat[cam_id] + cam_quat = mulQuat(cam_quat, euler2quat([np.pi, 0, 0])) + return get_transformation_matrix(cam_pos, cam_quat) + + +def get_transformation_matrix(pos, quat): + # Convert the pose from MuJoCo format to a 4x4 transformation matrix + arr = np.identity(4) + arr[:3, :3] = quat2mat(quat) + arr[:3, 3] = pos + return arr + + +# ------ General functions ------ + +def get_intrinsics(fovy, img_width, img_height): + # Get the camera intrinsics matrix + aspect = float(img_width) / img_height + fovx = 2 * np.arctan(np.tan(np.deg2rad(fovy) * 0.5) * aspect) + fovx = np.rad2deg(fovx) + cx = img_width / 2. + cy = img_height / 2. + fx = cx / np.tan(np.deg2rad(fovx / 2.)) + fy = cy / np.tan(np.deg2rad(fovy / 2.)) + K = np.zeros((3,3), dtype=np.float64) + K[2][2] = 1 + K[0][0] = fx + K[1][1] = fy + K[0][2] = cx + K[1][2] = cy + return K + + +def depth2xyz(depth, cam_K): + # Convert depth image to point cloud + h, w = depth.shape + ymap, xmap = np.meshgrid(np.arange(w), np.arange(h)) + + x = ymap + y = xmap + z = depth + + x = (x - cam_K[0,2]) * z / cam_K[0,0] + y = (y - cam_K[1,2]) * z / cam_K[1,1] + + xyz = np.stack([x, y, z], axis=2) + return xyz + + +def visualize_point_cloud_from_nparray(d, c=None, vis_coordinate=False): + # Visualize a point cloud using open3d + if c is not None: + if len(c.shape) == 3: + c = c.reshape(-1, 3) + if c.max() > 1: + c = c.astype(np.float64)/256 + + import open3d as o3d + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(d) + if c is not None: + pcd.colors = o3d.utility.Vector3dVector(c) + + if vis_coordinate: + # Visualize coordinate frame + mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5) + o3d.visualization.draw_geometries([mesh, pcd]) + else: + o3d.visualization.draw_geometries([pcd]) + diff --git a/robohive/utils/prompt_utils.py b/robohive/utils/prompt_utils.py index 9d59d632..5d1752f7 100644 --- a/robohive/utils/prompt_utils.py +++ b/robohive/utils/prompt_utils.py @@ -7,7 +7,7 @@ """ Utility script to help with information verbosity produced by RoboHive -To control verbosity set env variable ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS +To control verbosity set env variable ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT """ from termcolor import cprint @@ -19,9 +19,9 @@ class Prompt(enum.IntEnum): """Prompt verbosity types""" ALL = 0 # print everything (lowest priority) - INFO = 1 - WARN = 2 - ERROR = 3 + INFO = 1 # useful info + WARN = 2 # warnings (default) + ERROR = 3 # errors ONCE = 4 # print: once and higher ALWAYS = 5 # print: only always (highest priority) SILENT = 6 # Supress all prints @@ -33,7 +33,7 @@ class Prompt(enum.IntEnum): # Infer verbose mode to be used VERBOSE_MODE = os.getenv('ROBOHIVE_VERBOSITY') -if VERBOSE_MODE==None: +if VERBOSE_MODE is None: VERBOSE_MODE = Prompt.WARN else: VERBOSE_MODE = VERBOSE_MODE.upper() diff --git a/robohive/utils/tensor_utils.py b/robohive/utils/tensor_utils.py index 5f54d284..6de21549 100644 --- a/robohive/utils/tensor_utils.py +++ b/robohive/utils/tensor_utils.py @@ -1,6 +1,4 @@ -# Source: https://github.dev/aravindr93/mjrl/tree/master/mjrl -import operator - +# Adapted from Source: https://github.com/aravindr93/mjrl/tree/master/mjrl import numpy as np diff --git a/setup.py b/setup.py index b19980c0..7af1f1d7 100644 --- a/setup.py +++ b/setup.py @@ -19,74 +19,84 @@ (2) OSX: brew install ffmpeg""" raise ModuleNotFoundError(help) -if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 8): - print("This library requires Python 3.8 or higher, but you are running " - "Python {}.{}. The installation will likely fail.".format(sys.version_info.major, sys.version_info.minor)) +if sys.version_info.major < 3 or ( + sys.version_info.major == 3 and sys.version_info.minor < 8 +): + print( + "This library requires Python 3.8 or higher, but you are running " + "Python {}.{}. The installation will likely fail.".format( + sys.version_info.major, sys.version_info.minor + ) + ) + def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() + def package_files(directory): paths = [] - for (path, directories, filenames) in os.walk(directory): + for path, directories, filenames in os.walk(directory): for filename in filenames: - paths.append(os.path.join('..', path, filename)) + paths.append(os.path.join("..", path, filename)) return paths -extra_files = package_files('robohive') + +extra_files = package_files("robohive") setup( - name='robohive', - version='0.7.0', - license='Apache 2.0', + name="robohive", + version="0.7.0", + license="Apache 2.0", packages=find_packages(), - package_data={"": extra_files+['../robohive_init.py']}, + package_data={"": extra_files + ["../robohive_init.py"]}, include_package_data=True, - description='A Unified Framework for Robot Learning', - long_description=read('README.md'), + description="A Unified Framework for Robot Learning", + long_description=read("README.md"), long_description_content_type="text/markdown", - url='https://github.com/vikashplus/robohive.git', - author='Vikash Kumar', + url="https://github.com/vikashplus/robohive.git", + author="Vikash Kumar", author_email="vikahsplus@gmail.com", install_requires=[ - 'click', + "click", # 'gym==0.13', # default to this stable point if caught in gym issues. - 'gymnasium==0.29.1', - 'mujoco==3.1.3', - 'dm-control==1.0.16', - 'termcolor', - 'sk-video', - 'flatten_dict', - 'matplotlib', - 'ffmpeg', - 'absl-py', - 'torch', - 'h5py==3.7.0', - 'pink-noise-rl', - 'gitpython' + "gymnasium==0.29.1", + "mujoco==3.1.3", + "numpy>=2", + "dm-control==1.0.16", + "termcolor", + "sk-video", + "flatten_dict", + "matplotlib", + "ffmpeg", + "absl-py", + "torch", + "h5py>=3.11.0", + "pink-noise-rl", + "gitpython", ], extras_require={ - # To use mujoco bindings, run (pip install -e ".[mujoco]") and set sim_backend=MUJOCO - 'mujoco_py':[ - 'free-mujoco-py', + # To use mujoco bindings, run (pip install -e ".[mujoco]") and set sim_backend=MUJOCO + "mujoco_py": [ + "free-mujoco-py", + ], + # To use hardware dependencies, run (pip install -e ".[a0]") and follow install instructions inside robot + "a0": [ + "pycapnp>=1.1.1", + "alephzero", # real_sense subscribers dependency ], - # To use hardware dependencies, run (pip install -e ".[a0]") and follow install instructions inside robot - 'a0': [ - 'pycapnp>=1.1.1', - 'alephzero', # real_sense subscribers dependency + "encoder": [ + "torchvision", + # Unlike pypi, Git dependencies can be directly installed in editable mode. + # To use r3m/vc encoders, uncomment below and run (pip install -e ".[encoder]") + # 'r3m @ git+https://github.com/facebookresearch/r3m.git', + # 'vc_models @ git+https://github.com/facebookresearch/eai-vc.git@9958b278666bcbde193d665cc0df9ccddcdb8a5a#egg=vc_models&subdirectory=vc_models', ], - 'encoder':[ - 'torchvision', - # Unlike pypi, Git dependencies can be directly installed in editable mode. - # To use r3m/vc encoders, uncomment below and run (pip install -e ".[encoder]") - # 'r3m @ git+https://github.com/facebookresearch/r3m.git', - # 'vc_models @ git+https://github.com/facebookresearch/eai-vc.git@9958b278666bcbde193d665cc0df9ccddcdb8a5a#egg=vc_models&subdirectory=vc_models', - ] }, entry_points={ - 'console_scripts': [ - 'robohive_init = robohive_init:fetch_simhive', - 'robohive_clean = robohive_init:clean_simhive', + "console_scripts": [ + "robohive_init = robohive_init:fetch_simhive", + "robohive_clean = robohive_init:clean_simhive", ], }, )