Skip to content

Commit

Permalink
Update pytorch_alstm_ts.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wendili-cs authored and you-n-g committed Jan 18, 2021
1 parent b4a088e commit 740c297
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions qlib/contrib/model/pytorch_alstm_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def fit(
verbose=True,
save_path=None,
):
dl_train = dataset.prepare("train", data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", data_key=DataHandlerLP.DK_L)
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)

dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
Expand Down Expand Up @@ -260,7 +260,7 @@ def predict(self, dataset):
if not self._fitted:
raise ValueError("model is not fitted yet!")

dl_test = dataset.prepare("test", data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.ALSTM_model.eval()
Expand Down

0 comments on commit 740c297

Please sign in to comment.