Skip to content

[NDTensors] Avoid threadid in block sparse multithreading code #1650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2025
4 changes: 1 addition & 3 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
version = "0.4.7"
version = "0.4.8"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -19,7 +19,6 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down Expand Up @@ -79,7 +78,6 @@ MacroTools = "0.5"
MappedArrays = "0.4"
Metal = "1"
Octavian = "0.3"
PackageExtensionCompat = "1"
Random = "<0.0.1, 1.10"
SimpleTraits = "0.9.4"
SparseArrays = "<0.0.1, 1.10"
Expand Down
8 changes: 1 addition & 7 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ include("blocksparse/contract.jl")
include("blocksparse/contract_utilities.jl")
include("blocksparse/contract_generic.jl")
include("blocksparse/contract_sequential.jl")
include("blocksparse/contract_folds.jl")
include("blocksparse/contract_threads.jl")
include("blocksparse/contract_threaded.jl")
include("blocksparse/diagblocksparse.jl")
include("blocksparse/similar.jl")
include("blocksparse/combiner.jl")
Expand Down Expand Up @@ -221,9 +220,4 @@ end

function backend_octavian end

using PackageExtensionCompat
function __init__()
@require_extensions
end

end # module NDTensors
2 changes: 0 additions & 2 deletions NDTensors/src/blocksparse/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ function contract_blockoffsets(
alg = Algorithm"sequential"()
if using_threaded_blocksparse() && nthreads() > 1
alg = Algorithm"threaded_threads"()
# This code is a bit cleaner but slower:
# alg = Algorithm"threaded_folds"()
end
return contract_blockoffsets(
alg, boffs1, inds1, labels1, boffs2, inds2, labels2, indsR, labelsR
Expand Down
60 changes: 0 additions & 60 deletions NDTensors/src/blocksparse/contract_folds.jl

This file was deleted.

38 changes: 2 additions & 36 deletions NDTensors/src/blocksparse/contract_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,14 @@ function contract_blockoffsets(
indsR,
labelsR,
)
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
NR = length(labelsR)
ValNR = ValLength(labelsR)
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
labels1, labels2, labelsR
)

# Contraction plan element type
T = Tuple{Block{N1},Block{N2},Block{NR}}

# Thread-local collections of block contractions.
# Could use:
# ```julia
# FLoops.@reduce(contraction_plans = append!(T[], [(block1, block2, blockR)]))
# ```
# as a simpler alternative but it is slower.

contraction_plans = Vector{T}[T[] for _ in 1:nthreads()]

#
# Reserve some capacity
# In theory the maximum is length(boffs1) * length(boffs2)
# but in practice that is too much
#for contraction_plan in contraction_plans
# sizehint!(contraction_plan, max(length(boffs1), length(boffs2)))
#end
#

contract_blocks!(
alg,
contraction_plans,
boffs1,
boffs2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
contraction_plan = contract_blocks(
alg, boffs1, boffs2, labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR, ValNR
)

contraction_plan = reduce(vcat, contraction_plans)
blockoffsetsR = BlockOffsets{NR}()
nnzR = 0
for (_, _, blockR) in contraction_plan
Expand All @@ -60,7 +27,6 @@ function contract_blockoffsets(
nnzR += blockdim(indsR, blockR)
end
end

return blockoffsetsR, contraction_plan
end

Expand Down
91 changes: 91 additions & 0 deletions NDTensors/src/blocksparse/contract_threaded.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using .Expose: expose
function contract_blocks(
alg::Algorithm"threaded_threads",
boffs1,
boffs2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR::Val{NR},
) where {NR}
N1 = length(blocktype(boffs1))
N2 = length(blocktype(boffs2))
blocks1 = keys(boffs1)
blocks2 = keys(boffs2)
T = Tuple{Block{N1},Block{N2},Block{NR}}
return if length(blocks1) > length(blocks2)
tasks = map(
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
) do blocks1_partition
@spawn begin
block_contractions = T[]
for block1 in blocks1_partition
for block2 in blocks2
block_contraction = maybe_contract_blocks(
block1,
block2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
)
if !isnothing(block_contraction)
push!(block_contractions, block_contraction)
end
end
end
return block_contractions
end
end
all_block_contractions = T[]
for task in tasks
append!(all_block_contractions, fetch(task))
end
return all_block_contractions
else
tasks = map(
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
) do blocks2_partition
@spawn begin
block_contractions = T[]
for block2 in blocks2_partition
for block1 in blocks1
block_contraction = maybe_contract_blocks(
block1,
block2,
labels1_to_labels2,
labels1_to_labelsR,
labels2_to_labelsR,
ValNR,
)
if !isnothing(block_contraction)
push!(block_contractions, block_contraction)
end
end
end
return block_contractions
end
end
all_block_contractions = T[]
for task in tasks
append!(all_block_contractions, fetch(task))
end
return all_block_contractions
end
end

function contract!(
::Algorithm"threaded_folds",
R::BlockSparseTensor,
labelsR,
tensor1::BlockSparseTensor,
labelstensor1,
tensor2::BlockSparseTensor,
labelstensor2,
contraction_plan,
)
executor = ThreadedEx()
return contract!(
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
)
end
Loading
Loading