题解 - CodeForces 1902E Collapsing Strings

1 minute read

Published:

题目传送门

分析

要求 \(\sum_i^n\sum_j^n C(i,j)\),直接计算至少需要 \(O(n^2)\) 的时间复杂度,所以考虑使用某种数据结构维护所有字符串的相关信息,使得遍历每个字符串都能计算它对答案的贡献。既然是 \(n\) 个字符串,那么不妨使用 Trie 来维护。

对于每个字符串,不妨考虑它对答案的贡献取决于它的前缀和其他字符串后缀,故考虑建立一个由所有字符串的逆序串组成的 Trie。

对于 Trie 上的每个点,记录该点包含的字符串个数 \(cnt\) 以及该点记录的所有字符串长度之和 \(sz\)。遍历每个字符串,用 Trie 尝试将所有逆序字符串与当前字符串的正序进行匹配,对于当前字符串的某一位,如果匹配到的字符串个数减少,也就说明存在某些字符串不再能与当前字符串首尾抵消,我们可以利用当前匹配到的位数、这些字符串的长度、当前字符串的长度计算当前字符串与它们产生的贡献。具体地,当前匹配的字符串长度为 \(len\),匹配到第 \(j\) 位时从 Trie 上的 \(u\) 节点转移到 \(v\) 节点,产生的贡献值为

\[(len - j) \times (cnt[u] - cnt[v]) + (sz[u] - sz[v]) - (cnt[u] - cnt[v]) \times j\]

其中 \(cnt[u] - cnt[v]\) 即为前文提到的“某些字符串”的个数,\(sz[u] - sz[v]\) 即为“某些字符串”的总长度。时间复杂度 \(O(\sum_i^n\lvert s_i\rvert)\)。

C++20 实现

#include <iostream>
#include <cstdio>
#include <vector>
#include <utility>
#include <algorithm>
#include <ranges>
#include <string>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;

ll read() {
    ll x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {
        if(ch == '-') f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9') {
        x = (x << 3) + (x << 1) + ch - '0';
        ch = getchar();
    }
    return x * f;
}

const int N = 1e6 + 10;
// trie[i][x]记录了节点i转移至字母x对应的节点编号
int n, trie[N][26], tot = 1, cnt[N], sz[N];
string s[N];

void sol() {
    n = read();
    for (int i = 1; i <= n; i++) cin >> s[i];
    for (int i = 1; i <= n; i++) {
        int pos = 1, len = s[i].size();
        sz[1] += len, cnt[1]++;
        // 逆序遍历
        for (auto j : s[i] | views::reverse) {
            if (!trie[pos][j - 'a']) trie[pos][j - 'a'] = ++tot;
            pos = trie[pos][j - 'a'];
            cnt[pos]++;
            sz[pos] += len;
        }
    }
    ll ans = 0;
    for (int i = 1; i <= n; i++) {
        int from, to = 1;
        for (int j = 0; j < s[i].size(); j++) {
            from = to, to = trie[from][s[i][j] - 'a'];
            ans += 1ll * (cnt[from] - cnt[to]) * (s[i].size() - j * 2) + (sz[from] - sz[to]);
        }
        // 特别考虑匹配完当前字符串所有字符之后的情况
        ans += sz[to] - 1ll * cnt[to] * s[i].size();
    }
    printf("%lld\n", ans);
}

int main() {
    sol();
    return 0;
}