From 0327b0388241c1d788231339f9f9a2026fc9cf01 Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 14 Apr 2022 13:16:51 +0000 Subject: [PATCH] fix tra dataset bug --- examples/benchmarks/TRA/src/dataset.py | 16 +++++++++++----- qlib/data/dataset/__init__.py | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/benchmarks/TRA/src/dataset.py b/examples/benchmarks/TRA/src/dataset.py index 50c57d5818..6740b1cbdf 100644 --- a/examples/benchmarks/TRA/src/dataset.py +++ b/examples/benchmarks/TRA/src/dataset.py @@ -6,8 +6,7 @@ import numpy as np import pandas as pd -from qlib.utils import init_instance_by_config -from qlib.data.dataset import DatasetH, DataHandler +from qlib.data.dataset import DatasetH device = "cuda" if torch.cuda.is_available() else "cpu" @@ -95,7 +94,7 @@ def __init__( shuffle=True, pin_memory=False, drop_last=False, - **kwargs + **kwargs, ): assert horizon > 0, "please specify `horizon` to avoid data leakage" @@ -150,8 +149,15 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs): def _prepare_seg(self, slc, **kwargs): fn = _get_date_parse_fn(self._index[0][1]) - start_date = fn(slc.start) - end_date = fn(slc.stop) + + if isinstance(slc, slice): + start, stop = slc.start, slc.stop + elif isinstance(slc, (list, tuple)): + start, stop = slc + else: + raise NotImplementedError(f"This type of input is not supported") + start_date = fn(start) + end_date = fn(stop) obj = copy.copy(self) # shallow copy # NOTE: Seriable will disable copy `self._data` so we manually assign them here obj._data = self._data diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 7262640588..6e0c0ab606 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -171,6 +171,7 @@ def _prepare_seg(self, slc, **kwargs): Parameters ---------- slc : please refer to the docs of `prepare` + NOTE: it may not be an instance of slice. It may be a segment of `segments` from `def prepare` """ if hasattr(self, "fetch_kwargs"): return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)