10
10
11
11
logger = logging .getLogger (__name__ )
12
12
13
+
13
14
@dataclass (eq = True , frozen = True )
14
15
class PreprocessorConfig :
15
16
"""
@@ -108,6 +109,7 @@ def to_dict(self):
108
109
for k , v in asdict (self ).items ()
109
110
}
110
111
112
+
111
113
class TabPFNClassifier (BaseEstimator , ClassifierMixin ):
112
114
def __init__ (
113
115
self ,
@@ -128,7 +130,9 @@ def __init__(
128
130
feature_shift_decoder : str = "shuffle" ,
129
131
normalize_with_test : bool = False ,
130
132
average_logits : bool = False ,
131
- optimize_metric : Literal ["auroc" , "roc" , "auroc_ovo" , "balanced_acc" , "acc" , "log_loss" , None ] = "roc" ,
133
+ optimize_metric : Literal [
134
+ "auroc" , "roc" , "auroc_ovo" , "balanced_acc" , "acc" , "log_loss" , None
135
+ ] = "roc" ,
132
136
transformer_predict_kwargs : Optional [dict ] = None ,
133
137
multiclass_decoder = "shuffle" ,
134
138
softmax_temperature : Optional [float ] = - 0.1 ,
@@ -205,7 +209,9 @@ def predict(self, X):
205
209
206
210
def predict_proba (self , X ):
207
211
check_is_fitted (self )
208
- return config .g_tabpfn_config .inference_handler .predict (X , task = "classification" , config = self .get_params ())["probas" ]
212
+ return config .g_tabpfn_config .inference_handler .predict (
213
+ X , task = "classification" , config = self .get_params ()
214
+ )["probas" ]
209
215
210
216
211
217
class TabPFNRegressor (BaseEstimator , RegressorMixin ):
@@ -225,7 +231,9 @@ def __init__(
225
231
feature_shift_decoder : str = "shuffle" ,
226
232
normalize_with_test : bool = False ,
227
233
average_logits : bool = False ,
228
- optimize_metric : Literal ["mse" , "rmse" , "mae" , "r2" , "mean" , "median" , "mode" , "exact_match" , None ] = "rmse" ,
234
+ optimize_metric : Literal [
235
+ "mse" , "rmse" , "mae" , "r2" , "mean" , "median" , "mode" , "exact_match" , None
236
+ ] = "rmse" ,
229
237
transformer_predict_kwargs : Optional [Dict ] = None ,
230
238
softmax_temperature : Optional [float ] = - 0.1 ,
231
239
use_poly_features = False ,
@@ -324,7 +332,7 @@ def fit(self, X, y):
324
332
"Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)"
325
333
)
326
334
return self
327
-
335
+
328
336
def predict (self , X ):
329
337
full_prediction_dict = self .predict_full (X )
330
338
if self .optimize_metric in ("mse" , "rmse" , "r2" , "mean" , None ):
@@ -335,7 +343,9 @@ def predict(self, X):
335
343
return full_prediction_dict ["mode" ]
336
344
else :
337
345
raise ValueError (f"Optimize metric { self .optimize_metric } not supported" )
338
-
346
+
339
347
def predict_full (self , X ):
340
348
check_is_fitted (self )
341
- return config .g_tabpfn_config .inference_handler .predict (X , task = "regression" , config = self .get_params ())
349
+ return config .g_tabpfn_config .inference_handler .predict (
350
+ X , task = "regression" , config = self .get_params ()
351
+ )
0 commit comments