遅延評価 Segment Tree の一般的な実装方法
この間, AOJ にある Do use segment tree とかいう問題を解きなおしたときに今まで持っていた遅延評価 Segment Tree で無理やり書こうとしたら,無限にバグらせて結局一から書き直す羽目になったので,そのとき考えたことについて書きます.
間違っているところとかがあれば指摘してもらえると助かります.
そもそも遅延評価 Segment Tree とは
皆さんおなじみの遅延評価 Segment Tree ですが,こいつはいったい何物なのかと考えると,データ列に対して以下の操作ができるデータ構造と考えられます.
- 操作① : ある区間に対して,含まれる全ての値をある演算で順番に連結していった結果の値を求める.
- 操作② : ある区間に対して,含まれる全ての値をある値と結合させた結果に置き換える.
ただしデータ列の値と操作②で作用させる値は両方ともモノイドである(結合律を満たす演算と単位元が存在する)必要があります.
このとき,この二つの値が同じ種類の型である必要はないということに気付きました.(今まで同じ型で実装していたのでツラかった)
そう考えると,必要な型と演算は以下のようになると思います.
- 型(モノイド)① : データ列の各値の型
- 型(モノイド)② : 操作②で作用させる値の型
- 演算① : 型①同士の演算で,データ列を連結させる演算
- 演算② : 操作②で,型①に対して型②を作用させるための演算
- 演算③ : 型②同士の演算(実はこれも無いとダメ)
これを踏まえたうえでいい感じに実装します.
実装例(モノイドと演算)
各モノイドに対して単位元の定義も必要なことに注意してください.
これは int 型でのRMQと区間代入クエリがくる場合の例です.
class RURM { public: using t1 = int; using t2 = int; static t1 id1() { return INT_MAX; } static t2 id2() { return -1; } static t1 op1(const t1& l, const t1& r) { return min(l, r); } static t1 op2(const t1& l, const t2& r) { return r != id2() ? r : l; } static t2 op3(const t2& l, const t2& r) { return r != id2() ? r : l; } };
実装例(本体)
template <typename M> class LazySegmentTree { using T1 = typename M::t1; using T2 = typename M::t2; private: const int n; vector<T1> data; vector<T2> lazy; private: int size(int n) { int res = 1; while (res < n) res <<= 1; return res; } void push(int node) { if (lazy[node] == M::id2()) return; if (node < n) { lazy[node * 2] = M::op3(lazy[node * 2], lazy[node]); lazy[node * 2 + 1] = M::op3(lazy[node * 2 + 1], lazy[node]); } data[node] = M::op2(data[node], lazy[node]); lazy[node] = M::id2(); } T1 sub(int l, int r, int node, int lb, int ub) { if (ub <= l || r <= lb) return M::id1(); if (l <= lb && ub <= r) return M::op2(data[node], lazy[node]); push(node); int c = (lb + ub) / 2; return M::op1(sub(l, r, node * 2, lb, c), sub(l, r, node * 2 + 1, c, ub)); } void suc(int l, int r, int node, int lb, int ub, T2 val) { if (ub <= l || r <= lb) return; if (l <= lb && ub <= r) { lazy[node] = M::op3(lazy[node], val); return; } push(node); int c = (lb + ub) / 2; suc(l, r, node * 2, lb, c, val); suc(l, r, node * 2 + 1, c, ub, val); data[node] = M::op1(M::op2(data[node * 2], lazy[node * 2]) , M::op2(data[node * 2 + 1], lazy[node * 2 + 1])); } public: LazySegmentTree(int n_) : n(size(n_)), data(n * 2, M::id1()), lazy(n * 2, M::id2()) {} LazySegmentTree(int n_, T1 v1) : n(size(n_)), data(n * 2, v1), lazy(n * 2, M::id2()) {} LazySegmentTree(const vector<T1>& data_) : n(size(data_.size())), data(n * 2, M::id1()), lazy(n * 2, M::id2()) { init(data_); } void init() { for (int i = n - 1; i >= 1; i--) data[i] = M::op1(data[i * 2], data[i * 2 + 1]); } void init(const vector<T1>& data_) { for (int i = 0; i < (int)data_.size(); i++) data[i + n] = data_[i]; init(); } T1 find(int l, int r) { return sub(l, r, 1, 0, n); } void update(int l, int r, T2 val) { suc(l, r, 1, 0, n, val); } };
RMQ and RUQ | Aizu Online Judge で verify 済みです.(ソースコード)
Starry Sky(星空) も遅延伝搬しない StarrySkyTree を使わずに一応通りました.(ソースコード)
本来の目的である Do use Segment tree も通りました.(ソースコード)
おまけ
再帰って遅いらしいし*1,非再帰で書いたら速いんじゃないの?と思ったので書いてみました.
非再帰での実装例(本体)
template <typename M> class LazySegmentTree { using T1 = typename M::t1; using T2 = typename M::t2; private: const int h, n; vector<T1> data; vector<T2> lazy; private: void push(int node) { if (lazy[node] == M::id2()) return; if (node < n) { lazy[node * 2] = M::op3(lazy[node * 2], lazy[node]); lazy[node * 2 + 1] = M::op3(lazy[node * 2 + 1], lazy[node]); } data[node] = M::op2(data[node], lazy[node]); lazy[node] = M::id2(); } void update(int node) { data[node] = M::op1(M::op2(data[node * 2], lazy[node * 2]) , M::op2(data[node * 2 + 1], lazy[node * 2 + 1])); } public: LazySegmentTree(int n_) : h(ceil(log2(n_))), n(1 << h), data(n * 2, M::id1()), lazy(n * 2, M::id2()) {} LazySegmentTree(int n_, T1 v1) : h(ceil(log2(n_))), n(1 << h), data(n * 2, v1), lazy(n * 2, M::id2()) {} LazySegmentTree(const vector<T1>& data_) : h(ceil(log2(data_.size()))), n(1 << h), data(n * 2, M::id1()), lazy(n * 2, M::id2()) { init(data_); } void init() { for (int i = n - 1; i >= 1; i--) data[i] = M::op1(data[i * 2], data[i * 2 + 1]); } void init(const vector<T1>& data_) { for (int i = 0; i < (int)data_.size(); i++) data[i + n] = data_[i]; init(); } void update(int l, int r, T2 val) { l += n, r += n - 1; for (int i = h; i > 0; i--) push(l >> i), push(r >> i); int tl = l, tr = r; r++; while (l < r) { if (l & 1) lazy[l] = M::op3(lazy[l], val), l++; if (r & 1) r--, lazy[r] = M::op3(lazy[r], val); l >>= 1; r >>= 1; } while (tl >>= 1, tr >>= 1, tl) { if (lazy[tl] == M::id2()) update(tl); if (lazy[tr] == M::id2()) update(tr); } } T1 find(int l, int r) { l += n, r += n - 1; for (int i = h; i > 0; i--) push(l >> i), push(r >> i); r++; T1 res1 = M::id1(), res2 = M::id1(); while (l < r) { if (l & 1) res1 = M::op1(res1, M::op2(data[l], lazy[l])), l++; if (r & 1) r--, res2 = M::op1(M::op2(data[r], lazy[r]), res2); l >>= 1; r >>= 1; } return M::op1(res1, res2); } };
これも上記の問題で verify 済みです.
再帰と非再帰の速度比較
手元の環境(GNU C++ 4.9.3 で最適化オプションは -O2)でランダムなクエリを生成して速度を計測したところ以下のようになりました.
クエリの種類はさっきのと同じで int 型でのRMQと区間代入クエリです.
1e5回の find 操作 | 1e6回の find 操作 | 1e7回の find 操作 | 1e5回の update 操作 | 1e6回の update 操作 | 1e7回の update 操作 | |
---|---|---|---|---|---|---|
再帰 | 36 ms | 566 ms | 9715 ms | 36 ms | 619 ms | 10274 ms |
非再帰 | 6 ms | 85 ms | 2974 ms | 37 ms | 494 ms | 8082 ms |
find 操作はかなり速くなってますが,update 操作は思っていたほど速くなりませんでした.
*1:要出典