Skip to content

Very rough implementation of bcast for CuSparseVector #2733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 24, 2025
213 changes: 174 additions & 39 deletions lib/cusparse/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ end
end
end
end
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
@inline function _capturescalars(arg)
# this definition is just an optimization (to bottom out the recursion slightly sooner)
if scalararg(arg)
return (), () -> (arg,) # add scalararg
elseif scalarwrappedarg(arg)
Expand All @@ -103,7 +104,6 @@ end

## COV_EXCL_START
## iteration helpers

"""
CSRIterator{Ti}(row, args...)

Expand Down Expand Up @@ -288,15 +288,20 @@ end
end

# helpers to index a sparse or dense array
function _getindex(arg::Union{CuSparseDeviceMatrixCSR,CuSparseDeviceMatrixCSC}, I, ptr)
@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},
CuSparseDeviceMatrixCSC{Tv},
CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
if ptr == 0
zero(eltype(arg))
return zero(Tv)
else
@inbounds arg.nzVal[ptr]
return @inbounds arg.nzVal[ptr]::Tv
end
end
_getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)

@inline function _getindex(arg::CuDeviceArray{Tv}, I, ptr)::Tv where {Tv}
return @inbounds arg[I]::Tv
end
@inline _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)

## sparse broadcast implementation

Expand All @@ -305,8 +310,46 @@ iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}

_has_row(A, offsets, row::Int32, fpreszeros::Bool) = fpreszeros ? 0i32 : row
_has_row(A::CuDeviceArray, offsets, row::Int32, ::Bool) = row
function _has_row(A::CuSparseDeviceVector, offsets, row::Int32, ::Bool)::Int32
for row_ix in 1i32:length(A.iPtr)
arg_row = @inbounds A.iPtr[row_ix]
arg_row == row && return row_ix
arg_row > row && break
end
return 0i32
end

function _get_my_row(first_row)::Int32
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
return row_ix + first_row - 1i32
end

function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti,
fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
args...) where {Ti, N}
row = _get_my_row(first_row)
row > last_row && return

# TODO load arg.iPtr slices into shared memory
row_is_nnz = 0i32
arg_row_is_nnz = ntuple(Val(N)) do i
arg = @inbounds args[i]
_has_row(arg, offsets, row, fpreszeros)::Int32
end
row_is_nnz = 0i32
for i in 1:N
row_is_nnz |= @inbounds arg_row_is_nnz[i]
end
key = (row_is_nnz == 0i32) ? typemax(Ti) : row
@inbounds offsets[row - first_row + 1i32] = key => arg_row_is_nnz
return
end

# kernel to count the number of non-zeros in a row, to determine the row offsets
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}},
offsets::AbstractVector{Ti},
args...) where Ti
# every thread processes an entire row
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
Expand All @@ -331,8 +374,30 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
return
end

# broadcast kernels that iterate the elements of sparse arrays
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti},
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
args...) where {Tv, Ti, N, F}
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
row_ix > output.nnz && return
row_and_ptrs = @inbounds offsets[row_ix]
row = @inbounds row_and_ptrs[1]
arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(N)) do i
arg = @inbounds args[i]
# ptr is 0 if the sparse vector doesn't have an element at this row
# ptr is 0 if the arg is a scalar AND f preserves zeros
ptr = @inbounds arg_ptrs[i]
_getindex(arg, row, ptr)::Tv
end
output_val = f(vals...)
@inbounds output.iPtr[row_ix] = row
@inbounds output.nzVal[row_ix] = output_val
return
end

function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing},
args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},
CuSparseDeviceMatrixCSC{<:Any,Ti}}}
# every thread processes an entire row
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
Expand All @@ -345,7 +410,7 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
# fetch the row offset, and write it to the output
@inbounds begin
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
if leading_dim == leading_dim_size
if leading_dim == leading_dim_size
output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32]
end
end
Expand All @@ -368,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract

return
end
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f,
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti},
CuSparseMatrixCSC{Tv, Ti}}}, f,
output::CuDeviceArray, args...) where {Tv, Ti}
# every thread processes an entire row
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
Expand All @@ -392,6 +458,28 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,

return
end

function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
output::CuDeviceArray{Tv},
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
args...) where {Tv, F, N, Ti}
# every thread processes an entire row
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
row_ix > length(output) && return
row_and_ptrs = @inbounds offsets[row_ix]
row = @inbounds row_and_ptrs[1]
arg_ptrs = @inbounds row_and_ptrs[2]
vals = ntuple(Val(length(args))) do i
arg = @inbounds args[i]
# ptr is 0 if the sparse vector doesn't have an element at this row
# ptr is row if the arg is dense OR a scalar with non-zero-preserving f
# ptr is 0 if the arg is a scalar AND f preserves zeros
ptr = @inbounds arg_ptrs[i]
_getindex(arg, row, ptr)::Tv
end
@inbounds output[row] = f(vals...)
return
end
## COV_EXCL_STOP

function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}})
Expand All @@ -405,12 +493,14 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
error("broadcast with multiple types of sparse arrays ($(join(sparse_types, ", "))) is not supported")
end
sparse_typ = typeof(bc.args[first(sparse_args)])
sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC} ||
error("broadcast with sparse arrays is currently only implemented for CSR and CSC matrices")
sparse_typ <: Union{CuSparseMatrixCSR,CuSparseMatrixCSC,CuSparseVector} ||
error("broadcast with sparse arrays is currently only implemented for vectors and CSR and CSC matrices")
Ti = if sparse_typ <: CuSparseMatrixCSR
reduce(promote_type, map(i->eltype(bc.args[i].rowPtr), sparse_args))
elseif sparse_typ <: CuSparseMatrixCSC
reduce(promote_type, map(i->eltype(bc.args[i].colPtr), sparse_args))
elseif sparse_typ <: CuSparseVector
reduce(promote_type, map(i->eltype(bc.args[i].iPtr), sparse_args))
end

# determine the output type
Expand All @@ -433,23 +523,32 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl

# the kernels below parallelize across rows or cols, not elements, so it's unlikely
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
rows, cols = size(bc)
rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) # `size(bc, ::Int)` is missing
function compute_launch_config(kernel)
config = launch_configuration(kernel.fun)
if sparse_typ <: CuSparseMatrixCSR
threads = min(rows, config.threads)
blocks = max(cld(rows, threads), config.blocks)
blocks = max(cld(rows, threads), config.blocks)
threads = cld(rows, blocks)
elseif sparse_typ <: CuSparseMatrixCSC
threads = min(cols, config.threads)
blocks = max(cld(cols, threads), config.blocks)
blocks = max(cld(cols, threads), config.blocks)
threads = cld(cols, blocks)
elseif sparse_typ <: CuSparseVector
threads = 512
blocks = max(cld(rows, threads), config.blocks)
end
(; threads, blocks)
end

# for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
# but the only rows present in any sparse vector input are between 2 and 128, no need to
# launch massive threads.
# TODO: use the difference here to set the thread count
overall_first_row = one(Ti)
overall_last_row = Ti(rows)
offsets = nothing
# allocate the output container
if !fpreszeros
if !fpreszeros && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
# either we have dense inputs, or the function isn't preserving zeros,
# so use a dense output to broadcast into.
output = CuArray{Tv}(undef, size(bc))
Expand All @@ -466,20 +565,20 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
end
end
broadcast!(bc.f, output, nonsparse_args...)
elseif length(sparse_args) == 1
elseif length(sparse_args) == 1 && sparse_typ <: Union{CuSparseMatrixCSR, CuSparseMatrixCSC}
# we only have a single sparse input, so we can reuse its structure for the output.
# this avoids a kernel launch and costly synchronization.
sparse_arg = bc.args[first(sparse_args)]
if sparse_typ <: CuSparseMatrixCSR
offsets = rowPtr = sparse_arg.rowPtr
colVal = similar(sparse_arg.colVal)
nzVal = similar(sparse_arg.nzVal, Tv)
output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
colVal = similar(sparse_arg.colVal)
nzVal = similar(sparse_arg.nzVal, Tv)
output = CuSparseMatrixCSR(rowPtr, colVal, nzVal, size(bc))
elseif sparse_typ <: CuSparseMatrixCSC
offsets = colPtr = sparse_arg.colPtr
rowVal = similar(sparse_arg.rowVal)
nzVal = similar(sparse_arg.nzVal, Tv)
output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
rowVal = similar(sparse_arg.rowVal)
nzVal = similar(sparse_arg.nzVal, Tv)
output = CuSparseMatrixCSC(colPtr, rowVal, nzVal, size(bc))
end
# NOTE: we don't use CUSPARSE's similar, because that copies the structure arrays,
# while we do that in our kernel (for consistency with other code paths)
Expand All @@ -490,43 +589,79 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
CuArray{Ti}(undef, rows+1)
elseif sparse_typ <: CuSparseMatrixCSC
CuArray{Ti}(undef, cols+1)
elseif sparse_typ <: CuSparseVector
CUDA.@allowscalar begin
arg_first_rows = ntuple(Val(length(bc.args))) do i
bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[1]
return one(Ti)
end
arg_last_rows = ntuple(Val(length(bc.args))) do i
bc.args[i] isa CuSparseVector && return bc.args[i].iPtr[end]
return Ti(rows)
end
end
overall_first_row = min(arg_first_rows...)
overall_last_row = max(arg_last_rows...)
CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1)
end
let
args = (sparse_typ, offsets, bc.args...)
args = if sparse_typ <: CuSparseVector
(sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
else
(sparse_typ, offsets, bc.args...)
end
kernel = @cuda launch=false compute_offsets_kernel(args...)
threads, blocks = compute_launch_config(kernel)
kernel(args...; threads, blocks)
end

# accumulate these values so that we can use them directly as row pointer offsets,
# as well as to get the total nnz count to allocate the sparse output array.
# cusparseXcsrgeam2Nnz computes this in one go, but it doesn't seem worth the effort
accumulate!(Base.add_sum, offsets, offsets)
total_nnz = @allowscalar last(offsets[end]) - 1

if !(sparse_typ <: CuSparseVector)
accumulate!(Base.add_sum, offsets, offsets)
total_nnz = @allowscalar last(offsets[end]) - 1
else
sort!(offsets; by=first)
total_nnz = mapreduce(x->first(x) != typemax(first(x)), +, offsets)
end
output = if sparse_typ <: CuSparseMatrixCSR
colVal = CuArray{Ti}(undef, total_nnz)
nzVal = CuArray{Tv}(undef, total_nnz)
nzVal = CuArray{Tv}(undef, total_nnz)
CuSparseMatrixCSR(offsets, colVal, nzVal, size(bc))
elseif sparse_typ <: CuSparseMatrixCSC
rowVal = CuArray{Ti}(undef, total_nnz)
nzVal = CuArray{Tv}(undef, total_nnz)
nzVal = CuArray{Tv}(undef, total_nnz)
CuSparseMatrixCSC(offsets, rowVal, nzVal, size(bc))
elseif sparse_typ <: CuSparseVector && !fpreszeros
CuArray{Tv}(undef, size(bc))
elseif sparse_typ <: CuSparseVector && fpreszeros
iPtr = CUDA.zeros(Ti, total_nnz)
nzVal = CUDA.zeros(Tv, total_nnz)
CuSparseVector(iPtr, nzVal, rows)
end
if sparse_typ <: CuSparseVector && !fpreszeros
nonsparse_args = map(bc.args) do arg
# NOTE: this assumes the broadcst is flattened, but not yet preprocessed
if arg isa AbstractCuSparseArray
zero(eltype(arg))
else
arg
end
end
broadcast!(bc.f, output, nonsparse_args...)
end
end

# perform the actual broadcast
if output isa AbstractCuSparseArray
args = (bc.f, output, offsets, bc.args...)
args = (bc.f, output, offsets, bc.args...)
kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
threads, blocks = compute_launch_config(kernel)
kernel(args...; threads, blocks)
else
args = (sparse_typ, bc.f, output, bc.args...)
args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
(sparse_typ, bc.f, output, bc.args...)
kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...)
threads, blocks = compute_launch_config(kernel)
kernel(args...; threads, blocks)
end
threads, blocks = compute_launch_config(kernel)
kernel(args...; threads, blocks)

return output
end
2 changes: 1 addition & 1 deletion lib/cusparse/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct CuSparseDeviceVector{Tv,Ti, A} <: AbstractSparseVector{Tv,Ti}
nnz::Ti
end

Base.length(g::CuSparseDeviceVector) = prod(g.dims)
Base.length(g::CuSparseDeviceVector) = g.len
Base.size(g::CuSparseDeviceVector) = (g.len,)
SparseArrays.nnz(g::CuSparseDeviceVector) = g.nnz

Expand Down
Loading