STUDY LOG/๐Ÿ‘ฉ๐Ÿป‍๐Ÿ’ป โฐ ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ

์•Œ๊ณ ๋ฆฌ์ฆ˜ ์Šคํ„ฐ๋”” Union-find & MST

jjsyeon 2025. 4. 17. 14:15

Union-Find


๊ฐœ๋…

  • ๊ทธ๋ž˜ํ”„์—์„œ ๋‘ ๋…ธ๋“œ๊ฐ€ ๊ฐ™์€ ๊ทธ๋ž˜ํ”„์— ์†ํ•˜๋Š”์ง€(์—ฐ๊ฒฐ๋˜์–ด ์žˆ๋Š”์ง€)๋ฅผ ํŒ๋ณ„
  • ์ƒํ˜ธ ๋ฐฐํƒ€์  ์ง‘ํ•ฉ(Disjoint Set)์ด๋ผ๊ณ ๋„ ํ•จ
  • ์‹œ๊ฐ„ ๋ณต์žก๋„
    • ๋Œ€ํ‘œ ๋…ธ๋“œ(๋ฃจํŠธ) ์ฐพ๊ธฐ find : $O( \alpha(N))$ โ†’ ์‹ค์ „์—์„œ ๊ฑฐ์˜ $O(N)$
    • ์ง‘ํ•ฉ ํ•ฉ์น˜๊ธฐ union : $O( \alpha(N))$ โ†’ ์‹ค์ „์—์„œ ๊ฑฐ์˜ $O(N)$

์ „์ฒด ์ฝ”๋“œ

def find(a):
    if parent[a] != a: parent[a] = find(parent[a])
    return parent[a]

def union(a,b):
    pa, pb = find(a), find(b)
    if a < b : parent[b] = parent[a]
    else : parent[a] = parent[b]

# ๊ทธ๋ž˜ํ”„๊ฐ€ ์ฃผ์–ด์กŒ๋‹ค๋ฉด
graph = [] #<- list in list
parent = [] #<- ๋…ธ๋“œ ์ˆ˜๋งŒํผ
result = set()
for i in range(len(graph)):
    for j in range(len(graph[i]):
        union(graph[i], graph[j])
            set.add(parent[i])

Union-Find ๊ตฌ์„ฑํ•˜๋Š” ํ•จ์ˆ˜๋“ค

find : ๋Œ€ํ‘œ ๋…ธ๋“œ(root) ์ฐพ๊ธฐ

def find(a, b):
    if parent[a] != a: parent[a] = find(parent[a])
    return parent[a]
  • if parent[a] != a โ†’ ๋‚ด ๋ถ€๋ชจ๊ฐ€ ๋‚ด๊ฐ€ ์•„๋‹ ๋•Œ = ๋‚˜ ์ž์‹ ์ด root๊ฐ€ ์•„๋‹๋•Œ
    • ๊ณ„์† ๋ถ€๋ชจ ํƒ€๊ณ  ์˜ฌ๋ผ๊ฐ€๋ฉด์„œ ์ง‘ํ•ฉ์˜ ๋Œ€ํ‘œ(root)๋ฅผ ์ฐพ์Œ

union : ์ง‘ํ•ฉ ํ•ฉ์น˜๊ธฐ

def union(a,b):
    pa, pb = find(a), find(b)
    if a == b : return
    elif a < b : parent[b] = parent[a]
    else : parent[a] = parent[b]
  • if a == b โ†’ ๋‘ ๊ฐœ์˜ ๋…ธ๋“œ๊ฐ€ ๊ฐ™์€ ์ง‘ํ•ฉ์— ์†ํ•จ
    • ๋‹ค์‹œ unionํ•  ํ•„์š” ์—†์œผ๋ฏ€๋กœ return
  • elif a < b & else : union ์ž‘์—… ํ•„์š”ํ•จ
    • ๋Œ€ํ‘œ์ž๋ฅผ ๋” ์ž‘์€ ๋…ธ๋“œ ๊ฐ’์œผ๋กœ ์„ค์ •ํ•˜๋Š” ์กฐ๊ฑด
    • ๊ทธ๋ƒฅ a != b ์ผ ๋•Œ a๋‚˜ b ์ค‘ ํ•˜๋‚˜๋กœ parent ํ†ต์ผ ์‹œ์ผœ์ค˜๋„ ์ƒ๊ด€์—†๊ธด ํ•จ
    • ๊ทผ๋ฐ ๋Œ€ํ‘œ์ž๋ฅผ ์ž‘์€ ์ˆ˜๋กœ ๋†“๋Š” ๊ฒŒ ๋” ์‹ค์šฉ์ 
      • ์ •๋ ฌ๋œ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ์Œ โ†’ ํ•ด์„์ด ์‰ฌ์›€
      • ์‚ฌ์ดํด ํŒ๋ณ„ ์‹œ ์•ˆ์ •์ ์ธ ๊ตฌ์กฐ
      • ํŠธ๋ฆฌ ๊นŠ์ด ์™„ํ™”

Union-Find๋ฅผ ๊ฐœ์„ ํ•ด๋ณด์ž!


Path Compression : ๊ฒฝ๋กœ ์••์ถ•

  • find ํ•  ๋•Œ ๊ฒธ์‚ฌ๊ฒธ์‚ฌ ํŠธ๋ฆฌ ๋†’์ด๋„ ํ‰ํ‰ํ•˜๊ฒŒ ํ•˜์ž
  • find๋กœ root ์ฐพ์€ ํ›„์— ์ด root๋ฅผ parent์— ์ €์žฅํ•ด์„œ ๋‹ค์Œ์— findํ• ๋•Œ๋Š” ๋ฐ”๋กœ ์ฐพ์„ ์ˆ˜ ์žˆ๊ฒŒ ํ•จ
  • ๋Œ€ํ‘œ๊ฐ€ ์•„๋‹Œ ๋…ธ๋“œ๋ฅผ parent๋กœ ์—ฐ๊ฒฐํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ root์— ์ง์ ‘ ์—ฐ๊ฒฐํ•˜์ž
def find(a):
    if parent[a] != a: parent[a] = find(parent[a])  # ๊ฒฝ๋กœ ์••์ถ•
    return parent[a]

Union by Rank : ๋žญํฌ ๊ฐœ๋… ์‚ฌ์šฉํ•˜๊ธฐ

  • Rank : ๊ฐ ์ง‘ํ•ฉ์˜ ํŠธ๋ฆฌ ํฌ๊ธฐ๋ฅผ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•œ ์ˆ˜๋กœ ํŠธ๋ฆฌ์˜ ๋†’์ด๋ฅผ ๋งํ•จ
  • ๋‘ ๊ฐœ์˜ ์ง‘ํ•ฉ์„ Unionํ•  ๋•Œ ๋†’์ด๊ฐ€ ๋” ์ž‘์€ ํŠธ๋ฆฌ๋ฅผ ํฐ ํŠธ๋ฆฌ์˜ ์ž์‹์œผ๋กœ ๋„ฃ์–ด์ฃผ๋Š” ๋ฐฉ์‹
    • ์ž‘์€ ํŠธ๋ฆฌ๋ฅผ ํฐ ํŠธ๋ฆฌ์— ๋ถ™์˜€์„ ๋•Œ ํŠธ๋ฆฌ ๋†’์ด๊ฐ€ ๋œ ์ฆ๊ฐ€
      • ์ž‘์€ ํŠธ๋ฆฌ rank : s, ํฐ ํŠธ๋ฆฌ rank : b โ†’ s โ‰ค b
      • ์ž‘์€ ํŠธ๋ฆฌ๋ฅผ ์ง‘์–ด๋„ฃ์—ˆ์„ ๋•Œ ํŠธ๋ฆฌ rank : max(s+1, b) โ†’ s+1 or b
      • ํฐ ํŠธ๋ฆฌ๋ฅผ ์ง‘์–ด๋„ฃ์—ˆ์„ ๋•Œ ํŠธ๋ฆฌ rank : max(s, b+1) โ†’ b+1
    • ๊ทธ๋ž˜์„œ findํ•  ๋•Œ parent ๋…ธ๋“œ ํƒ€๊ณ  ๊ฐ€๋Š” ํšŸ์ˆ˜๊ฐ€ ์ ์–ด ๋” ํšจ์œจ์ 
def union(a, b):
    pa, pb = find(a), find(b)

  if pa == pb: return

    # Union by Rank
    if rank[pa] < rank[pb]:
        parent[pa] = pb
  elif rank[pa] > rank[pb]:
    parent[pb] = pa
  else:
    parent[pb] = pa
    rank[pa] += 1


MST : ํฌ๋ฃจ์Šค์นผ & ํ”„๋ฆผ ์•Œ๊ณ ๋ฆฌ์ฆ˜

: ํฌ๋ฃจ์Šค์นผ ์•Œ๊ณ ๋ฆฌ์ฆ˜๊ณผ ํ”„๋ฆผ ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ MST๋ฅผ ๋งŒ๋“œ๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜

MST (Minimum Spanning Tree)

  • ๊ฐœ๋…
    • ๋ชจ๋“  ์ •์ ์„ ์—ฐ๊ฒฐํ•˜๋ฉด์„œ
    • Cycle ์ด ์—†๊ณ 
    • ๊ฐ„์„ ์˜ ์ด ๊ฐ€์ค‘์น˜๊ฐ€ ์ตœ์†Œ์ธ ๊ทธ๋ž˜ํ”„
  • ์–ด๋–ค ๊ทธ๋ž˜ํ”„๊ฐ€ ์žˆ์„ ๋•Œ,
    ๋ชจ๋“  ๋…ธ๋“œ๋ฅผ ์—ฐ๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ„์„ ์˜ ๊ฐ€์ค‘์น˜ ํ•ฉ์„ ์ตœ์†Œ๋กœ ํ•˜๋Š” ๊ฒฝ์šฐ๋ฅผ ๊ตฌํ•˜๊ณ ์ž ํ•  ๋•Œ MST๋ฅผ ๋งŒ๋“ฌ
  • ์ด๋Ÿฐ MST๋ฅผ ๋งŒ๋“œ๋Š”๋ฐ ํฌ๋ฃจ์Šค์นผ ํ˜น์€ ํ”„๋ฆผ์„ ****์‚ฌ์šฉ

Kruskal ํฌ๋ฃจ์Šค์นผ ์•Œ๊ณ ๋ฆฌ์ฆ˜

  • ๊ฐ„์„ ์„ ๊ธฐ์ค€์œผ๋กœ MST๋ฅผ ๋งŒ๋“ค์–ด๊ฐ€๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜

  • Union-Find ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ™œ์šฉํ•ด์„œ ๊ตฌํ˜„

  • ๋™์ž‘ ์ˆœ์„œ

      edges = [(1, 0, 1), (3, 1, 2), ... ] # ๊ฐ„์„  ์ •๋ณด (์—ฐ๊ฒฐ ๋…ธ๋“œ A, ์—ฐ๊ฒฐ ๋…ธ๋“œ B, weight)
      parent = [i for i in range(6)]
    
      def find(x):
          if parent[x] != x:
              parent[x] = find(parent[x])
          return parent[x]
    
      def union(x, y):
          x_root = find(x)
          y_root = find(y)
          if x_root != y_root:
              parent[y_root] = x_root
              return True
          ****return False
    
      edges.sort()  # ๊ฐ€์ค‘์น˜ ๊ธฐ์ค€ ์˜ค๋ฆ„์ฐจ์ˆœ
      mst_weight = 0
      edge_cnt = 0
      for w, a, b in edges:
          if union(a, b):
                  edge_cnt += 1
              mst_weight += w
          if edge_cnt == len(parent)-1 : break
    
      print("Kruskal MST ์ด ๊ฐ€์ค‘์น˜:", mst_weight)
    1. ๊ฐ„์„  ์ •๋ณด๋ฅผ ๋ฐฐ์—ด์— ์ €์žฅํ•˜๊ณ  ๊ฐ€์ค‘์น˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์˜ค๋ฆ„์ฐจ์ˆœ ์ •๋ ฌ
    2. ๊ฐ€์ค‘์น˜๊ฐ€ ์ž‘์€ ๊ฐ„์„ ๋ถ€ํ„ฐ ํ•˜๋‚˜ ์”ฉ ๋ฝ‘์•„์„œ
    3. ๊ฐ„์„ ๊ณผ ์—ฐ๊ฒฐ๋œ ๋…ธ๋“œ union ํ–ˆ์„ ๋•Œ
      • ์ด๋ฏธ ์—ฐ๊ฒฐ๋˜์–ด์žˆ์œผ๋ฉด โ†’ continue
        • ์ด๋ฏธ ์—ฐ๊ฒฐ๋˜์–ด์žˆ๋Š” ๋…ธ๋“œ์— ๋Œ€ํ•œ ๊ฐ„์„ ์„ ๋˜ ์ถ”๊ฐ€ํ•  ๊ฒฝ์šฐ cycle ์ƒ๊น€
      • ์•„์ง ์—ฐ๊ฒฐ ์•ˆ๋˜์–ด ์žˆ์œผ๋ฉด โ†’ ์„ ํƒ : mst_weight ๊ฐ’์— ๋”ํ•˜๊ธฐ
    4. ์„ ํƒํ•œ ๊ฐ„์„ ์˜ ๊ฐœ์ˆ˜๊ฐ€ ๋…ธ๋“œ ๊ฐœ์ˆ˜-1 ์ด๋ฉด ์ข…๋ฃŒ
      • ๋…ธ๋“œ ๊ฐœ์ˆ˜-1 ๊ฐœ๋ฉด ๋ชจ๋“  ๋…ธ๋“œ ์—ฐ๊ฒฐ ๊ฐ€๋Šฅํ•จ
  • ์‹œ๊ฐ„ ๋ณต์žก๋„ : $O(E log E)$ โ†’ ๊ฐ„์„ ์ด ๋งŽ์ง€ ์•Š์„ ๋•Œ ํ™œ์šฉํ•˜๋ฉด ์ข‹์€ ์•Œ๊ณ ๋ฆฌ


Prim ํ”„๋ฆผ ์•Œ๊ณ ๋ฆฌ์ฆ˜

  • ๋…ธ๋“œ๋ฅผ ๊ธฐ์ค€์œผ๋กœ MST๋ฅผ ๋งŒ๋“œ๋Š” ์•Œ๊ณ ๋ฆฌ์ฆ˜

  • BFS๋กœ ๋ฐฉ๋ฌธํ•˜๋ฉฐ ๊ฐ„์„ ์„ ๊ธฐ์ค€์œผ๋กœ ์ €์žฅ๋˜๋Š” ์šฐ์„ ์ˆœ์œ„ Queue์— ๋…ธ๋“œ๋ฅผ ์ถ”๊ฐ€ํ•˜๋ฉด์„œ ํƒ์ƒ‰

  • ๋™์ž‘ ์ˆœ์„œ

      import heapq
    
      # ์ธ์ ‘ ๋ฆฌ์ŠคํŠธ ๊ทธ๋ž˜ํ”„
      graph = {
          0: [(1, 1), (3, 4)],
          1: [(0, 1), (2, 3), (4, 2)],
          2: [(1, 3), (5, 5)],
          3: [(0, 4), (4, 6)],
          4: [(1, 2), (3, 6), (5, 7)],
          5: [(2, 5), (4, 7)]
      }
    
      visited = [False] * 6
      heap = [(0, 0)]  # (๊ฐ€์ค‘์น˜, ์‹œ์ž‘ ์ •์ )
      mst_weight = 0
    
      while heap:
          weight, node = heapq.heappop(heap)
          if visited[node]:
              continue
          visited[node] = True
          mst_weight += weight
          for w, neighbor in graph[node]:
              if not visited[neighbor]:
                  heapq.heappush(heap, (w, neighbor))
    
      print("Prim MST ์ด ๊ฐ€์ค‘์น˜:", mst_weight)
    1. ์ž„์˜์˜ ์ •์ ์—์„œ ์‹œ์ž‘ (์‹œ์ž‘์  ์šฐ์„ ์ˆœ์œ„ Queue์— ์ถ”๊ฐ€)
    2. ํ˜„์žฌ ํŠธ๋ฆฌ์— ์ธ์ ‘ํ•œ ๊ฐ„์„  ์ค‘ ๊ฐ€์žฅ ์งง์€ ๊ฒƒ ์„ ํƒ
      • ์šฐ์„ ์ˆœ์œ„ Queue์—์„œ pop
      • popํ•œ ๋…ธ๋“œ์˜ ๊ฐ€์ค‘์น˜๋ฅผ mst_weight ์— ๋”ํ•จ
    3. ์•„์ง ๋ฐฉ๋ฌธํ•˜์ง€ ์•Š์€ ์ •์ ์„ ์ถ”๊ฐ€ํ•˜๊ณ  1~3 ๋ฐ˜๋ณต
      • pop ํ•œ ๋…ธ๋“œ์™€ ์—ฐ๊ฒฐ๋œ ๊ฐ„์„ ์„ ๊ธฐ์ค€์œผ๋กœ ์ƒˆ๋กœ ๋ฐฉ๋ฌธํ•˜๋Š” ๋…ธ๋“œ์™€ ๊ฐ€์ค‘์น˜๋ฅผ Queue์— ์ถ”๊ฐ€
    4. ๋ชจ๋“  ์ •์ ์„ ๋ฐฉ๋ฌธํ•  ๋•Œ ๊นŒ์ง€
  • ์‹œ๊ฐ„ ๋ณต์žก๋„ : $O(E log V)$ โ†’ ๊ฐ„์„ ์ด ๋งŽ๊ฑฐ๋‚˜ ๋…ธ๋“œ ์ˆ˜๊ฐ€ ์ ์„ ๋•Œ ์œ ์šฉํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜