PS와 개발을 공부하자

제5회 kriiicon 연습 세션 C번 다항식 계산 본문

Algorithm/Problem Solving

제5회 kriiicon 연습 세션 C번 다항식 계산

sgc109 2017. 4. 30. 09:06

https://oj.uz/problem/view/KRIII5P_2


\(0\le{N}\le{10^6},\,1\le{P}\le{10^3}\)

N차 다항식 \(f(x) = a_{N}x^{N}+\cdots+a_{1}x+a_{0}\) 과 소수가 주어진다.

이 때 \(f(0)\,mod\,P,\,f(1)\,mod\,P,\,\cdots,\,f(P-1)\,mod\,P\) 를 각각 구하는 것이다.


우선 가장 naive 하게 생각을 해보면

f(x) 를 구하는데에 걸리는 시간을 생각해 보면 빠른 N제곱을 하는데에 O(lgN) 이 걸리기 때문에

각 항을 계산하는데에는 O(lgN) 이걸리는데 항의 수가 N개 이기 때문에 O(NlgN) 이 걸린다는것을 알 수가 있다.

그런데 P가지의 x 에 대해 계산을 하기 때문에 O(NPlgN) 의 시간이 걸린다는 것을 알 수가 있다.

그렇기 때문에 이걸론 너무 느리다.

그럼 어떻게 할까? 사실 매번 N개에 대해 구할 필요가 없다. 왜냐하면 페르마의 소정리를 생각해 보면

소수 P에 대하여

\(X^{P}\equiv{X}\pmod P\) 이기 때문에 N개의 항들에 대해 따로 계산하는 것이아니라 P-1개로 합쳐서

한꺼번에 계산할 수가 있다는 것이다. 그러면 지수가 P-1만큼 차이나는 모든 항들의 계수를 더해주어

지수를 % (P-1) 을 해준다. 그러면 O(P^2lgP) 로 시간복잡도를 줄일 수가 있다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <bits/stdc++.h>
#define REP(i,a,b) for(int i=a;i<=b;++i)
#define FOR(i,n) for(int i=0;i<n;++i)
#define pb push_back
#define all(v) (v).begin(),(v).end()
#define sz(v) ((int)(v).size())
#define inp1(a) scanf("%d",&a)
#define inp2(a,b) scanf("%d%d",&a,&b)
#define inp3(a,b,c) scanf("%d%d%d",&a,&b,&c)
#define inp4(a,b,c,d) scanf("%d%d%d%d",&a,&b,&c,&d)
#define inp5(a,b,c,d,e) scanf("%d%d%d%d%d",&a,&b,&c,&d,&e)
#define fastio() ios_base::sync_with_stdio(false),cin.tie(NULL)
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef vector<int> vi;    
typedef vector<ll> vl;
typedef pair<int,int> pii;
typedef vector<pii> vii;
typedef vector<pll> vll;
typedef vector<vector<int> > vvi;
typedef pair<int,pair<int,int> > piii;
typedef vector<piii> viii;
const double EPSILON = 1e-9;
const double PI = acos(-1);
const int MOD = 1e9+7;
const int INF = 0x3c3c3c3c;
const long long INFL = 0x3c3c3c3c3c3c3c3c;
const int MAX_N = 102;
 
int N,P;
ll A[1000003];
ll sum[1003];
ll go(int x, int n){
    if(!x) return 0;
    if(n==0return 1;
    if(n%2){
        ll tmp = go(x,(n-1)/2);
        return (x*tmp*tmp)%P;
    }
    ll tmp = go(x,n/2);
    return (tmp*tmp)%P;
}
 
int main() {
    inp2(N,P);
    FOR(i,N+1scanf("%lld",A+i);
    if(P==1){
        printf("%lld",A[N]);
        return 0;
    }
    FOR(i,N+1) sum[i%(P-1)] += A[N-i];
    sum[0-= A[N];
    FOR(i,P){
        ll ans = 0;
        FOR(j,P){
            ans = (ans + go(i, j) * sum[j]) % P;
        }
        printf("%lld\n",(ans+A[N])%P);
    }
 
    return 0;
}
 
cs


0 Comments
댓글쓰기 폼