kazuma8128’s blog

競プロの面白い問題を解きます

HackerRank Mr. X and His Shots

問題概要

一次元上にN個の線分の集合 X とM個の線分の集合 Y がある.
点または線で重なる線分のペア (x, y) (x ∈ X, y ∈ Y) の個数を求めよ.

解法

各 x ∈ X について, Y に含まれる線分のうち x より左にあるものと, 右にあるものの個数を数えれば O(MlogM + NlogM) で解ける. (Editorial の解法)

あれれ~おかしいな~, Data Structure のタグで問題を解いてたはずなのにな~.

........

というわけで以下はデータ構造パンチ解法.

二次元平面に対する以下のクエリに変換する.

  • 線分 [l, r] の追加 → 二次元平面の座標 (l, r) の値を1加算
  • [l, r] に交差する線分の個数の取得 → 二次元平面の [0, r] x [l, ∞] の範囲にある値の和

こうすると動的な二次元セグメントツリーで O(N(log(N+M))^2 + M(log(N+M))^2) で解ける.

ソースコード

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

const int MAX = 1 << 18;

struct seg_node1 {
    int val;
    seg_node1 *lch, *rch;
    seg_node1(int v = 0, seg_node1* l = nullptr, seg_node1* r = nullptr) : val(v), lch(l), rch(r) {}
};

int get_val(seg_node1* t) {
    return t ? t->val : 0;
}

int get_sum(seg_node1* t, int l, int r, int lb = 0, int ub = MAX) {
    if (!t || r <= lb || ub <= l) return 0;
    if (l <= lb && ub <= r) return t->val;
    int m = (lb + ub) >> 1;
    return get_sum(t->lch, l, r, lb, m) + get_sum(t->rch, l, r, m, ub);
}

seg_node1* add(seg_node1* t, int p, int val, int lb = 0, int ub = MAX) {
    if (p < lb || ub <= p) return t;
    if (!t) t = new seg_node1;
    if (ub - lb == 1) {
        t->val += val;
        return t;
    }
    int m = (lb + ub) >> 1;
    t->lch = add(t->lch, p, val, lb, m);
    t->rch = add(t->rch, p, val, m, ub);
    t->val = get_val(t->lch) + get_val(t->rch);
    return t;
}

struct seg_node2 {
    seg_node1 *tree;
    seg_node2 *lch, *rch;
    seg_node2(seg_node1* t = nullptr, seg_node2* l = nullptr, seg_node2* r = nullptr)
        : tree(t), lch(l), rch(r) {}
};

int get_sum(seg_node2* t, int li, int lj, int ri, int rj, int lb = 0, int ub = MAX) {
    if (!t || ri <= lb || ub <= li) return 0;
    if (li <= lb && ub <= ri) return get_sum(t->tree, lj, rj);
    int m = (lb + ub) >> 1;
    return get_sum(t->lch, li, lj, ri, rj, lb, m) + get_sum(t->rch, li, lj, ri, rj, m, ub);
}

seg_node2* add(seg_node2* t, int pi, int pj, int val, int lb = 0, int ub = MAX) {
    if (pi < lb || ub <= pi) return t;
    if (!t) t = new seg_node2;
    t->tree = add(t->tree, pj, val);
    if (ub - lb == 1) return t;
    int m = (lb + ub) >> 1;
    t->lch = add(t->lch, pi, pj, val, lb, m);
    t->rch = add(t->rch, pi, pj, val, m, ub);
    return t;
}


int main()
{
    ios::sync_with_stdio(false), cin.tie(0);
    int N, M;
    cin >> N >> M;
    vector<int> xs, ys;
    vector<int> A(N), B(N);
    for (int i = 0; i < N; i++) {
        cin >> A[i] >> B[i];
        xs.push_back(A[i]);
        ys.push_back(B[i]);
    }
    vector<int> C(M), D(M);
    for (int i = 0; i < M; i++) {
        cin >> C[i] >> D[i]; D[i]++;
        ys.push_back(C[i]);
        xs.push_back(D[i]);
    }
    sort(xs.begin(), xs.end());
    xs.erase(unique(xs.begin(), xs.end()), xs.end());
    sort(ys.begin(), ys.end());
    ys.erase(unique(ys.begin(), ys.end()), ys.end());
    seg_node2* tree = nullptr;
    for (int i = 0; i < N; i++) {
        A[i] = lower_bound(xs.begin(), xs.end(), A[i]) - xs.begin();
        B[i] = lower_bound(ys.begin(), ys.end(), B[i]) - ys.begin();
        tree = add(tree, A[i], B[i], 1);
    }
    ll res = 0;
    for (int i = 0; i < M; i++) {
        C[i] = lower_bound(ys.begin(), ys.end(), C[i]) - ys.begin();
        D[i] = lower_bound(xs.begin(), xs.end(), D[i]) - xs.begin();
        res += get_sum(tree, 0, C[i], D[i], MAX);
    }
    cout << res << endl;
    return 0;
}