블로그 옮겼습니다

BOJ 7812번 중앙 트리(약간 다른 더 좋은 풀이) 본문

Algorithm/Problem Solving

BOJ 7812번 중앙 트리(약간 다른 더 좋은 풀이)

sgc109 2017. 3. 21. 20:39

사실 원래의 풀이와 아주 큰 차이는 없다. 두번의 dfs를 돌리는 것까지는 똑같지만 변수가 하나 줄었고, 점화식이 좀 더 단순해 졌다.


우선 size(i) 는 i를 루트로 갖는 서브트리의 총 노드수라고 보면

노드 i 부터 자손들까지의 거리의 총합 S(i)는,

노드 i의 자식이 c1, c2, c3..... cn 이고, ck 와 연결된 간선의 가중치는 wk 일때

size(c1)*w1 + size(c2)*w2 +....+size(cn)*wn +

S(c1) + S(c2) + S(c3) +.....+S(cn) 이다.

즉 각각의 간선에 대해 몇번 지남당하는지를 계산한다고 보면된다. 이렇게 각각의 간선에 대해 몇 개의 경로에 대해 지남 당하느냐로

거리 의 총합을 구하는 경우가 많은 것같다. 트리상에 존재하는 모든 정점쌍간의 거리의 총합을 구할때도 사용되고, 이렇게 특정 노드를 루트로

갖는 서브트리에서 루트에서 모든 자손노드까지의 거리의 총합을 구할 때도 사용된다.

왜냐하면 정점쌍에서 출발점이 될 수 있는 노드의 개수와 도착점이 될 수 있는 노드의 개수를 알면 두 개를 곱하면 되기 때문이다.


또 다른 풀이의 핵심은 이 것이다.

중앙 정점이 v 에서 v의 자식인 u 로 옮겨 갈 때 총 거리의 합의 '변화량' 은 몇인가?

이 변화량만 잘 계산해주면 결국 한 정점에서의 총 거리의 합을 알면 그 다음 노드를 중앙 정점으로 했을 때 총 거리의 합도 쉽게 구할 수가 있다. 점화식은 이렇다.


dp[there] = dp[here] - size[there] * cost + (N-size[there]) * cost

 = dp[here] - (N-2*size[there]) * cost


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int T,N;
int a,b,c;
vii G[10003];
ll dp[10003];
ll size[10003];
void dfs1(int here, int dad){
    dp[here] = 0;
    size[here] = 1;
    for(auto p : G[here]){
        int there = p.first;
        int cost = p.second;
        if(there == dad) continue;
        dfs1(there, here);
        size[here] += size[there];
        dp[here] += dp[there] + size[there] * cost;
    }
}
void dfs2(int here, int dad){
    for(auto p : G[here]){
        int there = p.first;
        int cost = p.second;
        if(there == dad) continue;    
        dp[there] = dp[here] + (N-2*size[there]) * cost;
        dfs2(there, here);
    }
}    
int main() {
    while(1){
        FOR(i,10003) G[i].clear();
        memset(size,0,sizeof(size));
        memset(dp,0,sizeof(dp));
        inp1(N);
        if(!N) break;
        FOR(i,N-1){
            inp3(a,b,c);
            G[a].pb({b,c});
            G[b].pb({a,c});
        }
        dfs1(0,-1);
        dfs2(0,-1);
        ll ans = INFL;
        FOR(i,N) ans = min(ans, dp[i]);
        printf("%lld\n",ans);
    }
    return 0;
}
cs


Comments