Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor collector #68

Merged
merged 10 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

- name: Test data downloads and examples
run: |
python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn
# cd examples
# estimator -c estimator/estimator_config.yaml
# jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html
4 changes: 2 additions & 2 deletions examples/estimator/analyze_from_estimator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(CUR_DIR.parent.parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
Expand Down Expand Up @@ -219,4 +219,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
2 changes: 1 addition & 1 deletion examples/train_and_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts")))
from get_data import GetData

GetData().qlib_data_cn(target_dir=provider_uri)
GetData().qlib_data(target_dir=provider_uri)

qlib.init(provider_uri=provider_uri, region=REG_CN)

Expand Down
4 changes: 2 additions & 2 deletions examples/train_backtest_analyze.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
" print(f\"Qlib data is not found in {provider_uri}\")\n",
" sys.path.append(str(Path.cwd().parent.joinpath(\"scripts\")))\n",
" from get_data import GetData\n",
" GetData().qlib_data_cn(target_dir=provider_uri)\n",
" GetData().qlib_data(target_dir=provider_uri)\n",
"qlib.init(provider_uri=provider_uri, region=REG_CN)"
]
},
Expand Down Expand Up @@ -335,4 +335,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
144 changes: 144 additions & 0 deletions scripts/check_dump_bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from pathlib import Path
from concurrent.futures import ProcessPoolExecutor

import qlib
from qlib.data import D

import fire
import datacompy
import pandas as pd
from tqdm import tqdm
from loguru import logger


class CheckBin:

NOT_IN_FEATURES = "not in features"
COMPARE_FALSE = "compare False"
COMPARE_TRUE = "compare True"
COMPARE_ERROR = "compare error"

def __init__(
self,
qlib_dir: str,
csv_path: str,
check_fields: str = None,
freq: str = "day",
symbol_field_name: str = "symbol",
date_field_name: str = "date",
file_suffix: str = ".csv",
max_workers: int = 16,
):
"""

Parameters
----------
qlib_dir : str
qlib dir
csv_path : str
origin csv path
check_fields : str, optional
check fields, by default None, check qlib_dir/features/<first_dir>/*.<freq>.bin
freq : str, optional
freq, value from ["day", "1m"]
symbol_field_name: str, optional
symbol field name, by default "symbol"
date_field_name: str, optional
date field name, by default "date"
file_suffix: str, optional
csv file suffix, by default ".csv"
max_workers: int, optional
max workers, by default 16
"""
self.qlib_dir = Path(qlib_dir).expanduser()
bin_path_list = list(self.qlib_dir.joinpath("features").iterdir())
self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list))
qlib.init(
provider_uri=str(self.qlib_dir.resolve()),
mount_path=str(self.qlib_dir.resolve()),
auto_mount=False,
redis_port=-1,
)
csv_path = Path(csv_path).expanduser()
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])

if check_fields is None:
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
else:
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
self.check_fields = list(map(lambda x: x.strip(), check_fields))
self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields))
self.max_workers = max_workers
self.symbol_field_name = symbol_field_name
self.date_field_name = date_field_name
self.freq = freq
self.file_suffix = file_suffix

def _compare(self, file_path: Path):
symbol = file_path.name.strip(self.file_suffix)
if symbol.lower() not in self.qlib_symbols:
return self.NOT_IN_FEATURES
# qlib data
qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq)
qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True)
# csv data
origin_df = pd.read_csv(file_path)
origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name])
if self.symbol_field_name not in origin_df.columns:
origin_df[self.symbol_field_name] = symbol
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
origin_df.index.names = qlib_df.index.names
try:
compare = datacompy.Compare(
origin_df,
qlib_df,
on_index=True,
abs_tol=1e-08, # Optional, defaults to 0
rel_tol=1e-05, # Optional, defaults to 0
df1_name="Original", # Optional, defaults to 'df1'
df2_name="New", # Optional, defaults to 'df2'
)
_r = compare.matches(ignore_extra_columns=True)
return self.COMPARE_TRUE if _r else self.COMPARE_FALSE
except Exception as e:
logger.warning(f"{symbol} compare error: {e}")
return self.COMPARE_ERROR

def check(self):
"""Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data

"""
logger.info("start check......")

error_list = []
not_in_features = []
compare_false = []
with tqdm(total=len(self.csv_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)):
symbol = file_path.name.strip(self.file_suffix)
if _check_res == self.NOT_IN_FEATURES:
not_in_features.append(symbol)
elif _check_res == self.COMPARE_ERROR:
error_list.append(symbol)
elif _check_res == self.COMPARE_FALSE:
compare_false.append(symbol)
p_bar.update()

logger.info("end of check......")
if error_list:
logger.warning(f"compare error: {error_list}")
if not_in_features:
logger.warning(f"not in features: {not_in_features}")
if compare_false:
logger.warning(f"compare False: {compare_false}")
logger.info(
f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false"
)


if __name__ == "__main__":
fire.Fire(CheckBin)
22 changes: 22 additions & 0 deletions scripts/data_collector/cn_index/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# CSI300/CSI100 History Companies Collection

## Requirements

```bash
pip install -r requirements.txt
```

## Collector Data

```bash
# parse instruments, using in qlib/instruments.
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments

# parse new companies
python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies

# index_name support: CSI300, CSI100
# help
python collector.py --help
```

Loading