Skip to content

Commit

Permalink
fix auto pin_memory : idist.device().type should be used (pytorch#1131)
Browse files Browse the repository at this point in the history
* fix auto pin_memory : idist.device().type should be used

* fix cuda in device

* fix test

* use idist.device().type to test

* add missing ()

Co-authored-by: Desroziers <sylvain.desroziers@ifpen.fr>
  • Loading branch information
sdesrozis and Desroziers committed Jun 15, 2020
1 parent 3b1dc3a commit db63e94
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion ignite/distributed/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def auto_dataloader(dataset, **kwargs):
)
kwargs["pin_memory"] = False
else:
kwargs["pin_memory"] = kwargs.get("pin_memory", idist.device() == "cuda")
kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type)

logger.info("Use data loader kwargs for dataset '{}': \n\t{}".format(repr(dataset)[:20].strip(), kwargs))
dataloader = DataLoader(dataset, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/distributed/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _test_auto_dataloader(ws, nproc, sampler_name=None, dl_type=DataLoader):
sampler_type = DistributedSampler if sampler is None else DistributedProxySampler
assert isinstance(dataloader.sampler, sampler_type)
if isinstance(dataloader, DataLoader):
assert dataloader.pin_memory == (data.device == "cuda")
assert dataloader.pin_memory == ("cuda" in idist.device().type)


def _test_auto_model_optimizer(ws, device):
Expand Down

0 comments on commit db63e94

Please sign in to comment.