diff --git a/mujoco_py/mjsim.pyx b/mujoco_py/mjsim.pyx index cfe91349..81b9dcfd 100644 --- a/mujoco_py/mjsim.pyx +++ b/mujoco_py/mjsim.pyx @@ -56,8 +56,10 @@ cdef class MjSim(object): cdef public int nsubsteps # User defined state. cdef readonly dict udd_state + # Schema for udd_state. This schema shouldn't change between steps. + cdef readonly dict _schema_example # User defined dynamics callback - cdef readonly object _udd_callback + cdef public object udd_callback # Allows to store extra information in MjSim. cdef readonly dict extras # Function pointer for substep callback, stored as uintptr @@ -82,7 +84,8 @@ cdef class MjSim(object): self.render_contexts = [] self._render_context_offscreen = None self._render_context_window = None - self.udd_state = None + self.udd_state = {} + self._schema_example = None self.udd_callback = udd_callback self.render_callback = render_callback self.extras = {} @@ -95,9 +98,9 @@ cdef class MjSim(object): with wrap_mujoco_warning(): mj_resetData(self.model.ptr, self.data.ptr) - self.udd_state = None + self.udd_state = {} + self._schema_example = None self.step_udd() - def forward(self): """ Computes the forward kinematics. Calls ``mj_forward`` internally. @@ -176,16 +179,6 @@ cdef class MjSim(object): elif not render_context.offscreen and self._render_context_window is None: self._render_context_window = render_context - @property - def udd_callback(self): - return self._udd_callback - - @udd_callback.setter - def udd_callback(self, value): - self._udd_callback = value - self.udd_state = None - self.step_udd() - cpdef substep_callback(self): if self.substep_callback_ptr: (self.substep_callback_ptr)(self.model.ptr, self.data.ptr) @@ -217,25 +210,26 @@ cdef class MjSim(object): raise TypeError('invalid: {}'.format(type(substep_callback))) def step_udd(self): - if self._udd_callback is None: - self.udd_state = {} - else: - schema_example = self.udd_state - self.udd_state = self._udd_callback(self) - # Check to make sure the udd_state has consistent keys and dimension across steps - if schema_example is not None: - keys = set(schema_example.keys()) | set(self.udd_state.keys()) - for key in keys: - assert key in schema_example, "Keys cannot be added to udd_state between steps." - assert key in self.udd_state, "Keys cannot be dropped from udd_state between steps." - if isinstance(schema_example[key], Number): - assert isinstance(self.udd_state[key], Number), \ - "Every value in udd_state must be either a number or a numpy array" - else: - assert isinstance(self.udd_state[key], np.ndarray), \ - "Every value in udd_state must be either a number or a numpy array" - assert self.udd_state[key].shape == schema_example[key].shape, \ - "Numpy array values in udd_state must keep the same dimension across steps." + if self.udd_callback is None: + return + if len(self.udd_state) > 0 and self._schema_example is None: + self._schema_example = self.udd_state + self.udd_state = self.udd_callback(self) + + # Check to make sure the udd_state has consistent keys and dimension across steps + if self._schema_example is not None: + keys = set(self._schema_example.keys()) | set(self.udd_state.keys()) + for key in keys: + assert key in self._schema_example, "Keys cannot be added to udd_state between steps." + assert key in self.udd_state, "Keys cannot be dropped from udd_state between steps." + if isinstance(self._schema_example[key], Number): + assert isinstance(self.udd_state[key], Number), \ + "Every value in udd_state must be either a number or a numpy array" + else: + assert isinstance(self.udd_state[key], np.ndarray), \ + "Every value in udd_state must be either a number or a numpy array" + assert self.udd_state[key].shape == self._schema_example[key].shape, \ + "Numpy array values in udd_state must keep the same dimension across steps." def get_state(self): """ Returns a copy of the simulator state. """ diff --git a/mujoco_py/tests/test_cymj.py b/mujoco_py/tests/test_cymj.py index 8f11df5b..7b396d18 100644 --- a/mujoco_py/tests/test_cymj.py +++ b/mujoco_py/tests/test_cymj.py @@ -114,7 +114,7 @@ def udd_callback(sim): return d sim = MjSim(model, nsubsteps=2, udd_callback=udd_callback) - + sim.step_udd() assert(sim.udd_state is not None) assert(sim.udd_state["foo"] == foo) assert(sim.udd_state["foo_2"].shape[0] == 2) @@ -158,6 +158,8 @@ def udd_callback(sim): return {"foo": foo} sims = [MjSim(model, udd_callback=udd_callback) for _ in range(2)] + for sim in sims: + sim.step_udd() sim_pool = MjSimPool(sims, nsubsteps=2) for i in range(len(sim_pool.sims)): @@ -260,6 +262,7 @@ def udd_callback(sim): return d sim = MjSim(model, nsubsteps=2, udd_callback=udd_callback) + sim.step_udd() state = sim.get_state() assert np.array_equal(state.time, sim.data.time) diff --git a/mujoco_py/version.py b/mujoco_py/version.py index b7203e1a..171a9cde 100644 --- a/mujoco_py/version.py +++ b/mujoco_py/version.py @@ -1,6 +1,6 @@ __all__ = ['__version__', 'get_version'] -version_info = (1, 50, 1, 26) +version_info = (1, 50, 1, 27) # format: # ('mujoco_major', 'mujoco_minor', 'mujoco_py_major', 'mujoco_py_minor')