Skip to content

Commit 96c2b51

Browse files
authored
[NDTensors] Avoid threadid in block sparse multithreading code (#1650)
1 parent 3583101 commit 96c2b51

16 files changed

+98
-396
lines changed

Diff for: NDTensors/Project.toml

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>"]
4-
version = "0.4.7"
4+
version = "0.4.8"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -19,7 +19,6 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
1919
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2020
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2121
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
22-
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
2322
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2423
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
2524
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -79,7 +78,6 @@ MacroTools = "0.5"
7978
MappedArrays = "0.4"
8079
Metal = "1"
8180
Octavian = "0.3"
82-
PackageExtensionCompat = "1"
8381
Random = "<0.0.1, 1.10"
8482
SimpleTraits = "0.9.4"
8583
SparseArrays = "<0.0.1, 1.10"

Diff for: NDTensors/src/NDTensors.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ include("blocksparse/contract.jl")
7272
include("blocksparse/contract_utilities.jl")
7373
include("blocksparse/contract_generic.jl")
7474
include("blocksparse/contract_sequential.jl")
75-
include("blocksparse/contract_folds.jl")
76-
include("blocksparse/contract_threads.jl")
75+
include("blocksparse/contract_threaded.jl")
7776
include("blocksparse/diagblocksparse.jl")
7877
include("blocksparse/similar.jl")
7978
include("blocksparse/combiner.jl")
@@ -221,9 +220,4 @@ end
221220

222221
function backend_octavian end
223222

224-
using PackageExtensionCompat
225-
function __init__()
226-
@require_extensions
227-
end
228-
229223
end # module NDTensors

Diff for: NDTensors/src/blocksparse/contract.jl

-2
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ function contract_blockoffsets(
4646
alg = Algorithm"sequential"()
4747
if using_threaded_blocksparse() && nthreads() > 1
4848
alg = Algorithm"threaded_threads"()
49-
# This code is a bit cleaner but slower:
50-
# alg = Algorithm"threaded_folds"()
5149
end
5250
return contract_blockoffsets(
5351
alg, boffs1, inds1, labels1, boffs2, inds2, labels2, indsR, labelsR

Diff for: NDTensors/src/blocksparse/contract_folds.jl

-60
This file was deleted.

Diff for: NDTensors/src/blocksparse/contract_generic.jl

+2-36
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,14 @@ function contract_blockoffsets(
1111
indsR,
1212
labelsR,
1313
)
14-
N1 = length(blocktype(boffs1))
15-
N2 = length(blocktype(boffs2))
1614
NR = length(labelsR)
1715
ValNR = ValLength(labelsR)
1816
labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR = contract_labels(
1917
labels1, labels2, labelsR
2018
)
21-
22-
# Contraction plan element type
23-
T = Tuple{Block{N1},Block{N2},Block{NR}}
24-
25-
# Thread-local collections of block contractions.
26-
# Could use:
27-
# ```julia
28-
# FLoops.@reduce(contraction_plans = append!(T[], [(block1, block2, blockR)]))
29-
# ```
30-
# as a simpler alternative but it is slower.
31-
32-
contraction_plans = Vector{T}[T[] for _ in 1:nthreads()]
33-
34-
#
35-
# Reserve some capacity
36-
# In theory the maximum is length(boffs1) * length(boffs2)
37-
# but in practice that is too much
38-
#for contraction_plan in contraction_plans
39-
# sizehint!(contraction_plan, max(length(boffs1), length(boffs2)))
40-
#end
41-
#
42-
43-
contract_blocks!(
44-
alg,
45-
contraction_plans,
46-
boffs1,
47-
boffs2,
48-
labels1_to_labels2,
49-
labels1_to_labelsR,
50-
labels2_to_labelsR,
51-
ValNR,
19+
contraction_plan = contract_blocks(
20+
alg, boffs1, boffs2, labels1_to_labels2, labels1_to_labelsR, labels2_to_labelsR, ValNR
5221
)
53-
54-
contraction_plan = reduce(vcat, contraction_plans)
5522
blockoffsetsR = BlockOffsets{NR}()
5623
nnzR = 0
5724
for (_, _, blockR) in contraction_plan
@@ -60,7 +27,6 @@ function contract_blockoffsets(
6027
nnzR += blockdim(indsR, blockR)
6128
end
6229
end
63-
6430
return blockoffsetsR, contraction_plan
6531
end
6632

Diff for: NDTensors/src/blocksparse/contract_threaded.jl

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using .Expose: expose
2+
function contract_blocks(
3+
alg::Algorithm"threaded_threads",
4+
boffs1,
5+
boffs2,
6+
labels1_to_labels2,
7+
labels1_to_labelsR,
8+
labels2_to_labelsR,
9+
ValNR::Val{NR},
10+
) where {NR}
11+
N1 = length(blocktype(boffs1))
12+
N2 = length(blocktype(boffs2))
13+
blocks1 = keys(boffs1)
14+
blocks2 = keys(boffs2)
15+
T = Tuple{Block{N1},Block{N2},Block{NR}}
16+
return if length(blocks1) > length(blocks2)
17+
tasks = map(
18+
Iterators.partition(blocks1, max(1, length(blocks1) ÷ nthreads()))
19+
) do blocks1_partition
20+
@spawn begin
21+
block_contractions = T[]
22+
for block1 in blocks1_partition
23+
for block2 in blocks2
24+
block_contraction = maybe_contract_blocks(
25+
block1,
26+
block2,
27+
labels1_to_labels2,
28+
labels1_to_labelsR,
29+
labels2_to_labelsR,
30+
ValNR,
31+
)
32+
if !isnothing(block_contraction)
33+
push!(block_contractions, block_contraction)
34+
end
35+
end
36+
end
37+
return block_contractions
38+
end
39+
end
40+
all_block_contractions = T[]
41+
for task in tasks
42+
append!(all_block_contractions, fetch(task))
43+
end
44+
return all_block_contractions
45+
else
46+
tasks = map(
47+
Iterators.partition(blocks2, max(1, length(blocks2) ÷ nthreads()))
48+
) do blocks2_partition
49+
@spawn begin
50+
block_contractions = T[]
51+
for block2 in blocks2_partition
52+
for block1 in blocks1
53+
block_contraction = maybe_contract_blocks(
54+
block1,
55+
block2,
56+
labels1_to_labels2,
57+
labels1_to_labelsR,
58+
labels2_to_labelsR,
59+
ValNR,
60+
)
61+
if !isnothing(block_contraction)
62+
push!(block_contractions, block_contraction)
63+
end
64+
end
65+
end
66+
return block_contractions
67+
end
68+
end
69+
all_block_contractions = T[]
70+
for task in tasks
71+
append!(all_block_contractions, fetch(task))
72+
end
73+
return all_block_contractions
74+
end
75+
end
76+
77+
function contract!(
78+
::Algorithm"threaded_folds",
79+
R::BlockSparseTensor,
80+
labelsR,
81+
tensor1::BlockSparseTensor,
82+
labelstensor1,
83+
tensor2::BlockSparseTensor,
84+
labelstensor2,
85+
contraction_plan,
86+
)
87+
executor = ThreadedEx()
88+
return contract!(
89+
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
90+
)
91+
end

0 commit comments

Comments
 (0)