Skip to content

Commit 980aeff

Browse files
committed
ruff
1 parent d03708d commit 980aeff

File tree

7 files changed

+36
-16
lines changed

7 files changed

+36
-16
lines changed

quick_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import logging
2-
import numpy as np
32

43
from sklearn.datasets import load_breast_cancer, load_diabetes
54
from sklearn.model_selection import train_test_split
65

7-
from tabpfn_client import estimator, UserDataClient, init
6+
from tabpfn_client import UserDataClient, init
87
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor
98

109
logging.basicConfig(level=logging.DEBUG)
@@ -44,4 +43,4 @@
4443
print("predicting reg")
4544
print(tabpfn.predict(X_test))
4645

47-
print(UserDataClient().get_data_summary())
46+
print(UserDataClient().get_data_summary())

tabpfn_client/client.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ def upload_train_set(self, X, y) -> str:
100100
train_set_uid = response.json()["train_set_uid"]
101101
return train_set_uid
102102

103-
def predict(self, train_set_uid: str, x_test, task: Literal["classification", "regression"], tabpfn_config: dict | None = None) -> dict[str, np.ndarray]:
103+
def predict(
104+
self,
105+
train_set_uid: str,
106+
x_test,
107+
task: Literal["classification", "regression"],
108+
tabpfn_config: dict | None = None,
109+
) -> dict[str, np.ndarray]:
104110
"""
105111
Predict the class labels for the provided data (test set).
106112
@@ -140,7 +146,7 @@ def predict(self, train_set_uid: str, x_test, task: Literal["classification", "r
140146

141147
if not isinstance(result, dict):
142148
result = {"probas": result}
143-
149+
144150
for k in result:
145151
result[k] = np.array(result[k])
146152

tabpfn_client/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ def reset():
7373
g_tabpfn_config.user_auth_handler.reset_cache()
7474

7575
# remove cache dir
76-
shutil.rmtree(CACHE_DIR, ignore_errors=True)
76+
shutil.rmtree(CACHE_DIR, ignore_errors=True)

tabpfn_client/estimator.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
1314
@dataclass(eq=True, frozen=True)
1415
class PreprocessorConfig:
1516
"""
@@ -108,6 +109,7 @@ def to_dict(self):
108109
for k, v in asdict(self).items()
109110
}
110111

112+
111113
class TabPFNClassifier(BaseEstimator, ClassifierMixin):
112114
def __init__(
113115
self,
@@ -128,7 +130,9 @@ def __init__(
128130
feature_shift_decoder: str = "shuffle",
129131
normalize_with_test: bool = False,
130132
average_logits: bool = False,
131-
optimize_metric: Literal["auroc", "roc", "auroc_ovo", "balanced_acc", "acc", "log_loss", None] = "roc",
133+
optimize_metric: Literal[
134+
"auroc", "roc", "auroc_ovo", "balanced_acc", "acc", "log_loss", None
135+
] = "roc",
132136
transformer_predict_kwargs: Optional[dict] = None,
133137
multiclass_decoder="shuffle",
134138
softmax_temperature: Optional[float] = -0.1,
@@ -205,7 +209,9 @@ def predict(self, X):
205209

206210
def predict_proba(self, X):
207211
check_is_fitted(self)
208-
return config.g_tabpfn_config.inference_handler.predict(X, task="classification", config=self.get_params())["probas"]
212+
return config.g_tabpfn_config.inference_handler.predict(
213+
X, task="classification", config=self.get_params()
214+
)["probas"]
209215

210216

211217
class TabPFNRegressor(BaseEstimator, RegressorMixin):
@@ -225,7 +231,9 @@ def __init__(
225231
feature_shift_decoder: str = "shuffle",
226232
normalize_with_test: bool = False,
227233
average_logits: bool = False,
228-
optimize_metric: Literal["mse", "rmse", "mae", "r2", "mean", "median", "mode", "exact_match", None] = "rmse",
234+
optimize_metric: Literal[
235+
"mse", "rmse", "mae", "r2", "mean", "median", "mode", "exact_match", None
236+
] = "rmse",
229237
transformer_predict_kwargs: Optional[Dict] = None,
230238
softmax_temperature: Optional[float] = -0.1,
231239
use_poly_features=False,
@@ -324,7 +332,7 @@ def fit(self, X, y):
324332
"Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)"
325333
)
326334
return self
327-
335+
328336
def predict(self, X):
329337
full_prediction_dict = self.predict_full(X)
330338
if self.optimize_metric in ("mse", "rmse", "r2", "mean", None):
@@ -335,7 +343,9 @@ def predict(self, X):
335343
return full_prediction_dict["mode"]
336344
else:
337345
raise ValueError(f"Optimize metric {self.optimize_metric} not supported")
338-
346+
339347
def predict_full(self, X):
340348
check_is_fitted(self)
341-
return config.g_tabpfn_config.inference_handler.predict(X, task="regression", config=self.get_params())
349+
return config.g_tabpfn_config.inference_handler.predict(
350+
X, task="regression", config=self.get_params()
351+
)

tabpfn_client/service_wrapper.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ def fit(self, X, y) -> None:
180180

181181
self.last_train_set_uid = self.service_client.upload_train_set(X, y)
182182

183-
def predict(self, X, task:Literal["classification","regression"], config=None):
183+
def predict(self, X, task: Literal["classification", "regression"], config=None):
184184
return self.service_client.predict(
185-
train_set_uid=self.last_train_set_uid, x_test=X, tabpfn_config=config, task=task,
185+
train_set_uid=self.last_train_set_uid,
186+
x_test=X,
187+
tabpfn_config=config,
188+
task=task,
186189
)

tabpfn_client/tests/integration/test_tabpfn_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from tabpfn_client import init, reset
8-
from tabpfn_client import estimator, TabPFNClassifier
8+
from tabpfn_client import TabPFNClassifier
99
from tabpfn_client.tests.mock_tabpfn_server import with_mock_server
1010
from tabpfn_client.service_wrapper import UserAuthenticationClient
1111
from tabpfn_client.client import ServiceClient

tabpfn_client/tests/unit/test_client.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
179179
)
180180

181181
pred = self.client.predict(
182-
train_set_uid=dummy_json["train_set_uid"], x_test=self.X_test, task="classification"
182+
train_set_uid=dummy_json["train_set_uid"],
183+
x_test=self.X_test,
184+
task="classification",
183185
)
184186
self.assertTrue(np.array_equal(pred["probas"], dummy_result["classification"]))
185187

0 commit comments

Comments
 (0)