Codeforces Round #404 E Anton and Permutation
問題概要
1, 2, ... , n の列が最初にあって l_i, r_i が Q 個与えられるので毎回 l_i 番目の値と r_i 番目の値をスワップしてから全体の反転数を求めよ
AC までの思考過程
Editorial 曰く, 想定解は平方分割らしい.
それなりに面白い.
まあそんなことはおいといて(?), これを計算量 poly log で解きたい.
欲しいクエリは以下
- ある位置の値の変更
- ある区間に含まれるある値未満の値の個数の取得
これは動的 Wavelet Matrix でできる. 両方 O( (log n)^2).
しかし, 投げてみるとTLE.
流石に平衡二分木の定数倍がヤバすぎて無理っぽい.
まあしょうがないね・・・.
なので次に, 二次元セグメントツリーを試す.
MLE.
まあ空間計算量 O(q (log n)^2) なのでこれも仕方ない.
一応, 他の人の提出を見たところ二次元セグメントツリーで通ってる人もいるので, たぶん気合が足りないだけなのだろうけど自分には無理だった.
次に, セグメントツリーに平衡二分木を載せてみたところ, ギリギリAC.
やはり wavelet matrix 成分よりセグメントツリー成分の log の方が定数倍が小さいらしい.
ソースコード
http://codeforces.com/contest/785/submission/39274304
#include <bits/stdc++.h> using namespace std; using ll = long long; template <typename T> class avl_tree { struct node { T val; node *ch[2]; int dep, size; node(T v, node* l = nullptr, node* r = nullptr) : val(v), dep(1), size(1) { ch[0] = l; ch[1] = r; } }; int depth(node *t) { return t == nullptr ? 0 : t->dep; } int count(node *t) { return t == nullptr ? 0 : t->size; } node *update(node *t) { t->dep = max(depth(t->ch[0]), depth(t->ch[1])) + 1; t->size = count(t->ch[0]) + count(t->ch[1]) + 1; return t; } node *rotate(node *t, int b) { node *s = t->ch[1 - b]; t->ch[1 - b] = s->ch[b]; s->ch[b] = t; t = update(t); s = update(s); return s; } node *fix(node *t) { if (t == nullptr) return t; if (depth(t->ch[0]) - depth(t->ch[1]) == 2) { if (depth(t->ch[0]->ch[1]) > depth(t->ch[0]->ch[0])) { t->ch[0] = rotate(t->ch[0], 0); } t = rotate(t, 1); } else if (depth(t->ch[0]) - depth(t->ch[1]) == -2) { if (depth(t->ch[1]->ch[0]) > depth(t->ch[1]->ch[1])) { t->ch[1] = rotate(t->ch[1], 1); } t = rotate(t, 0); } return t; } node *insert(node *t, int k, T v) { if (t == nullptr) return new node(v); int c = count(t->ch[0]), b = (k > c); t->ch[b] = insert(t->ch[b], k - (b ? (c + 1) : 0), v); update(t); return fix(t); } node *erase(node *t) { if (t == nullptr) return nullptr; if (t->ch[0] == nullptr && t->ch[1] == nullptr) { delete t; return nullptr; } if (t->ch[0] == nullptr || t->ch[1] == nullptr) { node *res = t->ch[t->ch[0] == nullptr]; delete t; return res; } node *res = new node(find(t->ch[1], 0)->val, t->ch[0], erase(t->ch[1], 0)); delete t; return fix(update(res)); } node *erase(node *t, int k) { if (t == nullptr) return nullptr; int c = count(t->ch[0]); if (k < c) { t->ch[0] = erase(t->ch[0], k); t = update(t); } else if (k > c) { t->ch[1] = erase(t->ch[1], k - (c + 1)); t = update(t); } else { t = erase(t); } return fix(t); } node *find(node *t, int k) { if (t == nullptr) return t; int c = count(t->ch[0]); return k < c ? find(t->ch[0], k) : k == c ? t : find(t->ch[1], k - (c + 1)); } int cnt(node *t, T v) { if (t == nullptr) return 0; if (t->val < v) return count(t->ch[0]) + 1 + cnt(t->ch[1], v); if (t->val == v) return count(t->ch[0]); return cnt(t->ch[0], v); } node *root; public: avl_tree() : root(nullptr) {} int size() { return count(root); } void insert(T val) { root = insert(root, cnt(root, val), val); } void erase(int k) { root = erase(root, k); } T find(int k) { return find(root, k)->val; } int count(T val) { return cnt(root, val); } }; const int MAX = 1 << 18; avl_tree<int> segs[MAX << 1]; void add(int x, int y, int val, int t = 1, int lb = 0, int ub = MAX) { if (x < lb || ub <= x) return; if (val > 0) segs[t].insert(y); else segs[t].erase(segs[t].count(y)); if (ub - lb == 1) return; int m = (lb + ub) >> 1; add(x, y, val, t << 1, lb, m); add(x, y, val, (t << 1) | 1, m, ub); } int find(int li, int lj, int ri, int rj, int t = 1, int lb = 0, int ub = MAX) { if (ri <= lb || ub <= li) return 0; if (li <= lb && ub <= ri) return segs[t].count(rj) - segs[t].count(lj); int m = (lb + ub) >> 1; return find(li, lj, ri, rj, t << 1, lb, m) + find(li, lj, ri, rj, (t << 1) | 1, m, ub); } int main() { ios::sync_with_stdio(false), cin.tie(0); int n, q; cin >> n >> q; vector<int> val(n); for (int i = 0; i < n; i++) { val[i] = i; add(i, i, 1); } ll res = 0; while (q--) { int l, r; cin >> l >> r; l--, r--; if (l == r) { printf("%lld\n", res); continue; } if (l > r) swap(l, r); int lv = val[l], rv = val[r]; res += lv < rv ? 1 : -1; int cnt1 = find(l + 1, 0, r, lv); res -= cnt1; res += (r - l - 1) - cnt1; int cnt2 = find(l + 1, 0, r, rv); res += cnt2; res -= (r - l - 1) - cnt2; val[l] = rv, val[r] = lv; add(l, lv, -1); add(l, val[l], 1); add(r, rv, -1); add(r, val[r], 1); printf("%lld\n", res); } return 0; }
感想
想定解も悪くはないんだけど, やっぱり poly log で解かないと気が済まないよね.