Skip to content

Commit

Permalink
update tests for MergeRepeatedMeasurements (facebook#1607)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#1607

updates per comments after landing D45558427

Reviewed By: esantorella

Differential Revision: D45608725

fbshipit-source-id: 2ee32e87e575af7b8dbe5ecba6688a8104170e7f
  • Loading branch information
sdaulton authored and facebook-github-bot committed May 5, 2023
1 parent 50b9951 commit 4f36140
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
22 changes: 16 additions & 6 deletions ax/modelbridge/tests/test_merge_repeated_measurements_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ def compare_obs(
test.assertEqual(obs1.data.metric_names, obs2.data.metric_names)
test.assertTrue(np.array_equal(obs1.data.means, obs2.data.means))
discrep = np.max(np.abs(obs1.data.covariance - obs2.data.covariance))
test.assertTrue(discrep <= discrepancy_tol)
test.assertTrue(obs1.features.parameters == obs2.features.parameters)
test.assertLessEqual(discrep, discrepancy_tol)
test.assertEqual(obs1.features.parameters, obs2.features.parameters)


class MergeRepeatedMeasurementsTransformTest(TestCase):
def testTransform(self) -> None:
obs_feats1 = ObservationFeatures(parameters={"a": 0.0})
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(
RuntimeError, "MergeRepeatedMeasurements requires observations"
):
# test that observations are required
MergeRepeatedMeasurements()
# test nan in covariance
Expand All @@ -37,7 +39,9 @@ def testTransform(self) -> None:
),
features=obs_feats1,
)
with self.assertRaises(NotImplementedError):
with self.assertRaisesRegex(
NotImplementedError, "All metrics must have noise observations."
):
MergeRepeatedMeasurements(observations=[observation])
# test full covariance
observation = Observation(
Expand All @@ -48,7 +52,9 @@ def testTransform(self) -> None:
),
features=obs_feats1,
)
with self.assertRaises(NotImplementedError):
with self.assertRaisesRegex(
NotImplementedError, "Only independent metrics are currently supported."
):
MergeRepeatedMeasurements(observations=[observation])

# test noiseless, different means
Expand All @@ -71,7 +77,11 @@ def testTransform(self) -> None:
features=obs_feats1,
),
]
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError,
"All repeated arms with noiseless measurements "
"must have the same means.",
):
MergeRepeatedMeasurements(observations=observations)
# test noiseless, same means
observations = [
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/transforms/merge_repeated_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
str, DefaultDict[str, DefaultDict[str, List[float]]]
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
observation_features, observation_data = separate_observations(observations)
#
for j, obsd in enumerate(observation_data):
# This intentionally ignores the trial index
key = Arm.md5hash(observation_features[j].parameters)
Expand Down

0 comments on commit 4f36140

Please sign in to comment.