알고리즘
MST : 최단경로 알고리즘 (크루스칼/프림)
두구둥둥
2021. 5. 30. 10:01
MST = Minimal Spanning Tree
일단 MST를 알기 전에 Spanning Tree를 알아야 한다!
Spanning Tree란,
그래프의 모든 정점을 포함하면서 간선의 수가 최소인 트리이다.
Spanning Tree의 조건
- N개의 정점
- N-1개의 간선
- 사이클이 없다.
MST는 Spanning Tree 중에서 가중치의 합이 최소인 것을 말한다!

'마을과 마을을 잇는 도로들이 주어지고, 도로의 길이가 최소가 되게 모든 마을을 이어라.' 와 같은 문제에 사용되는 알고리즘이다.
MST 알고리즘 종류
- 크루스칼
- 프림
크루스칼 vs. 프림
유형 | 시간복잡도 | 따라하세요! | ||||
크루스칼 | Greedy | 간선 중심 | O(ElogE) | 간적크 | ||
프림 | 정점 중심 | O(V^2)/O(ElogV) | 간많프 |
크루스칼 알고리즘
간선중심으로 최소 비용 신장 트리를 만드는 그리디 기반 알고리즘이다.
모든 간선들이 선택지가 되고, 그 중 가중치가 최소인 것부터 선택된다.
사이클 생기는지 체크를 위해 union-find 알고리즘이 사용된다.

구현과정
- 모든 간선을 가중치에 대한 오름차순으로 정렬한다.
- 사이클이 안생기도록 가중치가 최소인 간선을 선택한다.
- 해당 간선을 MST 집합에 추가
- 2~3을 반복하고, 집합에 추가된 간선이 V-1개가 되면 탈출!
구현 코드
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;
public class Main {
static class Node{
int start;
int end;
long weight;
public Node(int start, int end, long weight){
this.start = start;
this.end = end;
this.weight = weight;
}
}
private static boolean union(int[] parent, int a, int b){
int parentA = findParent(parent,a);
int parentB = findParent(parent,b);
if(parentA != parentB) {
parent[parentA] = parentB;
return false;
}
return true;
}
private static int findParent(int[] parent, int a) {
if(a == parent[a]) return a;
return parent[a] = findParent(parent, parent[a]);
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine()," ");
int V = Integer.parseInt(st.nextToken());
int E = Integer.parseInt(st.nextToken());
//make edges
ArrayList<Node> edges = new ArrayList<>();
for (int i = 0; i < E; i++) {
st = new StringTokenizer(br.readLine()," ");
Node node = new Node(Integer.parseInt(st.nextToken()),Integer.parseInt(st.nextToken()),Long.parseLong(st.nextToken()));
edges.add(node);
}
Collections.sort(edges, new Comparator<Node>() {
@Override
public int compare(Node o1, Node o2) {
return (int)(o1.weight-o2.weight);
}
});
long result = 0;
int[] parents = new int[V+1];
for (int i = 1; i <= V; i++) {
parents[i] = i;
}
Node n;
int cnt = 0;
for (int i = 0; i < edges.size(); i++) {
if(cnt==V-1) break;
n = edges.get(i);
if(union(parents, n.start, n.end)) continue;
result += n.weight;
++cnt;
}
System.out.println(result);
}
}
프림 알고리즘
노드를 중심으로 MST 집합을 늘려가는 그리디 기반 알고리즘이다.
MST집합에 속한 노드와 연결된 모든 노드들이 최소간선을 찾는 선택지들이 된다!

구현과정
- 인접리스트를 구현한다.
- 현재까지 MST에 추가된 간선을 담는 PriorityQueue를 만들고, (시작하려는 정점, weight=0)을 넣어준다.
- PriorityQueue에서 가중치가 제일 작은 간선을 뺀다.
- 해당 간선의 도착 지점이 이미 MST집합에 포함되어있으면 지나친다.
- 아니라면 MST집합에 넣고, cnt ++ 한다. (cnt : 현재까지 MST 집합에 포함되어있는 노드 개수)
- 해당 간선의 도착 지점과 연결된 노드 중 아직 MST집합에 포함되지 않은 것을 PriorityQueue에 넣는다.
- cnt == 노드수가 될때까지 3-6 반복
시간복잡도
MST집합에 포함될 수 있는 예비 간선들을,
PriorityQueue 최소 힙으로 구현한다면 최악의 경우 O (ElogV)
간선 개수만큼 pq에 넣을 수 있고 O(E)
그 내부에서 pq에 삽입, 삭제마다 heap 구조 변경하고, heap에는 최대 V개의 데이터 있으므로 O(logV)
단순 배열형태로 구현한다면 최악의 경우 O(V^2)
MST 집합에 속하지 않은 정점들 중에서 O(V)
그 내부에서 가중치가 가장 낮은 정점 선택 O(V)
구현 코드
※ graph를 만드는 방법은 input에 따라 다양하다.
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.PrimitiveIterator;
import java.util.PriorityQueue;
import java.util.StringTokenizer;
public class Main{
static class Node{
int end;
int weight;
public Node(int end, int weight){
this.end = end;
this.weight = weight;
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = null;
int N = Integer.parseInt(br.readLine());
//graph 만들기
ArrayList<ArrayList<Node>> graph = new ArrayList<>();
for (int i = 0; i <=N; i++) {
graph.add(new ArrayList<Node>());
}
//이차원 배열 형태의 input을 graph로 만들기
int input = 0;
for (int i = 1; i <= N; i++) {
st = new StringTokenizer(br.readLine()," ");
for (int j = 1; j <= N; j++) {
input = Integer.parseInt(st.nextToken());
if(i == j) continue;
graph.get(i).add(new Node(j,input));
graph.get(j).add(new Node(i,input));
}
}
int result = Prim(graph,N);
System.out.println(result);
}
private static int Prim(ArrayList<ArrayList<Node>> graph, int N){
PriorityQueue<Node> pq = new PriorityQueue<>((o1,o2)->(o1.weight-o2.weight));
pq.add(new Node(1,0));
int cnt = 0;
int result = 0;
boolean[] visited = new boolean[N+1];
while(!pq.isEmpty()){
Node cur = pq.poll();
if(visited[cur.end]) continue;
visited[cur.end] = true;
result += cur.weight;
for (Node node : graph.get(cur.end)) {
if(!visited[node.end]){
pq.add(node);
}
}
if(++cnt == N){
break;
}
}
return result;
}
}