From 5101a0fe2a89683ea3df8cb896610aaa3da04fb1 Mon Sep 17 00:00:00 2001 From: Grigory Malivenko Date: Wed, 2 Jan 2019 02:27:05 +0300 Subject: [PATCH] Update id parser in the converter. --- pytorch2keras/converter.py | 77 +++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/pytorch2keras/converter.py b/pytorch2keras/converter.py index ac18086..22b0b4a 100644 --- a/pytorch2keras/converter.py +++ b/pytorch2keras/converter.py @@ -4,6 +4,7 @@ import contextlib from packaging import version +from collections import defaultdict import torch import torch.jit @@ -134,8 +135,14 @@ def _optimize_graph(graph, operator_export_type=OperatorExportTypes.RAW): def get_node_id(node): import re - node_id = re.search(r"[\d]+", node.__str__()) - return node_id.group(0) + try: + node_id = re.search(r"\%[\d\w]+", node.__str__()) + int(node_id.group(0)[1:]) + return node_id.group(0)[1:] + except AttributeError: + return '0' + except ValueError: + return '0' def pytorch_to_keras( @@ -182,11 +189,9 @@ def pytorch_to_keras( trace.set_graph(_optimize_graph(trace.graph(), False)) trace.graph().lint() - if verbose: - print(trace.graph()) if verbose: - print(list(trace.graph().outputs())) + print(trace.graph()) # Get all graph nodes nodes = list(trace.graph().nodes()) @@ -216,7 +221,6 @@ def pytorch_to_keras( s = i k += 1 if k == len(seq_to_find): - print('found seq', k, s) reshape_op = nodes[s + k - 1] flatten_op = { 'kind': (lambda: 'onnx::Flatten'), @@ -226,27 +230,21 @@ def pytorch_to_keras( 'inputs': (lambda: list(reshape_op.inputs())[:1]), '__str__': (lambda: reshape_op.__str__()), } - print(flatten_op) nodes = nodes[:s] + [SimpleNamespace(**flatten_op)] + nodes[s+k:] - # print(nodes) - # exit(0) break else: k = 0 s = -1 - print(nodes) - # Collect graph outputs - graph_outputs = [n.uniqueName() for n in trace.graph().outputs()] - print('Graph outputs:', graph_outputs) - - - graph_inputs = [n.uniqueName() for n in trace.graph().inputs()] - print('Graph inputs:', graph_inputs) - + # Collect graph inputs and outputs + graph_outputs = [get_node_id(n) for n in trace.graph().outputs()] + graph_inputs = [get_node_id(n) for n in trace.graph().inputs()] + # Collect model state dict state_dict = _unique_state_dict(model) if verbose: + print('Graph inputs:', graph_inputs) + print('Graph outputs:', graph_outputs) print('State dict:', list(state_dict)) import re @@ -267,38 +265,39 @@ def pytorch_to_keras( input_index = 0 model_inputs = ['input' + i for i in graph_inputs] + group_indices = defaultdict(lambda: 0, {}) + for node in nodes: node_inputs = list(node.inputs()) - # print(node_inputs, model_inputs) node_input_names = [] for node_input in node_inputs: - if node_input.node().scopeName(): - node_input_names.append(get_node_id(node_input.node())) if 'input{0}'.format(get_node_id(node_input.node())) in model_inputs: - node_input_names.append('input{0}'.format(node_input.uniqueName())) - # print(node_input_names) - - # if len(node_input_names) == 0: - # if len(node_inputs) > 0: - # if node_inputs[0] in model_inputs: - # node_input_names.append(model_inputs[node_inputs[0]]) - # else: - # input_name = 'input{0}'.format(input_index) - # if input_name not in layers: - # continue - # node_input_names.append(input_name) - # input_index += 1 - # model_inputs[node_inputs[0]] = input_name + node_input_names.append('input{0}'.format(get_node_id(node_input.node()))) + else: + node_input_names.append(get_node_id(node_input.node())) node_type = node.kind() - # print(dir(node)) node_scope_name = node.scopeName() node_id = get_node_id(node) - node_weights_name = '.'.join( - re.findall(r'\[([\w\d.\-\[\]\s]+)\]', node_scope_name) - ) + node_name_regex = re.findall(r'\[([\w\d.\-\[\]\s]+)\]', node_scope_name) + + try: + int(node_name_regex[-1]) + node_weigth_group_name = '.'.join( + node_name_regex[:-1] + ) + node_weights_name = node_weigth_group_name + '.' + str(group_indices[node_weigth_group_name]) + group_indices[node_weigth_group_name] += 1 + + except ValueError: + node_weights_name = '.'.join( + node_name_regex + ) + except IndexError: + node_weights_name = '.'.join(node_input_names) + node_attrs = {k: node[k] for k in node.attributeNames()} node_outputs = list(node.outputs())