Skip to content

Commit

Permalink
Unroll infos for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
jsuarez5341 committed Jun 22, 2023
1 parent 29d9389 commit 12c1e4a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 9 additions & 1 deletion clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
import pufferlib.vectorization.serial
import wandb

def unroll_nested_dict(d):
for k, v in d.items():
if isinstance(v, dict):
for k2, v2 in unroll_nested_dict(v):
yield f'{k}/{k2}', v2
else:
yield k, v

@dataclass
class CleanPuffeRL:
binding: pufferlib.emulation.Binding
Expand Down Expand Up @@ -245,7 +253,7 @@ def evaluate(self, agent, data):

for item in i:
for agent_info in item.values():
for name, stat in agent_info.items():
for name, stat in unroll_nested_dict(agent_info):
try:
stat = float(stat)
stats[name].append(stat)
Expand Down
2 changes: 1 addition & 1 deletion pufferlib/emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _poststep(self, obs, rewards, dones, infos):
if team not in infos:
infos[team] = {}

infos[team ]= self._handle_infos(rewards[team], dones[team], infos[team], team)
infos[team]= self._handle_infos(rewards[team], dones[team], infos[team], team)

# Observation shape test
if __debug__:
Expand Down

0 comments on commit 12c1e4a

Please sign in to comment.