kazuma8128’s blog

競プロの面白い問題を解きます

CodeChef SPC19F1 Lockied Away

問題概要

N が与えられるので、積がN以上となるような非負整数のマルチセットのうち, 和の最小値を求めよ.
Nは桁数が 2*10^6 以下の非負整数.

解法

4 = 2 * 2 , 5 < 2 * 3 , 6 < 3 * 3 という風に考えていくと, 4以上の値は使う必要がないことがわかる.
当然1もいらなくて (ただしN=1を除く) , 2*2*2 < 3*3 なので, 2も2個以下しか使う必要がない.
そうすると, {3,3,3, ... ,3,3} or {2,3,3, ... ,3,3} or {2,2,3, ... ,3,3} の場合のみ考えれば十分であることがわかる.

あとは long double なんかで適当にやると誤差で無理なので, 実際に整数で計算しないといけない.
値が非常に大きいので繰り返し二乗法でかつ, 10進数で持って積は FFT を使う.
桁数から大体近い3のべき乗を見積もって, 生成したあとは N を超えるまで3をかけていく.

あと, logN 回程度 FFT をするけど小さいサイズから始めて毎回 resize をしていくと計算量が O( Σ (2^i)log(2^i) ) = O(XlogX) (Xは10進数での桁数) になる.
あと NTT (高速剰余変換) を使うと誤差もなく高速でできるので, TLが異常に厳しいけどなんとか間に合った.

ソースコード

https://www.codechef.com/viewsolution/22165745

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

int mod_inv(ll a, ll m) {
    ll b = m, u = 1, v = 0;
    while (b > 0) {
        ll t = a / b;
        a -= t * b; swap(a, b);
        u -= t * v; swap(u, v);
    }
    return (u % m + m) % m;
}

ll mod_pow(ll x, ll y, ll md) {
    ll res = 1;
    while (y) {
        if (y & 1) res = res * x % md;
        x = x * x % md;
        y >>= 1;
    }
    return res;
}

template <int Mod, int PrimitiveRoot>
class fast_modulo_transform {
public:
    static vector<int> fft(vector<int> v, bool inv) {
        const int N = v.size();
        assert((N ^ (N & -N)) == 0);
        int ww = mod_pow(PrimitiveRoot, (Mod - 1) / N, Mod);
        if (inv) ww = mod_inv(ww, Mod);
        for (int m = N; m >= 2; m >>= 1) {
            const int mh = m >> 1;
            int w = 1;
            for (int i = 0; i < mh; ++i) {
                for (int j = i; j < N; j += m) {
                    const int k = j + mh;
                    int x = v[j] - v[k];
                    if (x < 0) x += Mod;
                    v[j] += -Mod + v[k];
                    if (v[j] < 0) v[j] += Mod;
                    v[k] = (1LL * w * x) % Mod;
                }
                w = (1LL * w * ww) % Mod;
            }
            ww = (1LL * ww * ww) % Mod;
        }
        int i = 0;
        for (int j = 1; j < N - 1; ++j) {
            for (int k = N >> 1; k > (i ^= k); k >>= 1);
            if (j < i) swap(v[i], v[j]);
        }
        if (inv) {
            const int inv_n = mod_inv(N, Mod);
            for (auto& x : v) {
                x = (1LL * x * inv_n) % Mod;
                assert(0 <= x && x < Mod);
            }
        }
        return move(v);
    }
    static vector<int> convolution(vector<int> f, vector<int> g) {
        int sz = 1;
        const int m = f.size() + g.size() - 1;
        while (sz < m) sz *= 2;
        f.resize(sz), g.resize(sz);
        f = fast_modulo_transform::fft(move(f), false); g = fast_modulo_transform::fft(move(g), false);
        for (int i = 0; i < sz; ++i) {
            f[i] = (1LL * f[i] * g[i]) % Mod;
        }
        return fast_modulo_transform::fft(move(f), true);
    }
};

const ll fmt_mod = 998244353;

using fmt = fast_modulo_transform<998244353, 3>;

vector<int> prod(const vector<int>& a, const vector<int>& b) {
    auto v = fmt::convolution(a, b);
    int x = v.size();
    int up = 0;
    for (int i = 0; i < x; i++) {
        up += v[i];
        v[i] = up % 10;
        up /= 10;
    }
    while (up > 0) {
        v.push_back(up % 10);
        up /= 10;
    }
    while (v.back() == 0) v.pop_back();
    return move(v);
}

vector<int> calc_pow(int n) {
    vector<int> res = { 1 }, tmp = { 3 };
    while (n) {
        if (n & 1) res = prod(res, tmp);
        n >>= 1;
        if (n) tmp = prod(tmp, tmp);
    }
    return move(res);
}

vector<int> get_v(string s) {
    vector<int> v;
    reverse(s.begin(), s.end());
    for (auto c : s) {
        v.push_back(c - '0');
    }
    return v;
}

void kakeru(vector<int>& v, int d) {
    int n = v.size();
    v.push_back(0);
    int up = 0;
    for (int i = 0; i <= n; i++) {
        up += v[i] * d;
        v[i] = up % 10;
        up /= 10;
    }
    if (v.back() == 0) v.pop_back();
}

bool comp(const vector<int>& a, const vector<int>& b) {
    if (a.size() != b.size()) return a.size() < b.size();
    int n = a.size();
    for (int i = n - 1; i >= 0; i--) if (a[i] != b[i]) {
        return a[i] < b[i];
    }
    return false;
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    string S;
    cin >> S;
    int keta = S.size();
    if (keta <= 6) {
        int N = stoi(S);
        if (N == 1) {
            cout << 1 << endl;
            return 0;
        }
        int res = N;
        for (int i = 2; i <= 4; i++) {
            int val = i, sum = i;
            while (val < N) val *= 3, sum += 3;
            res = min(res, sum);
        }
        cout << res << endl;
        return 0;
    }
    auto vec = get_v(S);
    int x = int(2.0959 * (keta - 1)) - 2;
    auto pw = calc_pow(x);
    int res = 1e9;
    for (int i = 2; i <= 4; i++) {
        auto v = pw;
        kakeru(v, i);
        int sum = x * 3 + i;
        while (comp(v, vec)) {
            kakeru(v, 3);
            sum += 3;
        }
        res = min(res, sum);
    }
    cout << res << endl;
    return 0;
}

感想

これコンテスト中に誰も解いてなかったけど面白かった.