Skip to content

Commit

Permalink
Merge pull request #82 from Vent1narc/master
Browse files Browse the repository at this point in the history
onnx:ELU support
  • Loading branch information
gmalivenko committed Jun 13, 2019
2 parents c76ba52 + 324b529 commit 73d6ae9
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[![Readthedocs](https://img.shields.io/readthedocs/pytorch2keras.svg)](https://pytorch2keras.readthedocs.io/en/latest/)


PyTorch to Keras model converter.
PyTorch to Keras model converter.

## Installation

Expand Down Expand Up @@ -207,6 +207,7 @@ Options:
## Supported layers

* Activations:
+ ELU
+ ReLU
+ LeakyReLU
+ SELU
Expand Down Expand Up @@ -319,4 +320,4 @@ Options:
Look at the `tests` directory.

## License
This software is covered by MIT License.
This software is covered by MIT License.
26 changes: 26 additions & 0 deletions pytorch2keras/activation_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@
from .common import random_string


def convert_elu(params, w_name, scope_name, inputs, layers, weights, names):
"""
Convert elu layer.
Args:
params: dictionary with layer parameters
w_name: name prefix in state_dict
scope_name: pytorch scope name
inputs: pytorch node inputs
layers: dictionary with keras tensors
weights: pytorch state_dict
names: use short names for keras layers
"""
print('Converting elu ...')

if names == 'short':
tf_name = 'ELU' + random_string(4)
elif names == 'keep':
tf_name = w_name
else:
tf_name = w_name + str(random.random())

elu = keras.layers.Activation('elu', name=tf_name)
layers[scope_name] = elu(layers[inputs[0]])


def convert_relu(params, w_name, scope_name, inputs, layers, weights, names):
"""
Convert relu layer.
Expand Down
2 changes: 1 addition & 1 deletion pytorch2keras/convolution_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def convolve_lambda(i, k):
return tf.nn.conv2d(i, k, strides=[1, stride_y, stride_x, 1], padding='VALID')

input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x)
weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=W.transpose(0, 1, 2, 3))
weight_groups = tf.split(axis=3, num_or_size_splits=groups, value=np.array(W, dtype=np.float32).transpose(0, 1, 2, 3))
output_groups = [convolve_lambda(i, k) for i, k in zip(input_groups, weight_groups)]

layer = tf.concat(axis=3, values=output_groups)
Expand Down
3 changes: 2 additions & 1 deletion pytorch2keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .elementwise_layers import convert_elementwise_add, convert_elementwise_mul, \
convert_elementwise_div, convert_elementwise_sub
from .activation_layers import convert_relu, convert_lrelu, convert_selu, \
convert_softmax, convert_sigmoid, convert_tanh, convert_hardtanh
convert_softmax, convert_sigmoid, convert_tanh, convert_hardtanh, convert_elu
from .pooling_layers import convert_avgpool, convert_maxpool, convert_maxpool3, \
convert_adaptive_avg_pool2d, convert_adaptive_max_pool2d
from .normalization_layers import convert_batchnorm, convert_instancenorm, convert_dropout
Expand Down Expand Up @@ -45,6 +45,7 @@
'onnx::Sub': convert_elementwise_sub,
'onnx::Sum': convert_sum,
'onnx::Concat': convert_concat,
'onnx::Elu': convert_elu,
'onnx::Relu': convert_relu,
'onnx::LeakyRelu': convert_lrelu,
'onnx::Sigmoid': convert_sigmoid,
Expand Down

0 comments on commit 73d6ae9

Please sign in to comment.