diff options
-rw-r--r-- | lib/containers/intervaltree.go | 142 | ||||
-rw-r--r-- | lib/containers/intervaltree_test.go | 81 |
2 files changed, 223 insertions, 0 deletions
diff --git a/lib/containers/intervaltree.go b/lib/containers/intervaltree.go new file mode 100644 index 0000000..424b297 --- /dev/null +++ b/lib/containers/intervaltree.go @@ -0,0 +1,142 @@ +// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package containers + +type intervalKey[K Ordered[K]] struct { + Min, Max K +} + +func (ival intervalKey[K]) ContainsFn(fn func(K) int) bool { + return fn(ival.Min) >= 0 && fn(ival.Max) <= 0 +} + +func (a intervalKey[K]) Cmp(b intervalKey[K]) int { + if d := a.Min.Cmp(b.Min); d != 0 { + return d + } + return a.Max.Cmp(b.Max) +} + +type intervalValue[K Ordered[K], V any] struct { + Val V + SpanOfChildren intervalKey[K] +} + +type IntervalTree[K Ordered[K], V any] struct { + MinFn func(V) K + MaxFn func(V) K + inner RBTree[intervalKey[K], intervalValue[K, V]] +} + +func (t *IntervalTree[K, V]) keyFn(v intervalValue[K, V]) intervalKey[K] { + return intervalKey[K]{ + Min: t.MinFn(v.Val), + Max: t.MaxFn(v.Val), + } +} + +func (t *IntervalTree[K, V]) attrFn(node *RBNode[intervalValue[K, V]]) { + max := t.MaxFn(node.Value.Val) + if node.Left != nil && node.Left.Value.SpanOfChildren.Max.Cmp(max) > 0 { + max = node.Left.Value.SpanOfChildren.Max + } + if node.Right != nil && node.Right.Value.SpanOfChildren.Max.Cmp(max) > 0 { + max = node.Right.Value.SpanOfChildren.Max + } + node.Value.SpanOfChildren.Max = max + + min := t.MinFn(node.Value.Val) + if node.Left != nil && node.Left.Value.SpanOfChildren.Min.Cmp(min) < 0 { + min = node.Left.Value.SpanOfChildren.Min + } + if node.Right != nil && node.Right.Value.SpanOfChildren.Min.Cmp(min) < 0 { + min = node.Right.Value.SpanOfChildren.Min + } + node.Value.SpanOfChildren.Min = min +} + +func (t *IntervalTree[K, V]) init() { + if t.inner.KeyFn == nil { + t.inner.KeyFn = t.keyFn + t.inner.AttrFn = t.attrFn + } +} + +func (t *IntervalTree[K, V]) Delete(min, max K) { + t.init() + t.inner.Delete(intervalKey[K]{ + Min: min, + Max: max, + }) +} + +func (t *IntervalTree[K, V]) Equal(u *IntervalTree[K, V]) bool { + return t.inner.Equal(&u.inner) +} + +func (t *IntervalTree[K, V]) Insert(val V) { + t.init() + t.inner.Insert(intervalValue[K, V]{Val: val}) +} + +func (t *IntervalTree[K, V]) Min() (K, bool) { + if t.inner.root == nil { + var zero K + return zero, false + } + return t.inner.root.Value.SpanOfChildren.Min, true +} + +func (t *IntervalTree[K, V]) Max() (K, bool) { + if t.inner.root == nil { + var zero K + return zero, false + } + return t.inner.root.Value.SpanOfChildren.Max, true +} + +func (t *IntervalTree[K, V]) Lookup(k K) (V, bool) { + return t.Search(k.Cmp) +} + +func (t *IntervalTree[K, V]) Search(fn func(K) int) (V, bool) { + node := t.inner.root + for node != nil { + switch { + case t.keyFn(node.Value).ContainsFn(fn): + return node.Value.Val, true + case node.Left != nil && node.Left.Value.SpanOfChildren.ContainsFn(fn): + node = node.Left + case node.Right != nil && node.Right.Value.SpanOfChildren.ContainsFn(fn): + node = node.Right + default: + node = nil + } + } + var zero V + return zero, false +} + +func (t *IntervalTree[K, V]) searchAll(fn func(K) int, node *RBNode[intervalValue[K, V]], ret *[]V) { + if node == nil { + return + } + if !node.Value.SpanOfChildren.ContainsFn(fn) { + return + } + t.searchAll(fn, node.Left, ret) + if t.keyFn(node.Value).ContainsFn(fn) { + *ret = append(*ret, node.Value.Val) + } + t.searchAll(fn, node.Right, ret) +} + +func (t *IntervalTree[K, V]) SearchAll(fn func(K) int) []V { + var ret []V + t.searchAll(fn, t.inner.root, &ret) + return ret +} + +//func (t *IntervalTree[K, V]) Walk(fn func(*RBNode[V]) error) error diff --git a/lib/containers/intervaltree_test.go b/lib/containers/intervaltree_test.go new file mode 100644 index 0000000..7a4689b --- /dev/null +++ b/lib/containers/intervaltree_test.go @@ -0,0 +1,81 @@ +// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package containers + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func (t *IntervalTree[K, V]) ASCIIArt() string { + return t.inner.ASCIIArt() +} + +func (v intervalValue[K, V]) String() string { + return fmt.Sprintf("%v) ([%v,%v]", + v.Val, + v.SpanOfChildren.Min, + v.SpanOfChildren.Max) +} + +func (v NativeOrdered[T]) String() string { + return fmt.Sprintf("%v", v.Val) +} + +type SimpleInterval struct { + Min, Max int +} + +func (ival SimpleInterval) String() string { + return fmt.Sprintf("[%v,%v]", ival.Min, ival.Max) +} + +func TestIntervalTree(t *testing.T) { + tree := IntervalTree[NativeOrdered[int], SimpleInterval]{ + MinFn: func(ival SimpleInterval) NativeOrdered[int] { return NativeOrdered[int]{ival.Min} }, + MaxFn: func(ival SimpleInterval) NativeOrdered[int] { return NativeOrdered[int]{ival.Max} }, + } + + // CLRS Figure 14.4 + // level 0 + tree.Insert(SimpleInterval{16, 21}) + // level 1 + tree.Insert(SimpleInterval{8, 9}) + tree.Insert(SimpleInterval{25, 30}) + // level 2 + tree.Insert(SimpleInterval{5, 8}) + tree.Insert(SimpleInterval{15, 23}) + tree.Insert(SimpleInterval{17, 19}) + tree.Insert(SimpleInterval{26, 26}) + // level 3 + tree.Insert(SimpleInterval{0, 3}) + tree.Insert(SimpleInterval{6, 10}) + tree.Insert(SimpleInterval{19, 20}) + + t.Log(tree.ASCIIArt()) + + // find intervals that touch [9,20] + intervals := tree.SearchAll(func(k NativeOrdered[int]) int { + if k.Val < 9 { + return 1 + } + if k.Val > 20 { + return -1 + } + return 0 + }) + assert.Equal(t, + []SimpleInterval{ + {6, 10}, + {8, 9}, + {15, 23}, + {16, 21}, + {17, 19}, + {19, 20}, + }, + intervals) +} |