题解(树形DP) by Star

Star 2022-02-12 16:14:08 2022-02-12 16:15:39

险些把@Laffey的题抢了

注意审题!!因为审题出了很多问题。

这题一眼看下去很无从下手,简单来说就是距离为2的两个点的权值之积为联合权值,让求最大联合权值所有联合权值的总和

仔细思考一下,一棵树中存在两个距离为2的点只有两种情况,一条线一个角

image-20220211140921603

数据有 60% 在 2000 以内, 这使我萌生了暴力的想法,先拿60再说

暴搜思想(DFS)

I. 一条线

这种很好想,只需在dfs的基础上记录一个前驱pre,若当前父节点为x,子节点为y,那么pre与y的距离为2,w[pre] * w[j]w[j] * w[pre]是两个联合权值。

代码

void dfs(int x, int pre) {
    v[x] = 1;
    for(int i = head[x]; i; i = nxt[i]) {
        int & y = to[i];
        if(!v[y]) {
            dfs(y, x);
            maxn = max(maxn, w[y] * w[pre]);
            sum += w[y] *w[pre] * 2;
            sum %= MOD;
        }
    }
}

这个时候,样例过了,因为样例是单链!!

II.一个角

这种情况难度就稍微大了一些,需统计同一层所有子节点,对其进行排列组合,每种配对的情况都是两组联合权值。 (可以对比之前做过的一道深搜 选数)

举个栗子:

image-20220211184217103

这是一个以①为根的子树

搜到②时,无法配对,继续搜索

搜得到③时,可与②组合

搜到④时,可与②和③组合

刚好不重不漏

我们的思路是搜到每一个子节点后向回搜索同层的子节点,分别计算,为此我们要记录当前父节点下的已搜过的子节点,直接开N * N的数组会爆too large所以我们开vector<int> son[N]

做法是每搜到一个节点,先回搜同一层的子节点进行操作,然后将其编号存入son数组中为下一个节点使用,在原dfs加入

for(auto i : son[x]) {
	maxn = max(maxn, w[y] * w[i]);
	sum += w[y] * w[i] * 2;
	sum %= MOD;
}
son[x].push_back(y);

这样,一个暴力的做法就完成了。

暴力(70分)代码

#include <cstdio>
#include <iostream>
#include <vector>

using namespace std;

const int N = 4000010;
const int MOD = 10007;
typedef long long ll; //记得开long long不然取模之前容易爆
ll maxn;
ll sum;
bool v[N];
int to[N], head[N], nxt[N];
ll w[N];
vector<int> son[N];
int tot = 0;

void add(int x, int y) {
    to[++tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void dfs(int x, int pre) {
    v[x] = 1;
    for(int i = head[x]; i; i = nxt[i]) {
        int & y = to[i];
        if(!v[y]) {
            dfs(y, x);
            for(auto i : son[x]) {
                maxn = max(maxn, w[y] * w[i]);
                sum += w[y] * w[i] * 2;
                sum %= MOD;
            }
            son[x].push_back(y);
            maxn = max(maxn, w[y] * w[pre]);
            sum += w[y] * w[pre] * 2;
            sum %= MOD;
        }
    }
}

int main() {
    int n;
    scanf("%d", &n);
    for(int i = 1; i < n; ++i) {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    for(int i = 1; i <= n; ++i)
        scanf("%d", &w[i]);

    dfs(1, 0);

    printf("%lld %lld", maxn, sum);
    return 0;
}

优化 DFS -> 树形DP

I.一个角

一条线的时间复杂度已经不大了,所以我们先来考虑一个角的情况

再举个相同的栗子:

image-20220211184217103

回顾一下,以④为例

sum += 2 * w[2] *w[4] + 2 * w[3] *w[4] = 2 * w[4] * (w[2] + w[3])

乘法分配律,明白吧

推广到一般的父节点x,子节点y

sum = sum + 2 * w[y] * s(w[y前所有子节点])

同样最值也可以这么操作

maxn = max(maxn, w[y] * max(w[y前所有的子节点]))

所以处理这种情况只需维护一个前缀和s[N]和一个最值m[N]就可以完成状态转移,无需遍历一遍。

sum += 2 * s[x] * w[y];
sum %= MOD;
maxn = max(maxn, m[x] * w[y]);
//注意先处理sum和maxn再转移状态,因为反之w[y]会 * w[y]自己造成问题
s[x] += w[y];
m[x] = max(m[x], w[y]);

把之前的循环换掉就可以a掉了

II.一条线

image-20220212141139984

当dfs回溯到x时

s[y] = sum(w[y的所有孩子])\\ m[y] = max(w[y的所有孩子])

你发现了什么?

其实在做上一次优化的时候一条线的情况也已经预处理好了,只需

sum += 2 * s[y] * w[x];
maxn = max(maxn, m[y] * w[x]);

就ok了

这样也不用再统计前驱了

Ac代码

#include <cstdio>
#include <iostream>
#include <vector>

using namespace std;

const int N = 400010;
const int MOD = 10007;
typedef long long ll;
bool v[N];
int to[N], head[N], nxt[N]; //前向星
ll s[N], m[N], w[N];
/**
* s[i]表示以i为顶的一层子树节点权值和
* m[i]表示以i为顶的一层子树节点权值最值
*/
ll maxn = 0, sum = 0;
int tot = 0;

void add(int x, int y) {
    to[++tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void dp(int x) {
    v[x] = 1;
    for(int i = head[x]; i; i = nxt[i]) {
        int & y = to[i];
        if(!v[y]) {
            dp(y);
            // case corner
            sum += 2 * s[x] * w[y];
            sum %= MOD;
            maxn = max(maxn, m[x] * w[y]);

            //case line 这个放到哪里都行,与当前层的y权值无关
            sum += 2 * s[y] * w[x];
            sum %= MOD;
            maxn = max(maxn, m[y] * w[x]);

            //dp
            s[x] += w[y];
            m[x] = max(m[x], w[y]);
        }
    }
}

int main() {
    int n;

    //read
    scanf("%d", &n);
    for(int i = 1; i < n; ++i) {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    for(int i = 1; i <= n; ++i)
        scanf("%d", &w[i]);
    
    //dp
    dp(1);

    printf("%lld %lld", maxn, sum);
    return 0;
}