diff --git a/examples/contrib/cifar10/main.py b/examples/contrib/cifar10/main.py index 32a5d1c45bb..e62d1b41d01 100644 --- a/examples/contrib/cifar10/main.py +++ b/examples/contrib/cifar10/main.py @@ -1,5 +1,6 @@ from datetime import datetime from pathlib import Path +from typing import Any, Optional import fire import torch @@ -136,27 +137,27 @@ def _(): def run( - seed=543, - data_path="/tmp/cifar10", - output_path="/tmp/output-cifar10/", - model="resnet18", - batch_size=512, - momentum=0.9, - weight_decay=1e-4, - num_workers=12, - num_epochs=24, - learning_rate=0.4, - num_warmup_epochs=4, - validate_every=3, - checkpoint_every=1000, - backend=None, - resume_from=None, - log_every_iters=15, - nproc_per_node=None, - stop_iteration=None, - with_clearml=False, - with_amp=False, - **spawn_kwargs, + seed: int = 543, + data_path: str = "/tmp/cifar10", + output_path: str = "/tmp/output-cifar10/", + model: str = "resnet18", + batch_size: int = 512, + momentum: float = 0.9, + weight_decay: float = 1e-4, + num_workers: int = 12, + num_epochs: int = 24, + learning_rate: float = 0.4, + num_warmup_epochs: int = 4, + validate_every: int = 3, + checkpoint_every: int = 1000, + backend: Optional[str] = None, + resume_from: Optional[str] = None, + log_every_iters: int = 15, + nproc_per_node: Optional[int] = None, + stop_iteration: Optional[int] = None, + with_clearml: bool = False, + with_amp: bool = False, + **spawn_kwargs: Any, ): """Main entry to train an model on CIFAR10 dataset.