Skip to content

Commit

Permalink
Update conditions on pytorch versions
Browse files Browse the repository at this point in the history
Update conditions to support pytorch sub-versions between 0.4.1 and 1.0.0
  • Loading branch information
EloiZalczer committed Jan 29, 2019
1 parent c08ed38 commit dd961b0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch2keras/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def set_training(model, mode):
model.train(old_mode)


if version.parse('0.4.1') < version.parse(torch.__version__):
if version.parse('1.0.0') <= version.parse(torch.__version__):
from torch._C import ListType

# ONNX can't handle constants that are lists of tensors, which can
Expand All @@ -59,7 +59,7 @@ def _split_tensor_list_constants(g, block):
.setType(ListType.ofTensors()))
node.output().replaceAllUsesWith(lc)

if version.parse('0.4.0') >= version.parse(torch.__version__):
if version.parse('1.0.0') > version.parse(torch.__version__):
def _optimize_graph(graph, aten):
# run dce first to eliminate dead parts of the graph that might have been
# left behind by things like symbolic_override
Expand All @@ -79,7 +79,7 @@ def _optimize_graph(graph, aten):
return graph
else:
def _optimize_graph(graph, operator_export_type=OperatorExportTypes.RAW):
if version.parse('0.4.1') < version.parse(torch.__version__):
if version.parse('1.0.0') <= version.parse(torch.__version__):
torch._C._jit_pass_remove_inplace_ops(graph)
# we record now record some ops like ones/zeros
# into a trace where we previously recorded constants
Expand Down

0 comments on commit dd961b0

Please sign in to comment.