-
Notifications
You must be signed in to change notification settings - Fork 8
/
register_azureml.py
142 lines (112 loc) · 4.77 KB
/
register_azureml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import argparse
import json
import logging
import os
import subprocess
from pathlib import Path
from typing import Any
from azure.ai.ml import MLClient, load_component, load_data, load_environment
from azure.ai.ml.entities import Data, Environment
from azure.identity import DefaultAzureCredential
_logger = logging.getLogger(__file__)
logging.basicConfig(level=logging.INFO)
REG_CONFIG_FILENAME = "registration_config.json"
ENV_KEY = "environments"
COMP_KEY = "components"
DATA_KEY = "data"
SUBDIR_KEY = "nested_directories"
def parse_args():
# setup arg parser
parser = argparse.ArgumentParser()
# add arguments
parser.add_argument(
"--workspace_config", type=str, help="Path to workspace config.json"
)
parser.add_argument(
"--component_config", type=str, help="Path to component_config.json"
)
parser.add_argument("--base_directory", type=str, help="Path to base directory")
# parse args
args = parser.parse_args()
# return args
return args
def read_json_path(path: str) -> Any:
_logger.info("Reading JSON file {0}".format(path))
with open(path, "r") as f:
result = json.load(f)
return result
def process_file(input_file, output_file, replacements) -> None:
with open(input_file, "r") as infile, open(output_file, "w") as outfile:
for line in infile:
for f, r in replacements.items():
line = line.replace(f, r)
outfile.write(line)
def process_directory(directory: Path, ml_client: MLClient, version: int) -> None:
_logger.info("Processing: {0}".format(directory))
assert directory.is_absolute()
registration_file = directory / REG_CONFIG_FILENAME
reg_config = read_json_path(registration_file.resolve())
replacements = {"VERSION_REPLACEMENT_STRING": str(version)}
_logger.info("Changing directory")
os.chdir(directory)
if ENV_KEY in reg_config.keys():
for e in reg_config[ENV_KEY]:
_logger.info("Registering environment: {0}".format(e))
processed_file = e + ".processed"
process_file(e, processed_file, replacements)
curr_env: Environment = load_environment(processed_file)
ml_client.environments.create_or_update(curr_env)
_logger.info("Registered {0}".format(curr_env.name))
else:
_logger.info("No key for environments")
if COMP_KEY in reg_config.keys():
for c in reg_config[COMP_KEY]:
_logger.info("Registering component: {0}".format(c))
processed_file = c + ".processed"
process_file(c, processed_file, replacements)
curr_component = load_component(source=processed_file)
ml_client.components.create_or_update(curr_component)
_logger.info("Registered {0}".format(curr_component.name))
else:
_logger.info("No key for components")
if DATA_KEY in reg_config.keys():
_logger.info("Working through data entries")
for data_info in reg_config[DATA_KEY]:
script_file = data_info["script"]
_logger.info("Running script {0}".format(script_file))
subprocess.run(["python", script_file], check=True)
for d in data_info["data_yamls"]:
_logger.info("Processing {0}".format(d))
processed_file = d + ".processed"
process_file(d, processed_file, replacements)
curr_dataset: Data = load_data(processed_file)
ml_client.data.create_or_update(curr_dataset)
_logger.info("Registered {0}".format(curr_dataset.name))
else:
_logger.info("No key for datasets")
if SUBDIR_KEY in reg_config.keys():
_logger.info("Working through nested directories")
for d in reg_config[SUBDIR_KEY]:
next_dir = directory / d
process_directory(next_dir.resolve(), ml_client, version)
os.chdir(directory)
else:
_logger.info("No subdirectories found for {0}".format(directory))
def main(args):
ws_config = read_json_path(args.workspace_config)
component_config = read_json_path(args.component_config)
ml_client = MLClient(
credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
subscription_id=ws_config["subscription_id"],
resource_group_name=ws_config["resource_group"],
workspace_name=ws_config["workspace_name"],
logging_enable=True,
)
version: int = component_config["version"]
process_directory(Path(args.base_directory).resolve(), ml_client, version)
if __name__ == "__main__":
args = parse_args()
main(args)