@@ -41,6 +41,7 @@ class HyperFastClassifier(BaseEstimator, ClassifierMixin):
41
41
n_ensemble (int): Number of ensemble models to use.
42
42
batch_size (int): Size of the batch for weight prediction and ensembling.
43
43
nn_bias (bool): Whether to use nearest neighbor bias.
44
+ nn_bias_mini_batches (bool): Whether to use mini-batches of size 128 for nearest neighbor bias.
44
45
optimization (str or None): Strategy for optimization, can be None, 'optimize', or 'ensemble_optimize'.
45
46
optimize_steps (int): Number of optimization steps.
46
47
torch_pca (bool): Whether to use PyTorch-based PCA optimized for GPU (fast) or scikit-learn PCA (slower).
@@ -58,6 +59,7 @@ def __init__(
58
59
n_ensemble : int = 16 ,
59
60
batch_size : int = 2048 ,
60
61
nn_bias : bool = False ,
62
+ nn_bias_mini_batches : bool = True ,
61
63
optimization : str | None = "ensemble_optimize" ,
62
64
optimize_steps : int = 64 ,
63
65
torch_pca : bool = True ,
@@ -72,6 +74,7 @@ def __init__(
72
74
self .n_ensemble = n_ensemble
73
75
self .batch_size = batch_size
74
76
self .nn_bias = nn_bias
77
+ self .nn_bias_mini_batches = nn_bias_mini_batches
75
78
self .optimization = optimization
76
79
self .optimize_steps = optimize_steps
77
80
self .torch_pca = torch_pca
@@ -112,7 +115,7 @@ def _initialize_model(self, cfg: SimpleNamespace) -> HyperFast:
112
115
flush = True ,
113
116
)
114
117
model .load_state_dict (
115
- torch .load (cfg .model_path , map_location = torch .device (cfg .device ))
118
+ torch .load (cfg .model_path , map_location = torch .device (cfg .device ), weights_only = True )
116
119
)
117
120
print (
118
121
f"Model loaded from { cfg .model_path } on { cfg .device } device." ,
@@ -208,9 +211,7 @@ def _preprocess_fitting_data(
208
211
y = column_or_1d (y , warn = True )
209
212
self .n_features_in_ = x .shape [1 ]
210
213
self .classes_ , y = np .unique (y , return_inverse = True )
211
- return torch .tensor (x , dtype = torch .float ).to (self .device ), torch .tensor (
212
- y , dtype = torch .long
213
- ).to (self .device )
214
+ return torch .tensor (x , dtype = torch .float ), torch .tensor (y , dtype = torch .long )
214
215
215
216
def _preprocess_test_data (
216
217
self ,
@@ -240,7 +241,7 @@ def _preprocess_test_data(
240
241
x_test = check_array (x_test )
241
242
# Standardize data
242
243
x_test = self ._scaler .transform (x_test )
243
- return torch .tensor (x_test , dtype = torch .float ). to ( self . device )
244
+ return torch .tensor (x_test , dtype = torch .float )
244
245
245
246
def _initialize_fit_attributes (self ) -> None :
246
247
self ._rfs = []
@@ -314,6 +315,7 @@ def fit(
314
315
315
316
for n in range (self .n_ensemble ):
316
317
X_pred , y_pred = self ._sample_data (X , y )
318
+ X_pred , y_pred = X_pred .to (self .device ), y_pred .to (self .device )
317
319
self .n_classes_ = len (torch .unique (y_pred ).cpu ().numpy ())
318
320
319
321
rf , pca , main_network = self ._model (X_pred , y_pred , self .n_classes_ )
@@ -362,57 +364,59 @@ def fit(
362
364
def predict_proba (self , X : np .ndarray | pd .DataFrame ) -> np .ndarray :
363
365
check_is_fitted (self )
364
366
X = self ._preprocess_test_data (X )
365
- with torch .no_grad ():
366
- orig_X = X
367
- yhats = []
368
- for jj in range (len (self ._main_networks )):
369
- main_network = self ._main_networks [jj ]
370
- rf = self ._rfs [jj ]
371
- pca = self ._pcas [jj ]
372
- X_pred = self ._X_preds [jj ]
373
- y_pred = self ._y_preds [jj ]
374
- if self .feature_bagging :
375
- X_ = X [:, self .selected_features [jj ]]
376
- orig_X_ = orig_X [:, self .selected_features [jj ]]
377
- else :
378
- X_ = X
379
- orig_X_ = orig_X
380
-
381
- X_transformed = transform_data_for_main_network (
382
- X = X_ , cfg = self ._cfg , rf = rf , pca = pca
383
- )
384
- outputs , intermediate_activations = forward_main_network (
385
- X_transformed , main_network
386
- )
387
-
388
- if self .nn_bias :
389
- X_pred_ = transform_data_for_main_network (
390
- X = X_pred , cfg = self ._cfg , rf = rf , pca = pca
367
+ X_dataset = torch .utils .data .TensorDataset (X )
368
+ X_loader = torch .utils .data .DataLoader (X_dataset , batch_size = self .batch_size , shuffle = False )
369
+ all_yhats = []
370
+ for X_batch in X_loader :
371
+ X_batch = X_batch [0 ].to (self .device )
372
+ with torch .no_grad ():
373
+ orig_X = X_batch
374
+ yhats = []
375
+ for jj in range (len (self ._main_networks )):
376
+ main_network = self ._main_networks [jj ]
377
+ rf = self ._rfs [jj ]
378
+ pca = self ._pcas [jj ]
379
+ X_pred = self ._X_preds [jj ]
380
+ y_pred = self ._y_preds [jj ]
381
+ if self .feature_bagging :
382
+ X_ = X_batch [:, self .selected_features [jj ]]
383
+ orig_X_ = orig_X [:, self .selected_features [jj ]]
384
+ else :
385
+ X_ = X_batch
386
+ orig_X_ = orig_X
387
+
388
+ X_transformed = transform_data_for_main_network (
389
+ X = X_ , cfg = self ._cfg , rf = rf , pca = pca
391
390
)
392
- outputs_pred , intermediate_activations_pred = forward_main_network (
393
- X_pred_ , main_network
391
+ outputs , intermediate_activations = forward_main_network (
392
+ X_transformed , main_network
394
393
)
395
- for bb , bias in enumerate (self ._model .nn_bias ):
396
- if bb == 0 :
397
- outputs = nn_bias_logits (
398
- outputs , orig_X_ , X_pred , y_pred , bias , self .n_classes_
399
- )
400
- elif bb == 1 :
401
- outputs = nn_bias_logits (
402
- outputs ,
403
- intermediate_activations ,
404
- intermediate_activations_pred ,
405
- y_pred ,
406
- bias ,
407
- self .n_classes_ ,
408
- )
409
-
410
- predicted = F .softmax (outputs , dim = 1 )
411
- yhats .append (predicted )
412
-
413
- yhats = torch .stack (yhats )
414
- yhats = torch .mean (yhats , axis = 0 )
415
- return yhats .cpu ().numpy ()
394
+
395
+ if self .nn_bias :
396
+ X_pred_ = transform_data_for_main_network (
397
+ X = X_pred , cfg = self ._cfg , rf = rf , pca = pca
398
+ )
399
+ outputs_pred , intermediate_activations_pred = forward_main_network (
400
+ X_pred_ , main_network
401
+ )
402
+ for bb , bias in enumerate (self ._model .nn_bias ):
403
+ if bb == 0 :
404
+ outputs = nn_bias_logits (
405
+ outputs , orig_X_ , X_pred , y_pred , bias , self .n_classes_ , self .nn_bias_mini_batches
406
+ )
407
+ elif bb == 1 :
408
+ outputs = nn_bias_logits (
409
+ outputs , intermediate_activations , intermediate_activations_pred , y_pred , bias , self .n_classes_ , self .nn_bias_mini_batches ,
410
+ )
411
+
412
+ predicted = F .softmax (outputs , dim = 1 )
413
+ yhats .append (predicted )
414
+
415
+ yhats = torch .stack (yhats )
416
+ yhats = torch .mean (yhats , axis = 0 )
417
+ yhats = yhats .cpu ().numpy ()
418
+ all_yhats .append (yhats )
419
+ return np .concatenate (all_yhats , axis = 0 )
416
420
417
421
def predict (self , X : np .ndarray | pd .DataFrame ) -> np .ndarray :
418
422
outputs = self .predict_proba (X )
0 commit comments