diff --git a/pyproject.toml b/pyproject.toml
index 7d0f872..71254f5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "tabpfn-client"
-version = "0.0.11"
+version = "0.0.12"
 requires-python = ">=3.10"
 dependencies = [
     "httpx>=0.24.1",
diff --git a/quick_test.py b/quick_test.py
index 3f0caac..128b1af 100644
--- a/quick_test.py
+++ b/quick_test.py
@@ -32,7 +32,7 @@
 
     else:
         tabpfn_classifier.init()
-        tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted")
+        tabpfn = TabPFNClassifier(model="latest_tabpfn_hosted", n_estimators=3)
         # print("checking estimator", check_estimator(tabpfn))
         tabpfn.fit(X_train[:99], y_train[:99])
         print("predicting")
diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py
index 155f8ad..f78551c 100644
--- a/tabpfn_client/client.py
+++ b/tabpfn_client/client.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import traceback
 from pathlib import Path
 import httpx
 import logging
@@ -96,7 +99,7 @@ def upload_train_set(self, X, y) -> str:
         train_set_uid = response.json()["train_set_uid"]
         return train_set_uid
 
-    def predict(self, train_set_uid: str, x_test):
+    def predict(self, train_set_uid: str, x_test, tabpfn_config: dict | None=None):
         """
         Predict the class labels for the provided data (test set).
 
@@ -115,9 +118,14 @@ def predict(self, train_set_uid: str, x_test):
 
         x_test = common_utils.serialize_to_csv_formatted_bytes(x_test)
 
+        params = {"train_set_uid": train_set_uid}
+
+        if tabpfn_config is not None:
+            params["tabpfn_config"] = json.dumps(tabpfn_config, default=lambda x: x.to_dict())
+
         response = self.httpx_client.post(
             url=self.server_endpoints.predict.path,
-            params={"train_set_uid": train_set_uid},
+            params=params,
             files=common_utils.to_httpx_post_file_format([
                 ("x_file", "x_test_filename", x_test)
             ])
@@ -125,7 +133,7 @@ def predict(self, train_set_uid: str, x_test):
 
         self._validate_response(response, "predict")
 
-        return np.array(response.json()["y_pred"])
+        return np.array(response.json()["y_pred_proba"])
 
     @staticmethod
     def _validate_response(response, method_name, only_version_check=False):
@@ -198,7 +206,6 @@ def try_connection(self) -> bool:
 
         except httpx.ConnectError as e:
             logger.error(f"Failed to connect to the server with error: {e}")
-            import traceback
             traceback.print_exc()
             found_valid_connection = False
 
diff --git a/tabpfn_client/remote_tabpfn_classifier.py b/tabpfn_client/remote_tabpfn_classifier.py
deleted file mode 100644
index c6024fb..0000000
--- a/tabpfn_client/remote_tabpfn_classifier.py
+++ /dev/null
@@ -1,70 +0,0 @@
-from sklearn.utils.validation import check_is_fitted
-from sklearn.base import BaseEstimator, ClassifierMixin
-
-from tabpfn_client.service_wrapper import InferenceClient
-
-
-class RemoteTabPFNClassifier(BaseEstimator, ClassifierMixin):
-
-    def __init__(
-            self,
-            model=None,
-            device="cpu",
-            base_path=".",
-            model_string="",
-            batch_size_inference=4,
-            fp16_inference=False,
-            inference_mode=True,
-            c=None,
-            N_ensemble_configurations=10,
-            preprocess_transforms=("none", "power_all"),
-            feature_shift_decoder=False,
-            normalize_with_test=False,
-            average_logits=False,
-            categorical_features=tuple(),
-            optimize_metric=None,
-            seed=None,
-            transformer_predict_kwargs_init=None,
-            multiclass_decoder="permutation",
-
-            # dependency injection (for testing)
-            inference_handler=InferenceClient()
-    ):
-        # TODO:
-        #  These configs are ignored at the moment -> all clients share the same (default) on-server TabPFNClassifier.
-        #  In the future version, these configs will be used to create per-user TabPFNClassifier,
-        #    allowing the user to setup the desired TabPFNClassifier on the server.
-        # config for tabpfn
-        self.model = model
-        self.device = device
-        self.base_path = base_path
-        self.model_string = model_string
-        self.batch_size_inference = batch_size_inference
-        self.fp16_inference = fp16_inference
-        self.inference_mode = inference_mode
-        self.c = c
-        self.N_ensemble_configurations = N_ensemble_configurations
-        self.preprocess_transforms = preprocess_transforms
-        self.feature_shift_decoder = feature_shift_decoder
-        self.normalize_with_test = normalize_with_test
-        self.average_logits = average_logits
-        self.categorical_features = categorical_features
-        self.optimize_metric = optimize_metric
-        self.seed = seed
-        self.transformer_predict_kwargs_init = transformer_predict_kwargs_init
-        self.multiclass_decoder = multiclass_decoder
-
-        self.inference_handler = inference_handler
-
-    def fit(self, X, y):
-        self.inference_handler.fit(X, y)
-        self.fitted_ = True
-        return self
-
-    def predict(self, X):
-        check_is_fitted(self)
-        return self.inference_handler.predict(X)
-
-    def predict_proba(self, X):
-        check_is_fitted(self)
-        return self.inference_handler.predict_proba(X)
diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml
index ee05c1e..4923bf2 100644
--- a/tabpfn_client/server_config.yaml
+++ b/tabpfn_client/server_config.yaml
@@ -1,12 +1,13 @@
 ## testing
-#protocol: "http"
-#host: "0.0.0.0"
-#port: "8000"
+protocol: "http"
+host: "localhost"
+port: "8080"
 
 # production
-protocol: "https"
-host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
-port: "443"
+#protocol: "https"
+#host: "tabpfn-server-wjedmz7r5a-ez.a.run.app"
+#host: tabpfn-server-preprod-wjedmz7r5a-ez.a.run.app   # preprod
+#port: "443"
 endpoints:
   root:
     path: "/"
diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py
index 3773d17..77bb322 100644
--- a/tabpfn_client/service_wrapper.py
+++ b/tabpfn_client/service_wrapper.py
@@ -177,16 +177,9 @@ def fit(self, X, y) -> None:
 
         self.last_train_set_uid = self.service_client.upload_train_set(X, y)
 
-    def predict(self, X):
+    def predict(self, X, config=None):
         return self.service_client.predict(
             train_set_uid=self.last_train_set_uid,
-            x_test=X
+            x_test=X,
+            tabpfn_config=config
         )
-
-    def predict_proba(self, X):
-        return self.service_client.predict_proba(
-            train_set_uid=self.last_train_set_uid,
-            x_test=X
-        )
-
-
diff --git a/tabpfn_client/tabpfn_classifier.py b/tabpfn_client/tabpfn_classifier.py
index 362e297..cf30ff2 100644
--- a/tabpfn_client/tabpfn_classifier.py
+++ b/tabpfn_client/tabpfn_classifier.py
@@ -1,12 +1,13 @@
+from typing import Optional, Tuple, Literal
 import logging
-from pathlib import Path
 import shutil
+from dataclasses import dataclass, asdict
 
+import numpy as np
 from sklearn.base import BaseEstimator, ClassifierMixin
 from sklearn.utils.validation import check_is_fitted
 
 from tabpfn import TabPFNClassifier as LocalTabPFNClassifier
-from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier
 from tabpfn_client.service_wrapper import UserAuthenticationClient, InferenceClient
 from tabpfn_client.client import ServiceClient
 from tabpfn_client.constants import CACHE_DIR
@@ -78,108 +79,175 @@ def reset():
     shutil.rmtree(CACHE_DIR, ignore_errors=True)
 
 
+@dataclass(eq=True, frozen=True)
+class PreprocessorConfig:
+    """
+    Configuration for data preprocessors.
+
+    Attributes:
+        name (Literal): Name of the preprocessor.
+        categorical_name (Literal): Name of the categorical encoding method. Valid options are "none", "numeric",
+                                "onehot", "ordinal", "ordinal_shuffled". Default is "none".
+        append_original (bool): Whether to append the original features to the transformed features. Default is False.
+        subsample_features (float): Fraction of features to subsample. -1 means no subsampling. Default is -1.
+        global_transformer_name (str): Name of the global transformer to use. Default is None.
+    """
+
+    name: Literal[
+        "per_feature",  # a different transformation for each feature
+        "power",  # a standard sklearn power transformer
+        "safepower",  # a power transformer that prevents some numerical issues
+        "power_box",
+        "safepower_box",
+        "quantile_uni_coarse",  # different quantile transformations with few quantiles up to a lot
+        "quantile_norm_coarse",
+        "quantile_uni",
+        "quantile_norm",
+        "quantile_uni_fine",
+        "quantile_norm_fine",
+        "robust",  # a standard sklearn robust scaler
+        "kdi",
+        "none",  # no transformation (inside the transformer we anyways do a standardization)
+        "kdi_random_alpha",
+        "kdi_uni",
+        "kdi_random_alpha_uni",
+        "adaptive",
+        "norm_and_kdi",
+        # KDI with alpha collection
+        "kdi_alpha_0.3_uni",
+        "kdi_alpha_0.5_uni",
+        "kdi_alpha_0.8_uni",
+        "kdi_alpha_1.0_uni",
+        "kdi_alpha_1.2_uni",
+        "kdi_alpha_1.5_uni",
+        "kdi_alpha_2.0_uni",
+        "kdi_alpha_3.0_uni",
+        "kdi_alpha_5.0_uni",
+        "kdi_alpha_0.3",
+        "kdi_alpha_0.5",
+        "kdi_alpha_0.8",
+        "kdi_alpha_1.0",
+        "kdi_alpha_1.2",
+        "kdi_alpha_1.5",
+        "kdi_alpha_2.0",
+        "kdi_alpha_3.0",
+        "kdi_alpha_5.0",
+    ]
+    categorical_name: Literal[
+        "none",
+        "numeric",
+        "onehot",
+        "ordinal",
+        "ordinal_shuffled",
+        "ordinal_very_common_categories_shuffled",
+    ] = "none"
+    # categorical_name meanings:
+    # "none": categorical features are pretty much treated as ordinal, just not resorted
+    # "numeric": categorical features are treated as numeric, that means they are also power transformed for example
+    # "onehot": categorical features are onehot encoded
+    # "ordinal": categorical features are sorted and encoded as integers from 0 to n_categories - 1
+    # "ordinal_shuffled": categorical features are encoded as integers from 0 to n_categories - 1 in a random order
+    append_original: bool = False
+    subsample_features: Optional[float] = -1
+    global_transformer_name: Optional[str] = None
+    # if True, the transformed features (e.g. power transformed) are appended to the original features
+
+    def __str__(self):
+        return (
+            f"{self.name}_cat:{self.categorical_name}"
+            + ("_and_none" if self.append_original else "")
+            + (
+                "_subsample_feats_" + str(self.subsample_features)
+                if self.subsample_features > 0
+                else ""
+            )
+            + (
+                f"_global_transformer_{self.global_transformer_name}"
+                if self.global_transformer_name is not None
+                else ""
+            )
+        )
+
+    def can_be_cached(self):
+        return not self.subsample_features > 0
+
+    def to_dict(self):
+        return {k: str(v) if not isinstance(v, (str, int, float, list, dict)) else v for k, v in asdict(self).items()}
+
+
+ClassificationOptimizationMetricType = Literal[
+    "auroc", "roc", "auroc_ovo", "balanced_acc", "acc", "log_loss", None
+]
+
+
 class TabPFNClassifier(BaseEstimator, ClassifierMixin):
-    # def __init__(self):
-        # Configuration for TabPFNClassifier is still under development.
-    #     pass
-    
     def __init__(
-            self,
-            model="latest_tabpfn_hosted",
-            device="cpu",
-            base_path=Path(__file__).parent.parent.resolve(),
-            model_string="",
-            batch_size_inference=4,
-            fp16_inference=False,
-            inference_mode=True,
-            c=None,
-            N_ensemble_configurations=10,
-            preprocess_transforms=("none", "power_all"),
-            feature_shift_decoder=False,
-            normalize_with_test=False,
-            average_logits=False,
-            categorical_features=tuple(),
-            optimize_metric=None,
-            seed=None,
-            transformer_predict_kwargs_init=None,
-            multiclass_decoder="permutation",
+        self,
+        model="latest_tabpfn_hosted",
+        n_estimators: int = 4,
+        preprocess_transforms: Tuple[PreprocessorConfig, ...] = (
+            PreprocessorConfig(
+                "quantile_uni_coarse",
+                append_original=True,
+                categorical_name="ordinal_very_common_categories_shuffled",
+                global_transformer_name="svd",
+                subsample_features=-1,
+            ),
+            PreprocessorConfig(
+                "none", categorical_name="numeric", subsample_features=-1
+            ),
+        ),
+        feature_shift_decoder: str = "shuffle",
+        normalize_with_test: bool = False,
+        average_logits: bool = False,
+        optimize_metric: ClassificationOptimizationMetricType = "roc",
+        transformer_predict_kwargs: Optional[dict] = None,
+        multiclass_decoder="shuffle",
+        softmax_temperature: Optional[float] = -0.1,
+        use_poly_features=False,
+        max_poly_features=50,
+        remove_outliers=12.0,
+        add_fingerprint_features=True,
+        subsample_samples=-1,
     ):
-        # config for tabpfn
         self.model = model
-        self.device = device
-        self.base_path = base_path
-        self.model_string = model_string
-        self.batch_size_inference = batch_size_inference
-        self.fp16_inference = fp16_inference
-        self.inference_mode = inference_mode
-        self.c = c
-        self.N_ensemble_configurations = N_ensemble_configurations
+        self.n_estimators = n_estimators
         self.preprocess_transforms = preprocess_transforms
         self.feature_shift_decoder = feature_shift_decoder
         self.normalize_with_test = normalize_with_test
         self.average_logits = average_logits
-        self.categorical_features = categorical_features
         self.optimize_metric = optimize_metric
-        self.seed = seed
-        self.transformer_predict_kwargs_init = transformer_predict_kwargs_init
+        self.transformer_predict_kwargs = transformer_predict_kwargs
         self.multiclass_decoder = multiclass_decoder
+        self.softmax_temperature = softmax_temperature
+        self.use_poly_features = use_poly_features
+        self.max_poly_features = max_poly_features
+        self.remove_outliers = remove_outliers
+        self.add_fingerprint_features = add_fingerprint_features
+        self.subsample_samples = subsample_samples
 
     def fit(self, X, y):
         # assert init() is called
         if not g_tabpfn_config.is_initialized:
-            raise RuntimeError("TabPFNClassifier.init() must be called before using TabPFNClassifier")
-
-        # create classifier if not created yet
-        if not hasattr(self, "classifier"):
-            # arguments that are commented out are not used at the moment
-            # (not supported until new TabPFN interface is released)
-            classifier_cfg = {
-                # "model": self.model,
-                "device": self.device,
-                "base_path": self.base_path,
-                "model_string": self.model_string,
-                "batch_size_inference": self.batch_size_inference,
-                # "fp16_inference": self.fp16_inference,
-                # "inference_mode": self.inference_mode,
-                # "c": self.c,
-                "N_ensemble_configurations": self.N_ensemble_configurations,
-                # "preprocess_transforms": self.preprocess_transforms,
-                "feature_shift_decoder": self.feature_shift_decoder,
-                # "normalize_with_test": self.normalize_with_test,
-                # "average_logits": self.average_logits,
-                # "categorical_features": self.categorical_features,
-                # "optimize_metric": self.optimize_metric,
-                "seed": self.seed,
-                # "transformer_predict_kwargs_init": self.transformer_predict_kwargs_init,
-                "multiclass_decoder": self.multiclass_decoder
-            }
-            #classifier_cfg = {}
-
-            if g_tabpfn_config.use_server:
-                try:
-                    assert self.model == "latest_tabpfn_hosted", "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)"
-                except AssertionError as e:
-                    print(e)
-                self.classifier_ = RemoteTabPFNClassifier(
-                    **classifier_cfg,
-                    inference_handler=g_tabpfn_config.inference_handler
-                )
-            else:
-                try:
-                    assert self.model == "tabpfn_1_local", "Only 'tabpfn_1_local' model is supported at the moment for tabpfn_classifier.init(use_server=False)"
-                except AssertionError as e:
-                    print(e)
-                self.classifier_ = LocalTabPFNClassifier(**classifier_cfg)
-
-        self.classifier_.fit(X, y)
+            raise RuntimeError("tabpfn_client.init() must be called before using TabPFNClassifier")
+
+        if g_tabpfn_config.use_server:
+            try:
+                assert self.model == "latest_tabpfn_hosted", "Only 'latest_tabpfn_hosted' model is supported at the moment for tabpfn_classifier.init(use_server=True)"
+            except AssertionError as e:
+                print(e)
+            g_tabpfn_config.inference_handler.fit(X, y)
+            self.fitted_ = True
+        else:
+            raise NotImplementedError("Only server mode is supported at the moment for tabpfn_classifier.init(use_server=False)")
         return self
 
     def predict(self, X):
-        check_is_fitted(self)
-        return self.classifier_.predict(X)
+        probas = self.predict_proba(X)
+        return np.argmax(probas, axis=1)
 
     def predict_proba(self, X):
         check_is_fitted(self)
-        return self.classifier_.predict_proba(X)
+        return g_tabpfn_config.inference_handler.predict(X, config=self.get_params())
 
 
diff --git a/tabpfn_client/tabpfn_common_utils b/tabpfn_client/tabpfn_common_utils
index a2df122..cb44694 160000
--- a/tabpfn_client/tabpfn_common_utils
+++ b/tabpfn_client/tabpfn_common_utils
@@ -1 +1 @@
-Subproject commit a2df122f2894369a444eb2335776d7dd5eade5d9
+Subproject commit cb4469425eba995b4cefad1357c020878e1a6d02
diff --git a/tabpfn_client/tests/integration/test_tabpfn_classifier.py b/tabpfn_client/tests/integration/test_tabpfn_classifier.py
index 1fe4dbf..52eb80d 100644
--- a/tabpfn_client/tests/integration/test_tabpfn_classifier.py
+++ b/tabpfn_client/tests/integration/test_tabpfn_classifier.py
@@ -19,15 +19,6 @@ def tearDown(self):
         tabpfn_classifier.reset()
         ServiceClient().delete_instance()
 
-    def test_use_local_tabpfn_classifier(self):
-        tabpfn_classifier.init(use_server=False)
-        tabpfn = TabPFNClassifier(device="cpu", model="tabpfn_1_local")
-        tabpfn.fit(self.X_train, self.y_train)
-
-        self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier))
-        pred = tabpfn.predict(self.X_test)
-        self.assertEqual(pred.shape[0], self.X_test.shape[0])
-
     @with_mock_server()
     def test_use_remote_tabpfn_classifier(self, mock_server):
         # create dummy token file
@@ -52,7 +43,7 @@ def test_use_remote_tabpfn_classifier(self, mock_server):
         # mock prediction
         mock_server.router.post(mock_server.endpoints.predict.path).respond(
             200,
-            json={"y_pred": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict(self.X_test).tolist()}
+            json={"y_pred_proba": LocalTabPFNClassifier().fit(self.X_train, self.y_train).predict_proba(self.X_test).tolist()}
         )
         pred = tabpfn.predict(self.X_test)
         self.assertEqual(pred.shape[0], self.X_test.shape[0])
diff --git a/tabpfn_client/tests/unit/test_client.py b/tabpfn_client/tests/unit/test_client.py
index ab2a850..a913896 100644
--- a/tabpfn_client/tests/unit/test_client.py
+++ b/tabpfn_client/tests/unit/test_client.py
@@ -91,7 +91,7 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
 
         self.client.upload_train_set(self.X_train, self.y_train)
 
-        dummy_result = {"y_pred": [1, 2, 3]}
+        dummy_result = {"y_pred_proba": [1, 2, 3]}
         mock_server.router.post(mock_server.endpoints.predict.path).respond(
             200, json=dummy_result)
 
@@ -99,7 +99,7 @@ def test_predict_with_valid_train_set_and_test_set(self, mock_server):
             train_set_uid=dummy_json["train_set_uid"],
             x_test=self.X_test
         )
-        self.assertTrue(np.array_equal(pred, dummy_result["y_pred"]))
+        self.assertTrue(np.array_equal(pred, dummy_result["y_pred_proba"]))
 
     @with_mock_server()
     def test_add_user_information(self, mock_server):
diff --git a/tabpfn_client/tests/unit/test_prompt_agent.py b/tabpfn_client/tests/unit/test_prompt_agent.py
index 14b2b7b..7c17370 100644
--- a/tabpfn_client/tests/unit/test_prompt_agent.py
+++ b/tabpfn_client/tests/unit/test_prompt_agent.py
@@ -1,10 +1,7 @@
 import unittest
 from unittest.mock import patch, MagicMock
-import respx
-from httpx import Response
 from tabpfn_client.prompt_agent import PromptAgent
 from tabpfn_client.tests.mock_tabpfn_server import with_mock_server
-from tabpfn_client.service_wrapper import UserAuthenticationClient, ServiceClient
 
 
 class TestPromptAgent(unittest.TestCase):
diff --git a/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py
deleted file mode 100644
index 5c76b55..0000000
--- a/tabpfn_client/tests/unit/test_remote_tabpfn_classifier.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import unittest
-from unittest.mock import MagicMock, patch
-import shutil
-
-from sklearn.datasets import load_breast_cancer
-from sklearn.model_selection import train_test_split
-from sklearn.exceptions import NotFittedError
-
-from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier
-from tabpfn_client.client import ServiceClient
-from tabpfn_client.service_wrapper import InferenceClient
-from tabpfn_client.constants import CACHE_DIR
-
-
-class TestRemoteTabPFNClassifier(unittest.TestCase):
-
-    def setUp(self):
-        self.dummy_token = "dummy_token"
-        X, y = load_breast_cancer(return_X_y=True)
-        self.X_train, self.X_test, self.y_train, self.y_test = \
-            train_test_split(X, y, test_size=0.33)
-
-        # mock service client
-        self.mock_client = MagicMock(spec=ServiceClient)
-        self.mock_client.is_initialized.return_value = True
-        inference_handler = InferenceClient(service_client=self.mock_client)
-
-        self.remote_tabpfn = RemoteTabPFNClassifier(inference_handler=inference_handler)
-
-    def tearDown(self):
-        patch.stopall()
-        shutil.rmtree(CACHE_DIR, ignore_errors=True)
-
-    def test_fit_and_predict_with_valid_datasets(self):
-        # mock responses
-        self.mock_client.upload_train_set.return_value = "dummy_train_set_uid"
-
-        mock_predict_response = [1, 1, 0]
-        self.mock_client.predict.return_value = mock_predict_response
-
-        self.remote_tabpfn.fit(self.X_train, self.y_train)
-        y_pred = self.remote_tabpfn.predict(self.X_test)
-
-        self.assertEqual(mock_predict_response, y_pred)
-        self.mock_client.upload_train_set.called_once_with(self.X_train, self.y_train)
-        self.mock_client.predict.called_once_with(self.X_test)
-
-    def test_call_predict_without_calling_fit_before(self):
-        self.assertRaises(
-            NotFittedError,
-            self.remote_tabpfn.predict,
-            self.X_test
-        )
-
-    def test_call_predict_proba_without_calling_fit_before(self):
-        self.assertRaises(
-            NotFittedError,
-            self.remote_tabpfn.predict_proba,
-            self.X_test
-        )
-
-    def test_predict_with_conflicting_test_set(self):
-        # TODO: implement this
-        pass
diff --git a/tabpfn_client/tests/unit/test_tabpfn_classifier.py b/tabpfn_client/tests/unit/test_tabpfn_classifier.py
index ccf8a88..5653a25 100644
--- a/tabpfn_client/tests/unit/test_tabpfn_classifier.py
+++ b/tabpfn_client/tests/unit/test_tabpfn_classifier.py
@@ -2,13 +2,13 @@
 from unittest.mock import patch
 import shutil
 
+import numpy as np
 from sklearn.datasets import load_breast_cancer
 from sklearn.model_selection import train_test_split
-from tabpfn import TabPFNClassifier as LocalTabPFNClassifier
+from sklearn.exceptions import NotFittedError
 
 from tabpfn_client import tabpfn_classifier
 from tabpfn_client.tabpfn_classifier import TabPFNClassifier
-from tabpfn_client.remote_tabpfn_classifier import RemoteTabPFNClassifier
 from tabpfn_client.service_wrapper import UserAuthenticationClient
 from tabpfn_client.client import ServiceClient
 from tabpfn_client.tests.mock_tabpfn_server import with_mock_server
@@ -34,11 +34,6 @@ def tearDown(self):
         # remove cache dir
         shutil.rmtree(CACHE_DIR, ignore_errors=True)
 
-    def test_init_local_classifier(self):
-        tabpfn_classifier.init(use_server=False)
-        tabpfn = TabPFNClassifier(model="tabpfn_1_local").fit(self.X_train, self.y_train)
-        self.assertTrue(isinstance(tabpfn.classifier_, LocalTabPFNClassifier))
-
     @with_mock_server()
     @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_and_set_token")
     @patch("tabpfn_client.prompt_agent.PromptAgent.prompt_terms_and_cond",
@@ -54,13 +49,29 @@ def test_init_remote_classifier(self, mock_server, mock_prompt_for_terms_and_con
         )
         mock_server.router.get(mock_server.endpoints.retrieve_greeting_messages.path).respond(
             200, json={"messages": []})
+        mock_predict_response = [[1, 0.],[.9, .1],[0.01, 0.99]]
+        predict_route = mock_server.router.post(mock_server.endpoints.predict.path)
+        predict_route.respond(
+            200, json={"y_pred_proba": mock_predict_response}
+        )
 
         tabpfn_classifier.init(use_server=True)
-        tabpfn = TabPFNClassifier().fit(self.X_train, self.y_train)
-        self.assertTrue(isinstance(tabpfn.classifier_, RemoteTabPFNClassifier))
+
+        tabpfn = TabPFNClassifier(n_estimators=10)
+        self.assertRaises(
+            NotFittedError,
+            tabpfn.predict,
+            self.X_test
+        )
+        tabpfn.fit(self.X_train, self.y_train)
         self.assertTrue(mock_prompt_and_set_token.called)
         self.assertTrue(mock_prompt_for_terms_and_cond.called)
 
+        y_pred = tabpfn.predict(self.X_test)
+        self.assertTrue(np.all(np.argmax(mock_predict_response, axis=1) == y_pred))
+
+        self.assertIn('n_estimators%22%3A%2010', str(predict_route.calls.last.request.url), "check that n_estimators is passed to the server")
+
     @with_mock_server()
     def test_reuse_saved_access_token(self, mock_server):
         # mock connection and authentication
@@ -99,11 +110,6 @@ def test_invalid_saved_access_token(self, mock_server, mock_prompt_for_terms_and
         self.assertRaises(RuntimeError, tabpfn_classifier.init, use_server=True)
         self.assertTrue(mock_prompt_and_set_token.called)
 
-    def test_reset_on_local_classifier(self):
-        tabpfn_classifier.init(use_server=False)
-        tabpfn_classifier.reset()
-        self.assertFalse(tabpfn_classifier.g_tabpfn_config.is_initialized)
-
     @with_mock_server()
     def test_reset_on_remote_classifier(self, mock_server):
         # create dummy token file