kazuma8128’s blog

競プロの面白い問題を解きます

CodeChef FIBTREE Fibonacci Numbers on Tree

問題概要

サイズNの木がある. 各頂点は値を持っている(初期値0). Q個の以下のクエリをオンラインで処理せよ.

  • クエリ1:頂点 x から頂点 y へのパス上の各頂点に順番に fib(0), fib(1), ... , fib(k) のフィボナッチ数列を足す
  • クエリ2:頂点 x を根としたときに, 頂点 y の部分木に含まれる頂点の値の和 mod 1e9 + 9 を出力
  • クエリ3:頂点 x から頂点 y へのパス上の頂点の値の和 mod 1e9 + 9 を出力
  • クエリ4:木全体を x 回目のクエリの直後の状態に戻す

解法

面白いけど, 流石に引くレベルの激ヤバ問題.
HL分解でできる色んなしんどい要素をとにかく詰め込みまくったみたいな感じ.

解法は一言でいうと, 部分木クエリ対応のHL分解をしてから完全永続な遅延セグメントツリーに入れてやる.

クエリ1

HL分解すれば区間フィボナッチ数列を足すクエリに変換できる.
これは遅延セグメントツリーの各ノードに部分木の和とそのノードに加えたい列の最初の2項を持たせてやれば, そのノードに足す値と下に伝搬する2項が求まる.
このときに最初の2項から長さLの列の和を求めたり, 最初の2項からK個進んだ位置の2項を求める必要があるが, これは最初にN個程度の行列を前処理で求めておいてそれを掛けてやれば重めの O(1) で可能.
よって各クエリ O(logN) できる.

ちなみにこの数列バージョンだけの問題が以下. (これだけでもかなり大変)
http://codeforces.com/contest/446/problem/C

これにHL分解の log が付いてほんとに間に合うの?という気持ちにはなるけど, HL分解の log は速いし TL 5s だからなんとかなる.

さらに, パスクエリなせいで左右の方向も考えないといけないので, 右向きに足される列と左向きに足される列の2個の遅延セグメントツリーを持たないといけなくて大変.

クエリ2

ここでも根を固定しない嫌がらせ.
実は根が固定でなくても結局は根を 1 (1-indexed)にしてHL分解して問題ない.

x = y なら (全体の和) を求めればよい.
x != y で x が y の部分木に含まれないなら普通に (y の部分木の和) を求めればよい.
x != y で x が y の部分木に含まれるなら, y の x 方向への子頂点を z とするとき, (全体の和) - (z の部分木の和) を求めればよい.
この部分木に含まれるかどうかの判定は部分木クエリ対応HL分解(オイラーツアー)なら O(1) で判定可能.
z はHL分解を利用しても求まるし, ダブリングを使っても高々 O(logN) で求まる.
部分木クエリも部分木クエリ対応HL分解ならできる.

部分木クエリ対応HL分解については1つ前の記事を参照.

クエリ3

これはパス上の和を求めるだけなのでHL分解で可能.

クエリ4

ここで唐突な完全永続要素.
まあ先ほどの遅延セグメントツリーを永続に実装すればできる.
一応空間計算量に log が付くけど問題はない.

ソースコード

https://www.codechef.com/viewsolution/19244642

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

template<int MOD>
struct mod_int {
    static const int Mod = MOD;
    unsigned x;
    mod_int() : x(0) { }
    mod_int(int sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
    mod_int(long long sig) { int sigt = sig % MOD; if (sigt < 0) sigt += MOD; x = sigt; }
    int get() const { return (int)x; }

    mod_int &operator+=(mod_int that) { if ((x += that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator-=(mod_int that) { if ((x += MOD - that.x) >= MOD) x -= MOD; return *this; }
    mod_int &operator*=(mod_int that) { x = (unsigned long long)x * that.x % MOD; return *this; }
    mod_int &operator/=(mod_int that) { return *this *= that.inverse(); }

    mod_int operator+(mod_int that) const { return mod_int(*this) += that; }
    mod_int operator-(mod_int that) const { return mod_int(*this) -= that; }
    mod_int operator*(mod_int that) const { return mod_int(*this) *= that; }
    mod_int operator/(mod_int that) const { return mod_int(*this) /= that; }

    mod_int inverse() const {
        long long a = x, b = MOD, u = 1, v = 0;
        while (b) {
            long long t = a / b;
            a -= t * b; swap(a, b);
            u -= t * v; swap(u, v);
        }
        return mod_int(u);
    }
    bool operator==(mod_int that) const { return x == that.x; }
};

const int mod = 1e9 + 9;

using mint = mod_int<mod>;
using pmint = pair<mint, mint>;

using vi = vector<int>;
using vvi = vector<vector<int>>;

vvi prod(const vvi& a, const vvi& b) {
    assert(a.front().size() == b.size());
    int x = a.size(), y = a.front().size(), z = b.front().size();
    vvi c(x, vi(z));
    for (int i = 0; i < x; i++) {
        for (int j = 0; j < y; j++) {
            for (int k = 0; k < z; k++) {
                c[i][k] = (c[i][k] + (int)((ll)a[i][j] * b[j][k] % mod)) % mod;
            }
        }
    }
    return c;
}

const int MAX = 1 << 17;

mint nex_f[MAX + 1][2][2];
mint sum_f[MAX + 1][2];

void init_fib() {
    vvi nex = { { 0, 1 },{ 1, 1 } };
    vvi sum = { { 0, 1, 0 },{ 1, 1, 0 },{ 1, 0, 1 } };

    vvi tmp1 = { { 1, 0 },{ 0, 1 } };
    nex_f[0][0][0] = 1;
    nex_f[0][0][1] = 0;
    nex_f[0][1][0] = 0;
    nex_f[0][1][1] = 1;
    for (int i = 1; i <= MAX; i++) {
        tmp1 = prod(nex, tmp1);
        for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 2; k++) {
                nex_f[i][j][k] = tmp1[j][k];
            }
        }

    }
    vvi tmp2 = { { 1, 0, 0 },{ 0, 1, 0 },{ 0, 0, 1 } };
    sum_f[0][0] = sum_f[0][1] = 0;
    for (int i = 1; i <= MAX; i++) {
        tmp2 = prod(sum, tmp2);
        sum_f[i][0] = tmp2[2][0];
        sum_f[i][1] = tmp2[2][1];
    }
}

pmint next_fib(const pmint& a01, int x) {
    return make_pair(nex_f[x][0][0] * a01.first + nex_f[x][0][1] * a01.second
                , nex_f[x][1][0] * a01.first + nex_f[x][1][1] * a01.second);
}

mint sum_fib(const pmint& a01, int x) {
    return sum_f[x][0] * a01.first + sum_f[x][1] * a01.second;
}

const pmint id(0, 0);

void operator+=(pmint& l, const pmint& r) {
    l.first += r.first;
    l.second += r.second;
}

struct seg_node {
    mint sum;
    pmint lazy;
    seg_node *lch, *rch;
    seg_node(mint s = 0, const pmint& la = id, seg_node* l = nullptr, seg_node* r = nullptr)
        : sum(s), lazy(la), lch(l), rch(r) {}
};

seg_node* add(seg_node* t, int l, int r, const pmint& val, int lb = 0, int ub = MAX) {
    if (r <= lb || ub <= l) return t;
    t = t ? new seg_node(t->sum, t->lazy, t->lch, t->rch) : new seg_node;
    if (l <= lb && ub <= r) {
        auto p = next_fib(val, lb - l);
        t->sum += sum_fib(p, ub - lb);
        t->lazy += p;
        return t;
    }
    int m = (lb + ub) >> 1;
    if (t->lazy != id) {
        t->lch = t->lch ? new seg_node(t->lch->sum, t->lch->lazy, t->lch->lch, t->lch->rch) : new seg_node;
        t->lch->sum += sum_fib(t->lazy, m - lb);
        t->lch->lazy += t->lazy;
        auto p = next_fib(t->lazy, m - lb);
        t->rch = t->rch ? new seg_node(t->rch->sum, t->rch->lazy, t->rch->lch, t->rch->rch) : new seg_node;
        t->rch->sum += sum_fib(p, ub - m);
        t->rch->lazy += p;
        t->lazy = id;
    }
    t->lch = add(t->lch, l, r, val, lb, m);
    t->rch = add(t->rch, l, r, val, m, ub);
    t->sum = (t->lch ? t->lch->sum : 0) + (t->rch ? t->rch->sum : 0);
    return t;
}

mint find(seg_node* t, int l, int r, pmint lazy = id, int lb = 0, int ub = MAX) {
    if (ub <= l || r <= lb) return 0;
    if (!t) {
        l = max(l, lb);
        r = min(r, ub);
        return sum_fib(next_fib(lazy, l - lb), r - l);
    }
    if (l <= lb && ub <= r) return t->sum + sum_fib(lazy, ub - lb);
    lazy += t->lazy;
    int m = (lb + ub) >> 1;
    return find(t->lch, l, r, lazy, lb, m) + find(t->rch, l, r, next_fib(lazy, m - lb), m, ub);
}

class heavy_light_decomposition {
    const int n;
    vector<vector<int>> g;
    vector<int> par, heavy, head, in, out;
    void dfs1(int rt) {
        vector<int> size(n, 1);
        vector<size_t> iter(n);
        vector<pair<int, int>> stp;
        stp.reserve(n);
        stp.emplace_back(rt, -1);
        while (!stp.empty()) {
            int v = stp.back().first;
            if (iter[v] < g[v].size()) {
                if (g[v][iter[v]] != stp.back().second) {
                    stp.emplace_back(g[v][iter[v]], v);
                }
                ++iter[v];
                continue;
            }
            par[v] = stp.back().second;
            for (auto& u : g[v]) if (u == par[v]) {
                if (u != g[v].back()) swap(u, g[v].back());
                g[v].pop_back();
                break;
            }
            for (auto& u : g[v]) {
                size[v] += size[u];
                if (size[u] > size[g[v].front()]) swap(u, g[v].front());
            }
            heavy[v] = g[v].empty() ? -1 : g[v].front();
            stp.pop_back();
        }
    }
    void dfs2(int rt) {
        int it = 0;
        vector<size_t> iter(n);
        vector<int> st; st.reserve(n);
        st.push_back(rt);
        while (!st.empty()) {
            int v = st.back();
            if (!iter[v]) in[v] = it++;
            if (iter[v] < g[v].size()) {
                int u = g[v][iter[v]];
                head[u] = iter[v] ? u : head[v];
                st.push_back(u);
                ++iter[v];
                continue;
            }
            out[v] = it;
            st.pop_back();
        }
    }
public:
    heavy_light_decomposition(int n_)
        : n(n_), g(n), par(n), heavy(n), head(n), in(n), out(n) {}
    void add_edge(int u, int v) {
        g[u].push_back(v);
        g[v].push_back(u);
    }
    void build(int rt = 0) {
        dfs1(rt);
        head[rt] = rt;
        dfs2(rt);
    }
    int get_id(int v) {
        return in[v];
    }
    int is_include(int u, int v) {
        return in[u] <= in[v] && out[v] <= out[u];
    }
    int get_child(int u, int v) {
        while (head[v] != head[u]) {
            if (par[head[v]] == u) return head[v];
            v = par[head[v]];
        }
        return heavy[u];
    }
    void path_query(int u, int v, function<void(int, int)> f) {
        while (true) {
            if (in[u] > in[v]) swap(u, v);
            f(max(in[head[v]], in[u]), in[v] + 1);
            if (head[u] == head[v]) return;
            v = par[head[v]];
        }
    }
    void path_query(int u, int v, function<void(int, int, bool)> f) {
        vector<pair<int, int>> lc, rc;
        while (true) {
            if (in[u] <= in[v]) {
                rc.emplace_back(max(in[head[v]], in[u]), in[v] + 1);
                if (head[u] == head[v]) break;
                v = par[head[v]];
            }
            else {
                lc.emplace_back(max(in[head[u]], in[v]), in[u] + 1);
                if (head[u] == head[v]) break;
                u = par[head[u]];
            }
        }
        reverse(rc.begin(), rc.end());
        for (auto& p : lc) f(p.first, p.second, true);
        for (auto& p : rc) f(p.first, p.second, false);
    }
    void subtree_query(int v, function<void(int, int)> f) {
        f(in[v], out[v]);
    }
};

int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    init_fib();
    int N, M;
    cin >> N >> M;
    heavy_light_decomposition hld(N);
    for (int i = 0; i < N - 1; i++) {
        int u, v;
        cin >> u >> v; --u, --v;
        hld.add_edge(u, v);
    }
    hld.build();
    vector<seg_node*> lrt(M), rrt(M);
    seg_node *lr = nullptr, *rr = nullptr;
    int lastans = 0;
    for (int qn = 0; qn < M; qn++) {
        lrt[qn] = lr; rrt[qn] = rr;
        string com;
        int x;
        cin >> com >> x; x ^= lastans;
        if (com == "A") {
            int y;
            cin >> y; --x, --y;
            pmint val(1, 1);
            hld.path_query(x, y, [&](int l, int r, bool rev) {
                if (!rev) {
                    lr = add(lr, l, r, val);
                }
                else {
                    rr = add(rr, N - r, N - l, val);
                }
                val = next_fib(val, r - l);
            });
        }
        else if (com == "QS") {
            int y;
            cin >> y; --x, --y;
            mint res = 0;
            if (x == y) {
                res = find(lr, 0, N) + find(rr, 0, N);
            }
            else if (!hld.is_include(y, x)) {
                hld.subtree_query(y, [&](int l, int r) {
                    res += find(lr, l, r) + find(rr, N - r, N - l);
                });
            }
            else {
                res = find(lr, 0, N) + find(rr, 0, N);
                int z = hld.get_child(y, x);
                hld.subtree_query(z, [&](int l, int r) {
                    res -= find(lr, l, r) + find(rr, N - r, N - l);
                });
            }
            printf("%d\n", lastans = res.get());
        }
        else if (com == "QC") {
            int y;
            cin >> y; --x, --y;
            mint res = 0;
            hld.path_query(x, y, [&](int l, int r) {
                res += find(lr, l, r) + find(rr, N - r, N - l);
            });
            printf("%d\n", lastans = res.get());
        }
        else {
            lr = lrt[x];
            rr = rrt[x];
        }
    }
    return 0;
}