-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
78 lines (63 loc) · 2.44 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from argparse import ArgumentParser
import os
import logging
import yaml
from time import time, sleep
import subprocess
from omegaconf import OmegaConf
from spkanon_eval.main import main
from spkanon_eval.utils import seed_everything
def setup(args):
config = load_subconfigs(yaml.full_load(open(args.config)))
config = OmegaConf.create(config)
config.device = args.device
config.data.config.num_workers = args.num_workers
config.commit_hash = (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
.decode("ascii")
.strip()
)
# create the logging directory
exp_folder = os.path.join(config.log_dir, str(int(time())))
while os.path.exists(exp_folder):
sleep(1)
exp_folder = os.path.join(config.log_dir, str(int(time())))
os.makedirs(exp_folder)
config.exp_folder = exp_folder
# if a seed is specified, set it
if config.seed is not None:
seed_everything(config.seed)
# dump config file to experiment folder
OmegaConf.save(config, os.path.join(exp_folder, "exp_config.yaml"))
# create logger in experiment folder to log progress: dump to file and stdout
logger_name = "progress"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler(os.path.join(exp_folder, f"{logger_name}.log"))
file_handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(file_handler)
logger.addHandler(logging.StreamHandler())
return config, exp_folder
def load_subconfigs(config):
"""
Given a config, load all the subconfigs that are specified in the config into
the same level as the parameter. Configs are specified by parameters ending with
'_cfg'. If a value of the config is a dict, call this function again recursively.
"""
full_config = dict()
for key, value in config.items():
if isinstance(value, dict):
full_config[key] = load_subconfigs(value)
elif key.endswith("_cfg"):
full_config.update(load_subconfigs(yaml.full_load(open(value))))
else:
full_config[key] = value
return full_config
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config")
parser.add_argument("--device", default="cuda")
parser.add_argument("--num_workers", default=10, type=int)
main(*setup(parser.parse_args()))