kazuma8128’s blog

競プロの面白い問題を解きます

HackerRank Heavy Light 2 White Falcon

問題概要

頂点に値が振られた木に対する以下のクエリを処理せよ.
クエリ1:u-v パス上の各頂点の値に x, 2x, 3x, ... の等差数列を足す
クエリ2:u-v パス上の各頂点の値の和 mod 1e9+7 を答える

解説

問題タイトルからしてあからさまに HL分解をしてほしそうです.
が, なんとなく Link-Cut Tree で解きました.

まあどちらにせよ, 区間に等差数列を足す操作と区間和を求める操作が必要になります.
これは実は遅延評価セグメントツリーでできます.

セグメントツリーの各ノードに区間和と加算したい数列の初項と公差を持たせてうまく子ノードに伝搬させてやればよいです.

あとは Link-Cut Tree でやりたければスプレー木で似たようにやります.

でも evert 操作を付けたいので反転とかを考えて左の子ノードを間違えないようにしないといけないので, 結構バグりやすいと思います.(僕はめちゃくちゃバグりました)
HL分解でも左右の向きを考えてやらないといけないのでどっちみちめんどくさそうです.
この辺でどちらを選ぶかは好みに依りそう.

ソースコード

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int mod = 1e9 + 7;

template<int MOD>
struct mod_int {
    static const int Mod = MOD;
    unsigned x;
    mod_int() : x(0) { }
    mod_int(int sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
    mod_int(long long sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
    int get() const { return (int)x; }

    mod_int &operator+=(mod_int that) { if ((x += that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator-=(mod_int that) { if ((x += MOD - that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator*=(mod_int that) { x = (unsigned long long)x * that.x % MOD; return *this; }
    mod_int &operator/=(mod_int that) { return *this *= that.inverse(); }

    mod_int operator+(mod_int that) const { return mod_int(*this) += that; }
    mod_int operator-(mod_int that) const { return mod_int(*this) -= that; }
    mod_int operator*(mod_int that) const { return mod_int(*this) *= that; }
    mod_int operator/(mod_int that) const { return mod_int(*this) /= that; }

    mod_int inverse() const {
        long long a = x, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        return mod_int(u);
    }
};

using mint = mod_int<mod>;

struct RS {
    using type = mint;
    static type id() { return 0; }
    static type op(const type& l, const type & r) {
        return l + r;
    }
};

class lct_node {
    using M = RS;
    using T = typename M::type;
    using U = pair<mint, mint>;

    lct_node *l, *r, *p;
    bool rev;
    T val, all;
    int size;
    bool flag;
    U lazy;

    int pos() {
        if (p && p->l == this) return 1;
        if (p && p->r == this) return 3;
        return 0;
    }
    void update() {
        size = (l ? l->size : 0) + (r ? r->size : 0) + 1;
        all = M::op(l ? l->all : M::id(), M::op(val, r ? r->all : M::id()));
    }
    void update_lazy(const U& v) {
        if (!flag) lazy = make_pair(0, 0);
        int ls = !rev ? (l ? l->size : 0) : (r ? r->size : 0);
        val += v.first + v.second * ls;
        all += v.first * size + ((v.second * (size - 1)) * size) / 2;
        lazy = make_pair(M::op(lazy.first, v.first), M::op(lazy.second, v.second));
        flag = true;
    }
    void rev_data() {
        lazy = make_pair(lazy.first + lazy.second * (size - 1), mint(0) - lazy.second);
    }
    void push() {
        if (pos()) p->push();
        if (rev) {
            swap(l, r);
            if (l) l->rev ^= true, l->rev_data();
            if (r) r->rev ^= true, r->rev_data();
            rev = false;
        }
        if (flag) {
            if (l) l->update_lazy(lazy);
            if (r) r->update_lazy(make_pair(lazy.first + lazy.second * (l ? l->size + 1 : 1), lazy.second));
            flag = false;
        }
    }
    void rot() {
        lct_node *par = p;
        lct_node *mid;
        if (p->l == this) {
            mid = r;
            r = par;
            par->l = mid;
        }
        else {
            mid = l;
            l = par;
            par->r = mid;
        }
        if (mid) mid->p = par;
        p = par->p;
        par->p = this;
        if (p && p->l == par) p->l = this;
        if (p && p->r == par) p->r = this;
        par->update();
        update();
    }
    void splay() {
        push();
        while (pos()) {
            int st = pos() ^ p->pos();
            if (!st) p->rot(), rot();
            else if (st == 2) rot(), rot();
            else rot();
        }
    }

public:
    lct_node() : l(nullptr), r(nullptr), p(nullptr), rev(false), val(M::id()), all(M::id()), size(1), flag(false) {}
    void expose() {
        for (lct_node *x = this, *y = nullptr; x; y = x, x = x->p) x->splay(), x->r = y, x->update();
        splay();
    }
    void link(lct_node *x) {
        x->expose();
        expose();
        p = x;
    }
    void evert() {
        expose();
        rev = true;
        rev_data();
    }
    T find() {
        expose();
        return all;
    }
    void update(U v) {
        expose();
        update_lazy(v);
    }
};

const int MAX = 5e4;
lct_node lct[MAX];

void build(int v, int prev, const vector<vector<int>>& G) {
    for (int to : G[v]) if (to != prev) {
        lct[to].link(&lct[v]);
        build(to, v, G);
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int N, Q;
    cin >> N >> Q;
    vector<vector<int>> G(N);
    for (int i = 0; i < N - 1; i++) {
        int u, v;
        cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    build(0, -1, G);
    while (Q--) {
        int com, u, v;
        cin >> com >> u >> v;
        if (com == 1) {
            int x;
            cin >> x;
            lct[u].evert();
            lct[v].update(make_pair(mint(x), mint(x)));
        }
        else {
            lct[u].evert();
            printf("%d\n", lct[v].find().get());
        }
    }
    return 0;
}

感想

サクッと通せると思ってたのにめちゃバグって悔しい.
evert ありの時の l, r を反転するタイミングをちゃんとおぼえとかないとダメ.

巨大modでの掛け算の高速化 (Codeforces Round #259 D Little Pony and Elements of Harmony)

解法

大体の方針はアダマール変換を使った XOR の畳み込みでできます.
あとは mod を p * 2^m (<= 1e15 くらい) にしておいて, 最後に 2^m で割れば逆変換ができるのでOK.

ただ, 今回は mod が 10^9 よりだいぶ大きいので long long でオーバーフローしてしまうため普通に掛け算できません.
なのでとりあえずダブリングでやってみます.
コードはこんな感じ.

ll mod_prod(ll a, ll b, ll md) {
    ll res = 0;
    while (b) {
        if (b & 1) res = (res + a) % md;
        a = (a + a) % md;
        b >>= 1;
    }
    return res;
}

ll は long long を typedef したものです.

しかしこれに繰り返し二乗法の log が付いてさらに 10^6 くらい回す必要があるので流石に TLE します.
なのでもっと速い方法は無いかと強い人たちのコードを読んでたらこんなのを見つけました.

ll mod_prod(ll a, ll b, ll md) {
    ll res = (a * b - (ll)((long double)a / md * b) * md) % md;
    return res < 0 ? res + md : res;
}

long double とか使ってて一瞬意味が分かりませんが, やってることは a * b から a * b を md の倍数に丸めたものを引いてるだけです.
両方ともオーバーフローする部分の値が一致するから上手くいくっぽいです.
これなら割り算があるとはいえ O(1) なので間に合うようになります.

ソースコード

http://codeforces.com/contest/453/submission/38935734

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

ll mod_prod(ll a, ll b, ll md) {
    ll res = (a * b - (ll)((long double)a / md * b) * md) % md;
    return res < 0 ? res + md : res;
}

ll mod_pow(ll v, ll x, ll md) {
    ll res = 1;
    while (x) {
        if (x & 1LL) res = mod_prod(res, v, md);
        v = mod_prod(v, v, md);
        x >>= 1;
    }
    return res;
}

ll mod;

template <typename T>
void fwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = x + y - (x + y >= mod ? mod : 0);
                f[j | i] = x - y + (x - y < 0 ? mod : 0);
            }
        }
    }
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int m;
    ll t;
    cin >> m >> t >> mod; mod <<= (ll)m;
    int n = 1 << m;
    vector<ll> e(n);
    for (int i = 0; i < n; i++) {
        cin >> e[i]; e[i] %= mod;
    }
    vector<ll> b(m + 1);
    for (int i = 0; i <= m; i++) {
        cin >> b[i]; b[i] %= mod;
    }
    vector<ll> v(n);
    for (int i = 0; i < n; i++) {
        v[i] = b[__builtin_popcount(i)];
    }
    fwt(e);
    fwt(v);
    for (int i = 0; i < n; i++) {
        e[i] = mod_prod(e[i], mod_pow(v[i], t, mod), mod);
    }
    fwt(e);
    for (int i = 0; i < n; i++) {
        printf("%d\n", (int)(e[i] >> (ll)m));
    }
    return 0;
}

色々な畳み込み

最近覚えたのでメモ.
理論とか仕組みとかの説明はしません.

高速フーリエ変換(FFT)

添え字和での畳み込み

  • Σ(i + j = k) a_i * b_j = c_k

この形で長さ n の列 a, b から長さ 2n の列 c を求めます. ただし n は2の冪乗とします.

列 a, b を離散フーリエ変換して, かけ合わせたものを逆離散フーリエ変換すると c が求まります.
計算量は O(nlogn) です.

コード

長いのでリンクだけ
github.com

ちなみに整数での特殊な mod の場合は 高速剰余変換(NTT) が使えて定数倍高速かつ誤差なく計算できます.

高速ゼータ変換(FZT)

上位/下位集合の畳み込み

  • Σ(j ⊂ i) a_i = b_j

または

  • Σ(i ⊂ j) a_i = b_j

この形で長さ n の列 a から長さ n の列 b を求めます. ただし n は2の冪乗とします.

やってることは実質 bitDP みたいな感じ.
計算量は O(nlogn) です.

コード
template <typename T>
void fzt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j] += f[j | i];
                // この場合上位集合の畳み込みになる
                // 左辺と右辺を逆にすると下位集合の畳み込みになる
            }
        }
    }
}

高速メビウス変換(FMT)

ゼータ変換の逆変換

  • Σ(j ⊂ i) (-1)^|i \ j| * a_i = b_j

または

  • Σ(i ⊂ j) (-1)^|j \ i| * a_i = b_j

この形で長さ n の列 a から長さ n の列 b を求めます. ただし n は2の冪乗とします.
計算量は O(nlogn) です.

コード
template <typename T>
void fmt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j] -= f[j | i];
                // この場合上位集合のゼータ変換の逆変換になる
                // 左辺と右辺を逆にすると下位集合のゼータ変換の逆変換になる
            }
        }
    }
}

高速アダマール変換(FWT)

添え字ANDでの畳み込み

  • Σ(i & j = k) a_i * b_j = c_k

この形で長さ n の列 a, b から長さ n の列 c を求めます. ただし n は2の冪乗とします.

列 a, b を離散アダマール変換をして, かけ合わせてから逆変換すると c が求まります.
上位集合での高速ゼータ変換+高速メビウス変換と同じです.
計算量は O(nlogn) です.

コード
template <typename T>
void fwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j] += f[j | i];
            }
        }
    }
}
template <typename T>
void ifwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j] -= f[j | i];
            }
        }
    }
}

添え字ORでの畳み込み

  • Σ(i | j = k) a_i * b_j = c_k

この形で長さ n の列 a, b から長さ n の列 c を求めます. ただし n は2の冪乗とします.

ANDとほぼ同じですが, 足される方向が逆になってます.
下位集合での高速ゼータ変換+高速メビウス変換と同じです.
計算量は O(nlogn) です.

コード
template <typename T>
void fwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j | i] += f[j];
            }
        }
    }
}
template <typename T>
void ifwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                f[j | i] -= f[j];
            }
        }
    }
}

添え字XORでの畳み込み

  • Σ(i ^ j = k) a_i * b_j = c_k

この形で長さ n の列 a, b から長さ n の列 c を求めます. ただし n は2の冪乗とします.

列 a, b を離散アダマール変換をして, かけ合わせてから逆変換すると c が求まります.
ifwt は2で除算する必要があるので注意.
ちなみに, 逆変換は fwt してから全要素を n で割ることでもできます.
計算量は O(nlogn) です.

コード
template <typename T>
void fwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = x + y, f[j | i] = x - y;
            }
        }
    }
}
template <typename T>
void ifwt(vector<T>& f) {
    int n = f.size();
    for (int i = 1; i < n; i <<= 1) {
        for (int j = 0; j < n; j++) {
            if ((j & i) == 0) {
                T x = f[j], y = f[j | i];
                f[j] = (x + y) / 2, f[j | i] = (x - y) / 2;
            }
        }
    }
}

その他

添え字GCDでの畳み込み

kazuma8128.hatenablog.com


HackerRank World CodeSprint 13 参加記

今回の World CodeSprint 13 で初めて HackerRank の T シャツをゲットできてテンションが上がったので調子に乗って参加記的なのを書いてみました.
問題番号順じゃなくて解いた時系列順に書いてるので注意.

7問目:Dynamic Trees

問題概要:木の辺を繋ぎかえたり, 頂点のブール値を反転させたり, パス上の k 番目の true の位置を求める.

これはもう問題文を読む前から問題名で解法がわかります(ほんまか)

あとは Sleator と Tarjan に感謝しながら ST-Tree(Link-Cut Tree) を生やすとAC.

と思ったらめっちゃバグった.
evert 機能つけてるのに上から探索するときにフラグを push するの忘れてた…だめやんけ…


この問題は Link-Cut Tree ゲーに見えるけど, たぶんクエリ平方分割でもいけます.
あんまりちゃんと考えてないけど.

1問目:Find The Absent Students

問題概要:1-n のうち入力に無い番号を出力する.

まあ, はい.
注意点としては答えが空のときに改行出力を忘れないこと.
(なくてもACになるかも?HackerRank のジャッジをよく知らない)

2問目:Group Formation

問題概要:条件を満たすようにマージしまくったした後一番大きい全てのグループの人の名前を列挙するみたいな感じだった気がする.

列挙が楽なように Quick-Find を使った.
寝不足すぎたため, 既にマージされたのを unite するときの continue を忘れる典型的なミスをして 1 WA.

3問目:Watson's Love for Arrays

問題概要:積が k mod m の空でない連続部分列の個数を求める.

0 が含まれると逆元がなくて困るので, 場合分け.

k = 0 なら 0 以外の区間の任意の部分列を求めて n * (n+1) / 2 から引く.

k != 0 なら 0 以外の区間に区切ってそれぞれ処理する.
区間で左から順番に積の累積を持ちながらその逆元を map に入れていってカウントを持っておく.
あとは各位置で ( (k の逆元) * 現在の積) の逆元 の個数を足し合わせると答えみたいな感じだったはず.

n * (n + 1) / 2 を計算するときに n が int なのでオーバーフローして 2 WA. おいおい.

4問目:Balanced Sequence

問題概要:'('と')'の列が与えられて, 区間に対する左右反転 & '('⇔')'反転の操作を使って最小何回でvalidな括弧列にできるか.

この問題はこのセットの中で一番難しい.( ※個人の感想です)
苦手なDP系の問題っぽく見える.
とりあえず, 元々 valid なら 0 でそれ以外は 1 というガバガバを投げてみたところ, 41.25/55 点貰えたため満足して飛ばす.

5問目:Competitive Teams

問題概要:チームのマージクエリと, チームのサイズ差が c (各クエリによって違う値) 以上のチームペアの個数を求めるクエリがくる.

とりあえず Union-Find パートをなくすと, 整数集合に対する 追加 / 削除 / 差がc以上のペアの個数 のクエリになる.

一見不可能に見えるので平方分割とかか?と思う.

が, よく考えるとサイズの合計は n なので, サイズの種類数は √n になることを考えると map で愚直にやって良いことがわかる.

意味不明なバグらせ方をして6WAするも AC.

この辺りから寝不足+バグらせまくりで精神を病み始める.

6問目:Landslide

問題概要:木があり, 辺が通れなくなったり通れるようになったりするクエリと, x-y パスが通れるかの判定クエリがくる.

お, また Link-Cut Tree ゲーじゃ~んとなる.

よく考えると木の構造は変わらないので HL 分解 + RMQ でよいのでは?

でも Link-Cut Tree 大好きマンなのでそのまま実装する(これがよくなかった)

既に繋がってるのに繋ぐクエリがきたり, 繋がってないのに切るクエリがくることを考えてなくて 9 WA して鬱病になる. 問題文をちゃんと読もうな.

やっと AC.

4問目:Balanced Sequence(再)

なんかよく考えたら, 最悪でも2回くらいで valid にできるんでは?という気持ちになる.

色々サブミットして試したところ, 1回目の提出の 1 を 2 に変えると1回目通らなかったケースは全部通るので, さっきの仮説は正しいことがわかる.

あとは実質 Yes/No 判定

思考力がもう残ってないので, 適当にそれっぽい嘘の条件をいっぱい増やして投げて, 1と2のどちらかに完全に包含されていれば確定するという虚無をする.

あと2ケースだけになったけど, もうそれっぽい条件が思いつかない><

randome_device() を使えば確率 1/4 のガチャになるので, 十分な回数投げれば勝ち確(は???)

3回目のガチャでAC(不正ACをやめろ)(Codeforcesだと不可能)(T回のテストケース形式でも無理)(実質0完)

感想

HackerRank は WA のペナルティーがないし, 間違ったケースの番号もわかるので最高だなと思いました.(不正AC者並みの感想)
あと Link-Cut Tree ゲーを出してくれたのは本当に神.

非負整数値を扱う Trie について

はじめに

非負整数を二分木のトライ木で管理するアレに関する日本語記事があんまり無いっぽいので雑にメモ. (というかそもそも専用の呼び名ないのかな?)

とりあえずここでは Binary Trie って呼んどきます.

Binary Trie とは

整数をビット列とみなしてトライ木っぽく持つ set 的なことができるデータ構造です.
正確には要素の重複を許す multiset っぽく実装することが多そう.

整数集合を管理できますが, 平衡二分木よりも実装が楽なので最高です.

こんな感じ
f:id:kazuma8128:20180506004833p:plain

ノードに書いてある数字は部分木に含まれる要素の個数です.

できること

ビット長を B とすると, 以下の操作が全て O(B) でできます.

  • insert(x) : 値 x を集合に(一つ)追加
  • erase(x) : 値 x を集合から(一つ)削除
  • max_element/min_elemet() : 集合内の最大/最小値を取得
  • lower_bound/upper_bound(x) : 集合内で値 x 以上の/より大きい 最小の要素の番号を取得. ここでの番号とは小さい方から何番目かということ.


ここまではただの multiset でもできますが, 次の操作が重要.

  • kth_element(k) : k 番目に小さい要素を取得
  • max_element/min_element(x) : 集合内の値の中で, 値 x と XOR したときに最大/最小になる値の取得
  • kth_element(k, x) : 集合内の値の中で, 値 x で XOR したときの k 番目に小さい値の取得

これも O(B) でできます.


さらに遅延評価を行うと以下の操作が O(1) で可能になります

  • xor_all(x) : 全要素を値 x で XOR した値に変更

これと lower_bound とかを組み合わせることがよくあります.
(追記:よく考えたら遅延評価せずに XOR したときの lower_bound とかもできそうなので, 遅延評価は実装を楽にするため以外の価値がないかも?)


あと木構造なので, 当然永続化も簡単です.
詳しくはこっちを参照.

kazuma8128.hatenablog.com

方法

最低限のノードに持つ値は, 部分木に含まれる要素の個数, 左の子のポインタ, 右の子のポインタ の三つ.

見方を変えると, 各値ごとにカウントを持たせる動的な Segment Tree という風にも言えるかもしれません.

insert, erase ではノードを作成, 削除したり個数を±1したりするだけです.
max_element/min_element では個数が 0 でない部分木に 右/左 ノード優先で潜っていきます.
kth_element では二分探索木みたいな感じで根から降りながら二分探索していく感じでできます.

XOR した優先順位での操作系では各深さに対応するビットが立っていれば左右の優先順位を反転させて同じことをやればよいです.

遅延評価するときはさらに, 部分木に対して XOR したい値を持ちます.
あとは各操作で子ノードに降りる前に伝搬していく.
ついでにその深さのビットが立っていれば左右の子ノードを swap.


よく分からなければ実装を見た方が早いかも.

実装例

通常版
procon-lib/binary_trie.cpp at master · kazuma8128/procon-lib · GitHub

template<typename U = unsigned, int B = 32>
class binary_trie {
    struct node {
        int cnt;
        node *ch[2];
        node() : cnt(0), ch{ nullptr, nullptr } {}
    };
    node* add(node* t, U val, int b = B - 1) {
        if (!t) t = new node;
        t->cnt += 1;
        if (b < 0) return t;
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = add(t->ch[f], val, b - 1);
        return t;
    }
    node* sub(node* t, U val, int b = B - 1) {
        assert(t);
        t->cnt -= 1;
        if (t->cnt == 0) return nullptr;
        if (b < 0) return t;
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = sub(t->ch[f], val, b - 1);
        return t;
    }
    U get_min(node* t, U val, int b = B - 1) const {
        assert(t);
        if (b < 0) return 0;
        bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f];
        return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b);
    }
    U get(node* t, int k, int b = B - 1) const {
        if (b < 0) return 0;
        int m = t->ch[0] ? t->ch[0]->cnt : 0;
        return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b);
    }
    int count_lower(node* t, U val, int b = B - 1) {
        if (!t || b < 0) return 0;
        bool f = (val >> (U)b) & (U)1;
        return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1);
    }
    node *root;
public:
    binary_trie() : root(nullptr) {}
    int size() const {
        return root ? root->cnt : 0;
    }
    bool empty() const {
        return !root;
    }
    void insert(U val) {
        root = add(root, val);
    }
    void erase(U val) {
        root = sub(root, val);
    }
    U max_element(U bias = 0) const {
        return get_min(root, ~bias);
    }
    U min_element(U bias = 0) const {
        return get_min(root, bias);
    }
    int lower_bound(U val) { // return id
        return count_lower(root, val);
    }
    int upper_bound(U val) { // return id
        return count_lower(root, val + 1);
    }
    U operator[](int k) const {
        assert(0 <= k && k < size());
        return get(root, k);
    }
    int count(U val) const {
        if (!root) return 0;
        node *t = root;
        for (int i = B - 1; i >= 0; i--) {
            t = t->ch[(val >> (U)i) & (U)1];
            if (!t) return 0;
        }
        return t->cnt;
    }
};

遅延評価版
procon-lib/lazy_binary_trie.cpp at master · kazuma8128/procon-lib · GitHub

template<typename U = unsigned, int B = 32>
class lazy_binary_trie {
    struct node {
        int cnt;
        U lazy;
        node *ch[2];
        node() : cnt(0), lazy(0), ch{ nullptr, nullptr } {}
    };
    void push(node* t, int b) {
        if ((t->lazy >> (U)b) & (U)1) swap(t->ch[0], t->ch[1]);
        if (t->ch[0]) t->ch[0]->lazy ^= t->lazy;
        if (t->ch[1]) t->ch[1]->lazy ^= t->lazy;
        t->lazy = 0;
    }
    node* add(node* t, U val, int b = B - 1) {
        if (!t) t = new node;
        t->cnt += 1;
        if (b < 0) return t;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = add(t->ch[f], val, b - 1);
        return t;
    }
    node* sub(node* t, U val, int b = B - 1) {
        assert(t);
        t->cnt -= 1;
        if (t->cnt == 0) return nullptr;
        if (b < 0) return t;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        t->ch[f] = sub(t->ch[f], val, b - 1);
        return t;
    }
    U get_min(node* t, U val, int b = B - 1) {
        assert(t);
        if (b < 0) return 0;
        push(t, b);
        bool f = (val >> (U)b) & (U)1; f ^= !t->ch[f];
        return get_min(t->ch[f], val, b - 1) | ((U)f << (U)b);
    }
    U get(node* t, int k, int b = B - 1) {
        if (b < 0) return 0;
        push(t, b);
        int m = t->ch[0] ? t->ch[0]->cnt : 0;
        return k < m ? get(t->ch[0], k, b - 1) : get(t->ch[1], k - m, b - 1) | ((U)1 << (U)b);
    }
    int count_lower(node* t, U val, int b = B - 1) {
        if (!t || b < 0) return 0;
        push(t, b);
        bool f = (val >> (U)b) & (U)1;
        return (f && t->ch[0] ? t->ch[0]->cnt : 0) + count_lower(t->ch[f], val, b - 1);
    }
    node *root;
public:
    lazy_binary_trie() : root(nullptr) {}
    int size() const {
        return root ? root->cnt : 0;
    }
    bool empty() const {
        return !root;
    }
    void insert(U val) {
        root = add(root, val);
    }
    void erase(U val) {
        root = sub(root, val);
    }
    void xor_all(U val) {
        if (root) root->lazy ^= val;
    }
    U max_element(U bias = 0) {
        return get_min(root, ~bias);
    }
    U min_element(U bias = 0) {
        return get_min(root, bias);
    }
    int lower_bound(U val) { // return id
        return count_lower(root, val);
    }
    int upper_bound(U val) { // return id
        return count_lower(root, val + 1);
    }
    U operator[](int k) {
        assert(0 <= k && k < size());
        return get(root, k);
    }
    int count(U val) {
        if (!root) return 0;
        node *t = root;
        for (int i = B - 1; i >= 0; i--) {
            push(t, i);
            t = t->ch[(val >> (U)i) & (U)1];
            if (!t) return 0;
        }
        return t->cnt;
    }
};

存在しない要素を erase すると assert で落ちるので注意.

ノードをプールしてないのでちょっと遅いです. 定数倍で困ったらプールするとよさげ.

永続化したい場合はクラス化しない方が便利ですが, ここではクラス化してます.

たぶん大丈夫だと思いますがバグってたらごめんなさい.

練習問題

Binary Trie を使わなくても解ける問題もありますが, verifyに便利なので載せます.
白文字で方針も軽く書いてます(ネタバレ注意)

https://arc033.contest.atcoder.jp/tasks/arc033_3 (ARC033 C データ構造)
set 操作の verify 用
http://codeforces.com/contest/947/problem/C (Codeforces Round #470 C Perfect Security)
XOR での最小値を使う
http://www.spoj.com/problems/SUBXOR/ (SPOJ SubXor)
遅延で lower_bound
http://codeforces.com/contest/966/problem/C (Codeforces Round #477 C Big Secret)
遅延の upper_bound で殴れる
https://www.codechef.com/problems/GPD (CodeChef Gotham PD)
永続化
https://www.codechef.com/problems/PSHTTR (CodeChef Pishty and tree)
永続化して, 木のノードに部分木の総 XOR も持たせるとオンラインで各クエリO(B)で解ける. 想定解はたぶんオフラインでBIT
http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2270 (AOJ UTPC The L-th Number)
解説では永続 Segment Tree のイメージで説明されてますが, 永続 Binary Trie の問題とも見なせる

RUPC2018 Day3 E Broccoli or Cauliflower

問題概要

日本語なので省略

解法

Euler Tour をとってやると, あとは必要な操作は「区間和」と「区間のブール値反転操作」だけになります.

想定解は遅延評価セグメントツリーを使っています. (解説URL
ですが, 遅延評価セグメントツリーを使わなくても解けます.

区間和をもつ merge/split 可能な平衡二分木を使います.
区間の左右反転操作とイメージは同じです.

元の値を持った平衡二分木と, さらにブール値が反転した平衡二分木をもう一つ用意します.
すると, 反転したい区間を両方splitして, swapしてからそれぞれ merge してやると遅延評価なしで簡単に反転操作が可能になります.

シンプルで面白いですね.

ソースコード

https://onlinejudge.u-aizu.ac.jp/beta/review.html#RitsCamp18Day3/2760632

#include <bits/stdc++.h>
using namespace std;
using graph = vector<vector<int>>;

const int MAX = 1e5;

int it;
int lv[MAX], rv[MAX];
int id[MAX];

void euler_tour(int v, const graph& G) {
    id[it] = v;
    lv[v] = it++;
    for (auto to : G[v]) {
        euler_tour(to, G);
    }
    rv[v] = it;
}

unsigned xor128() {
    static unsigned x = 123456789, y = 362436069, z = 521288629, w = 88675123;
    unsigned t = x ^ (x << 11);
    x = y; y = z; z = w;
    return w = w ^ (w >> 19) ^ (t ^ (t >> 8));
}

struct node {
    int val, sum;
    node *lch, *rch;
    int size;
    node(int v, node* l = nullptr, node* r = nullptr) : val(v), sum(v), lch(l), rch(r), size(1) {}
};

int count(node *t) { return t ? t->size : 0; }
int getsum(node *t) { return t ? t->sum : 0; }

node *update(node *t) {
    t->size = count(t->lch) + count(t->rch) + 1;
    t->sum = getsum(t->lch) + t->val + getsum(t->rch);
    return t;
}

node *merge(node *l, node *r) {
    if (!l) return r;
    if (!r) return l;
    if ((int)xor128() % (l->size + r->size) < l->size) {
        l->rch = merge(l->rch, r);
        return update(l);
    }
    r->lch = merge(l, r->lch);
    return update(r);
}

pair<node*, node*> split(node *t, int k) {
    if (!t) return make_pair(nullptr, nullptr);
    if (k <= count(t->lch)) {
        pair<node*, node*> s = split(t->lch, k);
        t->lch = s.second;
        return make_pair(s.first, update(t));
    }
    pair<node*, node*> s = split(t->rch, k - count(t->lch) - 1);
    t->rch = s.first;
    return make_pair(update(t), s.second);
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int n, q;
    cin >> n >> q;
    graph G(n);
    for (int i = 1; i < n; i++) {
        int p;
        cin >> p; p--;
        G[p].push_back(i);
    }
    euler_tour(0, G);
    vector<char> c(n);
    for (int i = 0; i < n; i++) {
        cin >> c[i];
    }
    node *t = nullptr, *rt = nullptr;
    for (int i = 0; i < n; i++) {
        t = merge(t, new node(c[id[i]] == 'G'));
        rt = merge(rt, new node(c[id[i]] == 'W'));
    }
    while (q--) {
        int v;
        cin >> v; v--;

        auto s1 = split(t, lv[v]);
        auto s2 = split(s1.second, rv[v] - lv[v]);
        auto rs1 = split(rt, lv[v]);
        auto rs2 = split(rs1.second, rv[v] - lv[v]);

        t = merge(s1.first, merge(rs2.first, s2.second));
        rt = merge(rs1.first, merge(s2.first, rs2.second));

        puts(getsum(t) > getsum(rt) ? "broccoli" : "cauliflower");
    }
    return 0;
}

あとがき

実はこの解法はコンテスト終了後にサークルの先輩に教えてもらった方法です.
オンライン参加時は想定解で解きました.

Mo's Algorithm について

はじめに

この記事は Mo's Algorithm について整理するための自分用メモです.


ポエム

Mo's Algorithm とは, 二次元グリッドにおける与えられた点の集合を上下左右方向への移動だけで巡回するときのより効率の良い順番を見つけるためのアルゴリズムといえるのでは.


1次元

自明だけど一応.

一次元座標 (x_i) (0 ≤ x_i < X) (0 ≤ i < Q) なる distinct な点集合を巡回する方法を考える.

やり方

当然 x 座標が小さいものから順に行けばよい.

計算量

x 軸移動 : O(X)

よって全体の計算量は O(X) である


2次元

普通の Mo と大体同じ.

二次元座標 (x_i, y_i) (0 ≤ x_i < X, 0 ≤ y_i < Y) (0 ≤ i < Q) なる distinct な点集合を巡回する方法を考える.

やり方

まず, b_x を使って x 座標によって以下のようにバケット毎に分ける.

バケット数は ceil(X / bx) 個.
バケット毎に独立で, それぞれの点を (y_i) とみなして 1 次元の方法を適用する.

計算量

バケット内の点から点への移動は以下のようになる.

x 軸移動 : O(b_x)
y 軸移動 : O(Y)

バケット全体では以下のようになる. ただし, c_j はそのバケット内の点の数.

x 軸移動 : O(c_j * b_x)
y 軸移動 : O(Y)

バケット間での移動の合計は

x 軸移動 : O(X)
y 軸移動 : O(Y * (X / b_x)) = O(X * Y / b_x)

よって全体で以下.

x 軸移動 : O( (Σc_j) * b_x + X) = O(Q * b_x + X)
y 軸移動 : O(X * Y / b_x + Y)

合計 : O(Q * b_x + X * Y / b_x + X + Y)

このとき, 合計を最小にするバケットサイズは b_x = (XY/Q)^(1/2) である.
ちなみに, このバケットサイズが 1 より小さいことはない. (そうなる条件は Q > XY だが, これは各点が distinct であることに矛盾するため)

合計:O( (XYQ)^(1/2) + X + Y )

b_x が X を超えてしまうことを考慮していないが, そのような状況では Y >> XQ の条件を満たし, y座標でソートすることになるので, O(XQ + Y) = O(Y) となり問題ない.

また実際に, Mo's Algorithm を使う多くの場合は X, Y が同じか近い値である*1ので, 最終的に全体の計算量は O( (XYQ)^(1/2) ) となる.


3次元

時空間 Mo と大体同じ.

三次元座標 (x_i, y_i, z_i) (0 ≤ x_i < X, 0 ≤ y_i < Y, 0 ≤ z_i < Z) (0 ≤ i < Q) なる distinct な点集合を巡回する方法を考える.

やり方

二次元と同様にまず, b_x を使って x 座標でバケット毎に分ける.
バケット数は ceil(X / b_x) (= B とする) 個.
そして, 点を (y_i, z_i) とみなして 2 次元の方法をバケットごとに独立で適用する.

計算量

バケット内の点から点への移動の合計は以下のようになる. ただし, c_j はそのバケット内の点の数.

x 軸移動 : O(b_x * c_j)
y 軸移動 : O( (YZ * c_j)^(1/2) )
z 軸移動 : O( (YZ * c_j)^(1/2) )

バケット同士の間の移動の合計は以下.

x 軸移動 : O(X)
y 軸移動 : O(YB)
z 軸移動 : O(ZB)

全体は以下.

x 軸移動 : O( (Σc_j) * b_x) = O(Q * b_x + X)
y 軸移動 : O( (YZ)^(1/2) * ( Σc_j^(1/2) ) + YB ) = O( (YZ)^(1/2) * (QB)^(1/2) + YB ) *2 = O( (XYZQ)^(1/2) / b_x^(1/2) + XY / b_x )
z 軸方向 : O( hoge + ZB ) = O( hoge + XZ / b_x ) (hogeは y 軸方向の第一項と同じ)

合計 : O(Q * b_x + (XYZQ)^(1/2) / b_x^(1/2) + XY / b_x + XZ / b_x + X)

細かい議論は省略するが X, Y, Z がある程度近い値*3のとき, 後ろの三項がボトルネックになることはない.
そのとき, 合計を最小にするバケットサイズは b_x = (XYZ/Q)^(1/3) である.
よって, 時空間 Mo を使う大抵の場合において全体の計算量 O( (XYZ)^(1/3) * Q^(2/3)) になる.


n 次元

あんまり実用性なさそうだけど一応.
煩雑になるので, 各軸での最大値を統一する.

n 次元座標 (x_0_i, x_1_i, ... , x_(n - 1)_i ) (0 ≤ x_j_i < X (0 ≤ j < n) ) (0 ≤ i < Q) なる distinct な点集合を巡回する方法を考える.
ただし, n ≧ 2 とする

やり方

x_0 座標で B 毎にバケットとして分ける. バケット数は ceil(X / B) (= C とする) 個.
バケットごとに n - 1 次元を適用する.

計算量

全体で O(X * Q^( (n-1)/ n) ) になることを帰納法で示す.

n-1 次元での計算量が O(X * Q^( (n-2) / (n-1) ) ) と仮定する.
すると, 各バケット内での点から点への移動の合計は

x_0 軸移動 : O(B * c_j)
その他の各軸移動 : O(X * c_j^( (n-2) / (n-1) ) )

バケット間の移動の合計は

x_0 軸移動 : O(X)
その他の各軸移動 : O(X * C)

よって全体は

x_0 軸移動 : O(QB + X^2 / B + X)
その他の各軸移動 : O(X * (Σc_j^( (n-2) / (n-1) ) ) + X * C) = O(X * (Q^( (n-2) / (n-1) ) * C^(1 / (n-1) ) ) + X * C)*4 = O(X^(n / (n-1) ) * Q^( (n-2) / (n-1) ) / B^(1 / (n-1) ) + X^2 / B)

合計 : O(QB + X^(n / (n-1) ) * Q^( (n-2) / (n-1) ) / B^(1 / (n-1) ) + X^2 / B + X) = O(QB + X^(n / (n-1) ) * Q^( (n-2) / (n-1) ) / B^(1 / (n-1) ) )

この合計を最小にするバケットサイズは B = X / Q^(1/n) である.
そして, このとき全体の計算量は O(X * Q^( (n-1) / n) ) となり, 帰納法が成立する.

追記

この計算量の解析はガバガバで, nが定数とみなせないサイズになると定数項が爆発するので注意.


まとめ

移動の各方向の計算量が対称でない場合は(ある軸だけ log 付いたりとか)より良くできる場合があるので, これを眺めながら良い感じバケットサイズを計算したりとかに使えるかも(?)

*1:具体的には, max(X, Y) ≤ Q * min(X, Y) を満たすとき

*2: Σc_j = Q から, Schwarz の不等式を利用した

*3:具体的には, max{X, Y, Z} ≤ Q * min{X, Y, Z} を満たすとき

*4:Hölder の不等式の特殊化によって得られる