kazuma8128’s blog

競プロの面白い問題を解きます

巨大modでの掛け算の高速化 (Codeforces Round #259 D Little Pony and Elements of Harmony)

解法

大体の方針はアダマール変換を使った XOR の畳み込みでできます.
あとは mod を p * 2^m (<= 1e15 くらい) にしておいて, 最後に 2^m で割れば逆変換ができるのでOK.

ただ, 今回は mod が 10^9 よりだいぶ大きいので long long でオーバーフローしてしまうため普通に掛け算できません.
なのでとりあえずダブリングでやってみます.
コードはこんな感じ.

ll mod_prod(ll a, ll b, ll md) {
    ll res = 0;
    while (b) {
        if (b & 1) res = (res + a) % md;
        a = (a + a) % md;
        b >>= 1;
    }
    return res;
}

ll は long long を typedef したものです.

しかしこれに繰り返し二乗法の log が付いてさらに 10^6 くらい回す必要があるので流石に TLE します.
なのでもっと速い方法は無いかと強い人たちのコードを読んでたらこんなのを見つけました.

ll mod_prod(ll a, ll b, ll md) {
    ll res = (a * b - (ll)((long double)a / md * b) * md) % md;
    return res < 0 ? res + md : res;
}

long double とか使ってて一瞬意味が分かりませんが, やってることは a * b から a * b を md の倍数に丸めたものを引いてるだけです.
両方ともオーバーフローする部分の値が一致するから上手くいくっぽいです.
これなら割り算があるとはいえ O(1) なので間に合うようになります.

ソースコード

http://codeforces.com/contest/453/submission/38935734

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

ll mod_prod(ll a, ll b, ll md) {
    ll res = (a * b - (ll)((long double)a / md * b) * md) % md;
    return res < 0 ? res + md : res;
}

ll mod_pow(ll v, ll x, ll md) {
    ll res = 1;
    while (x) {
        if (x & 1LL) res = mod_prod(res, v, md);
        v = mod_prod(v, v, md);
        x >>= 1;
    }
    return res;
}

ll mod;

template <typename T>
void fwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = x + y - (x + y >= mod ? mod : 0);
                f[j | i] = x - y + (x - y < 0 ? mod : 0);
            }
        }
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int m;
    ll t;
    cin >> m >> t >> mod; mod <<= (ll)m;
    int n = 1 << m;
    vector<ll> e(n);
    for (int i = 0; i < n; i++) {
        cin >> e[i]; e[i] %= mod;
    }
    vector<ll> b(m + 1);
    for (int i = 0; i <= m; i++) {
        cin >> b[i]; b[i] %= mod;
    }
    vector<ll> v(n);
    for (int i = 0; i < n; i++) {
        v[i] = b[__builtin_popcount(i)];
    }
    fwt(e);
    fwt(v);
    for (int i = 0; i < n; i++) {
        e[i] = mod_prod(e[i], mod_pow(v[i], t, mod), mod);
    }
    fwt(e);
    for (int i = 0; i < n; i++) {
        printf("%d\n", (int)(e[i] >> (ll)m));
    }
    return 0;
}