Skip to content

Commit

Permalink
Update id parser in the converter.
Browse files Browse the repository at this point in the history
  • Loading branch information
gmalivenko committed Jan 1, 2019
1 parent 7e393ee commit 5101a0f
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions pytorch2keras/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
from packaging import version
from collections import defaultdict

import torch
import torch.jit
Expand Down Expand Up @@ -134,8 +135,14 @@ def _optimize_graph(graph, operator_export_type=OperatorExportTypes.RAW):

def get_node_id(node):
import re
node_id = re.search(r"[\d]+", node.__str__())
return node_id.group(0)
try:
node_id = re.search(r"\%[\d\w]+", node.__str__())
int(node_id.group(0)[1:])
return node_id.group(0)[1:]
except AttributeError:
return '0'
except ValueError:
return '0'


def pytorch_to_keras(
Expand Down Expand Up @@ -182,11 +189,9 @@ def pytorch_to_keras(
trace.set_graph(_optimize_graph(trace.graph(), False))

trace.graph().lint()
if verbose:
print(trace.graph())

if verbose:
print(list(trace.graph().outputs()))
print(trace.graph())

# Get all graph nodes
nodes = list(trace.graph().nodes())
Expand Down Expand Up @@ -216,7 +221,6 @@ def pytorch_to_keras(
s = i
k += 1
if k == len(seq_to_find):
print('found seq', k, s)
reshape_op = nodes[s + k - 1]
flatten_op = {
'kind': (lambda: 'onnx::Flatten'),
Expand All @@ -226,27 +230,21 @@ def pytorch_to_keras(
'inputs': (lambda: list(reshape_op.inputs())[:1]),
'__str__': (lambda: reshape_op.__str__()),
}
print(flatten_op)
nodes = nodes[:s] + [SimpleNamespace(**flatten_op)] + nodes[s+k:]
# print(nodes)
# exit(0)
break
else:
k = 0
s = -1

print(nodes)
# Collect graph outputs
graph_outputs = [n.uniqueName() for n in trace.graph().outputs()]
print('Graph outputs:', graph_outputs)


graph_inputs = [n.uniqueName() for n in trace.graph().inputs()]
print('Graph inputs:', graph_inputs)

# Collect graph inputs and outputs
graph_outputs = [get_node_id(n) for n in trace.graph().outputs()]
graph_inputs = [get_node_id(n) for n in trace.graph().inputs()]

# Collect model state dict
state_dict = _unique_state_dict(model)
if verbose:
print('Graph inputs:', graph_inputs)
print('Graph outputs:', graph_outputs)
print('State dict:', list(state_dict))

import re
Expand All @@ -267,38 +265,39 @@ def pytorch_to_keras(
input_index = 0
model_inputs = ['input' + i for i in graph_inputs]

group_indices = defaultdict(lambda: 0, {})

for node in nodes:
node_inputs = list(node.inputs())
# print(node_inputs, model_inputs)
node_input_names = []

for node_input in node_inputs:
if node_input.node().scopeName():
node_input_names.append(get_node_id(node_input.node()))
if 'input{0}'.format(get_node_id(node_input.node())) in model_inputs:
node_input_names.append('input{0}'.format(node_input.uniqueName()))
# print(node_input_names)

# if len(node_input_names) == 0:
# if len(node_inputs) > 0:
# if node_inputs[0] in model_inputs:
# node_input_names.append(model_inputs[node_inputs[0]])
# else:
# input_name = 'input{0}'.format(input_index)
# if input_name not in layers:
# continue
# node_input_names.append(input_name)
# input_index += 1
# model_inputs[node_inputs[0]] = input_name
node_input_names.append('input{0}'.format(get_node_id(node_input.node())))
else:
node_input_names.append(get_node_id(node_input.node()))

node_type = node.kind()
# print(dir(node))

node_scope_name = node.scopeName()
node_id = get_node_id(node)
node_weights_name = '.'.join(
re.findall(r'\[([\w\d.\-\[\]\s]+)\]', node_scope_name)
)
node_name_regex = re.findall(r'\[([\w\d.\-\[\]\s]+)\]', node_scope_name)

try:
int(node_name_regex[-1])
node_weigth_group_name = '.'.join(
node_name_regex[:-1]
)
node_weights_name = node_weigth_group_name + '.' + str(group_indices[node_weigth_group_name])
group_indices[node_weigth_group_name] += 1

except ValueError:
node_weights_name = '.'.join(
node_name_regex
)
except IndexError:
node_weights_name = '.'.join(node_input_names)

node_attrs = {k: node[k] for k in node.attributeNames()}

node_outputs = list(node.outputs())
Expand Down

0 comments on commit 5101a0f

Please sign in to comment.