Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metrics: add SSIM #1217

Merged
merged 18 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Complete list of metrics
- :class:`~ignite.metrics.Recall`
- :class:`~ignite.metrics.RootMeanSquaredError`
- :class:`~ignite.metrics.RunningAverage`
- :class:`~ignite.metrics.SSIM`
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
- :class:`~ignite.metrics.VariableAccumulation`

Expand Down Expand Up @@ -278,6 +279,8 @@ Complete list of metrics

.. autoclass:: RunningAverage

.. autoclass:: SSIM

.. autoclass:: TopKCategoricalAccuracy

.. autoclass:: VariableAccumulation
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.recall import Recall
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy

__all__ = [
Expand All @@ -39,4 +40,5 @@
"RunningAverage",
"VariableAccumulation",
"Frequency",
"SSIM",
]
170 changes: 170 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Callable, Sequence, Union

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["SSIM"]


class SSIM(Metric):
"""
Computes Structual Similarity Index Measure

Args:
data_range (int or float): Range of the image. Typically, ``1.0`` or ``255``.
kernel_size (int or list or tuple of int): Size of the kernel. Default: (11, 11)
sigma (float or list or tuple of float): Standard deviation of the gaussian kernel.
Argument is used if ``gaussian=True``. Default: (1.5, 1.5)
k1 (float): Parameter of SSIM. Default: 0.01
k2 (float): Parameter of SSIM. Default: 0.03
gaussian (bool): ``True`` to use gaussian kernel, ``False`` to use uniform kernel
output_transform (callable, optional): A callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric.

Example:

To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

``y_pred`` and ``y`` can be un-normalized or normalized image tensors. Depending on that, the user might need
to adjust ``data_range``. ``y_pred`` and ``y`` should have the same shape.

.. code-block:: python

def process_function(engine, batch):
# ...
return y_pred, y
engine = Engine(process_function)
metric = SSIM(data_range=1.0)
metric.attach(engine, "ssim")
"""

def __init__(
self,
data_range: Union[int, float],
kernel_size: Union[int, Sequence[int]] = (11, 11),
sigma: Union[float, Sequence[float]] = (1.5, 1.5),
k1: float = 0.01,
k2: float = 0.03,
gaussian: bool = True,
output_transform: Callable = lambda x: x,
):
if isinstance(kernel_size, int):
self.kernel_size = [kernel_size, kernel_size]
elif isinstance(kernel_size, Sequence):
self.kernel_size = kernel_size
else:
raise ValueError("Argument kernel_size should be either int or a sequence of int.")

if isinstance(sigma, float):
self.sigma = [sigma, sigma]
elif isinstance(sigma, Sequence):
self.sigma = sigma
else:
raise ValueError("Argument sigma should be either float or a sequence of float.")

if any(x % 2 == 0 or x <= 0 for x in self.kernel_size):
raise ValueError("Expected kernel_size to have odd positive number. Got {}.".format(kernel_size))

if any(y <= 0 for y in self.sigma):
raise ValueError("Expected sigma to have positive number. Got {}.".format(sigma))

self.gaussian = gaussian
self.c1 = (k1 * data_range) ** 2
self.c2 = (k2 * data_range) ** 2
self.pad_h = (self.kernel_size[0] - 1) // 2
self.pad_w = (self.kernel_size[1] - 1) // 2
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
super(SSIM, self).__init__(output_transform=output_transform)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_batchwise_ssim = 0.0
self._num_examples = 0
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)

def _uniform(self, kernel_size):
max, min = 2.5, -2.5
kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32)
for i, j in enumerate(kernel):
if min <= j <= max:
kernel[i] = 1 / (max - min)
else:
kernel[i] = 0

return kernel.unsqueeze(dim=0) # (1, kernel_size)

def _gaussian(self, kernel_size, sigma):
kernel = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32)
gauss = torch.exp(-kernel.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)

def _gaussian_or_uniform_kernel(self, kernel_size, sigma):
if self.gaussian:
kernel_x = self._gaussian(kernel_size[0], sigma[0])
kernel_y = self._gaussian(kernel_size[1], sigma[1])
else:
kernel_x = self._uniform(kernel_size[0])
kernel_y = self._uniform(kernel_size[1])

return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y_pred.dtype != y.dtype:
raise TypeError(
"Expected y_pred and y to have the same data type. Got y_pred: {} and y: {}.".format(
y_pred.dtype, y.dtype
)
)

if y_pred.shape != y.shape:
raise ValueError(
"Expected y_pred and y to have the same shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
)

if len(y_pred.shape) != 4 or len(y.shape) != 4:
raise ValueError(
"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {} and y: {}.".format(y_pred.shape, y.shape)
)

channel = y_pred.size(1)
if len(self._kernel.shape) < 4:
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)

y_pred = F.pad(y_pred, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect")
y = F.pad(y, (self.pad_w, self.pad_w, self.pad_h, self.pad_h), mode="reflect")

input_list = torch.cat([y_pred, y, y_pred * y_pred, y * y, y_pred * y])
outputs = F.conv2d(input_list, self._kernel, groups=channel)

output_list = [outputs[x * y_pred.size(0) : (x + 1) * y_pred.size(0)] for x in range(len(outputs))]

mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]

sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target

a1 = 2 * mu_pred_target + self.c1
a2 = 2 * sigma_pred_target + self.c2
b1 = mu_pred_sq + mu_target_sq + self.c1
b2 = sigma_pred_sq + sigma_target_sq + self.c2

ssim_idx = (a1 * a2) / (b1 * b2)
self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return torch.sum(self._sum_of_batchwise_ssim / self._num_examples)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ neptune-client
tensorboard
pynvml; python_version > '3.5'
trains>=0.15.1
scikit-image>=0.15.0
# Examples dependencies
pandas
gym
177 changes: 177 additions & 0 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import os

import pytest
import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import SSIM

try:
from skimage.metrics import structural_similarity as ski_ssim
except ImportError:
from skimage.measure import compare_ssim as ski_ssim


def test_zero_div():
ssim = SSIM(data_range=1.0)
with pytest.raises(NotComputableError):
ssim.compute()


def test_invalid_ssim():
y_pred = torch.rand(16, 1, 32, 32)
y = y_pred + 0.125
with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got 10."):
ssim = SSIM(data_range=1.0, kernel_size=10)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Expected kernel_size to have odd positive number. Got -1."):
ssim = SSIM(data_range=1.0, kernel_size=-1)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument kernel_size should be either int or a sequence of int."):
ssim = SSIM(data_range=1.0, kernel_size=1.0)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
ssim = SSIM(data_range=1.0, sigma=-1)
ssim.update((y_pred, y))
ssim.compute()

with pytest.raises(ValueError, match=r"Argument sigma should be either float or a sequence of float."):
ssim = SSIM(data_range=1.0, sigma=1)
ssim.update((y_pred, y))
ssim.compute()


def test_ssim():
ssim = SSIM(data_range=1.0)
device = "cuda" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(16, 3, 64, 64, device=device)
y = y_pred * 0.65
ssim.update((y_pred, y))

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_y = np_pred * 0.65
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)

ssim = SSIM(data_range=1.0, gaussian=False, kernel_size=7)
device = "cuda" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(16, 3, 227, 227, device=device)
y = y_pred * 0.65
ssim.update((y_pred, y))

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_y = np_pred * 0.65
np_ssim = ski_ssim(np_pred, np_y, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(ssim.compute(), torch.tensor(np_ssim, dtype=torch.float64, device=device), atol=1e-4)


def _test_distrib_integration(device, tol=1e-4):
from ignite.engine import Engine

rank = idist.get_rank()
n_iters = 100
s = 10
offset = n_iters * s

y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device)
y = y_pred * 0.65

def update(engine, i):
return (
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
y[i * s + offset * rank : (i + 1) * s + offset * rank],
)

engine = Engine(update)
SSIM(data_range=1.0).attach(engine, "ssim")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "ssim" in engine.state.metrics
res = engine.state.metrics["ssim"]

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_true = np_pred * 0.65
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True, data_range=1.0)

assert pytest.approx(res, abs=tol) == true_res

engine = Engine(update)
SSIM(data_range=1.0, gaussian=False, kernel_size=7).attach(engine, "ssim")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "ssim" in engine.state.metrics
res = engine.state.metrics["ssim"]

np_pred = y_pred.permute(0, 2, 3, 1).cpu().numpy()
np_true = np_pred * 0.65
true_res = ski_ssim(np_pred, np_true, win_size=7, multichannel=True, gaussian_weights=False, data_range=1.0)

assert pytest.approx(res, abs=tol) == true_res


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl):

device = "cuda:{}".format(local_rank)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)