Skip to content

Commit 9200e09

Browse files
committedAug 10, 2023
Implement minimal local and remote mode
1 parent 0dbc3b5 commit 9200e09

7 files changed

+126
-40
lines changed
 

‎.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "tabpfn_client/tabpfn_common_utils"]
2+
path = tabpfn_client/tabpfn_common_utils
3+
url = https://github.com/liam-sbhoo/tabpfn_common_utils.git

‎tabpfn_client/server_spec.yaml

+22-21
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
1-
host: "192.52.42.37"
1+
#host: "192.52.42.37"
2+
host: "0.0.0.0"
23
port: "80"
34
endpoints:
4-
- root:
5-
path: "/"
6-
methods: ["GET"]
7-
description: "Root endpoint"
5+
root:
6+
path: "/"
7+
methods: ["GET"]
8+
description: "Root endpoint"
89

9-
- login:
10-
path: "/auth/login/"
11-
methods: ["POST"]
12-
description: "Login endpoint"
10+
login:
11+
path: "/auth/login/"
12+
methods: ["POST"]
13+
description: "Login endpoint"
1314

14-
- protected_root:
15-
path: "/protected/"
16-
methods: ["GET"]
17-
description: "Protected root endpoint"
15+
protected_root:
16+
path: "/protected/"
17+
methods: ["GET"]
18+
description: "Protected root endpoint"
1819

19-
- upload_train_set:
20-
path: "/upload/train_set/"
21-
methods: ["POST"]
22-
description: "Upload train set endpoint"
20+
upload_train_set:
21+
path: "/upload/train_set/"
22+
methods: ["POST"]
23+
description: "Upload train set endpoint"
2324

24-
- predict:
25-
path: "/predict/"
26-
methods: ["POST"]
27-
description: "Predict endpoint"
25+
predict:
26+
path: "/predict/"
27+
methods: ["POST"]
28+
description: "Predict endpoint"

‎tabpfn_client/tabpfn_classifier.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from omegaconf import OmegaConf
44

55
from tabpfn import TabPFNClassifier as TabPFNClassifierLocal
6-
from tabpfn_client.tabpfn_classifier_interface import TabPFNClassifierInterface
6+
from tabpfn_client.tabpfn_classifier_interface import AbstractTabPFNClassifier
77
from tabpfn_client.tabpfn_service_client import TabPFNServiceClient
88

99
SERVER_SPEC_FILE = pathlib.Path(__file__).parent.resolve() / "server_spec.yaml"
@@ -46,9 +46,7 @@ def remove_saved_access_token():
4646
pass
4747

4848

49-
class TabPFNClassifier(TabPFNClassifierInterface):
50-
# TODO: ask Sam/Noah if we could create an interface of TabPFNClassifier instead
51-
49+
class TabPFNClassifier(AbstractTabPFNClassifier):
5250
def __init__(self, device='cpu', base_path=pathlib.Path(__file__).parent.parent.resolve(), model_string='',
5351
N_ensemble_configurations=3, no_preprocess_mode=False, multiclass_decoder='permutation',
5452
feature_shift_decoder=True, only_inference=True, seed=0, no_grad=True, batch_size_inference=32):

‎tabpfn_client/tabpfn_classifier_interface.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
from abc import ABC, abstractmethod
22

33

4-
class TabPFNClassifierInterface(ABC):
4+
class AbstractTabPFNClassifier(ABC):
55

66
@abstractmethod
77
def remove_models_from_memory(self):
88
pass
99

1010
@abstractmethod
11-
def load_result_minimal(self, path, i, e):
12-
pass
13-
14-
@abstractmethod
15-
def fit(self, X, y):
11+
def fit(self, X, y, overwrite_warning=False):
1612
pass
1713

1814
@abstractmethod
+56-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import httpx
3+
from typing import Any
4+
import logging
35

4-
from tabpfn_client.tabpfn_classifier_interface import TabPFNClassifierInterface
6+
from tabpfn_client.tabpfn_classifier_interface import AbstractTabPFNClassifier
7+
from tabpfn_client.tabpfn_common_utils import utils as common_utils
58

69
SERVER_ENDPOINTS_YAML = os.path.join(os.path.dirname(__file__), "server_endpoints.yaml")
710

811

9-
class TabPFNServiceClient(TabPFNClassifierInterface):
12+
class TabPFNServiceClient(AbstractTabPFNClassifier):
1013
def __init__(self, server_spec: dict, access_token: str):
1114
self.host = server_spec["host"]
1215
self.port = server_spec["port"]
@@ -16,22 +19,66 @@ def __init__(self, server_spec: dict, access_token: str):
1619
self.access_token = access_token
1720
self.server_endpoints = server_spec["endpoints"]
1821

22+
self.last_per_user_train_set_id = None
23+
1924
def remove_models_from_memory(self):
2025
raise NotImplementedError
2126

22-
def load_result_minimal(self, path, i, e):
23-
raise NotImplementedError
27+
def fit(self, X: Any, y: Any):
28+
X = common_utils.serialize_to_csv_formatted_bytes(X)
29+
y = common_utils.serialize_to_csv_formatted_bytes(y)
2430

25-
def fit(self, X, y):
26-
pass
31+
response = self.client.post(
32+
url=self.server_endpoints["upload_train_set"]["path"],
33+
headers={"Authorization": f"Bearer {self.access_token}"},
34+
files=common_utils.to_httpx_post_file_format([
35+
("x_file", X),
36+
("y_file", y)
37+
])
38+
)
2739

28-
def predict(self, X):
29-
raise NotImplementedError
40+
if response.status_code != 200:
41+
logging.error(f"Fail to call upload_train_set(), response status: {response.status_code}")
42+
# TODO: error probably doesn't have json() method, check in unit test
43+
logging.error(f"Fail to call fit(), server response: {response.json()}")
44+
raise RuntimeError(f"Fail to call fit(), server response: {response.json()}")
45+
46+
self.last_per_user_train_set_id = response.json()["per_user_train_set_id"]
47+
48+
return self
49+
50+
def predict(self, X, return_winning_class=False, normalize_with_test=False):
51+
52+
# TODO: handle return_winning_class and normalize_with_test
53+
54+
# check if user has already called fit() before
55+
if self.last_per_user_train_set_id is None:
56+
raise RuntimeError("You must call fit() before calling predict()")
57+
58+
X = common_utils.serialize_to_csv_formatted_bytes(X)
59+
60+
response = self.client.post(
61+
url=self.server_endpoints["predict"]["path"],
62+
headers={"Authorization": f"Bearer {self.access_token}"},
63+
params={"per_user_train_set_id": self.last_per_user_train_set_id},
64+
files=common_utils.to_httpx_post_file_format([
65+
("x_file", X)
66+
])
67+
)
68+
69+
if response.status_code != 200:
70+
logging.error(f"Fail to call predict(), response status: {response.status_code}")
71+
raise RuntimeError(f"Fail to call predict(), server response: {response.json()}")
72+
73+
return response.json()
3074

3175
def predict_proba(self, X, return_winning_probability=False, normalize_with_test=False):
3276
pass
3377

3478
def try_root(self):
35-
response = self.client.get("/")
79+
response = self.client.get(
80+
self.server_endpoints["protected_root"]["path"],
81+
headers={"Authorization": f"Bearer {self.access_token}"},
82+
)
3683
print("response:", response.json())
3784
return response
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import unittest
2+
3+
4+
class TestTabpfnClassifier(unittest.TestCase):
5+
def test_use_local_tabpfn_classifier(self):
6+
pass
7+
8+
def test_use_remote_tabpfn_classifier(self):
9+
pass
10+
11+
12+
class TestInitTabPFNBuilder(unittest.TestCase):
13+
def test_save_access_token_upon_successful_login(self):
14+
pass
15+
16+
def test_remove_saved_access_token(self):
17+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import unittest
2+
3+
4+
class TestTabpfnServiceClient(unittest.TestCase):
5+
def test_invalid_auth_token(self):
6+
pass
7+
8+
def test_predict_with_valid_train_set_and_test_set(self):
9+
pass
10+
11+
def test_predict_with_conflicting_test_set(self):
12+
pass
13+
14+
def test_call_predict_without_calling_fit_before(self):
15+
pass
16+
17+
def test_call_predict_proba_without_calling_fit_before(self):
18+
pass
19+
20+
def test_call_predict_after_calling_fit_twice(self):
21+
pass
22+
23+
def test_call_predict_proba_after_calling_fit_twice(self):
24+
pass

0 commit comments

Comments
 (0)
Please sign in to comment.