esc to dismiss
x

信息

题解

构建图

每加入一个边, 就扫描整个图, 更新由这个边联通的节点之间的距离, 复杂度 \(O(n^3)\), 然后超时

pub fn sum_of_distances_in_tree(n: i32, edges: Vec<Vec<i32>>) -> Vec<i32> {
    let n = n as usize;
    let mut dp = vec![vec![i32::MAX; n]; n];
    for edge in edges {
        let (start, end) = (edge[0], edge[1]);
        let (start, end) = (start as usize, end as usize);
        dp[start][end] = 1;
        dp[end][start] = 1;
        for i in 0..n {
            if i == start || i == end{
                continue;
            }
            // 两端并不一定相同, 这个地方会有遗漏, 比如 x -> start -> end -> y 的就更新不到, 导致节点连不起来

                // start --> end --> i or start --> i
                let end2i = dp[end][i];
                let start2i = dp[start][i];
                dp[start][i] = start2i.min(end2i.checked_add(1).unwrap_or(i32::MAX));
                dp[i][start] = dp[start][i];
                // end --> start --> i or end --> i
                dp[end][i] = end2i.min(start2i.checked_add(1).unwrap_or(i32::MAX));
                dp[i][end] = dp[end][i];

            for j in 0..n{
                if i == j || j == start || j == end{
                    continue;
                }
                // i -> start -> end -> y
                let (i2start, start2j) = (dp[i][start], dp[start][j]);
                let (i2end, end2j) = (dp[i][end], dp[end][j]);

                dp[i][j] = dp[i][j].min(
                    i2start.checked_add(start2j).unwrap_or(i32::MAX) // i --> start --> j
                ).min(
                    i2end.checked_add(end2j).unwrap_or(i32::MAX) // i --> end --> j
                );
                dp[j][i] = dp[i][j];
            }
        }
    }
    // dbg!(&dp);
    let mut ans = vec![0; n];
    for i in 0..n {
        let mut cnt = 0;
        for j in 0..n {
            if i == j {
                continue;
            }
            cnt += dp[i][j];
        }
        ans[i] = cnt;
    }
    ans
}

树形DP

递推关系

对于两个相邻节点\(A\)\(B\),将树拆分为两个子树, 根节点分别为\(A\)\(B\). 记

  • \(A\)子树中所有节点到\(A\)节点的距离和\(sum_A\)
  • \(A\)子树的大小(节点数量)为\(cnt_A\)
  • \(A\)节点到所有其他节点的距离和\(ans_{A}\)

同理记\(B\)子树的分别为 \(sum_B\), \(cnt_B\), \(ans_B\)

则有 \(ans_A = sum_A + (sum_B + cnt_B)\)
前半部分为\(A\)子树自身的
后半部分为 \(B\)子树到节点\(B\)的和, 加上\(B\)子树内所有节点通过 A-B 之间距离为1的连接的次数 \(1 \times cnt_B\).

特殊情况

如果\(A\)\(root\), 那么\(B\)对应的为空树, 因此 \(sum_B\)\(cnt_B\) 为0
\(ans_{root} = sum_{root}\)

进而可以推导得到
$$
\begin{align}
ans_A &= sum_A + sum_B + cnt_B \
ans_B &= sum_B + sum_A + cnt_A \
\
ans_A &= ans_B - cnt_A + cnt_B \
&= ans_B - cnt_A + N - cnt_A\
\
ans_B &= ans_A - cnt_B + cnt_A \
&= ans_A - cnt_B + N - cnt_B
\end{align}
$$

\(A=root\), 则可以得到
$$
\begin{align}
ans_B &= ans_{root} - cnt_B + N - cnt_B \
&= sum_{root} - cnt_B + N - cnt_B
\end{align}
$$

因此如果得到 \(sum_{root}\), 则可以计算出与 \(root\) 相连的子节点, 进而递归得到所有节点的数据(先\(root\), 后子节点, 先序遍历)

\(sum_{root}\) 为所有与 \(root\) 相连的子节点的 \(sum\) 之和 加上各自的规模, 即
$$
sum_{root} = \textstyle{ \sum_{i} sum_i + cnt_i }
$$
(先子节点, 后root, 后序遍历)

Tip

题目只给出树的边, 并没有限定谁是 \(root\), 随便找一个\(node\)做根也是可以的

pub fn sum_of_distances_in_tree(n: i32, edges: Vec<Vec<i32>>) -> Vec<i32> {
    fn post_order(
        ans: &mut Vec<i32>,
        cnt: &mut Vec<i32>,
        graph: &Vec<Vec<usize>>,
        child: usize,
        parent: usize,
    ) {
        for i in 0..graph[child].len() {
            if graph[child][i] != parent {
                // 因为是无向边, 不能 "开倒车" 走回到父节点
                post_order(ans, cnt, graph, graph[child][i], child);
                // 计算完子节点的, 开始汇总
                cnt[child] += cnt[graph[child][i]]; // 得到以 child 为根的子树的规模
                ans[child] += ans[graph[child][i]] + cnt[graph[child][i]]; // 暂时用 ans 存 sum 的结果, 最终只有 ans[0] == sum[0]
            }
        }
    }
    fn pre_order(
        ans: &mut Vec<i32>,
        cnt: &mut Vec<i32>,
        graph: &Vec<Vec<usize>>,
        child: usize,
        parent: usize,
    ) {
        for i in 0..graph[child].len() {
            if parent != graph[child][i] {
                // 开始逐层修正 ans 的值
                // 初始只有 ans[0] 结果正确的
                ans[graph[child][i]] =
                    ans[child] - cnt[graph[child][i]] + (ans.len() as i32) - cnt[graph[child][i]];
                // 当前层的子节点处理完毕, 处理以这个子节点为根的子树
                pre_order(ans, cnt, graph, graph[child][i], child);
            }
        }
    }

    let n = n as usize;

    let mut ans = vec![0; n];
    let mut cnt = vec![1; n]; // 统计自己, 默认大小为1
    let mut graph = vec![vec![]; n];

    for edge in edges {
        let (start, end) = (edge[0] as usize, edge[1] as usize);
        graph[start].push(end); // 记录 节点start可以到的其他节点
        graph[end].push(start);
    }

    // 准确讲, 其实是 dfs
    post_order(&mut ans, &mut cnt, &graph, 0, n + 1);
    pre_order(&mut ans, &mut cnt, &graph, 0, n + 1);

    ans
}

也可以通过指针关联节点, 不过拿 rust 不好写

class Node:
    def __init__(self, num):
        self.num = num

        self.children = set()
        self.cnt = 1 # 初始就有自己
        self.sum = 0 # 只有计算 root 时用, 为了过程清晰还是拆出来
        self.ans = 0 # 和其他节点的距离和, 也就是答案

    def add_child(self, child):
        self.children.add(child)

    def sum_and_cnt(self, parent):
        for child in self.children:
            if child == parent:
                continue
            child.sum_and_cnt(self)
            self.cnt += child.cnt
            self.sum += (child.sum + child.cnt)

    def calculate_dist(self, N, parent):
        for child in self.children:
            if child == parent:
                continue
            child.ans = self.ans - child.cnt + N - child.cnt
            child.calculate_dist(N, self)

    def fill_ans(self, ans, parent):
        ans[self.num] = self.ans
        for child in self.children:
            if child == parent:
                continue
            child.fill_ans(ans, self)

class Graph:
    def __init__(self):
        self.node = dict()
    def add_edge(self, edge):
        f_node = self.node.get(edge[0], Node(edge[0]))
        t_node = self.node.get(edge[1], Node(edge[1]))

        f_node.add_child(t_node)
        t_node.add_child(f_node)

        self.node[edge[0]] = f_node
        self.node[edge[1]] = t_node
    def peek_root(self):
        return self.node.get(0)

class Solution:
    def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:        
        graph = Graph()
        for edge in edges:
            graph.add_edge(edge)

        root = graph.peek_root()
        if root is None:
            return [0]

        root.sum_and_cnt(None)

        root.ans = root.sum

        root.calculate_dist(n, None)

        ans = [0] * n
        root.fill_ans(ans, None)

        return ans
x