Skip to content

Commit

Permalink
flake8 fixes (#1752)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1752

Looks like flake8 became a bit more strict with type comparisons. This fixes the failures raised in https://github.com/facebook/Ax/actions/runs/5719204578/job/15496614136

Reviewed By: Balandat

Differential Revision: D47931988

fbshipit-source-id: b5674e2035d2edc51545ee96a3ac8ae9a0eabb9b
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Aug 1, 2023
1 parent 1cea3d1 commit feab1f3
Show file tree
Hide file tree
Showing 11 changed files with 17 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/problems/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if not type(other) == type(self):
if type(other) is not type(self):
return False

# Checking the whole datasets' equality here would be too expensive to be
Expand Down
2 changes: 1 addition & 1 deletion ax/core/optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _validate_optimization_config(
outcome_constraints: Constraints to validate.
risk_measure: An optional risk measure to validate.
"""
if type(objective) == MultiObjective:
if type(objective) is MultiObjective:
# Raise error on exact equality; `ScalarizedObjective` is OK
raise ValueError(
(
Expand Down
6 changes: 3 additions & 3 deletions ax/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def validate_single_metric_data(data: SingleMetricData) -> None:

def validate_trial_evaluation(evaluation: TTrialEvaluation) -> None:
for key, value in evaluation.items():
if type(key) != str:
if not isinstance(key, str):
raise TypeError(f"Keys must be strings in TTrialEvaluation, found {key}.")

validate_single_metric_data(data=value)
Expand All @@ -160,15 +160,15 @@ def validate_param_value(param_value: TParamValue) -> None:

def validate_parameterization(parameterization: TParameterization) -> None:
for key, value in parameterization.items():
if type(key) != str:
if not isinstance(key, str):
raise TypeError(f"Keys must be strings in TParameterization, found {key}.")

validate_param_value(param_value=value)


def validate_map_dict(map_dict: TMapDict) -> None:
for key, value in map_dict.items():
if type(key) != str:
if not isinstance(key, str):
raise TypeError(f"Keys must be strings in TMapDict, found {key}.")

if not isinstance(value, Hashable):
Expand Down
2 changes: 0 additions & 2 deletions ax/models/random/sobol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class SobolGenerator(RandomModel):
"""

engine: Optional[SobolEngine] = None

def __init__(
self,
seed: Optional[int] = None,
Expand Down
15 changes: 6 additions & 9 deletions ax/models/tests/test_botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from copy import deepcopy
from unittest import mock
from unittest.mock import Mock

import torch
from ax.models.torch.botorch_defaults import (
Expand Down Expand Up @@ -47,7 +48,7 @@ def test_get_model(self) -> None:
)
self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 6.0)
model = _get_model(X=x, Y=y, Yvar=unknown_var, task_feature=1)
self.assertTrue(type(model) == MultiTaskGP) # Don't accept subclasses.
self.assertIs(type(model), MultiTaskGP) # Don't accept subclasses.
model = _get_model(X=x, Y=y, Yvar=var, task_feature=1)
self.assertIsInstance(model, FixedNoiseMultiTaskGP)
model = _get_model(X=x, Y=y, Yvar=partial_var.clone(), task_feature=1)
Expand Down Expand Up @@ -153,7 +154,7 @@ def test_get_model(self) -> None:
task_feature=1,
**deepcopy(kwargs6), # pyre-ignore
)
self.assertTrue(type(model) == MultiTaskGP)
self.assertIs(type(model), MultiTaskGP)
self.assertEqual(
model.covar_module.base_kernel.lengthscale_prior.concentration, 12.0
)
Expand Down Expand Up @@ -252,9 +253,7 @@ def test_task_feature(self, get_model_mock):

@mock.patch("ax.models.torch.botorch_defaults._get_model", wraps=_get_model)
@fast_botorch_optimize
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def test_pass_customized_prior(self, get_model_mock):
def test_pass_customized_prior(self, get_model_mock: Mock) -> None:
x = [torch.zeros(2, 2)]
y = [torch.zeros(2, 1)]
yvars = [torch.ones(2, 1)]
Expand All @@ -277,14 +276,12 @@ def test_pass_customized_prior(self, get_model_mock):
refit_model=False,
**kwarg, # pyre-ignore
)
self.assertTrue(type(model) == FixedNoiseGP)
self.assertIs(type(model), FixedNoiseGP)
self.assertEqual(
# pyre-ignore
model.covar_module.base_kernel.lengthscale_prior.concentration,
12.0,
)
self.assertEqual(model.covar_module.base_kernel.lengthscale_prior.rate, 2.0)
# pyre-ignore
self.assertEqual(model.covar_module.outputscale_prior.concentration, 2.0)
self.assertEqual(model.covar_module.outputscale_prior.rate, 12.0)

Expand All @@ -300,7 +297,7 @@ def test_pass_customized_prior(self, get_model_mock):
**kwarg, # pyre-ignore
)
for m in model.models:
self.assertTrue(type(m) == FixedNoiseMultiTaskGP)
self.assertIs(type(m), FixedNoiseMultiTaskGP)
self.assertEqual(
m.covar_module.base_kernel.lengthscale_prior.concentration,
12.0,
Expand Down
2 changes: 1 addition & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ def _validate_runner_and_implemented_metrics(self, experiment: Experiment) -> No
raise UnsupportedError(msg)
else:
base_metrics = {
m_name for m_name, m in experiment.metrics.items() if type(m) == Metric
m_name for m_name, m in experiment.metrics.items() if type(m) is Metric
}
if base_metrics:
msg += f" Metrics {base_metrics} do not implement fetching logic."
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def opt_config_and_tracking_metrics_from_sqa(
if objective is None:
return None, tracking_metrics

if objective_thresholds or type(objective) == MultiObjective:
if objective_thresholds or type(objective) is MultiObjective:
optimization_config = MultiObjectiveOptimizationConfig(
objective=objective,
outcome_constraints=outcome_constraints,
Expand Down
2 changes: 1 addition & 1 deletion ax/storage/sqa_store/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def copy_db_ids(source: Any, target: Any, path: Optional[List[str]] = None) -> N
# introducing infinite loops
raise SQADecodeError(error_message_prefix + "Encountered path of length > 10.")

if type(source) != type(target):
if type(source) is not type(target):
if not issubclass(type(target), type(source)):
if source is None and isinstance(target, SearchSpace):
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/common/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def object_attribute_dicts_find_unequal_fields(
one_val = numpy_type_to_python_type(one_val)
other_val = numpy_type_to_python_type(other_val)
skip_type_check = skip_db_id_check and field == "_db_id"
if not skip_type_check and (type(one_val) != type(other_val)):
if not skip_type_check and (type(one_val) is not type(other_val)):
unequal_type[field] = (one_val, other_val)
if fast_return:
return unequal_type, unequal_value
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/common/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _is_named_tuple(x: Any) -> bool:
f = getattr(t, "_fields", None)
if not isinstance(f, tuple):
return False # pragma nocover
return all(type(n) == str for n in f)
return all(isinstance(n, str) for n in f)


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
2 changes: 1 addition & 1 deletion ax/utils/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def generic_equals(first: Any, second: Any) -> bool:
sorted(first.items()), sorted(second.items())
)
if isinstance(first, (tuple, list)):
if type(first) != type(second) or len(first) != len(second):
if type(first) is not type(second) or len(first) != len(second):
return False
for f, s in zip(first, second):
if not generic_equals(f, s):
Expand Down

0 comments on commit feab1f3

Please sign in to comment.