Skip to content

Commit 5cdf318

Browse files
authored
New sum combination (#30)
* save * update * clean up * add one more comment
1 parent 477d550 commit 5cdf318

File tree

4 files changed

+99
-24
lines changed

4 files changed

+99
-24
lines changed

examples/IndependentSet.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ show_graph(graph; locs=locations)
3232
# x_i^{w_i}
3333
# \end{matrix}\right),
3434
# ```
35-
# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and `w_i` is the weight of vertex ``i``.
35+
# where ``W(x_i)_0=1`` is the first element associated with ``s_i=0`` and ``W(x_i)_1=x_i^{w_i}`` is the second element associated with ``s_i=1``, and ``w_i`` is the weight of vertex ``i``.
3636
# Similarly, on each edge ``(u, v)``, we define a matrix ``B`` indexed by ``s_u`` and ``s_v`` as
3737
# ```math
3838
# B = \left(\begin{matrix}

src/arithematics.jl

+86-22
Original file line numberDiff line numberDiff line change
@@ -221,42 +221,106 @@ function Base.:*(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,
221221
return ExtendedTropical{K,TO}(sorted_sum_combination!(res, a.orders, b.orders))
222222
end
223223

224+
# 1. bisect over summed value and find the critical value `c`,
225+
# 2. collect the values with sum combination `≥ c`,
226+
# 3. sort the collected values
224227
function sorted_sum_combination!(res::AbstractVector{TO}, A::AbstractVector{TO}, B::AbstractVector{TO}) where TO
225228
K = length(res)
226229
@assert length(B) == length(A) == K
227-
@inbounds maxval = A[K] * B[K]
228-
ptr = K
229-
@inbounds res[ptr] = maxval
230-
@inbounds queue = [(K,K-1,A[K]*B[K-1]), (K-1,K,A[K-1]*B[K])]
231-
for k = 1:K-1
232-
@inbounds (i, j, res[K-k]) = _pop_max_sum!(queue) # TODO: do not enumerate, use better data structures
233-
_push_if_not_exists!(queue, i, j-1, A, B)
234-
_push_if_not_exists!(queue, i-1, j, A, B)
230+
@inbounds high = A[K] * B[K]
231+
232+
mA = findfirst(!iszero, A)
233+
mB = findfirst(!iszero, B)
234+
if mA === nothing || mB === nothing
235+
res .= Ref(zero(TO))
236+
return res
237+
end
238+
@inbounds low = A[mA] * B[mB]
239+
# count number bigger than x
240+
c, _ = count_geq(A, B, mB, low, true)
241+
@inbounds if c <= K # return
242+
res[K-c+1:K] .= sort!(collect_geq!(view(res,1:c), A, B, mB, low))
243+
if c < K
244+
res[1:K-c] .= zero(TO)
245+
end
246+
return res
247+
end
248+
# calculate by bisection for at most 30 times.
249+
@inbounds for _ = 1:30
250+
mid = mid_point(high, low)
251+
c, nB = count_geq(A, B, mB, mid, true)
252+
if c > K
253+
low = mid
254+
mB = nB
255+
elseif c == K # return
256+
# NOTE: this is the bottleneck
257+
return sort!(collect_geq!(res, A, B, mB, mid))
258+
else
259+
high = mid
260+
end
235261
end
262+
clow, _ = count_geq(A, B, mB, low, false)
263+
@inbounds res .= sort!(collect_geq!(similar(res, clow), A, B, mB, low))[end-K+1:end]
236264
return res
237265
end
238266

239-
function _push_if_not_exists!(queue, i, j, A, B)
240-
@inbounds if j>=1 && i>=1 && !any(x->x[1] >= i && x[2] >= j, queue)
241-
push!(queue, (i, j, A[i]*B[j]))
267+
# count the number of sum-combinations with the sum >= low
268+
function count_geq(A, B, mB, low, earlybreak)
269+
K = length(A)
270+
k = 1 # TODO: we should tighten mA, mB later!
271+
@inbounds Ak = A[K-k+1]
272+
@inbounds Bq = B[K-mB+1]
273+
c = 0
274+
nB = mB
275+
@inbounds for q = K-mB+1:-1:1
276+
Bq = B[K-q+1]
277+
while k < K && Ak * Bq >= low
278+
k += 1
279+
Ak = A[K-k+1]
280+
end
281+
if Ak * Bq >= low
282+
c += k
283+
else
284+
c += (k-1)
285+
if k==1
286+
nB += 1
287+
end
288+
end
289+
if earlybreak && c > K
290+
return c, nB
291+
end
242292
end
293+
return c, nB
243294
end
244295

245-
function _pop_max_sum!(queue)
246-
maxsum = first(queue)[3]
247-
maxloc = 1
248-
@inbounds for i=2:length(queue)
249-
m = queue[i][3]
250-
if m > maxsum
251-
maxsum = m
252-
maxloc = i
296+
function collect_geq!(res, A, B, mB, low)
297+
K = length(A)
298+
k = 1 # TODO: we should tighten mA, mB later!
299+
Ak = A[K-k+1]
300+
Bq = B[K-mB+1]
301+
l = 0
302+
for q = K-mB+1:-1:1
303+
Bq = B[K-q+1]
304+
while k < K && Ak * Bq >= low
305+
k += 1
306+
Ak = A[K-k+1]
307+
end
308+
# push data
309+
ck = Ak * Bq >= low ? k : k-1
310+
for j=1:ck
311+
l += 1
312+
res[l] = Bq * A[end-j+1]
253313
end
254314
end
255-
@inbounds data = queue[maxloc]
256-
deleteat!(queue, maxloc)
257-
return data
315+
return res
258316
end
259317

318+
# for bisection
319+
mid_point(a::Tropical{T}, b::Tropical{T}) where T = Tropical{T}((a.n + b.n) / 2)
320+
mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T,CT} = CountingTropical{T,CT}((a.n + b.n) / 2, a.c)
321+
mid_point(a::Tropical{T}, b::Tropical{T}) where T<:Integer = Tropical{T}((a.n + b.n) ÷ 2)
322+
mid_point(a::CountingTropical{T,CT}, b::CountingTropical{T,CT}) where {T<:Integer,CT} = CountingTropical{T,CT}((a.n + b.n) ÷ 2, a.c)
323+
260324
function Base.:+(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,TO}
261325
res = Vector{TO}(undef, K)
262326
ptr1, ptr2 = K, K

src/bounding.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct CacheTree{T}
6060
end
6161
function cached_einsum(se::SlicedEinsum, @nospecialize(xs), size_dict)
6262
if length(se.slicing) != 0
63-
@warn "Slicing is not supported for caching! Fallback to `NestedEinsum`."
63+
@warn "Slicing is not supported for caching, got nslices = $(length(se.slicing))! Fallback to `NestedEinsum`."
6464
end
6565
return cached_einsum(se.eins, xs, size_dict)
6666
end

test/arithematics.jl

+11
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ end
153153
end
154154
end
155155

156+
@testset "count geq" begin
157+
A = collect(1:10)
158+
B = collect(2:2:20)
159+
low = 20
160+
c, _ = GraphTensorNetworks.count_geq(A, B, 1, low, false)
161+
@test c == count(x->x>=low, vec([a*b for a in A, b in B]))
162+
res = similar(A, c)
163+
@test sort!(GraphTensorNetworks.collect_geq!(res, A, B, 1, low)) == sort!(filter(x->x>=low, vec([a*b for a in A, b in B])))
164+
end
165+
166+
156167
# check the correctness of sampling
157168
@testset "generate samples" begin
158169
Random.seed!(2)

0 commit comments

Comments
 (0)