题意
有 n 把价值分别为 a_i 的斧子,河神可能拿走 1 – 3 把,问每种可能的损失价值及其对应方案数。(不计顺序)
思路
生成函数入门题。
考虑设出多项式 A(x),其系数有 A[a_i] = 1,代表选一把的。则答案为 A(x) + A^2(x) + A^3(x)。但是这样是显然不对的。为什么?
因为这样的话同一个元素可能被选两次或三次,对于这种情况定义 B(x) 和 C(x) 满足 B[2a_i] = 1 和 C[3a_i] = 1,代表同时选两次/三次的,减掉这些方案数就可以了。然后需要注意顺序问题:
选一把的答案为 A(x),不难发现选两种的即为 \dfrac{A^2(x) – B(x)}{2},选三种的比较麻烦:不能同时选两种一样的,即减去 3A(x)B(x),,但是选三种同样的又会被多减两次,最后除以 3! 去掉顺序问题,所以最终答案为:
A(x) + \frac{A^2(x) – B(x)}{2} + \frac{A^3(x) – 3A(x)B(x) + 2C(x)}{6}
生成函数的卷积使用 NTT 或 FFT 优化即可。注意此时 NTT 模数要取一个更大的质数。
#include <cstdio>
#include <cctype>
#define FOR(i, a, b) for (int i = a; i <= b; ++i)
int read()
{
char c = getchar();
int s = 0;
while (!isdigit(c))
c = getchar();
while (isdigit(c))
s = 10 * s + c - '0', c = getchar();
return s;
}
typedef long long ll;
const int maxn = 5e5 + 5;
const ll G = 3, mod = 2281701377;
template<typename T> inline void myswap(T &a, T &b)
{
T t = a;
a = b;
b = t;
return;
}
ll pow(ll base, ll p = mod - 2)
{
ll ret = 1;
for (; p; p >>= 1)
{
if (p & 1)
ret = ret * base % mod;
base = base * base % mod;
}
return ret;
}
int rev[maxn];
const ll invG = pow(G);
void NTT(ll *f, int lim, int type)
{
FOR(i, 0, lim - 1)
if (i < rev[i])
myswap(f[i], f[rev[i]]);
for (int p = 2; p <= lim; p <<= 1)
{
int len = p >> 1;
ll tG = pow(type ? G : invG, (mod - 1) / p);
for (int k = 0; k < lim; k += p)
{
ll buf = 1;
for (int l = k; l < k + len; ++l, buf = buf * tG % mod)
{
ll tmp = buf * f[len + l] % mod;
f[len + l] = f[l] - tmp;
if (f[len + l] < 0) f[len + l] += mod;
f[l] = f[l] + tmp;
if (f[l] > mod) f[l] -= mod;
}
}
}
ll invlim = pow(lim);
if (!type)
FOR(i, 0, lim - 1)
f[i] = f[i] * invlim % mod;
return;
}
ll f1[maxn], f2[maxn], f3[maxn], ans[maxn];
ll g[maxn], t[maxn];
int main()
{
int n = read();
while (n--)
{
int tmp = read();
++f1[tmp], ++g[tmp], ++ans[tmp];
++f2[tmp << 1], ++f3[tmp * 3];
}
int lim = 1;
while (lim <= (40000 * 3 + 5)) lim <<= 1;
FOR(i, 0, lim - 1)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (lim >> 1) : 0);
NTT(f1, lim, 1), NTT(g, lim, 1);
FOR(i, 0, lim - 1)
g[i] = f1[i] * g[i] % mod;
NTT(g, lim, 0);
FOR(i, 0, lim - 1)
ans[i] += (g[i] - f2[i]) / 2;
NTT(g, lim, 1);
FOR(i, 0, lim - 1)
g[i] = f1[i] * g[i] % mod;
NTT(g, lim, 0);
NTT(f2, lim, 1);
FOR(i, 0, lim - 1)
f2[i] = f2[i] * f1[i] % mod;
NTT(f2, lim, 0);
FOR(i, 0, lim - 1)
{
ans[i] += (g[i] - 3 * f2[i] + 2 * f3[i]) / 6;
if (ans[i]) printf("%d %lld\n", i, ans[i]);
}
return 0;
}
留言