Skip to content

Commit

Permalink
Add more imports for lambda layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
gmalivenko committed Dec 31, 2018
1 parent ede034c commit 40cccf1
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pytorch2keras/reshape_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def convert_reshape(params, w_name, scope_name, inputs, layers, weights, names):
print('Cannot deduct batch size! It will be omitted, but result may be wrong.')

def target_layer(x, shape=layers[inputs[1]]):
import tensorflow as tf
return tf.reshape(x, shape)

lambda_layer = keras.layers.Lambda(target_layer)
Expand Down Expand Up @@ -119,6 +120,7 @@ def convert_squeeze(params, w_name, scope_name, inputs, layers, weights, names):
raise AssertionError('Cannot convert squeeze by multiple dimensions')

def target_layer(x, axis=int(params['axes'][0])):
import tensorflow as tf
return tf.squeeze(x, axis=axis)

lambda_layer = keras.layers.Lambda(target_layer)
Expand Down Expand Up @@ -171,6 +173,7 @@ def convert_shape(params, w_name, scope_name, inputs, layers, weights, names):
print('Converting shape ...')

def target_layer(x):
import tensorflow as tf
return tf.shape(x)

lambda_layer = keras.layers.Lambda(target_layer)
Expand Down

0 comments on commit 40cccf1

Please sign in to comment.