每加入一个边, 就扫描整个图, 更新由这个边联通的节点之间的距离, 复杂度 \(O(n^3)\), 然后超时
pub fn sum_of_distances_in_tree(n: i32, edges: Vec<Vec<i32>>) -> Vec<i32> {
let n = n as usize;
let mut dp = vec![vec![i32::MAX; n]; n];
for edge in edges {
let (start, end) = (edge[0], edge[1]);
let (start, end) = (start as usize, end as usize);
dp[start][end] = 1;
dp[end][start] = 1;
for i in 0..n {
if i == start || i == end{
continue;
}
// 两端并不一定相同, 这个地方会有遗漏, 比如 x -> start -> end -> y 的就更新不到, 导致节点连不起来
// start --> end --> i or start --> i
let end2i = dp[end][i];
let start2i = dp[start][i];
dp[start][i] = start2i.min(end2i.checked_add(1).unwrap_or(i32::MAX));
dp[i][start] = dp[start][i];
// end --> start --> i or end --> i
dp[end][i] = end2i.min(start2i.checked_add(1).unwrap_or(i32::MAX));
dp[i][end] = dp[end][i];
for j in 0..n{
if i == j || j == start || j == end{
continue;
}
// i -> start -> end -> y
let (i2start, start2j) = (dp[i][start], dp[start][j]);
let (i2end, end2j) = (dp[i][end], dp[end][j]);
dp[i][j] = dp[i][j].min(
i2start.checked_add(start2j).unwrap_or(i32::MAX) // i --> start --> j
).min(
i2end.checked_add(end2j).unwrap_or(i32::MAX) // i --> end --> j
);
dp[j][i] = dp[i][j];
}
}
}
// dbg!(&dp);
let mut ans = vec![0; n];
for i in 0..n {
let mut cnt = 0;
for j in 0..n {
if i == j {
continue;
}
cnt += dp[i][j];
}
ans[i] = cnt;
}
ans
}
对于两个相邻节点\(A\)和\(B\),将树拆分为两个子树, 根节点分别为\(A\)和\(B\). 记
同理记\(B\)子树的分别为 \(sum_B\), \(cnt_B\), \(ans_B\)
则有 \(ans_A = sum_A + (sum_B + cnt_B)\)
前半部分为\(A\)子树自身的
后半部分为 \(B\)子树到节点\(B\)的和, 加上\(B\)子树内所有节点通过 A-B
之间距离为1
的连接的次数 \(1 \times cnt_B\).
如果\(A\)为\(root\), 那么\(B\)对应的为空树, 因此 \(sum_B\) 和 \(cnt_B\) 为0
即 \(ans_{root} = sum_{root}\)
进而可以推导得到
$$
\begin{align}
ans_A &= sum_A + sum_B + cnt_B \
ans_B &= sum_B + sum_A + cnt_A \
\
ans_A &= ans_B - cnt_A + cnt_B \
&= ans_B - cnt_A + N - cnt_A\
\
ans_B &= ans_A - cnt_B + cnt_A \
&= ans_A - cnt_B + N - cnt_B
\end{align}
$$
令 \(A=root\), 则可以得到
$$
\begin{align}
ans_B &= ans_{root} - cnt_B + N - cnt_B \
&= sum_{root} - cnt_B + N - cnt_B
\end{align}
$$
因此如果得到 \(sum_{root}\), 则可以计算出与 \(root\) 相连的子节点, 进而递归得到所有节点的数据(先\(root\), 后子节点, 先序遍历)
而 \(sum_{root}\) 为所有与 \(root\) 相连的子节点的 \(sum\) 之和 加上各自的规模, 即
$$
sum_{root} = \textstyle{ \sum_{i} sum_i + cnt_i }
$$
(先子节点, 后root, 后序遍历)
题目只给出树的边, 并没有限定谁是 \(root\), 随便找一个\(node\)做根也是可以的
pub fn sum_of_distances_in_tree(n: i32, edges: Vec<Vec<i32>>) -> Vec<i32> {
fn post_order(
ans: &mut Vec<i32>,
cnt: &mut Vec<i32>,
graph: &Vec<Vec<usize>>,
child: usize,
parent: usize,
) {
for i in 0..graph[child].len() {
if graph[child][i] != parent {
// 因为是无向边, 不能 "开倒车" 走回到父节点
post_order(ans, cnt, graph, graph[child][i], child);
// 计算完子节点的, 开始汇总
cnt[child] += cnt[graph[child][i]]; // 得到以 child 为根的子树的规模
ans[child] += ans[graph[child][i]] + cnt[graph[child][i]]; // 暂时用 ans 存 sum 的结果, 最终只有 ans[0] == sum[0]
}
}
}
fn pre_order(
ans: &mut Vec<i32>,
cnt: &mut Vec<i32>,
graph: &Vec<Vec<usize>>,
child: usize,
parent: usize,
) {
for i in 0..graph[child].len() {
if parent != graph[child][i] {
// 开始逐层修正 ans 的值
// 初始只有 ans[0] 结果正确的
ans[graph[child][i]] =
ans[child] - cnt[graph[child][i]] + (ans.len() as i32) - cnt[graph[child][i]];
// 当前层的子节点处理完毕, 处理以这个子节点为根的子树
pre_order(ans, cnt, graph, graph[child][i], child);
}
}
}
let n = n as usize;
let mut ans = vec![0; n];
let mut cnt = vec![1; n]; // 统计自己, 默认大小为1
let mut graph = vec![vec![]; n];
for edge in edges {
let (start, end) = (edge[0] as usize, edge[1] as usize);
graph[start].push(end); // 记录 节点start可以到的其他节点
graph[end].push(start);
}
// 准确讲, 其实是 dfs
post_order(&mut ans, &mut cnt, &graph, 0, n + 1);
pre_order(&mut ans, &mut cnt, &graph, 0, n + 1);
ans
}
也可以通过指针关联节点, 不过拿 rust
不好写
class Node:
def __init__(self, num):
self.num = num
self.children = set()
self.cnt = 1 # 初始就有自己
self.sum = 0 # 只有计算 root 时用, 为了过程清晰还是拆出来
self.ans = 0 # 和其他节点的距离和, 也就是答案
def add_child(self, child):
self.children.add(child)
def sum_and_cnt(self, parent):
for child in self.children:
if child == parent:
continue
child.sum_and_cnt(self)
self.cnt += child.cnt
self.sum += (child.sum + child.cnt)
def calculate_dist(self, N, parent):
for child in self.children:
if child == parent:
continue
child.ans = self.ans - child.cnt + N - child.cnt
child.calculate_dist(N, self)
def fill_ans(self, ans, parent):
ans[self.num] = self.ans
for child in self.children:
if child == parent:
continue
child.fill_ans(ans, self)
class Graph:
def __init__(self):
self.node = dict()
def add_edge(self, edge):
f_node = self.node.get(edge[0], Node(edge[0]))
t_node = self.node.get(edge[1], Node(edge[1]))
f_node.add_child(t_node)
t_node.add_child(f_node)
self.node[edge[0]] = f_node
self.node[edge[1]] = t_node
def peek_root(self):
return self.node.get(0)
class Solution:
def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
graph = Graph()
for edge in edges:
graph.add_edge(edge)
root = graph.peek_root()
if root is None:
return [0]
root.sum_and_cnt(None)
root.ans = root.sum
root.calculate_dist(n, None)
ans = [0] * n
root.fill_ans(ans, None)
return ans