kazuma8128’s blog

競プロの面白い問題を解きます

Codeforces Round #404 E Anton and Permutation

問題概要

1, 2, ... , n の列が最初にあって l_i, r_i が Q 個与えられるので毎回 l_i 番目の値と r_i 番目の値をスワップしてから全体の反転数を求めよ

AC までの思考過程

Editorial 曰く, 想定解は平方分割らしい.
それなりに面白い.

まあそんなことはおいといて(?), これを計算量 poly log で解きたい.
欲しいクエリは以下

  • ある位置の値の変更
  • ある区間に含まれるある値未満の値の個数の取得

これは動的 Wavelet Matrix でできる. 両方 O( (log n)^2).

しかし, 投げてみるとTLE.
流石に平衡二分木の定数倍がヤバすぎて無理っぽい.
まあしょうがないね・・・.

なので次に, 二次元セグメントツリーを試す.
MLE.
まあ空間計算量 O(q (log n)^2) なのでこれも仕方ない.

一応, 他の人の提出を見たところ二次元セグメントツリーで通ってる人もいるので, たぶん気合が足りないだけなのだろうけど自分には無理だった.

次に, セグメントツリーに平衡二分木を載せてみたところ, ギリギリAC.
やはり wavelet matrix 成分よりセグメントツリー成分の log の方が定数倍が小さいらしい.

ソースコード

http://codeforces.com/contest/785/submission/39274304

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

template <typename T>
class avl_tree {
    struct node {
        T val;
        node *ch[2];
        int dep, size;
        node(T v, node* l = nullptr, node* r = nullptr) : val(v), dep(1), size(1) {
            ch[0] = l; ch[1] = r;
        }
    };
    int depth(node *t) { return t == nullptr ? 0 : t->dep; }
    int count(node *t) { return t == nullptr ? 0 : t->size; }
    node *update(node *t) {
        t->dep = max(depth(t->ch[0]), depth(t->ch[1])) + 1;
        t->size = count(t->ch[0]) + count(t->ch[1]) + 1;
        return t;
    }
    node *rotate(node *t, int b) {
        node *s = t->ch[1 - b];
        t->ch[1 - b] = s->ch[b];
        s->ch[b] = t;
        t = update(t);
        s = update(s);
        return s;
    }
    node *fix(node *t) {
        if (t == nullptr) return t;
        if (depth(t->ch[0]) - depth(t->ch[1]) == 2) {
            if (depth(t->ch[0]->ch[1]) > depth(t->ch[0]->ch[0])) {
                t->ch[0] = rotate(t->ch[0], 0);
            }
            t = rotate(t, 1);
        }
        else if (depth(t->ch[0]) - depth(t->ch[1]) == -2) {
            if (depth(t->ch[1]->ch[0]) > depth(t->ch[1]->ch[1])) {
                t->ch[1] = rotate(t->ch[1], 1);
            }
            t = rotate(t, 0);
        }
        return t;
    }
    node *insert(node *t, int k, T v) {
        if (t == nullptr) return new node(v);
        int c = count(t->ch[0]), b = (k > c);
        t->ch[b] = insert(t->ch[b], k - (b ? (c + 1) : 0), v);
        update(t);
        return fix(t);
    }
    node *erase(node *t) {
        if (t == nullptr) return nullptr;
        if (t->ch[0] == nullptr && t->ch[1] == nullptr) {
            delete t;
            return nullptr;
        }
        if (t->ch[0] == nullptr || t->ch[1] == nullptr) {
            node *res = t->ch[t->ch[0] == nullptr];
            delete t;
            return res;
        }
        node *res = new node(find(t->ch[1], 0)->val, t->ch[0], erase(t->ch[1], 0));
        delete t;
        return fix(update(res));
    }
    node *erase(node *t, int k) {
        if (t == nullptr) return nullptr;
        int c = count(t->ch[0]);
        if (k < c) {
            t->ch[0] = erase(t->ch[0], k);
            t = update(t);
        }
        else if (k > c) {
            t->ch[1] = erase(t->ch[1], k - (c + 1));
            t = update(t);
        }
        else {
            t = erase(t);
        }
        return fix(t);
    }
    node *find(node *t, int k) {
        if (t == nullptr) return t;
        int c = count(t->ch[0]);
        return k < c ? find(t->ch[0], k) : k == c ? t : find(t->ch[1], k - (c + 1));
    }
    int cnt(node *t, T v) {
        if (t == nullptr) return 0;
        if (t->val < v) return count(t->ch[0]) + 1 + cnt(t->ch[1], v);
        if (t->val == v) return count(t->ch[0]);
        return cnt(t->ch[0], v);
    }
    node *root;
public:
    avl_tree() : root(nullptr) {}
    int size() {
        return count(root);
    }
    void insert(T val) {
        root = insert(root, cnt(root, val), val);
    }
    void erase(int k) {
        root = erase(root, k);
    }
    T find(int k) {
        return find(root, k)->val;
    }
    int count(T val) {
        return cnt(root, val);
    }
};

const int MAX = 1 << 18;

avl_tree<int> segs[MAX << 1];

void add(int x, int y, int val, int t = 1, int lb = 0, int ub = MAX) {
    if (x < lb || ub <= x) return;
    if (val > 0) segs[t].insert(y);
    else segs[t].erase(segs[t].count(y));
    if (ub - lb == 1) return;
    int m = (lb + ub) >> 1;
    add(x, y, val, t << 1, lb, m);
    add(x, y, val, (t << 1) | 1, m, ub);
}

int find(int li, int lj, int ri, int rj, int t = 1, int lb = 0, int ub = MAX) {
    if (ri <= lb || ub <= li) return 0;
    if (li <= lb && ub <= ri) return segs[t].count(rj) - segs[t].count(lj);
    int m = (lb + ub) >> 1;
    return find(li, lj, ri, rj, t << 1, lb, m) + find(li, lj, ri, rj, (t << 1) | 1, m, ub);
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int n, q;
    cin >> n >> q;
    vector<int> val(n);
    for (int i = 0; i < n; i++) {
        val[i] = i;
        add(i, i, 1);
    }
    ll res = 0;
    while (q--) {
        int l, r;
        cin >> l >> r; l--, r--;
        if (l == r) {
            printf("%lld\n", res);
            continue;
        }
        if (l > r) swap(l, r);
        int lv = val[l], rv = val[r];

        res += lv < rv ? 1 : -1;

        int cnt1 = find(l + 1, 0, r, lv);
        res -= cnt1;
        res += (r - l - 1) - cnt1;

        int cnt2 = find(l + 1, 0, r, rv);
        res += cnt2;
        res -= (r - l - 1) - cnt2;

        val[l] = rv, val[r] = lv;

        add(l, lv, -1);
        add(l, val[l], 1);
        add(r, rv, -1);
        add(r, val[r], 1);

        printf("%lld\n", res);
    }
    return 0;
}

感想

想定解も悪くはないんだけど, やっぱり poly log で解かないと気が済まないよね.