Skip to content

Commit 96a5eee

Browse files
committed
Add TorchScript test file (#842) (#851)
1 parent 941048a commit 96a5eee

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

Diff for: source/python.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -2800,7 +2800,7 @@ python.Execution = class {
28002800
break;
28012801
}
28022802
case 'var': {
2803-
context.set(statement.name, undefined);
2803+
context.set(statement.name, statement.initializer ? this.expression(statement.initializer, context) : undefined);
28042804
break;
28052805
}
28062806
case '=': {

Diff for: source/pytorch.js

+38-2
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,10 @@ pytorch.Execution = class extends python.Execution {
16041604
}
16051605
throw new pytorch.Error("Unknown 'torch.ge' expression type.");
16061606
});
1607+
this.registerFunction('torch.is_floating_point', function(tensor) {
1608+
const type = tensor.dtype.scalar_type();
1609+
return (type === 5 || type === 6 || type === 7);
1610+
});
16071611
this.registerFunction('torch.jit._pickle.build_boollist', function(data) {
16081612
return data;
16091613
});
@@ -1748,7 +1752,7 @@ pytorch.Execution = class extends python.Execution {
17481752
throw new pytorch.Error('Slicing only supports step=1');
17491753
}
17501754
start = Math.max(0, start >= 0 ? start : l.length + start);
1751-
end = Math.min(l.length, end);
1755+
end = Math.min(l.length, end || Number.MAX_SAFE_INTEGER);
17521756
return l.slice(start, end);
17531757
});
17541758
this.registerFunction('torch.sub', function(left, right) {
@@ -2973,6 +2977,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
29732977
}
29742978
case 'torch.mean':
29752979
case 'torch.mul':
2980+
case 'torch.div':
29762981
case 'torch.batch_norm':
29772982
case 'torch.gelu':
29782983
case 'torch.relu':
@@ -2983,7 +2988,8 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
29832988
}
29842989
break;
29852990
}
2986-
case 'torch.add': {
2991+
case 'torch.add':
2992+
case 'torch.sub': {
29872993
const input = this.expression(args[0], context);
29882994
if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
29892995
parameter.resize_(input.size());
@@ -2996,6 +3002,13 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
29963002
}
29973003
break;
29983004
}
3005+
case 'torch.select': {
3006+
const input = this.expression(args[0], context);
3007+
if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) {
3008+
parameter.resize_(Array(input.size().length - 1).fill(NaN));
3009+
}
3010+
break;
3011+
}
29993012
case 'torch.layer_norm': {
30003013
const input = this.expression(args[0], context);
30013014
const normalized_shape = this.expression(args[1], context);
@@ -3176,6 +3189,29 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
31763189
tensor.resize_(Array(number).fill(NaN));
31773190
}
31783191
}
3192+
// val = torch.slice(torch.size(img), -2)
3193+
// if torch.eq(torch.len(val), 2):
3194+
// pass
3195+
// else:
3196+
// ops.prim.RaiseException("AssertionError: ")
3197+
if (assign.type === '=' &&
3198+
condition.type === 'if' &&
3199+
pytorch.Utility.isCall(assign.expression, 'torch.slice', 2) &&
3200+
pytorch.Utility.isCall(assign.expression.arguments[0], 'torch.size', 1) &&
3201+
pytorch.Utility.isCall(condition.condition, 'torch.eq', 2) &&
3202+
pytorch.Utility.isCall(condition.condition.arguments[0], 'torch.len', 1) &&
3203+
pytorch.Utility.isEqual(condition.condition.arguments[0].arguments[0], assign.target) &&
3204+
condition.else.statements.length == 1 &&
3205+
pytorch.Utility.isCall(condition.else.statements[0], 'ops.prim.RaiseException', 1)) {
3206+
const tensor = this.expression(assign.expression.arguments[0].arguments[0], context);
3207+
if (pytorch.Utility.isTensor(tensor) && tensor.shape === undefined) {
3208+
const start = this.expression(assign.expression.arguments[1], context);
3209+
const value = this.expression(condition.condition.arguments[1], context);
3210+
if (Number.isInteger(start) && Number.isInteger(value)) {
3211+
tensor.resize_(Array(value - start).fill(NaN));
3212+
}
3213+
}
3214+
}
31793215
}
31803216
if (statements.length > 1) {
31813217
const size = statements[0];

Diff for: test/models.json

+9-2
Original file line numberDiff line numberDiff line change
@@ -4326,8 +4326,8 @@
43264326
{
43274327
"type": "pytorch",
43284328
"target": "fasterrcnn_resnet50_fpn.pt",
4329-
"source": "https://github.com/lutzroeder/netron/files/6040364/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]",
4330-
"error": "Unsupported function 'torch.full' in 'fasterrcnn_resnet50_fpn.pt'.",
4329+
"source": "https://github.com/lutzroeder/netron/files/7677467/fasterrcnn_resnet50_fpn.pt.zip[fasterrcnn_resnet50_fpn.pt]",
4330+
"error": "Unknown torch.add expression type in 'fasterrcnn_resnet50_fpn.pt'.",
43314331
"link": "https://github.com/lutzroeder/netron/issues/689"
43324332
},
43334333
{
@@ -4859,6 +4859,13 @@
48594859
"source": "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
48604860
"format": "PyTorch v0.1.1"
48614861
},
4862+
{
4863+
"type": "pytorch",
4864+
"target": "ssdlite320_mobilenet_v3_large.pt",
4865+
"source": "https://github.com/lutzroeder/netron/files/7677468/ssdlite320_mobilenet_v3_large.pt.zip[ssdlite320_mobilenet_v3_large.pt]",
4866+
"error": "l.slice is not a function in 'ssdlite320_mobilenet_v3_large.pt'.",
4867+
"link": "https://github.com/lutzroeder/netron/issues/842"
4868+
},
48624869
{
48634870
"type": "pytorch",
48644871
"target": "superpoint_v1.pth",

0 commit comments

Comments
 (0)