永続セグ木と区間mexクエリの話

この記事でやること

  • 全永続セグ木のPython (PyPy3) による非再帰実装
  • 区間mexクエリのオフライン処理方法の解説

全永続セグ木って?

セグ木を全永続化したものです。

1点更新をする際にバージョンを名付けておき、以降の1点更新や区間演算を指定したバージョンのセグ木上で行うことができます。計算量はいずれも元のセグ木と変わらず  O(\text{log}N) 。すごい。*1

実現方法は、こちらのブログ の解説が非常に簡潔でわかりやすいと思いました。(丸投げ)

根を選ぶだけでバージョンが指定できるため、そこからは元のセグ木と同様に二分探索もできます。

全永続セグ木の実装例

PyPy提出を想定して非再帰で書いてます。木dpを非再帰で書く方法 に近いイメージで、dfsで通った頂点を \text{tps}に積んでおき、帰りがけは \text{tps}の逆順にまとめて処理しています。

class PersistentSegTree:
    class Node:
        def __init__(self, L, R) -> None:
            self.L = L
            self.R = R
            self.x = None
            self.left = None
            self.right = None
            self._calc = None
    def __init__(self, array, op, e, original_v=0) -> None:
        self.length = len(array)
        self.size = 1<<(self.length-1).bit_length()
        self.op = op
        self.e = e
        self.original_v = original_v
        self.version = {original_v: self.Node(0, self.size)}
        stack = [self.version[self.original_v]]
        tps = []
        while stack:
            node = stack.pop()
            if node.R-node.L > 1:
                tps.append(node)
                node.left = self.Node(node.L, (node.L+node.R)>>1)
                stack.append(node.left)
                node.right = self.Node((node.L+node.R)>>1, node.R)
                stack.append(node.right)
            else:
                node.x = array[node.L] if node.L < self.length else self.e
        for node in reversed(tps):
            node.x = self.op(node.left.x, node.right.x)
    def set(self, old_v, new_v, i, x):
        assert old_v in self.version and new_v not in self.version and 0 <= i < self.length
        self.version[new_v] = self.Node(0, self.size)
        old_node, new_node = self.version[old_v], self.version[new_v]
        tps = []
        while new_node.R-new_node.L > 1:
            tps.append(new_node)
            m = (new_node.L+new_node.R)>>1
            if new_node.L <= i < m:
                new_node.left = self.Node(new_node.L, m)
                new_node.right = old_node.right
                old_node, new_node = old_node.left, new_node.left
            else:
                new_node.left = old_node.left
                new_node.right = self.Node(m, new_node.R)
                old_node, new_node = old_node.right, new_node.right
        new_node.x = x
        for node in reversed(tps):
            node.x = self.op(node.left.x, node.right.x)
    def get(self, v, i):
        assert v in self.version and 0 <= i < self.length
        node = self.version[v]
        while node.R-node.L > 1:
            if node.L <= i < (node.L+node.R)>>1:
                node = node.left
            else:
                node = node.right
        return node.x
    def all_prod(self, v):
        assert v in self.version
        return self.version[v].x
    def prod(self, v, l, r):
        assert v in self.version
        if not 0 <= l < r <= self.length: return self.e
        stack = [(self.version[v], l, r)]
        tps = []
        while stack:
            node, l, r = stack.pop()
            m = (node.L+node.R)>>1
            if node.L == l and r == node.R:
                node._calc = node.x
            elif r <= m:
                tps.append((node, 0))
                stack.append((node.left, l, r))
            elif m <= l:
                tps.append((node, 1))
                stack.append((node.right, l, r))
            else:
                tps.append((node, 2))
                stack.append((node.left, l, m))
                stack.append((node.right, m, r))
        for node, look in reversed(tps):
            if look == 0:
                node._calc = node.left._calc
            elif look == 1:
                node._calc = node.right._calc
            else:
                node._calc = self.op(node.left._calc, node.right._calc)
        return self.version[v]._calc
    def max_right(self, v, l, f):
        assert v in self.version and 0 <= l < self.length and f(self.e)
        tps = []
        node = self.version[v]
        while node.L != l:
            if l < (node.L+node.R)>>1:
                tps.append(node.right)
                node = node.left
            else:
                node = node.right
        tps.append(node)
        x = self.e
        for node in reversed(tps):
            if not f(self.op(x, node.x)):
                while node.R-node.L > 1:
                    if f(self.op(x, node.left.x)):
                        x = self.op(x, node.left.x)
                        node = node.right
                    else:
                        node = node.left
                return node.L
            x = self.op(x, node.x)
        return self.length
    def min_left(self, v, r, f):
        assert v in self.version and 0 < r <= self.length and f(self.e)
        tps = []
        node = self.version[v]
        while node.R != r:
            if (self.L+self.R)>>1 < r:
                tps.append(node.left)
                node = node.right
            else:
                node = node.left
        tps.append(node)
        x = self.e
        for node in reversed(tps):
            if not f(self.op(node.x, x)):
                while node.R-node.L > 1:
                    if f(self.op(node.right.x, x)):
                        x = self.op(node.right.x, x)
                        node = node.left
                    else:
                        node = node.right
                return node.R
            x = self.op(node.x, x)
        return 0

 

区間mexクエリの処理方法

 \text{mex}とは、非負整数の集合を引数に取る関数で、「集合に含まれない最小の非負整数」を返します。

競プロでは、特にGrundy数の定義に登場することでよく知られています。

 

以下、長さ Nの非負整数列 Aに対し、 \text{mex}(A[l, r)) を求めるクエリを考えてみます。モノイドでないので、そのままセグ木に乗ったりはしません。悲しい。

オフラインの場合

 0 \le i \lt N について、 A[0, r)における i の最大のindex(ないなら-∞)を S_r[i] と表します。このとき、

 i \in A[l, r) \Leftrightarrow S_r [i] \ge l

であることに注意すると、 \text{min}(S_r [0, i)) \lt l を満たす最小の非負整数 i \text{mex}(A[l, r)) であるとわかります。

これは、 S_rをRmQセグ木で持つことで二分探索により求められる*2ので、クエリを rについて昇順に見ていきながら適切にセグ木を更新していくことで、全体 O((N+Q)\text{log}N)で解けました。

オンラインの場合

オフラインの場合における S_rを永続セグ木のバージョン rとするだけです。計算量は変わらず O(N+Q)\text{log}N

実装例

class RangeMexQuery:
    def __init__(self, array) -> None:
        self.pst = PersistentSegTree([-inf]*len(array), min, inf, 0)
        for r, a in enumerate(array):
            if a < len(array): self.pst.set(r, r+1, a, r)
    def prod(self, l, r):
        assert 0 <= l < r <= self.pst.length
        return self.pst.max_right(r, 0, lambda x: x >= l)

 

終わりに

そのうち『永続セグ木と区間K-thクエリの話』も書きたい…。

 

参考文献

熨斗袋さんのツイート

永続データ構造 - Wikipedia

永続セグメント木・永続遅延セグメント木 – 37zigenのHP

非再帰セグ木上の任意始点にぶたん - えびちゃんの日記

非再帰BFSでトポソから木DPをする - Qiita

Grundy数(Nim数, Nimber)の理論

*1:空間計算量は  O(N+Q\text{log}N) に悪化します。

*2:非再帰セグ木上の任意始点にぶたん - えびちゃんの日記 の表記を借りると、述語を p(x) := x \ge l とすればよいです。