Skip to content

Commit 2bc5987

Browse files
committed
Cleanup
1 parent df65249 commit 2bc5987

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

lib/cusparse/broadcast.jl

+12-8
Original file line numberDiff line numberDiff line change
@@ -375,11 +375,13 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
375375
row_ix > output.nnz && return
376376
row_and_ptrs = @inbounds offsets[row_ix]
377377
row = @inbounds row_and_ptrs[1]
378-
args_are_nnz = @inbounds row_and_ptrs[2]
378+
arg_ptrs = @inbounds row_and_ptrs[2]
379379
vals = ntuple(Val(N)) do i
380380
arg = @inbounds args[i]
381-
arg_is_nnz = @inbounds args_are_nnz[i]
382-
_getindex(arg, row, arg_is_nnz)::Tv
381+
# ptr is 0 if the sparse vector doesn't have an element at this row
382+
# ptr is 0 if the arg is a scalar AND f preserves zeros
383+
ptr = @inbounds arg_ptrs[i]
384+
_getindex(arg, row, ptr)::Tv
383385
end
384386
output_val = f(vals...)
385387
@inbounds output.iPtr[row_ix] = row
@@ -455,14 +457,16 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
455457
row_ix > length(output) && return
456458
row_and_ptrs = @inbounds offsets[row_ix]
457459
row = @inbounds row_and_ptrs[1]
458-
args_are_nnz = @inbounds row_and_ptrs[2]
460+
arg_ptrs = @inbounds row_and_ptrs[2]
459461
vals = ntuple(Val(length(args))) do i
460462
arg = @inbounds args[i]
461-
arg_is_nnz = @inbounds args_are_nnz[i]
462-
_getindex(arg, row, arg_is_nnz)::Tv
463+
# ptr is 0 if the sparse vector doesn't have an element at this row
464+
# ptr is row if the arg is dense OR a scalar with non-zero-preserving f
465+
# ptr is 0 if the arg is a scalar AND f preserves zeros
466+
ptr = @inbounds arg_ptrs[i]
467+
_getindex(arg, row, ptr)::Tv
463468
end
464-
out_val = f(vals...)
465-
@inbounds output[row] = out_val
469+
@inbounds output[row] = f(vals...)
466470
return
467471
end
468472
## COV_EXCL_STOP

test/Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
55
CUDA_Driver_jll = "4ee394cb-3365-5eb0-8335-949819d2adfc"
66
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
98
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
109
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1110
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

0 commit comments

Comments
 (0)