Skip to content

Commit

Permalink
Force state rand vec updates to update the sim (Farama-Foundation#328)
Browse files Browse the repository at this point in the history
* Force state rand vec updates to update the sim

In some environments, the locations of objects are updated
via changes to the underlying simulator.

Those changes don't reflect in the simulator until the simulator
itself has been stepped (a call to sim.step()).

This commit forces sim.step when state rand vecs are changed by calling
reset inside set_task.

closes Farama-Foundation#324
  • Loading branch information
avnishn committed Feb 26, 2021
1 parent b7bd661 commit d9a75c4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
1 change: 1 addition & 0 deletions metaworld/envs/mujoco/sawyer_xyz/sawyer_xyz_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def set_task(self, task):
self._partially_observable = data['partially_observable']
del data['partially_observable']
self._set_task_inner(**data)
self.reset()

def set_xyz_action(self, action):
action = np.clip(action, -1, 1)
Expand Down
41 changes: 41 additions & 0 deletions tests/metaworld/envs/mujoco/sawyer_xyz/test_sawyer_xyz_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np

import random

import metaworld

def test_reset_returns_same_obj_and_goal():
benchmark = metaworld.MT50()
env_dict = benchmark.train_classes
tasks = benchmark.train_tasks
initial_obj_poses = {name: [] for name in env_dict.keys()}
goal_poses = {name: [] for name in env_dict.keys()}

# Execute rollout for each environment in benchmark.
for env_name, env_cls in env_dict.items():

# Create environment and set task.
env = env_cls()
env_tasks = [t for t in tasks if t.env_name == env_name]
env.set_task(random.choice(env_tasks))

# Step through environment for a fixed number of episodes.
for _ in range(2):
# Reset environment and extract initial object position.
obs = env.reset()
goal = obs[-3:]
goal_poses[env_name].append(goal)
initial_obj_pos = obs[3:9]
initial_obj_poses[env_name].append(initial_obj_pos)

# Display initial object positions and find environments with non-unique positions.
violating_envs_obs = []
for env_name, task_initial_pos in initial_obj_poses.items():
if len(np.unique(np.array(task_initial_pos), axis=0)) > 1:
violating_envs_obs.append(env_name)
violating_envs_goals = []
for env_name, target_pos in goal_poses.items():
if len(np.unique(np.array(target_pos), axis=0)) > 1:
violating_envs_goals.append(env_name)
assert not violating_envs_obs
assert not violating_envs_goals

0 comments on commit d9a75c4

Please sign in to comment.