自学内容网 自学内容网

第十四届蓝桥杯省赛 Python B 组 I 题——异或和(AC)

1. 异或和

  • 前置知识点:树状数组,线段树,DFS 序。

1. 问题描述

给一棵含有 n n n 个结点的有根树,根结点为 1 1 1,编号为 i i i 的点有点权 a i a_i ai i ∈ [ 1 , n ] i \in [1,n] i[1,n])。现在有两种操作,格式如下:

  • 1   x   y 1\ x\ y 1 x y:该操作表示将点 x x x 的点权改为 y y y
  • 2   x 2\ x 2 x:该操作表示查询以结点 x x x 为根的子树内的所有点的点权的异或和。

现有长度为 m m m 的操作序列,请对于每个第二类操作给出正确的结果。

2 .输入格式

输入的第一行包含两个正整数 n , m n,m n,m,用一个空格分隔。

第二行包含 n n n 个整数 a 1 , a 2 , … , a n a_1,a_2,\ldots,a_n a1,a2,,an,相邻整数之间使用一个空格分隔。

接下来 n − 1 n-1 n1 行,每行包含两个正整数 u i , v i u_i,v_i ui,vi,表示结点 u i u_i ui v i v_i vi 之间有一条边。

接下来 m m m 行,每行包含一个操作。

3. 输出格式

输出若干行,每行对应一个查询操作的答案。

4. 样例输入

4 4
1 2 3 4
1 2
1 3
2 4
2 1
1 1 0
2 1
2 2

5. 样例输出

4
5
6

6. 评测用例规模与约定

对于 30 30 30% 的评测用例, n , m ≤ 1000 n,m \leq 1000 n,m1000

对于所有评测用例, 1 ≤ n , m ≤ 100000 1 \leq n,m \leq 100000 1n,m100000 0 ≤ a i , y ≤ 100000 0 \leq a_i,y \leq 100000 0ai,y100000 1 ≤ u i , v i , x ≤ n 1 \leq u_i,v_i,x \leq n 1ui,vi,xn

7. 原题链接

异或和

2. 解题思路

考虑第二个操作,查询以节点 x x x 为根的子树内的所有点的点权的异或和。

类似这种子树查询问题,我们通常使用 DFS 序对树进行预处理。具体地说,在 DFS 遍历中,我们从根节点开始,依次遍历它的每个子节点。对于每个子节点,我们首先遍历它的子树,然后回溯到该子节点,继续遍历它的兄弟节点。在遍历的过程中,我们可以记录每个节点在 DFS 序中的遍历顺序,即第一次遍历到该节点的时间戳和最后一次遍历到该节点的时间戳。这里的时间戳可以使用一个计数器来实现,每次遍历到一个新节点时,计数器加 1 1 1,表示当前节点的时间戳。当回溯到该节点时,表示当前节点的最后一次遍历时间戳。

这样操作有什么作用呢?假设我们有一个长度大于 n n n 的数组 a a a,我们记进入每个点 i i i 的时间戳为 in [ i ] \text{in}[i] in[i],回溯到点 i i i 的时间戳为 out [ i ] \text{out}[i] out[i],同时将每个点的点权赋值到 a [ in [ i ] ] a[\text{in}[i]] a[in[i]] 上。这样对于一个根为 x x x 的子树内所有点的点权异或和就等价于 a a a 数组区间 [ in [ x ] , out [ x ] ] [\text{in}[x],\text{out}[x]] [in[x],out[x]] 的异或和。这样我们就将复杂的树上询问,转化为我们熟悉的数组区间查询问题。

接下来考虑操作 1 1 1,将点 x x x 的点权改为 y y y

结合上述分析,该操作即是将 a [ in [ x ] ] a[\text{in}[x]] a[in[x]] 改为 y y y

综上所述,我们需要对 a a a 数组进行一个单点修改和区间查询的操作,这个经典操作我们可以使用树状数组或者线段树来维护,代码中使用的是树状数组。具体地说,我们使用一个树状数组 a a a 来维护树的 DFS 序的前缀异或序列, a i a_i ai 表示区间 [ 1 , i ] [1,i] [1,i] 的异或和。

  • 操作 1 1 1:将 a [ in [ x ] ] a[\text{in}[x]] a[in[x]] 修改为 y y y
  • 操作 2 2 2:求解 [ in [ x ] , out [ x ] ] [\text{in}[x],\text{out}[x]] [in[x],out[x]] 的异或和,根据异或性质 [ 1 , r ] ⊕ [ 1 , l − 1 ] = [ l , r ] [1,r] \oplus[1,l-1]=[l,r] [1,r][1,l1]=[l,r],我们只需要求解 a in [ x ] − 1 ⊕ a r a_{\text{in}[x]-1} \oplus a_r ain[x]1ar 即可求得答案。

时间复杂度为 O ( n log ⁡ n ) O(n \log n) O(nlogn)

3. AC_Code

  • C++
#include<bits/stdc++.h>
using namespace std;
const int N = 100010;

template <typename T>
struct Fenwick {
int n;
std::vector<T> a;

Fenwick(int n = 0) {
init(n);
}

void init(int n) {
this->n = n;
a.assign(n + 1, T());
}

void add(int x, T v) {
for (; x <= n; x += x & (-x)) {
a[x] ^= v;
}
}

T sum(int x) {
auto ans = T();
for (; x; x -= x & (-x)) {
ans ^= a[x];
}
return ans;
}

T rangeSum(int l, int r) {
return sum(r) ^ sum(l);
}
};
int n, m, tot;
int a[N], in[N], out[N];
std::vector<int> e[N];
void dfs(int u, int fa) {
in[u] = ++tot;
for (auto v : e[u]) {
if (v == fa) continue;
dfs(v, u);
}
out[u] = tot;
}
int main()
{
ios_base :: sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m;
Fenwick<int> tr(n);
for (int i = 1; i <= n; ++i) cin >> a[i];
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
for (int i = 1; i <= n; ++i) tr.add(in[i], a[i]);
int op, x, y;
for (int i = 0; i < m; ++i) {
cin >> op >> x;
if (op == 1) {
cin >> y;
int v = tr.rangeSum(in[x] - 1, in[x]);
tr.add(in[x], y ^ v);
} else {
cout << tr.rangeSum(in[x] - 1, out[x]) << '\n';
}
}
return 0;
}
  • Java
import java.util.*;
import java.io.*;
 
public class Main {
    static int N = 100010;
    static int n, m, tot;
    static int[] a = new int[N], in = new int[N], out = new int[N], b = new int[N];
    static List<Integer>[] e = new List[N];
 
    static void add(int x, int v) {
        for (; x <= n; x += x & (-x)) {
            b[x] ^= v;
        }
    }
 
    static int sum(int x) {
        int ans = 0;
        if (x == 0) return 0;
        for (; x > 0; x -= x & (-x)) {
            ans ^= b[x];
        }
        return ans;
    }
 
    static int rangeSum(int l, int r) {
        return sum(r) ^ sum(l);
    }
 
    static void dfs(int u, int fa) {
        in[u] = ++tot;
        for (int v : e[u]) {
            if (v == fa) continue;
            dfs(v, u);
        }
        out[u] = tot;
    }
 
    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out));
        String[] temp = reader.readLine().split(" ");
        n = Integer.parseInt(temp[0]);
        m = Integer.parseInt(temp[1]);
        temp = reader.readLine().split(" ");
        for (int i = 1; i <= n; ++i) {
            a[i] = Integer.parseInt(temp[i - 1]);
            e[i]=new ArrayList<>();
        }
        for (int i = 0; i < n - 1; ++i) {
            temp = reader.readLine().split(" ");
            int u = Integer.parseInt(temp[0]);
            int v = Integer.parseInt(temp[1]);
            e[u].add(v);
            e[v].add(u);
        }
        dfs(1, 0);
        for (int i = 1; i <= n; ++i) {
            add(in[i], a[i]);
        }
        for (int i = 0; i < m; ++i) {
            temp = reader.readLine().split(" ");
            int op = Integer.parseInt(temp[0]);
            int x = Integer.parseInt(temp[1]);
            if (op == 1) {
                int y = Integer.parseInt(temp[2]);
                int v = rangeSum(in[x] - 1, in[x]);
                add(in[x], y ^ v);
            } else {
                writer.write(rangeSum(in[x] - 1, out[x]) + "\n");
            }
        }
        reader.close();
        writer.flush();
        writer.close();
    }
}
  • Python
import sys

N = 100010
n, m, tot = 0, 0, 0
a = [0]*N
in_ = [0]*N
out = [0]*N
b = [0]*N
e = [[] for _ in range(N)]

def add(x, v):
    while x <= n:
        b[x] ^= v
        x += x & (-x)

def sum_(x):
    ans = 0
    if x == 0:
        return ans
    while x > 0:
        ans ^= b[x]
        x -= x & (-x)
    return ans

def rangeSum(l, r):
    return sum_(r) ^ sum_(l)

def dfs(u, fa):
    global tot
    in_[u] = tot = tot + 1
    for v in e[u]:
        if v == fa:
            continue
        dfs(v, u)
    out[u] = tot

def main():
    global n, m, tot
    n, m = map(int, sys.stdin.readline().split())
    a[1:n+1] = map(int, sys.stdin.readline().split())
    for _ in range(n - 1):
        u, v = map(int, sys.stdin.readline().split())
        e[u].append(v)
        e[v].append(u)
    dfs(1, 0)
    for i in range(1, n+1):
        add(in_[i], a[i])
    for _ in range(m):
        op, x, *extra = map(int, sys.stdin.readline().split())
        if op == 1:
            y = extra[0]
            v = rangeSum(in_[x] - 1, in_[x])
            add(in_[x], y ^ v)
        else:
            print(rangeSum(in_[x] - 1, out[x]))

if __name__ == "__main__":
    main()

原文地址:https://blog.csdn.net/m0_57487901/article/details/135619052

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!