Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

metrics: add SSIM #1217

Merged
merged 18 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Complete list of metrics
- :class:`~ignite.metrics.Recall`
- :class:`~ignite.metrics.RootMeanSquaredError`
- :class:`~ignite.metrics.RunningAverage`
- :class:`~ignite.metrics.SSIM`
- :class:`~ignite.metrics.TopKCategoricalAccuracy`
- :class:`~ignite.metrics.VariableAccumulation`

Expand Down Expand Up @@ -278,6 +279,8 @@ Complete list of metrics

.. autoclass:: RunningAverage

.. autoclass:: SSIM

.. autoclass:: TopKCategoricalAccuracy

.. autoclass:: VariableAccumulation
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.recall import Recall
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy

__all__ = [
Expand All @@ -39,4 +40,5 @@
"RunningAverage",
"VariableAccumulation",
"Frequency",
"SSIM",
]
171 changes: 171 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from typing import Callable, Sequence

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["SSIM"]


class SSIM(Metric):
"""
Computes Structual Similarity Index Measure

Args:
kernel_size: Size of the gaussian kernel. Default: (11, 11)
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
output_transform: A callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric.

Returns:
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
A Tensor with SSIM

Example:

To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. code-block:: python

def process_function(engine, batch):
# ...
return y_pred, y
engine = Engine(process_function)
metric = SSIM()
metric.attach(engine, "ssim")

If the output of the engine is not in the format above, ``output_transform`` argument can be used to transform it.
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

def process_function(engine, batch):
# ...
return {'prediction': y_pred, 'target': y, ...}

engine = Engine(process_function)

def output_transform(output):
# `output` variable is returned by above `process_function`
y_pred = output['prediction']
y = output['target']
return y_pred, y # output format is according to `Accuracy` docs

metric = SSIM(output_transform=output_transform)
metric.attach(engine, "ssim")

The user even can use the metric with ``update`` and ``compute`` methods.

.. code-block:: python

>>> y_pred = torch.rand([16, 1, 16, 16])
>>> y = y_pred * 1.25
>>> ssim = SSIM()
>>> ssim.update((y_pred, y))
>>> ssim.compute()
tensor(0.9520)
"""

def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03,
output_transform: Callable = lambda x: x,
):
if len(kernel_size) != 2 or len(sigma) != 2:
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
)

if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")

if any(y <= 0 for y in sigma):
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")

self.kernel_size = kernel_size
self.sigma = sigma
self.data_range = data_range
self.k1 = k1
self.k2 = k2
super(SSIM, self).__init__(output_transform=output_transform)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_batchwise_ssim = 0.0
self._num_examples = 0

def _gaussian_kernel(self, channel, kernel_size, sigma, device):
def gaussian(kernel_size, sigma, device):
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)

gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y_pred.dtype != y.dtype:
raise TypeError(
f"Expected `y_pred` and `y` to have the same data type. Got y_pred: {y_pred.dtype} and y: {y.dtype}."
)

if y_pred.shape != y.shape:
raise ValueError(
f"Expected `y_pred` and `y` to have the same shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)

if len(y_pred.shape) != 4 or len(y.shape) != 4:
raise ValueError(
f"Expected `y_pred` and `y` to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
)

if self.data_range is None:
self.data_range = max(y_pred.max() - y_pred.min(), y.max() - y.min())
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved

C1 = pow(self.k1 * self.data_range, 2)
C2 = pow(self.k2 * self.data_range, 2)
device = y_pred.device

channel = y_pred.size(1)
kernel = self._gaussian_kernel(channel, self.kernel_size, self.sigma, device)
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
mu_pred = F.conv2d(y_pred, kernel, groups=channel)
mu_target = F.conv2d(y, kernel, groups=channel)

mu_pred_sq = mu_pred.pow(2)
mu_target_sq = mu_target.pow(2)
mu_pred_target = mu_pred * mu_target

sigma_pred_sq = F.conv2d(y_pred * y_pred, kernel, groups=channel) - mu_pred_sq
sigma_target_sq = F.conv2d(y * y, kernel, groups=channel) - mu_target_sq
sigma_pred_target = F.conv2d(y_pred * y, kernel, groups=channel) - mu_pred_target

UPPER = 2 * sigma_pred_target + C2
LOWER = sigma_pred_sq + sigma_target_sq + C2

ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)
self._sum_of_batchwise_ssim += torch.mean(ssim_idx, (1, 2, 3))
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_batchwise_ssim", "_num_examples")
def compute(self) -> torch.Tensor:
if self._num_examples == 0:
raise NotComputableError("SSIM must have at least one example before it can be computed.")
return torch.sum(self._sum_of_batchwise_ssim / self._num_examples)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ neptune-client
tensorboard
pynvml; python_version > '3.5'
trains>=0.15.1
scikit-image
# Examples dependencies
pandas
gym
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ignore = E402, E721
max_line_length = 120

[isort]
known_third_party=dill,matplotlib,numpy,pytest,setuptools,sklearn,torch,torchvision,trains
known_third_party=dill,matplotlib,numpy,pytest,setuptools,skimage,sklearn,torch,torchvision,trains
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
119 changes: 119 additions & 0 deletions tests/ignite/metrics/test_ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import numpy as np
import pytest
import torch
from skimage.metrics import structural_similarity as ski_ssim

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import SSIM


def test_zero_div():
ssim = SSIM()
with pytest.raises(NotComputableError):
ssim.compute()


def test_ssim():
ssim = SSIM()
device = "cuda" if torch.cuda.is_available() else "cpu"
y_pred = torch.rand(16, 3, 32, 32, device=device)
y = y_pred + 0.125
ssim.update((y_pred, y))

np_pred = np.random.rand(16, 32, 32, 3)
np_y = np.add(np_pred, 0.125)
np_ssim = ski_ssim(np_pred, np_y, win_size=11, multichannel=True, gaussian_weights=True)

assert isinstance(ssim.compute(), torch.Tensor)
assert torch.allclose(
ssim.compute(), torch.tensor(np_ssim, dtype=torch.float32, device=device), atol=1e-3, rtol=1e-3
)


def _test_distrib_integration(device, tol=1e-3):
from ignite.engine import Engine

rank = idist.get_rank()
n_iters = 100
s = 10
offset = n_iters * s

y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, dtype=torch.float, device=device)
y = y_pred + 0.125

def update(engine, i):
return (
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
y[i * s + offset * rank : (i + 1) * s + offset * rank],
)

engine = Engine(update)
SSIM().attach(engine, "ssim")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "ssim" in engine.state.metrics
res = engine.state.metrics["ssim"]

np_pred = np.random.rand(offset * idist.get_world_size(), 28, 28, 3)
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
np_true = np.add(np_pred, 0.125)
true_res = ski_ssim(np_pred, np_true, win_size=11, multichannel=True, gaussian_weights=True)

assert pytest.approx(res, rel=tol) == true_res


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(local_rank, distributed_context_single_node_nccl):

device = "cuda:{}".format(local_rank)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = "cpu"
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-3)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)