Skip to content

Commit

Permalink
pylint code refine & Fix nested example (#848)
Browse files Browse the repository at this point in the history
* refine code by CI

* fix argument error

* fix nested eample
  • Loading branch information
you-n-g authored Jan 14, 2022
1 parent c399695 commit d0113ea
Show file tree
Hide file tree
Showing 26 changed files with 65 additions and 68 deletions.
8 changes: 5 additions & 3 deletions examples/nested_decision_execution/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class NestedDecisionExecutionWorkflow:
},
}

exp_name = "nested"

port_analysis_config = {
"executor": {
"class": "NestedExecutor",
Expand Down Expand Up @@ -230,7 +232,7 @@ def _init_qlib(self):
qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None)

def _train_model(self, model, dataset):
with R.start(experiment_name="train"):
with R.start(experiment_name=self.exp_name):
R.log_params(**flatten_dict(self.task))
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
Expand All @@ -257,7 +259,7 @@ def backtest(self):
self.port_analysis_config["strategy"] = strategy_config
self.port_analysis_config["backtest"]["benchmark"] = self.benchmark

with R.start(experiment_name="backtest"):
with R.start(experiment_name=self.exp_name, resume=True):
recorder = R.get_recorder()
par = PortAnaRecord(
recorder,
Expand Down Expand Up @@ -382,7 +384,7 @@ def backtest_only_daily(self):
}
pa_conf["backtest"]["benchmark"] = self.benchmark

with R.start(experiment_name="backtest"):
with R.start(experiment_name=self.exp_name, resume=True):
recorder = R.get_recorder()
par = PortAnaRecord(recorder, pa_conf)
par.generate()
Expand Down
6 changes: 2 additions & 4 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def generate_order_for_target_amount_position(self, target_position, current_pos
deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor)
if deal_amount == 0:
continue
elif deal_amount > 0:
if deal_amount > 0:
# buy stock
buy_order_list.append(
Order(
Expand Down Expand Up @@ -687,9 +687,7 @@ def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> int:
orig_deal_amount = order.deal_amount
order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0)
if vol_limit_min < orig_deal_amount:
self.logger.debug(
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}")

def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
"""return the real order amount after cash limit for buying.
Expand Down
2 changes: 1 addition & 1 deletion qlib/backtest/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
return return_value.get("execute_result")

@abstractclassmethod
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
"""
Please refer to the doc of collect_data
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
Expand Down
2 changes: 1 addition & 1 deletion qlib/backtest/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BasePosition:
Please refer to the `Position` class for the position
"""

def __init__(self, cash=0.0, *args, **kwargs):
def __init__(self, *args, cash=0.0, **kwargs):
self._settle_type = self.ST_NO

def skip_update(self) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions qlib/backtest/profit_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,16 @@ def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df):
group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df)

group_ret = {}
for group_key in stock_weight_in_group:
stock_weight_in_group_start_date = min(stock_weight_in_group[group_key].index)
stock_weight_in_group_end_date = max(stock_weight_in_group[group_key].index)
for group_key, val in stock_weight_in_group.items():
stock_weight_in_group_start_date = min(val.index)
stock_weight_in_group_end_date = max(val.index)

temp_stock_ret_df = stock_ret_df[
(stock_ret_df.index >= stock_weight_in_group_start_date)
& (stock_ret_df.index <= stock_weight_in_group_end_date)
]

group_ret[group_key] = (temp_stock_ret_df * stock_weight_in_group[group_key]).sum(axis=1)
group_ret[group_key] = (temp_stock_ret_df * val).sum(axis=1)
# If no weight is assigned, then the return of group will be np.nan
group_ret[group_key][group_weight[group_key] == 0.0] = np.nan

Expand Down
3 changes: 2 additions & 1 deletion qlib/backtest/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def load_portfolio_metrics(self, path):
path: str/ pathlib.Path()
"""
path = pathlib.Path(path)
r = pd.read_csv(open(path, "rb"), index_col=0)
with path.open("rb") as f:
r = pd.read_csv(f, index_col=0)
r.index = pd.DatetimeIndex(r.index)

index = r.index
Expand Down
5 changes: 1 addition & 4 deletions qlib/backtest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,7 @@ def get(self, infra_name):
warnings.warn(f"infra {infra_name} is not found!")

def has(self, infra_name):
if infra_name in self.get_support_infra() and hasattr(self, infra_name):
return True
else:
return False
return infra_name in self.get_support_infra() and hasattr(self, infra_name)

def update(self, other):
support_infra = other.get_support_infra()
Expand Down
6 changes: 2 additions & 4 deletions qlib/contrib/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def _get_date_parse_fn(target):
get_date_parse_fn('20120101')('2017-01-01') => '20170101'
get_date_parse_fn(20120101)('2017-01-01') => 20170101
"""
if isinstance(target, pd.Timestamp):
_fn = lambda x: pd.Timestamp(x) # Timestamp('2020-01-01')
elif isinstance(target, int):
if isinstance(target, int):
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
elif isinstance(target, str) and len(target) == 8:
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
Expand Down Expand Up @@ -158,7 +156,7 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs):
try:
df = self.handler._learn.copy() # use copy otherwise recorder will fail
# FIXME: currently we cannot support switching from `_learn` to `_infer` for inference
except:
except Exception:
warnings.warn("cannot access `_learn`, will load raw data")
df = self.handler._data.copy()
df.index = df.index.swaplevel()
Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def long_short_backtest(

def t_run():
pred_FN = "./check_pred.csv"
pred = pd.read_csv(pred_FN)
pred: pd.DataFrame = pd.read_csv(pred_FN)
pred["datetime"] = pd.to_datetime(pred["datetime"])
pred = pred.set_index([pred.columns[0], pred.columns[1]])
pred = pred.iloc[:9000]
Expand Down
2 changes: 1 addition & 1 deletion qlib/contrib/model/pytorch_adarnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def predict(self, x):
return fc_out


class TransferLoss(object):
class TransferLoss:
def __init__(self, loss_type="cosine", input_dim=512):
"""
Supported loss_type: mmd(mmd_lin), mmd_rbf, coral, cosine, kl, js, mine, adv
Expand Down
6 changes: 2 additions & 4 deletions qlib/contrib/model/pytorch_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(
"\nlr_decay_steps : {}"
"\noptimizer : {}"
"\nloss_type : {}"
"\neval_steps : {}"
"\nseed : {}"
"\ndevice : {}"
"\nuse_GPU : {}"
Expand All @@ -113,7 +112,6 @@ def __init__(
lr_decay_steps,
optimizer,
loss,
eval_steps,
seed,
self.device,
self.use_gpu,
Expand Down Expand Up @@ -331,8 +329,8 @@ def __init__(self, input_dim, output_dim, layers=(256, 512, 768, 512, 256, 128,
dnn_layers = []
drop_input = nn.Dropout(0.05)
dnn_layers.append(drop_input)
for i, (input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(input_dim, hidden_units)
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
fc = nn.Linear(_input_dim, hidden_units)
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
bn = nn.BatchNorm1d(hidden_units)
seq = nn.Sequential(fc, bn, activation)
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/model/pytorch_tra.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

try:
from torch.utils.tensorboard import SummaryWriter
except:
except ImportError:
SummaryWriter = None

from tqdm import tqdm
Expand Down Expand Up @@ -257,7 +257,7 @@ def train_epoch(self, epoch, data_set, is_pretrain=False):
total_loss += loss.item()
total_count += 1

if self.use_daily_transport and len(P_all):
if self.use_daily_transport and len(P_all) > 0:
P_all = pd.concat(P_all, axis=0)
prob_all = pd.concat(prob_all, axis=0)
choice_all = pd.concat(choice_all, axis=0)
Expand Down
5 changes: 2 additions & 3 deletions qlib/contrib/report/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class BaseGraph:
""" """

_name = None

Expand Down Expand Up @@ -297,8 +296,8 @@ def _init_sub_graph_data(self):
:return:
"""
self._sub_graph_data = list()
self._subplot_titles = list()
self._sub_graph_data = []
self._subplot_titles = []

for i, column_name in enumerate(self._df.columns):
row = math.ceil((i + 1) / self.__cols)
Expand Down
2 changes: 1 addition & 1 deletion qlib/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
flt_data = super()._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
Expand Down
6 changes: 3 additions & 3 deletions qlib/data/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,14 +1407,14 @@ def get_longest_back_rolling(self):
)

def get_extended_window_size(self):
ll, lr = self.feature_left.get_extended_window_size()
rl, rr = self.feature_right.get_extended_window_size()
if self.N == 0:
get_module_logger(self.__class__.__name__).warning(
"The PairRolling(ATTR, 0) will not be accurately calculated"
)
return self.feature.get_extended_window_size()
return -np.inf, max(lr, rr)
else:
ll, lr = self.feature_left.get_extended_window_size()
rl, rr = self.feature_right.get_extended_window_size()
return max(ll, rl) + self.N - 1, max(lr, rr)


Expand Down
2 changes: 1 addition & 1 deletion qlib/data/storage/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def data(self) -> List[CalVT]:
# If cache is enabled, then return cache directly
if self.enable_read_cache:
key = "orig_file" + str(self.uri)
if not key in H["c"]:
if key not in H["c"]:
H["c"][key] = self._read_calendar()
_calendar = H["c"][key]
else:
Expand Down
2 changes: 1 addition & 1 deletion qlib/model/riskmodel/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, factor_model: str = "pca", num_factors: int = 10, **kwargs):
num_factors (int): number of components to keep.
kwargs: see `RiskModel` for more information
"""
if "nan_option" in kwargs.keys():
if "nan_option" in kwargs:
assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format(
kwargs["nan_option"]
)
Expand Down
24 changes: 12 additions & 12 deletions qlib/model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,21 @@ def train(self, tasks: list, train_func: Callable = None, experiment_name: str =
recs.append(rec)
return recs

def end_train(self, recs: list, **kwargs) -> List[Recorder]:
def end_train(self, models: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
models (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
if isinstance(recs, Recorder):
recs = [recs]
for rec in recs:
if isinstance(models, Recorder):
models = [models]
for rec in models:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
return models


class DelayTrainerR(TrainerR):
Expand All @@ -289,32 +289,32 @@ def __init__(self, experiment_name: str = None, train_func=begin_task_train, end
self.end_train_func = end_train_func
self.delay = True

def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them
models (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if isinstance(recs, Recorder):
recs = [recs]
if isinstance(models, Recorder):
models = [models]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
for rec in recs:
for rec in models:
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
continue
end_train_func(rec, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
return models


class TrainerRM(Trainer):
Expand Down
3 changes: 2 additions & 1 deletion qlib/tests/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

import re
import sys
import qlib
import shutil
import zipfile
Expand Down Expand Up @@ -101,7 +102,7 @@ def _delete_qlib_data(file_dir: Path):
f"\nAre you sure you want to delete, yes(Y/y), no (N/n):"
)
if str(flag) not in ["Y", "y"]:
exit()
sys.exit()
for _p in rm_dirs:
logger.warning(f"delete: {_p}")
shutil.rmtree(_p)
Expand Down
17 changes: 7 additions & 10 deletions qlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,16 +654,13 @@ def exists_qlib_data(qlib_dir):
def check_qlib_data(qlib_config):
inst_dir = Path(qlib_config["provider_uri"]).joinpath("instruments")
for _p in inst_dir.glob("*.txt"):
try:
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)
except AssertionError:
raise
assert len(pd.read_csv(_p, sep="\t", nrows=0, header=None).columns) == 3, (
f"\nThe {str(_p.resolve())} of qlib data is not equal to 3 columns:"
f"\n\tIf you are using the data provided by qlib: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#qlib-format-dataset"
f"\n\tIf you are using your own data, please dump the data again: "
f"https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format"
)


def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
Expand Down
3 changes: 1 addition & 2 deletions qlib/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

# Base exception class
class QlibException(Exception):
def __init__(self, message):
super(QlibException, self).__init__(message)
pass


class RecorderInitializationError(QlibException):
Expand Down
3 changes: 1 addition & 2 deletions qlib/utils/paral.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def run(self):
data = self._q.get()
if data == self.STOP_MARK:
break
else:
data()
data()

def __call__(self, func, *args, **kwargs):
self._q.put(partial(func, *args, **kwargs))
Expand Down
Loading

0 comments on commit d0113ea

Please sign in to comment.