diff --git a/src/caffe2-metadata.json b/src/caffe2-metadata.json index 3e368ff27db..934d808251c 100644 --- a/src/caffe2-metadata.json +++ b/src/caffe2-metadata.json @@ -502,7 +502,7 @@ { "name": "SpatialBN", "schema": { - "attributes ": [ + "attributes": [ { "default": 0, "description": "If set to nonzero, run spatial batch normalization in test mode.", diff --git a/src/keras.js b/src/keras.js index 914767d5421..99f4c4e3fc8 100644 --- a/src/keras.js +++ b/src/keras.js @@ -660,10 +660,6 @@ keras.Node = class { return this._attributes; } - get dependencies() { - return []; - } - get inner() { return this._inner; } diff --git a/src/onnx.js b/src/onnx.js index 598b41880bb..fadccb2d82b 100644 --- a/src/onnx.js +++ b/src/onnx.js @@ -531,10 +531,6 @@ onnx.Node = class { get outputs() { return this._outputs; } - - get dependencies() { - return []; - } }; onnx.Attribute = class { diff --git a/src/tf-metadata.json b/src/tf-metadata.json index 4b69a5be197..6259718b838 100644 --- a/src/tf-metadata.json +++ b/src/tf-metadata.json @@ -26668,6 +26668,10 @@ ], "name": "T", "type": "type" + }, + { + "name": "_gradient_op_type", + "visible": false } ], "description": "*NOTE*: `Maximum` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)", diff --git a/src/tf.js b/src/tf.js index 40887f933b8..ae759311253 100644 --- a/src/tf.js +++ b/src/tf.js @@ -213,7 +213,6 @@ tf.Graph = class { this._metadata = new tf.GraphMetadata(metadata, metaGraph.meta_info_def); this._name = name; this._operators = {}; - this._inputMap = {}; if (metaGraph.graph_def) { var graph = metaGraph.graph_def; if (graph.versions) { @@ -232,75 +231,14 @@ tf.Graph = class { this._operators[node.op] = (this._operators[node.op] || 0) + 1; }); } - } - - get operators() { - return this._operators; - } - - get name() { - return this._name; - } - - get version() { - return this._version; - } - - get tags() { - return this._tags; - } - - get groups() { - return false; - // TODO return true; - } - - get inputs() { - this._update(); - return Object.keys(this._inputMap).map((key) => { - return this._inputMap[key]; - }); - } - - get outputs() { - this._update(); - return []; - } - - get nodes() { - this._update(); - var results = []; - if (this._metaGraph.graph_def) { - this._metaGraph.graph_def.node.forEach((node) => { - if (node.output.filter(output => !output.startsWith('^')) != 0 || - node.input.filter(input => !input.startsWith('^')).length > 0) { - var id = node.name; - if (!this._initializerMap[id] && !this._inputMap[id] /* && node.op != 'NoOp' */) { - results.push(new tf.Node(this, node)); - } - } - }); - } - return results; - } - - get metadata() { - return this._metadata; - } - - get namespaces() { - return this._namespaces; - } - - _update() { - if (!this._nodeMap && this._metaGraph.graph_def.node) { - this._nodeMap = {}; + var nodeMap = {}; + if (this._metaGraph.graph_def.node) { this._namespaces = {}; var nodes = this._metaGraph.graph_def.node; nodes.forEach((node) => { var name = node.name; - this._nodeMap[name] = node; - if (node.op != 'Const') { + nodeMap[name] = node; + if (node.op != 'Const') { var lastIndex = name.lastIndexOf('/'); if (lastIndex != -1) { var namespace = name.substring(0, lastIndex); @@ -310,115 +248,134 @@ tf.Graph = class { node.output = []; }); nodes.forEach((node) => { - for (var i = 0; i < node.input.length; i++) - { - var split = node.input[i].split(':', 2); + var inputs = node.input; + node.input = []; + node.controlDependencies = []; + inputs.forEach((input) => { + var split = input.split(':', 2); var inputName = split[0]; - if (!inputName.startsWith('^')) { - var outputIndex = split.length == 1 ? 0 : parseInt(split[1]); - var outputName = inputName; - var outputNode = this._nodeMap[outputName]; - node.input[i] = outputIndex == 0 ? inputName : inputName + ':' + outputIndex.toString(); - if (outputNode) { - for (var j = outputNode.output.length; j <= outputIndex; j++) { - outputNode.output.push(''); - } - outputNode.output[outputIndex] = node.input[i]; - } + var outputIndex = split.length == 1 ? 0 : parseInt(split[1]); + var outputName = inputName.startsWith('^') ? inputName.substring(1) : inputName; + var outputNode = nodeMap[outputName]; + outputName = outputIndex == 0 ? outputName : outputName + ':' + outputIndex.toString(); + if (inputName.startsWith('^')) { + node.controlDependencies.push(outputName); } else { - var sourceName = inputName.substring(1); - var sourceNode = this._nodeMap[sourceName]; - if (sourceNode) { - if (!sourceNode.dependency) { - sourceNode.dependency = []; - } - sourceNode.dependency.push({ - id: inputName, - name: node.name, - operator: node.op - }); + node.input.push(outputName); + } + if (outputNode) { + for (var j = outputNode.output.length; j <= outputIndex; j++) { + outputNode.output.push(''); } + outputNode.output[outputIndex] = outputName; } - } + }); }); this._nodeOutputCountMap = {}; nodes.forEach((node) => { - this._metadata.getInputs(node).forEach((input) => { - var multiple = input.connections.length > 1; - input.connections.forEach((connection) => { - if (multiple) { - this._nodeOutputCountMap[connection.id] = 'N'; - } - else { - var id = connection.id.startsWith('^') ? connection.id.substring(1) : connection.id; - var count = this._nodeOutputCountMap[id]; - if (count != 'N') { - count = count ? count : 0; - this._nodeOutputCountMap[input] = count + 1; - } - } - }); - }); - node.input.forEach((input) => { - input = input.startsWith('^') ? input.substring(1) : input; - var count = this._nodeOutputCountMap[input]; - if (!count) { - count = 0; - } - this._nodeOutputCountMap[input] = count + 1; + this._nodeOutputCountMap[input] = (this._nodeOutputCountMap[input] || 0) + 1; + }); + node.controlDependencies.forEach((controlDependency) => { + this._nodeOutputCountMap[controlDependency] = (this._nodeOutputCountMap[controlDependency] || 0) + 1; }); }); - this._initializerMap = {}; + var initializers = {}; this._metaGraph.graph_def.node.forEach((node) => { - if (node.op == 'Const' && this._checkEmptyInput(node) && this._checkSingleOutput(node)) { + if (node.op == 'Const' && node.input.length == 0 && node.controlDependencies.length == 0 && this._checkSingleOutput(node)) { var value = node.attr.value; if (value && value.hasOwnProperty('tensor')) { var output = node.output[0]; if (output) { - this._initializerMap[output] = new tf.Tensor(value.tensor, node.name, 'Constant'); + initializers[output] = new tf.Tensor(value.tensor, node.name, 'Constant'); } } } }); this._metaGraph.graph_def.node.forEach((node) => { - if (node.op == 'Identity' && node.input.length == 1 && this._checkSingleOutput(node)) { + if (node.op == 'Identity' && node.input.length == 1 && node.controlDependencies.length == 0 && this._checkSingleOutput(node)) { var input = node.input[0]; - var tensor = this._initializerMap[input]; + var tensor = initializers[input]; if (tensor) { var output = node.output[0]; - this._initializerMap[input] = "-"; + initializers[input] = "-"; tensor.kind = 'Identity Constant'; - this._initializerMap[output] = tensor; + initializers[output] = tensor; } } }); - this._inputMap = {}; + var inputMap = {}; this._metaGraph.graph_def.node.forEach((node) => { - if (node.op == 'Placeholder' && node.input.length == 0 && node.output.length == 1) { + if (node.op == 'Placeholder' && node.input.length == 0 && node.controlDependencies.length == 0 && node.output.length == 1) { var dtype = node.attr.dtype; var shape = node.attr.shape; if (dtype && dtype.type && shape && shape.shape) { var type = new tf.TensorType(dtype.type, shape.shape); var connection = new tf.Connection(node.output[0], type, null); - this._inputMap[node.output[0]] = new tf.Argument(node.name, [ connection ]); + inputMap[node.output[0]] = new tf.Argument(node.name, [ connection ]); + } + } + }); + + this._inputs = Object.keys(inputMap).map((key) => { + return inputMap[key]; + }); + + this._nodes = []; + this._metaGraph.graph_def.node.forEach((node) => { + if (node.output.filter(output => !output.startsWith('^')) != 0 || + node.input.filter(input => !input.startsWith('^')).length > 0) { + var id = node.name; + if (!initializers[id] && !inputMap[id] /* && node.op != 'NoOp' */) { + this._nodes.push(new tf.Node(this, node, initializers)); } } }); } } - _getInitializer(input) { - var initializer = this._initializerMap[input]; - return initializer ? initializer : null; + get operators() { + return this._operators; } - _checkEmptyInput(node) { - var inputs = node.input.filter((input) => !input.startsWith('^')); - return inputs.length == 0; + get name() { + return this._name; + } + + get version() { + return this._version; + } + + get tags() { + return this._tags; + } + + get groups() { + return false; + // TODO return true; + } + + get inputs() { + return this._inputs; + } + + get outputs() { + return []; + } + + get nodes() { + return this._nodes; + } + + get metadata() { + return this._metadata; + } + + get namespaces() { + return this._namespaces; } _checkSingleOutput(node) { @@ -478,37 +435,99 @@ tf.Connection = class { tf.Node = class { - constructor(graph, node) { + constructor(graph, node, initializers) { this._graph = graph; - this._node = node; + this._operator = node.op; + this._name = node.name; + if (node.hasOwnProperty('device')) { + this._device = node.device; + } var metadata = graph.metadata; this._attributes = []; if (node.attr) { Object.keys(node.attr).forEach((name) => { var value = node.attr[name]; - this._attributes.push(new tf.Attribute(name, value, node.op, metadata)); + this._attributes.push(new tf.Attribute(name, value, this._operator, metadata)); }); } - } - get graph() { - return this._graph; + var schema = metadata.getSchema(node.op); + + this._inputs = []; + var inputIndex = 0; + var inputs = node.input.filter(input => !input.startsWith('^')); + if (schema && schema.inputs) { + schema.inputs.forEach((input) => { + var inputCount = 1; + if (input.numberAttr) { + var number = node.attr[input.numberAttr]; + if (number && number.i) { + inputCount = number.i; + } + } + var result = {}; + result.name = input.name; + var connections = inputs.slice(inputIndex, inputIndex + inputCount).map((id) => { + return new tf.Connection(id, null, initializers[id]); + }); + this._inputs.push(new tf.Argument(input.name, connections)); + inputIndex += inputCount; + }); + } + else { + inputs.slice(inputIndex).forEach((input) => { + this._inputs.push(new tf.Argument(inputIndex.toString(), [ + new tf.Connection(input, null, initializers[input]) + ])); + inputIndex++; + }); + } + + this._outputs = []; + var outputIndex = 0; + var outputs = node.output; + if (schema && schema.outputs) { + schema.outputs.forEach((output) => { + var outputCount = 1; + if (output.numberAttr) { + var number = node.attr[output.numberAttr]; + if (number && number.i) { + outputCount = number.i; + } + } + var connections = outputs.slice(outputIndex, outputIndex + outputCount).map((id) => { + return new tf.Connection(id, null, null); + }); + this._outputs.push(new tf.Argument(output.name, connections)); + outputIndex += outputCount; + }); + } + else { + outputs.slice(outputIndex).forEach((output) => { + this._outputs.push(new tf.Argument(outputIndex.toString(), [ + new tf.Connection(output, null, null) + ])); + outputIndex++; + }); + } + + this._controlDependencies = node.controlDependencies; } get operator() { - return this._node.op; + return this._operator; } get name() { - return this._node.name; + return this._name; } get device() { - return this._node.device; + return this._device || null; } get group() { - var name = this._node.name; + var name = this._name; if (this._graph.namespaces[name]) { return name; } @@ -598,7 +617,6 @@ tf.Node = class { return schema; } return null; - } get category() { @@ -607,34 +625,15 @@ tf.Node = class { } get inputs() { - if (this._node.input) { - var inputs = this._graph.metadata.getInputs(this._node); - return inputs.map((input) => { - return new tf.Argument(input.name, input.connections.map((connection) => { - var initializer = this._graph._getInitializer(connection.id); - return new tf.Connection(connection.id, null, initializer); - })); - }); - } - return []; + return this._inputs; } get outputs() { - return this._graph.metadata.getOutputs(this._node).map((output) => { - return new tf.Argument(output.name, output.connections.map((connection) => { - return new tf.Connection(connection.id, null, null); - })); - }); + return this._outputs; } - get dependencies() { - var results = []; - if (this._node.dependency) { - this._node.dependency.forEach((dependency) => { - results.push(dependency); - }); - } - return results; + get controlDependencies() { + return this._controlDependencies; } get attributes() { @@ -733,7 +732,7 @@ tf.Attribute = class { } if (schema) { - if (schema.hasOwnProperty('visible') && !attributeSchema.visible) { + if (schema.hasOwnProperty('visible') && !schema.visible) { this._visible = false; } else if (schema.hasOwnProperty('default')) { @@ -797,7 +796,7 @@ tf.Tensor = class { } get kind() { - return this._kind; + return this._kind || null; } set kind(value) { @@ -1131,96 +1130,6 @@ tf.GraphMetadata = class { return {}; } - getInputs(node) { - var results = []; - var index = 0; - var inputs = node.input.filter(input => !input.startsWith('^')); - var schema = this.getSchema(node.op); - if (schema && schema.inputs) { - schema.inputs.forEach((input) => { - var count = 1; - if (input.numberAttr) { - var number = node.attr[input.numberAttr]; - if (number && number.i) { - count = number.i; - } - } - var result = {}; - result.name = input.name; - if (input.type) { - result.type = tf.Tensor.formatDataType(input.type); - } - else if (input.typeAttr) { - result.type = input.typeAttr; - } - else if (input.typeListAttr) { - result.type = input.typeListAttr; - } - result.connections = inputs.slice(index, index + count).map((id) => { - return { - id: id - }; - }); - results.push(result); - index += count; - }); - } - else { - inputs.slice(index).forEach((input) => { - results.push({ - name: '(' + index.toString() + ')', - connections: [ { id: input } ] - }); - index++; - }); - } - return results; - } - - getOutputs(node) { - var results = []; - var index = 0; - var outputs = node.output; - var schema = this.getSchema(node.op); - if (schema && schema.outputs) { - schema.outputs.forEach((output) => { - var count = 1; - if (output.numberAttr) { - var number = node.attr[output.numberAttr]; - if (number && number.i) { - count = number.i; - } - } - var result = {}; - result.name = output.name; - if (output.type) { - result.type = tf.Tensor.formatDataType(output.type); - } - else if (output.typeAttr) { - result.type = output.typeAttr; - } - else if (output.typeListAttr) { - result.type = output.typeListAttr; - } - result.connections = outputs.slice(index, index + count).map((id) => { - return { id: id }; - }); - results.push(result); - index += count; - }); - } - else { - outputs.slice(index).forEach((output) => { - results.push({ - name: '(' + index.toString() + ')', - connections: [ { id: output } ] - }); - index++; - }); - } - return results; - } - static _formatAttributeValue(value) { if (value == null) { return null; diff --git a/src/tflite.js b/src/tflite.js index 278180c3594..987793a6484 100644 --- a/src/tflite.js +++ b/src/tflite.js @@ -286,10 +286,6 @@ tflite.Node = class { return this._chain; } - get dependencies() { - return []; - } - get attributes() { return this._attributes; } diff --git a/src/view-grapher.css b/src/view-grapher.css index c25f92817af..ca918325d61 100644 --- a/src/view-grapher.css +++ b/src/view-grapher.css @@ -2,8 +2,6 @@ .node path { stroke: #333; fill: none; stroke-width: 1px; } .node line { stroke: #333; fill: none; stroke-width: 1px; } -.node-control-dependency { stroke-dasharray: 4, 1; } - .node-item path { stroke-width: 0; stroke: #000; fill: #fff; } .node-item text { font-family: 'Open Sans', --apple-system, "Helvetica Neue", Helvetica, Arial, sans-serf; font-size: 10px; font-weight: 600; text-rendering: geometricPrecision; } @@ -66,6 +64,7 @@ .edge-label text { font-family: 'Open Sans', --apple-system, "Helvetica Neue", Helvetica, Arial, sans-serf; font-size: 10px; } .edge-path { stroke: #000; stroke-width: 1px; fill: none; } #arrowhead-vee { fill: #000; } +.edge-path-control-dependency { stroke-dasharray: 3, 2; } .cluster rect { stroke: #000; fill: #000; fill-opacity: 0.02; stroke-opacity: 0.06; stroke-width: 1px; } diff --git a/src/view-grapher.js b/src/view-grapher.js index 95bc64647a5..0ea9af21af6 100644 --- a/src/view-grapher.js +++ b/src/view-grapher.js @@ -227,10 +227,6 @@ grapher.NodeElement = class { return this._block; } - setControlDependencies() { - this._controlDependencies = true; - } - format(contextElement) { var rootElement = this.createElement('g'); contextElement.appendChild(rootElement); @@ -254,11 +250,7 @@ grapher.NodeElement = class { }); var borderElement = this.createElement('path'); - var borderClassList = [ 'node', 'border' ]; - if (this._controlDependencies) { - borderClassList.push('node-control-dependency'); - } - borderElement.setAttribute('class', borderClassList.join(' ')); + borderElement.setAttribute('class', [ 'node', 'border' ].join(' ')); borderElement.setAttribute('d', grapher.NodeElement.roundedRect(0, 0, width, height, true, true, true, true)); rootElement.appendChild(borderElement); diff --git a/src/view.js b/src/view.js index 1a172ac3a95..2ff23574bfb 100644 --- a/src/view.js +++ b/src/view.js @@ -516,14 +516,20 @@ view.View = class { var connection = initializer.connections[0]; var type = connection.type; var shape = ''; + var separator = ''; if (type && type.shape && type.shape.dimensions && type.shape.dimensions.hasOwnProperty('length')) { shape = '\u3008' + type.shape.dimensions.join('\u00D7') + '\u3009'; + if (type.shape.dimensions.length == 0 && connection.initializer) { + shape = connection.initializer.toString(); + separator = ' = '; + } } - block.add('initializer-' + connection.id, initializer.name, shape, type ? type.toString() : '', ''); + block.add('initializer-' + connection.id, initializer.name, shape, type ? type.toString() : '', separator); }); if (hiddenInitializers) { block.add(null, '\u3008' + '...' + '\u3009', '', null, ''); } + attributes.forEach((attribute) => { if (attribute.visible) { var attributeValue = view.View.formatAttributeValue(attribute.value, attribute.type); @@ -554,7 +560,7 @@ view.View = class { }); var outputs = node.outputs; if (node.chain && node.chain.length > 0) { - var chainOutputs = node.chain[node.chain.length - 1].outputs + var chainOutputs = node.chain[node.chain.length - 1].outputs; if (chainOutputs.length > 0) { outputs = chainOutputs; } @@ -588,11 +594,21 @@ view.View = class { addNode(element, node, true); - var dependencies = node.dependencies; - if (dependencies && dependencies.length > 0) { - element.setControlDependencies(); + if (node.controlDependencies && node.controlDependencies.length > 0) { + node.controlDependencies.forEach((controlDependency) => { + var tuple = edgeMap[controlDependency]; + if (!tuple) { + tuple = { from: null, to: [] }; + edgeMap[controlDependency] = tuple; + } + tuple.to.push({ + node: nodeId, + name: controlDependency, + controlDependency: true + }); + }); } - + var name = node.name; if (name) { g.setNode(nodeId, { label: element.format(graphElement), id: 'node-' + name }); @@ -702,8 +718,8 @@ view.View = class { text = edge.split('\n').shift(); // custom connection id } - if (to.dependency) { - g.setEdge(tuple.from.node, to.node, { label: text, id: 'edge-' + edge, arrowhead: 'vee', class: 'edge-path-control' } ); + if (to.controlDependency) { + g.setEdge(tuple.from.node, to.node, { label: text, id: 'edge-' + edge, arrowhead: 'vee', class: 'edge-path-control-dependency' } ); } else { g.setEdge(tuple.from.node, to.node, { label: text, id: 'edge-' + edge, arrowhead: 'vee' } ); @@ -980,6 +996,9 @@ view.View = class { return value.map((item) => item.toString()).join(', '); } if (type == 'tensor') { + if (value.type && value.type.shape && value.type.shape.dimensions && value.type.shape.dimensions.length == 0) { + return value.toString(); + } return '[...]'; } if (Array.isArray(value)) {