Skip to content

Commit 169cb69

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Fix fantasization with FixedNoiseGP and outcome transforms and use FantasizeMixin (#2011)
Summary: This fixes fantasization with FixedNoiseGP when using outcome transforms----previously, already-transformed noise was transformed again during fantasization. This also improves the fantasization for batched and batched multi-output models to use the average noise for each batch and output. This also removes repeated code and uses the logic in `FantasizeMixin.fantasize` for handling `X` with size 0 on the -2 dimension. This also deprecates the use of `observation_noise` as a boolean argument to fantasize. Reviewed By: Balandat Differential Revision: D49200325
1 parent fa51038 commit 169cb69

15 files changed

+259
-100
lines changed

botorch/acquisition/active_learning.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def forward(self, X: Tensor) -> Tensor:
9393
# Construct the fantasy model (we actually do not use the full model,
9494
# this is just a convenient way of computing fast posterior covariances
9595
fantasy_model = self.model.fantasize(
96-
X=X, sampler=self.sampler, observation_noise=True
96+
X=X,
97+
sampler=self.sampler,
9798
)
9899

99100
bdims = tuple(1 for _ in X.shape[:-2])

botorch/acquisition/knowledge_gradient.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,8 @@ def forward(self, X: Tensor) -> Tensor:
184184

185185
# construct the fantasy model of shape `num_fantasies x b`
186186
fantasy_model = self.model.fantasize(
187-
X=X_actual, sampler=self.sampler, observation_noise=True
187+
X=X_actual,
188+
sampler=self.sampler,
188189
)
189190

190191
# get the value function
@@ -233,7 +234,8 @@ def evaluate(self, X: Tensor, bounds: Tensor, **kwargs: Any) -> Tensor:
233234

234235
# construct the fantasy model of shape `num_fantasies x b`
235236
fantasy_model = self.model.fantasize(
236-
X=X, sampler=self.sampler, observation_noise=True
237+
X=X,
238+
sampler=self.sampler,
237239
)
238240

239241
# get the value function
@@ -451,7 +453,8 @@ def forward(self, X: Tensor) -> Tensor:
451453
# construct the fantasy model of shape `num_fantasies x b`
452454
# expand X (to potentially add trace observations)
453455
fantasy_model = self.model.fantasize(
454-
X=self.expand(X_eval), sampler=self.sampler, observation_noise=True
456+
X=self.expand(X_eval),
457+
sampler=self.sampler,
455458
)
456459
# get the value function
457460
value_function = _get_value_function(

botorch/acquisition/max_value_entropy_search.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
389389
if X_pending is not None:
390390
# fantasize the model and use this as the new model
391391
self.model = init_model.fantasize(
392-
X=X_pending, sampler=self.fantasies_sampler, observation_noise=True
392+
X=X_pending,
393+
sampler=self.fantasies_sampler,
393394
)
394395
else:
395396
self.model = init_model

botorch/acquisition/multi_objective/max_value_entropy_search.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
146146
if X_pending is not None:
147147
# fantasize the model
148148
fantasy_model = self._init_model.fantasize(
149-
X=X_pending, sampler=self.fantasies_sampler, observation_noise=True
149+
X=X_pending,
150+
sampler=self.fantasies_sampler,
150151
)
151152
self.mo_model = fantasy_model
152153
# convert model to batched single outcome model.

botorch/acquisition/multi_step_lookahead.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def _step(
399399
# construct fantasy model (with batch shape f_{j+1} x ... x f_1 x batch_shape)
400400
prop_grads = step_index > 0 # need to propagate gradients for steps > 0
401401
fantasy_model = model.fantasize(
402-
X=X, sampler=samplers[0], observation_noise=True, propagate_grads=prop_grads
402+
X=X, sampler=samplers[0], propagate_grads=prop_grads
403403
)
404404

405405
# augment sample weights appropriately
@@ -585,7 +585,6 @@ def _get_induced_fantasy_model(
585585
fantasy_model = model.fantasize(
586586
X=Xs[0],
587587
sampler=samplers[0],
588-
observation_noise=True,
589588
)
590589

591590
return _get_induced_fantasy_model(

botorch/models/gp_regression.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@
3030

3131
from __future__ import annotations
3232

33-
from typing import Any, List, NoReturn, Optional, Union
33+
from typing import Any, List, NoReturn, Optional
3434

3535
import torch
36-
from botorch import settings
3736
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
3837
from botorch.models.model import FantasizeMixin
3938
from botorch.models.transforms.input import InputTransform
4039
from botorch.models.transforms.outcome import Log, OutcomeTransform
41-
from botorch.models.utils import fantasize as fantasize_flag, validate_input_scaling
40+
from botorch.models.utils import validate_input_scaling
4241
from botorch.models.utils.gpytorch_modules import (
4342
get_gaussian_likelihood_with_gamma_prior,
4443
get_matern_kernel_with_gamma_prior,
@@ -164,7 +163,7 @@ def forward(self, x: Tensor) -> MultivariateNormal:
164163
return MultivariateNormal(mean_x, covar_x)
165164

166165

167-
class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP):
166+
class FixedNoiseGP(BatchedMultiOutputGPyTorchModel, ExactGP, FantasizeMixin):
168167
r"""A single-task exact GP model using fixed noise levels.
169168
170169
A single-task exact GP that uses fixed observation noise levels, differing from
@@ -270,7 +269,7 @@ def fantasize(
270269
self,
271270
X: Tensor,
272271
sampler: MCSampler,
273-
observation_noise: Union[bool, Tensor] = True,
272+
observation_noise: Optional[Tensor] = None,
274273
**kwargs: Any,
275274
) -> FixedNoiseGP:
276275
r"""Construct a fantasy model.
@@ -290,29 +289,32 @@ def fantasize(
290289
`batch_shape` is the batch shape (must be compatible with the
291290
batch shape of the model).
292291
sampler: The sampler used for sampling from the posterior at `X`.
293-
observation_noise: If True, include the mean across the observation
294-
noise in the training data as observation noise in the posterior
295-
from which the samples are drawn. If a Tensor, use it directly
296-
as the specified measurement noise.
292+
observation_noise: The noise level for fantasization if
293+
provided. If `None`, the mean across the observation
294+
noise in the training data is used as observation noise in
295+
the posterior from which the samples are drawn and
296+
the fantasized noise level. If observation noise is
297+
provided, it is assumed to be in the outcome-transformed
298+
space, if an outcome transform is used.
297299
298300
Returns:
299301
The constructed fantasy model.
300302
"""
301-
propagate_grads = kwargs.pop("propagate_grads", False)
302-
with fantasize_flag():
303-
with settings.propagate_grads(propagate_grads):
304-
post_X = self.posterior(
305-
X, observation_noise=observation_noise, **kwargs
306-
)
307-
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
308-
# Use the mean of the previous noise values (TODO: be smarter here).
309-
# noise should be batch_shape x q x m when X is batch_shape x q x d, and
310-
# Y_fantasized is num_fantasies x batch_shape x q x m.
311-
noise_shape = Y_fantasized.shape[1:]
312-
noise = self.likelihood.noise.mean().expand(noise_shape)
313-
return self.condition_on_observations(
314-
X=self.transform_inputs(X), Y=Y_fantasized, noise=noise
315-
)
303+
# self.likelihood.noise is an `batch_shape x n x s(m)`-dimensional tensor
304+
if observation_noise is None:
305+
if self.num_outputs > 1:
306+
# make noise ... x n x m
307+
observation_noise = self.likelihood.noise.transpose(-1, -2)
308+
else:
309+
observation_noise = self.likelihood.noise.unsqueeze(-1)
310+
observation_noise = observation_noise.mean(dim=-2, keepdim=True)
311+
312+
return super().fantasize(
313+
X=X,
314+
sampler=sampler,
315+
observation_noise=observation_noise,
316+
**kwargs,
317+
)
316318

317319
def forward(self, x: Tensor) -> MultivariateNormal:
318320
# TODO: reduce redundancy with the 'forward' method of

botorch/models/gpytorch.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def posterior(
159159
jointly.
160160
observation_noise: If True, add the observation noise from the
161161
likelihood to the posterior. If a Tensor, use it directly as the
162-
observation noise (must be of shape `(batch_shape) x q`).
162+
observation noise (must be of shape `(batch_shape) x q`). It is
163+
assumed to be in the outcome-transformed space if an outcome
164+
transform is used.
163165
posterior_transform: An optional PosteriorTransform.
164166
165167
Returns:
@@ -223,7 +225,8 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
223225
# pass the transformed data to get_fantasy_model below
224226
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
225227
if not isinstance(self, BatchedMultiOutputGPyTorchModel):
226-
Y, Yvar = self.outcome_transform(Y, Yvar)
228+
# `noise` is assumed to already be outcome-transformed.
229+
Y, _ = self.outcome_transform(Y, Yvar)
227230
# validate using strict=False, since we cannot tell if Y has an explicit
228231
# output dimension
229232
self._validate_tensor_args(X=X, Y=Y, Yvar=Yvar, strict=False)
@@ -373,18 +376,32 @@ def posterior(
373376
)
374377
mvn = self(X)
375378
if observation_noise is not False:
379+
if self._num_outputs > 1:
380+
noise_shape = X.shape[:-3] + torch.Size(
381+
[self._num_outputs, X.shape[-2]]
382+
)
383+
else:
384+
noise_shape = X.shape[:-1]
376385
if torch.is_tensor(observation_noise):
377386
# TODO: Validate noise shape
378387
# make observation_noise `batch_shape x q x n`
379388
if self.num_outputs > 1:
380389
obs_noise = observation_noise.transpose(-1, -2)
381390
else:
382391
obs_noise = observation_noise.squeeze(-1)
383-
mvn = self.likelihood(mvn, X, noise=obs_noise)
392+
mvn = self.likelihood(
393+
mvn,
394+
X,
395+
noise=obs_noise.expand(noise_shape),
396+
)
384397
elif isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
385398
# Use the mean of the previous noise values (TODO: be smarter here).
386-
noise = self.likelihood.noise.mean().expand(X.shape[:-1])
387-
mvn = self.likelihood(mvn, X, noise=noise)
399+
observation_noise = self.likelihood.noise.mean(dim=-1, keepdim=True)
400+
mvn = self.likelihood(
401+
mvn,
402+
X,
403+
noise=observation_noise.expand(noise_shape),
404+
)
388405
else:
389406
mvn = self.likelihood(mvn, X)
390407
if self._num_outputs > 1:
@@ -443,8 +460,9 @@ def condition_on_observations(
443460
"""
444461
noise = kwargs.get("noise")
445462
if hasattr(self, "outcome_transform"):
446-
# we need to apply transforms before shifting batch indices around
447-
Y, noise = self.outcome_transform(Y, noise)
463+
# We need to apply transforms before shifting batch indices around.
464+
# `noise` is assumed to already be outcome-transformed.
465+
Y, _ = self.outcome_transform(Y)
448466
self._validate_tensor_args(X=X, Y=Y, Yvar=noise, strict=False)
449467
inputs = X
450468
if self._num_outputs > 1:

botorch/models/model.py

+47-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
import numpy as np
3434
import torch
3535
from botorch import settings
36-
from botorch.exceptions.errors import BotorchTensorDimensionError, InputDataError
36+
from botorch.exceptions.errors import (
37+
BotorchTensorDimensionError,
38+
DeprecationError,
39+
InputDataError,
40+
)
3741
from botorch.logging import shape_to_str
3842
from botorch.models.utils.assorted import fantasize as fantasize_flag
3943
from botorch.posteriors import Posterior, PosteriorList
@@ -83,7 +87,7 @@ def posterior(
8387
self,
8488
X: Tensor,
8589
output_indices: Optional[List[int]] = None,
86-
observation_noise: bool = False,
90+
observation_noise: Union[bool, Tensor] = False,
8791
posterior_transform: Optional[PosteriorTransform] = None,
8892
**kwargs: Any,
8993
) -> Posterior:
@@ -102,7 +106,12 @@ def posterior(
102106
Can be used to speed up computation if only a subset of the
103107
model's outputs are required for optimization. If omitted,
104108
computes the posterior over all model outputs.
105-
observation_noise: If True, add observation noise to the posterior.
109+
observation_noise: For models with an inferred noise level, if True,
110+
include observation noise. For models with an observed noise level,
111+
this must be a `model_batch_shape x 1 x m`-dim tensor or
112+
a `model_batch_shape x n' x m`-dim tensor containing the average
113+
noise for each batch and output. `noise` must be in the
114+
outcome-transformed space if an outcome transform is used.
106115
posterior_transform: An optional PosteriorTransform.
107116
108117
Returns:
@@ -310,7 +319,7 @@ def fantasize(
310319
# TODO: see if any of these can be imported only if TYPE_CHECKING
311320
X: Tensor,
312321
sampler: MCSampler,
313-
observation_noise: bool = True,
322+
observation_noise: Optional[Tensor] = None,
314323
**kwargs: Any,
315324
) -> TFantasizeMixin:
316325
r"""Construct a fantasy model.
@@ -328,12 +337,21 @@ def fantasize(
328337
`batch_shape` is the batch shape (must be compatible with the
329338
batch shape of the model).
330339
sampler: The sampler used for sampling from the posterior at `X`.
331-
observation_noise: If True, include observation noise.
340+
observation_noise: A `model_batch_shape x 1 x m`-dim tensor or
341+
a `model_batch_shape x n' x m`-dim tensor containing the average
342+
noise for each batch and output, where `m` is the number of outputs.
343+
`noise` must be in the outcome-transformed space if an outcome
344+
transform is used. If None, then the noise will be the inferred
345+
noise level.
332346
kwargs: Will be passed to `model.condition_on_observations`
333347
334348
Returns:
335349
The constructed fantasy model.
336350
"""
351+
if not isinstance(observation_noise, Tensor) and observation_noise is not None:
352+
raise DeprecationError(
353+
"`fantasize` no longer accepts a boolean for `observation_noise`."
354+
)
337355
# if the inputs are empty, expand the inputs
338356
if X.shape[-2] == 0:
339357
output_shape = (
@@ -350,8 +368,15 @@ def fantasize(
350368
propagate_grads = kwargs.pop("propagate_grads", False)
351369
with fantasize_flag():
352370
with settings.propagate_grads(propagate_grads):
353-
post_X = self.posterior(X, observation_noise=observation_noise)
371+
post_X = self.posterior(
372+
X,
373+
observation_noise=True
374+
if observation_noise is None
375+
else observation_noise,
376+
)
354377
Y_fantasized = sampler(post_X) # num_fantasies x batch_shape x n' x m
378+
if observation_noise is not None:
379+
kwargs["noise"] = observation_noise.expand(Y_fantasized.shape[1:])
355380
return self.condition_on_observations(
356381
X=self.transform_inputs(X), Y=Y_fantasized, **kwargs
357382
)
@@ -434,7 +459,9 @@ def posterior(
434459
respective likelihoods to the posterior. If a Tensor of shape
435460
`(batch_shape) x q x m`, use it directly as the observation
436461
noise (with `observation_noise[...,i]` added to the posterior
437-
of the `i`-th model).
462+
of the `i`-th model). `observation_noise` is assumed
463+
to be in the outcome-transformed space, if an outcome transform
464+
is used by the model.
438465
posterior_transform: An optional PosteriorTransform.
439466
440467
Returns:
@@ -553,7 +580,7 @@ def fantasize(
553580
self,
554581
X: Tensor,
555582
sampler: MCSampler,
556-
observation_noise: bool = True,
583+
observation_noise: Optional[Tensor] = None,
557584
evaluation_mask: Optional[Tensor] = None,
558585
**kwargs: Any,
559586
) -> Model:
@@ -573,7 +600,12 @@ def fantasize(
573600
batch shape of the model).
574601
sampler: The sampler used for sampling from the posterior at `X`. If
575602
evaluation_mask is not None, this must be a `ListSampler`.
576-
observation_noise: If True, include observation noise.
603+
observation_noise: A `model_batch_shape x 1 x m`-dim tensor or
604+
a `model_batch_shape x n' x m`-dim tensor containing the average
605+
noise for each batch and output, where `m` is the number of outputs.
606+
`noise` must be in the outcome-transformed space if an outcome
607+
transform is used. If None, then the noise will be the inferred
608+
noise level.
577609
evaluation_mask: A `n' x m`-dim tensor of booleans indicating which
578610
outputs should be fantasized for a given design. This uses the same
579611
evaluation mask for all batches.
@@ -595,6 +627,8 @@ def fantasize(
595627

596628
fant_models = []
597629
X_i = X
630+
if observation_noise is None:
631+
observation_noise_i = observation_noise
598632
for i in range(self.num_outputs):
599633
# get the inputs to fantasize at for output i
600634
if evaluation_mask is not None:
@@ -604,12 +638,15 @@ def fantasize(
604638
# samples from a single Sobol sequence or consider requiring that the
605639
# sampling is IID to ensure good coverage.
606640
sampler_i = sampler.samplers[i]
641+
if observation_noise is not None:
642+
observation_noise_i = observation_noise[..., mask_i, i : i + 1]
607643
else:
608644
sampler_i = sampler
645+
609646
fant_model = self.models[i].fantasize(
610647
X=X_i,
611648
sampler=sampler_i,
612-
observation_noise=observation_noise,
649+
observation_noise=observation_noise_i,
613650
**kwargs,
614651
)
615652
fant_models.append(fant_model)

botorch/utils/testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def _get_random_data(
375375
[torch.linspace(0, 0.95, n, **tkwargs) for _ in range(d)], dim=-1
376376
)
377377
train_x = train_x + 0.05 * torch.rand_like(train_x).repeat(rep_shape)
378+
train_x[0] += 0.02 # modify the first batch
378379
train_y = torch.sin(train_x[..., :1] * (2 * math.pi))
379380
train_y = train_y + 0.2 * torch.randn(n, m, **tkwargs).repeat(rep_shape)
380381
return train_x, train_y

0 commit comments

Comments
 (0)