@@ -1604,6 +1604,10 @@ pytorch.Execution = class extends python.Execution {
1604
1604
}
1605
1605
throw new pytorch . Error ( "Unknown 'torch.ge' expression type." ) ;
1606
1606
} ) ;
1607
+ this . registerFunction ( 'torch.is_floating_point' , function ( tensor ) {
1608
+ const type = tensor . dtype . scalar_type ( ) ;
1609
+ return ( type === 5 || type === 6 || type === 7 ) ;
1610
+ } ) ;
1607
1611
this . registerFunction ( 'torch.jit._pickle.build_boollist' , function ( data ) {
1608
1612
return data ;
1609
1613
} ) ;
@@ -1748,7 +1752,7 @@ pytorch.Execution = class extends python.Execution {
1748
1752
throw new pytorch . Error ( 'Slicing only supports step=1' ) ;
1749
1753
}
1750
1754
start = Math . max ( 0 , start >= 0 ? start : l . length + start ) ;
1751
- end = Math . min ( l . length , end ) ;
1755
+ end = Math . min ( l . length , end || Number . MAX_SAFE_INTEGER ) ;
1752
1756
return l . slice ( start , end ) ;
1753
1757
} ) ;
1754
1758
this . registerFunction ( 'torch.sub' , function ( left , right ) {
@@ -2973,6 +2977,7 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
2973
2977
}
2974
2978
case 'torch.mean' :
2975
2979
case 'torch.mul' :
2980
+ case 'torch.div' :
2976
2981
case 'torch.batch_norm' :
2977
2982
case 'torch.gelu' :
2978
2983
case 'torch.relu' :
@@ -2983,7 +2988,8 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
2983
2988
}
2984
2989
break ;
2985
2990
}
2986
- case 'torch.add' : {
2991
+ case 'torch.add' :
2992
+ case 'torch.sub' : {
2987
2993
const input = this . expression ( args [ 0 ] , context ) ;
2988
2994
if ( pytorch . Utility . isTensor ( input ) && Array . isArray ( input . size ( ) ) ) {
2989
2995
parameter . resize_ ( input . size ( ) ) ;
@@ -2996,6 +3002,13 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
2996
3002
}
2997
3003
break ;
2998
3004
}
3005
+ case 'torch.select' : {
3006
+ const input = this . expression ( args [ 0 ] , context ) ;
3007
+ if ( pytorch . Utility . isTensor ( input ) && Array . isArray ( input . size ( ) ) ) {
3008
+ parameter . resize_ ( Array ( input . size ( ) . length - 1 ) . fill ( NaN ) ) ;
3009
+ }
3010
+ break ;
3011
+ }
2999
3012
case 'torch.layer_norm' : {
3000
3013
const input = this . expression ( args [ 0 ] , context ) ;
3001
3014
const normalized_shape = this . expression ( args [ 1 ] , context ) ;
@@ -3176,6 +3189,29 @@ pytorch.Container.Zip.Execution = class extends pytorch.Execution {
3176
3189
tensor . resize_ ( Array ( number ) . fill ( NaN ) ) ;
3177
3190
}
3178
3191
}
3192
+ // val = torch.slice(torch.size(img), -2)
3193
+ // if torch.eq(torch.len(val), 2):
3194
+ // pass
3195
+ // else:
3196
+ // ops.prim.RaiseException("AssertionError: ")
3197
+ if ( assign . type === '=' &&
3198
+ condition . type === 'if' &&
3199
+ pytorch . Utility . isCall ( assign . expression , 'torch.slice' , 2 ) &&
3200
+ pytorch . Utility . isCall ( assign . expression . arguments [ 0 ] , 'torch.size' , 1 ) &&
3201
+ pytorch . Utility . isCall ( condition . condition , 'torch.eq' , 2 ) &&
3202
+ pytorch . Utility . isCall ( condition . condition . arguments [ 0 ] , 'torch.len' , 1 ) &&
3203
+ pytorch . Utility . isEqual ( condition . condition . arguments [ 0 ] . arguments [ 0 ] , assign . target ) &&
3204
+ condition . else . statements . length == 1 &&
3205
+ pytorch . Utility . isCall ( condition . else . statements [ 0 ] , 'ops.prim.RaiseException' , 1 ) ) {
3206
+ const tensor = this . expression ( assign . expression . arguments [ 0 ] . arguments [ 0 ] , context ) ;
3207
+ if ( pytorch . Utility . isTensor ( tensor ) && tensor . shape === undefined ) {
3208
+ const start = this . expression ( assign . expression . arguments [ 1 ] , context ) ;
3209
+ const value = this . expression ( condition . condition . arguments [ 1 ] , context ) ;
3210
+ if ( Number . isInteger ( start ) && Number . isInteger ( value ) ) {
3211
+ tensor . resize_ ( Array ( value - start ) . fill ( NaN ) ) ;
3212
+ }
3213
+ }
3214
+ }
3179
3215
}
3180
3216
if ( statements . length > 1 ) {
3181
3217
const size = statements [ 0 ] ;
0 commit comments