Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions robohive/utils/examine_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@ def __init__(self, env, seed):

def get_action(self, obs):
# return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low)
return self.env.action_space.sample(), {'mode': 'random samples'}
return self.env.action_space.sample(), {'mode': 'random samples', 'evaluation':self.env.action_space.sample()}

def load_class_from_str(module_name, class_name):
try:
m = __import__(module_name, globals(), locals(), class_name)
return getattr(m, class_name)
except (ImportError, AttributeError):
return None

# MAIN =========================================================
@click.command(help=DESC)
Expand All @@ -57,13 +64,19 @@ def main(env_name, policy_path, mode, seed, num_episodes, render, camera_name, o

# resolve policy and outputs
if policy_path is not None:
pi = pickle.load(open(policy_path, 'rb'))
if output_dir == './': # overide the default
output_dir, pol_name = os.path.split(policy_path)
output_name = os.path.splitext(pol_name)[0]
if output_name is None:
pol_name = os.path.split(policy_path)[1]
output_name = os.path.splitext(pol_name)[0]
policy_tokens = policy_path.split('.')
pi = load_class_from_str('.'.join(policy_tokens[:-1]), policy_tokens[-1])

if pi is not None:
pi = pi(env, seed)
else:
pi = pickle.load(open(policy_path, 'rb'))
if output_dir == './': # overide the default
output_dir, pol_name = os.path.split(policy_path)
output_name = os.path.splitext(pol_name)[0]
if output_name is None:
pol_name = os.path.split(policy_path)[1]
output_name = os.path.splitext(pol_name)[0]
else:
pi = rand_policy(env, seed)
mode = 'exploration'
Expand Down