Skip to content

Commit

Permalink
Update methods.py to allow for TF2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
sanju99 committed Sep 25, 2023
1 parent 87fb43a commit b6218ba
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions deepexplain/tensorflow/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,23 @@ class AttributionMethod(object):
"""
def __init__(self, T, X, session, keras_learning_phase=None):
self.T = T # target Tensor
self.X = X # input Tensor

if type(X) is list and len(X) == 1:
self.X = X[0] # input Tensor
else:
self.X = X

self.Y_shape = [None,] + T.get_shape().as_list()[1:]
# Most often T contains multiple output units. In this case, it is often necessary to select
# a single unit to compute contributions for. This can be achieved passing 'ys' as weight for the output Tensor.
self.Y = tf.placeholder(tf.float32, self.Y_shape)
self.Y = tf.compat.v1.placeholder(tf.float32, self.Y_shape)
# placeholder_from_data(ys) if ys is not None else 1.0 # Tensor that represents weights for T
self.T = self.T * self.Y
self.has_multiple_inputs = type(self.X) is list or type(self.X) is tuple
self.symbolic_attribution = None
self.session = session
self.keras_learning_phase = keras_learning_phase
self.has_multiple_inputs = type(self.X) is list or type(self.X) is tuple

logging.info('Model with multiple inputs: %s' % self.has_multiple_inputs)

# Set baseline
Expand Down Expand Up @@ -179,25 +185,19 @@ def _set_check_baseline(self):
return

if self.baseline is None:
if self.has_multiple_inputs:
self.baseline = [np.zeros([1,] + xi.get_shape().as_list()[1:]) for xi in self.X]
else:
self.baseline = np.zeros([1,] + self.X.get_shape().as_list()[1:])

raise RuntimeError('Please provide a baseline input when using DeepLIFT')
else:
if self.has_multiple_inputs:
for i, xi in enumerate(self.X):
if list(self.baseline[i].shape) == xi.get_shape().as_list()[1:]:
self.baseline[i] = np.expand_dims(self.baseline[i], 0)
if self.baseline[i].shape[1:] == xi.shape[1:]:
self.baseline[i] = np.expand_dims(np.squeeze(self.baseline[i]), 0)
else:
raise RuntimeError('Baseline shape %s does not match expected shape %s'
% (self.baseline[i].shape, self.X.get_shape().as_list()[1:]))
raise RuntimeError(f'Baseline shape {self.baseline[i].shape[1:]} does not match expected shape {xi.shape[1:]}')
else:
if list(self.baseline.shape) == self.X.get_shape().as_list()[1:]:
self.baseline = np.expand_dims(self.baseline, 0)
if self.baseline.shape[1:] == self.X.shape[1:]:
self.baseline = np.expand_dims(np.squeeze(self.baseline), 0)
else:
raise RuntimeError('Baseline shape %s does not match expected shape %s'
% (self.baseline.shape, self.X.get_shape().as_list()[1:]))
raise RuntimeError(f'Baseline shape {self.baseline.shape[1:]} does not match expected shape {self.X.shape[1:]}')


class GradientBasedMethod(AttributionMethod):
Expand Down Expand Up @@ -373,7 +373,7 @@ def _init_references(self):
sys.stdout.flush()
self._deeplift_ref.clear()
ops = []
g = tf.get_default_graph()
g = tf.compat.v1.get_default_graph()
for op in g.get_operations():
if len(op.inputs) > 0 and not op.name.startswith('gradients'):
if op.type in SUPPORTED_ACTIVATIONS:
Expand Down Expand Up @@ -556,7 +556,7 @@ def deepexplain_grad(op, grad):

class DeepExplain(object):

def __init__(self, graph=None, session=tf.get_default_session()):
def __init__(self, graph=None, session=tf.compat.v1.get_default_session()):
self.method = None
self.batch_size = None
self.session = session
Expand Down Expand Up @@ -635,16 +635,11 @@ def _check_ops(self):
and needs to be passed in feed_dict.
:return:
"""
g = tf.get_default_graph()
g = tf.compat.v1.get_default_graph()
for op in g.get_operations():
if len(op.inputs) > 0 and not op.name.startswith('gradients'):
if op.type in UNSUPPORTED_ACTIVATIONS:
warnings.warn('Detected unsupported activation (%s). '
'This might lead to unexpected or wrong results.' % op.type)
elif 'keras_learning_phase' in op.name:
self.keras_phase_placeholder = op.outputs[0]





0 comments on commit b6218ba

Please sign in to comment.