Mae向きなブログ

Mae向きな情報発信を続けていきたいと思います。

kd木

algorithm - 最近点検索をkd-treeで」を読んで,kd木というアルゴリズムを知りました。非常に面白いです。

ちゃんと理解したかったので,上記のページを参考にRubyで書いてみました。実際,書いてみると理解が深まるような気がします.
例えば,pts2kdtreeメソッド内で,配列をソートしているのは,2分木のバランスを保つためなんだろうなとか,kd木を作れても,最近点を検索するのは結構,難しいなとか。

汚いソースですが,以下に載せます。

kdtree.rb

class Point
  attr_accessor :x, :y
  def initialize(x, y)
    @x = x
    @y = y
  end
  def to_s
    '(' + @x.to_s + ', ' + @y.to_s + ')'
  end
  def dsquare(p)
    dx = @x - p.x
    dy = @y - p.y
    dx * dx + dy * dy
  end
  def closer(a, b)
    self.dsquare(a) < self.dsquare(b) ? a : b
  end
end

def pts2kdtree(ary, depth)
  return nil if ary.size == 0
  result = { }
  if ary.size == 1
    result[:mid] = ary[0]
    result[:head] = nil
    result[:tail] = nil
    return result
  end
  axis = depth % 2 == 1 ? 'y' : 'x';
  copy = ary.sort { |a, b|
    eval("a." + axis) <=> eval("b." + axis)
  }
  mid = copy.size >> 1
  result[:mid]  = copy[mid]
  result[:head] = pts2kdtree(copy.slice(0, mid), depth + 1)
  result[:tail] = pts2kdtree(copy.slice(mid + 1, copy.size), depth + 1)
  return result
end

def kdtree_find(qt, pt, depth)
  axis = depth % 2 == 1 ? 'y' : 'x';
  return nil if qt == nil
  return qt[:mid] if is_leaf(qt)
  leaf = nil
  if eval("pt." + axis) < eval("qt[:mid]." + axis)
    if qt[:head]
      if is_leaf(qt[:head])
        leaf = qt[:head][:mid]
      else
        leaf = kdtree_find(qt[:head], pt, depth + 1)
      end
    else
      leaf = kdtree_find(qt[:tail], pt, depth + 1)
    end
    if qt[:head] && qt[:tail] &&
        pt.dsquare(leaf) > (eval("pt." + axis) - eval("qt[:mid]." + axis))**2
      leaf = pt.closer(leaf, kdtree_find(qt[:tail], pt, depth + 1))
    end
  else
    if qt[:tail]
      if is_leaf(qt[:tail])
        leaf = qt[:tail][:mid]
      else
        leaf = kdtree_find(qt[:tail], pt, depth + 1)
      end
    else
      leaf = kdtree_find(qt[:head], pt, depth + 1)
    end
    if qt[:tail] && qt[:head] &&
        pt.dsquare(leaf) > (eval("pt." + axis) - eval("qt[:mid]." + axis))**2
      leaf = pt.closer(leaf, kdtree_find(qt[:head], pt, depth + 1))
    end
  end
  return pt.closer(qt[:mid], leaf)
end

def is_leaf(a)
  return true if a[:head] == nil && a[:tail] == nil
end

def randpoints(n, xmax, ymax)
  result = []
  n.times do
    result << Point.new(rand(xmax), rand(ymax))
  end
  result
end

def points2str(pts)
  result = []
  pts.each do |p|
    result << p.to_s
  end
  result.join(', ')
end

if __FILE__ == $0
  ary = randpoints(100, 1000, 1000)
  puts points2str(ary)
  kdtree = pts2kdtree(ary, 0) 
  puts "\nNearrest point => " + kdtree_find(kdtree, Point.new(500, 500), 0).to_s
end

実行結果

x, yの最大値を1000とし,100個の点を生成しました。(500, 500)の最近点を検索してみました。

$ ruby kdtree.rb
(948, 355), (210, 392), (967, 836), (183, 634), (413, 56), (318, 686), (233, 750), (267, 375), (397, 303), (837, 500), (476, 212), (377, 828), (897, 143), (953, 467), (61, 77), (119, 673), (785, 668), (350, 697), (696, 549), (783, 607), (345, 798), (151, 23), (46, 921), (939, 237), (757, 50), (913, 940), (144, 113), (617, 998), (214, 502), (359, 128), (519, 994), (845, 407), (98, 42), (596, 994), (299, 233), (219, 127), (120, 244), (839, 206), (43, 747), (341, 634), (305, 476), (701, 101), (260, 569), (549, 601), (60, 620), (333, 395), (612, 212), (276, 35), (144, 335), (543, 701), (902, 822), (250, 903), (196, 849), (806, 338), (747, 43), (840, 643), (323, 485), (114, 677), (427, 532), (274, 587), (282, 796), (305, 920), (368, 67), (641, 671), (823, 349), (253, 868), (782, 398), (137, 564), (35, 779), (629, 741), (207, 587), (289, 193), (984, 779), (736, 60), (393, 234), (656, 435), (696, 90), (163, 600), (438, 107), (897, 67), (472, 717), (582, 90), (876, 710), (826, 747), (766, 434), (723, 246), (861, 372), (362, 908), (599, 880), (133, 153), (93, 797), (970, 835), (767, 310), (608, 618), (142, 915), (475, 254), (263, 637), (252, 527), (996, 635), (452, 979)

Nearrest point => (427, 532)

続きです。