forked from NifTK/NiftyNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rand_flip.py
executable file
·60 lines (49 loc) · 2.05 KB
/
rand_flip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import warnings
import numpy as np
from niftynet.layer.base_layer import RandomisedLayer
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)
class RandomFlipLayer(RandomisedLayer):
"""
Add a random flipping layer as pre-processing.
"""
def __init__(self,
flip_axes,
flip_probability=0.5,
name='random_flip'):
"""
:param flip_axes: a list of indices over which to flip
:param flip_probability: the probability of performing the flip
(default = 0.5)
:param name:
"""
super(RandomFlipLayer, self).__init__(name=name)
self._flip_axes = flip_axes
self._flip_probability = flip_probability
self._rand_flip = None
def randomise(self, spatial_rank=3):
spatial_rank = int(np.floor(spatial_rank))
self._rand_flip = np.random.random(
size=spatial_rank) < self._flip_probability
def _apply_transformation(self, image):
assert self._rand_flip is not None, "Flip is unset -- Error!"
for axis_number, do_flip in enumerate(self._rand_flip):
if axis_number in self._flip_axes and do_flip:
image = np.flip(image, axis=axis_number)
return image
def layer_op(self, inputs, interp_orders=None, *args, **kwargs):
if inputs is None:
return inputs
if isinstance(inputs, dict) and isinstance(interp_orders, dict):
for (field, image_data) in inputs.items():
assert (all([i < 0 for i in interp_orders[field]]) or
all([i >= 0 for i in interp_orders[field]])), \
'Cannot combine interpolatable and non-interpolatable data'
if interp_orders[field][0]<0:
continue
inputs[field] = self._apply_transformation(image_data)
else:
inputs = self._apply_transformation(inputs)
return inputs