この記事でやること
- 全永続セグ木のPython (PyPy3) による非再帰実装
- 区間mexクエリのオフライン処理方法の解説
全永続セグ木って?
セグ木を全永続化したものです。
1点更新をする際にバージョンを名付けておき、以降の1点更新や区間演算を指定したバージョンのセグ木上で行うことができます。計算量はいずれも元のセグ木と変わらず 。すごい。*1
実現方法は、こちらのブログ の解説が非常に簡潔でわかりやすいと思いました。(丸投げ)
根を選ぶだけでバージョンが指定できるため、そこからは元のセグ木と同様に二分探索もできます。
全永続セグ木の実装例
PyPy提出を想定して非再帰で書いてます。木dpを非再帰で書く方法 に近いイメージで、dfsで通った頂点をに積んでおき、帰りがけはの逆順にまとめて処理しています。
class PersistentSegTree:
class Node:
def __init__(self, L, R) -> None:
self.L = L
self.R = R
self.x = None
self.left = None
self.right = None
self._calc = None
def __init__(self, array, op, e, original_v=0) -> None:
self.length = len(array)
self.size = 1<<(self.length-1).bit_length()
self.op = op
self.e = e
self.original_v = original_v
self.version = {original_v: self.Node(0, self.size)}
stack = [self.version[self.original_v]]
tps = []
while stack:
node = stack.pop()
if node.R-node.L > 1:
tps.append(node)
node.left = self.Node(node.L, (node.L+node.R)>>1)
stack.append(node.left)
node.right = self.Node((node.L+node.R)>>1, node.R)
stack.append(node.right)
else:
node.x = array[node.L] if node.L < self.length else self.e
for node in reversed(tps):
node.x = self.op(node.left.x, node.right.x)
def set(self, old_v, new_v, i, x):
assert old_v in self.version and new_v not in self.version and 0 <= i < self.length
self.version[new_v] = self.Node(0, self.size)
old_node, new_node = self.version[old_v], self.version[new_v]
tps = []
while new_node.R-new_node.L > 1:
tps.append(new_node)
m = (new_node.L+new_node.R)>>1
if new_node.L <= i < m:
new_node.left = self.Node(new_node.L, m)
new_node.right = old_node.right
old_node, new_node = old_node.left, new_node.left
else:
new_node.left = old_node.left
new_node.right = self.Node(m, new_node.R)
old_node, new_node = old_node.right, new_node.right
new_node.x = x
for node in reversed(tps):
node.x = self.op(node.left.x, node.right.x)
def get(self, v, i):
assert v in self.version and 0 <= i < self.length
node = self.version[v]
while node.R-node.L > 1:
if node.L <= i < (node.L+node.R)>>1:
node = node.left
else:
node = node.right
return node.x
def all_prod(self, v):
assert v in self.version
return self.version[v].x
def prod(self, v, l, r):
assert v in self.version
if not 0 <= l < r <= self.length: return self.e
stack = [(self.version[v], l, r)]
tps = []
while stack:
node, l, r = stack.pop()
m = (node.L+node.R)>>1
if node.L == l and r == node.R:
node._calc = node.x
elif r <= m:
tps.append((node, 0))
stack.append((node.left, l, r))
elif m <= l:
tps.append((node, 1))
stack.append((node.right, l, r))
else:
tps.append((node, 2))
stack.append((node.left, l, m))
stack.append((node.right, m, r))
for node, look in reversed(tps):
if look == 0:
node._calc = node.left._calc
elif look == 1:
node._calc = node.right._calc
else:
node._calc = self.op(node.left._calc, node.right._calc)
return self.version[v]._calc
def max_right(self, v, l, f):
assert v in self.version and 0 <= l < self.length and f(self.e)
tps = []
node = self.version[v]
while node.L != l:
if l < (node.L+node.R)>>1:
tps.append(node.right)
node = node.left
else:
node = node.right
tps.append(node)
x = self.e
for node in reversed(tps):
if not f(self.op(x, node.x)):
while node.R-node.L > 1:
if f(self.op(x, node.left.x)):
x = self.op(x, node.left.x)
node = node.right
else:
node = node.left
return node.L
x = self.op(x, node.x)
return self.length
def min_left(self, v, r, f):
assert v in self.version and 0 < r <= self.length and f(self.e)
tps = []
node = self.version[v]
while node.R != r:
if (self.L+self.R)>>1 < r:
tps.append(node.left)
node = node.right
else:
node = node.left
tps.append(node)
x = self.e
for node in reversed(tps):
if not f(self.op(node.x, x)):
while node.R-node.L > 1:
if f(self.op(node.right.x, x)):
x = self.op(node.right.x, x)
node = node.left
else:
node = node.right
return node.R
x = self.op(node.x, x)
return 0
区間mexクエリの処理方法
とは、非負整数の集合を引数に取る関数で、「集合に含まれない最小の非負整数」を返します。
競プロでは、特にGrundy数の定義に登場することでよく知られています。
以下、長さの非負整数列に対し、 を求めるクエリを考えてみます。モノイドでないので、そのままセグ木に乗ったりはしません。悲しい。
オフラインの場合
各 について、における の最大のindex(ないなら-∞)を と表します。このとき、
であることに注意すると、 を満たす最小の非負整数 が であるとわかります。
これは、をRmQセグ木で持つことで二分探索により求められる*2ので、クエリをについて昇順に見ていきながら適切にセグ木を更新していくことで、全体で解けました。
オンラインの場合
オフラインの場合におけるを永続セグ木のバージョンとするだけです。計算量は変わらず。
実装例
class RangeMexQuery:
def __init__(self, array) -> None:
self.pst = PersistentSegTree([-inf]*len(array), min, inf, 0)
for r, a in enumerate(array):
if a < len(array): self.pst.set(r, r+1, a, r)
def prod(self, l, r):
assert 0 <= l < r <= self.pst.length
return self.pst.max_right(r, 0, lambda x: x >= l)
終わりに
そのうち『永続セグ木と区間K-thクエリの話』も書きたい…。
参考文献
熨斗袋さんのツイート
永続データ構造 - Wikipedia
永続セグメント木・永続遅延セグメント木 – 37zigenのHP
非再帰セグ木上の任意始点にぶたん - えびちゃんの日記
非再帰BFSでトポソから木DPをする - Qiita
Grundy数(Nim数, Nimber)の理論