魔改 KMP 大概是不行的,卷积在这里出现的很妙。

卷积处理匹配

定义匹配函数

$$ d(x,y) = [x = y] $$

给定文本串 $B$ 和长为 $m$ 的模式串 $A$,要找出 $A$ 在 $B$ 的所有出现,可以定义

$$ f(k) = \sum_{i=0}^{m-1} d(A_{i}, B_{i-m+k}) $$

即 $f(k) = 0$ 时有完全匹配 $A = B[k-m\ldots k-1]$。考虑让 $d$ 函数更 “数学” 一点,以更好的计算

$$ d(x,y) = (x - y)^2 $$

再提供模式串 $A$ 的翻转 $S$,即 $A_i = S_{m-i}$,展开有

$$ \begin{aligned} f(k) &= \sum_{i=0}^{m-1}A_{i}^2 + \sum_{i=0}^{m-1}B_{i-m+k}^2 - 2\sum_{i=0}^{m-1} A_{i}B_{i-m+k}\\ &= \sum_{i=0}^{m-1}A_{i}^2 + \sum_{i=0}^{m-1}B_{i-m+k}^2 - 2\sum_{i=0}^{m-1} S_{m-i}B_{i-m+k} \end{aligned} $$

前面两项能够预处理,最后一项是卷积。于是最终复杂度是 $O(n \log n)$。

考虑通配符

仅令通配符的字符值为 $0$,再搓个匹配函数

$$ d(x,y) = xy(x-y)^2 $$

然后大力展开

$$ \begin{aligned} f(k) &= \sum_{i=0}^{m-1}A_{i}^3B_{i-m+k} + \sum_{i=0}^{m-1}A_{i}B_{i-m+k}^3 - 2\sum_{i=0}^{m-1} A_{i}^2B_{i-m+k}^2\\ &= \sum_{i=0}^{m-1}S_{m-i}^3B_{i-m+k} + \sum_{i=0}^{m-1}S_{m-i}B_{i-m+k}^3 - 2\sum_{i=0}^{m-1} S_{m-i}^2B_{i-m+k}^2 \end{aligned} $$

注意到三个都是卷积,于是最终复杂度是 $O(n \log n)$。

记得优化取模。我换 NTT 之后 TLE 了好几次,最后发现是 NTT 里取模写多了。。

#define ACM_MOD 998244353
const int mod = ACM_MOD;

#include "template/basic/qpow.hpp"
#include "template/basic/inv.hpp"
#include "template/basic/mint.hpp"

using poly_t = vector<Mint>;

poly_t w;

#include "template/poly-ntt/pre_w.hpp"
#include "template/poly-ntt/ntt.hpp"

const int maxn = 1 << 21;

char a[maxn], b[maxn];

int stk[maxn], cnt = 0;

int main() {
	int m = rr(), n = rr();
	int lim = get_lim(m, n);
	w = pre_w(lim);
	poly_t B1(lim), B2(lim), B3(lim), S1(lim), S2(lim), S3(lim);

	scanf("%s %s", a, b);

	for (int i = 1; i <= m; i++) {
		int j = m - i, t = a[j] - 'a' + 1;
		if (a[j] == '*')
			t = 0;
		S1[i] = t;
		S2[i] = S1[i] * S1[i];
		S3[i] = S2[i] * S1[i];
	}

	for (int i = 0; i < n; i++) {
		int t = b[i] - 'a' + 1;
		if (b[i] == '*')
			t = 0;
		B1[i] = t;
		B2[i] = B1[i] * B1[i];
		B3[i] = B2[i] * B1[i];
	}

	ntt(S1), ntt(S2), ntt(S3);
	ntt(B1), ntt(B2), ntt(B3);

	for (int i = 0; i < lim; i++)
		S1[i] = S1[i] * B3[i] + S3[i] * B1[i] - S2[i] * B2[i] * 2;

	intt(S1);
	for (int i = m; i <= n; i++) {
		if (S1[i].v == 0)
			stk[++cnt] = i - m + 1;
	}
	printf("%d\n", cnt);
	for (int i = 1; i <= cnt; i++) {
		printf("%d ", stk[i]);
	}

	return 0;
}