Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add get_feature_importance to model interpret #444

Merged
merged 3 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions examples/highfreq/workflow.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import sys
import fire
from pathlib import Path

import qlib
import pickle
import numpy as np
import pandas as pd
from qlib.config import REG_CN, HIGH_FREQ_CONFIG
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)

from qlib.utils import init_instance_by_config, exists_qlib_data

from qlib.utils import init_instance_by_config
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.ops import Operators
from qlib.data.data import Cal
Expand Down Expand Up @@ -96,9 +85,7 @@ def _init_qlib(self):
# use yahoo_cn_1min data
QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF}
provider_uri = QLIB_INIT_CONFIG.get("provider_uri")
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN)
GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True)
qlib.init(**QLIB_INIT_CONFIG)

def _prepare_calender_cache(self):
Expand Down
58 changes: 14 additions & 44 deletions examples/hyperparameter/LightGBM/hyperparameter_158.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,9 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna

provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")

market = "csi300"
benchmark = "SH000300"

data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.config import CSI300_DATASET_CONFIG
from qlib.tests.data import GetData


def objective(trial):
Expand All @@ -65,12 +28,19 @@ def objective(trial):
},
},
}

evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])


study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":

provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region="cn")

dataset = init_instance_by_config(CSI300_DATASET_CONFIG)

study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
57 changes: 15 additions & 42 deletions examples/hyperparameter/LightGBM/hyperparameter_360.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,11 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
from qlib.config import REG_CN
from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS

provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData

GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")

market = "csi300"
benchmark = "SH000300"

data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS)


def objective(trial):
Expand Down Expand Up @@ -72,5 +37,13 @@ def objective(trial):
return min(evals_result["valid"])


study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
if __name__ == "__main__":

provider_uri = "~/.qlib/qlib_data/cn_data"
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
qlib.init(provider_uri=provider_uri, region=REG_CN)

dataset = init_instance_by_config(DATASET_CONFIG)

study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)
32 changes: 32 additions & 0 deletions examples/model_interpreter/feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


import qlib
from qlib.config import REG_CN

from qlib.utils import init_instance_by_config
from qlib.tests.data import GetData
from qlib.tests.config import CSI300_GBDT_TASK


if __name__ == "__main__":

# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)

qlib.init(provider_uri=provider_uri, region=REG_CN)

###################################
# train model
###################################
# model initialization
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
model.fit(dataset)

# get model feature importance
feature_importance = model.get_feature_importance()
print("feature importance:")
print(feature_importance)
62 changes: 4 additions & 58 deletions examples/model_rolling/task_manager_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,63 +17,7 @@
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM


data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}

dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}

record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]

# use lgb
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}

# use xgboost
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG


class RollingTaskExample:
Expand All @@ -85,11 +29,13 @@ def __init__(
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
task_config=None,
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
if task_config is None:
task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
Expand Down
62 changes: 4 additions & 58 deletions examples/online_srv/online_management_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,7 @@
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager


data_handler_config = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}

dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}

record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]

# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}

# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG


class OnlineSimulationExample:
Expand All @@ -84,7 +28,7 @@ def __init__(
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config, task_lgb_config],
tasks=None,
):
"""
Init OnlineManagerExample.
Expand All @@ -101,6 +45,8 @@ def __init__(
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
Expand Down
Loading