Skip to content

Commit badd7e6

Browse files
committed
Replace type strings with tensor types (#71)
1 parent 9dc7b85 commit badd7e6

10 files changed

+139
-98
lines changed

src/caffe-model.js

+4-4
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class CaffeGraph {
118118
this._inputs.push({
119119
id: input,
120120
name: input,
121-
type: 'T'
121+
type: null
122122
});
123123
});
124124
}
@@ -155,15 +155,15 @@ class CaffeGraph {
155155
this._outputs.push({
156156
id: keys[0],
157157
name: keys[0],
158-
type: 'T'
158+
type: null
159159
});
160160
}
161161
else if (outputs.length == 1) {
162162
outputs[0]._outputs = [ 'output' ];
163163
this._outputs.push({
164164
id: 'output',
165165
name: 'output',
166-
type: 'T'
166+
type: null
167167
});
168168
}
169169
}
@@ -294,7 +294,7 @@ class CaffeNode {
294294
input.connections.forEach((connection) => {
295295
if (connection.id instanceof CaffeTensor) {
296296
connection.initializer = connection.id;
297-
connection.type = connection.initializer.type.toString();
297+
connection.type = connection.initializer.type;
298298
connection.id = '';
299299
}
300300
});

src/caffe2-model.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class Caffe2Graph {
117117
this._inputs.push({
118118
id: input,
119119
name: input,
120-
type: 'T'
120+
type: null
121121
});
122122
}
123123
});
@@ -127,7 +127,7 @@ class Caffe2Graph {
127127
this._outputs.push({
128128
id: output,
129129
name: output,
130-
type: 'T'
130+
type: null
131131
});
132132
});
133133
}
@@ -214,7 +214,7 @@ class Caffe2Node {
214214
var initializer = this._initializers[connection.id];
215215
if (initializer) {
216216
connection.initializer = initializer;
217-
connection.type = initializer.type.toString();
217+
connection.type = initializer.type;
218218
}
219219
});
220220
});

src/coreml-model.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,13 @@ class CoreMLGraph {
378378
result = 'image(' + CoreMLGraph.formatColorSpace(type.imageType.colorSpace) + ',' + type.imageType.width.toString() + 'x' + type.imageType.height.toString() + ')';
379379
break;
380380
case 'dictionaryType':
381-
result = 'map<' + type.dictionaryType.KeyType.replace('KeyType', '') + ',double>';
381+
result = 'map<' + type.dictionaryType.KeyType.replace('KeyType', '') + ',float64>';
382382
break;
383383
case 'stringType':
384384
result = 'string';
385385
break;
386386
case 'doubleType':
387-
result = 'double';
387+
result = 'float64';
388388
break;
389389
case 'int64Type':
390390
result = 'int64';
@@ -477,7 +477,7 @@ class CoreMLNode {
477477
name: initializer.name,
478478
connections: [ {
479479
id: '',
480-
type: initializer.type.toString(),
480+
type: initializer.type,
481481
initializer: initializer, } ]
482482
};
483483
if (!CoreMLOperatorMetadata.operatorMetadata.getInputVisible(this._operator, initializer.name)) {

src/keras-model.js

+9-8
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class KerasGraph {
281281
if (addGraphOutput) {
282282
this._outputs.push({
283283
id: inputName,
284-
type: 'T',
284+
type: null,
285285
name: inputName
286286
});
287287
}
@@ -340,7 +340,7 @@ class KerasGraph {
340340
if (connection) {
341341
this._outputs.push({
342342
id: connection,
343-
type: 'T',
343+
type: null,
344344
name: connection
345345
});
346346
}
@@ -363,19 +363,20 @@ class KerasGraph {
363363
}
364364

365365
_loadInput(layer, input) {
366-
input.type = '';
366+
input.type = null;
367367
if (layer && layer.config) {
368+
var dataType = '?';
369+
var shape = [];
368370
var config = layer.config;
369371
if (config.dtype) {
370-
input.type = config.dtype;
372+
dataType = config.dtype;
371373
delete config.dtype;
372374
}
373375
if (config.batch_input_shape) {
374-
var shape = config.batch_input_shape;
375-
shape = shape.map(s => s == null ? '?' : s).join(',');
376-
input.type = input.type + '[' + shape + ']';
376+
shape = config.batch_input_shape.map(s => s == null ? '?' : s);
377377
delete config.batch_input_shape;
378378
}
379+
input.type = new KerasTensorType(dataType, shape);
379380
}
380381
}
381382
}
@@ -471,7 +472,7 @@ class KerasNode {
471472
input.connections.forEach((connection) => {
472473
var initializer = this._initializers[connection.id];
473474
if (initializer) {
474-
connection.type = initializer.type.toString();
475+
connection.type = initializer.type;
475476
connection.initializer = initializer;
476477
}
477478
});

src/mxnet-model.js

+6-6
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,10 @@ class MXNetGraph {
293293
var output = {};
294294
output.id = MXNetGraph._updateOutput(nodes, head);
295295
output.name = nodes[output.id[0]] ? nodes[output.id[0]].name : ('output' + ((index == 0) ? '' : (index + 1).toString()));
296-
output.type = 'T';
296+
output.type = null;
297297
var outputSignature = outputs[output.name];
298298
if (outputSignature && outputSignature.data_shape) {
299-
output.type = '?' + '[' + outputSignature.data_shape.toString() + ']';
299+
output.type = new MXNetTensorType(null, outputSignature.data_shape);
300300
}
301301
this._outputs.push(output);
302302
});
@@ -315,10 +315,10 @@ class MXNetGraph {
315315
var input = {};
316316
input.id = argument.outputs[0];
317317
input.name = argument.name;
318-
input.type = 'T';
318+
input.type = null;
319319
var inputSignature = inputs[input.name];
320320
if (inputSignature && inputSignature.data_shape) {
321-
input.type = '?' + '[' + inputSignature.data_shape.toString() + ']';
321+
input.type = new MXNetTensorType(null, inputSignature.data_shape);
322322
}
323323
this._inputs.push(input);
324324
}
@@ -464,7 +464,7 @@ class MXNetNode {
464464
var initializer = this._initializers[connection.id];
465465
if (initializer) {
466466
connection.id = initializer.name || connection.id;
467-
connection.type = initializer.type.toString();
467+
connection.type = initializer.type;
468468
connection.initializer = initializer;
469469
}
470470
});
@@ -674,7 +674,7 @@ class MXNetTensorType {
674674
}
675675

676676
toString() {
677-
return this.dataType + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
677+
return (this.dataType || '?') + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
678678
}
679679
}
680680

src/onnx-model.js

+72-43
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ class OnnxGraph {
252252
var initializer = this._initializerMap[connection.id];
253253
if (initializer) {
254254
connection.initializer = initializer;
255-
connection.type = connection.type || initializer.type.toString();
255+
connection.type = connection.type || initializer.type;
256256
}
257257
return connection;
258258
});
@@ -499,6 +499,7 @@ class OnnxTensor {
499499
this._tensor = tensor;
500500
this._id = id;
501501
this._kind = kind || null;
502+
this._type = new OnnxTensorType(this._tensor.dataType, this._tensor.dims.map((dim) => dim), null);
502503
}
503504

504505
get id() {
@@ -514,7 +515,7 @@ class OnnxTensor {
514515
}
515516

516517
get type() {
517-
return new OnnxTensorType(this._tensor);
518+
return this._type;
518519
}
519520

520521
get value() {
@@ -744,37 +745,8 @@ class OnnxTensor {
744745

745746
static _formatType(type, imageFormat) {
746747
if (!type) {
747-
return { value: '?' };
748+
return null;
748749
}
749-
var value = {};
750-
switch (type.value) {
751-
case 'tensorType':
752-
var tensorType = type.tensorType;
753-
var text = OnnxTensor._formatElementType(tensorType.elemType);
754-
if (tensorType.shape && tensorType.shape.dim) {
755-
text += '[' + tensorType.shape.dim.map((dimension) => {
756-
if (dimension.dimParam) {
757-
return dimension.dimParam;
758-
}
759-
return dimension.dimValue.toString();
760-
}).join(',') + ']';
761-
}
762-
value = text;
763-
break;
764-
case 'mapType':
765-
var keyType = OnnxTensor._formatElementType(type.mapType.keyType);
766-
var valueType = OnnxTensor._formatType(type.mapType.valueType);
767-
value = 'map<' + keyType + ',' + valueType.value + '>';
768-
break;
769-
case 'sequenceType':
770-
var elemType = OnnxTensor._formatType(type.sequenceType.elemType);
771-
value = 'sequence<' + elemType.value + '>';
772-
break;
773-
default:
774-
// debugger
775-
value = '?';
776-
break;
777-
}
778750
var denotation = '';
779751
switch (type.denotation) {
780752
case 'TENSOR':
@@ -790,21 +762,30 @@ class OnnxTensor {
790762
denotation = 'Text';
791763
break;
792764
}
793-
return { value: value, denotation: denotation };
765+
switch (type.value) {
766+
case 'tensorType':
767+
var shape = [];
768+
if (type.tensorType.shape && type.tensorType.shape.dim) {
769+
shape = type.tensorType.shape.dim.map((dim) => {
770+
return dim.dimParam ? dim.dimParam : dim.dimValue;
771+
});
772+
}
773+
return new OnnxTensorType(type.tensorType.elemType, shape, denotation);
774+
case 'mapType':
775+
return new OnnxMapType(type.mapType.keyType, OnnxTensor._formatType(type.mapType.valueType, imageFormat), denotation);
776+
case 'sequenceType':
777+
return new OnnxSequenceType(OnnxTensor._formatType(type.sequenceType.elemType, imageFormat), denotation);
778+
}
779+
return null;
794780
}
795781
}
796782

797783
class OnnxTensorType {
798784

799-
constructor(tensor) {
800-
this._dataType = '?';
801-
if (tensor.hasOwnProperty('dataType')) {
802-
this._dataType = OnnxTensor._formatElementType(tensor.dataType);
803-
}
804-
this._shape = [];
805-
if (tensor.hasOwnProperty('dims')) {
806-
this._shape = tensor.dims.map((dimension) => dimension);
807-
}
785+
constructor(dataType, shape, denotation) {
786+
this._dataType = OnnxTensor._formatElementType(dataType);
787+
this._shape = shape;
788+
this._denotation = denotation || null;
808789
}
809790

810791
get dataType() {
@@ -815,10 +796,58 @@ class OnnxTensorType {
815796
return this._shape;
816797
}
817798

799+
get denotation() {
800+
return this._denotation;
801+
}
802+
818803
toString() {
819-
return this.dataType + (this._shape ? ('[' + this._shape.map((dimension) => dimension.toString()).join(',') + ']') : '');
804+
return this.dataType + ((this._shape && this._shape.length) ? ('[' + this._shape.join(',') + ']') : '');
820805
}
806+
}
807+
808+
class OnnxSequenceType {
809+
810+
constructor(elementType, denotation) {
811+
this._elementType = elementType;
812+
this._denotation = denotation;
813+
}
814+
815+
get elementType() {
816+
return this._elementType;
817+
}
818+
819+
get dennotation() {
820+
return this._dennotation;
821+
}
822+
823+
toString() {
824+
return 'sequence<' + this._elementType.toString() + '>';
825+
}
826+
}
821827

828+
class OnnxMapType {
829+
830+
constructor(keyType, valueType, denotation) {
831+
this._keyType = OnnxTensor._formatElementType(keyType);
832+
this._valueType = valueType;
833+
this._denotation = denotation;
834+
}
835+
836+
get keyType() {
837+
return this._keyType;
838+
}
839+
840+
get valueType() {
841+
return this._valueType;
842+
}
843+
844+
get denotation() {
845+
return this._denotation;
846+
}
847+
848+
toString() {
849+
return 'map<' + this._keyType + ',' + this._valueType.toString() + '>';
850+
}
822851
}
823852

824853
class OnnxGraphOperatorMetadata {

src/tf-model.js

+3-5
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ class TensorFlowNode {
414414
input.connections.forEach((connection) => {
415415
var initializer = this._graph._getInitializer(connection.id);
416416
if (initializer) {
417-
connection.type = initializer.type.toString();
417+
connection.type = initializer.type;
418418
connection.initializer = initializer;
419419
}
420420
});
@@ -503,9 +503,6 @@ class TensorFlowAttribute {
503503
return TensorFlowTensor.formatTensorShape(value.shape);
504504
}
505505
else if (value.hasOwnProperty('s')) {
506-
if (value.s.length == 0) {
507-
return '';
508-
}
509506
if (value.s.filter(c => c <= 32 && c >= 128).length == 0) {
510507
return '"' + TensorFlowOperatorMetadata.textDecoder.decode(value.s) + '"';
511508
}
@@ -581,6 +578,7 @@ class TensorFlowTensor {
581578
if (kind) {
582579
this._kind = kind;
583580
}
581+
this._type = new TensorFlowTensorType(this._tensor.dtype, this._tensor.tensorShape);
584582
}
585583

586584
get id() {
@@ -592,7 +590,7 @@ class TensorFlowTensor {
592590
}
593591

594592
get type() {
595-
return new TensorFlowTensorType(this._tensor.dtype, this._tensor.tensorShape);
593+
return this._type;
596594
}
597595

598596
get kind() {

0 commit comments

Comments
 (0)