Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDG-DA paper code #743

Merged
merged 47 commits into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
48f8694
Merge data selection to main
wendili-cs Jul 1, 2021
5bb06cd
Update trainer for reweighter
wendili-cs Jul 1, 2021
4f442f5
Typos fixed.
wendili-cs Jul 8, 2021
da013fd
Merge branch 'main' into ds
you-n-g Jul 30, 2021
81b4383
update data selection interface
you-n-g Aug 9, 2021
aa2699f
successfully run exp after refactor some interface
you-n-g Aug 13, 2021
d17aaac
data selection share handler & trainer
you-n-g Aug 20, 2021
82b4115
fix meta model time series bug
you-n-g Aug 22, 2021
5b118c4
fix online workflow set_uri bug
you-n-g Sep 13, 2021
3b073f7
fix set_uri bug
you-n-g Sep 26, 2021
384b670
Merge remote-tracking branch 'origin/main' into ds
you-n-g Sep 26, 2021
b0850b0
updawte ds docs and delay trainer bug
you-n-g Sep 27, 2021
051b261
Merge remote-tracking branch 'wd_ds/ds' into ds
you-n-g Oct 9, 2021
f10d726
Merge branch 'main' into ds
you-n-g Nov 14, 2021
cdcfe30
Merge remote-tracking branch 'origin/main' into ds
you-n-g Nov 14, 2021
6d61ad0
Merge remote-tracking branch 'origin/main' into ds
you-n-g Nov 16, 2021
f32a7ad
docs
you-n-g Nov 16, 2021
8fb37b6
resume reweighter
you-n-g Nov 16, 2021
21baead
add reweighting result
you-n-g Nov 16, 2021
12afe61
fix qlib model import
you-n-g Nov 17, 2021
1d9732b
make recorder more friendly
you-n-g Nov 17, 2021
20a8fe5
fix experiment workflow bug
you-n-g Nov 18, 2021
faf3e03
commit for merging master incase of conflictions
you-n-g Dec 9, 2021
76d1bd9
Merge remote-tracking branch 'origin/main' into ds
you-n-g Dec 9, 2021
3bc4030
Successful run DDG-DA with a single command
you-n-g Dec 11, 2021
49c4074
remove unused code
you-n-g Dec 11, 2021
ce66d9a
asdd more docs
you-n-g Dec 13, 2021
cea134d
Update README.md
you-n-g Dec 13, 2021
a4a2b32
Update & fix some bugs.
demon143 Jan 8, 2022
8241832
Update configuration & remove debug functions
wendili-cs Jan 8, 2022
e1b079d
Update README.md
wendili-cs Jan 9, 2022
6a3f471
Modfify horizon from code rather than yaml
wendili-cs Jan 9, 2022
c3364cd
Update performance in README.md
wendili-cs Jan 9, 2022
b3d1081
Merge remote-tracking branch 'origin/main' into ds
you-n-g Jan 9, 2022
fa2d047
fix part comments
you-n-g Jan 9, 2022
efab5cb
Remove unfinished TCTS.
wendili-cs Jan 10, 2022
5a184eb
Fix some details.
wendili-cs Jan 10, 2022
8fee1b4
Update meta docs
wendili-cs Jan 10, 2022
a31a4d5
Update README.md of the benchmarks_dynamic
wendili-cs Jan 10, 2022
ca3fe76
Merge branch 'main' into ds
you-n-g Jan 10, 2022
97f61d5
Update README.md files
wendili-cs Jan 10, 2022
2726560
Merge branch 'ds' of wd_git:you-n-g/qlib into ds
wendili-cs Jan 10, 2022
da68103
Add README.md to the rolling_benchmark baseline.
wendili-cs Jan 10, 2022
7e1183b
Refine the docs and link
you-n-g Jan 10, 2022
b0857c2
Rename README.md in benchmarks_dynamic.
wendili-cs Jan 10, 2022
38b83dd
Remove comments.
wendili-cs Jan 10, 2022
34f5bd2
auto download data
you-n-g Jan 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
data selection share handler & trainer
  • Loading branch information
you-n-g committed Aug 20, 2021
commit d17aaac659ca91c445b42c5cc6b460568263ff7d
125 changes: 68 additions & 57 deletions qlib/contrib/meta/data_selection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,66 @@ def __init__(
self.max_epoch = max_epoch
self.fitted = False

def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False):
if phase == "train": # phase 0 for training, 1 for inference
self.tn.train()
torch.set_grad_enabled(True)
else:
self.tn.eval()
torch.set_grad_enabled(False)
running_loss = 0.0
pred_y_all = []
for task in tqdm(task_list, desc=f"{phase} Task", leave=False):
meta_input = task.get_meta_input()
pred, weights = self.tn(
meta_input["X"],
meta_input["y"],
meta_input["time_perf"],
meta_input["time_belong"],
meta_input["X_test"],
ignore_weight=ignore_weight
) # 这里可能因为如下原因导致pred为None;
if self.criterion == "mse":
criterion = nn.MSELoss()
loss = criterion(pred, meta_input["y_test"])
elif self.criterion == "ic_loss":
criterion = ICLoss()
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"], skip_size=50)

if np.isnan(loss.detach().item()): __import__('ipdb').set_trace()

if phase == "train":
opt.zero_grad()
norm_loss = nn.MSELoss()
loss.backward()
opt.step()
elif phase == "test":
pass

pred_y_all.append(
pd.DataFrame(
{
"pred": pd.Series(pred.detach().cpu().numpy(), index=meta_input["test_idx"]),
"label": pd.Series(
meta_input["y_test"].detach().cpu().numpy(), index=meta_input["test_idx"]
),
}
)
)
running_loss += loss.detach().item()
running_loss = running_loss / len(task_list)
loss_l.setdefault(phase, []).append(running_loss)

pred_y_all = pd.concat(pred_y_all)
ic = (
pred_y_all.groupby("datetime")
.apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
.mean()
)

R.log_metrics(**{f"loss/{phase}": running_loss, "step": epoch})
R.log_metrics(**{f"ic/{phase}": ic, "step": epoch})

def fit(self, meta_dataset: MetaDatasetHDS):
"""
The meta-learning-based data selection interacts directly with meta-dataset due to the close-form proxy measurement.
Expand All @@ -81,67 +141,18 @@ def fit(self, meta_dataset: MetaDatasetHDS):
step=self.step, hist_step_n=self.hist_step_n, clip_weight=self.clip_weight, clip_method=self.clip_method
)

train_step = 0
opt = optim.Adam(self.tn.parameters(), lr=self.lr)

# run weight with no weight
for phase, task_list in zip(phases, meta_tasks_l):
self.run_epoch(f"{phase}_noweight", task_list, 0, opt, {}, ignore_weight=True)
self.run_epoch(f"{phase}_init", task_list, 0, opt, {})

# run training
loss_l = {}
for epoch in tqdm(range(self.max_epoch), desc="epoch"):
for phase, task_list in zip(phases, meta_tasks_l):
if phase == "train": # phase 0 for training, 1 for inference
self.tn.train()
torch.set_grad_enabled(True)
else:
self.tn.eval()
torch.set_grad_enabled(False)
running_loss = 0.0
pred_y_all = []
for task in tqdm(task_list, desc=f"{phase} Task", leave=False):
meta_input = task.get_meta_input()
pred, weights = self.tn(
meta_input["X"],
meta_input["y"],
meta_input["time_perf"],
meta_input["time_belong"],
meta_input["X_test"],
)
if self.criterion == "mse":
criterion = nn.MSELoss()
loss = criterion(pred, meta_input["y_test"])
elif self.criterion == "ic_loss":
criterion = ICLoss()
loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"])

if phase == "train":
opt.zero_grad()
norm_loss = nn.MSELoss()
loss.backward()
opt.step()
train_step += 1
elif phase == "test":
pass

pred_y_all.append(
pd.DataFrame(
{
"pred": pd.Series(pred.detach().cpu().numpy(), index=meta_input["test_idx"]),
"label": pd.Series(
meta_input["y_test"].detach().cpu().numpy(), index=meta_input["test_idx"]
),
}
)
)
running_loss += loss.detach().item()
running_loss = running_loss / len(task_list)
loss_l.setdefault(phase, []).append(running_loss)

pred_y_all = pd.concat(pred_y_all)
ic = (
pred_y_all.groupby("datetime")
.apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
.mean()
)

R.log_metrics(**{f"loss/{phase}": running_loss, "step": epoch})
R.log_metrics(**{f"ic/{phase}": ic, "step": epoch})
self.run_epoch(phase, task_list, epoch, opt, loss_l)
R.save_objects(**{"model.pkl": self.tn})
self.fitted = True

Expand Down
9 changes: 7 additions & 2 deletions qlib/contrib/meta/data_selection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_sim_mat_idx(i_sim_mat, outsample_period):


class ICLoss(nn.Module):
def forward(self, pred, y, idx):
def forward(self, pred, y, idx, skip_size=50):
"""forward.

:param pred:
Expand All @@ -41,15 +41,20 @@ def forward(self, pred, y, idx):
diff_point.append(None)

ic_all = 0.0
skip_n = 0
for start_i, end_i in zip(diff_point, diff_point[1:]):
pred_focus = pred[start_i:end_i] # TODO: just for fake
if pred_focus.shape[0] < skip_size:
# skip some days which have very small amount of stock.
skip_n += 1
continue
y_focus = y[start_i:end_i]
ic_day = torch.dot(
(pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(),
(y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(),
)
ic_all += ic_day
ic_mean = ic_all / (len(diff_point) - 1)
ic_mean = ic_all / (len(diff_point) - 1 - skip_n)
return -ic_mean # ic loss


Expand Down
70 changes: 39 additions & 31 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,38 @@
from qlib.data.dataset.weight import Reweighter


def _log_task_info(task_config: dict):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})


def _exe_task(task_config: dict):
rec = R.get_recorder()
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
reweighter: Reweighter = task_config.get("reweighter", None)
# model training
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()


def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Begin task training to start a recorder and save the task config.
Expand All @@ -38,11 +70,8 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
_log_task_info(task_config)
return R.get_recorder()


def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
Expand All @@ -58,29 +87,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
reweighter: Reweighter = task_config.get("reweighter", None)
# model training
auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()

_exe_task(task_config)
return rec


Expand All @@ -101,9 +108,10 @@ def task_train(task_config: dict, experiment_name: str, recorder_name: str = Non
----------
Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name, recorder_name=recorder_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
_log_task_info(task_config)
_exe_task(task_config)
return R.get_recorder()


class Trainer:
Expand Down
4 changes: 2 additions & 2 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
sorted dataframe
"""
idx = df.index if axis == 0 else df.columns
if idx.is_monotonic_increasing:
if idx.is_monotonic_increasing and (not isinstance(idx, pd.MultiIndex) or not idx.is_lexsorted()):
return df
else:
return df.sort_index(axis=axis)
Expand Down Expand Up @@ -657,7 +657,7 @@ def _func(*args, **kwargs):
for k, v in kwargs.items():
# if `func` don't accept variable keyword arguments like `**kwargs` and have not according named arguments
if spec.varkw is None and k not in spec.args:
log.warn(f"The parameter `{k}` with value `{v}` is ignored.")
log.warning(f"The parameter `{k}` with value `{v}` is ignored.")
else:
new_kwargs[k] = v
return func(*args, **new_kwargs)
Expand Down
21 changes: 19 additions & 2 deletions qlib/workflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from contextlib import contextmanager
from typing import Text, Optional
from .expm import MLflowExpManager
from .expm import ExpManager
from .exp import Experiment
from .recorder import Recorder
from ..utils import Wrapper
Expand All @@ -15,7 +15,7 @@ class QlibRecorder:
A global system that helps to manage the experiments.
"""

def __init__(self, exp_manager):
def __init__(self, exp_manager: ExpManager):
self.exp_manager = exp_manager

def __repr__(self):
Expand Down Expand Up @@ -334,6 +334,23 @@ def set_uri(self, uri: Optional[Text]):
"""
self.exp_manager.set_uri(uri)

@contextmanager
def uri_context(self, uri: Text):
"""
Temporarily set the exp_manager's uri to uri

Parameters
----------
uri : Text
the temporal uri
"""
prev_uri = self.exp_manager._current_uri
self.exp_manager.set_uri(uri)
try:
yield
finally:
self.exp_manager._current_uri = prev_uri

def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:
Expand Down
5 changes: 3 additions & 2 deletions qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .recorder import Recorder
from ..log import get_module_logger

logger = get_module_logger("workflow", logging.INFO)
logger = get_module_logger("workflow")


class ExpManager:
Expand Down Expand Up @@ -258,7 +258,7 @@ def set_uri(self, uri: Optional[Text] = None):

"""
if uri is None:
logger.info("No tracking URI is provided. Use the default tracking URI.")
logger.debug("No tracking URI is provided. Use the default tracking URI.")
self._current_uri = self.default_uri
else:
# Temporarily re-set the current uri as the uri argument.
Expand All @@ -269,6 +269,7 @@ def set_uri(self, uri: Optional[Text] = None):
def _set_uri(self):
"""
Customized features for subclasses' set_uri function.
This method is designed for the underlying experiment backend storage.
"""
raise NotImplementedError(f"Please implement the `_set_uri` method.")

Expand Down
Loading