PS와 개발을 공부하자

OJ.UZ) 초음속철도 본문

Algorithm/Problem Solving

OJ.UZ) 초음속철도

sgc109 2017.02.27 10:21

https://oj.uz/problem/view/OJUZ11_rail


2 <= n <= 10억

1 <= m <= 20만


인 n과 m 이 주어지는데


n은 역의 개수이고 m은 철도 노선의 개수이다.


역의 번호는 1번 부터 n번까지 부여되고 하나의 철도 노선은 시작점과 도착점으로 이루어져 있다.


하나의 노선이 지나는 어떤 역에서든 하차할수있으며 승차할수있다.


m개의 철도의 노선들이 주어질때 몇개는 없애도 1번역부터 n번역까지 갈 수 있을것이다.


1번역부터 n번역까지 갈 수 있도록 노선을 선택하는 모든 경우의 수를 구하는 문제이다.

(가는 방법의 수가 아니다. 1번역에서 n번역까지 갈 수 있으면서 고른 노선이 하나라도 다르면 둘다 별개의 답인것이다.)


일단 n이 비정상적으로 크기때문에 좌표압축을 해준다.

그 다음에 최적 부분 구조를 위해 노선을 시작점의좌표를 기준으로 오름차순 정렬한다. 그리고 노선을 차례대로 보면서 1. 고르거나, 2. 안고르거나 둘중 하나를 하면서 전 노선까지 사용했을때의 경우의 수들을 가지고 몇을 곱하거나 서로 더해서 이번 노선까지 사용했을 때의 경우의 수를 구하는 것이다. 이것을 점화식으로 나타낼수가있다.


$ DP(i,j) = $ i번째 노선까지 사용하여 딱 j번 역까지 도달가능할때의 경우의 수


$ i)(if, E_{i} < j), DP(i,j) = 2 \times DP(i-1,j) $

$ ii) (if, E_{i} = j), DP(i,j) = DP(i-1,j) + \sum_{k=S_{i}}^{j}DP(i-1,k) $

$ iii) (if, E_{i} > j), DP(i,j) = DP(i-1,j) $


각각의 케이스에 대해 설명을 하자면,

1번 경우는 1~j 가 도달 가능한 부분 문제를 구하려고 하는건데 이미 지금 보고있는 경로를 포함하는 범위이기 때문에 선택을 하거나 하는경우와 안 하는경우 둘다 답에 추가 될 수가 있으므로 2를 곱한다.


2번 경우는 지금 보는 노선을 선택을 하지않을 수도있기 때문에 $DP(i,j)$를 더하고, $E_{i} = j $ 이기 때문에 지금 보는 노선도 선택을 할 수도 있기때문에 전 노선까지 사용했을때 도착지와 겹치는 부분이 있는 경우들에 대해서만 이번 노선을 사용하여 j까지 도달 할 수 있으므로 다 더한다.


3번 경우는 지금 보는 노선이 j를 넘어간다. 우리가 구하는 부분 문제의 정의를 보면 딱 j까지만 덮는 노선들의 조합의 경우의수를 구하는 것이기 때문에 이 부분 문제에서는 이 노선은 절대 선택을 할 수가 없는것이다. 그렇기 때문에 $ DP(i-1,j) $ 를 그대로 가져오는 것이다.


이렇게 하면 O(m^2) 이라는 시간 복잡도가 나올 것이다.

이것만으로는 불충분하다. 어떻게 더 시간을 줄일 수가 있을까?

점화식을 잘 보면 $j$와 $E_{j}$ 의 대소 관계에 따라 범위마다 점화식이 다른데

이번 노선까지 고려한 부분 문제들을 구할때에

구간의 합을 구하는 연산과 구간에 2를 곱하는 연산 그리고 전 부분문제에서 그대로 가져오기때문에 구간을 그대로 두는행위가 전부이다.


그렇기 때문에 segment tree lazy propagation를 사용하면 mlgm 에 계산이 가능하다.

여기서 중요한 것은 lazy 값만큼 2를 곱해야 하는데 그럼 $2^{lazy[node]} $ 인데 이 연산은 빠른 n제곱(분할정복) 으로 미리 modular 한 값을 구해놓아서 mlgm 에 할 수 있고 그때 그때 구한다면 mlg^2m 가 걸릴 것이다. 그냥 그때그때 구해도 시간안에 충분히 돌아갔지만 그래도 미리 구해놓고 레이지를 돌렸다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)z
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const int MOD = 1000000007;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
struct Range{
    int l,r;
    bool operator<(Range& rhs){
        return l<rhs.l;
    }
};
int N,M,a,b,E;
Range ranges[200003];
ll dp[1610003];
int lz[1610003];
ll poww[500000];
ll pow2(ll x, int n){
    if(!n) return 1;
    ll memo;
    if(n%2) {
        memo = pow2(x,(n-1)/2);
        return x*memo%MOD*memo%MOD;
    }
    memo = pow2(x,n/2);
    return memo*memo%MOD;
}
unordered_set<int> us;
ll query(int nl, int nr, int l, int r, int nd){
    if(lz[nd]) (dp[nd]*=poww[lz[nd]])%=MOD,(nl!=nr?lz[2*nd]+=lz[nd],lz[2*nd+1]+=lz[nd]:0),lz[nd]=0;
    if(l<=nl&&nr<=r) return dp[nd];
    if(r<nl||nr<l) return 0;
    return (query(nl,(nl+nr)/2,l,r,2*nd)+query((nl+nr)/2+1,nr,l,r,2*nd+1))%MOD;
} ll query(int l, int r){return query(0,E,l,r,1);}
 
void update(int nl, int nr, int l, int r, int nd){
    if(lz[nd]) (dp[nd]*=poww[lz[nd]])%=MOD,(nl!=nr?lz[2*nd]+=lz[nd],lz[2*nd+1]+=lz[nd]:0),lz[nd]=0;
    if(l<=nl&&nr<=r) {(dp[nd]*=2)%=MOD,(nl!=nr?lz[2*nd]++,lz[2*nd+1]++:0);return;}
    if(r<nl||nr<l) return;
    update(nl,(nl+nr)/2,l,r,2*nd),update((nl+nr)/2+1,nr,l,r,2*nd+1),dp[nd]=(dp[2*nd]+dp[2*nd+1])%MOD;
void update(int l, int r){update(0,E,l,r,1);}
 
void pUpdate(int nl, int nr, int nd, int pos, ll val){
    if(lz[nd]) (dp[nd]*=poww[lz[nd]])%=MOD,(nl!=nr?lz[2*nd]+=lz[nd],lz[2*nd+1]+=lz[nd]:0),lz[nd]=0;
    if(nl==nr&&nl==pos) {dp[nd]=val;return;}
    if(nr<pos||pos<nl) return;
    pUpdate(nl,(nl+nr)/2,2*nd,pos,val),pUpdate((nl+nr)/2+1,nr,2*nd+1,pos,val),dp[nd]=(dp[2*nd]+dp[2*nd+1])%MOD;
void pUpdate(int pos, ll val){pUpdate(0,E,1,pos,val);}
 
int main() {
    vi sorted;
    inp2(N,M);
    us.insert(1),sorted.pb(1);
    us.insert(N),sorted.pb(N);
    FOR(i,M){
        inp2(a,b);
        ranges[i]=Range{a,b};
        if(!us.count(a)) us.insert(a),sorted.pb(a);
        if(!us.count(b)) us.insert(b),sorted.pb(b);
    }
    sort(all(sorted));
    E=sz(sorted)-1;
    FOR(i,M) {
        ranges[i].l=lower_bound(all(sorted),ranges[i].l)-sorted.begin();
        ranges[i].r=lower_bound(all(sorted),ranges[i].r)-sorted.begin();
    }
    sort(ranges,ranges+M);
    // 좌표압축 + 정렬까지 완료
    FOR(i,500000) poww[i]=pow2(2,i);
    pUpdate(0,1);
    FOR(i,M){
        pUpdate(ranges[i].r,(query(ranges[i].l,ranges[i].r)+query(ranges[i].r,ranges[i].r))%MOD);
        update(ranges[i].r+1,E);
    }
    printf("%lld",query(E,E));
    return 0;
}
cs


0 Comments
댓글쓰기 폼