@@ -221,42 +221,106 @@ function Base.:*(a::ExtendedTropical{K,TO}, b::ExtendedTropical{K,TO}) where {K,
221
221
return ExtendedTropical {K,TO} (sorted_sum_combination! (res, a. orders, b. orders))
222
222
end
223
223
224
+ # 1. bisect over summed value and find the critical value `c`,
225
+ # 2. collect the values with sum combination `≥ c`,
226
+ # 3. sort the collected values
224
227
function sorted_sum_combination! (res:: AbstractVector{TO} , A:: AbstractVector{TO} , B:: AbstractVector{TO} ) where TO
225
228
K = length (res)
226
229
@assert length (B) == length (A) == K
227
- @inbounds maxval = A[K] * B[K]
228
- ptr = K
229
- @inbounds res[ptr] = maxval
230
- @inbounds queue = [(K,K- 1 ,A[K]* B[K- 1 ]), (K- 1 ,K,A[K- 1 ]* B[K])]
231
- for k = 1 : K- 1
232
- @inbounds (i, j, res[K- k]) = _pop_max_sum! (queue) # TODO : do not enumerate, use better data structures
233
- _push_if_not_exists! (queue, i, j- 1 , A, B)
234
- _push_if_not_exists! (queue, i- 1 , j, A, B)
230
+ @inbounds high = A[K] * B[K]
231
+
232
+ mA = findfirst (! iszero, A)
233
+ mB = findfirst (! iszero, B)
234
+ if mA === nothing || mB === nothing
235
+ res .= Ref (zero (TO))
236
+ return res
237
+ end
238
+ @inbounds low = A[mA] * B[mB]
239
+ # count number bigger than x
240
+ c, _ = count_geq (A, B, mB, low, true )
241
+ @inbounds if c <= K # return
242
+ res[K- c+ 1 : K] .= sort! (collect_geq! (view (res,1 : c), A, B, mB, low))
243
+ if c < K
244
+ res[1 : K- c] .= zero (TO)
245
+ end
246
+ return res
247
+ end
248
+ # calculate by bisection for at most 30 times.
249
+ @inbounds for _ = 1 : 30
250
+ mid = mid_point (high, low)
251
+ c, nB = count_geq (A, B, mB, mid, true )
252
+ if c > K
253
+ low = mid
254
+ mB = nB
255
+ elseif c == K # return
256
+ # NOTE: this is the bottleneck
257
+ return sort! (collect_geq! (res, A, B, mB, mid))
258
+ else
259
+ high = mid
260
+ end
235
261
end
262
+ clow, _ = count_geq (A, B, mB, low, false )
263
+ @inbounds res .= sort! (collect_geq! (similar (res, clow), A, B, mB, low))[end - K+ 1 : end ]
236
264
return res
237
265
end
238
266
239
- function _push_if_not_exists! (queue, i, j, A, B)
240
- @inbounds if j>= 1 && i>= 1 && ! any (x-> x[1 ] >= i && x[2 ] >= j, queue)
241
- push! (queue, (i, j, A[i]* B[j]))
267
+ # count the number of sum-combinations with the sum >= low
268
+ function count_geq (A, B, mB, low, earlybreak)
269
+ K = length (A)
270
+ k = 1 # TODO : we should tighten mA, mB later!
271
+ @inbounds Ak = A[K- k+ 1 ]
272
+ @inbounds Bq = B[K- mB+ 1 ]
273
+ c = 0
274
+ nB = mB
275
+ @inbounds for q = K- mB+ 1 : - 1 : 1
276
+ Bq = B[K- q+ 1 ]
277
+ while k < K && Ak * Bq >= low
278
+ k += 1
279
+ Ak = A[K- k+ 1 ]
280
+ end
281
+ if Ak * Bq >= low
282
+ c += k
283
+ else
284
+ c += (k- 1 )
285
+ if k== 1
286
+ nB += 1
287
+ end
288
+ end
289
+ if earlybreak && c > K
290
+ return c, nB
291
+ end
242
292
end
293
+ return c, nB
243
294
end
244
295
245
- function _pop_max_sum! (queue)
246
- maxsum = first (queue)[3 ]
247
- maxloc = 1
248
- @inbounds for i= 2 : length (queue)
249
- m = queue[i][3 ]
250
- if m > maxsum
251
- maxsum = m
252
- maxloc = i
296
+ function collect_geq! (res, A, B, mB, low)
297
+ K = length (A)
298
+ k = 1 # TODO : we should tighten mA, mB later!
299
+ Ak = A[K- k+ 1 ]
300
+ Bq = B[K- mB+ 1 ]
301
+ l = 0
302
+ for q = K- mB+ 1 : - 1 : 1
303
+ Bq = B[K- q+ 1 ]
304
+ while k < K && Ak * Bq >= low
305
+ k += 1
306
+ Ak = A[K- k+ 1 ]
307
+ end
308
+ # push data
309
+ ck = Ak * Bq >= low ? k : k- 1
310
+ for j= 1 : ck
311
+ l += 1
312
+ res[l] = Bq * A[end - j+ 1 ]
253
313
end
254
314
end
255
- @inbounds data = queue[maxloc]
256
- deleteat! (queue, maxloc)
257
- return data
315
+ return res
258
316
end
259
317
318
+ # for bisection
319
+ mid_point (a:: Tropical{T} , b:: Tropical{T} ) where T = Tropical {T} ((a. n + b. n) / 2 )
320
+ mid_point (a:: CountingTropical{T,CT} , b:: CountingTropical{T,CT} ) where {T,CT} = CountingTropical {T,CT} ((a. n + b. n) / 2 , a. c)
321
+ mid_point (a:: Tropical{T} , b:: Tropical{T} ) where T<: Integer = Tropical {T} ((a. n + b. n) ÷ 2 )
322
+ mid_point (a:: CountingTropical{T,CT} , b:: CountingTropical{T,CT} ) where {T<: Integer ,CT} = CountingTropical {T,CT} ((a. n + b. n) ÷ 2 , a. c)
323
+
260
324
function Base.:+ (a:: ExtendedTropical{K,TO} , b:: ExtendedTropical{K,TO} ) where {K,TO}
261
325
res = Vector {TO} (undef, K)
262
326
ptr1, ptr2 = K, K
0 commit comments