kazuma8128’s blog

競プロの面白い問題を解きます

CodeChef GPD Gotham PD

問題概要

頂点数 N の根付き木が与えられる. 各頂点には正整数のキーが付いている.
以下の二種類の Q 個のクエリをオンラインで処理せよ.

クエリ1:頂点 v の子に, キー k の付いた新しい頂点 u を増やす
クエリ2:頂点 v から根までのパスに含まれる頂点に付いたキーの中で, k でXORしたときの最大の値と最小の値を求める


制約:1 ≤ N ≤ 10^5, 1 ≤ Q ≤ 2*10^5, 1 ≤ (各キー) ≤ 2^31-1

解法

まず, トライ木を使うと, 要素の挿入 (insert) , 要素の削除 (erase) , ある値でXORしたときの最大値/最小値の取得が O(ビット数) でできます. 詳しくは以下のページなどを参照
www.geeksforgeeks.org

これを使って根からDFSして, ある頂点に行くときはそのキーを insert して, 抜けるときに erase する, という感じでやってやるとオフラインクエリ(クエリを先読みできる場合)なら解けます.

しかし, 今回の問題は謎の暗号化で先読みができないのでこれでは解けません.
というわけで, トライ木を完全永続化しましょう!(唐突)

オフラインと同様に DFS しながら, 各頂点において, 親頂点のトライ木(の根)にその頂点のキーを insert した新しいトライ木(の根)を保存していきます.

トライ木の永続化の方法は非常に簡単で, insert をするときにノードの情報をそのまま書き換えるのではなく, 新しいノードを作ってからそれを書き換える, という風にするだけです.

トライ木やセグメントツリーのような木構造を永続化するときは, 元の実装で変更クエリを再帰で書いて, さらに返り値にノードのポインタを返すような実装にしておくと修正が少なくて済むので楽です.

ソースコード

https://www.codechef.com/viewsolution/17455234

#include <bits/stdc++.h>
using namespace std;
using graph = vector<vector<int>>;

int compress(int key) {
	static map<int, int> dic;
	static int cnt = 0;
	if (dic.count(key)) return dic[key];
	return dic[key] = cnt++;
}

const int BIT_SIZE = 31;

struct t_node {
	int size;
	t_node *lch, *rch;
	t_node(int s = 1, t_node *l = nullptr, t_node *r = nullptr)
		: size(s), lch(l), rch(r) {}
};

t_node* new_node(t_node *x) {
	return x ? new t_node(x->size + 1, x->lch, x->rch) : new t_node();
}

t_node* insert(t_node *x, int val, int h = BIT_SIZE - 1) {
	if (h < 0) return new t_node();
	x = new_node(x);
	if ((val >> h) & 1) {
		x->rch = insert(x->rch, val, h - 1);
	}
	else {
		x->lch = insert(x->lch, val, h - 1);
	}
	return x;
}

int get_min(t_node *x, int val, int h = BIT_SIZE - 1) {
	if (h < 0) return 0;
	if ((val >> h) & 1) {
		if (x->rch) {
			return get_min(x->rch, val, h - 1) | (1 << h);
		}
		else {
			return get_min(x->lch, val, h - 1);
		}
	}
	else {
		if (x->lch) {
			return get_min(x->lch, val, h - 1);
		}
		else {
			return get_min(x->rch, val, h - 1) | (1 << h);
		}
	}
}

int get_max(t_node *x, int val, int h = BIT_SIZE - 1) {
	if (h < 0) return 0;
	if ((val >> h) & 1) {
		if (x->lch) {
			return get_max(x->lch, val, h - 1);
		}
		else {
			return get_max(x->rch, val, h - 1) | (1 << h);
		}
	}
	else {
		if (x->rch) {
			return get_max(x->rch, val, h - 1) | (1 << h);
		}
		else {
			return get_max(x->lch, val, h - 1);
		}
	}
}

void dfs(int v, int prev, const graph& G, const vector<int>& keys, vector<t_node*>& rts) {
	rts[v] = insert(prev == -1 ? nullptr : rts[prev], keys[v]);
	for (auto to : G[v]) {
		dfs(to, v, G, keys, rts);
	}
}

int main()
{
	ios::sync_with_stdio(false), cin.tie(0);
	int N, Q, R, key;
	cin >> N >> Q >> R >> key;
	graph G(N);
	vector<int> keys(N);
	R = compress(R);
	keys[R] = key;
	for (int i = 0; i < N - 1; i++) {
		int u, v, k;
		cin >> u >> v >> k;
		u = compress(u); v = compress(v);
		G[v].push_back(u);
		keys[u] = k;
	}
	vector<t_node*> rts(N + Q);
	dfs(R, -1, G, keys, rts);
	int last = 0;
	while (Q--) {
		int t, u, v, k;
		cin >> t;
		t ^= last;
		if (t == 0) {
			cin >> v >> u >> k; v ^= last; u ^= last; k ^= last;
			v = compress(v); u = compress(u);
			rts[u] = insert(rts[v], k);
		}
		else {
			cin >> v >> k; v ^= last; k ^= last;
			v = compress(v);
			int mi = get_min(rts[v], k);
			int ma = get_max(rts[v], k);
			cout << (mi ^ k) << ' ' << (ma ^ k) << endl;
			last = mi ^ ma;
		}
	}
	return 0;
}

蛇足

なぜか血迷って C言語でも書いてみた.
https://www.codechef.com/viewsolution/17444288