kazuma8128’s blog

競プロの面白い問題を解きます

HL分解で部分木クエリ

HL分解*1でパスクエリができることは知ってたけど同時に部分木クエリも行えることを知ったのでメモ. (もしかして割と常識なのかな?)

これ
https://codeforces.com/blog/entry/53170

弱い方のHL分解

元々持ってたライブラリでは, 一回 dfs で各部分木のサイズを求めてから heavy-edge *2を見つけて, その後 bfs で根に近い heavy-path から順に, その中でも根に近い頂点から順に番号を振る感じでやってました.
これで各 heavy-path 上の頂点は全て連続した番号を持つので, パスクエリには対処できます.
しかしこれだと部分木クエリをやりたくても一つの部分列にはならないので, たぶん無理です.

強い方のHL分解

一回目の dfs までは同じで, 次に番号を振るときに bfs ではなく dfs をします. (なので番号の順序は行きがけ順に変わります)
このとき, 各頂点で最初に必ず heavy-edge から潜るようにします.
そうすることで, 各 heavy-path 上の頂点がオイラーツアー上において一つの連続した部分列になるのでパスクエリができます.
さらに, オイラーツアーなので各部分木の頂点も連続な部分列になって部分木クエリも可能になります.
もちろん LCA とかも求められます.

実装例

class heavy_light_decomposition {
    const int n;
    vector<vector<int>> g;
    vector<int> in, out, size, head, par;
    int it;
    void erase_par(int v, int prev) {
        par[v] = prev;
        for (auto& u : g[v]) {
            if (u == g[v].back()) break;
            if (u == prev) swap(u, g[v].back());
            erase_par(u, v);
        }
        g[v].pop_back();
    }
    void dfs1(int v) {
        for (auto& u : g[v]) {
            dfs1(u);
            size[v] += size[u];
            if (size[u] > size[g[v][0]]) swap(u, g[v][0]);
        }
    }
    void dfs2(int v) {
        in[v] = it++;
        for (auto u : g[v]) {
            head[u] = (u == g[v][0] ? head[v] : u);
            dfs2(u);
        }
        out[v] = it;
    }
public:
    heavy_light_decomposition(int n_)
        : n(n_), g(n), in(n, -1), out(n, -1), size(n, 1), head(n), par(n, -1), it(0) {}
    heavy_light_decomposition(const vector<vector<int>>& G)
        : n(G.size()), g(G), in(n, -1), out(n, -1), size(n, 1), head(n), par(n, -1), it(0) {}
    void add_edge(int u, int v) {
        g[u].push_back(v);
        g[v].push_back(u);
    }
    void build(int rt = 0) {
        for (auto v : g[rt]) erase_par(v, rt);
        dfs1(rt);
        head[rt] = rt;
        dfs2(rt);
    }
    int get_id(int v) {
        return in[v];
    }
    int get_lca(int u, int v) {
        while (true) {
            if (in[u] > in[v]) swap(u, v);
            if (head[u] == head[v]) return u;
            v = par[head[v]];
        }
    }
    void path_query(int u, int v, function<void(int, int)> f) {
        while (true) {
            if (in[u] > in[v]) swap(u, v);
            f(max(in[head[v]], in[u]), in[v] + 1);
            if (head[u] == head[v]) return;
            v = par[head[v]];
        }
    }
    void subtree_query(int v, function<void(int, int)> f) {
        f(in[v], out[v]);
    }
};

パス/部分木クエリではラムダ式とかで f を渡して使います.
f には半開区間 [l, r) が渡されます.

ちなみに木のサイズが 1e6 とかの場合は, スタックオーバーフローする可能性が高いので再帰を展開した方が良いでしょう.

例題

このHL分解を使うと以下の問題が解けます.

Subtrees And Paths | HackerRank

問題概要
サイズN(<=1e5)の根付き木の各頂点が整数値を持っている. これらは全て初期値0. 以下のクエリをQ(<=1e5)個処理せよ.

  • add t value : 頂点 t の部分木に含まれる任意の頂点の整数値に値 value を加算
  • max a b : 頂点 a から頂点 b までのパスに含まれる任意の頂点が持つ整数値の最大値を出力

解法
区間加算とRMQ ができればよいので, Starry Sky Tree または遅延セグメントツリーと HL分解を組み合わせればよいです.
addクエリは O(logN), maxクエリは O( (logN)^2) でできます.

*1:正確には Centroid Path Decomposition

*2:正確には centroid path に含まれる辺ですがこれ以降も heavy-edge と表記します

区間内の x 未満の値の個数, 区間内の k 番目に小さい値

クエリの定義

長さ N の整数列 X に対して, 以下の操作を定義します. (X の値の最大値を M とする)

  • rank(p, x) :X[0], X[1], ... , X[p - 1] の中で x 未満の値の個数を取得
  • rank(l, r, x) :X[l], X[l + 1], ... , X[r - 1] の中で x 未満の値の個数を取得
  • rank(l, r, x, y) :X[l], X[l + 1], ... , X[r - 1] の中で x 以上 y 未満の値の個数を取得
  • kth_min(l, r, k) :X[l], X[l + 1], ... , X[r - 1] を昇順にソートしたときの k 番目の値を取得(0-indexed)

基本的なテクニック

rank はどれか一つができれば残りはそれを使えばできます.
なので大抵は 1つ目の方法のみを考えればよいです.

また, 一応 kth_min はオンラインクエリの rank ができれば値の二分探索で求まります.
ただし logM 倍が余計に付くので最適な方法ではないことが多いです.

値の変更なし ver

オフラインクエリ

rank

解法:座圧+クエリをソートして Fenwick Tree に各値の個数を持たせながら左端から順に追加していく
計算量:均し O(logN) で定数倍かなり高速
メモリ:O(N)

kth_min

オンラインクエリと同じ方法でいい

オンラインクエリ

rank

解法①:前処理で永続 Segment Tree で各値の個数を持って左から順に追加していき, 各位置でバージョンを持っておく. 各クエリには, 適切な位置の木に区間和クエリを投げればよい
計算量:O(logN)
メモリ:O(NlogN)

解法②:ソートされた列を持つ Segment Tree に入れて, さらに Fractional-Cascading で高速化する
計算量:O(logN)
メモリ:O(NlogN)

解法③:Wavelet Matrix を使う
計算量:O(logM) ※座圧すれば O(logN)
メモリ:O(NlogM)

kth_min

解法①:rank の解法①と同じ前処理をする. あとは各クエリで, l の位置と r の位置の木を同時に降りながら二分探索すれば求まる
計算量:O(logN)
メモリ:O(NlogN)

解法②:Wavelet Matrix を使う
計算量:O(logM) ※座圧すれば O(logN)
メモリ:O(NlogM)

まとめ

ライブラリを持ってるなら全部 Wavelet Matrix を使うのが楽そう.

値の変更あり ver

値の一点更新クエリなどと一緒にくるときの場合です.

N * M の動的な二次元 Segment Tree を使います.
(i, X[i]) (0 <= i < N) は 1 , 他は 0 にしておきます.

X[p] の値を v に変更

解法:(p, X[p]) を 0 に, (p, v) を 1 に更新する
計算量:O(logM * logN)

rank

解法:rank(l, r, x, y) では [l, r) * [x, y) の範囲の値の和を求める
計算量:O(logM * logN)

kth_min

解法:[l, r) の範囲に対応する O(logN) 個の一次元 Segment Tree を同時に降りながら, 二分探索していく
計算量:O(logM * logN)

その他

空間計算量は O(N * logN * logM).
全クエリを先読みして座圧できれば logM の部分を log(Q+N) とかにできたりします.
また, 任意の位置での削除クエリなどが来る場合は各位置で削除されたかどうかを Segment Tree で持てば計算量を変えずに行えます. この場合削除クエリは O(logN).

もし, メモリが足りない場合はM個のセグメントツリーを平衡二分木で代用すると定数倍が悪くなる代わりに空間計算量 O(N * logN) にできます.

値の変更/挿入あり ver

列の任意の位置に値の挿入クエリがきたりする場合です.
このときは動的 Wavelet Matrix を使います.
動的 Wavelet Matrix とは, 完備辞書を insert/erase 操作ができる平衡二分木に置き換えた Wavelet Matrix です.

これを使うと insert, erase, rank, kth_min の全て O(logM * logN) でできます.
空間計算量は O(N * logM).

ただし二個目の log は平衡二分木なので非常に重く, 赤黒木を使っても僕の書いたものは 10^5 回のクエリに数秒程かかります.
使える状況はほとんど無いと思った方が良いです.
欲しくなったときは想定解が平方分割とかだったりする気がします.

k 番目に小さい値を取得可能な集合を管理するデータ構造

前提

集合に対して N 個の 追加・削除・k番目に小さい値の取得 のクエリがくるとします.
集合の要素は順序が付けられるものだけ考えます.
要素が非負整数の場合は最大値を X とします.

配列(または bitset)

各値の位置にカウントをもたせます.

用途

k 番目取得が要らない場合は一番使えるので一応.

座標圧縮

  • 要素が非負整数:最大値が小さい場合不要, 大きいと必要
  • その他:必要

時間計算量

  • 追加・削除クエリ:O(1)
  • k 番目取得クエリ:O(N) または O(X)

空間計算量:O(N) または O(X)

Fenwick Tree

各値の位置にカウントをもたせます.
追加・削除でうまく足し引きして, k番目クエリでは根から二分探索しながら降ります.

用途
定数倍が速いので, 全操作が均等にくる場合はこれが最強.

座標圧縮

  • 要素が非負整数:最大値が小さい場合不要, 大きいと必要
  • その他:必要

時間計算量

  • 全クエリ:O(logN) または O(logX)

空間計算量:O(N) または O(X)

実装例

class fenwick_tree_set {
    const int n;
    int cnt;
    vector<int> data;
    int find(int p) const {
        int res = 0;
        while (p > 0) {
            res += data[p];
            p -= p & -p;
        }
        return res;
    }
    void add(int p, int x) {
        ++p;
        while (p <= n) {
            data[p] += x;
            p += p & -p;
        }
    }
public:
    fenwick_tree_set(int maxi)
        : n(maxi + 1), cnt(0), data(n + 1) {}
    int size() const {
        return cnt;
    }
    int count(int val) const {
        assert(0 <= val && val < n);
        return find(val + 1) - find(val);
    }
    void insert(int val) {
        assert(0 <= val && val < n);
        add(val, 1);
        cnt += 1;
    }
    void erase(int val) {
        assert(0 <= val && val < n);
        assert(0 < count(val));
        add(val, -1);
        cnt -= 1;
    }
    int kth_element(int k) const {
        assert(0 <= k && k < cnt);
        int p = 1 << (32 - __builtin_clz(n)), res = 0;
        while (p >>= 1) {
            if (res + p <= n && data[res + p] <= k) {
                k -= data[res + p];
                res += p;
            }
        }
        return res;
    }
};

平方分割

配列の方法 + 各バケット毎の個数の合計 を持ちます.

用途

追加・削除が高速で Mo's Algorithm と相性がよいです.

座標圧縮

  • 要素が非負整数:最大値が小さい場合不要, 大きいと必要
  • その他:必要

時間計算量

  • 追加・削除クエリ:O(1)
  • k 番目取得クエリ:O(sqrt(N)) または O(sqrt(X))

空間計算量:O(N) または O(X)

実装例
sqrtX を二冪で丸めてます.

class sqrt_decomposition_set {
    const int n, b;
    int cnt;
    vector<int> data, sum;
    int get_b(int x) const {
        int t = 0;
        while ((1 << t) < x) ++t;
        return t >> 1;
    }
public:
    sqrt_decomposition_set(int maxi)
        : n(maxi + 1), b(get_b(n)), cnt(0), data(n), sum((n + (1 << b) - 1) >> b) {}
    int size() const {
        return cnt;
    }
    int count(int val) const {
        assert(0 <= val && val < n);
        return data[val];
    }
    void insert(int val) {
        assert(0 <= val && val < n);
        ++data[val];
        ++sum[val >> b];
        ++cnt;
    }
    void erase(int val) {
        assert(0 <= val && val < n);
        assert(0 < data[val]);
        --data[val];
        --sum[val >> b];
        --cnt;
    }
    int kth_element(int k) const {
        assert(0 <= k && k < cnt);
        int it = 0;
        while (sum[it] < k) k -= sum[it++];
        int id = it << b;
        while (data[id] == 0 || data[id] <= k) k -= data[id++];
        return id;
    }
};

Trie(動的 Segment Tree)

各値の位置にカウントをもたせます.
くわしくはこれ
http://kazuma8128.hatenablog.com/entry/2018/05/06/022654

用途

座圧ができないオンラインクエリのときに使えます.
ちなみに, ある値で XOR したときの k 番目の値とかも O(logX) でできます.

座標圧縮

  • 要素が非負整数:不要
  • その他:必要(文字列とかはそのままで可能だけどもはや別物)

時間計算量

  • 全クエリ:O(logN) または O(logX)

空間計算量:O(N logN) または O(N logX)

実装例:上のリンクに載ってるので省略

平衡二分木

std::set では k番目取得クエリができないので自分で書きます.
g++拡張が使える場合は tree とかいうのを使うのが楽そう.

用途

Trie でも MLE するときとかに使います.
定数倍が異常にデカいので速度は期待しない方が良いです.

座標圧縮

  • 要素が非負整数:不要
  • その他:不要

時間計算量

  • 全クエリ:O(logN)

空間計算量:O(N)

実装例

おまけ

趣旨がずれますが, van Emde Boas Tree(通称:謎木)というものを使うと非負整数に対して以下の操作が全て O(log log X) でできます.

  • 要素の追加
  • 要素の削除
  • lower_bound / upper_bound

ちなみに最大/最小値の取得は O(1) です.

これを使えば速い set とか map が作れます.

空間計算量は O(X) です.
unordered_map を内部で使えば O(N) にできますが, 定数倍が遅くなるのでちょっと残念になります.

実装例
雑に書いたので定数倍が遅めかもしれません.
コード

CodeChef ARRQRS Logan and his ARRAY Queries

問題概要

数列に対して以下のクエリを処理せよ

  • 末尾に値 Z を追加
  • 前から X 番目の値を削除
  • 前から X 番目の値を Z に変更
  • 区間 [L, R] に含まれる値の中で小さい方から K 番目の値を出力

クエリ数 : 1e5 以下
Time Limit : 3s

解法

これらのクエリは雑に一般化すると以下になります.

  • ある位置へ値を挿入
  • ある位置の値の削除
  • ある区間の K 番目最小値の取得

これらはすべて, 動的 Wavelet Matrix でできます.
動的 Wavelet Matrix というのは, Wavelet Matrix にさらに値の挿入, 削除, 変更などができるようになった(動的になった)ものです.
やり方はシンプルで, 完備辞書に平衡二分木を使うだけです.
なので計算量は全ての操作に平衡二分木の O(logN) (Nは要素数) がさらに付きます.

というわけで, この問題は動的 Wavelet Matrix の verify にちょうどいいとおもいます.
TLがすこしゆるめ(?)なのでがんばればギリギリ通せます.
たぶんこの問題が通らないようなら大抵の問題で使い物にならないでしょう.

ちなみに, この問題の想定解はおそらく平方分割です.

ソースコード

https://www.codechef.com/viewsolution/18947173
アホみたいに長いのでリンクだけ

Codeforces Round #404 E Anton and Permutation

問題概要

1, 2, ... , n の列が最初にあって l_i, r_i が Q 個与えられるので毎回 l_i 番目の値と r_i 番目の値をスワップしてから全体の反転数を求めよ

AC までの思考過程

Editorial 曰く, 想定解は平方分割らしい.
それなりに面白い.

まあそんなことはおいといて(?), これを計算量 poly log で解きたい.
欲しいクエリは以下

  • ある位置の値の変更
  • ある区間に含まれるある値未満の値の個数の取得

これは動的 Wavelet Matrix でできる. 両方 O( (log n)^2).

しかし, 投げてみるとTLE.
流石に平衡二分木の定数倍がヤバすぎて無理っぽい.
まあしょうがないね・・・.

なので次に, 二次元セグメントツリーを試す.
MLE.
まあ空間計算量 O(q (log n)^2) なのでこれも仕方ない.

一応, 他の人の提出を見たところ二次元セグメントツリーで通ってる人もいるので, たぶん気合が足りないだけなのだろうけど自分には無理だった.

次に, セグメントツリーに平衡二分木を載せてみたところ, ギリギリAC.
やはり wavelet matrix 成分よりセグメントツリー成分の log の方が定数倍が小さいらしい.

ソースコード

http://codeforces.com/contest/785/submission/39274304

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

template <typename T>
class avl_tree {
    struct node {
        T val;
        node *ch[2];
        int dep, size;
        node(T v, node* l = nullptr, node* r = nullptr) : val(v), dep(1), size(1) {
            ch[0] = l; ch[1] = r;
        }
    };
    int depth(node *t) { return t == nullptr ? 0 : t->dep; }
    int count(node *t) { return t == nullptr ? 0 : t->size; }
    node *update(node *t) {
        t->dep = max(depth(t->ch[0]), depth(t->ch[1])) + 1;
        t->size = count(t->ch[0]) + count(t->ch[1]) + 1;
        return t;
    }
    node *rotate(node *t, int b) {
        node *s = t->ch[1 - b];
        t->ch[1 - b] = s->ch[b];
        s->ch[b] = t;
        t = update(t);
        s = update(s);
        return s;
    }
    node *fix(node *t) {
        if (t == nullptr) return t;
        if (depth(t->ch[0]) - depth(t->ch[1]) == 2) {
            if (depth(t->ch[0]->ch[1]) > depth(t->ch[0]->ch[0])) {
                t->ch[0] = rotate(t->ch[0], 0);
            }
            t = rotate(t, 1);
        }
        else if (depth(t->ch[0]) - depth(t->ch[1]) == -2) {
            if (depth(t->ch[1]->ch[0]) > depth(t->ch[1]->ch[1])) {
                t->ch[1] = rotate(t->ch[1], 1);
            }
            t = rotate(t, 0);
        }
        return t;
    }
    node *insert(node *t, int k, T v) {
        if (t == nullptr) return new node(v);
        int c = count(t->ch[0]), b = (k > c);
        t->ch[b] = insert(t->ch[b], k - (b ? (c + 1) : 0), v);
        update(t);
        return fix(t);
    }
    node *erase(node *t) {
        if (t == nullptr) return nullptr;
        if (t->ch[0] == nullptr && t->ch[1] == nullptr) {
            delete t;
            return nullptr;
        }
        if (t->ch[0] == nullptr || t->ch[1] == nullptr) {
            node *res = t->ch[t->ch[0] == nullptr];
            delete t;
            return res;
        }
        node *res = new node(find(t->ch[1], 0)->val, t->ch[0], erase(t->ch[1], 0));
        delete t;
        return fix(update(res));
    }
    node *erase(node *t, int k) {
        if (t == nullptr) return nullptr;
        int c = count(t->ch[0]);
        if (k < c) {
            t->ch[0] = erase(t->ch[0], k);
            t = update(t);
        }
        else if (k > c) {
            t->ch[1] = erase(t->ch[1], k - (c + 1));
            t = update(t);
        }
        else {
            t = erase(t);
        }
        return fix(t);
    }
    node *find(node *t, int k) {
        if (t == nullptr) return t;
        int c = count(t->ch[0]);
        return k < c ? find(t->ch[0], k) : k == c ? t : find(t->ch[1], k - (c + 1));
    }
    int cnt(node *t, T v) {
        if (t == nullptr) return 0;
        if (t->val < v) return count(t->ch[0]) + 1 + cnt(t->ch[1], v);
        if (t->val == v) return count(t->ch[0]);
        return cnt(t->ch[0], v);
    }
    node *root;
public:
    avl_tree() : root(nullptr) {}
    int size() {
        return count(root);
    }
    void insert(T val) {
        root = insert(root, cnt(root, val), val);
    }
    void erase(int k) {
        root = erase(root, k);
    }
    T find(int k) {
        return find(root, k)->val;
    }
    int count(T val) {
        return cnt(root, val);
    }
};

const int MAX = 1 << 18;

avl_tree<int> segs[MAX << 1];

void add(int x, int y, int val, int t = 1, int lb = 0, int ub = MAX) {
    if (x < lb || ub <= x) return;
    if (val > 0) segs[t].insert(y);
    else segs[t].erase(segs[t].count(y));
    if (ub - lb == 1) return;
    int m = (lb + ub) >> 1;
    add(x, y, val, t << 1, lb, m);
    add(x, y, val, (t << 1) | 1, m, ub);
}

int find(int li, int lj, int ri, int rj, int t = 1, int lb = 0, int ub = MAX) {
    if (ri <= lb || ub <= li) return 0;
    if (li <= lb && ub <= ri) return segs[t].count(rj) - segs[t].count(lj);
    int m = (lb + ub) >> 1;
    return find(li, lj, ri, rj, t << 1, lb, m) + find(li, lj, ri, rj, (t << 1) | 1, m, ub);
}

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int n, q;
    cin >> n >> q;
    vector<int> val(n);
    for (int i = 0; i < n; i++) {
        val[i] = i;
        add(i, i, 1);
    }
    ll res = 0;
    while (q--) {
        int l, r;
        cin >> l >> r; l--, r--;
        if (l == r) {
            printf("%lld\n", res);
            continue;
        }
        if (l > r) swap(l, r);
        int lv = val[l], rv = val[r];

        res += lv < rv ? 1 : -1;

        int cnt1 = find(l + 1, 0, r, lv);
        res -= cnt1;
        res += (r - l - 1) - cnt1;

        int cnt2 = find(l + 1, 0, r, rv);
        res += cnt2;
        res -= (r - l - 1) - cnt2;

        val[l] = rv, val[r] = lv;

        add(l, lv, -1);
        add(l, val[l], 1);
        add(r, rv, -1);
        add(r, val[r], 1);

        printf("%lld\n", res);
    }
    return 0;
}

感想

想定解も悪くはないんだけど, やっぱり poly log で解かないと気が済まないよね.

HackerRank Mr. X and His Shots

問題概要

一次元上にN個の線分の集合 X とM個の線分の集合 Y がある.
点または線で重なる線分のペア (x, y) (x ∈ X, y ∈ Y) の個数を求めよ.

解法

各 x ∈ X について, Y に含まれる線分のうち x より左にあるものと, 右にあるものの個数を数えれば O(MlogM + NlogM) で解ける. (Editorial の解法)

あれれ~おかしいな~, Data Structure のタグで問題を解いてたはずなのにな~.

........

というわけで以下はデータ構造パンチ解法.

二次元平面に対する以下のクエリに変換する.

  • 線分 [l, r] の追加 → 二次元平面の座標 (l, r) の値を1加算
  • [l, r] に交差する線分の個数の取得 → 二次元平面の [0, r] x [l, ∞] の範囲にある値の和

こうすると動的な二次元セグメントツリーで O(N(log(N+M))^2 + M(log(N+M))^2) で解ける.

ソースコード

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int MAX = 1 << 18;

struct seg_node1 {
    int val;
    seg_node1 *lch, *rch;
    seg_node1(int v = 0, seg_node1* l = nullptr, seg_node1* r = nullptr) : val(v), lch(l), rch(r) {}
};

int get_val(seg_node1* t) {
    return t ? t->val : 0;
}

int get_sum(seg_node1* t, int l, int r, int lb = 0, int ub = MAX) {
    if (!t || r <= lb || ub <= l) return 0;
    if (l <= lb && ub <= r) return t->val;
    int m = (lb + ub) >> 1;
    return get_sum(t->lch, l, r, lb, m) + get_sum(t->rch, l, r, m, ub);
}

seg_node1* add(seg_node1* t, int p, int val, int lb = 0, int ub = MAX) {
    if (p < lb || ub <= p) return t;
    if (!t) t = new seg_node1;
    if (ub - lb == 1) {
        t->val += val;
        return t;
    }
    int m = (lb + ub) >> 1;
    t->lch = add(t->lch, p, val, lb, m);
    t->rch = add(t->rch, p, val, m, ub);
    t->val = get_val(t->lch) + get_val(t->rch);
    return t;
}

struct seg_node2 {
    seg_node1 *tree;
    seg_node2 *lch, *rch;
    seg_node2(seg_node1* t = nullptr, seg_node2* l = nullptr, seg_node2* r = nullptr)
        : tree(t), lch(l), rch(r) {}
};

int get_sum(seg_node2* t, int li, int lj, int ri, int rj, int lb = 0, int ub = MAX) {
    if (!t || ri <= lb || ub <= li) return 0;
    if (li <= lb && ub <= ri) return get_sum(t->tree, lj, rj);
    int m = (lb + ub) >> 1;
    return get_sum(t->lch, li, lj, ri, rj, lb, m) + get_sum(t->rch, li, lj, ri, rj, m, ub);
}

seg_node2* add(seg_node2* t, int pi, int pj, int val, int lb = 0, int ub = MAX) {
    if (pi < lb || ub <= pi) return t;
    if (!t) t = new seg_node2;
    t->tree = add(t->tree, pj, val);
    if (ub - lb == 1) return t;
    int m = (lb + ub) >> 1;
    t->lch = add(t->lch, pi, pj, val, lb, m);
    t->rch = add(t->rch, pi, pj, val, m, ub);
    return t;
}


int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int N, M;
    cin >> N >> M;
    vector<int> xs, ys;
    vector<int> A(N), B(N);
    for (int i = 0; i < N; i++) {
        cin >> A[i] >> B[i];
        xs.push_back(A[i]);
        ys.push_back(B[i]);
    }
    vector<int> C(M), D(M);
    for (int i = 0; i < M; i++) {
        cin >> C[i] >> D[i]; D[i]++;
        ys.push_back(C[i]);
        xs.push_back(D[i]);
    }
    sort(xs.begin(), xs.end());
    xs.erase(unique(xs.begin(), xs.end()), xs.end());
    sort(ys.begin(), ys.end());
    ys.erase(unique(ys.begin(), ys.end()), ys.end());
    seg_node2* tree = nullptr;
    for (int i = 0; i < N; i++) {
        A[i] = lower_bound(xs.begin(), xs.end(), A[i]) - xs.begin();
        B[i] = lower_bound(ys.begin(), ys.end(), B[i]) - ys.begin();
        tree = add(tree, A[i], B[i], 1);
    }
    ll res = 0;
    for (int i = 0; i < M; i++) {
        C[i] = lower_bound(ys.begin(), ys.end(), C[i]) - ys.begin();
        D[i] = lower_bound(xs.begin(), xs.end(), D[i]) - xs.begin();
        res += get_sum(tree, 0, C[i], D[i], MAX);
    }
    cout << res << endl;
    return 0;
}

HackerRank Cube Summation

問題概要

N x N x N の3次元配列がある. (N <= 100)
一点更新クエリと, 範囲和クエリがくるので処理せよ.

解法

Nが小さいので 3 次元の Fenwick Tree を書けば終わりです.

3次元は流石にライブラリ化してなかったのでコンテストとかで出たらタイピングバトルになって面白そう(いいえ)

ソースコード

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

template <typename T>
class fenwick_tree3 {
    const int a, b, c;
    vector<vector<vector<T>>> data;
public:
    fenwick_tree3(int a_, int b_, int c_)
        : a(a_), b(b_), c(c_), data(a, vector<vector<T>>(b, vector<T>(c))) {}

    // [0, pi] & [0, pj] & [0, pk]
    T find(int pi, int pj, int pk) const {
        T res = 0;
        for (int i = pi; i >= 0; i = (i & (i + 1)) - 1) {
            for (int j = pj; j >= 0; j = (j & (j + 1)) - 1) {
                for (int k = pk; k >= 0; k = (k & (k + 1)) - 1) {
                    res += data[i][j][k];
                }
            }
        }
        return res;
    }

    // [li, ri] & [lj, rj] & [lk, rk]
    T find(int li, int lj, int lk, int ri, int rj, int rk) const {
        --li, --lj, --lk;
        return find(ri, rj, rk) - find(ri, rj, lk) - find(ri, lj, rk) - find(li, rj, rk)
             + find(ri, lj, lk) + find(li, rj, lk) + find(li, lj, rk) - find(li, lj, lk);
    }

    void add(int pi, int pj, int pk, T val) {
        for (int i = pi; i < a; i |= i + 1) {
            for (int j = pj; j < b; j |= j + 1) {
                for (int k = pk; k < c; k |= k + 1) {
                    data[i][j][k] += val;
                }
            }
        }
    }

    void update(int pi, int pj, int pk, T val) {
        add(pi, pj, pk, val - find(pi, pj, pk, pi, pj, pk));
    }
};

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int T;
    cin >> T;
    while (T--) {
        int N, M;
        cin >> N >> M;
        fenwick_tree3<ll> ft(N, N, N);
        while (M--) {
            string com;
            cin >> com;
            if (com == "UPDATE") {
                int x, y, z, w;
                cin >> x >> y >> z >> w;
                x--, y--, z--;
                ft.update(x, y, z, w);
            }
            else {
                int x1, y1, z1, x2, y2, z2;
                cin >> x1 >> y1 >> z1 >> x2 >> y2 >> z2;
                x1--, y1--, z1--, x2--, y2--, z2--;
                cout << ft.find(x1, y1, z1, x2, y2, z2) << endl;
            }
        }
    }
    return 0;
}