Skip to content

Commit

Permalink
Fixes #991 (#1047)
Browse files Browse the repository at this point in the history
- average output in RunningAverage
  • Loading branch information
vfdev-5 committed May 16, 2020
1 parent 25d38d1 commit 8886948
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

Expand Down Expand Up @@ -106,7 +107,9 @@ def _get_metric_value(self) -> Union[torch.Tensor, float]:

@sync_all_reduce("src")
def _get_output_value(self) -> Union[torch.Tensor, float]:
return self.src
# we need to compute average instead of sum produced by @sync_all_reduce("src")
output = self.src / idist.get_world_size()
return output

def _metric_iteration_completed(self, engine: Engine) -> None:
self.src.started(engine)
Expand Down
1 change: 1 addition & 0 deletions tests/ignite/metrics/test_running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def running_avg_output_init(engine):
def running_avg_output_update(engine):
i = engine.state.iteration - 1
o = sum([all_loss_values[i + j * k] for j in range(idist.get_world_size())]).item()
o /= idist.get_world_size()
if engine.state.running_avg_output is None:
engine.state.running_avg_output = o
else:
Expand Down

0 comments on commit 8886948

Please sign in to comment.