『アルゴリズムC〈第3巻〉グラフ・数理・トピックス』のp34で紹介されている「合併-発見アルゴリズム」をPythonで作ってみました。
「合併-発見アルゴリズム」は、以下のような用途で使われているようです。
- グラフにおいて、その具体的な道を求める必要はないが、頂点xが頂点yと繋がっているかどうかを知りたい
- 一般の集合演算の処理に利用
(1) 単純(ナイーブ)な方法
配列dad
の添え字で親子関係を表現している。例えば、dad[4]=>2
なら、4の親が2となる。dad[2]=>-1
なら要素2
が集合の根(負数)を表す。要素x
と要素y
が同じ集合に属しているかを調べる場合は、x
が属している集合の根の値とy
が属している根の値を比較することで調べることができる。
class UnionFind: def __init__(self, n): self.dad = [-1 for _ in range(n)] def find(self, x, y, doit=False): i, j = x, y while self.dad[i] >= 0: i = self.dad[i] while self.dad[j] >= 0: j = self.dad[j] if doit and i != j: self.dad[j] = i return i == j
(2) 道短縮法(path compression)
例えば、dad[8]=>4, dad[4]=>2, dad[2]=>-1
のとき、8が属している集合の根は2となるが、3回配列を調べる必要がある。これをdad[8]=>2
とすることで、探索回数を減らすことができる。
class UnionFind: def __init__(self, n): self.dad = [-1 for _ in range(n)] def find(self, x, y, doit=False): i, j = x, y while self.dad[i] >= 0: i = self.dad[i] while self.dad[j] >= 0: j = self.dad[j] # 道短縮法(path compression) while self.dad[x] >= 0: t = x x = self.dad[x] self.dad[t] = i while self.dad[y] >= 0: t = y y = self.dad[y] self.dad[t] = j if doit and i != j: self.dad[j] = i return i == j
(3) 重さ均衡法(weight balancing)
(1),(2)の方法では、2つの集合を合併するとき、常にdad[j] = i
としていたが、根に重みを持たせることにより、子孫の数が多い方を根とする。
class UnionFind: def __init__(self, n): self.dad = [-1 for _ in range(n)] def find(self, x, y, doit=False): i, j = x, y while self.dad[i] >= 0: i = self.dad[i] while self.dad[j] >= 0: j = self.dad[j] # 道短縮法(path compression) while self.dad[x] >= 0: t = x x = self.dad[x] self.dad[t] = i while self.dad[y] >= 0: t = y y = self.dad[y] self.dad[t] = j # 重さ均衡法(weight balancing) if doit and i != j: if self.dad[j] < self.dad[i]: self.dad[j] += self.dad[i] - 1 self.dad[i] = j else: self.dad[i] += self.dad[j] - 1 self.dad[j] = i return i == j
実行の様子について
(1)〜(3)について、以下を実行してみます。
if __name__ == "__main__": import pprint uf = UnionFind(10) # 0, 2, 4を同じグループ(A)に所属させる。 uf.find(0, 2, True) pprint.pprint(uf.dad) uf.find(0, 4, True) pprint.pprint(uf.dad) # 6, 8を同じグループ(B)に所属させる。 uf.find(6, 8, True) pprint.pprint(uf.dad) # 1, 3, 5, 7, 9を同じグループ(C)に所属させる。奇数グループ for i in range(3, 10, 2): uf.find(i - 2, i, True) pprint.pprint(uf.dad) uf.find(0, 8, True) # (A)と(B)を統合する。偶数グループ pprint.pprint(uf.find(0, 8)) pprint.pprint(uf.dad) pprint.pprint(uf.find(1, 8)) pprint.pprint(uf.dad) uf.find(1, 8, True) # 偶奇統合。自然数グループへ pprint.pprint(uf.find(1, 8)) pprint.pprint(uf.dad)
ナイーブな方法
[-1, -1, 0, -1, -1, -1, -1, -1, -1, -1] # 0,2が同じグループへ [-1, -1, 0, -1, 0, -1, -1, -1, -1, -1] # 0,2,4が同じグループへ [-1, -1, 0, -1, 0, -1, -1, -1, 6, -1] # 6,8が同じグループへ [-1, -1, 0, 1, 0, -1, -1, -1, 6, -1] # 1,3,5,7,9が同じグループへ [-1, -1, 0, 1, 0, 1, -1, -1, 6, -1] [-1, -1, 0, 1, 0, 1, -1, 1, 6, -1] [-1, -1, 0, 1, 0, 1, -1, 1, 6, 1] # {0,2,4}, {6,8}, {1,3,5,7,9}の3グループ True # {0,2,4}+{6,8}後、0,8は同じグループか? [-1, -1, 0, 1, 0, 1, 0, 1, 6, 1] False # 1,8は同じグループか? [-1, -1, 0, 1, 0, 1, 0, 1, 6, 1] True # 偶奇統合後、1,8は同じグループか? [1, -1, 0, 1, 0, 1, 0, 1, 6, 1] # 8の親は6,6の親は1 (8→6→1)
道短縮法
[-1, -1, 0, -1, -1, -1, -1, -1, -1, -1] [-1, -1, 0, -1, 0, -1, -1, -1, -1, -1] [-1, -1, 0, -1, 0, -1, -1, -1, 6, -1] [-1, -1, 0, 1, 0, -1, -1, -1, 6, -1] [-1, -1, 0, 1, 0, 1, -1, -1, 6, -1] [-1, -1, 0, 1, 0, 1, -1, 1, 6, -1] [-1, -1, 0, 1, 0, 1, -1, 1, 6, 1] True [-1, -1, 0, 1, 0, 1, 0, 1, 0, 1] False [-1, -1, 0, 1, 0, 1, 0, 1, 0, 1] True [1, -1, 0, 1, 0, 1, 0, 1, 1, 1] # 8の親は1(8→1) 短縮されている!
重さ均衡法
[-3, -1, 0, -1, -1, -1, -1, -1, -1, -1] [-5, -1, 0, -1, 0, -1, -1, -1, -1, -1] [-5, -1, 0, -1, 0, -1, -3, -1, 6, -1] [-5, -3, 0, 1, 0, -1, -3, -1, 6, -1] [-5, -5, 0, 1, 0, 1, -3, -1, 6, -1] [-5, -7, 0, 1, 0, 1, -3, 1, 6, -1] [-5, -9, 0, 1, 0, 1, -3, 1, 6, 1] True [-9, -9, 0, 1, 0, 1, 0, 1, 0, 1] False [-9, -9, 0, 1, 0, 1, 0, 1, 0, 1] True [1, -19, 0, 1, 0, 1, 0, 1, 1, 1] # 根(値が負)に重みを持たせている。
データ数を増やして実行時間を計測
データの数を100000
として実行速度を調べてみました。
import random import time import union_find1 import union_find2 import union_find3 N = 10**5 print("#### (1)ナイーブな方法") uf = union_find1.UnionFind(N) start = time.perf_counter() for i in range(N): x = random.randint(0, N - 1) y = random.randint(0, N - 1) uf.find(x, y, True) print(time.perf_counter() - start) print("#### (2)道短縮法(path compression)") uf = union_find2.UnionFind(N) start = time.perf_counter() for i in range(N): x = random.randint(0, N - 1) y = random.randint(0, N - 1) uf.find(x, y, True) print(time.perf_counter() - start) print("#### (3)道短縮法(path compression) + 重さ均衡法(weight balancing)") uf = union_find3.UnionFind(N) start = time.perf_counter() for i in range(N): x = random.randint(0, N - 1) y = random.randint(0, N - 1) uf.find(x, y, True) print(time.perf_counter() - start)
実行結果
(2)と(3)には顕著な速度差が見られませんでしたが、(1)の遅さは際立っていました。
#### (1)ナイーブな方法 18.461921999999998 #### (2)道短縮法(path compression) 0.2709850830000029 #### (3)道短縮法(path compression) + 重さ均衡法(weight balancing) 0.2369107500000034
AtCoder Beginner Contest 269 D
作成したUnionFind
を使って以下の問題を解いてみました。