읽는데 약 3분
[백준 11479] 서로 다른 부분 문자열의 개수 2
문제
https://www.acmicpc.net/problem/11479
알고리즘
Suffix Array
풀이
열심히 구한 Suffix Array와 LCP로 우리는 무엇을 할 수 있을까요? 그중 하나인 부분 문자열과 관련된 문제들입니다.
이 문제에선 서로다른 부분 문자열의 개수를 묻고 있습니다. 부분 문자열이란 것은 달리 말하면 Suffix의 Prefix입니다.
Banana를 예로 들어보겠습니다. 모든 Suffix를 나열해보면 {Banana,anana,nana,ana,na,a} 가 될 겁니다. 여기서 부분 문자열을 얻기 위해 각각의 Suffix들에 대해 Prefix를 보면 되겠죠. 모든 Suffix 중 Suffix “anana"에서 얻을 수 있는 부분 문자열은 {a,an,ana,anana,anana} 입니다. 이는 모든 Prefix와 동일합니다.
남은 문제는 중복되는 부분 문자열을 거르는 일이죠. 이때는 LCP를 검사하며 값을 빼주기만 하면 됩니다. 사전 순으로 정렬한 Suffix Array는 그 정의에 따라 인접할수록 문자열이 비슷하다는 것을 의미하며, 문자열이 비슷하다는 것은 중복되는 문자열이 가장 많은 상태입니다. LCP를 빼게 되면 우리는 가장 많은 중복 부분 문자열을 제거해나갈 수 있습니다.
이전 Suffix Array구현에서 $p$가 추가되었습니다. 이는 $t$가 문자열의 길이 $n$보다 짧다 하더라도 구분이 끝나게 되면 반복문을 탈출하기 위함입니다. 구분이 끝났다는 것은 $p==n$ 입니다.
코드
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0;i<n;++i)
#define REP(i,n) for(int i=1;i<=n;++i)
#define FAST cin.tie(NULL);cout.tie(NULL); ios::sync_with_stdio(false)
using namespace std;
string input;
vector<int> g, ng, cnt, ordered, lcp, sfx;
long long ans;
void getsfx(string& s) {
int p = 1;
int n = s.size();
int mx = max(257, n + 1);
g.resize(n + 1), ng.resize(n + 1);
sfx.resize(n), ordered.resize(n);
for (int i = 0;i < n;++i) g[i] = s[i];
for (int t = 1;t < n;t <<= 1) {
cnt.clear(); cnt.resize(mx);
for (int i = 0;i < n;++i) ++cnt[g[min(i + t, n)]];
for (int i = 1;i < mx;++i) cnt[i] += cnt[i - 1];
for (int i = n - 1;i >= 0;--i) ordered[--cnt[g[min(i + t, n)]]] = i;
cnt.clear(); cnt.resize(mx);
for (int i = 0;i < n;++i) ++cnt[g[i]];
for (int i = 1;i < mx;++i) cnt[i] += cnt[i - 1];
for (int i = n - 1;i >= 0;--i) sfx[--cnt[g[ordered[i]]]] = ordered[i];
if (p == n) break;
p = 1;
ng[sfx[0]] = 1;
auto cmp = [&](int i, int j) {
if (g[i] == g[j]) return g[i + t] < g[j + t];
return g[i] < g[j];
};
for (int i = 1;i < n;++i) {
if (cmp(sfx[i - 1], sfx[i])) ++p, ng[sfx[i]] = ng[sfx[i - 1]] + 1;
else ng[sfx[i]] = ng[sfx[i - 1]];
}
g = ng;
}
lcp.resize(n);
for (int i = 0;i < n;++i) g[sfx[i]] = i;
for (int i = 0, k = 0;i < n;++i, k = max(k - 1, 0)) {
if (g[i] == n - 1) continue;
for (int j = sfx[g[i] + 1];s[i + k] == s[j + k];++k);
lcp[g[i]] = k;
ans -= k;
}
ans += ((long long)n * (n + 1) / 2);
}
int main() {
FAST;
cin >> input;
getsfx(input);
cout << ans;
return 0;
}
Read other posts