HackerRank Mr. X and His Shots
問題概要
一次元上にN個の線分の集合 X とM個の線分の集合 Y がある.
点または線で重なる線分のペア (x, y) (x ∈ X, y ∈ Y) の個数を求めよ.
解法
各 x ∈ X について, Y に含まれる線分のうち x より左にあるものと, 右にあるものの個数を数えれば O(MlogM + NlogM) で解ける. (Editorial の解法)
あれれ~おかしいな~, Data Structure のタグで問題を解いてたはずなのにな~.
........
というわけで以下はデータ構造パンチ解法.
二次元平面に対する以下のクエリに変換する.
- 線分 [l, r] の追加 → 二次元平面の座標 (l, r) の値を1加算
- [l, r] に交差する線分の個数の取得 → 二次元平面の [0, r] x [l, ∞] の範囲にある値の和
こうすると動的な二次元セグメントツリーで O(N(log(N+M))^2 + M(log(N+M))^2) で解ける.
ソースコード
#include <bits/stdc++.h> using namespace std; using ll = long long; const int MAX = 1 << 18; struct seg_node1 { int val; seg_node1 *lch, *rch; seg_node1(int v = 0, seg_node1* l = nullptr, seg_node1* r = nullptr) : val(v), lch(l), rch(r) {} }; int get_val(seg_node1* t) { return t ? t->val : 0; } int get_sum(seg_node1* t, int l, int r, int lb = 0, int ub = MAX) { if (!t || r <= lb || ub <= l) return 0; if (l <= lb && ub <= r) return t->val; int m = (lb + ub) >> 1; return get_sum(t->lch, l, r, lb, m) + get_sum(t->rch, l, r, m, ub); } seg_node1* add(seg_node1* t, int p, int val, int lb = 0, int ub = MAX) { if (p < lb || ub <= p) return t; if (!t) t = new seg_node1; if (ub - lb == 1) { t->val += val; return t; } int m = (lb + ub) >> 1; t->lch = add(t->lch, p, val, lb, m); t->rch = add(t->rch, p, val, m, ub); t->val = get_val(t->lch) + get_val(t->rch); return t; } struct seg_node2 { seg_node1 *tree; seg_node2 *lch, *rch; seg_node2(seg_node1* t = nullptr, seg_node2* l = nullptr, seg_node2* r = nullptr) : tree(t), lch(l), rch(r) {} }; int get_sum(seg_node2* t, int li, int lj, int ri, int rj, int lb = 0, int ub = MAX) { if (!t || ri <= lb || ub <= li) return 0; if (li <= lb && ub <= ri) return get_sum(t->tree, lj, rj); int m = (lb + ub) >> 1; return get_sum(t->lch, li, lj, ri, rj, lb, m) + get_sum(t->rch, li, lj, ri, rj, m, ub); } seg_node2* add(seg_node2* t, int pi, int pj, int val, int lb = 0, int ub = MAX) { if (pi < lb || ub <= pi) return t; if (!t) t = new seg_node2; t->tree = add(t->tree, pj, val); if (ub - lb == 1) return t; int m = (lb + ub) >> 1; t->lch = add(t->lch, pi, pj, val, lb, m); t->rch = add(t->rch, pi, pj, val, m, ub); return t; } int main() { ios::sync_with_stdio(false), cin.tie(0); int N, M; cin >> N >> M; vector<int> xs, ys; vector<int> A(N), B(N); for (int i = 0; i < N; i++) { cin >> A[i] >> B[i]; xs.push_back(A[i]); ys.push_back(B[i]); } vector<int> C(M), D(M); for (int i = 0; i < M; i++) { cin >> C[i] >> D[i]; D[i]++; ys.push_back(C[i]); xs.push_back(D[i]); } sort(xs.begin(), xs.end()); xs.erase(unique(xs.begin(), xs.end()), xs.end()); sort(ys.begin(), ys.end()); ys.erase(unique(ys.begin(), ys.end()), ys.end()); seg_node2* tree = nullptr; for (int i = 0; i < N; i++) { A[i] = lower_bound(xs.begin(), xs.end(), A[i]) - xs.begin(); B[i] = lower_bound(ys.begin(), ys.end(), B[i]) - ys.begin(); tree = add(tree, A[i], B[i], 1); } ll res = 0; for (int i = 0; i < M; i++) { C[i] = lower_bound(ys.begin(), ys.end(), C[i]) - ys.begin(); D[i] = lower_bound(xs.begin(), xs.end(), D[i]) - xs.begin(); res += get_sum(tree, 0, C[i], D[i], MAX); } cout << res << endl; return 0; }