Efficient, easy, and formally verified segment trees

I’ve been spending the past month or so formalizing a folklore optimized implementation of segment trees, a workhorse data structure in competitive programming. Compared to the implementation I learned when I was starting out (similar to the one in Implementation), this has the advantage of only needing to store 2 * n elements, compared to 4 * n, and being conceptually much simpler than the ‘Memory efficient implementation’ linked above.

Implementation

The common implementation that explicitly allocates a node (whether on the heap, or within an struct-of-arrays data structure) can be modeled as follows:

inductive SegmentTree (α : Type) where
  | Leaf (element : α) : SegmentTree α
  | Internal (length : Nat) (total : α) (left : SegmentTree α) (right : SegmentTree α) : SegmentTree α

Reasoning with this could be easier, since we gain the ability to do proper structural induction, but the array indexing logic is a good part of what made the implementation so inscrutable to me1. so I decided to instead model a segment tree as a flat list of values:

structure SegmentTree (α : Type) : Type where
  origLen : Nat
  elements : Array α
  helementsLength : elements.size = 2 * origLen

helementsLength is required to discharge some of the bounds checks in our operations later, so we include it in the definition of the data structure.

The query function is a simple translation of the for loop in the C++ code to a recursive function. The only new additions are:

  • arr_index_to_segtree_index, a helper to translate bounds in terms of the original array’s length to bounds in terms of the elements array’s length, which includes internal nodes.
  • a requirement that the underlying type α has a Monoid instance, i.e. that it should support an associate operation * with an identity, which generalizes over almost2 all operations that can be aggregated using a segment tree.
def query' [Monoid α] (tree : SegmentTree α) (i j : Nat) (accl accr : α) (hi : i < tree.elements.size := by get_elem_tactic) (hj : j <= tree.elements.size := by get_elem_tactic)  : α :=
  if j ≤ i
  then accl * accr
  else
    let new_accl :=
      if i % 2 = 1
      then accl * tree.elements[i]
      else accl
    let new_accr :=
      if j % 2 = 1
      then tree.elements[j - 1] * accr
      else accr
    query' tree ((i + 1) / 2) (j / 2) new_accl new_accr

def query [Monoid α] (tree : SegmentTree α) (i j : Nat) (hi : i < tree.origLen := by get_elem_tactic) (hj : j <= tree.origLen := by get_elem_tactic) : α :=
  let hi := arr_index_to_segtree_index tree i hi
  let hj := arr_index_to_segtree_index' tree j hj
  query' tree (i + tree.origLen) (j + tree.origLen) 1 1

The update function is a slightly less straightforward, but still simple translation of the for loop:

def rebuild [Monoid α] (segtree : SegmentTree α) (i : Nat) (hi : i < segtree.elements.size) :=
  if i = 0
  then segtree
  else
    let j := i / 2
    let j_c2 := 2 * j + 1
    let hj_c2 : j_c2 < segtree.elements.size := by
      unfold j_c2 j
      have hsize : segtree.elements.size % 2 = 0 := by
        apply segtree_elements_even_size
      omega
    let value_to_set := segtree.elements[2 * j] * segtree.elements[j_c2]
    rebuild (set_index j value_to_set segtree) j (by unfold set_index; simp; omega)

def update [Monoid α] (tree : SegmentTree α) (i : Nat) (hi : i < tree.origLen) (f : α -> α) :=
  let hin : i + tree.origLen < tree.elements.size := arr_index_to_segtree_index tree i hi
  let current_value := tree.elements[i + tree.origLen]
  let tree := set_index (i + tree.origLen) (f current_value) tree hin
  let hin : i + tree.origLen < tree.elements.size := arr_index_to_segtree_index tree i hi
  rebuild tree (i + tree.origLen) hin

Here we have to do a bit more legwork to get all the bounds checks to pass. We use the fact that the elements array is even to prove that 2 * j + 1 is in-bounds, and we use a helper, set_index, to set an element in the underlying array while preserving the length constraint.

Basic Concepts

This is the main structural invariant of this segment tree, which essentially states that for every node in the tree, the value we store is the product of the values of its children. All other correctness proofs will assume this:

def TreeCorrect [Monoid α] (segtree : SegmentTree α) : Prop :=
  ∀ j : Nat,
(hj : 2 * j + 1 < segtree.elements.size),
  j > 0 -> segtree.elements[j] = segtree.elements[2 * j] * segtree.elements[2 * j + 1]

We will also need a Prop to represent the segment tree “storing” an array, which corresponds to the elements of the array being stored in the leaves of the tree:

def Representative [Monoid α] (segtree : SegmentTree α) (xs : List α) : Prop :=
  TreeCorrect segtree ∧
  xs.length = segtree.origLen ∧
(i : Fin xs.length) (heq : i + segtree.origLen < segtree.elements.size),
    segtree.elements[i + segtree.origLen] = xs[i]

Correctness of updates

The main correctness theorem for update should state that if a segment tree stores a particular array xs, the array stored after the update should be xs, with the relevant element modified. In Lean, this looks like this:

theorem update_representative [Monoid α] : 
(tree : SegmentTree α) (i : Nat) (hi : i < tree.origLen) (f : α -> α) (xs : List α),
    Representative tree xs -> Representative (update tree i hi f) (List.modify xs i f)

update modifies one element of the tree, then walks up the tree via rebuild to restore the TreeCorrect invariant. Representative is a conjunction of three Props:

  • The second is that the elements array has the correct length. This follows from simple induction.
  • The third is that the tree’s leaf elements correspond to the array. This follows from the fact that rebuild does not modify these leaf elements, which also follows from simple induction.
  • The remaining one is that the tree is TreeCorrect. To prove this, we prove that rebuild maintains a temporary invariant: that the tree is TreeCorrect aside from one element, a ‘hole’:
def TreeCorrectHole [Monoid α] (segtree : SegmentTree α) (i : Nat) : Prop :=
  ∀ j : Nat,
(hj : 2 * j + 1 < segtree.elements.size),
  j > 0 -> j ≠ i -> segtree.elements[j] = segtree.elements[2 * j] * segtree.elements[2 * j + 1]

Then, rebuild’s correctness can be stated as such: when given an almost-correct tree, with the ‘hole’ in the right spot, it returns a correct tree, which is expressed in the following code:

lemma rebuild_correctness [Monoid α] :
(tree : SegmentTree α) (i : Nat) (hi : i < tree.elements.size),
      TreeCorrectHole tree (i / 2) -> TreeCorrect (rebuild tree i hi)

This is enough to prove that the resulting tree from update is TreeCorrect.

Correctness of queries

The main correctness theorem for query should state that the result of query equals the result of taking the product over the relevant subarray of the array it represents:

theorem query_correctness [Monoid α] :(tree : SegmentTree α) (xs : List α) (i j : Nat) hi hj,
  Representative tree xs -> query tree i j hi hj = List.prod (List.take (j - i) (xs.drop i))

Note that we use List.prod over Mathlib’s Finprod.sum, since the latter requires that α be a commutative monoid, which does not necessarily hold in this setting.

To prove this, we first generalize from the stored element array to arbitrary functions Nat -> Nat that have a TreeCorrect-like structure: (I’m sure the following can also be done directly on the element array, but dealing with all the bounds obligations really pissed me off.)

def isTreeMapping [Monoid α] (f : Nat -> α) : Prop :=
  ∀ i, 0 < i -> f i = f (2 * i) * f (2 * i + 1)

def takeRangeSum' [Monoid α] (f : Nat -> α) (i j : Nat) (accl accr : α) : α :=
  if j ≤ i
  then accl * accr
  else
    let new_accl :=
      if i % 2 = 1
      then accl * f i
      else accl
    let new_accr :=
      if j % 2 = 1
      then f (j - 1) * accr
      else accr
    takeRangeSum' f ((i + 1) / 2) (j / 2) new_accl new_accr

def takeRangeSum [Monoid α] (f : Nat -> α) (i j : Nat) : α :=
  takeRangeSum' f i j 1 1

def takeLineSum [Monoid α] (f : Nat -> α) (i j : Nat) : α :=
  let indices := List.range' i (j - i)
  List.foldr (fun x y => x * y) 1 (List.map f indices)

theorem takeRangeSum_correctness [Monoid α] (f : Nat -> α) :
    isTreeMapping f ->
    ∀ i j, j <= 2 * i -> takeRangeSum f i j = takeLineSum f i j

The proof proceeds by induction on j, with the help of a few legwork lemmas about takeLineSum. To be completely honest, this was the most interesting part of this implementation, so to have its correctness be proven by simple induction felt really anti-climactic to me, especially because it did not reflect the intuition I’ve built thinking about this data structure at all!

Anyway, with a few helper lemmas, we can reduce query_correctness to takeRangeSum_correctness:

lemma list_range_sum [Monoid α] :(tree : SegmentTree α) (xs : List α) (i j : Nat),
  i < tree.origLen -> j <= tree.origLen ->
  Representative tree xs -> List.prod (List.take (j - i) (xs.drop i)) = takeLineSum (extended_values tree.elements) (i + tree.origLen) (j + tree.origLen)

lemma query_equivalent_takeRangeSum [Monoid α] :(tree : SegmentTree α) (i j : Nat) hi hj,
  query tree i j hi hj = takeRangeSum (extended_values tree.elements) (i + tree.origLen) (j + tree.origLen)

completing the proof of correctness for queries.

What’s really happening

As I mentioned earlier, over the course of this project I ended up developing a different intuition to how the data structure works compared to what I ended up formalizing.

I’ve known for a while that it’s possible to split an interval [l, r) of integers into O(log n) ranges, each of which can be described as the set of the numbers x such that x >> k == n for some n, k. For instance, partitioning the range [77, 102] is done as follows:

range start range end range mask (binary)
77 77 1001101
78 79 100111?
80 95 101????
96 99 11000??
100 101 110010?
102 102 1100101

If you look at the structure of the internal nodes in the segment tree, you will notice that the ranges of array elements covered by each internal node can be written in this form. In fact, all query does is partition the input range [l + n, r + n) into subranges with this form, then use the precomputed sums of each subrange to get the sum of the whole range! This interested me because the way I usually implement deriving such a partition is more complicated than query. However, in this setting, we know that, given the input interval [l, r) to partition, 2 * l <= r, which helps simplify assumptions (e.g. among all your partitions, the popcount of any number is either x or x + 1, for some x).

This especially interested me, because in the special case where n is a power of 2, and l = 0, then the partitions are fully decided by the set bits in r. In this special case, our data structure becomes essentially identical to a binary-indexed tree (BIT), the other usual choice for solving this type of problem3. This also somewhat applies to the general case, as there, the partitions are decided by the set bits of l, which are used to create partitions in increasing order, and the set bits of r, which are used to create partitions in decreasing order. This creates a view of segment trees as a collection of mini-BITs, and a BIT can be seen as compressing this collection to only take the largest one. I should probably write more about what this means later on.

Usage

Anyway, the data structure itself is usable. Internally, it uses Lean arrays, which are contiguous in memory, leading to better performance for our random access pattern compared to the typical linked lists you see in functional languages. I’m unsure if it’s guaranteed that the array is never copied assuming the caller only keeps one reference to the segment tree; this is something I’ll need to verify once I’m better aware of Lean’s runtime model. I did end up defining this helper type to make the size of the tree immediately inferrable for purposes of proving bounds obligations, as opposed to having to simplify the tree back to the original definition with empty_tree:

def BoundedSegmentTree (α : Type) (n : Nat) := { tree : ArrayTree.SegmentTree α // tree.origLen = n }

This is how it can be used in practice, along with a demonstration of how to implement custom Monoid instances to change the internal operation:

structure Elem where
  val : Int
deriving Repr

instance : Mul Elem where
  mul (p q : Elem) := ⟨p.val + q.val⟩

instance : One Elem where
  one := ⟨0
instance : Monoid Elem where
  one := ⟨0  mul_assoc := by
    intro a b c
    simp [HMul.hMul, Mul.mul]
    ac_rfl
  mul_one := by
    intro a
    simp [HMul.hMul, Mul.mul, OfNat.ofNat]
    congr
    ring
  one_mul := by
    intro a
    simp [HMul.hMul, Mul.mul, OfNat.ofNat]
    congr
    ring

def main : IO Unit :=
  let tree : BoundedSegmentTree Elem _ := empty_tree 10
  let tree := update tree 3 (Function.const Elem ⟨2)
  let tree := update tree 5 (Function.const Elem ⟨5)
  let ans := query tree 2 6
  IO.println s!"{ans.val}"

Takeaways

I could probably have used a better object language

As with a majority of forklore algorithms in competitive programming parlance, the ‘reference’ implementation is written in C++. Although it is still possible to write code closely mirroring the imperative style of C++ in Lean, I felt like I was going against the grain of the language a lot of the time. Having an explicitly defined object language (as compared to shallowly embedding the algorithm in Lean) would have taken a lot more work

On a related note: the formalization does not prove that either operation is O(log n). Here, being able to wrap the computation in some sort of monad that could track operation costs would be helpful.

Unfortunately, we are not dependent types yet

Somehow, the ergonomics of using dependent types in my object code could still use some work. I remember having a lot of trouble with the proof objects needed to index into arrays, and how they could cause trouble when rewriting because I couldn’t guarantee that the two proof objects for the statement i < xs.length were definitionally equal, and not just equal by proof equivalence. They also caused me trouble because one statement needed i < xs.length while the other involved j < xs.length, and I couldn’t rewrite i = j using rw like I knew how because it causes all the dependently typed expressions to mismatch.

That, or I just didn’t pay attention to the tutorial book explaining all the subtleties between rw and simp and apply. Oh well…

I also tried modelling the element array as a Subtype, instead of having the length constraint be a separate field of the tree structure (i.e. in helementsLength), but having to coerce such an element array back to a normal Array got in the way of both my code and my proofs so many times that I simply ended up with the current approach.

Conclusion

This is my first time seriously trying out Lean for something that isn’t a tutorial, so I definitely ended up doing some things the unidiomatic way. Regardless, I feel like I learned quite a bit by working through this formalization, including finally learning why this algorithm works in ways I could connect to previous knowledge, and I’m definitely interested in contributing to similar projects in the future.

The complete proof script is available at this repo.


  1. Probably also why it’s a lot more cache-friendly. ↩︎

  2. Technically, the only requirement is Semigroup α, but it is easy to extend such cases to the Monoid case anyway. ↩︎

  3. The implementation compared to a node-based segment tree is a lot simpler. The downside is that a typical BIT is only applicable to commutative groups, making it less general. ↩︎

astrallexicon

the embers of a falling star