The diameter of a tree is the number of edges in the longest path in that tree.
There is an undirected tree of n nodes labeled from 0 to n - 1. You are given a 2D array edges where edges.length == n - 1 and edges[i] = [ai, bi] indicates that there is an undirected edge between nodes ai and bi in the tree.
We build the adjacency list for the tree, then use DFS (or BFS) twice as described above. This works because the longest path in a tree always lies between two leaf nodes.
#include<vector>#include<queue>usingnamespace std;
classSolution {
public:int treeDiameter(vector<vector<int>>& edges) {
int n = edges.size() +1;
vector<vector<int>> g(n);
for (auto& e : edges) {
g[e[0]].push_back(e[1]);
g[e[1]].push_back(e[0]);
}
auto bfs = [&](int start) {
vector<int> dist(n, -1);
queue<int> q;
q.push(start);
dist[start] =0;
int far = start;
while (!q.empty()) {
int u = q.front(); q.pop();
for (int v : g[u]) {
if (dist[v] ==-1) {
dist[v] = dist[u] +1;
q.push(v);
if (dist[v] > dist[far]) far = v;
}
}
}
returnmake_pair(far, dist[far]);
};
int u = bfs(0).first;
returnbfs(u).second;
}
};
import java.util.*;
classSolution {
publicinttreeDiameter(int[][] edges) {
int n = edges.length+ 1;
List<List<Integer>> g =new ArrayList<>();
for (int i = 0; i < n; ++i) g.add(new ArrayList<>());
for (int[] e : edges) {
g.get(e[0]).add(e[1]);
g.get(e[1]).add(e[0]);
}
int[] res = bfs(g, 0);
res = bfs(g, res[0]);
return res[1];
}
privateint[]bfs(List<List<Integer>> g, int start) {
int n = g.size();
int[] dist =newint[n];
Arrays.fill(dist, -1);
Queue<Integer> q =new LinkedList<>();
q.offer(start);
dist[start]= 0;
int far = start;
while (!q.isEmpty()) {
int u = q.poll();
for (int v : g.get(u)) {
if (dist[v]==-1) {
dist[v]= dist[u]+ 1;
q.offer(v);
if (dist[v]> dist[far]) far = v;
}
}
}
returnnewint[]{far, dist[far]};
}
}
from collections import deque, defaultdict
classSolution:
deftreeDiameter(self, edges):
n = len(edges) +1 g = defaultdict(list)
for a, b in edges:
g[a].append(b)
g[b].append(a)
defbfs(start):
dist = [-1] * n
q = deque([start])
dist[start] =0 far = start
while q:
u = q.popleft()
for v in g[u]:
if dist[v] ==-1:
dist[v] = dist[u] +1 q.append(v)
if dist[v] > dist[far]:
far = v
return far, dist[far]
u, _ = bfs(0)
_, d = bfs(u)
return d