summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/containers/intervaltree.go142
-rw-r--r--lib/containers/intervaltree_test.go81
2 files changed, 223 insertions, 0 deletions
diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go
new file mode 100644
index 0000000..424b297
--- /dev/null
+++ b/lib/containers/intervaltree.go
@@ -0,0 +1,142 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package containers
+
+type intervalKey[K Ordered[K]] struct {
+ Min, Max K
+}
+
+func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool {
+ return fn(ival.Min) >= 0 && fn(ival.Max) <= 0
+}
+
+func (a intervalKey[K]) Cmp(b intervalKey[K]) int {
+ if d := a.Min.Cmp(b.Min); d != 0 {
+ return d
+ }
+ return a.Max.Cmp(b.Max)
+}
+
+type intervalValue[K Ordered[K], V any] struct {
+ Val V
+ SpanOfChildren intervalKey[K]
+}
+
+type IntervalTree[K Ordered[K], V any] struct {
+ MinFn func(V) K
+ MaxFn func(V) K
+ inner RBTree[intervalKey[K], intervalValue[K, V]]
+}
+
+func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] {
+ return intervalKey[K]{
+ Min: t.MinFn(v.Val),
+ Max: t.MaxFn(v.Val),
+ }
+}
+
+func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) {
+ max := t.MaxFn(node.Value.Val)
+ if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Cmp(max) > 0 {
+ max = node.Left.Value.SpanOfChildren.Max
+ }
+ if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Cmp(max) > 0 {
+ max = node.Right.Value.SpanOfChildren.Max
+ }
+ node.Value.SpanOfChildren.Max = max
+
+ min := t.MinFn(node.Value.Val)
+ if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Cmp(min) < 0 {
+ min = node.Left.Value.SpanOfChildren.Min
+ }
+ if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Cmp(min) < 0 {
+ min = node.Right.Value.SpanOfChildren.Min
+ }
+ node.Value.SpanOfChildren.Min = min
+}
+
+func (t *IntervalTree[K, V]) init() {
+ if t.inner.KeyFn == nil {
+ t.inner.KeyFn = t.keyFn
+ t.inner.AttrFn = t.attrFn
+ }
+}
+
+func (t *IntervalTree[K, V]) Delete(min, max K) {
+ t.init()
+ t.inner.Delete(intervalKey[K]{
+ Min: min,
+ Max: max,
+ })
+}
+
+func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool {
+ return t.inner.Equal(&u.inner)
+}
+
+func (t *IntervalTree[K, V]) Insert(val V) {
+ t.init()
+ t.inner.Insert(intervalValue[K, V]{Val: val})
+}
+
+func (t *IntervalTree[K, V]) Min() (K, bool) {
+ if t.inner.root == nil {
+ var zero K
+ return zero, false
+ }
+ return t.inner.root.Value.SpanOfChildren.Min, true
+}
+
+func (t *IntervalTree[K, V]) Max() (K, bool) {
+ if t.inner.root == nil {
+ var zero K
+ return zero, false
+ }
+ return t.inner.root.Value.SpanOfChildren.Max, true
+}
+
+func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) {
+ return t.Search(k.Cmp)
+}
+
+func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) {
+ node := t.inner.root
+ for node != nil {
+ switch {
+ case t.keyFn(node.Value).ContainsFn(fn):
+ return node.Value.Val, true
+ case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn):
+ node = node.Left
+ case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn):
+ node = node.Right
+ default:
+ node = nil
+ }
+ }
+ var zero V
+ return zero, false
+}
+
+func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) {
+ if node == nil {
+ return
+ }
+ if !node.Value.SpanOfChildren.ContainsFn(fn) {
+ return
+ }
+ t.searchAll(fn, node.Left, ret)
+ if t.keyFn(node.Value).ContainsFn(fn) {
+ *ret = append(*ret, node.Value.Val)
+ }
+ t.searchAll(fn, node.Right, ret)
+}
+
+func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V {
+ var ret []V
+ t.searchAll(fn, t.inner.root, &ret)
+ return ret
+}
+
+//func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error
diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go
new file mode 100644
index 0000000..7a4689b
--- /dev/null
+++ b/lib/containers/intervaltree_test.go
@@ -0,0 +1,81 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package containers
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func (t *IntervalTree[K, V]) ASCIIArt() string {
+ return t.inner.ASCIIArt()
+}
+
+func (v intervalValue[K, V]) String() string {
+ return fmt.Sprintf("%v) ([%v,%v]",
+ v.Val,
+ v.SpanOfChildren.Min,
+ v.SpanOfChildren.Max)
+}
+
+func (v NativeOrdered[T]) String() string {
+ return fmt.Sprintf("%v", v.Val)
+}
+
+type SimpleInterval struct {
+ Min, Max int
+}
+
+func (ival SimpleInterval) String() string {
+ return fmt.Sprintf("[%v,%v]", ival.Min, ival.Max)
+}
+
+func TestIntervalTree(t *testing.T) {
+ tree := IntervalTree[NativeOrdered[int], SimpleInterval]{
+ MinFn: func(ival SimpleInterval) NativeOrdered[int] { return NativeOrdered[int]{ival.Min} },
+ MaxFn: func(ival SimpleInterval) NativeOrdered[int] { return NativeOrdered[int]{ival.Max} },
+ }
+
+ // CLRS Figure 14.4
+ // level 0
+ tree.Insert(SimpleInterval{16, 21})
+ // level 1
+ tree.Insert(SimpleInterval{8, 9})
+ tree.Insert(SimpleInterval{25, 30})
+ // level 2
+ tree.Insert(SimpleInterval{5, 8})
+ tree.Insert(SimpleInterval{15, 23})
+ tree.Insert(SimpleInterval{17, 19})
+ tree.Insert(SimpleInterval{26, 26})
+ // level 3
+ tree.Insert(SimpleInterval{0, 3})
+ tree.Insert(SimpleInterval{6, 10})
+ tree.Insert(SimpleInterval{19, 20})
+
+ t.Log(tree.ASCIIArt())
+
+ // find intervals that touch [9,20]
+ intervals := tree.SearchAll(func(k NativeOrdered[int]) int {
+ if k.Val < 9 {
+ return 1
+ }
+ if k.Val > 20 {
+ return -1
+ }
+ return 0
+ })
+ assert.Equal(t,
+ []SimpleInterval{
+ {6, 10},
+ {8, 9},
+ {15, 23},
+ {16, 21},
+ {17, 19},
+ {19, 20},
+ },
+ intervals)
+}