알고리즘

MST : 최단경로 알고리즘 (크루스칼/프림)

두구둥둥 2021. 5. 30. 10:01

MST = Minimal Spanning Tree

일단 MST를 알기 전에 Spanning Tree를 알아야 한다!

 

Spanning Tree란,

그래프의 모든 정점을 포함하면서 간선의 수가 최소인 트리이다.

Spanning Tree의 조건

  1. N개의 정점
  2. N-1개의 간선
  3. 사이클이 없다.

 

MST는 Spanning Tree 중에서 가중치의 합이 최소인 것을 말한다!

출처 : https://velog.io/@fldfls/최소-신장-트리-MST-크루스칼-프림-알고리즘

'마을과 마을을 잇는 도로들이 주어지고, 도로의 길이가 최소가 되게 모든 마을을 이어라.' 와 같은 문제에 사용되는 알고리즘이다.

 

 

MST 알고리즘 종류

  • 크루스칼
  • 프림

 

크루스칼 vs. 프림

  유형   시간복잡도 따라하세요!
크루스칼 Greedy 간선 중심 O(ElogE) 간적크
프림 정점 중심 O(V^2)/O(ElogV) 간많프

 

 

크루스칼 알고리즘

간선중심으로 최소 비용 신장 트리를 만드는 그리디 기반 알고리즘이다.

모든 간선들이 선택지가 되고, 그 중 가중치가 최소인 것부터 선택된다.

사이클 생기는지 체크를 위해 union-find 알고리즘이 사용된다.

출처 : https://velog.io/@hwi_chance/CS-Algorithm-Part.4-Minimum-Spanning-Tree

구현과정

  1. 모든 간선을 가중치에 대한 오름차순으로 정렬한다.
  2. 사이클이 안생기도록 가중치가 최소인 간선을 선택한다.
  3. 해당 간선을 MST 집합에 추가
  4. 2~3을 반복하고, 집합에 추가된 간선이 V-1개가 되면 탈출!  

 

구현 코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {

    static class Node{
        int start;
        int end;
        long weight;


        public Node(int start, int end, long weight){
            this.start = start;
            this.end = end;
            this.weight = weight;
        }

    }

    private static  boolean union(int[] parent, int a, int b){
        int parentA = findParent(parent,a);
        int parentB = findParent(parent,b);

        if(parentA != parentB) {
            parent[parentA] = parentB;
            return false;
        }

        return true;
    }

    private static int findParent(int[] parent, int a) {

        if(a == parent[a])  return a;
        return parent[a] = findParent(parent, parent[a]);
    }

    public static void main(String[] args) throws IOException {

        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine()," ");

        int V = Integer.parseInt(st.nextToken());
        int E = Integer.parseInt(st.nextToken());

        //make edges
        ArrayList<Node> edges = new ArrayList<>();

        for (int i = 0; i < E; i++) {
            st = new StringTokenizer(br.readLine()," ");

            Node node = new Node(Integer.parseInt(st.nextToken()),Integer.parseInt(st.nextToken()),Long.parseLong(st.nextToken()));
            edges.add(node);
        }

        Collections.sort(edges, new Comparator<Node>() {
            @Override
            public int compare(Node o1, Node o2) {
                return (int)(o1.weight-o2.weight);
            }
        });

        long result = 0;
        int[] parents = new int[V+1];

        for (int i = 1; i <= V; i++) {
            parents[i] = i;
        }

        Node n;

        int cnt = 0;
        for (int i = 0; i < edges.size(); i++) {

            if(cnt==V-1)  break;

            n = edges.get(i);

            if(union(parents, n.start, n.end)) continue;

            result += n.weight;
            ++cnt;
        }


        System.out.println(result);
    }
}

 

프림 알고리즘

노드를 중심으로 MST 집합을 늘려가는 그리디 기반 알고리즘이다.

MST집합에 속한 노드와 연결된 모든 노드들이 최소간선을 찾는 선택지들이 된다!

 

출처 : https://velog.io/@hwi_chance/CS-Algorithm-Part.4-Minimum-Spanning-Tree

 

구현과정

  1. 인접리스트를 구현한다.
  2. 현재까지 MST에 추가된 간선을 담는 PriorityQueue를 만들고, (시작하려는 정점, weight=0)을 넣어준다.
  3. PriorityQueue에서 가중치가 제일 작은 간선을 뺀다.
  4. 해당 간선의 도착 지점이 이미 MST집합에 포함되어있으면 지나친다.
  5. 아니라면 MST집합에 넣고, cnt ++ 한다. (cnt : 현재까지 MST 집합에 포함되어있는 노드 개수)
  6. 해당 간선의 도착 지점과 연결된 노드 중 아직 MST집합에 포함되지 않은 것을 PriorityQueue에 넣는다.
  7. cnt == 노드수가 될때까지 3-6 반복

 

시간복잡도

MST집합에 포함될 수 있는 예비 간선들을,

PriorityQueue 최소 힙으로 구현한다면 최악의 경우   O (ElogV) 

간선 개수만큼 pq에 넣을 수 있고 O(E)
그 내부에서 pq에 삽입, 삭제마다 heap 구조 변경하고, heap에는 최대 V개의 데이터 있으므로 O(logV)

 

단순 배열형태로 구현한다면 최악의 경우  O(V^2) 

MST 집합에 속하지 않은 정점들 중에서 O(V)
그 내부에서 가중치가 가장 낮은 정점 선택 O(V)

 

구현 코드

※ graph를 만드는 방법은 input에 따라 다양하다.

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.PrimitiveIterator;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main{
    static class Node{
        int end;
        int weight;

        public Node(int end, int weight){
            this.end = end;
            this.weight = weight;
        }
    }
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = null;

        int N = Integer.parseInt(br.readLine());

        //graph 만들기
        ArrayList<ArrayList<Node>> graph = new ArrayList<>();

        for (int i = 0; i <=N; i++) {
            graph.add(new ArrayList<Node>());
        }

        //이차원 배열 형태의 input을 graph로 만들기
        int input = 0;
        for (int i = 1; i <= N; i++) {
            st = new StringTokenizer(br.readLine()," ");
            for (int j = 1; j <= N; j++) {
                input = Integer.parseInt(st.nextToken());
                if(i == j)  continue;

                graph.get(i).add(new Node(j,input));
                graph.get(j).add(new Node(i,input));

            }
        }

        int result = Prim(graph,N);

        System.out.println(result);
    }

    private static int Prim(ArrayList<ArrayList<Node>> graph, int N){
        PriorityQueue<Node> pq = new PriorityQueue<>((o1,o2)->(o1.weight-o2.weight));
        pq.add(new Node(1,0));

        int cnt = 0;
        int result = 0;
        boolean[] visited = new boolean[N+1];

        while(!pq.isEmpty()){
            Node cur = pq.poll();

            if(visited[cur.end])    continue;

            visited[cur.end] = true;
            result += cur.weight;

            for (Node node : graph.get(cur.end)) {
                if(!visited[node.end]){
                    pq.add(node);
                }
            }

            if(++cnt == N){
                break;
            }
        }
        return result;
    }
}