Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance mprovement #921

Merged
merged 2 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 16 additions & 26 deletions qlib/contrib/model/pytorch_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections import defaultdict

import os
import gc
import numpy as np
import pandas as pd
from typing import Callable, Optional, Text, Union
Expand All @@ -32,7 +33,6 @@
from ...workflow import R
from qlib.contrib.meta.data_selection.utils import ICLoss
from torch.nn import DataParallel
from torch.utils.data import DataLoader, SequentialSampler


class DNNModelPytorch(Model):
Expand Down Expand Up @@ -201,7 +201,7 @@ def fit(
seg, col_set=["feature", "label"], data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L
)
all_df["x"][seg] = df["feature"]
all_df["y"][seg] = df["label"]
all_df["y"][seg] = df["label"].copy() # We have to use copy to remove the reference to release mem
if reweighter is None:
all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index)
elif isinstance(reweighter, Reweighter):
Expand All @@ -216,6 +216,10 @@ def fit(
all_t[v][seg] = all_t[v][seg].to(self.device) # This will consume a lot of memory !!!!

evals_result[seg] = []
# free memory
del df
del all_df["x"]
gc.collect()

save_path = get_or_create_path(save_path)
stop_steps = 0
Expand Down Expand Up @@ -266,7 +270,7 @@ def fit(
loss_val = cur_loss_val.item()
metric_val = (
self.get_metric(
preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["x"]["valid"].index
preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["y"]["valid"].index
)
.detach()
.cpu()
Expand All @@ -281,7 +285,7 @@ def fit(
self.get_metric(
self._nn_predict(all_t["x"]["train"], return_cpu=False),
all_t["y"]["train"].reshape(-1),
all_df["x"]["train"].index,
all_df["y"]["train"].index,
)
.detach()
.cpu()
Expand Down Expand Up @@ -351,31 +355,17 @@ def _nn_predict(self, data, return_cpu=True):
1) test inference (data may come from CPU and expect the output data is on CPU)
2) evaluation on training (data may come from GPU)
"""
if isinstance(data, torch.Tensor) and data.device.type != "cpu":
# GPU data
# CUDA data don't support pin_memory and multi-processing workers
num_workers = 0
pin_memory = False
else:
# CPU data
if not isinstance(data, torch.Tensor):
if isinstance(data, pd.DataFrame):
data = data.values
# else: CPU Tensor
num_workers = 8
pin_memory = True
data_loader = DataLoader(
data,
sampler=SequentialSampler(data),
batch_size=self.batch_size,
drop_last=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
if not isinstance(data, torch.Tensor):
if isinstance(data, pd.DataFrame):
data = data.values
data = torch.Tensor(data)
data = data.to(self.device)
preds = []
self.dnn_model.eval()
with torch.no_grad():
for x in data_loader:
batch_size = 8096
for i in range(0, len(data), batch_size):
x = data[i : i + batch_size]
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
if return_cpu:
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
Expand Down
4 changes: 2 additions & 2 deletions qlib/contrib/report/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

"""
import pandas as pd
from blocks.utils.log import logt
from qlib.log import TimeInspector
from qlib.contrib.report.utils import sub_fig_generator


class FeaAnalyser:
def __init__(self, dataset: pd.DataFrame):
self._dataset = dataset
with logt("calc_stat_values"):
with TimeInspector.logt("calc_stat_values"):
self.calc_stat_values()

def calc_stat_values(self):
Expand Down