PS와 개발을 공부하자

Topcoder SRM 519 Div1(Medium) RequiredSubstrings 본문

Algorithm/Problem Solving

Topcoder SRM 519 Div1(Medium) RequiredSubstrings

sgc109 2017.05.03 21:40

https://community.topcoder.com/stat?c=problem_statement&pm=11514


\(1\leq{N}\leq{50}\)

\(1\leq{C}\leq{N}\)

\(1\leq{L}\leq{50}\)

문자열의 개수 N개와 N개의 알파벳 소문자로만 이루어진 문자열이 주어지고,

C와 L이 주어지는데, 길이 L의 문자열을 만들고 싶은데, 이 만들어진 문자열 내에 N개의 문자열들 중

딱 C가지만 등장하도록 만드는 모든 문자열의 가짓수를 구하는 문제이다.

우선 N가지의 문자열들 중 딱 C가지만 등장하는지 여부를 판단하려면 N개의 문자열을 동시에 하나의 긴 문자열에서

등장하는지 찾아주는 알고리즘인 아호코라식 알고리즘이 필요하다는 것을 느낄 수가 있으며,

N가지 중 지금까지 몇가지가 등장했는지를 알아야 하고 현재 몇글자까지 만들었는지도 알아야 하며

현재까지 N가지의 문자열들에 대해 일치된 state를 알아야 한다. 이 세가지로 부분문제를 정의하여

dp로 구할 수 있을 것이다. 매 글자를 결정할 때 a~z 에 대해 변하는 인자값에 대해 서브호출하는 식으로 한다.

그리고 현재까지 N가지의 문자열들과 어떻게 매칭이되어있는지에 대한 상태는 트라이 노드로써 알 수가 있다.

그럼 지금까지 만든 문자열길이 len, 트라이노드 node, 지금까지 나온 문자열 정보 state 로 메모이제이션을 하는데

state 는 어차피 문자열이 6개밖에없으므로 비트마스크를 하면되고 len은 그냥 정수로 하는데 node는 뭘로할까?

조금만 생각해보면 최대 6개의 문자열들의 길이가 모두 50이고

이들의의 prefix들이 같은경우가 없다고 쳐도 300개의 노드밖에 생기지 않는다.

그렇기 때문에 트라이의 노드를 만들 때 id를 순서대로 할당하여, 함수 인자로 노드를 주면 그노드의 id로 하면된다.

그럼 결국 부분 문제의 수는 \(O(HL2^N)\) 인데, N은 주어지는 문자열의 수이고, L은 우리가 만들어야하는 문자열길이,

H는 주어진 모든 문자열들의 길이의 합이다. 

재귀함수를 통해 구현해보면 함수 내 로직은 간단하다 현재 만들어야하는 문자열의 위치에서 a~z에 대해

변화되는 노드로 재귀호출한다. 이 때 실패함수로 계산된 terminal 문자열들로 state를 갱신해준다.

그럼 K개의 문자들에 대해 (K=26) 실패함수는 최대 문자열의 길이만큼 도므로 최대 문자열의 길이를 M이라고하면

\(O(KM)\) 인 것이다. 그럼 결과적으로 O(KMHL2^N) 인데 최악의 경우 계산해보니 12억정도가 나온다. 

근데 이게 왜 실제로 돌렸을 때 모든 경우에 0.01초도 안나오는지 잘 모르겠다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <bits/stdc++.h>
using namespace std;
const int MOD = 1e9+9;
 
class RequiredSubstrings {
public:
    int solve(vector<string>intint);
};
 
int cnt;
struct Node {
    int id;
    Node *next[26];
    Node *fail;
    vector<int> end;
    Node() {
        id = cnt++;
        memset(next, 0sizeof(next));
        fail = 0;
    }
};
typedef Node *pNode;
 
pNode start;
long long dp[53][303][1<<6];
int C, L;
 
void push(pNode cur, string s, int pos, int id) {
    if (pos == s.size()) {
        cur->end.push_back(id);
        return;
    }
    auto& next = cur->next[s[pos] - 'a'];
    if (!next) {
        next = new Node;
    }
    push(next, s, pos + 1, id);
}
 
long long go(int pos, pNode cur, int include){
    for(auto e : cur->end) include |= (1<<e);
    if(pos == L) {
        int c = 0;
        for(;include;include>>=1) {
            if(include&1) c++;
        }

        return c == C;
    }
    long long& cache = dp[pos][cur->id][include];
    if(cache != -1return cache;
    cache = 0;
    for(int i = 0 ; i < 26; i++){
        pNode now = cur;
        if (now->next[i]) now = now->next[i];
        else{
            while (now != start) {
                if (now->fail->next[i]) {
                    now = now->fail->next[i];
                    break;
                }
                now = now->fail;
            }
        }
        cache = (cache + go(pos+1, now, include)) % MOD;
    }
    return cache;
}
 
void initFailFunc(){
    queue<pNode> q;
    q.push(start);
    start->fail = start;
    while (!q.empty()) {
        pNode par = q.front();
        q.pop();
        for (int i = 0; i < 26; i++) {
            pNode& child = par->next[i];
            if (!child) continue;
            if (par == start) child->fail = par;
            else {
                pNode cur = par;
                while (cur != start) {
                    if (cur->fail->next[i]) {
                        child->fail = cur->fail->next[i];
                        break;
                    }
                    cur = cur->fail;
                }
                if (child->fail == 0) child->fail = start;
            }
            child->end.insert(child->end.end(), child->fail->end.begin(), child->fail->end.end());
            q.push(child);
        }
    }
}
 
int RequiredSubstrings::solve(vector<string> words, int c, int l) {
    memset(dp, -1sizeof(dp));
    C = c, L = l;
    cnt = 0;
    start = new Node;
    for(int i = 0 ; i < words.size(); i++) push(start, words[i], 0, i);
 
    initFailFunc();
    return go(0, start, 0);
}
cs


0 Comments
댓글쓰기 폼