@@ -375,11 +375,13 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
375
375
row_ix > output. nnz && return
376
376
row_and_ptrs = @inbounds offsets[row_ix]
377
377
row = @inbounds row_and_ptrs[1 ]
378
- args_are_nnz = @inbounds row_and_ptrs[2 ]
378
+ arg_ptrs = @inbounds row_and_ptrs[2 ]
379
379
vals = ntuple (Val (N)) do i
380
380
arg = @inbounds args[i]
381
- arg_is_nnz = @inbounds args_are_nnz[i]
382
- _getindex (arg, row, arg_is_nnz):: Tv
381
+ # ptr is 0 if the sparse vector doesn't have an element at this row
382
+ # ptr is 0 if the arg is a scalar AND f preserves zeros
383
+ ptr = @inbounds arg_ptrs[i]
384
+ _getindex (arg, row, ptr):: Tv
383
385
end
384
386
output_val = f (vals... )
385
387
@inbounds output. iPtr[row_ix] = row
@@ -455,14 +457,16 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
455
457
row_ix > length (output) && return
456
458
row_and_ptrs = @inbounds offsets[row_ix]
457
459
row = @inbounds row_and_ptrs[1 ]
458
- args_are_nnz = @inbounds row_and_ptrs[2 ]
460
+ arg_ptrs = @inbounds row_and_ptrs[2 ]
459
461
vals = ntuple (Val (length (args))) do i
460
462
arg = @inbounds args[i]
461
- arg_is_nnz = @inbounds args_are_nnz[i]
462
- _getindex (arg, row, arg_is_nnz):: Tv
463
+ # ptr is 0 if the sparse vector doesn't have an element at this row
464
+ # ptr is row if the arg is dense OR a scalar with non-zero-preserving f
465
+ # ptr is 0 if the arg is a scalar AND f preserves zeros
466
+ ptr = @inbounds arg_ptrs[i]
467
+ _getindex (arg, row, ptr):: Tv
463
468
end
464
- out_val = f (vals... )
465
- @inbounds output[row] = out_val
469
+ @inbounds output[row] = f (vals... )
466
470
return
467
471
end
468
472
# # COV_EXCL_STOP
0 commit comments