Skip to content

Commit 63cf6a4

Browse files
authored
Tree enumerator (#17)
* save * update * done * work but slow length * non-binary * update * tree storage fully tested * fix docs * fix project.toml
1 parent 4ce1b4e commit 63cf6a4

12 files changed

+335
-57
lines changed

Project.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "GraphTensorNetworks"
22
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
33
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
7+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
910
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
@@ -35,9 +36,10 @@ Polynomials = "2.0"
3536
Primes = "0.5"
3637
Requires = "1"
3738
SIMDTypes = "0.1"
39+
StatsBase = "0.33"
3840
TropicalNumbers = "0.4, 0.5"
3941
Viznet = "0.3"
40-
StatsBase = "0.33"
42+
AbstractTrees = "0.3"
4143
julia = "1"
4244

4345
[extras]

docs/src/ref.md

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ Polynomials.Polynomial
6868
TruncatedPoly
6969
Max2Poly
7070
ConfigEnumerator
71+
TreeConfigEnumerator
7172
ConfigSampler
7273
```
7374

examples/IndependentSet.jl

+11
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ Compose.compose(context(),
127127
# One can use [`ConfigsAll`](@ref) to enumerate all sets satisfying the problem constraint.
128128
all_independent_sets = solve(problem, ConfigsAll())[]
129129

130+
# It is often difficult to store all configurations in a vector.
131+
# A more clever way to store the data is using the [`TreeConfigEnumerator`](@ref) format.
132+
all_independent_sets_tree = solve(problem, ConfigsAll(; tree_storage=true))[]
133+
134+
# The results encode the configurations in the sum-product-tree format. One can count and enumerate them explicitly by typing
135+
length(all_independent_sets_tree)
136+
137+
#
138+
139+
collect(all_independent_sets_tree)
140+
130141
# To save/read a set of configuration to disk, one can type the following
131142
filename = tempname()
132143

examples/MaximalIS.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ counting_min_maximal_independent_set = solve(problem, CountingMin())[]
7171
# ##### finding all maximal independent set
7272
maximal_configs = solve(problem, ConfigsAll())[]
7373

74-
all(c->is_maximal_independent_set(g, i), maximal_configs)
74+
all(c->is_maximal_independent_set(graph, c), maximal_configs)
7575

7676
#
7777

@@ -90,7 +90,7 @@ cliques = maximal_cliques(complement(graph))
9090

9191
# ##### finding minimum maximal independent set
9292
# It is the [`ConfigsMin`](@ref) property in the program.
93-
minimum_maximal_configs = solve(problem, ConfigsMin())[]
93+
minimum_maximal_configs = solve(problem, ConfigsMin())[].c
9494

9595
imgs2 = ntuple(k->show_graph(graph;
9696
locs=locations, scale=0.25,

src/GraphTensorNetworks.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export GreedyMethod, TreeSA, SABipartite, KaHyParBipartite, MergeVectors, MergeG
1414
# Algebras
1515
export StaticBitVector, StaticElementVector, @bv_str
1616
export is_commutative_semiring
17-
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
17+
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator
1818
export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32
1919

2020
# Lower level APIs

src/arithematics.jl

+185-11
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ using Polynomials: Polynomial
22
using TropicalNumbers: Tropical, CountingTropical
33
using Mods, Primes
44
using Base.Cartesian
5+
import AbstractTrees: children, printnode, print_tree
6+
7+
@enum TreeTag LEAF SUM PROD ZERO
58

69
# pirate
710
Base.abs(x::Mod) = x
@@ -275,20 +278,179 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
275278
Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C})
276279
Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})
277280

278-
# A patch to make `Polynomial{ConfigEnumerator}` work
279-
function Base.:*(a::Int, y::ConfigEnumerator)
280-
a == 0 && return zero(y)
281-
a == 1 && return y
282-
error("multiplication between int and config enumerator is not defined.")
281+
# tree config enumerator
282+
"""
283+
TreeConfigEnumerator{N,S,C}
284+
285+
Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
286+
and is often more memory efficient than putting the configurations in a vector.
287+
`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
288+
289+
Fields
290+
-----------------------
291+
* `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
292+
* `data` is the element stored in a `LEAF` node.
293+
* `left` and `right` are two operands of a `SUM` or `PROD` node.
294+
295+
Example
296+
------------------------
297+
```jldoctest; setup=:(using GraphTensorNetworks)
298+
julia> s = TreeConfigEnumerator(bv"00111")
299+
00111
300+
301+
302+
julia> q = TreeConfigEnumerator(bv"10000")
303+
10000
304+
305+
306+
julia> x = s + q
307+
+
308+
├─ 00111
309+
└─ 10000
310+
311+
312+
julia> y = x * x
313+
*
314+
├─ +
315+
│ ├─ 00111
316+
│ └─ 10000
317+
└─ +
318+
├─ 00111
319+
└─ 10000
320+
321+
322+
julia> collect(y)
323+
4-element Vector{StaticBitVector{5, 1}}:
324+
00111
325+
10111
326+
10111
327+
10000
328+
329+
julia> zero(s)
330+
331+
332+
333+
julia> one(s)
334+
00000
335+
336+
337+
```
338+
"""
339+
struct TreeConfigEnumerator{N,S,C}
340+
tag::TreeTag
341+
data::StaticElementVector{N,S,C}
342+
left::TreeConfigEnumerator{N,S,C}
343+
right::TreeConfigEnumerator{N,S,C}
344+
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = new{N,S,C}(tag, zero(StaticElementVector{N,S,C}), left, right)
345+
function TreeConfigEnumerator(data::StaticElementVector{N,S,C}) where {N,S,C}
346+
new{N,S,C}(LEAF, data)
347+
end
348+
function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C}
349+
@assert tag === ZERO
350+
return new{N,S,C}(tag)
351+
end
283352
end
284-
function Base.:*(a::Int, y::ConfigSampler)
285-
a == 0 && return zero(y)
286-
a == 1 && return y
287-
error("multiplication between int and config sampler is not defined.")
353+
354+
# AbstractTree APIs
355+
function children(t::TreeConfigEnumerator)
356+
if isdefined(t, :left)
357+
if isdefined(t, :right)
358+
return [t.left, t.right]
359+
else
360+
return [t.left]
361+
end
362+
else
363+
if isdefined(t, :right)
364+
return [t.right]
365+
else
366+
return typeof(t)[]
367+
end
368+
end
369+
end
370+
function printnode(io::IO, t::TreeConfigEnumerator)
371+
if t.tag === LEAF
372+
print(io, t.data)
373+
elseif t.tag === ZERO
374+
print(io, "")
375+
elseif t.tag === SUM
376+
print(io, "+")
377+
else # PROD
378+
print(io, "*")
379+
end
380+
end
381+
382+
function Base.length(x::TreeConfigEnumerator)
383+
if x.tag === SUM
384+
return length(x.left) + length(x.right)
385+
elseif x.tag === PROD
386+
return length(x.left) * length(x.right)
387+
elseif x.tag === ZERO
388+
return 0
389+
else
390+
return 1
391+
end
392+
end
393+
394+
function num_nodes(x::TreeConfigEnumerator)
395+
x.tag == ZERO && return 1
396+
x.tag == LEAF && return 1
397+
return num_nodes(x.left) + num_nodes(x.right) + 1
398+
end
399+
400+
function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
401+
return Set(collect(x)) == Set(collect(y))
402+
end
403+
404+
Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t)
405+
406+
function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
407+
if x.tag == ZERO
408+
return StaticElementVector{N,S,C}[]
409+
elseif x.tag == LEAF
410+
return StaticElementVector{N,S,C}[x.data]
411+
elseif x.tag == SUM
412+
return vcat(collect(x.left), collect(x.right))
413+
else # PROD
414+
return vec([reduce((x,y)->x|y, si) for si in Iterators.product(collect(x.left), collect(x.right))])
415+
end
416+
end
417+
418+
function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
419+
TreeConfigEnumerator(SUM, x, y)
420+
end
421+
422+
function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C}
423+
TreeConfigEnumerator(PROD, x, y)
424+
end
425+
426+
Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO)
427+
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator(zero(StaticElementVector{N,S,C}))
428+
Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C})
429+
Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C})
430+
# todo, check siblings too?
431+
function Base.iszero(t::TreeConfigEnumerator)
432+
if t.TAG == SUM
433+
iszero(t.left) && iszero(t.right)
434+
elseif t.TAG == ZERO
435+
true
436+
elseif t.TAG == LEAF
437+
false
438+
else
439+
iszero(t.left) || iszero(t.right)
440+
end
441+
end
442+
443+
# A patch to make `Polynomial{ConfigEnumerator}` work
444+
for T in [:ConfigEnumerator, :ConfigSampler, :TreeConfigEnumerator]
445+
@eval function Base.:*(a::Int, y::$T)
446+
a == 0 && return zero(y)
447+
a == 1 && return y
448+
error("multiplication between int and `$(typeof(y))` is not defined.")
449+
end
288450
end
289451

290452
# convert from counting type to bitstring type
291-
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler)]
453+
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler), (:treeset_type, :TreeConfigEnumerator)]
292454
@eval begin
293455
function $F(::Type{T}, n::Int, nflavor::Int) where {OT, K, T<:TruncatedPoly{K,C,OT} where C}
294456
TruncatedPoly{K, $F(n,nflavor),OT}
@@ -312,12 +474,24 @@ end
312474

313475
# utilities for creating onehot vectors
314476
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
477+
onehotv(::Type{TreeConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = TreeConfigEnumerator(onehotv(StaticElementVector{N,S,C}, i, v))
315478
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
479+
# just to make matrix transpose work
316480
Base.transpose(c::ConfigEnumerator) = c
317481
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
482+
Base.transpose(c::TreeConfigEnumerator) = c
483+
function Base.copy(c::TreeConfigEnumerator)
484+
if c.tag == LEAF
485+
TreeConfigEnumerator(c.data)
486+
elseif c.tag == ZERO
487+
TreeConfigEnumerator(c.tag)
488+
else
489+
TreeConfigEnumerator(c.tag, c.left, c.right)
490+
end
491+
end
318492

319493
# Handle boolean, this is a patch for CUDA matmul
320-
for TYPE in [:ConfigEnumerator, :ConfigSampler, :TruncatedPoly]
494+
for TYPE in [:ConfigEnumerator, :ConfigSampler, :TruncatedPoly, :TreeConfigEnumerator]
321495
@eval Base.:*(a::Bool, y::$TYPE) = a ? y : zero(y)
322496
@eval Base.:*(y::$TYPE, a::Bool) = a ? y : zero(y)
323497
end

src/bounding.jl

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
2828
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical(largest_k(mode)-1+1e-12)
2929
push!(masks, mask)
3030
elseif mode isa SingleConfig
31-
A = zeros(eltype(xs[i]), size(xs[i]))
3231
A = einsum(EinCode(nixs, niy), nxs, size_dict)
3332
push!(masks, onehotmask(A, xs[i]))
3433
else

0 commit comments

Comments
 (0)