Problem

Given an unbalanced binary tree, select k nodes uniformly at random. The total number of nodes may be very large or unknown. Return the selected nodes as a list.

Examples

Example 1

1
2
3
Input: root = [1,2,3,4,5], k = 2
Output: [2,5] (any two nodes, uniformly random)
Explanation: Each node has equal probability to be picked.

Example 2

1
2
3
Input: root = [10,20,30], k = 1
Output: [30] (any single node, uniformly random)
Explanation: Each node has 1/3 probability to be picked.

Solution

Method 1 – Reservoir Sampling (BFS Traversal)

Intuition

Reservoir sampling allows us to select k nodes from a binary tree with equal probability, even if the total number of nodes is unknown or very large. By maintaining a reservoir and updating it as we traverse, we guarantee uniform selection for each node.

Approach

  1. Traverse the tree using BFS (level order).
  2. Add the first k nodes directly to the reservoir.
  3. For each subsequent node at index i (i >= k):
    • Generate a random integer j in [0, i].
    • If j < k, replace the reservoir element at index j with the current node.
  4. Continue until all nodes are processed.
  5. Return the reservoir containing k uniformly random nodes.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
public:
    vector<TreeNode*> randomKSampleTreeNode(TreeNode* root, int k) {
        vector<TreeNode*> ans;
        queue<TreeNode*> q;
        q.push(root);
        int idx = 0;
        while (!q.empty() && idx < k) {
            TreeNode* node = q.front(); q.pop();
            ans.push_back(node);
            idx++;
            if (node->left) q.push(node->left);
            if (node->right) q.push(node->right);
        }
        while (!q.empty()) {
            TreeNode* node = q.front(); q.pop();
            int j = rand() % (idx + 1);
            idx++;
            if (j < k) ans[j] = node;
            if (node->left) q.push(node->left);
            if (node->right) q.push(node->right);
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
func RandomKSampleTreeNode(root *TreeNode, k int) []*TreeNode {
    var ans []*TreeNode
    queue := []*TreeNode{root}
    idx := 0
    for len(queue) > 0 && idx < k {
        node := queue[0]
        queue = queue[1:]
        ans = append(ans, node)
        idx++
        if node.Left != nil {
            queue = append(queue, node.Left)
        }
        if node.Right != nil {
            queue = append(queue, node.Right)
        }
    }
    for len(queue) > 0 {
        node := queue[0]
        queue = queue[1:]
        j := rand.Intn(idx + 1)
        idx++
        if j < k {
            ans[j] = node
        }
        if node.Left != nil {
            queue = append(queue, node.Left)
        }
        if node.Right != nil {
            queue = append(queue, node.Right)
        }
    }
    return ans
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    public TreeNode[] randomKSampleTreeNode(TreeNode root, int k) {
        TreeNode[] ans = new TreeNode[k];
        Queue<TreeNode> queue = new LinkedList<>();
        queue.offer(root);
        int idx = 0;
        while (!queue.isEmpty() && idx < k) {
            TreeNode node = queue.poll();
            ans[idx++] = node;
            if (node.left != null) queue.offer(node.left);
            if (node.right != null) queue.offer(node.right);
        }
        while (!queue.isEmpty()) {
            TreeNode node = queue.poll();
            int j = (int) (Math.random() * (idx + 1));
            idx++;
            if (j < k) ans[j] = node;
            if (node.left != null) queue.offer(node.left);
            if (node.right != null) queue.offer(node.right);
        }
        return ans;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    fun randomKSampleTreeNode(root: TreeNode, k: Int): Array<TreeNode?> {
        val ans = arrayOfNulls<TreeNode>(k)
        val queue = ArrayDeque<TreeNode>()
        queue.add(root)
        var idx = 0
        while (queue.isNotEmpty() && idx < k) {
            val node = queue.removeFirst()
            ans[idx++] = node
            node.left?.let { queue.add(it) }
            node.right?.let { queue.add(it) }
        }
        while (queue.isNotEmpty()) {
            val node = queue.removeFirst()
            val j = (0..idx).random()
            idx++
            if (j < k) ans[j] = node
            node.left?.let { queue.add(it) }
            node.right?.let { queue.add(it) }
        }
        return ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import List
import random

class Solution:
    def randomKSampleTreeNode(self, root: 'TreeNode', k: int) -> List['TreeNode']:
        ans: List['TreeNode'] = []
        queue = [root]
        idx = 0
        while queue and idx < k:
            node = queue.pop(0)
            ans.append(node)
            idx += 1
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)
        while queue:
            node = queue.pop(0)
            j = random.randint(0, idx)
            idx += 1
            if j < k:
                ans[j] = node
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)
        return ans
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use rand::Rng;
impl Solution {
    pub fn random_k_sample_tree_node(root: &TreeNode, k: usize) -> Vec<&TreeNode> {
        let mut ans = Vec::new();
        let mut queue = vec![root];
        let mut idx = 0;
        while !queue.is_empty() && idx < k {
            let node = queue.remove(0);
            ans.push(node);
            idx += 1;
            if let Some(left) = &node.left {
                queue.push(left);
            }
            if let Some(right) = &node.right {
                queue.push(right);
            }
        }
        let mut rng = rand::thread_rng();
        while !queue.is_empty() {
            let node = queue.remove(0);
            let j = rng.gen_range(0..=idx);
            idx += 1;
            if j < k {
                ans[j] = node;
            }
            if let Some(left) = &node.left {
                queue.push(left);
            }
            if let Some(right) = &node.right {
                queue.push(right);
            }
        }
        ans
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
    randomKSampleTreeNode(root: TreeNode, k: number): TreeNode[] {
        const ans: TreeNode[] = [];
        const queue: TreeNode[] = [root];
        let idx = 0;
        while (queue.length && idx < k) {
            const node = queue.shift()!;
            ans.push(node);
            idx++;
            if (node.left) queue.push(node.left);
            if (node.right) queue.push(node.right);
        }
        while (queue.length) {
            const node = queue.shift()!;
            const j = Math.floor(Math.random() * (idx + 1));
            idx++;
            if (j < k) ans[j] = node;
            if (node.left) queue.push(node.left);
            if (node.right) queue.push(node.right);
        }
        return ans;
    }
}

Complexity

  • ⏰ Time complexity: O(n) — Every node in the tree is visited exactly once during traversal.
  • 🧺 Space complexity: O(k + h) — The reservoir uses O(k) space, and the queue uses up to O(h) space, where h is the height of the tree.