Problem
Given an integer array nums
, return an integer array counts
where counts[i]
is the number of smaller elements to the right of nums[i]
.
Examples
Example 1:
Input: nums = [5,2,6,1]
Output: [2,1,1,0]
Explanation:
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.
Example 2:
Input: nums = [-1]
Output: [0]
Example 3:
Input: nums = [-1,-1]
Output: [0,0]
Solution
Method 1 - Merge sort
Here is the approach:
- Modified Merge Sort:
- We use merge sort to both sort the array and count the smaller elements to the right.
- During the merge step, we count how many elements from the right part are moved before an element from the left part. This count is the number of smaller elements to the right for that element.
Then, we can following steps:
- Initialization: Create a list of tuples where each tuple contains an element and its original index.
- Merge Sort: Sort the array while maintaining the original indices and counting the smaller elements.
- Count Update: During the merge step, update the counts for each element.
Code
Java
public class Solution {
public List<Integer> countSmaller(int[] nums) {
int n = nums.length;
Integer[] result = new Integer[n];
List<Pair> arr = new ArrayList<>();
for (int i = 0; i < n; i++) {
arr.add(new Pair(nums[i], i));
}
mergeSort(arr, result);
return List.of(result);
}
private void mergeSort(List<Pair> arr, Integer[] result) {
if (arr.size() > 1) {
int mid = arr.size() / 2;
List<Pair> left = new ArrayList<>(arr.subList(0, mid));
List<Pair> right = new ArrayList<>(arr.subList(mid, arr.size()));
mergeSort(left, result);
mergeSort(right, result);
int i = 0, j = 0, k = 0;
while (i < left.size() || j < right.size()) {
if (j == right.size()
|| i < left.size()
&& left.get(i).value <= right.get(j).value) {
arr.set(k, left.get(i));
result[left.get(i).index] += j;
i++;
} else {
arr.set(k, right.get(j));
j++;
}
k++;
}
}
}
private static class Pair {
int value;
int index;
Pair(int value, int index) {
this.value = value;
this.index = index;
}
}
}
Python
def countSmaller(nums):
if not nums:
return []
counts = [0] * len(nums)
def merge_sort(enum):
half = len(enum) // 2
if half:
left, right = merge_sort(enum[:half]), merge_sort(enum[half:])
i = len(left) - 1
j = len(right) - 1
for k in range(len(enum) - 1, -1, -1):
if j < 0 or (i >= 0 and left[i][1] > right[j][1]):
counts[left[i][0]] += j + 1
enum[k] = left[i]
i -= 1
else:
enum[k] = right[j]
j -= 1
return enum
merge_sort(list(enumerate(nums)))
return counts
Complexity
- ⏰ Time complexity:
O(n log n)
due to the merge sort algorithm. - 🧺 Space complexity:
O(n)
due to the temporary arrays used during the merge process.
Method 2 - Binary search tree
We can use BST or balanced BST like AVL tree. Here is the approach:
- TreeNode Class: Define a TreeNode class that represents each node in the BST along with a count of how many times this value has been inserted.
- Insert and Count: Traverse the array from right to left, and for each element, insert it into the BST while counting how many elements are smaller than it.
- Count Smaller: Use a helper function to insert elements into the BST and count the smaller elements to the right.
Code
Java
public class Solution {
class TreeNode {
TreeNode left;
TreeNode right;
int val;
int count = 1; // Number of occurrences of this value
int leftCount = 0; // Number of nodes in the left subtree
public TreeNode(int val) {
this.val = val;
}
}
private int insert(TreeNode node, int val) {
int smallerCount = 0;
while (true) {
if (val <= node.val) {
node.leftCount++;
if (node.left == null) {
node.left = new TreeNode(val);
break;
} else {
node = node.left;
}
} else {
smallerCount += node.leftCount + node.count;
if (node.right == null) {
node.right = new TreeNode(val);
break;
} else {
node = node.right;
}
}
}
return smallerCount;
}
public List<Integer> countSmaller(int[] nums) {
List<Integer> result = new ArrayList<>();
if (nums == null || nums.length == 0)
return result;
int n = nums.length;
TreeNode root = new TreeNode(nums[n - 1]);
result.add(0);
for (int i = n - 2; i >= 0; i--) {
result.add(insert(root, nums[i]));
}
Collections.reverse(result);
return result;
}
}
Python
class TreeNode:
def __init__(self, val):
self.val = val
self.count = 1 # Number of occurrences of this value
self.left_count = 0 # Number of nodes in the left subtree
self.left = None
self.right = None
class Solution:
def insert(self, root, val):
smaller_count = 0
while True:
if val <= root.val:
root.left_count += 1
if root.left is None:
root.left = TreeNode(val)
break
else:
root = root.left
else:
smaller_count += root.left_count + root.count
if root.right is None:
root.right = TreeNode(val)
break
else:
root = root.right
return smaller_count
def countSmaller(self, nums):
if not nums:
return []
result = []
n = len(nums)
root = TreeNode(nums[-1])
result.append(0)
for i in range(n - 2, -1, -1):
result.append(self.insert(root, nums[i]))
return result[::-1] # Reverse to get correct order
Complexity
- ⏰ Time complexity:
O(n log n)
- 🧺 Space complexity:
O(n)
, due to the storage in the BST and the result list.
Method 3 - Using AVL Tree
We can solve the problem more efficiently using a balanced BST such as an AVL tree. Observe that if a key is greater than the root, it is greater than all the nodes in the left subtree of the root.
We will treat each number in the array as a node in a BST and insert them into an AVL tree one by one, from right to left. During each insertion, we will update the size of the left subtree at the node being inserted. This allows us to keep track of the number of smaller elements. Additionally, we need to handle balancing the tree during insertion.
An AVL tree becomes imbalanced during an insertion when the height difference between the left and right subtrees of a node exceeds 1. There are four cases to consider for rebalancing the tree:
- Insert in the Left Subtree of the Left Child:
- Perform a right rotation on node
x
to balance it.
- Perform a right rotation on node
- Insert in the Right Subtree of the Right Child:
- Perform a left rotation on node
x
to balance it.
- Perform a left rotation on node
- Insert in the Right Subtree of the Left Child:
- Perform a left rotation on the left child of
x
, followed by a right rotation onx
to balance the tree.
- Perform a left rotation on the left child of
- Insert in the Left Subtree of the Right Child:
- Perform a right rotation on the right child of
x
, followed by a left rotation onx
to balance the tree.
- Perform a right rotation on the right child of
By maintaining the balance of the AVL tree during insertions and updating the size of the left subtree, we can efficiently count the number of smaller elements to the right of each element in the array.
Here is the approach:
- Define AVL Tree Node: Each node will keep track of its value, height, size of the left subtree, and a count of elements.
- Insert and Rotate Functions: Implement the necessary AVL tree operations (insert, rotate left, rotate right, update height, and size).
- Main Logic: Traverse the array from right to left, inserting elements into the AVL tree and maintaining a count of elements smaller than each inserted element.
Code
Java
class TreeNode {
TreeNode left, right;
int val, height, size, leftSize;
public TreeNode(int val) {
this.val = val;
this.height = 1;
this.size = 1;
this.leftSize = 0;
}
}
public class Solution {
class AVLTree {
private TreeNode root;
public TreeNode insert(TreeNode node, int val) {
if (node == null) {
return new TreeNode(val);
}
if (val < node.val) {
node.left = insert(node.left, val);
node.leftSize++;
} else {
node.right = insert(node.right, val);
}
node.size++;
node.height =
1 + Math.max(getHeight(node.left), getHeight(node.right));
return balance(node);
}
public int countSmaller(TreeNode node, int val) {
if (node == null) {
return 0;
}
if (val <= node.val) {
return countSmaller(node.left, val);
} else {
return node.leftSize + 1 + countSmaller(node.right, val);
}
}
private TreeNode balance(TreeNode node) {
int balanceFactor = getBalance(node);
if (balanceFactor > 1) {
if (getBalance(node.left) < 0) {
node.left = rotateLeft(node.left);
}
return rotateRight(node);
}
if (balanceFactor < -1) {
if (getBalance(node.right) > 0) {
node.right = rotateRight(node.right);
}
return rotateLeft(node);
}
return node;
}
private int getHeight(TreeNode node) {
return node == null ? 0 : node.height;
}
private int getBalance(TreeNode node) {
return node == null ? 0
: getHeight(node.left) - getHeight(node.right);
}
private TreeNode rotateRight(TreeNode y) {
TreeNode x = y.left;
TreeNode T2 = x.right;
x.right = y;
y.left = T2;
y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
y.size = (y.left != null ? y.left.size : 0)
+ (y.right != null ? y.right.size : 0) + 1;
x.size = (x.left != null ? x.left.size : 0)
+ (x.right != null ? x.right.size : 0) + 1;
return x;
}
private TreeNode rotateLeft(TreeNode x) {
TreeNode y = x.right;
TreeNode T2 = y.left;
y.left = x;
x.right = T2;
x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
x.size = (x.left != null ? x.left.size : 0)
+ (x.right != null ? x.right.size : 0) + 1;
y.size = (y.left != null ? y.left.size : 0)
+ (y.right != null ? y.right.size : 0) + 1;
return y;
}
public void insert(int val) {
this.root = insert(root, val);
}
public int countSmaller(int val) {
return countSmaller(root, val);
}
}
public List<Integer> countSmaller(int[] nums) {
List<Integer> result = new ArrayList<>();
if (nums.length == 0)
return result;
AVLTree avlTree = new AVLTree();
for (int i = nums.length - 1; i >= 0; i--) {
avlTree.insert(nums[i]);
result.add(0, avlTree.countSmaller(nums[i]));
}
return result;
}
public static void main(String[] args) {
Solution solution = new Solution();
int[] nums = {5, 2, 6, 1};
System.out.println(solution.countSmaller(nums)); // Output: [2, 1, 1, 0]
}
}
Python
class AVLTreeNode:
def __init__(self, val):
self.val = val
self.height = 1
self.left = None
self.right = None
self.subtree_size = (
1 # Number of nodes in the subtree rooted at this node
)
self.left_size = 0 # Size of the left subtree (number of nodes)
class AVLTree:
def __init__(self):
self.root = None
def insert(self, root, val):
if not root:
return AVLTreeNode(val), 0
if val < root.val:
root.left, smaller_count = self.insert(root.left, val)
root.left_size += 1
else:
root.right, right_smaller_count = self.insert(root.right, val)
smaller_count = (
right_smaller_count
+ root.left_size
+ (1 if val > root.val else 0)
)
root.subtree_size += 1
root.height = 1 + max(
self.get_height(root.left), self.get_height(root.right)
)
balance = self.get_balance(root)
if balance > 1 and val < root.left.val:
return self.right_rotate(root), smaller_count
if balance < -1 and val > root.right.val:
return self.left_rotate(root), smaller_count
if balance > 1 and val > root.left.val:
root.left = self.left_rotate(root.left)
return self.right_rotate(root), smaller_count
if balance < -1 and val < root.right.val:
root.right = self.right_rotate(root.right)
return self.left_rotate(root), smaller_count
return root, smaller_count
def get_height(self, node):
return node.height if node else 0
def get_balance(self, node):
return self.get_height(node.left) - self.get_height(node.right)
def right_rotate(self, y):
x = y.left
T2 = x.right
x.right = y
y.left = T2
y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))
x.height = 1 + max(self.get_height(x.left), self.get_height(x.right))
y.subtree_size = (
1
+ (y.left.subtree_size if y.left else 0)
+ (y.right.subtree_size if y.right else 0)
)
x.subtree_size = (
1
+ (x.left.subtree_size if x.left else 0)
+ (x.right.subtree_size if x.right else 0)
)
y.left_size = y.left.subtree_size if y.left else 0
x.left_size = x.left.subtree_size if x.left else 0
return x
def left_rotate(self, x):
y = x.right
T2 = y.left
y.left = x
x.right = T2
x.height = 1 + max(self.get_height(x.left), self.get_height(x.right))
y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))
x.subtree_size = (
1
+ (x.left.subtree_size if x.left else 0)
+ (x.right.subtree_size if x.right else 0)
)
y.subtree_size = (
1
+ (y.left.subtree_size if y.left else 0)
+ (y.right.subtree_size if y.right else 0)
)
x.left_size = x.left.subtree_size if x.left else 0
y.left_size = y.left.subtree_size if y.left else 0
return y
class Solution:
def countSmaller(self, nums):
if not nums:
return []
avl_tree = AVLTree()
num_len = len(nums)
result = [0] * num_len
for i in range(num_len - 1, -1, -1):
avl_tree.root, result[i] = avl_tree.insert(avl_tree.root, nums[i])
return result
# Example usage:
solution = Solution()
nums = [5, 2, 6, 1]
print(solution.countSmaller(nums)) # Output: [2, 1, 1, 0]
Method 4 - Binary search
Here is the approach:
- Use a list to keep track of the sorted elements that have been processed.
- For each element in the array (iterating from right to left), use binary search to find the index where the current element should be inserted in the sorted list.
- The index found using binary search gives the count of smaller elements to the right of the current element.
- Insert the current element into the sorted list at the correct position.
Code
Java
public List<Integer> countSmaller(int[] nums) {
Integer[] ans = new Integer[nums.length];
List<Integer> sorted = new ArrayList<Integer>();
for (int i = nums.length - 1; i >= 0; i--) {
int index = findIndex(sorted, nums[i]);
ans[i] = index;
sorted.add(index, nums[i]);
}
return Arrays.asList(ans);
}
private int findIndex(List<Integer> sorted, int target) {
if (sorted.size() == 0) return 0;
int start = 0;
int end = sorted.size() - 1;
if (sorted.get(end) < target) return end + 1;
if (sorted.get(start) >= target) return 0;
while (start + 1 < end) {
int mid = start + (end - start) / 2;
if (sorted.get(mid) < target) {
start = mid + 1;
} else {
end = mid;
}
}
if (sorted.get(start) >= target) return start;
return end;
}
Python
from bisect import bisect_left
def countSmaller(nums):
result = []
sorted_list = []
for num in reversed(nums):
index = bisect_left(sorted_list, num)
result.append(index)
sorted_list.insert(index, num)
return result[::-1]
Complexity
- ⏰ Time complexity:
O(n log n)
- 🧺 Space complexity:
O(n)
, to store thesorted_list
.
Method 5 - Segment tree
Here is the approach:
- Handling Negative Numbers: Normalize the array by offsetting negative values.
- Segment Tree Operations: Implement the segment tree with operations to add elements, delete elements, and query counts within a range.
Code
Java
public class Solution {
static class segmentTreeNode {
int start, end, count;
segmentTreeNode left, right;
segmentTreeNode(int start, int end, int count) {
this.start = start;
this.end = end;
this.count = count;
left = null;
right = null;
}
}
public static List<Integer> countSmaller(int[] nums) {
// write your code here
List<Integer> result = new ArrayList<Integer>();
int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
for (int i : nums) {
min = Math.min(min, i);
}
if (min < 0) {
for (int i = 0; i < nums.length; i++) {
nums[i] -= min;//deal with negative numbers, seems a dummy way
}
}
for (int i : nums) {
max = Math.max(max, i);
}
segmentTreeNode root = build(0, max);
for (int i = 0; i < nums.length; i++) {
updateAdd(root, nums[i]);
}
for (int i = 0; i < nums.length; i++) {
updateDel(root, nums[i]);
result.add(query(root, 0, nums[i] - 1));
}
return result;
}
public static segmentTreeNode build(int start, int end) {
if (start > end) return null;
if (start == end) return new segmentTreeNode(start, end, 0);
int mid = (start + end) / 2;
segmentTreeNode root = new segmentTreeNode(start, end, 0);
root.left = build(start, mid);
root.right = build(mid + 1, end);
root.count = root.left.count + root.right.count;
return root;
}
public static int query(segmentTreeNode root, int start, int end) {
if (root == null) return 0;
if (root.start == start && root.end == end) return root.count;
int mid = (root.start + root.end) / 2;
if (end < mid) {
return query(root.left, start, end);
} else if (start > end) {
return query(root.right, start, end);
} else {
return query(root.left, start, mid) + query(root.right, mid + 1, end);
}
}
public static void updateAdd(segmentTreeNode root, int val) {
if (root == null || root.start > val || root.end < val) return;
if (root.start == val && root.end == val) {
root.count ++;
return;
}
int mid = (root.start + root.end) / 2;
if (val <= mid) {
updateAdd(root.left, val);
} else {
updateAdd(root.right, val);
}
root.count = root.left.count + root.right.count;
}
public static void updateDel(segmentTreeNode root, int val) {
if (root == null || root.start > val || root.end < val) return;
if (root.start == val && root.end == val) {
root.count --;
return;
}
int mid = (root.start + root.end) / 2;
if (val <= mid) {
updateDel(root.left, val);
} else {
updateDel(root.right, val);
}
root.count = root.left.count + root.right.count;
}
}
Python
class SegmentTreeNode:
def __init__(self, start, end, count=0):
self.start = start
self.end = end
self.count = count
self.left = None
self.right = None
class Solution:
def build(self, start, end):
if start > end:
return None
node = SegmentTreeNode(start, end)
if start == end:
return node
mid = (start + end) // 2
node.left = self.build(start, mid)
node.right = self.build(mid + 1, end)
return node
def update_add(self, root, val):
if root is None:
return
if root.start == root.end == val:
root.count += 1
return
mid = (root.start + root.end) // 2
if val <= mid:
self.update_add(root.left, val)
else:
self.update_add(root.right, val)
root.count = root.left.count + root.right.count
def update_del(self, root, val):
if root is None:
return
if root.start == root.end == val:
root.count -= 1
return
mid = (root.start + root.end) // 2
if val <= mid:
self.update_del(root.left, val)
else:
self.update_del(root.right, val)
root.count = root.left.count + root.right.count
def query(self, root, start, end):
if root is None or start > end:
return 0
if root.start == start and root.end == end:
return root.count
mid = (root.start + root.end) // 2
if end <= mid:
return self.query(root.left, start, end)
elif start > mid:
return self.query(root.right, start, end)
else:
return self.query(root.left, start, mid) + self.query(
root.right, mid + 1, end
)
def countSmaller(self, nums):
if not nums:
return []
# Handle negative numbers
min_val = min(nums)
if min_val < 0:
nums = [num - min_val for num in nums]
max_val = max(nums)
root = self.build(0, max_val)
result = []
# Insert elements for initial counts
for num in nums:
self.update_add(root, num)
# Process elements right to left
for num in reversed(nums):
self.update_del(root, num)
if num > 0:
result.append(self.query(root, 0, num - 1))
else:
result.append(0)
return result[::-1]
# Example usage:
solution = Solution()
nums = [5, 2, 6, 1]
print(solution.countSmaller(nums)) # Output: [2, 1, 1, 0]
Complexity
- ⏰ Time complexity:
O(n log n)
- 🧺 Space complexity:
O(n)
Method 6 - Binary Indexed tree
Approach
- Normalize the Values: Compress the values of
nums
to a smaller range to efficiently use the BIT. - Use BIT: Traverse the array from right to left, and for each element, use the BIT to count the number of elements seen so far that are smaller.
- Update BIT: After counting for the current element, update the BIT.
Detailed Steps
- Coordinate Compression: Use sorted unique values to map the original values to their ranks.
- BIT Operations: Implement
update
andquery
functions for BIT to support increment and prefix sum operations.
Code
Java
public class Solution {
class BIT {
private int[] tree;
public BIT(int size) {
tree = new int[size + 1];
}
public void update(int index, int value) {
while (index < tree.length) {
tree[index] += value;
index += index & -index;
}
}
public int query(int index) {
int sum = 0;
while (index > 0) {
sum += tree[index];
index -= index & -index;
}
return sum;
}
}
public List<Integer> countSmaller(int[] nums) {
List<Integer> result = new ArrayList<>();
if (nums == null || nums.length == 0)
return result;
// Coordinate compression
Set<Integer> set = new TreeSet<>();
for (int num : nums) set.add(num);
Map<Integer, Integer> ranks = new HashMap<>();
int rank = 1;
for (int num : set) ranks.put(num, rank++);
// Initialize BIT
BIT bit = new BIT(ranks.size());
// Traverse array from right to left
for (int i = nums.length - 1; i >= 0; i--) {
int currRank = ranks.get(nums[i]);
result.add(bit.query(
currRank - 1)); // Count of numbers less than current num
bit.update(currRank, 1);
}
Collections.reverse(result);
return result;
}
}
Python
class BIT:
def __init__(self, n):
self.size = n
self.bit = [0] * (n + 1)
def update(self, i, val):
while i <= self.size:
self.bit[i] += val
i += i & -i
def query(self, i):
sum = 0
while i > 0:
sum += self.bit[i]
i -= i & -i
return sum
def countSmaller(nums):
if not nums:
return []
# Coordinate compression
sorted_nums = sorted(set(nums))
ranks = {
v: i + 1 for i, v in enumerate(sorted_nums)
} # 1-based index for BIT
# Initialize result array and BIT
result = []
bit = BIT(len(ranks))
# Traverse the array from right to left
for num in reversed(nums):
rank = ranks[num]
result.append(
bit.query(rank - 1)
) # Count of numbers less than current num
bit.update(rank, 1)
return result[::-1] # Reverse to get correct order
Complexity
- ⏰ Time complexity:
O(n log n)
, due to BIT operations and coordinate compression. - 🧺 Space complexity:
O(n)
, for the BIT and additional data structures.