Skip to content

Commit 7c1331f

Browse files
authored
Merge pull request #7 from salcc/main
Optimize memory and improve code
2 parents 9ec0033 + 860e8b5 commit 7c1331f

File tree

3 files changed

+67
-66
lines changed

3 files changed

+67
-66
lines changed

hyperfast/hyperfast.py

+58-54
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class HyperFastClassifier(BaseEstimator, ClassifierMixin):
4141
n_ensemble (int): Number of ensemble models to use.
4242
batch_size (int): Size of the batch for weight prediction and ensembling.
4343
nn_bias (bool): Whether to use nearest neighbor bias.
44+
nn_bias_mini_batches (bool): Whether to use mini-batches of size 128 for nearest neighbor bias.
4445
optimization (str or None): Strategy for optimization, can be None, 'optimize', or 'ensemble_optimize'.
4546
optimize_steps (int): Number of optimization steps.
4647
torch_pca (bool): Whether to use PyTorch-based PCA optimized for GPU (fast) or scikit-learn PCA (slower).
@@ -58,6 +59,7 @@ def __init__(
5859
n_ensemble: int = 16,
5960
batch_size: int = 2048,
6061
nn_bias: bool = False,
62+
nn_bias_mini_batches: bool = True,
6163
optimization: str | None = "ensemble_optimize",
6264
optimize_steps: int = 64,
6365
torch_pca: bool = True,
@@ -72,6 +74,7 @@ def __init__(
7274
self.n_ensemble = n_ensemble
7375
self.batch_size = batch_size
7476
self.nn_bias = nn_bias
77+
self.nn_bias_mini_batches = nn_bias_mini_batches
7578
self.optimization = optimization
7679
self.optimize_steps = optimize_steps
7780
self.torch_pca = torch_pca
@@ -112,7 +115,7 @@ def _initialize_model(self, cfg: SimpleNamespace) -> HyperFast:
112115
flush=True,
113116
)
114117
model.load_state_dict(
115-
torch.load(cfg.model_path, map_location=torch.device(cfg.device))
118+
torch.load(cfg.model_path, map_location=torch.device(cfg.device), weights_only=True)
116119
)
117120
print(
118121
f"Model loaded from {cfg.model_path} on {cfg.device} device.",
@@ -208,9 +211,7 @@ def _preprocess_fitting_data(
208211
y = column_or_1d(y, warn=True)
209212
self.n_features_in_ = x.shape[1]
210213
self.classes_, y = np.unique(y, return_inverse=True)
211-
return torch.tensor(x, dtype=torch.float).to(self.device), torch.tensor(
212-
y, dtype=torch.long
213-
).to(self.device)
214+
return torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.long)
214215

215216
def _preprocess_test_data(
216217
self,
@@ -240,7 +241,7 @@ def _preprocess_test_data(
240241
x_test = check_array(x_test)
241242
# Standardize data
242243
x_test = self._scaler.transform(x_test)
243-
return torch.tensor(x_test, dtype=torch.float).to(self.device)
244+
return torch.tensor(x_test, dtype=torch.float)
244245

245246
def _initialize_fit_attributes(self) -> None:
246247
self._rfs = []
@@ -314,6 +315,7 @@ def fit(
314315

315316
for n in range(self.n_ensemble):
316317
X_pred, y_pred = self._sample_data(X, y)
318+
X_pred, y_pred = X_pred.to(self.device), y_pred.to(self.device)
317319
self.n_classes_ = len(torch.unique(y_pred).cpu().numpy())
318320

319321
rf, pca, main_network = self._model(X_pred, y_pred, self.n_classes_)
@@ -362,57 +364,59 @@ def fit(
362364
def predict_proba(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
363365
check_is_fitted(self)
364366
X = self._preprocess_test_data(X)
365-
with torch.no_grad():
366-
orig_X = X
367-
yhats = []
368-
for jj in range(len(self._main_networks)):
369-
main_network = self._main_networks[jj]
370-
rf = self._rfs[jj]
371-
pca = self._pcas[jj]
372-
X_pred = self._X_preds[jj]
373-
y_pred = self._y_preds[jj]
374-
if self.feature_bagging:
375-
X_ = X[:, self.selected_features[jj]]
376-
orig_X_ = orig_X[:, self.selected_features[jj]]
377-
else:
378-
X_ = X
379-
orig_X_ = orig_X
380-
381-
X_transformed = transform_data_for_main_network(
382-
X=X_, cfg=self._cfg, rf=rf, pca=pca
383-
)
384-
outputs, intermediate_activations = forward_main_network(
385-
X_transformed, main_network
386-
)
387-
388-
if self.nn_bias:
389-
X_pred_ = transform_data_for_main_network(
390-
X=X_pred, cfg=self._cfg, rf=rf, pca=pca
367+
X_dataset = torch.utils.data.TensorDataset(X)
368+
X_loader = torch.utils.data.DataLoader(X_dataset, batch_size=self.batch_size, shuffle=False)
369+
all_yhats = []
370+
for X_batch in X_loader:
371+
X_batch = X_batch[0].to(self.device)
372+
with torch.no_grad():
373+
orig_X = X_batch
374+
yhats = []
375+
for jj in range(len(self._main_networks)):
376+
main_network = self._main_networks[jj]
377+
rf = self._rfs[jj]
378+
pca = self._pcas[jj]
379+
X_pred = self._X_preds[jj]
380+
y_pred = self._y_preds[jj]
381+
if self.feature_bagging:
382+
X_ = X_batch[:, self.selected_features[jj]]
383+
orig_X_ = orig_X[:, self.selected_features[jj]]
384+
else:
385+
X_ = X_batch
386+
orig_X_ = orig_X
387+
388+
X_transformed = transform_data_for_main_network(
389+
X=X_, cfg=self._cfg, rf=rf, pca=pca
391390
)
392-
outputs_pred, intermediate_activations_pred = forward_main_network(
393-
X_pred_, main_network
391+
outputs, intermediate_activations = forward_main_network(
392+
X_transformed, main_network
394393
)
395-
for bb, bias in enumerate(self._model.nn_bias):
396-
if bb == 0:
397-
outputs = nn_bias_logits(
398-
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_
399-
)
400-
elif bb == 1:
401-
outputs = nn_bias_logits(
402-
outputs,
403-
intermediate_activations,
404-
intermediate_activations_pred,
405-
y_pred,
406-
bias,
407-
self.n_classes_,
408-
)
409-
410-
predicted = F.softmax(outputs, dim=1)
411-
yhats.append(predicted)
412-
413-
yhats = torch.stack(yhats)
414-
yhats = torch.mean(yhats, axis=0)
415-
return yhats.cpu().numpy()
394+
395+
if self.nn_bias:
396+
X_pred_ = transform_data_for_main_network(
397+
X=X_pred, cfg=self._cfg, rf=rf, pca=pca
398+
)
399+
outputs_pred, intermediate_activations_pred = forward_main_network(
400+
X_pred_, main_network
401+
)
402+
for bb, bias in enumerate(self._model.nn_bias):
403+
if bb == 0:
404+
outputs = nn_bias_logits(
405+
outputs, orig_X_, X_pred, y_pred, bias, self.n_classes_, self.nn_bias_mini_batches
406+
)
407+
elif bb == 1:
408+
outputs = nn_bias_logits(
409+
outputs, intermediate_activations, intermediate_activations_pred, y_pred, bias, self.n_classes_, self.nn_bias_mini_batches,
410+
)
411+
412+
predicted = F.softmax(outputs, dim=1)
413+
yhats.append(predicted)
414+
415+
yhats = torch.stack(yhats)
416+
yhats = torch.mean(yhats, axis=0)
417+
yhats = yhats.cpu().numpy()
418+
all_yhats.append(yhats)
419+
return np.concatenate(all_yhats, axis=0)
416420

417421
def predict(self, X: np.ndarray | pd.DataFrame) -> np.ndarray:
418422
outputs = self.predict_proba(X)

hyperfast/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from sklearn.decomposition import PCA
5-
from .utils import *
5+
from .utils import TorchPCA, get_main_weights, forward_linear_layer
66

77

88
class HyperFast(nn.Module):

hyperfast/utils.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch.optim as optim
88
import torch.nn.functional as F
99
from torch.utils.data import DataLoader, TensorDataset
10-
from types import SimpleNamespace
1110

1211

1312
def seed_everything(seed: int):
@@ -22,11 +21,11 @@ def seed_everything(seed: int):
2221

2322

2423
def nn_bias_logits(
25-
test_logits, test_samples, train_samples, train_labels, bias_param, n_classes
24+
test_logits, test_samples, train_samples, train_labels, bias_param, n_classes, mini_batches
2625
):
2726
with torch.no_grad():
2827
nn = NN(train_samples, train_labels)
29-
preds = nn.predict(test_samples)
28+
preds = nn.predict(test_samples, mini_batches=mini_batches)
3029
preds_onehot = F.one_hot(preds, n_classes)
3130
test_logits[preds_onehot.bool()] += bias_param
3231
return test_logits
@@ -240,6 +239,7 @@ def fine_tune_main_network(
240239

241240
for step in range(optimize_steps):
242241
for inputs, targets in dataloader:
242+
inputs, targets = inputs.to(device), targets.to(device)
243243
optimizer.zero_grad()
244244
outputs = main_model(inputs, targets)
245245
loss = criterion(outputs, targets)
@@ -294,7 +294,7 @@ def transform_data_for_main_network(X, cfg, rf, pca):
294294

295295

296296
def distance_matrix(x, y=None, p=2):
297-
y = x if type(y) == type(None) else y
297+
y = x if y is None else y
298298

299299
n = x.size(0)
300300
m = y.size(0)
@@ -321,11 +321,8 @@ def train(self, X, Y):
321321
self.train_pts = X
322322
self.train_label = Y
323323

324-
def __call__(self, x, mini_batches=True):
325-
return self.predict(x)
326-
327324
def predict(self, x, mini_batches=True):
328-
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
325+
if self.train_pts is None or self.train_label is None:
329326
name = self.__class__.__name__
330327
raise RuntimeError(
331328
f"{name} wasn't trained. Need to execute {name}.train() first"
@@ -341,7 +338,7 @@ def predict(self, x, mini_batches=True):
341338
num_batches = math.ceil(x.shape[0] / batch_size)
342339
labels = []
343340
for ii in range(num_batches):
344-
x_ = x[batch_size * ii : batch_size * (ii + 1), :]
341+
x_ = x[batch_size * ii:batch_size * (ii + 1), :]
345342
dist = distance_matrix(x_, self.train_pts, self.p)
346343
labels_ = torch.argmin(dist, dim=1)
347344
labels.append(labels_)
@@ -350,7 +347,7 @@ def predict(self, x, mini_batches=True):
350347
return self.train_label[labels]
351348

352349
def predict_from_training_with_LOO(self, mini_batches=True):
353-
if type(self.train_pts) == type(None) or type(self.train_label) == type(None):
350+
if self.train_pts is None or self.train_label is None:
354351
name = self.__class__.__name__
355352
raise RuntimeError(
356353
f"{name} wasn't trained. Need to execute {name}.train() first"
@@ -365,7 +362,7 @@ def predict_from_training_with_LOO(self, mini_batches=True):
365362
num_batches = math.ceil(self.train_pts.shape[0] / batch_size)
366363
labels = []
367364
for ii in range(num_batches):
368-
x_ = self.train_pts[batch_size * ii : batch_size * (ii + 1), :]
365+
x_ = self.train_pts[batch_size * ii:batch_size * (ii + 1), :]
369366
dist = distance_matrix(x_, self.train_pts, self.p)
370367
dist.fill_diagonal_(float("inf"))
371368
labels_ = torch.argmin(dist, dim=1)

0 commit comments

Comments
 (0)