코딩테스트/SWEA_Java

다익스트라(Dijkstra) 구현_Java

Ski_ 2023. 2. 25. 14:55

import java.util.PriorityQueue;

class Solution {

    static int INF = 987654321;
    static int MAX_N = 6;

    static class Pair implements Comparable<Pair>{
        int x, y;

        Pair(int x, int y){
            this.x = x;
            this.y = y;
        }

        @Override
        public int compareTo(Pair o) {
            if (o.x == this.x) return this.y - o.y;
            else return this.x - o.x;
        }
    }

    static int graph[][] = new int[][]{
        {0,2,5,1,INF,INF},
        {2,0,3,2,INF,INF},
        {5,3,0,3,1,5},
        {1,2,3,0,1,INF},
        {INF,INF,1,1,0,2},
        {INF,INF,5,INF,2,0}
    };

    static int getMinIdx(int nodes[], boolean visited[]){

        int min = -1;

        for(int i=0;i< MAX_N;i++){

            if(visited[i]) continue;
            if(min<0 || nodes[min] > nodes[i]) min = i;
        }

        return min;

    }

    static void dijkstra(int arr[][], int start, int dist[]){
        // Priority Queue(Heap)를 이용한 구현 O(NlogN)

        PriorityQueue<Pair> pq = new PriorityQueue<>();

        for(int i=0;i< MAX_N;i++){
            dist[i] = INF;
        }

        pq.add(new Pair(0, start)); // {dist, destination}

        while(!pq.isEmpty()){

            int cur_dist = -pq.peek().x;
            int cur_node = pq.peek().y;
            pq.poll();

            for(int i=0;i< MAX_N;i++){
                int nxt_dist = cur_dist +  arr[cur_node][i];
                if(nxt_dist < dist[i])
                {
                    dist[i] = nxt_dist;
                    pq.add(new Pair(-nxt_dist,i));
                }
            }
        }
    }

    static void dijkstra2(int arr[][], int start, int dist[]){
        // 선형 탐색으로 구현 O(N^2)

        boolean visited[] = new boolean[MAX_N];

        for(int i=0;i< MAX_N;i++){
            dist[i] = arr[start][i];
        }

        visited[start] = true;

        for(int i=0;i< MAX_N-1;i++){
            int n_new = getMinIdx(dist,visited);
            visited[n_new] = true;
            for(int j=0;j< MAX_N;j++){

                if(visited[j]) continue;
                if(dist[j] > dist[n_new] + arr[n_new][j])
                    dist[j] = dist[n_new] + arr[n_new][j];
            }
        }

    }

    public static void main(String args[]) throws Exception {

        int dist[] = new int[MAX_N];

        int start = 0;

        dijkstra2(graph,start,dist);
        // dijkstra(graph,start,dist);

        for(int i=0;i< MAX_N;i++){
            System.out.printf("%d->%d : %d\n", start, i, dist[i]);
        }
        System.out.println();
        return;
    }

}
반응형