Skip to content

Commit d51b4a2

Browse files
authored
iddict to dict (#31)
* iddict to dict * save remove stack * fix sampling * clean up non-recursive _length * fix sampling * fix docs
1 parent d1d95f7 commit d51b4a2

11 files changed

+261
-165
lines changed

Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphTensorNetworks"
22
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
33
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -21,6 +21,7 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2222
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2323
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
24+
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
2425
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2526
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
2627
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"

docs/src/performancetips.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ Key word argument `optimizer` decides the contraction order optimizer of the ten
1515
Here, we choose the `TreeSA` optimizer to optimize the tensor network contraciton tree, it is a local search based algorithm.
1616
It is one of the state of the art tensor network contraction order optimizers, one may check [arXiv: 2108.05665](https://arxiv.org/abs/2108.05665) to learn more about the algorithm.
1717
Other optimizers include
18-
* [`GreedyMethod`](@ref) (default, fastest in searching speed but worse in contraction order)
19-
* [`TreeSA`](@ref)
18+
* [`GreedyMethod`](@ref) (default, fastest in searching speed but worst in contraction complexity)
19+
* [`TreeSA`](@ref) (often best in contraction complexity, supports slicing)
2020
* [`KaHyParBipartite`](@ref)
2121
* [`SABipartite`](@ref)
2222

@@ -32,8 +32,8 @@ julia> timespacereadwrite_complexity(problem)
3232
```
3333

3434
The return values are `log2` of the the number of iterations, the number elements in the largest tensor during contraction and the number of read-write operations to tensor elements.
35-
In this example, the number of `+` and `*` operations are both `\sim 2^{21.9}`
36-
and the number of read-write operations are `\sim 2^{20}`.
35+
In this example, the number of `+` and `*` operations are both ``\sim 2^{21.9}``
36+
and the number of read-write operations are ``\sim 2^{20}``.
3737
The largest tensor size is ``2^17``, one can check the element size by typing
3838
```julia
3939
julia> sizeof(TropicalF64)
@@ -136,7 +136,7 @@ julia> lineplot(hamming_distribution(samples, samples))
136136
```
137137

138138
## Multiprocessing
139-
Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`multiprocess_run`](@ref) function for simple multi-processing jobs.
139+
Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`GraphTensorNetworks.SimpleMultiprocessing.multiprocess_run`](@ref) function for simple multi-processing jobs.
140140
Suppose we want to find the independence polynomial for multiple graphs with 4 processes.
141141
We can create a file, e.g. named `run.jl` with the following content
142142

@@ -190,4 +190,4 @@ CUDA backended properties are
190190
* [`CountingAll`](@ref)
191191
* [`CountingMax`](@ref)
192192
* [`GraphPolynomial`](@ref)
193-
* [`SingleConfigMax`](@ref)
193+
* [`SingleConfigMax`](@ref)

docs/src/ref.md

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ StaticBitVector
9898
StaticElementVector
9999
save_configs
100100
load_configs
101+
save_sumproduct
102+
load_sumproduct
101103
@bv_str
102104
onehotv
103105

examples/IndependentSet.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,12 @@ all_independent_sets_tree = solve(problem, ConfigsAll(; tree_storage=true))[]
135135
# The results encode the configurations in the sum-product-tree format. One can count and enumerate them explicitly by typing
136136
length(all_independent_sets_tree)
137137

138-
#
138+
# Then one can use `Base.collect` function to create a [`ConfigEnumerator`](@ref) or use [`generate_samples`](@ref) to generate samples from it.
139139

140140
collect(all_independent_sets_tree)
141141

142+
generate_samples(all_independent_sets_tree, 10)
143+
142144
# One can use [`save_configs`](@ref) and [`load_configs`](@ref) to save and read a set of configuration to disk.
143145
filename = tempname()
144146

src/GraphTensorNetworks.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using TropicalNumbers
66
using OMEinsum
77
using OMEinsum: timespace_complexity, getixsv
88
using Graphs, Random
9+
using DelimitedFiles, Serialization
910

1011
# OMEinsum
1112
export timespace_complexity, timespacereadwrite_complexity, @ein_str, getixsv, getiyv
@@ -48,7 +49,7 @@ export is_matching
4849
export solve, SizeMax, SizeMin, CountingAll, CountingMax, CountingMin, GraphPolynomial, SingleConfigMax, SingleConfigMin, ConfigsAll, ConfigsMax, ConfigsMin
4950

5051
# Utilities
51-
export save_configs, load_configs, hamming_distribution
52+
export save_configs, load_configs, hamming_distribution, save_sumproduct, load_sumproduct
5253

5354
# Visualization
5455
export show_graph, spring_layout
@@ -64,6 +65,7 @@ include("configurations.jl")
6465
include("graphs.jl")
6566
include("bounding.jl")
6667
include("visualize.jl")
68+
include("fileio.jl")
6769
include("interfaces.jl")
6870
include("deprecate.jl")
6971
include("multiprocessing.jl")

src/arithematics.jl

+63-32
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,6 @@ function collect_geq!(res, A, B, mB, low)
297297
K = length(A)
298298
k = 1 # TODO: we should tighten mA, mB later!
299299
Ak = A[K-k+1]
300-
Bq = B[K-mB+1]
301300
l = 0
302301
for q = K-mB+1:-1:1
303302
Bq = B[K-q+1]
@@ -587,37 +586,56 @@ function printnode(io::IO, t::SumProductTree{ET}) where {ET}
587586
end
588587
end
589588

590-
# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
591-
Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}())
589+
# it must be mutable, otherwise, objectid might be slow serialization might fail.
590+
# IdDict is much slower than Dict, it is useless.
591+
Base.length(x::SumProductTree) = _length!(x, Dict{UInt, Float64}())
592592

593593
function _length!(x, d)
594-
haskey(d, x) && return d[x]
594+
id = objectid(x)
595+
haskey(d, id) && return d[id]
595596
if x.tag === SUM
596597
l = _length!(x.left, d) + _length!(x.right, d)
597-
d[x] = l
598+
d[id] = l
598599
return l
599600
elseif x.tag === PROD
600601
l = _length!(x.left, d) * _length!(x.right, d)
601-
d[x] = l
602+
d[id] = l
602603
return l
603604
elseif x.tag === ZERO
604-
return 0
605+
return 0.0
605606
else
606-
return 1
607+
return 1.0
607608
end
608609
end
609610

610-
num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}())
611+
function _find_branch(x, d)
612+
if x.tag === ZERO
613+
return true, 0.0
614+
elseif x.tag === ONE || x.tag === LEAF
615+
return true, 1.0
616+
else
617+
idl = objectid(x.left)
618+
if haskey(d, idl)
619+
return true, d[idl]
620+
else
621+
return false, 0.0
622+
end
623+
end
624+
end
625+
626+
627+
num_nodes(x::SumProductTree) = _num_nodes(x, Dict{UInt, Int}())
611628
function _num_nodes(x, d)
612-
haskey(d, x) && return 0
629+
id = objectid(x)
630+
haskey(d, id) && return 0
613631
if x.tag == ZERO || x.tag == ONE
614632
res = 1
615633
elseif x.tag == LEAF
616634
res = 1
617635
else
618636
res = _num_nodes(x.left, d) + _num_nodes(x.right, d) + 1
619637
end
620-
d[x] = res
638+
d[id] = res
621639
return res
622640
end
623641

@@ -708,33 +726,46 @@ true
708726
function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET}
709727
# get length dict
710728
res = fill(_data_one(ET), nsamples)
711-
d = IdDict{typeof(t), Int}()
729+
d = Dict{UInt, Float64}()
712730
sample_descend!(res, t, d)
713731
return res
714732
end
715733

716-
function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict)
717-
length(res) == 0 && return res
718-
if t.tag == LEAF
719-
res .|= Ref(t.data)
720-
elseif t.tag == SUM
721-
ratio = _length!(t.left, d)/_length!(t, d)
722-
nleft = 0
723-
for _ = 1:length(res)
724-
if rand() < ratio
725-
nleft += 1
734+
function sample_descend!(res::AbstractVector, t::SumProductTree, d::Dict)
735+
res_stack = Any[res]
736+
t_stack = [t]
737+
while !isempty(t_stack) && !isempty(res_stack)
738+
t = pop!(t_stack)
739+
res = pop!(res_stack)
740+
if t.tag == LEAF
741+
res .|= Ref(t.data)
742+
elseif t.tag == SUM
743+
ratio = _length!(t.left, d)/_length!(t, d)
744+
nleft = 0
745+
for _ = 1:length(res)
746+
if rand() < ratio
747+
nleft += 1
748+
end
749+
end
750+
shuffle!(res) # shuffle the `res` to avoid biased sampling, very important.
751+
if nleft >= 1
752+
push!(res_stack, view(res,1:nleft))
753+
push!(t_stack, t.left)
754+
end
755+
if length(res) > nleft
756+
push!(res_stack, view(res,nleft+1:length(res)))
757+
push!(t_stack, t.right)
726758
end
759+
elseif t.tag == PROD
760+
push!(res_stack, res)
761+
push!(res_stack, res)
762+
push!(t_stack, t.left)
763+
push!(t_stack, t.right)
764+
elseif t.tag == ZERO
765+
error("Meet zero when descending.")
766+
else
767+
# pass for 1
727768
end
728-
shuffle!(res) # shuffle the `res` to avoid biased sampling, very important.
729-
sample_descend!(view(res,1:nleft), t.left, d)
730-
sample_descend!(view(res,nleft+1:length(res)), t.right, d)
731-
elseif t.tag == PROD
732-
sample_descend!(res, t.right, d)
733-
sample_descend!(res, t.left, d)
734-
elseif t.tag == ZERO
735-
error("Meet zero when descending.")
736-
else
737-
# pass for 1
738769
end
739770
return res
740771
end

src/fileio.jl

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
save_configs(filename, data::ConfigEnumerator; format=:binary)
3+
4+
Save configurations `data` to file `filename`. The format is `:binary` or `:text`.
5+
"""
6+
function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C}
7+
if format == :binary
8+
write(filename, raw_matrix(data))
9+
elseif format == :text
10+
writedlm(filename, plain_matrix(data))
11+
else
12+
error("format must be `:binary` or `:text`, got `:$format`")
13+
end
14+
end
15+
16+
"""
17+
load_configs(filename; format=:binary, bitlength=nothing, nflavors=2)
18+
19+
Load configurations from file `filename`. The format is `:binary` or `:text`.
20+
If the format is `:binary`, the bitstring length `bitlength` must be specified,
21+
`nflavors` specifies the degree of freedom.
22+
"""
23+
function load_configs(filename; bitlength=nothing, format::Symbol=:binary, nflavors=2)
24+
if format == :binary
25+
bitlength === nothing && error("you need to specify `bitlength` for reading configurations from binary files.")
26+
S = ceil(Int, log2(nflavors))
27+
C = _nints(bitlength, S)
28+
return _from_raw_matrix(StaticElementVector{bitlength,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:))
29+
elseif format == :text
30+
return from_plain_matrix(readdlm(filename); nflavors=nflavors)
31+
else
32+
error("format must be `:binary` or `:text`, got `:$format`")
33+
end
34+
end
35+
36+
function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
37+
m = zeros(UInt64, C, length(x))
38+
@inbounds for i=1:length(x), j=1:C
39+
m[j,i] = x.data[i].data[j]
40+
end
41+
return m
42+
end
43+
function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
44+
m = zeros(UInt8, N, length(x))
45+
@inbounds for i=1:length(x), j=1:N
46+
m[j,i] = x.data[i][j]
47+
end
48+
return m
49+
end
50+
51+
function from_raw_matrix(m; bitlength, nflavors=2)
52+
S = ceil(Int,log2(nflavors))
53+
C = size(m, 1)
54+
T = StaticElementVector{bitlength,S,C}
55+
@assert bitlength*S <= C*64
56+
_from_raw_matrix(T, m)
57+
end
58+
function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
59+
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
60+
@inbounds for i=1:size(m, 2)
61+
data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i)))
62+
end
63+
return ConfigEnumerator(data)
64+
end
65+
function from_plain_matrix(m::Matrix; nflavors=2)
66+
S = ceil(Int,log2(nflavors))
67+
N = size(m, 1)
68+
C = _nints(N, S)
69+
T = StaticElementVector{N,S,C}
70+
_from_plain_matrix(T, m)
71+
end
72+
function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
73+
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
74+
@inbounds for i=1:size(m, 2)
75+
data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i))
76+
end
77+
return ConfigEnumerator(data)
78+
end
79+
80+
# convert to Matrix
81+
Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce)
82+
Base.Vector(ce::StaticElementVector) = collect(ce)
83+
84+
########## saving tree ####################
85+
"""
86+
save_sumproduct(filename, t::SumProductTree)
87+
88+
Serialize a sum-product tree into a file.
89+
"""
90+
save_sumproduct(filename::String, t::SumProductTree) = serialize(filename, dict_serialize_tree!(t, Dict{UInt,Any}()))
91+
92+
"""
93+
load_sumproduct(filename)
94+
95+
Deserialize a sum-product tree from a file.
96+
"""
97+
load_sumproduct(filename::String) = dict_deserialize_tree(deserialize(filename)...)
98+
99+
function dict_serialize_tree!(t::SumProductTree, d::Dict)
100+
id = objectid(t)
101+
if !haskey(d, id)
102+
if t.tag === GraphTensorNetworks.LEAF || t.tag === GraphTensorNetworks.ZERO || t.tag == GraphTensorNetworks.ONE
103+
d[id] = t
104+
else
105+
d[id] = (t.tag, objectid(t.left), objectid(t.right))
106+
dict_serialize_tree!(t.left, d)
107+
dict_serialize_tree!(t.right, d)
108+
end
109+
end
110+
return id, d
111+
end
112+
113+
function dict_deserialize_tree(id::UInt, d::Dict)
114+
@assert haskey(d, id)
115+
content = d[id]
116+
if content isa SumProductTree
117+
return content
118+
else
119+
(tag, left, right) = content
120+
t = SumProductTree(tag, dict_deserialize_tree(left, d), dict_deserialize_tree(right, d))
121+
d[id] = t
122+
return t
123+
end
124+
end
125+

0 commit comments

Comments
 (0)