Skip to content

Commit a500ec6

Browse files
authored
Move expense functionality to common (#94)
* Move expense functionality to common * Ruff clean up * Fix latency offset * Update ref
1 parent 0403626 commit a500ec6

File tree

3 files changed

+18
-53
lines changed

3 files changed

+18
-53
lines changed

tabpfn_client/mock_prediction.py

+11-48
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import warnings
88
import logging
99
import functools
10+
1011
from tabpfn_client.client import ServiceClient
12+
from tabpfn_client.tabpfn_common_utils.expense_estimation import estimate_duration
13+
14+
15+
COST_ESTIMATION_LATENCY_OFFSET = 1.0
16+
1117

1218
# For seamlessly switching between a mock mode for simulating prediction
1319
# costs and real prediction, use thread-local variables to keep track of the
@@ -49,53 +55,6 @@ def increment_mock_time(seconds: float):
4955
set_mock_time(get_mock_time() + seconds)
5056

5157

52-
def estimate_duration(
53-
num_rows: int,
54-
num_features: int,
55-
task: Literal["classification", "regression"],
56-
tabpfn_config: dict = {},
57-
) -> float:
58-
"""
59-
Estimates the duration of a prediction task.
60-
"""
61-
# Logic comes from _estimate_model_usage in base.py of the TabPFN codebase.
62-
CONSTANT_COMPUTE_OVERHEAD = 8000
63-
NUM_SAMPLES_FACTOR = 4
64-
NUM_SAMPLES_PLUS_FEATURES = 6.5
65-
CELLS_FACTOR = 0.25
66-
CELLS_SQUARED_FACTOR = 1.3e-7
67-
EMBEDDING_SIZE = 192
68-
NUM_HEADS = 6
69-
NUM_LAYERS = 12
70-
FEATURES_PER_GROUP = 2
71-
GPU_FACTOR = 1e-11
72-
LATENCY_OFFSET = 1.0
73-
74-
n_estimators = tabpfn_config.get(
75-
"n_estimators", 4 if task == "classification" else 8
76-
)
77-
78-
num_samples = num_rows
79-
num_feature_groups = int(np.ceil(num_features / FEATURES_PER_GROUP))
80-
81-
num_cells = (num_feature_groups + 1) * num_samples
82-
compute_cost = (EMBEDDING_SIZE**2) * NUM_HEADS * NUM_LAYERS
83-
84-
base_duration = (
85-
n_estimators
86-
* compute_cost
87-
* (
88-
CONSTANT_COMPUTE_OVERHEAD
89-
+ num_samples * NUM_SAMPLES_FACTOR
90-
+ (num_samples + num_feature_groups) * NUM_SAMPLES_PLUS_FEATURES
91-
+ num_cells * CELLS_FACTOR
92-
+ num_cells**2 * CELLS_SQUARED_FACTOR
93-
)
94-
)
95-
96-
return round(base_duration * GPU_FACTOR + LATENCY_OFFSET, 3)
97-
98-
9958
def mock_predict(
10059
X_test,
10160
task: Literal["classification", "regression"],
@@ -116,7 +75,11 @@ def mock_predict(
11675
)
11776

11877
duration = estimate_duration(
119-
X_train.shape[0] + X_test.shape[0], X_test.shape[1], task, config
78+
num_rows=X_train.shape[0] + X_test.shape[0],
79+
num_features=X_test.shape[1],
80+
task=task,
81+
tabpfn_config=config,
82+
latency_offset=COST_ESTIMATION_LATENCY_OFFSET, # To slightly overestimate (safer)
12083
)
12184
increment_mock_time(duration)
12285

tabpfn_client/tests/unit/test_mock_prediction.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_mock_cost,
1010
estimate_duration,
1111
is_mock_mode,
12+
COST_ESTIMATION_LATENCY_OFFSET,
1213
)
1314
from tabpfn_client.estimator import TabPFNClassifier, TabPFNRegressor
1415

@@ -50,10 +51,11 @@ def test_mock_mode_behavior(self):
5051

5152
# Verify time increased by the estimated duration
5253
expected_duration = estimate_duration(
53-
self.X_train.shape[0] + self.X_test.shape[0],
54-
self.X_test.shape[1],
55-
"classification",
56-
self.config,
54+
num_rows=self.X_train.shape[0] + self.X_test.shape[0],
55+
num_features=self.X_test.shape[1],
56+
task="classification",
57+
tabpfn_config=self.config,
58+
latency_offset=COST_ESTIMATION_LATENCY_OFFSET,
5759
)
5860
self.assertAlmostEqual(
5961
get_mock_time() - initial_time, expected_duration

0 commit comments

Comments
 (0)