@@ -2,6 +2,9 @@ using Polynomials: Polynomial
2
2
using TropicalNumbers: Tropical, CountingTropical
3
3
using Mods, Primes
4
4
using Base. Cartesian
5
+ import AbstractTrees: children, printnode, print_tree
6
+
7
+ @enum TreeTag LEAF SUM PROD ZERO
5
8
6
9
# pirate
7
10
Base. abs (x:: Mod ) = x
@@ -275,20 +278,179 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
275
278
Base. zero (:: ConfigSampler{N,S,C} ) where {N,S,C} = zero (ConfigSampler{N,S,C})
276
279
Base. one (:: ConfigSampler{N,S,C} ) where {N,S,C} = one (ConfigSampler{N,S,C})
277
280
278
- # A patch to make `Polynomial{ConfigEnumerator}` work
279
- function Base.:* (a:: Int , y:: ConfigEnumerator )
280
- a == 0 && return zero (y)
281
- a == 1 && return y
282
- error (" multiplication between int and config enumerator is not defined." )
281
+ # tree config enumerator
282
+ """
283
+ TreeConfigEnumerator{N,S,C}
284
+
285
+ Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
286
+ and is often more memory efficient than putting the configurations in a vector.
287
+ `N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.
288
+
289
+ Fields
290
+ -----------------------
291
+ * `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
292
+ * `data` is the element stored in a `LEAF` node.
293
+ * `left` and `right` are two operands of a `SUM` or `PROD` node.
294
+
295
+ Example
296
+ ------------------------
297
+ ```jldoctest; setup=:(using GraphTensorNetworks)
298
+ julia> s = TreeConfigEnumerator(bv"00111")
299
+ 00111
300
+
301
+
302
+ julia> q = TreeConfigEnumerator(bv"10000")
303
+ 10000
304
+
305
+
306
+ julia> x = s + q
307
+ +
308
+ ├─ 00111
309
+ └─ 10000
310
+
311
+
312
+ julia> y = x * x
313
+ *
314
+ ├─ +
315
+ │ ├─ 00111
316
+ │ └─ 10000
317
+ └─ +
318
+ ├─ 00111
319
+ └─ 10000
320
+
321
+
322
+ julia> collect(y)
323
+ 4-element Vector{StaticBitVector{5, 1}}:
324
+ 00111
325
+ 10111
326
+ 10111
327
+ 10000
328
+
329
+ julia> zero(s)
330
+
331
+
332
+
333
+ julia> one(s)
334
+ 00000
335
+
336
+
337
+ ```
338
+ """
339
+ struct TreeConfigEnumerator{N,S,C}
340
+ tag:: TreeTag
341
+ data:: StaticElementVector{N,S,C}
342
+ left:: TreeConfigEnumerator{N,S,C}
343
+ right:: TreeConfigEnumerator{N,S,C}
344
+ TreeConfigEnumerator (tag:: TreeTag , left:: TreeConfigEnumerator{N,S,C} , right:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = new {N,S,C} (tag, zero (StaticElementVector{N,S,C}), left, right)
345
+ function TreeConfigEnumerator (data:: StaticElementVector{N,S,C} ) where {N,S,C}
346
+ new {N,S,C} (LEAF, data)
347
+ end
348
+ function TreeConfigEnumerator {N,S,C} (tag:: TreeTag ) where {N,S,C}
349
+ @assert tag === ZERO
350
+ return new {N,S,C} (tag)
351
+ end
283
352
end
284
- function Base.:* (a:: Int , y:: ConfigSampler )
285
- a == 0 && return zero (y)
286
- a == 1 && return y
287
- error (" multiplication between int and config sampler is not defined." )
353
+
354
+ # AbstractTree APIs
355
+ function children (t:: TreeConfigEnumerator )
356
+ if isdefined (t, :left )
357
+ if isdefined (t, :right )
358
+ return [t. left, t. right]
359
+ else
360
+ return [t. left]
361
+ end
362
+ else
363
+ if isdefined (t, :right )
364
+ return [t. right]
365
+ else
366
+ return typeof (t)[]
367
+ end
368
+ end
369
+ end
370
+ function printnode (io:: IO , t:: TreeConfigEnumerator )
371
+ if t. tag === LEAF
372
+ print (io, t. data)
373
+ elseif t. tag === ZERO
374
+ print (io, " " )
375
+ elseif t. tag === SUM
376
+ print (io, " +" )
377
+ else # PROD
378
+ print (io, " *" )
379
+ end
380
+ end
381
+
382
+ function Base. length (x:: TreeConfigEnumerator )
383
+ if x. tag === SUM
384
+ return length (x. left) + length (x. right)
385
+ elseif x. tag === PROD
386
+ return length (x. left) * length (x. right)
387
+ elseif x. tag === ZERO
388
+ return 0
389
+ else
390
+ return 1
391
+ end
392
+ end
393
+
394
+ function num_nodes (x:: TreeConfigEnumerator )
395
+ x. tag == ZERO && return 1
396
+ x. tag == LEAF && return 1
397
+ return num_nodes (x. left) + num_nodes (x. right) + 1
398
+ end
399
+
400
+ function Base.:(== )(x:: TreeConfigEnumerator{N,S,C} , y:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
401
+ return Set (collect (x)) == Set (collect (y))
402
+ end
403
+
404
+ Base. show (io:: IO , t:: TreeConfigEnumerator ) = print_tree (io, t)
405
+
406
+ function Base. collect (x:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
407
+ if x. tag == ZERO
408
+ return StaticElementVector{N,S,C}[]
409
+ elseif x. tag == LEAF
410
+ return StaticElementVector{N,S,C}[x. data]
411
+ elseif x. tag == SUM
412
+ return vcat (collect (x. left), collect (x. right))
413
+ else # PROD
414
+ return vec ([reduce ((x,y)-> x| y, si) for si in Iterators. product (collect (x. left), collect (x. right))])
415
+ end
416
+ end
417
+
418
+ function Base.:+ (x:: TreeConfigEnumerator{N,S,C} , y:: TreeConfigEnumerator{N,S,C} ) where {N,S,C}
419
+ TreeConfigEnumerator (SUM, x, y)
420
+ end
421
+
422
+ function Base.:* (x:: TreeConfigEnumerator{L,S,C} , y:: TreeConfigEnumerator{L,S,C} ) where {L,S,C}
423
+ TreeConfigEnumerator (PROD, x, y)
424
+ end
425
+
426
+ Base. zero (:: Type{TreeConfigEnumerator{N,S,C}} ) where {N,S,C} = TreeConfigEnumerator {N,S,C} (ZERO)
427
+ Base. one (:: Type{TreeConfigEnumerator{N,S,C}} ) where {N,S,C} = TreeConfigEnumerator (zero (StaticElementVector{N,S,C}))
428
+ Base. zero (:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = zero (TreeConfigEnumerator{N,S,C})
429
+ Base. one (:: TreeConfigEnumerator{N,S,C} ) where {N,S,C} = one (TreeConfigEnumerator{N,S,C})
430
+ # todo, check siblings too?
431
+ function Base. iszero (t:: TreeConfigEnumerator )
432
+ if t. TAG == SUM
433
+ iszero (t. left) && iszero (t. right)
434
+ elseif t. TAG == ZERO
435
+ true
436
+ elseif t. TAG == LEAF
437
+ false
438
+ else
439
+ iszero (t. left) || iszero (t. right)
440
+ end
441
+ end
442
+
443
+ # A patch to make `Polynomial{ConfigEnumerator}` work
444
+ for T in [:ConfigEnumerator , :ConfigSampler , :TreeConfigEnumerator ]
445
+ @eval function Base.:* (a:: Int , y:: $T )
446
+ a == 0 && return zero (y)
447
+ a == 1 && return y
448
+ error (" multiplication between int and `$(typeof (y)) ` is not defined." )
449
+ end
288
450
end
289
451
290
452
# convert from counting type to bitstring type
291
- for (F,TP) in [(:set_type , :ConfigEnumerator ), (:sampler_type , :ConfigSampler )]
453
+ for (F,TP) in [(:set_type , :ConfigEnumerator ), (:sampler_type , :ConfigSampler ), ( :treeset_type , :TreeConfigEnumerator ) ]
292
454
@eval begin
293
455
function $F (:: Type{T} , n:: Int , nflavor:: Int ) where {OT, K, T<: TruncatedPoly{K,C,OT} where C}
294
456
TruncatedPoly{K, $ F (n,nflavor),OT}
@@ -312,12 +474,24 @@ end
312
474
313
475
# utilities for creating onehot vectors
314
476
onehotv (:: Type{ConfigEnumerator{N,S,C}} , i:: Integer , v) where {N,S,C} = ConfigEnumerator ([onehotv (StaticElementVector{N,S,C}, i, v)])
477
+ onehotv (:: Type{TreeConfigEnumerator{N,S,C}} , i:: Integer , v) where {N,S,C} = TreeConfigEnumerator (onehotv (StaticElementVector{N,S,C}, i, v))
315
478
onehotv (:: Type{ConfigSampler{N,S,C}} , i:: Integer , v) where {N,S,C} = ConfigSampler (onehotv (StaticElementVector{N,S,C}, i, v))
479
+ # just to make matrix transpose work
316
480
Base. transpose (c:: ConfigEnumerator ) = c
317
481
Base. copy (c:: ConfigEnumerator ) = ConfigEnumerator (copy (c. data))
482
+ Base. transpose (c:: TreeConfigEnumerator ) = c
483
+ function Base. copy (c:: TreeConfigEnumerator )
484
+ if c. tag == LEAF
485
+ TreeConfigEnumerator (c. data)
486
+ elseif c. tag == ZERO
487
+ TreeConfigEnumerator (c. tag)
488
+ else
489
+ TreeConfigEnumerator (c. tag, c. left, c. right)
490
+ end
491
+ end
318
492
319
493
# Handle boolean, this is a patch for CUDA matmul
320
- for TYPE in [:ConfigEnumerator , :ConfigSampler , :TruncatedPoly ]
494
+ for TYPE in [:ConfigEnumerator , :ConfigSampler , :TruncatedPoly , :TreeConfigEnumerator ]
321
495
@eval Base.:* (a:: Bool , y:: $TYPE ) = a ? y : zero (y)
322
496
@eval Base.:* (y:: $TYPE , a:: Bool ) = a ? y : zero (y)
323
497
end
0 commit comments