블로그 옮겼습니다

BOJ 14657번 준오는 최종 인재야!! 본문

Algorithm/Problem Solving

BOJ 14657번 준오는 최종 인재야!!

sgc109 2017. 7. 23. 15:55
http://boj.kr/14657


문제에서 주어진 세개의 조건을 통해 N개의 문제들은 트리를 이루고 있다는 것을 알 수가 있다.

그럼 가장 많은 문제를 풀기 위해서는 트리에서의 가장 긴 단순 경로를 찾으면 되며,

트리에서의 가장 긴 단순 경로는 바로 트리의 지름이 되기 때문에

트리의 모든 지름 중에 경로상에 포함된 에지들의 가중치합이 최소가 되는 지름을 찾는 문제로 볼 수가 있게된다.

트리의 모든 지름을 체크하기 위해 우선 트리의 중심을 찾고 그 중심을 기준으로 답을 구해보자.

하지만 트리의 중심은 1개 혹은 2개이기 때문에 두개의 케이스를 나누어보자.

1. 트리의 중심이 1개인 경우

트리의 중심이 1개인 경우에는 단순하다 트리의 중심으로 부터 거리가 지름/2 인 노드들까지의 경로들 중

그 경로상에 존재하는 에지들의 가중치의 합이 최소가 되는 경로 두개를 찾아 더하면 된다.

구현할 때 두 경로가 겹치지 않도록 조심해야한다.


2. 트리의 중심이 2개인 경우

트리의 중심이 1개인 경우와 다른점은 지름의 길이가 홀수라는 것인데 그렇기 때문에

두개중 하나의 중심으로 부터의 거리가 지름/2 와 (지름+1)/2 인 경로 두개를 합쳐야 지름이 된다는 것이다.

그렇기 때문에 중심 하나를 잡고 그 중심으로 부터의 거리가 지름/2 와 (지름+1)/2 인 경로의

경로상의 에지들의 가중치 합이 최소가 되는 값을 두 중심에 대해 각각 체크해주면 된다.


그렇게 하면 가장 많은 문제를 푸는 경우에 걸리는 시간 K를 구할 수가 있고

이것을 통해 며칠이 걸리는지는 (K + T - 1) / T 로 쉽게 계산할 수 있다.


(N <= 2 일땐 트리의 중심에서 시작하는 단순 경로의 개수가 1개 이하이기 때문에 따로 예외 처리를 한다.)


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
#define fastio() ios_base::sync_with_stdio(0),cin.tie(0)
using namespace std;
typedef long long ll;
const int mod = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
 
int N, T, a, b, c;
int dist[50003], par[50003];
vector<pair<int,int> > G[50003];
int findFarest(int start){
    memset(dist,-1,sizeof(dist));
    queue<int> q;
    q.push(start);
    dist[start] = 0;
    int farest = 0;
    int farD = 0;
    while(!q.empty()){
        int cur = q.front();
        q.pop();
        for(auto p : G[cur]){
            int next = p.first;
            if(dist[next] != -1continue;
            dist[next] = dist[cur] + 1;
            par[next] = cur;
            q.push(next);
            if(farD < dist[next]) farest = next, farD = dist[next];
        }
    }
    return farest;
}
 
int dfs(int cur, int dad, int d, int goal){
    int ret = INF;
    for(auto p : G[cur]){
        int next = p.first;
        int cost = p.second;
        if(next == dad) continue;
        int r = dfs(next, cur, d + 1, goal);
        ret = min(ret, r + cost);
    }
    if(ret == INF && d == goal) ret = 0;
    return ret;
}
int main() {
    fastio();
    cin >> N >> T;
    int sum = 0;
    for(int i = 0; i < N - 1; i++){
        cin >> a >> b >> c;
        sum += c;
        a--, b--;
        G[a].push_back({b, c});
        G[b].push_back({a, c});
    }
    if(N <= 2return !printf("%d", (sum + T - 1/ T);
    int f1 = findFarest(0);
    int f2 = findFarest(f1);
    vector<int> centers;
    int rad = INF;
    for(int cur = f2; cur != f1; cur = par[cur]){
        int r = max(dist[f2] - dist[cur], dist[cur]);
        rad = min(rad, r);
    }
    for(int cur = f2; cur != f1; cur = par[cur]){
        int r = max(dist[f2] - dist[cur], dist[cur]);
        if(r == rad) centers.push_back(cur);
    }
 
    int ans = 0;
    if((int)centers.size() == 2){
        int r1 = dfs(centers[0], centers[1], 0, dist[f2] / 2);
        int r2 = dfs(centers[1], centers[0], 0, dist[f2] / 2);
        ans = r1 + r2;
        for(auto p : G[centers[0]]){
            int next = p.first;
            int cost = p.second;
            if(next == centers[1]) {
                ans += cost;
                break;
            }
        }
    }
    else {
        vector<int> v;
        for(auto p : G[centers[0]]){
            int next = p.first;
            int cost = p.second;
            v.push_back(cost + dfs(next, centers[0], 0, dist[f2] / 2 - 1));
        }
        sort(v.begin(), v.end());
        ans = v[0+ v[1];
    }
 
    cout << (ans + T - 1/ T;
 
    return 0;
}
cs


사실 좀더 간단하게 푸는 방법도 있다.

굳이 지름을 찾지 않고 dfs 만 잘 구현하면 한방에 끝낼 수도 있다.

그런데 귀찮아서 나중에 짜야겠다..ㅎㅎ

Comments