kazuma8128’s blog

競プロの面白い問題を解きます

SPOJ ZQUERY Zero Query

問題概要

N 個の 1 または -1 の列が与えられて, Q 個の区間に対してそれぞれ値の合計が0になる連続する最長の部分列の長さを求めよ

制約:1 ≤ N ≤ 5 * 10^4, 1 ≤ Q ≤ 5 * 10^4

解法

とりあえず累積和をとってみます.
すると, クエリの区間に含まれる同じ値同士のペアの中で最も離れたものの距離を求める問題になります.

Mo's Algorithm*1 で解きたい気持ちになります.

区間を伸ばすときは, 追加する値と同じ値で最も遠い位置との距離を, 最大値の候補に追加します.
その最も遠い位置は, 各値に対して現在の区間に含まれるすべての出現位置を deque で持っておけば簡単に求まります.
区間を縮めるときは伸ばすときに追加したものを逆順で削除してくだけです.

最大値の候補の管理に multiset とかを使うと区間の伸縮が O(logN) になり, 全体で O(N * sqrt(Q) * logN) とかになります.

かなり重いです. 定数倍高速化ゲーはクソ, みたいな気分になれます. (実装に依っては普通に通るかもしれませんが)


というわけでどうにかして log を落としましょう!

Rollback 平方分割というものを使います.
snuke.hatenablog.com

これを使うと, 区間を伸ばす操作だけで Mo's Algorithm と同等のことが出来ます.
ただし, 削除が不要になる代わりに, Rollback(行った操作を巻き戻す)操作ができる必要があります.
Rollbak操作は, 配列の場合は変更が行われるたびに変更したインデックスと前の値をスタックに積んでいけば簡単にできます.

削除操作が要らなくなったので, 候補は最大値の整数一個持っておけば十分になります.
区間に含まれる各値の出現位置は, 端っこだけ分かればいいので, 各値につき両端の値二つだけですみます.

これで multiset が不要になり, 区間を伸ばす操作が O(1)でできるようになったので, 全体で O( (Q+N) * sqrt(N) ) とかになるので, 余裕をもって間に合います.

ソースコード

#include <bits/stdc++.h>
using namespace std;

template<typename T>
class retroactive_array {
    vector<T> data;
    vector<pair<int, T>> hist;
public:
    retroactive_array(int N) : data(N) {}
    retroactive_array(int N, T val) : data(N, val) {}
    retroactive_array(const vector<T>& v) : data(v) {}
    retroactive_array(vector<T>&& v) : data(move(v)) {}
    void change(int k, T val) {
        hist.emplace_back(k, data[k]);
        data[k] = val;
    }
    T operator[](int k) const {
        return data[k];
    }
    int get_version() const {
        return hist.size();
    }
    void rollback(int ver) {
        assert(0 <= ver && ver <= (int)hist.size());
        int cnt = hist.size() - ver;
        while (cnt--) {
            data[hist.back().first] = hist.back().second;
            hist.pop_back();
        }
    }
};

const int B = 256;

struct query {
    int l, r, id;
    query(int l_, int r_, int i) : l(l_), r(r_), id(i) {}
    bool operator<(const query& q) const {
        return l / B == q.l / B ? r < q.r : l / B < q.l / B;
    }
};

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int N, M;
    cin >> N >> M;
    vector<int> a(N + 1); a[0] = N;
    for (int i = 0; i < N; i++) {
        cin >> a[i + 1];
        a[i + 1] += a[i];
    }
    vector<query> qs;
    vector<int> res(M);
    retroactive_array<int> fs(N * 2 + 1, -1), bs(N * 2 + 1, -1);
    int ma = 0;
    for (int i = 0; i < M; i++) {
        int l, r;
        cin >> l >> r; l--;
        if (r - l > B) {
            qs.emplace_back(l, r, i);
            continue;
        }
        fs.change(a[l], l);
        for (int j = l; j < r; j++) {
            if (fs[a[j + 1]] == -1) {
                fs.change(a[j + 1], j + 1);
            }
            else {
                ma = max(ma, j + 1 - fs[a[j + 1]]);
            }
        }
        res[i] = ma;
        fs.rollback(0);
        ma = 0;
    }
    sort(qs.begin(), qs.end());
    int tl = 0, tr = 0;
    for (auto& que : qs) {
        int l = que.l, r = que.r;
        if ((l / B + 1) * B != tl) {
            tl = tr = (l / B + 1) * B;
            fs.rollback(0);
            bs.rollback(0);
            ma = 0;
            fs.change(a[tl], tl);
            bs.change(a[tl], tl);
        }
        while (tr < r) {
            tr++;
            if (fs[a[tr]] == -1) {
                fs.change(a[tr], tr);
            }
            else {
                ma = max(ma, tr - fs[a[tr]]);
            }
            bs.change(a[tr], tr);
        }
        int fver = fs.get_version();
        int bver = bs.get_version();
        int pma = ma, ptl = tl;
        while (tl > l) {
            tl--;
            fs.change(a[tl], tl);
            if (bs[a[tl]] == -1) {
                bs.change(a[tl], tl);
            }
            else {
                ma = max(ma, bs[a[tl]] - tl);
            }
        }
        res[que.id] = ma;
        tl = ptl;
        fs.rollback(fver);
        bs.rollback(bver);
        ma = pma;
    }
    for (int i = 0; i < M; i++) {
        printf("%d\n", res[i]);
    }
    return 0;
}