diff --git a/ax/models/tests/test_alebo.py b/ax/models/tests/test_alebo.py index 6600c3666ea..2b82ba87f4c 100644 --- a/ax/models/tests/test_alebo.py +++ b/ax/models/tests/test_alebo.py @@ -22,11 +22,13 @@ ) from ax.models.torch_base import TorchOptConfig from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.mock import fast_botorch_optimize from botorch.acquisition.analytic import ExpectedImprovement from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement from botorch.models.model_list_gp_regression import ModelListGP from botorch.utils.datasets import FixedNoiseDataset +from torch.nn.parameter import Parameter class ALEBOTest(TestCase): @@ -47,7 +49,9 @@ def testALEBOKernel(self) -> None: self.assertEqual(k.Uvec.shape, torch.Size([3])) k.Uvec.requires_grad_(False) - k.Uvec.copy_(torch.tensor([1.0, 2.0, 3.0], dtype=torch.double)) + checked_cast(Parameter, k.Uvec).copy_( + torch.tensor([1.0, 2.0, 3.0], dtype=torch.double) + ) k.Uvec.requires_grad_(True) x1 = torch.tensor([[0.0, 0.0], [1.0, 1.0]], dtype=torch.double) x2 = torch.tensor([[1.0, 1.0], [0.0, 0.0]], dtype=torch.double) diff --git a/ax/models/tests/test_botorch_defaults.py b/ax/models/tests/test_botorch_defaults.py index f2b73d2a20f..4a1fc5847f4 100644 --- a/ax/models/tests/test_botorch_defaults.py +++ b/ax/models/tests/test_botorch_defaults.py @@ -16,6 +16,7 @@ get_warping_transform, ) from ax.utils.common.testutils import TestCase +from ax.utils.common.typeutils import checked_cast from ax.utils.testing.mock import fast_botorch_optimize from botorch.acquisition.penalized import PenalizedMCObjective from botorch.models import FixedNoiseGP, SingleTaskGP @@ -384,14 +385,20 @@ def test_get_customized_covar_module(self) -> None: self.assertIsInstance(covar_module, Module) self.assertIsInstance(covar_module, ScaleKernel) self.assertIsInstance(covar_module.outputscale_prior, GammaPrior) - self.assertEqual(covar_module.outputscale_prior.concentration, 2.0) - self.assertEqual(covar_module.outputscale_prior.rate, 0.15) + prior = checked_cast(GammaPrior, covar_module.outputscale_prior) + self.assertEqual(prior.concentration, 2.0) + self.assertEqual(prior.rate, 0.15) self.assertIsInstance(covar_module.base_kernel, MaternKernel) - self.assertIsInstance(covar_module.base_kernel.lengthscale_prior, GammaPrior) - self.assertEqual(covar_module.base_kernel.lengthscale_prior.concentration, 3.0) - self.assertEqual(covar_module.base_kernel.lengthscale_prior.rate, 6.0) - self.assertEqual(covar_module.base_kernel.ard_num_dims, ard_num_dims) - self.assertEqual(covar_module.base_kernel.batch_shape, batch_shape) + base_kernel = checked_cast(MaternKernel, covar_module.base_kernel) + self.assertIsInstance(base_kernel.lengthscale_prior, GammaPrior) + self.assertEqual( + checked_cast(GammaPrior, base_kernel.lengthscale_prior).concentration, 3.0 + ) + self.assertEqual( + checked_cast(GammaPrior, base_kernel.lengthscale_prior).rate, 6.0 + ) + self.assertEqual(base_kernel.ard_num_dims, ard_num_dims) + self.assertEqual(base_kernel.batch_shape, batch_shape) covar_module = _get_customized_covar_module( covar_module_prior_dict={ @@ -405,14 +412,20 @@ def test_get_customized_covar_module(self) -> None: self.assertIsInstance(covar_module, Module) self.assertIsInstance(covar_module, ScaleKernel) self.assertIsInstance(covar_module.outputscale_prior, GammaPrior) - self.assertEqual(covar_module.outputscale_prior.concentration, 2.0) - self.assertEqual(covar_module.outputscale_prior.rate, 12.0) + prior = checked_cast(GammaPrior, covar_module.outputscale_prior) + self.assertEqual(prior.concentration, 2.0) + self.assertEqual(prior.rate, 12.0) self.assertIsInstance(covar_module.base_kernel, MaternKernel) - self.assertIsInstance(covar_module.base_kernel.lengthscale_prior, GammaPrior) - self.assertEqual(covar_module.base_kernel.lengthscale_prior.concentration, 12.0) - self.assertEqual(covar_module.base_kernel.lengthscale_prior.rate, 2.0) - self.assertEqual(covar_module.base_kernel.ard_num_dims, ard_num_dims - 1) - self.assertEqual(covar_module.base_kernel.batch_shape, batch_shape) + base_kernel = checked_cast(MaternKernel, covar_module.base_kernel) + self.assertIsInstance(base_kernel.lengthscale_prior, GammaPrior) + self.assertEqual( + checked_cast(GammaPrior, base_kernel.lengthscale_prior).concentration, 12.0 + ) + self.assertEqual( + checked_cast(GammaPrior, base_kernel.lengthscale_prior).rate, 2.0 + ) + self.assertEqual(base_kernel.ard_num_dims, ard_num_dims - 1) + self.assertEqual(base_kernel.batch_shape, batch_shape) def test_get_warping_transform(self) -> None: warp_tf = get_warping_transform(d=4) diff --git a/ax/models/torch/alebo.py b/ax/models/torch/alebo.py index 1c90570e3db..0443a03fed1 100644 --- a/ax/models/torch/alebo.py +++ b/ax/models/torch/alebo.py @@ -24,6 +24,7 @@ from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger +from ax.utils.common.typeutils import checked_cast from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import ExpectedImprovement from botorch.acquisition.objective import PosteriorTransform @@ -43,6 +44,7 @@ from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood from scipy.optimize import approx_fprime from torch import Tensor +from torch.nn.parameter import Parameter logger: Logger = get_logger(__name__) @@ -98,7 +100,7 @@ def forward( # Unpack Uvec into an upper triangular matrix U shapeU = self.Uvec.shape[:-1] + torch.Size([self.d, self.d]) U_t = torch.zeros(shapeU, dtype=self.B.dtype, device=self.B.device) - U_t[..., self.triu_indx[1], self.triu_indx[0]] = self.Uvec + U_t[..., self.triu_indx[1], self.triu_indx[0]] = checked_cast(Tensor, self.Uvec) # Compute kernel distance z1 = torch.matmul(x1, U_t) z2 = torch.matmul(x2, U_t) @@ -394,11 +396,11 @@ def get_batch_model( m_b.mean_module.raw_constant.requires_grad_(True) # Set output scale m_b.covar_module.raw_outputscale.requires_grad_(False) - m_b.covar_module.raw_outputscale.copy_(output_scale_batch) + checked_cast(Parameter, m_b.covar_module.raw_outputscale).copy_(output_scale_batch) m_b.covar_module.raw_outputscale.requires_grad_(True) # Set Uvec m_b.covar_module.base_kernel.Uvec.requires_grad_(False) - m_b.covar_module.base_kernel.Uvec.copy_(Uvec_batch) + checked_cast(Parameter, m_b.covar_module.base_kernel.Uvec).copy_(Uvec_batch) m_b.covar_module.base_kernel.Uvec.requires_grad_(True) m_b.eval() return m_b diff --git a/ax/models/torch/fully_bayesian_model_utils.py b/ax/models/torch/fully_bayesian_model_utils.py index 949f9b419f3..a9cbda6a996 100644 --- a/ax/models/torch/fully_bayesian_model_utils.py +++ b/ax/models/torch/fully_bayesian_model_utils.py @@ -177,10 +177,15 @@ def pyro_sample_input_warping( def load_mcmc_samples_to_model(model: GPyTorchModel, mcmc_samples: Dict) -> None: """Load MCMC samples into GPyTorchModel.""" if "noise" in mcmc_samples: + # pyre-ignore Undefined attribute [16]: `torch._tensor.Tensor` has + # no attribute `noise`. model.likelihood.noise_covar.noise = ( mcmc_samples["noise"] .detach() .clone() + # pyre-ignore Undefined attribute [16]: Item `torch._tensor.Tensor` of + # `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` + # has no attribute `noise`. .view(model.likelihood.noise_covar.noise.shape) .clamp_min(MIN_INFERRED_NOISE_LEVEL) ) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 5a3b015363b..3bab0756885 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -722,16 +722,44 @@ def test_fit_mixed(self) -> None: self.assertEqual(surrogate.model._ignore_X_dims_scaling_check, [0]) covar_module = checked_cast(Kernel, surrogate.model.covar_module) self.assertEqual( - covar_module.kernels[0].base_kernel.kernels[1].active_dims.tolist(), [0] + # pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable + # (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase), + # Named(indices, typing.Union[None, typing.List[typing.Any], int, slice, + # torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor. + # Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules. + # module.Module]` is not a function. + covar_module.kernels[0].base_kernel.kernels[1].active_dims.tolist(), + [0], ) self.assertEqual( - covar_module.kernels[0].base_kernel.kernels[0].active_dims.tolist(), [1, 2] + # pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable + # (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase), + # Named(indices, typing.Union[None, typing.List[typing.Any], int, slice, + # torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor. + # Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules. + # module.Module]` is not a function. + covar_module.kernels[0].base_kernel.kernels[0].active_dims.tolist(), + [1, 2], ) self.assertEqual( - covar_module.kernels[1].base_kernel.kernels[1].active_dims.tolist(), [0] + # pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable + # (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase), + # Named(indices, typing.Union[None, typing.List[typing.Any], int, slice, + # torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor. + # Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules. + # module.Module]` is not a function. + covar_module.kernels[1].base_kernel.kernels[1].active_dims.tolist(), + [0], ) self.assertEqual( - covar_module.kernels[1].base_kernel.kernels[0].active_dims.tolist(), [1, 2] + # pyre-ignore Call error [29]: `typing.Union[BoundMethod[typing.Callable + # (torch._C._TensorBase.__getitem__)[[Named(self, torch._C._TensorBase), + # Named(indices, typing.Union[None, typing.List[typing.Any], int, slice, + # torch._tensor.Tensor, typing.Tuple[typing.Any, ...]])], torch._tensor. + # Tensor], torch._tensor.Tensor], torch._tensor.Tensor, torch.nn.modules. + # module.Module]` is not a function. + covar_module.kernels[1].base_kernel.kernels[0].active_dims.tolist(), + [1, 2], ) # With modellist. training_data = [self.training_data[0], self.training_data[0]]