PS와 개발을 공부하자

트리에서 모든 노드쌍들 간의 거리의 총 합 본문

Algorithm/Problem Solving

트리에서 모든 노드쌍들 간의 거리의 총 합

sgc109 2017.03.20 09:41

naive 하게 모든 노드쌍간의 거리의 총합을 구한다고 하면 총 N개의 노드에 대해 DFS 혹은 BFS를 돌려서 O(N^2) 에 구하는 것이지만 N이 좀더 커지만 사용할 수가 없다.

이 때 발상을 조금 전환해서 모든 가능한 경로상에 각각의 간선이 총 몇번이나 포함되는지를 세어주면 이 것의 총합이 우리가 구하고자 하는 답이 될 것이다.

그렇다면 트리에서 루트 노드를 제외한 다른 모든 노드들은 하나의 부모가 있으므로 부모와 연결된 간선을 하나씩 가질 것이므로 루트를 제외한 하나의 노드는 각각 하나의 간선과 짝지어질 것이다. 그렇다면 루트 노드를 제외한 각각의 노드에 대해 자신의 부모와 연결된 간선을 지나는 모든 경로의 가지수를 계산하여 누적해 준다면 답이 될것인데 이것은 어떻게하냐면, size(i) 를 i번 노드를 루트 노드로 갖는 서브트리의 총 노드수 라고 정의 하겠다. 그렇다면 k번 노드와 매칭된 간선을 지나는 모든 경로의 수는 size(k) * (size(root) - size(k)) 이다. 왜냐하면 이 간선을 기준으로 두 부분으로 나뉘고 이것은 파티션처럼 각각을 노드를 원소로 갖는 집합으로 봤을 때 교집합은 공집합이고 합집합은 전체 트리집합이 된다는 말이다. 그렇기때문에 두 집합의 원소의 수를 서로 곱해주면 이 간선을 지나는 모든 경로의 수가 되는것이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int N,M,a,b;
vi G[200003];
int size[200003];
int distSum = 0;
int dfs(int here, int dad){
    int ret=0;
    for(int there : G[here]){
        if(there == dad) continue;
        int subCnt =  dfs(there, here);
        distSum += subCnt*(N-subCnt);
        ret += subCnt;
    }
    return size[here] = ret+1;
}
int main() {
    inp2(N,M);
    FOR(i,N-1){
        inp2(a,b);
        a--,b--;
        G[a].pb(b);
        G[b].pb(a);
    }
    dfs(0,-1);
    printf("%d",distSum);
    return 0;
}
cs


0 Comments
댓글쓰기 폼