블로그 옮겼습니다

제 5회 kriiicon UV(Unifying Values) 본문

Algorithm/Problem Solving

제 5회 kriiicon UV(Unifying Values)

sgc109 2017. 5. 1. 00:16

https://oj.uz/problem/view/KRIII5_UV


\(1\leq{N}\leq{10^4}\)

\(-10^{14}\leq{A[i]}\leq{10^{14}}\)

N이 주어지고 N개의 정수로 이루어진 배열 A가 주어진다.

이 배열을 여러개로 나누는데 나눈 각각의 덩어리에 속한 원소의 합이 같도록 나눠야한다.

이렇게 나누는 모든 방법의 가짓수를 구하는 문제이다. 물론 1e7 로 나눈 나머지를 구한다.


이 문제는 dp 문제이다. 각 덩어리의 원소의 합을 몇으로 같게 할지를 먼저 정해주고

K로 정했다면 합이 K 가 되도록 나누는 방법의 수를 dp로 구한다.

dp[i] : i번째 원소부터 합이 원소들의 합이 K인 덩어리들로 나누는 방법의 수

그리고 모든 가능한 K에 대해 메모이제이션한 것을 초기화해주고 다시 부분문제를 구한다.

그러면 가능한 모든 K는 무엇일까? 주어진 모든 N개의 원소의 합의 약수여야한다.

N개의 원소의 합이 K로 나누어 떨어져야 하기 때문이다. 그럼 이 약수의 수가 적다는 것을 어떻게 알 수 있을까

나도 이걸 모르겠다. 나중에 이유를 알아보자. 그럼 이게 꽤 적다는 것을 알았다면

이제 남은건 dp로 O(N^2)로 각각의 정해진 합에 대해 답을 구하는 것이다. 그럼 최소 O(N^2) 일 텐데

왜 N이 최대 10^4인데도 0.5초 안에돌까? 나도 이게 의문이다. 실제로 계산하는 양은 sparse 해서 그렇다던데

나중에 에디토리얼이 올라오면 정확하게 알아보자.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
#define fastio() ios_base::sync_with_stdio(false),cin.tie(NULL)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int N;
ll dp[10003];
ll pSum[10003];
ll A[10003];
vl valid;
 
ll go(int pos, ll sum){
    if(pos > N) return 1;
    ll& cache = dp[pos];
    if(cache != -1return cache;
    cache = 0;
    for(int i = pos; i <= N; i++){
        if(pos==1 && i==N) continue;
        if(pSum[i]-pSum[pos-1!= sum) continue;
        cache = (cache + go(i+1, sum)) % MOD;
    }
 
    return cache;
}
 
int main() {
    inp1(N);
    FOR(i,N){
        scanf("%lld",A+i);
        pSum[i+1= pSum[i] + A[i];
        valid.pb(pSum[i+1]);
    }
    
    sort(all(valid));
    valid.erase(unique(all(valid)),valid.end());
 
    ll ans = 0 ;
    FOR(i,sz(valid)){
        if(valid[i] && pSum[N] % valid[i]) continue;
        memset(dp,-1,sizeof(dp));
        ans = (ans + go(1, valid[i])) % MOD;
    }
 
    printf("%lld",ans);
    return 0;
}
 
cs


Comments