티스토리 뷰
최소 신장 트리 알고리즘
신장 트리 중에서 최소 비용으로 만들 수 있는 신장 트리를 찾는 알고리즘.
그 중에서 대표적인 최소 신장 트리 알고리즘이 크루스칼 알고리즘
그리디 알고리즘으로 분류
신장 트리의 개념
하나의 그래프가 있을 때 모든 노드를 포함하고, 연결되며 사이클이 존재하지 않는(tree) 그래프
신장 트리는 구글에 치기만 해도 예시가 나온다.
연결 관계에서 사이클을 형성하지 않으므로, 정점의 개수가 n개일 때, 간선이 n-1개가 된다.
최소 신장 트리
해당 신장 트리들 중에서 간선에 부여된 가중치의 합이 최소가 되는 신장 트리를 최소 신장 트리라 한다.
신장 트리의 조건을 만족하면서 최소 가중치(비용)을 들이는 신장 트리가 최소 신장 트리이다.
크루스칼 알고리즘
크루스칼 알고리즘의 구체적인 동작 과정
1. 간선 데이터를 비용에 따라 오름차순으로 정렬
2. 간선을 하나씩 확인하며 현재의 간선이 사이클을 발생시키는지 확인
2-1 사이클이 발생하지 않는 경우 최소 신장 트리에 포함
2-2 사이클이 발생하는 경우 최소 신장 트리에 미포함
3. 모든 간선에 대하여 2번 과정을 반복
선택된 간선의 갯수가 정점의 갯수 -1만큼 되면 알고리즘을 종료한다.
사이클인지 아닌지 판단할때는 union-find를 활용한다.
Union-Find란?
Disjoint set(서로소 집합)을 표현하는 자료구조
서로 다른 두 집합을 형합하는 union 연산, 집합 원소가 어떤 집합에 속해있는지 찾는 find 연산을 지원한다.
Union-Find에 대해서는 아래 글들 참조. 그림과 함께 설명이 자세하게 나와 있다.
https://todaycode.tistory.com/108
https://chanhuiseok.github.io/posts/algo-33/
코드
import sys
def find(parent, u):
# 노드 u의 부모 노드를 찾는 함수
if parent[u] != u:
parent[u] = find(parent, parent[u]) # 재귀를 통해 부모를 갱신하여 최적화
#즉, 부모노드를 찾아자다가, u가 parent와 같다면 그게 바로 부모 노드
return parent[u]
def union(parent, rank, u, v):
# 두 개의 집합을 합치는 함수
u_root = find(parent, u) # u의 루트 노드 찾기
v_root = find(parent, v) # v의 루트 노드 찾기
#그래서 u의 부모 노트와 v의 부모 노드를 구해서
if u_root == v_root:
return # 이미 같은 집합에 속해있으면 합치지 않는다.
#부모노드가 같다 = 같은 집합
if rank[u_root] > rank[v_root]: #
parent[v_root] = u_root # 낮은 랭크의 트리를 높은 랭크의 트리에 붙임
#v 의 부모 노드를 u_root로 갱신
else:
parent[u_root] = v_root #아니라면 u 의 부모 노드를 v_root로 갱신
if rank[u_root] == rank[v_root]:
#만약 같은 랭크라면 v의 랭크를 증가시키며, depth가 늘어났음을 표현해준다.
rank[v_root] += 1 # 같은 랭크일 경우 v의 랭크를 증가시킴
def kruskal(nodes, n):
# 크루스칼 알고리즘을 통해 최소 비용 신장 트리를 구하는 함수
nodes.sort(key=lambda x: x[2]) # 간선을 가중치를 기준으로 오름차순 정렬
parent = [i for i in range(n)] # 각 노드의 부모를 자기 자신으로 초기화
rank = [0] * n # 각 노드의 랭크 초기화
total_cost = 0 # 총 비용 초기화
for node in nodes:
u, v, cost = node
if find(parent, u) != find(parent, v): # 싸이클을 형성하지 않는다면,
# 즉, 부모 노드가 둘이 같지 않다면
union(parent, rank, u, v) # 두 집합을 합침
total_cost += cost # 비용 추가
return total_cost
n, m = map(int, sys.stdin.readline().split())
nodes = []
for _ in range(m):
a, b, c = map(int, sys.stdin.readline().split())
nodes.append((a - 1, b - 1, c)) # 인덱스를 0부터 시작하기 위해 -1
min_cost = kruskal(nodes, n)
print(min_cost)
코드 2
import sys
class union_find:
def __init__(self, n):
self.parent = [i for i in range(n)] #n이 들어오면 parent 변수 할당
self.rank = [0] * n #n이 들어오면 rank 변수 할당
def find(self, u):
if self.parent[u] != u: #
self.parent[u] = self.find(self.parent[u])
return self.parent[u]
def union(self, u, v):
u_root = self.find(u)
v_root = self.find(v)
if u_root == v_root:
return
if self.rank[u_root] > self.rank[v_root]:
self.parent[v_root] = u_root
else:
self.parent[u_root] = v_root
if self.rank[u_root] == self.rank[v_root]:
self.rank[v_root] += 1
def kruskal(nodes, n): #크루스칼 알고리즘
nodes.sort(key=lambda x: x[2])
print(nodes)
union = union_find(n)
total_cost = 0
for node in nodes:
u, v, cost = node
if union.find(u) != union.find(v):
union.union(u, v)
total_cost += cost
return total_cost
def main():
N, M = map(int, sys.stdin.readline().split())
node = []
for _ in range(M):
a, b, c = map(int, input().split())
node.append((a - 1, b - 1, c)) # 인덱스를 0부터 시작하기 위해 -1
print(node)
min_cost = kruskal(node, N)
print(min_cost)
main()