Skip to content

Commit

Permalink
Fix / silence Pyre failures due to latest GPyTorch changes (#1642)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1642

Turns out Pyre really doesn't like GPyTorch and the only thing keeping it sane was a `__getattr__` method with no return type.

Differential Revision: D46409697

fbshipit-source-id: 6a88c83c6ea92c1a80d99e4ba6a389897f4315f4
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 3, 2023
1 parent 23008c6 commit 7185f6d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 22 deletions.
6 changes: 5 additions & 1 deletion ax/models/tests/test_alebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
)
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.mock import fast_botorch_optimize
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.utils.datasets import FixedNoiseDataset
from torch.nn.parameter import Parameter


class ALEBOTest(TestCase):
Expand All @@ -47,7 +49,9 @@ def testALEBOKernel(self) -> None:
self.assertEqual(k.Uvec.shape, torch.Size([3]))

k.Uvec.requires_grad_(False)
k.Uvec.copy_(torch.tensor([1.0, 2.0, 3.0], dtype=torch.double))
checked_cast(Parameter, k.Uvec).copy_(
torch.tensor([1.0, 2.0, 3.0], dtype=torch.double)
)
k.Uvec.requires_grad_(True)
x1 = torch.tensor([[0.0, 0.0], [1.0, 1.0]], dtype=torch.double)
x2 = torch.tensor([[1.0, 1.0], [0.0, 0.0]], dtype=torch.double)
Expand Down
41 changes: 27 additions & 14 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_warping_transform,
)
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.mock import fast_botorch_optimize
from botorch.acquisition.penalized import PenalizedMCObjective
from botorch.models import FixedNoiseGP, SingleTaskGP
Expand Down Expand Up @@ -384,14 +385,20 @@ def test_get_customized_covar_module(self) -> None:
self.assertIsInstance(covar_module, Module)
self.assertIsInstance(covar_module, ScaleKernel)
self.assertIsInstance(covar_module.outputscale_prior, GammaPrior)
self.assertEqual(covar_module.outputscale_prior.concentration, 2.0)
self.assertEqual(covar_module.outputscale_prior.rate, 0.15)
prior = checked_cast(GammaPrior, covar_module.outputscale_prior)
self.assertEqual(prior.concentration, 2.0)
self.assertEqual(prior.rate, 0.15)
self.assertIsInstance(covar_module.base_kernel, MaternKernel)
self.assertIsInstance(covar_module.base_kernel.lengthscale_prior, GammaPrior)
self.assertEqual(covar_module.base_kernel.lengthscale_prior.concentration, 3.0)
self.assertEqual(covar_module.base_kernel.lengthscale_prior.rate, 6.0)
self.assertEqual(covar_module.base_kernel.ard_num_dims, ard_num_dims)
self.assertEqual(covar_module.base_kernel.batch_shape, batch_shape)
base_kernel = checked_cast(MaternKernel, covar_module.base_kernel)
self.assertIsInstance(base_kernel.lengthscale_prior, GammaPrior)
self.assertEqual(
checked_cast(GammaPrior, base_kernel.lengthscale_prior).concentration, 3.0
)
self.assertEqual(
checked_cast(GammaPrior, base_kernel.lengthscale_prior).rate, 6.0
)
self.assertEqual(base_kernel.ard_num_dims, ard_num_dims)
self.assertEqual(base_kernel.batch_shape, batch_shape)

covar_module = _get_customized_covar_module(
covar_module_prior_dict={
Expand All @@ -405,14 +412,20 @@ def test_get_customized_covar_module(self) -> None:
self.assertIsInstance(covar_module, Module)
self.assertIsInstance(covar_module, ScaleKernel)
self.assertIsInstance(covar_module.outputscale_prior, GammaPrior)
self.assertEqual(covar_module.outputscale_prior.concentration, 2.0)
self.assertEqual(covar_module.outputscale_prior.rate, 12.0)
prior = checked_cast(GammaPrior, covar_module.outputscale_prior)
self.assertEqual(prior.concentration, 2.0)
self.assertEqual(prior.rate, 12.0)
self.assertIsInstance(covar_module.base_kernel, MaternKernel)
self.assertIsInstance(covar_module.base_kernel.lengthscale_prior, GammaPrior)
self.assertEqual(covar_module.base_kernel.lengthscale_prior.concentration, 12.0)
self.assertEqual(covar_module.base_kernel.lengthscale_prior.rate, 2.0)
self.assertEqual(covar_module.base_kernel.ard_num_dims, ard_num_dims - 1)
self.assertEqual(covar_module.base_kernel.batch_shape, batch_shape)
base_kernel = checked_cast(MaternKernel, covar_module.base_kernel)
self.assertIsInstance(base_kernel.lengthscale_prior, GammaPrior)
self.assertEqual(
checked_cast(GammaPrior, base_kernel.lengthscale_prior).concentration, 12.0
)
self.assertEqual(
checked_cast(GammaPrior, base_kernel.lengthscale_prior).rate, 2.0
)
self.assertEqual(base_kernel.ard_num_dims, ard_num_dims - 1)
self.assertEqual(base_kernel.batch_shape, batch_shape)

def test_get_warping_transform(self) -> None:
warp_tf = get_warping_transform(d=4)
Expand Down
8 changes: 5 additions & 3 deletions ax/models/torch/alebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.objective import PosteriorTransform
Expand All @@ -43,6 +44,7 @@
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from scipy.optimize import approx_fprime
from torch import Tensor
from torch.nn.parameter import Parameter


logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -98,7 +100,7 @@ def forward(
# Unpack Uvec into an upper triangular matrix U
shapeU = self.Uvec.shape[:-1] + torch.Size([self.d, self.d])
U_t = torch.zeros(shapeU, dtype=self.B.dtype, device=self.B.device)
U_t[..., self.triu_indx[1], self.triu_indx[0]] = self.Uvec
U_t[..., self.triu_indx[1], self.triu_indx[0]] = checked_cast(Tensor, self.Uvec)
# Compute kernel distance
z1 = torch.matmul(x1, U_t)
z2 = torch.matmul(x2, U_t)
Expand Down Expand Up @@ -394,11 +396,11 @@ def get_batch_model(
m_b.mean_module.raw_constant.requires_grad_(True)
# Set output scale
m_b.covar_module.raw_outputscale.requires_grad_(False)
m_b.covar_module.raw_outputscale.copy_(output_scale_batch)
checked_cast(Parameter, m_b.covar_module.raw_outputscale).copy_(output_scale_batch)
m_b.covar_module.raw_outputscale.requires_grad_(True)
# Set Uvec
m_b.covar_module.base_kernel.Uvec.requires_grad_(False)
m_b.covar_module.base_kernel.Uvec.copy_(Uvec_batch)
checked_cast(Parameter, m_b.covar_module.base_kernel.Uvec).copy_(Uvec_batch)
m_b.covar_module.base_kernel.Uvec.requires_grad_(True)
m_b.eval()
return m_b
Expand Down
5 changes: 5 additions & 0 deletions ax/models/torch/fully_bayesian_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,15 @@ def pyro_sample_input_warping(
def load_mcmc_samples_to_model(model: GPyTorchModel, mcmc_samples: Dict) -> None:
"""Load MCMC samples into GPyTorchModel."""
if "noise" in mcmc_samples:
# pyre-ignore Undefined attribute [16]: `torch._tensor.Tensor` has
# no attribute `noise`.
model.likelihood.noise_covar.noise = (
mcmc_samples["noise"]
.detach()
.clone()
# pyre-ignore Undefined attribute [16]: Item `torch._tensor.Tensor` of
# `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
# has no attribute `noise`.
.view(model.likelihood.noise_covar.noise.shape)
.clamp_min(MIN_INFERRED_NOISE_LEVEL)
)
Expand Down
36 changes: 32 additions & 4 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,16 +722,44 @@ def test_fit_mixed(self) -> None:
self.assertEqual(surrogate.model._ignore_X_dims_scaling_check, [0])
covar_module = checked_cast(Kernel, surrogate.model.covar_module)
self.assertEqual(
covar_module.kernels[0].base_kernel.kernels[1].active_dims.tolist(), [0]
# pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable
# (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase),
# Named(indices, typing.Union[None, typing.List[typing.Any], int, slice,
# torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor.
# Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules.
# module.Module]` is not a function.
covar_module.kernels[0].base_kernel.kernels[1].active_dims.tolist(),
[0],
)
self.assertEqual(
covar_module.kernels[0].base_kernel.kernels[0].active_dims.tolist(), [1, 2]
# pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable
# (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase),
# Named(indices, typing.Union[None, typing.List[typing.Any], int, slice,
# torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor.
# Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules.
# module.Module]` is not a function.
covar_module.kernels[0].base_kernel.kernels[0].active_dims.tolist(),
[1, 2],
)
self.assertEqual(
covar_module.kernels[1].base_kernel.kernels[1].active_dims.tolist(), [0]
# pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable
# (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase),
# Named(indices, typing.Union[None, typing.List[typing.Any], int, slice,
# torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor.
# Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules.
# module.Module]` is not a function.
covar_module.kernels[1].base_kernel.kernels[1].active_dims.tolist(),
[0],
)
self.assertEqual(
covar_module.kernels[1].base_kernel.kernels[0].active_dims.tolist(), [1, 2]
# pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable
# (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase),
# Named(indices, typing.Union[None, typing.List[typing.Any], int, slice,
# torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor.
# Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules.
# module.Module]` is not a function.
covar_module.kernels[1].base_kernel.kernels[0].active_dims.tolist(),
[1, 2],
)
# With modellist.
training_data = [self.training_data[0], self.training_data[0]]
Expand Down

0 comments on commit 7185f6d

Please sign in to comment.