I’ve been spending the past month or so formalizing a folklore optimized implementation of
segment trees, a workhorse data structure in competitive programming.
Compared to the implementation I learned when I was starting out (similar to the one in Implementation), this has the advantage
of only needing to store 2 * n elements, compared to 4 * n, and being conceptually much simpler than the ‘Memory efficient implementation’
linked above.
Implementation
The common implementation that explicitly allocates a node (whether on the heap, or within an struct-of-arrays data structure) can be modeled as follows:
inductive SegmentTree (α : Type) where
| Leaf (element : α) : SegmentTree α
| Internal (length : Nat) (total : α) (left : SegmentTree α) (right : SegmentTree α) : SegmentTree α
Reasoning with this could be easier, since we gain the ability to do proper structural induction, but the array indexing logic is a good part of what made the implementation so inscrutable to me1. so I decided to instead model a segment tree as a flat list of values:
structure SegmentTree (α : Type) : Type where
origLen : Nat
elements : Array α
helementsLength : elements.size = 2 * origLen
helementsLength is required to discharge some of the bounds checks in our operations later, so we
include it in the definition of the data structure.
The query function is a simple translation of the for loop in the C++ code to a recursive function. The only new additions are:
arr_index_to_segtree_index, a helper to translate bounds in terms of the original array’s length to bounds in terms of theelementsarray’s length, which includes internal nodes.- a requirement that the underlying type
αhas aMonoidinstance, i.e. that it should support an associate operation*with an identity, which generalizes over almost2 all operations that can be aggregated using a segment tree.
def query' [Monoid α] (tree : SegmentTree α) (i j : Nat) (accl accr : α) (hi : i < tree.elements.size := by get_elem_tactic) (hj : j <= tree.elements.size := by get_elem_tactic) : α :=
if j ≤ i
then accl * accr
else
let new_accl :=
if i % 2 = 1
then accl * tree.elements[i]
else accl
let new_accr :=
if j % 2 = 1
then tree.elements[j - 1] * accr
else accr
query' tree ((i + 1) / 2) (j / 2) new_accl new_accr
def query [Monoid α] (tree : SegmentTree α) (i j : Nat) (hi : i < tree.origLen := by get_elem_tactic) (hj : j <= tree.origLen := by get_elem_tactic) : α :=
let hi := arr_index_to_segtree_index tree i hi
let hj := arr_index_to_segtree_index' tree j hj
query' tree (i + tree.origLen) (j + tree.origLen) 1 1
The update function is a slightly less straightforward, but still simple translation of the for loop:
def rebuild [Monoid α] (segtree : SegmentTree α) (i : Nat) (hi : i < segtree.elements.size) :=
if i = 0
then segtree
else
let j := i / 2
let j_c2 := 2 * j + 1
let hj_c2 : j_c2 < segtree.elements.size := by
unfold j_c2 j
have hsize : segtree.elements.size % 2 = 0 := by
apply segtree_elements_even_size
omega
let value_to_set := segtree.elements[2 * j] * segtree.elements[j_c2]
rebuild (set_index j value_to_set segtree) j (by unfold set_index; simp; omega)
def update [Monoid α] (tree : SegmentTree α) (i : Nat) (hi : i < tree.origLen) (f : α -> α) :=
let hin : i + tree.origLen < tree.elements.size := arr_index_to_segtree_index tree i hi
let current_value := tree.elements[i + tree.origLen]
let tree := set_index (i + tree.origLen) (f current_value) tree hin
let hin : i + tree.origLen < tree.elements.size := arr_index_to_segtree_index tree i hi
rebuild tree (i + tree.origLen) hin
Here we have to do a bit more legwork to get all the bounds checks to pass. We use the fact
that the elements array is even to prove that 2 * j + 1 is in-bounds, and we use a helper,
set_index, to set an element in the underlying array while preserving the length constraint.
Basic Concepts
This is the main structural invariant of this segment tree, which essentially states that for every node in the tree, the value we store is the product of the values of its children. All other correctness proofs will assume this:
def TreeCorrect [Monoid α] (segtree : SegmentTree α) : Prop :=
∀ j : Nat,
∀ (hj : 2 * j + 1 < segtree.elements.size),
j > 0 -> segtree.elements[j] = segtree.elements[2 * j] * segtree.elements[2 * j + 1]
We will also need a Prop to represent the segment tree “storing” an array, which corresponds
to the elements of the array being stored in the leaves of the tree:
def Representative [Monoid α] (segtree : SegmentTree α) (xs : List α) : Prop :=
TreeCorrect segtree ∧
xs.length = segtree.origLen ∧
∀ (i : Fin xs.length) (heq : i + segtree.origLen < segtree.elements.size),
segtree.elements[i + segtree.origLen] = xs[i]
Correctness of updates
The main correctness theorem for update should state that if a segment tree stores a
particular array xs, the array stored after the update should be xs, with the
relevant element modified. In Lean, this looks like this:
theorem update_representative [Monoid α] :
∀ (tree : SegmentTree α) (i : Nat) (hi : i < tree.origLen) (f : α -> α) (xs : List α),
Representative tree xs -> Representative (update tree i hi f) (List.modify xs i f)
update modifies one element of the tree, then walks up the tree via rebuild to restore
the TreeCorrect invariant. Representative is a conjunction of three Props:
- The second is that the
elementsarray has the correct length. This follows from simple induction. - The third is that the tree’s leaf elements correspond to the array. This follows from the fact
that
rebuilddoes not modify these leaf elements, which also follows from simple induction. - The remaining one is that the tree is
TreeCorrect. To prove this, we prove thatrebuildmaintains a temporary invariant: that the tree isTreeCorrectaside from one element, a ‘hole’:
def TreeCorrectHole [Monoid α] (segtree : SegmentTree α) (i : Nat) : Prop :=
∀ j : Nat,
∀ (hj : 2 * j + 1 < segtree.elements.size),
j > 0 -> j ≠ i -> segtree.elements[j] = segtree.elements[2 * j] * segtree.elements[2 * j + 1]
Then, rebuild’s correctness can be stated as such: when given an almost-correct tree,
with the ‘hole’ in the right spot,
it returns a correct tree, which is expressed in the following code:
lemma rebuild_correctness [Monoid α] :
∀ (tree : SegmentTree α) (i : Nat) (hi : i < tree.elements.size),
TreeCorrectHole tree (i / 2) -> TreeCorrect (rebuild tree i hi)
This is enough to prove that the resulting tree from update is TreeCorrect.
Correctness of queries
The main correctness theorem for query should state that the result of query equals
the result of taking the product over the relevant subarray of the array it represents:
theorem query_correctness [Monoid α] : ∀ (tree : SegmentTree α) (xs : List α) (i j : Nat) hi hj,
Representative tree xs -> query tree i j hi hj = List.prod (List.take (j - i) (xs.drop i))
Note that we use List.prod over Mathlib’s Finprod.sum, since the latter requires that α be
a commutative monoid, which does not necessarily hold in this setting.
To prove this, we first generalize from the stored element array to arbitrary functions Nat -> Nat
that have a TreeCorrect-like structure: (I’m sure the following can also be done directly on the element array, but dealing with all the bounds
obligations really pissed me off.)
def isTreeMapping [Monoid α] (f : Nat -> α) : Prop :=
∀ i, 0 < i -> f i = f (2 * i) * f (2 * i + 1)
def takeRangeSum' [Monoid α] (f : Nat -> α) (i j : Nat) (accl accr : α) : α :=
if j ≤ i
then accl * accr
else
let new_accl :=
if i % 2 = 1
then accl * f i
else accl
let new_accr :=
if j % 2 = 1
then f (j - 1) * accr
else accr
takeRangeSum' f ((i + 1) / 2) (j / 2) new_accl new_accr
def takeRangeSum [Monoid α] (f : Nat -> α) (i j : Nat) : α :=
takeRangeSum' f i j 1 1
def takeLineSum [Monoid α] (f : Nat -> α) (i j : Nat) : α :=
let indices := List.range' i (j - i)
List.foldr (fun x y => x * y) 1 (List.map f indices)
theorem takeRangeSum_correctness [Monoid α] (f : Nat -> α) :
isTreeMapping f ->
∀ i j, j <= 2 * i -> takeRangeSum f i j = takeLineSum f i j
The proof proceeds by induction on j, with the help of a few legwork lemmas about takeLineSum.
To be completely honest, this was the most interesting part of this implementation, so to
have its correctness be proven by simple induction felt really anti-climactic to me, especially because
it did not reflect the intuition I’ve built thinking about this data structure at all!
Anyway, with a few helper lemmas, we can reduce query_correctness to takeRangeSum_correctness:
lemma list_range_sum [Monoid α] : ∀ (tree : SegmentTree α) (xs : List α) (i j : Nat),
i < tree.origLen -> j <= tree.origLen ->
Representative tree xs -> List.prod (List.take (j - i) (xs.drop i)) = takeLineSum (extended_values tree.elements) (i + tree.origLen) (j + tree.origLen)
lemma query_equivalent_takeRangeSum [Monoid α] : ∀ (tree : SegmentTree α) (i j : Nat) hi hj,
query tree i j hi hj = takeRangeSum (extended_values tree.elements) (i + tree.origLen) (j + tree.origLen)
completing the proof of correctness for queries.
What’s really happening
As I mentioned earlier, over the course of this project I ended up developing a different intuition to how the data structure works compared to what I ended up formalizing.
I’ve known for a while that it’s possible to split an interval [l, r) of integers into O(log n) ranges, each of which can be described as
the set of the numbers x such that x >> k == n for some n, k. For instance, partitioning the range [77, 102] is done as follows:
| range start | range end | range mask (binary) |
|---|---|---|
| 77 | 77 | 1001101 |
| 78 | 79 | 100111? |
| 80 | 95 | 101???? |
| 96 | 99 | 11000?? |
| 100 | 101 | 110010? |
| 102 | 102 | 1100101 |
If you look at the structure of the internal nodes in the segment tree, you will notice that the ranges of array elements covered by each internal
node can be written in this form. In fact, all query does is partition the input range [l + n, r + n) into subranges with this form, then
use the precomputed sums of each subrange to get the sum of the whole range! This interested me because the way I usually implement deriving
such a partition is more complicated than query. However, in this setting, we know that, given the input interval [l, r) to partition,
2 * l <= r, which helps simplify assumptions (e.g. among all your partitions, the popcount of any number is either x or x + 1, for some x).
This especially interested me, because in the special case where n is a power of 2, and l = 0, then the partitions are fully decided by
the set bits in r. In this special case, our data structure becomes essentially identical to a binary-indexed tree (BIT), the other usual choice
for solving this type of problem3. This also somewhat applies to the general case, as there, the partitions are decided by the set bits of l, which are
used to create partitions in increasing order, and the set bits of r, which are used to create partitions in decreasing order. This creates
a view of segment trees as a collection of mini-BITs, and a BIT can be seen as compressing this collection to only take the largest one. I should
probably write more about what this means later on.
Usage
Anyway, the data structure itself is usable. Internally, it uses Lean arrays, which are contiguous in memory, leading to better performance for
our random access pattern compared to the typical linked lists you see in functional languages.
I’m unsure if it’s guaranteed that the array is never copied assuming
the caller only keeps one reference to the segment tree; this is something I’ll need to verify once I’m better aware of Lean’s runtime model.
I did end up defining this helper type to make the size of the tree immediately inferrable for purposes of proving bounds obligations,
as opposed to having to simplify the tree back to the original definition with empty_tree:
def BoundedSegmentTree (α : Type) (n : Nat) := { tree : ArrayTree.SegmentTree α // tree.origLen = n }
This is how it can be used in practice, along with a demonstration of how to implement custom Monoid instances
to change the internal operation:
structure Elem where
val : Int
deriving Repr
instance : Mul Elem where
mul (p q : Elem) := ⟨p.val + q.val⟩
instance : One Elem where
one := ⟨0⟩
instance : Monoid Elem where
one := ⟨0⟩
mul_assoc := by
intro a b c
simp [HMul.hMul, Mul.mul]
ac_rfl
mul_one := by
intro a
simp [HMul.hMul, Mul.mul, OfNat.ofNat]
congr
ring
one_mul := by
intro a
simp [HMul.hMul, Mul.mul, OfNat.ofNat]
congr
ring
def main : IO Unit :=
let tree : BoundedSegmentTree Elem _ := empty_tree 10
let tree := update tree 3 (Function.const Elem ⟨2⟩)
let tree := update tree 5 (Function.const Elem ⟨5⟩)
let ans := query tree 2 6
IO.println s!"{ans.val}"
Takeaways
I could probably have used a better object language
As with a majority of forklore algorithms in competitive programming parlance, the ‘reference’ implementation is written in C++. Although it is still possible to write code closely mirroring the imperative style of C++ in Lean, I felt like I was going against the grain of the language a lot of the time. Having an explicitly defined object language (as compared to shallowly embedding the algorithm in Lean) would have taken a lot more work
On a related note: the formalization does not prove that either operation is O(log n). Here, being able
to wrap the computation in some sort of monad that could track operation costs would be helpful.
Unfortunately, we are not dependent types yet
Somehow, the ergonomics of using dependent types in my object code could still use some work. I remember having
a lot of trouble with the proof objects needed to index into arrays, and how they could cause trouble when
rewriting because I couldn’t guarantee that the two proof objects for the statement i < xs.length were
definitionally equal, and not just equal by proof equivalence. They also caused me trouble because one statement
needed i < xs.length while the other involved j < xs.length, and I couldn’t rewrite i = j using rw like I knew how
because it causes all the dependently typed expressions to mismatch.
That, or I just didn’t pay attention to the tutorial book explaining all the subtleties between rw and simp
and apply. Oh well…
I also tried modelling the element array as a Subtype, instead of having the length constraint be a separate field of
the tree structure (i.e. in helementsLength), but having to coerce such an element array back to a normal Array
got in the way of both my code and my proofs so many times that I simply ended up with the current approach.
Conclusion
This is my first time seriously trying out Lean for something that isn’t a tutorial, so I definitely ended up doing some things the unidiomatic way. Regardless, I feel like I learned quite a bit by working through this formalization, including finally learning why this algorithm works in ways I could connect to previous knowledge, and I’m definitely interested in contributing to similar projects in the future.
The complete proof script is available at this repo.
-
Probably also why it’s a lot more cache-friendly. ↩︎
-
Technically, the only requirement is
Semigroup α, but it is easy to extend such cases to theMonoidcase anyway. ↩︎ -
The implementation compared to a node-based segment tree is a lot simpler. The downside is that a typical BIT is only applicable to commutative groups, making it less general. ↩︎