PS와 개발을 공부하자

Codeforces Round #405 (Div2.) D. Bear and Tree Jumps 본문

Algorithm/Problem Solving

Codeforces Round #405 (Div2.) D. Bear and Tree Jumps

sgc109 2017.03.21 09:30

우선 이 문제를 풀기 위해서 한가지 선행 지식이 필요하다. N개의 노드를 가진 트리에서 모든 정점 쌍의 거리의 합을 O(N)에 구할 수가 있는데 size(i) = i번 노드를 루트로 갖는 서브트리의 노드의 수. 로 정의하여 재귀적으로 모든 노드에 대해 구하면서 각 간선을 지나는 경로의 수를 dfs로 구하는것이다. 이 전, 전 포스트에 설명이 되어있다.


하지만 이 문제는 점프 거리 k 라는 입력이 주어져서 한번에 k개의 간선을 뛰어넘을 수가 있고,  단순히 모든 정점쌍 사이의 거리의 합이 아니라 모든 정점 쌍 사이의 경로를 가기위한 점프 횟수의 합을 요구한다. 그렇기 때문에 최대 점프거리 k가 3일때, 거리가 3인 경로는 3이 아니라 1이 더해져야하고 4인 경로는 4가 아니라 2가 더해져야하는 식이다.


여기서 파악해야 할 것은 거리가 k의 배수인 경로에 대해서는 단순히 3으로 나눈 몫이 더해진다는것과 거리가 k의 배수가 아닌 경로에 대해서는 원래의 거리 d 보다 큰 가장 작은 k의 배수로 만들어서 k로 나눈 몫을 더해야 한다는 것이다. 그렇다면 문제에서 요구하는 답은 결국 (모든 정점쌍 사이의 거리의 합 + 경로의 길이가 k의 배수가 아닌 경로에 대해 증가시켜주어야 할 거리들의 합) / k 라는 것을 알 수가 있다. 그렇다면 두가지 부분으로 나누어 구한뒤 마지막에 k로 나누면 답이 나올 것이다.


우선 첫 번째 부분인 '모든 정점 쌍 사이의 거리의 합' 은 이 글의 가장 처음에 언급했던 선행 지식을 사용하여 간단하게 O(N) 에 구할 수가 있다. 그렇다면 이제 남은건 두번째 부분이다.


이것은 트리 dp를 이용해 구할 수가 있다. 우선 부분 문제를 정의하겠다.


dp(i,m) : i번 노드를 루트로 가지는 서브트리의 노드들에서 i번 노드까지 오는 노드(경로)들중 k로 나눈 나머지가 m 인 노드(경로)들의 수.


점화식을 어떻게 구할까. 처음에 나는 어떤 노드 n 이 있고 이 노드의 자식노드들 c1, c2, c3, c4 ... cn  이 있을때 nC2 가지의 노드쌍에 대해 계산을 해줘야 하는 줄 알았다. 왜냐하면 두 자식까지 오는 경로에 두개의 간선을 추가하면 노드 n을 사이에 두고 가로지르는 경로가 생기기 때문이다. 하지만 트리 dp 에서 이러한 높은 복잡도와 복잡한 구현을 없애주는 좋은 테크닉이 있다. 바로 왼쪽 부터 차례대로 자식들을 봐주는 것이다. 지금까지 어떤 자식까지를 봐줬을 때에 대한 부분문제 값이 dp 배열에 저장 되어있을 때 새로운 자식하나를 더 봐줌으로써 이 새로운 자식까지 봐준 부분문제를 계산하고 이것을 마지막 자식까지 반복해주는 것이다. 어떤 경우에는 어떤 자식까지 봐주었는지로 부분문제를 나눠 한차원이 증가하게되는데 이 문제에서는 굳이 나눌 필요가 없다.

지금까지 본 노드들과 새로 보는 자식의 서브트리 노드들로 계산을 한뒤 두개를 병합하고 나서 다음 자식에 대해 똑같은 행위를 반복한다고 생각하면 좀더 이해가 쉬울 것이다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int N,M,a,b;
vi G[200003];
ll size[200003];
ll dp[200003][5];
ll ans;
void dfs(int here, int dad){
    size[here] = 0;
    dp[here][0]++;
    for(int there : G[here]){
        if(there == dad) continue;
        dfs(there, here);
        FOR(i,M){
            FOR(j,M){
                int k = i+j+1;
                ans += (M-((k%M)?(k%M):M)) * dp[here][i] * dp[there][j];
            }
        }
        FOR(i,M) dp[here][(i+1)%M] += dp[there][i];
        ans += size[there]*(N-size[there]);
        size[here] += size[there];
    }
    size[here]++;
}
int main() {
    inp2(N,M);
    FOR(i,N-1){
        inp2(a,b);
        a--,b--;
        G[a].pb(b);
        G[b].pb(a);
    }
    dfs(0,-1);
    printf("%lld",ans/M);
    return 0;
}
cs


0 Comments
댓글쓰기 폼