From 0f0ad00f2325e8429bed171de9d0dbbea88823ae Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Fri, 9 Sep 2022 10:43:01 +0430 Subject: [PATCH] Fix rare test_binary_input failure in Precision & Recall --- tests/ignite/metrics/test_precision.py | 6 +++--- tests/ignite/metrics/test_recall.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 36872e49157..43689845419 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -127,9 +127,9 @@ def _test(y_pred, y, batch_size): assert isinstance(pr.compute(), torch.Tensor if not average else float) pr_compute = pr.compute().numpy() if not average else pr.compute() sk_average_parameter = ignite_average_to_scikit_average(average, "binary") - assert precision_score(np_y, np_y_pred, average=sk_average_parameter, zero_division=0) == pytest.approx( - pr_compute - ) + assert precision_score( + np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1], zero_division=0 + ) == pytest.approx(pr_compute) def get_test_cases(): diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 94f4fc160ef..8aae0df95ec 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -131,7 +131,7 @@ def _test(y_pred, y, batch_size): assert isinstance(re.compute(), torch.Tensor if not average else float) re_compute = re.compute().numpy() if not average else re.compute() sk_average_parameter = ignite_average_to_scikit_average(average, "binary") - assert recall_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(re_compute) + assert recall_score(np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1]) == pytest.approx(re_compute) def get_test_cases():