Problem

Given an integer array nums, return an integer array counts where counts[i] is the number of smaller elements to the right of nums[i].

Examples

Example 1:

Input: nums = [5,2,6,1]
Output: [2,1,1,0]
Explanation:
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.

Example 2:

Input: nums = [-1]
Output: [0]

Example 3:

Input: nums = [-1,-1]
Output: [0,0]

Solution

Method 1 - Merge sort

Here is the approach:

  1. Modified Merge Sort:
    • We use merge sort to both sort the array and count the smaller elements to the right.
    • During the merge step, we count how many elements from the right part are moved before an element from the left part. This count is the number of smaller elements to the right for that element.

Then, we can following steps:

  1. Initialization: Create a list of tuples where each tuple contains an element and its original index.
  2. Merge Sort: Sort the array while maintaining the original indices and counting the smaller elements.
  3. Count Update: During the merge step, update the counts for each element.

Code

Java
public class Solution {
    public List<Integer> countSmaller(int[] nums) {
        int n = nums.length;
        Integer[] result = new Integer[n];
        List<Pair> arr = new ArrayList<>();

        for (int i = 0; i < n; i++) {
            arr.add(new Pair(nums[i], i));
        }

        mergeSort(arr, result);
        return List.of(result);
    }

    private void mergeSort(List<Pair> arr, Integer[] result) {
        if (arr.size() > 1) {
            int mid = arr.size() / 2;
            List<Pair> left = new ArrayList<>(arr.subList(0, mid));
            List<Pair> right = new ArrayList<>(arr.subList(mid, arr.size()));
            mergeSort(left, result);
            mergeSort(right, result);

            int i = 0, j = 0, k = 0;
            while (i < left.size() || j < right.size()) {
                if (j == right.size()
                    || i < left.size()
                        && left.get(i).value <= right.get(j).value) {
                    arr.set(k, left.get(i));
                    result[left.get(i).index] += j;
                    i++;
                } else {
                    arr.set(k, right.get(j));
                    j++;
                }
                k++;
            }
        }
    }

    private static class Pair {
        int value;
        int index;

        Pair(int value, int index) {
            this.value = value;
            this.index = index;
        }
    }

}
Python
def countSmaller(nums):
    if not nums:
        return []

    counts = [0] * len(nums)

    def merge_sort(enum):
        half = len(enum) // 2
        if half:
            left, right = merge_sort(enum[:half]), merge_sort(enum[half:])
            i = len(left) - 1
            j = len(right) - 1
            for k in range(len(enum) - 1, -1, -1):
                if j < 0 or (i >= 0 and left[i][1] > right[j][1]):
                    counts[left[i][0]] += j + 1
                    enum[k] = left[i]
                    i -= 1
                else:
                    enum[k] = right[j]
                    j -= 1
        return enum

    merge_sort(list(enumerate(nums)))
    return counts

Complexity

  • ⏰ Time complexity: O(n log n) due to the merge sort algorithm.
  • 🧺 Space complexity: O(n) due to the temporary arrays used during the merge process.

Method 2 - Binary search tree

We can use BST or balanced BST like AVL tree. Here is the approach:

  1. TreeNode Class: Define a TreeNode class that represents each node in the BST along with a count of how many times this value has been inserted.
  2. Insert and Count: Traverse the array from right to left, and for each element, insert it into the BST while counting how many elements are smaller than it.
  3. Count Smaller: Use a helper function to insert elements into the BST and count the smaller elements to the right.

Code

Java
public class Solution {
    class TreeNode {
        TreeNode left;
        TreeNode right;
        int val;
        int count = 1; // Number of occurrences of this value
        int leftCount = 0; // Number of nodes in the left subtree

        public TreeNode(int val) {
            this.val = val;
        }
    }

    private int insert(TreeNode node, int val) {
        int smallerCount = 0;
        while (true) {
            if (val <= node.val) {
                node.leftCount++;
                if (node.left == null) {
                    node.left = new TreeNode(val);
                    break;
                } else {
                    node = node.left;
                }
            } else {
                smallerCount += node.leftCount + node.count;
                if (node.right == null) {
                    node.right = new TreeNode(val);
                    break;
                } else {
                    node = node.right;
                }
            }
        }
        return smallerCount;
    }

    public List<Integer> countSmaller(int[] nums) {
        List<Integer> result = new ArrayList<>();
        if (nums == null || nums.length == 0)
            return result;

        int n = nums.length;
        TreeNode root = new TreeNode(nums[n - 1]);
        result.add(0);

        for (int i = n - 2; i >= 0; i--) {
            result.add(insert(root, nums[i]));
        }

        Collections.reverse(result);
        return result;
    }

}
Python
class TreeNode:
    def __init__(self, val):
        self.val = val
        self.count = 1  # Number of occurrences of this value
        self.left_count = 0  # Number of nodes in the left subtree
        self.left = None
        self.right = None

    class Solution:

        def insert(self, root, val):
            smaller_count = 0
            while True:
                if val <= root.val:
                    root.left_count += 1
                    if root.left is None:
                        root.left = TreeNode(val)
                        break
                    else:
                        root = root.left
                else:
                    smaller_count += root.left_count + root.count
                    if root.right is None:
                        root.right = TreeNode(val)
                        break
                    else:
                        root = root.right
            return smaller_count

        def countSmaller(self, nums):
            if not nums:
                return []

            result = []
            n = len(nums)
            root = TreeNode(nums[-1])
            result.append(0)

            for i in range(n - 2, -1, -1):
                result.append(self.insert(root, nums[i]))

            return result[::-1]  # Reverse to get correct order

Complexity

  • ⏰ Time complexity: O(n log n)
  • 🧺 Space complexity: O(n), due to the storage in the BST and the result list.

Method 3 - Using AVL Tree

We can solve the problem more efficiently using a balanced BST such as an AVL tree. Observe that if a key is greater than the root, it is greater than all the nodes in the left subtree of the root.

We will treat each number in the array as a node in a BST and insert them into an AVL tree one by one, from right to left. During each insertion, we will update the size of the left subtree at the node being inserted. This allows us to keep track of the number of smaller elements. Additionally, we need to handle balancing the tree during insertion.

An AVL tree becomes imbalanced during an insertion when the height difference between the left and right subtrees of a node exceeds 1. There are four cases to consider for rebalancing the tree:

  1. Insert in the Left Subtree of the Left Child:
    • Perform a right rotation on node x to balance it.
  2. Insert in the Right Subtree of the Right Child:
    • Perform a left rotation on node x to balance it.
  3. Insert in the Right Subtree of the Left Child:
    • Perform a left rotation on the left child of x, followed by a right rotation on x to balance the tree.
  4. Insert in the Left Subtree of the Right Child:
    • Perform a right rotation on the right child of x, followed by a left rotation on x to balance the tree.

By maintaining the balance of the AVL tree during insertions and updating the size of the left subtree, we can efficiently count the number of smaller elements to the right of each element in the array.

Here is the approach:

  1. Define AVL Tree Node: Each node will keep track of its value, height, size of the left subtree, and a count of elements.
  2. Insert and Rotate Functions: Implement the necessary AVL tree operations (insert, rotate left, rotate right, update height, and size).
  3. Main Logic: Traverse the array from right to left, inserting elements into the AVL tree and maintaining a count of elements smaller than each inserted element.

Code

Java
class TreeNode {
    TreeNode left, right;
    int val, height, size, leftSize;

    public TreeNode(int val) {
        this.val = val;
        this.height = 1;
        this.size = 1;
        this.leftSize = 0;
    }
}

public class Solution {
    class AVLTree {
        private TreeNode root;

        public TreeNode insert(TreeNode node, int val) {
            if (node == null) {
                return new TreeNode(val);
            }

            if (val < node.val) {
                node.left = insert(node.left, val);
                node.leftSize++;
            } else {
                node.right = insert(node.right, val);
            }

            node.size++;
            node.height =
                1 + Math.max(getHeight(node.left), getHeight(node.right));

            return balance(node);
        }

        public int countSmaller(TreeNode node, int val) {
            if (node == null) {
                return 0;
            }

            if (val <= node.val) {
                return countSmaller(node.left, val);
            } else {
                return node.leftSize + 1 + countSmaller(node.right, val);
            }
        }

        private TreeNode balance(TreeNode node) {
            int balanceFactor = getBalance(node);

            if (balanceFactor > 1) {
                if (getBalance(node.left) < 0) {
                    node.left = rotateLeft(node.left);
                }
                return rotateRight(node);
            }

            if (balanceFactor < -1) {
                if (getBalance(node.right) > 0) {
                    node.right = rotateRight(node.right);
                }
                return rotateLeft(node);
            }

            return node;
        }

        private int getHeight(TreeNode node) {
            return node == null ? 0 : node.height;
        }

        private int getBalance(TreeNode node) {
            return node == null ? 0
                                : getHeight(node.left) - getHeight(node.right);
        }

        private TreeNode rotateRight(TreeNode y) {
            TreeNode x = y.left;
            TreeNode T2 = x.right;

            x.right = y;
            y.left = T2;

            y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
            x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;

            y.size = (y.left != null ? y.left.size : 0)
                + (y.right != null ? y.right.size : 0) + 1;
            x.size = (x.left != null ? x.left.size : 0)
                + (x.right != null ? x.right.size : 0) + 1;

            return x;
        }

        private TreeNode rotateLeft(TreeNode x) {
            TreeNode y = x.right;
            TreeNode T2 = y.left;

            y.left = x;
            x.right = T2;

            x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
            y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;

            x.size = (x.left != null ? x.left.size : 0)
                + (x.right != null ? x.right.size : 0) + 1;
            y.size = (y.left != null ? y.left.size : 0)
                + (y.right != null ? y.right.size : 0) + 1;

            return y;
        }

        public void insert(int val) {
            this.root = insert(root, val);
        }

        public int countSmaller(int val) {
            return countSmaller(root, val);
        }
    }

    public List<Integer> countSmaller(int[] nums) {
        List<Integer> result = new ArrayList<>();
        if (nums.length == 0)
            return result;

        AVLTree avlTree = new AVLTree();

        for (int i = nums.length - 1; i >= 0; i--) {
            avlTree.insert(nums[i]);
            result.add(0, avlTree.countSmaller(nums[i]));
        }

        return result;
    }

    public static void main(String[] args) {
        Solution solution = new Solution();

        int[] nums = {5, 2, 6, 1};
        System.out.println(solution.countSmaller(nums)); // Output: [2, 1, 1, 0]
    }
}
Python
class AVLTreeNode:
    def __init__(self, val):
        self.val = val
        self.height = 1
        self.left = None
        self.right = None
        self.subtree_size = (
            1  # Number of nodes in the subtree rooted at this node
        )
        self.left_size = 0  # Size of the left subtree (number of nodes)


class AVLTree:
    def __init__(self):
        self.root = None

    def insert(self, root, val):
        if not root:
            return AVLTreeNode(val), 0

        if val < root.val:
            root.left, smaller_count = self.insert(root.left, val)
            root.left_size += 1
        else:
            root.right, right_smaller_count = self.insert(root.right, val)
            smaller_count = (
                right_smaller_count
                + root.left_size
                + (1 if val > root.val else 0)
            )

        root.subtree_size += 1
        root.height = 1 + max(
            self.get_height(root.left), self.get_height(root.right)
        )

        balance = self.get_balance(root)

        if balance > 1 and val < root.left.val:
            return self.right_rotate(root), smaller_count
        if balance < -1 and val > root.right.val:
            return self.left_rotate(root), smaller_count
        if balance > 1 and val > root.left.val:
            root.left = self.left_rotate(root.left)
            return self.right_rotate(root), smaller_count
        if balance < -1 and val < root.right.val:
            root.right = self.right_rotate(root.right)
            return self.left_rotate(root), smaller_count

        return root, smaller_count

    def get_height(self, node):
        return node.height if node else 0

    def get_balance(self, node):
        return self.get_height(node.left) - self.get_height(node.right)

    def right_rotate(self, y):
        x = y.left
        T2 = x.right

        x.right = y
        y.left = T2

        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))
        x.height = 1 + max(self.get_height(x.left), self.get_height(x.right))

        y.subtree_size = (
            1
            + (y.left.subtree_size if y.left else 0)
            + (y.right.subtree_size if y.right else 0)
        )
        x.subtree_size = (
            1
            + (x.left.subtree_size if x.left else 0)
            + (x.right.subtree_size if x.right else 0)
        )

        y.left_size = y.left.subtree_size if y.left else 0
        x.left_size = x.left.subtree_size if x.left else 0

        return x

    def left_rotate(self, x):
        y = x.right
        T2 = y.left

        y.left = x
        x.right = T2

        x.height = 1 + max(self.get_height(x.left), self.get_height(x.right))
        y.height = 1 + max(self.get_height(y.left), self.get_height(y.right))

        x.subtree_size = (
            1
            + (x.left.subtree_size if x.left else 0)
            + (x.right.subtree_size if x.right else 0)
        )
        y.subtree_size = (
            1
            + (y.left.subtree_size if y.left else 0)
            + (y.right.subtree_size if y.right else 0)
        )

        x.left_size = x.left.subtree_size if x.left else 0
        y.left_size = y.left.subtree_size if y.left else 0

        return y


class Solution:
    def countSmaller(self, nums):
        if not nums:
            return []

        avl_tree = AVLTree()
        num_len = len(nums)
        result = [0] * num_len

        for i in range(num_len - 1, -1, -1):
            avl_tree.root, result[i] = avl_tree.insert(avl_tree.root, nums[i])

        return result


# Example usage:
solution = Solution()
nums = [5, 2, 6, 1]
print(solution.countSmaller(nums))  # Output: [2, 1, 1, 0]

Here is the approach:

  • Use a list to keep track of the sorted elements that have been processed.
  • For each element in the array (iterating from right to left), use binary search to find the index where the current element should be inserted in the sorted list.
  • The index found using binary search gives the count of smaller elements to the right of the current element.
  • Insert the current element into the sorted list at the correct position.

Code

Java
 public List<Integer> countSmaller(int[] nums) {
	Integer[] ans = new Integer[nums.length];
	List<Integer> sorted = new ArrayList<Integer>();
	for (int i = nums.length - 1; i >= 0; i--) {
		int index = findIndex(sorted, nums[i]);
		ans[i] = index;
		sorted.add(index, nums[i]);
	}
	return Arrays.asList(ans);
}
private int findIndex(List<Integer> sorted, int target) {
	if (sorted.size() == 0) return 0;
	int start = 0;
	int end = sorted.size() - 1;
	if (sorted.get(end) < target) return end + 1;
	if (sorted.get(start) >= target) return 0;
	while (start + 1 < end) {
		int mid = start + (end - start) / 2;
		if (sorted.get(mid) < target) {
			start = mid + 1;
		} else {
			end = mid;
		}
	}
	if (sorted.get(start) >= target) return start;
	return end;
}
Python
from bisect import bisect_left


def countSmaller(nums):
    result = []
    sorted_list = []

    for num in reversed(nums):
        index = bisect_left(sorted_list, num)
        result.append(index)
        sorted_list.insert(index, num)

    return result[::-1]

Complexity

  • ⏰ Time complexity: O(n log n)
  • 🧺 Space complexity: O(n), to store the sorted_list.

Method 5 - Segment tree

Here is the approach:

  1. Handling Negative Numbers: Normalize the array by offsetting negative values.
  2. Segment Tree Operations: Implement the segment tree with operations to add elements, delete elements, and query counts within a range.

Code

Java
 public class Solution {
	static class segmentTreeNode {
		int start, end, count;
		segmentTreeNode left, right;
		segmentTreeNode(int start, int end, int count) {
			this.start = start;
			this.end = end;
			this.count = count;
			left = null;
			right = null;
		}
	}
	public static List<Integer> countSmaller(int[] nums) {
		// write your code here
		List<Integer> result = new ArrayList<Integer>();

		int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
		for (int i : nums) {
			min = Math.min(min, i);

		}
		if (min < 0) {
			for (int i = 0; i < nums.length; i++) {
				nums[i] -= min;//deal with negative numbers, seems a dummy way
			}
		}
		for (int i : nums) {
			max = Math.max(max, i);
		}
		segmentTreeNode root = build(0, max);
		for (int i = 0; i < nums.length; i++) {
			updateAdd(root, nums[i]);
		}
		for (int i = 0; i < nums.length; i++) {
			updateDel(root, nums[i]);
			result.add(query(root, 0, nums[i] - 1));
		}
		return result;
	}
	public static segmentTreeNode build(int start, int end) {
		if (start > end) return null;
		if (start == end) return new segmentTreeNode(start, end, 0);
		int mid = (start + end) / 2;
		segmentTreeNode root = new segmentTreeNode(start, end, 0);
		root.left = build(start, mid);
		root.right = build(mid + 1, end);
		root.count = root.left.count + root.right.count;
		return root;
	}

	public static int query(segmentTreeNode root, int start, int end) {
		if (root == null) return 0;
		if (root.start == start && root.end == end) return root.count;
		int mid = (root.start + root.end) / 2;
		if (end < mid) {
			return query(root.left, start, end);
		} else if (start > end) {
			return query(root.right, start, end);
		} else {
			return query(root.left, start, mid) + query(root.right, mid + 1, end);
		}
	}

	public static void updateAdd(segmentTreeNode root, int val) {
		if (root == null || root.start > val || root.end < val) return;
		if (root.start == val && root.end == val) {
			root.count ++;
			return;
		}
		int mid = (root.start + root.end) / 2;
		if (val <= mid) {
			updateAdd(root.left, val);
		} else {
			updateAdd(root.right, val);
		}
		root.count = root.left.count + root.right.count;
	}

	public static void updateDel(segmentTreeNode root, int val) {
		if (root == null || root.start > val || root.end < val) return;
		if (root.start == val && root.end == val) {
			root.count --;
			return;
		}
		int mid = (root.start + root.end) / 2;
		if (val <= mid) {
			updateDel(root.left, val);
		} else {
			updateDel(root.right, val);
		}
		root.count = root.left.count + root.right.count;
	}
}
Python
class SegmentTreeNode:
    def __init__(self, start, end, count=0):
        self.start = start
        self.end = end
        self.count = count
        self.left = None
        self.right = None


class Solution:

    def build(self, start, end):
        if start > end:
            return None
        node = SegmentTreeNode(start, end)
        if start == end:
            return node
        mid = (start + end) // 2
        node.left = self.build(start, mid)
        node.right = self.build(mid + 1, end)
        return node

    def update_add(self, root, val):
        if root is None:
            return
        if root.start == root.end == val:
            root.count += 1
            return
        mid = (root.start + root.end) // 2
        if val <= mid:
            self.update_add(root.left, val)
        else:
            self.update_add(root.right, val)
        root.count = root.left.count + root.right.count

    def update_del(self, root, val):
        if root is None:
            return
        if root.start == root.end == val:
            root.count -= 1
            return
        mid = (root.start + root.end) // 2
        if val <= mid:
            self.update_del(root.left, val)
        else:
            self.update_del(root.right, val)
        root.count = root.left.count + root.right.count

    def query(self, root, start, end):
        if root is None or start > end:
            return 0
        if root.start == start and root.end == end:
            return root.count
        mid = (root.start + root.end) // 2
        if end <= mid:
            return self.query(root.left, start, end)
        elif start > mid:
            return self.query(root.right, start, end)
        else:
            return self.query(root.left, start, mid) + self.query(
                root.right, mid + 1, end
            )

    def countSmaller(self, nums):
        if not nums:
            return []

        # Handle negative numbers
        min_val = min(nums)
        if min_val < 0:
            nums = [num - min_val for num in nums]

        max_val = max(nums)
        root = self.build(0, max_val)

        result = []

        # Insert elements for initial counts
        for num in nums:
            self.update_add(root, num)

        # Process elements right to left
        for num in reversed(nums):
            self.update_del(root, num)
            if num > 0:
                result.append(self.query(root, 0, num - 1))
            else:
                result.append(0)

        return result[::-1]


# Example usage:
solution = Solution()
nums = [5, 2, 6, 1]
print(solution.countSmaller(nums))  # Output: [2, 1, 1, 0]

Complexity

  • ⏰ Time complexity: O(n log n)
  • 🧺 Space complexity: O(n)

Method 6 - Binary Indexed tree

Approach

  1. Normalize the Values: Compress the values of nums to a smaller range to efficiently use the BIT.
  2. Use BIT: Traverse the array from right to left, and for each element, use the BIT to count the number of elements seen so far that are smaller.
  3. Update BIT: After counting for the current element, update the BIT.

Detailed Steps

  1. Coordinate Compression: Use sorted unique values to map the original values to their ranks.
  2. BIT Operations: Implement update and query functions for BIT to support increment and prefix sum operations.

Code

Java
 public class Solution {
    class BIT {
        private int[] tree;

        public BIT(int size) {
            tree = new int[size + 1];
        }

        public void update(int index, int value) {
            while (index < tree.length) {
                tree[index] += value;
                index += index & -index;
            }
        }

        public int query(int index) {
            int sum = 0;
            while (index > 0) {
                sum += tree[index];
                index -= index & -index;
            }
            return sum;
        }
    }

    public List<Integer> countSmaller(int[] nums) {
        List<Integer> result = new ArrayList<>();
        if (nums == null || nums.length == 0)
            return result;

        // Coordinate compression
        Set<Integer> set = new TreeSet<>();
        for (int num : nums) set.add(num);
        Map<Integer, Integer> ranks = new HashMap<>();
        int rank = 1;
        for (int num : set) ranks.put(num, rank++);

        // Initialize BIT
        BIT bit = new BIT(ranks.size());

        // Traverse array from right to left
        for (int i = nums.length - 1; i >= 0; i--) {
            int currRank = ranks.get(nums[i]);
            result.add(bit.query(
                currRank - 1)); // Count of numbers less than current num
            bit.update(currRank, 1);
        }

        Collections.reverse(result);
        return result;
    }
}
Python
class BIT:
    def __init__(self, n):
        self.size = n
        self.bit = [0] * (n + 1)

    def update(self, i, val):
        while i <= self.size:
            self.bit[i] += val
            i += i & -i

    def query(self, i):
        sum = 0
        while i > 0:
            sum += self.bit[i]
            i -= i & -i
        return sum


def countSmaller(nums):
    if not nums:
        return []

    # Coordinate compression
    sorted_nums = sorted(set(nums))
    ranks = {
        v: i + 1 for i, v in enumerate(sorted_nums)
    }  # 1-based index for BIT

    # Initialize result array and BIT
    result = []
    bit = BIT(len(ranks))

    # Traverse the array from right to left
    for num in reversed(nums):
        rank = ranks[num]
        result.append(
            bit.query(rank - 1)
        )  # Count of numbers less than current num
        bit.update(rank, 1)

    return result[::-1]  # Reverse to get correct order

Complexity

  • ⏰ Time complexity: O(n log n), due to BIT operations and coordinate compression.
  • 🧺 Space complexity: O(n), for the BIT and additional data structures.