Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
gmalivenko committed Aug 6, 2018
2 parents e2e4c0a + 809415d commit 0984dfe
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions pytorch2keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def convert_sum(
print('Converting Sum ...')

def target_layer(x):
return keras.backend.sum(x)
import keras.backend as K
return K.sum(x)

lambda_layer = keras.layers.Lambda(target_layer)
layers[scope_name] = lambda_layer(layers[inputs[0]])
Expand Down Expand Up @@ -1033,10 +1034,11 @@ def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights, shor
print('Converting reduce_sum ...')

keepdims = params['keepdims'] > 0
axis = np.array(params['axes'])
axis = params['axes']

def target_layer(x, keepdims=keepdims, axis=axis):
return keras.backend.sum(x, keepdims=keepdims, axis=axis)
import keras.backend as K
return K.sum(x, keepdims=keepdims, axis=axis)

lambda_layer = keras.layers.Lambda(target_layer)
layers[scope_name] = lambda_layer(layers[inputs[0]])
Expand All @@ -1057,12 +1059,15 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights, short_
"""
print('Converting constant ...')

# def target_layer(x, params=params):
# return keras.backend.constant(np.float32(params['value']))
params_list = params['value'].numpy().tolist()

# lambda_layer = keras.layers.Lambda(target_layer)
# layers[scope_name] = lambda_layer(layers[inputs[0]])
layers[scope_name] = np.float32(params['value'])
def target_layer(x):
import keras.backend as K
return K.constant(params_list)

lambda_layer = keras.layers.Lambda(target_layer)
layers[scope_name] = lambda_layer(layers['input0']) # Temporary fix for nonexistent input name created by converter.py
# layers[scope_name] = params['value']


def convert_upsample(params, w_name, scope_name, inputs, layers, weights, short_names):
Expand Down

0 comments on commit 0984dfe

Please sign in to comment.