Uniform Random Sampling of k Nodes in a Binary Tree
MediumUpdated: Sep 12, 2025
Problem
Given an unbalanced binary tree, select k nodes uniformly at random. The total number of nodes may be very large or unknown. Return the selected nodes as a list.
Examples
Example 1
Input: root = [1,2,3,4,5], k = 2
Output: [2,5] (any two nodes, uniformly random)
Explanation: Each node has equal probability to be picked.
Example 2
Input: root = [10,20,30], k = 1
Output: [30] (any single node, uniformly random)
Explanation: Each node has 1/3 probability to be picked.
Solution
Method 1 – Reservoir Sampling (BFS Traversal)
Intuition
Reservoir sampling allows us to select k nodes from a binary tree with equal probability, even if the total number of nodes is unknown or very large. By maintaining a reservoir and updating it as we traverse, we guarantee uniform selection for each node.
Approach
- Traverse the tree using BFS (level order).
- Add the first
knodes directly to the reservoir. - For each subsequent node at index
i(i >= k):- Generate a random integer
jin[0, i]. - If
j < k, replace the reservoir element at indexjwith the current node.
- Generate a random integer
- Continue until all nodes are processed.
- Return the reservoir containing
kuniformly random nodes.
Code
C++
class Solution {
public:
vector<TreeNode*> randomKSampleTreeNode(TreeNode* root, int k) {
vector<TreeNode*> ans;
queue<TreeNode*> q;
q.push(root);
int idx = 0;
while (!q.empty() && idx < k) {
TreeNode* node = q.front(); q.pop();
ans.push_back(node);
idx++;
if (node->left) q.push(node->left);
if (node->right) q.push(node->right);
}
while (!q.empty()) {
TreeNode* node = q.front(); q.pop();
int j = rand() % (idx + 1);
idx++;
if (j < k) ans[j] = node;
if (node->left) q.push(node->left);
if (node->right) q.push(node->right);
}
return ans;
}
};
Go
func RandomKSampleTreeNode(root *TreeNode, k int) []*TreeNode {
var ans []*TreeNode
queue := []*TreeNode{root}
idx := 0
for len(queue) > 0 && idx < k {
node := queue[0]
queue = queue[1:]
ans = append(ans, node)
idx++
if node.Left != nil {
queue = append(queue, node.Left)
}
if node.Right != nil {
queue = append(queue, node.Right)
}
}
for len(queue) > 0 {
node := queue[0]
queue = queue[1:]
j := rand.Intn(idx + 1)
idx++
if j < k {
ans[j] = node
}
if node.Left != nil {
queue = append(queue, node.Left)
}
if node.Right != nil {
queue = append(queue, node.Right)
}
}
return ans
}
Java
class Solution {
public TreeNode[] randomKSampleTreeNode(TreeNode root, int k) {
TreeNode[] ans = new TreeNode[k];
Queue<TreeNode> queue = new LinkedList<>();
queue.offer(root);
int idx = 0;
while (!queue.isEmpty() && idx < k) {
TreeNode node = queue.poll();
ans[idx++] = node;
if (node.left != null) queue.offer(node.left);
if (node.right != null) queue.offer(node.right);
}
while (!queue.isEmpty()) {
TreeNode node = queue.poll();
int j = (int) (Math.random() * (idx + 1));
idx++;
if (j < k) ans[j] = node;
if (node.left != null) queue.offer(node.left);
if (node.right != null) queue.offer(node.right);
}
return ans;
}
}
Kotlin
class Solution {
fun randomKSampleTreeNode(root: TreeNode, k: Int): Array<TreeNode?> {
val ans = arrayOfNulls<TreeNode>(k)
val queue = ArrayDeque<TreeNode>()
queue.add(root)
var idx = 0
while (queue.isNotEmpty() && idx < k) {
val node = queue.removeFirst()
ans[idx++] = node
node.left?.let { queue.add(it) }
node.right?.let { queue.add(it) }
}
while (queue.isNotEmpty()) {
val node = queue.removeFirst()
val j = (0..idx).random()
idx++
if (j < k) ans[j] = node
node.left?.let { queue.add(it) }
node.right?.let { queue.add(it) }
}
return ans
}
}
Python
from typing import List
import random
class Solution:
def randomKSampleTreeNode(self, root: 'TreeNode', k: int) -> List['TreeNode']:
ans: List['TreeNode'] = []
queue = [root]
idx = 0
while queue and idx < k:
node = queue.pop(0)
ans.append(node)
idx += 1
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
while queue:
node = queue.pop(0)
j = random.randint(0, idx)
idx += 1
if j < k:
ans[j] = node
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
return ans
Rust
use rand::Rng;
impl Solution {
pub fn random_k_sample_tree_node(root: &TreeNode, k: usize) -> Vec<&TreeNode> {
let mut ans = Vec::new();
let mut queue = vec![root];
let mut idx = 0;
while !queue.is_empty() && idx < k {
let node = queue.remove(0);
ans.push(node);
idx += 1;
if let Some(left) = &node.left {
queue.push(left);
}
if let Some(right) = &node.right {
queue.push(right);
}
}
let mut rng = rand::thread_rng();
while !queue.is_empty() {
let node = queue.remove(0);
let j = rng.gen_range(0..=idx);
idx += 1;
if j < k {
ans[j] = node;
}
if let Some(left) = &node.left {
queue.push(left);
}
if let Some(right) = &node.right {
queue.push(right);
}
}
ans
}
}
TypeScript
class Solution {
randomKSampleTreeNode(root: TreeNode, k: number): TreeNode[] {
const ans: TreeNode[] = [];
const queue: TreeNode[] = [root];
let idx = 0;
while (queue.length && idx < k) {
const node = queue.shift()!;
ans.push(node);
idx++;
if (node.left) queue.push(node.left);
if (node.right) queue.push(node.right);
}
while (queue.length) {
const node = queue.shift()!;
const j = Math.floor(Math.random() * (idx + 1));
idx++;
if (j < k) ans[j] = node;
if (node.left) queue.push(node.left);
if (node.right) queue.push(node.right);
}
return ans;
}
}
Complexity
- ⏰ Time complexity:
O(n)— Every node in the tree is visited exactly once during traversal. - 🧺 Space complexity:
O(k + h)— The reservoir usesO(k)space, and the queue uses up toO(h)space, wherehis the height of the tree.