std & 大致思路

Laffey 2022-09-14 20:48:27 2022-09-15 18:08:31

思路

假设把根从 1 换到 x ,那么修改一个在链 1 \sim x 上的节点的子树权值相当于在整个树上加上一个 v ,在 x 位于链上的儿子的子树上减去一个 v 。要判断是否在链上只要维护一个 LCA 就行了。这个儿子可以在求 LCA 时一并求出。注意特判修改节点 x 的情况。

询问操作同理。

std

来个人挑挑错

#include <cstdio>
using namespace std;

void swap(int &a, int &b) { int t = a; a = b, b = t; }

const int MAXN = 1e5 + 10;
const unsigned int MOD = 19260817;
int Head[MAXN], Next[MAXN << 1], to[MAXN << 1], tot;
int we[MAXN];

void add(const int &x, const int &y)
{
    to[++tot] = y;
    Next[tot] = Head[x];
    Head[x] = tot;
}

int Size[MAXN];
int l[MAXN], r[MAXN];
int id[MAXN];

int f[MAXN][20];
int dep[MAXN];

void dfs(int x)
{
    for (int i = 1; i <= 18; i++) {
        f[x][i] = f[f[x][i - 1]][i - 1];
    }

    Size[x] = 1;
    id[l[x]] = x;

    for (int i = Head[x]; i; i = Next[i]) {
        int y = to[i];
        if (y == f[x][0]) continue;

        l[y] = l[x] + Size[x];

        dep[y] = dep[x] + 1;
        f[y][0] = x;

        dfs(y);

        Size[x] += Size[y];
    }

    r[x] = l[x] + Size[x] - 1;
}

struct pii { int a, b; };

pii lca(int x, int y)
{
    if (dep[x] < dep[y]) {
        swap(x, y);
    }

    for (int i = 18; i >= 0; i--) {
        if (dep[f[x][i]] > dep[y]) {
            x = f[x][i];
        }
    }

    if (f[x][0] == y) return {f[x][0], x};

    for (int i = 18; i >= 0; i--) {
        if (f[x][i] != f[y][i]) {
            x = f[x][i];
            y = f[y][i];
        }
    }

    return {f[x][0], x};
}

int n;

typedef long long ll;

struct Stree {
    ll val, add;
    int l, r;
    int len;
} tr[MAXN << 2];

#define ls (p << 1)
#define rs (p << 1 | 1)
#define lson tr[ls]
#define rson tr[rs]
#define self tr[p]

void build(int p, int l, int r)
{
    self.l = l, self.r = r;
    self.len = r - l + 1;
    if (l == r) {
        self.val = we[id[l]];
        return;
    }

    int mid = l + r >> 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);

    self.val = lson.val + rson.val;
    self.val %= MOD;

    return;
}

void pushdown(int p)
{
    if (self.add) {
        lson.val += self.add * lson.len;
        lson.add += self.add;
        lson.add %= MOD;
        lson.val %= MOD;

        rson.val += self.add * rson.len;
        rson.add += self.add;
        rson.add %= MOD;
        rson.val %= MOD;

        self.add = 0;
    }
}

void add(int p, int l, int r, ll v)
{
    if (l <= self.l && self.r <= r) {
        self.add += v;
        self.val += self.len * v;

        self.add %= MOD;
        self.val %= MOD;
        return;
    }

    pushdown(p);
    if (l <= lson.r) add(ls, l, r, v);
    if (rson.l <= r) add(rs, l, r, v);

    self.val = lson.val + rson.val;
    self.val %= MOD;

    return;
}

ll query(int p, int l, int r)
{
    if (l <= self.l && self.r <= r) {
        return self.val;
    }

    pushdown(p);

    ll ans = 0;
    if (l <= lson.r) ans += query(ls, l, r), ans %= MOD;
    if (rson.l <= r) ans += query(rs, l, r), ans %= MOD;

    return ans;
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("tr3.in", "r", stdin);
    freopen("tr3.ans", "w", stdout);
#endif
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &we[i]);
    }
    for (int i = 1; i < n; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    l[1] = 1, r[1] = n;
    dep[1] = 1;
    dfs(1);
    build(1, 1, n);

    int root = 1;
    int q;
    scanf("%d", &q);
    while (q--) {
        int op, x, v;
        scanf("%d", &op);
        pii p;
        
        switch (op) {
            case 1:
                scanf("%d", &x);
                p = lca(root, x);
                if (root == x) {
                    printf("%lld\n", query(1, 1, n));
                }
                else if (p.a == x) {
                    ll ans = query(1, 1, n) - query(1, l[p.b], r[p.b]);
                    if (ans < 0) ans += MOD;
                    printf("%lld\n", ans);
                }
                else {
                    printf("%lld\n", query(1, l[x], r[x]));
                }
                break;
            case 2:
                scanf("%d%d", &x, &v);
                p = lca(root, x);
                if (root == x) {
                    add(1, 1, n, v);
                }
                else if (p.a == x) {
                    add(1, 1, n, v);
                    add(1, l[p.b], r[p.b], -v);
                }
                else {
                    add(1, l[x], r[x], v);
                }
                break;
            case 3:
                scanf("%d", &x);
                root = x;
                break;
            default:
                printf("ERROR\n");
                return 0;
        }
    }

    return 0;
}