Problem

Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).

The distance between two points on the X-Y plane is the Euclidean distance (i.e., √(x1 - x2)2 + (y1 - y2)2).

You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).

Examples

Example 1:

Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].

Example 2:

Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.

Notice the key requirement here: “K is much smaller than N. N is very large”. Definitely the brute-force solution by finding distance of all element and then sorting them in O(nlgn). But its costlier as we don’t need to sort all the n points as we are only concerned for first k points in the sorted list.

Follow up

Instead of origin, we are given some central point as well to the input? How will you handle that? -

Solution

Idea is to calculate the distance of the point from the origin and then use this problem - Kth Largest Element in an Array.

Method 1 - Using Max Heap

An efficient solution could be to use a max-heap of size k for maintaining k minimum distances.

A max heap has its largest element at the root. Whenever we add a point to the heap and its size exceeds K, we remove the root to discard the farthest point and retain the closest ones.

  • Use a max-heap to keep track of the k closest points.
  • Insert points into the heap based on their distance to the origin.
  • If the heap size exceeds k, remove the point with the maximum distance.
  • At the end, the heap contains the k closest points.

Code

Java
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        PriorityQueue<int[]> heap = new PriorityQueue<>((a, b) -> (b[0] * b[0] + b[1] * b[1]) - (a[0] * a[0] + a[1] * a[1]));
        for (int[] point : points) {
            heap.offer(point);
            if (heap.size() > k) {
                heap.poll();
            }
        }
        int[][] ans = new int[k][2];
        for (int i = 0; i < k; i++) {
            ans[i] = heap.poll();
        }
        return ans;
    }
}
Using Point Class
class Solution {

    static class Point {
        public int x;
        public int y;

        public Point(final int x, final int y) {
            this.x = x;
            this.y = y;
        }

        // distance from origin
        public long getDist() {
            return (long) x * x + y * y;
        }
    }

    public int[][] kClosest(int[][] points, final int k) {
        // Max heap
        final PriorityQueue<Point> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(o2.getDist(), o1.getDist()));

        for (int[] point : points) {
            Point p = new Point(point[0], point[1]);
            maxHeap.offer(p);
            if (maxHeap.size() > k) {
                maxHeap.poll();
            }
        }

        int[][] ans = new int[k][2];
        int index = 0;
        while (!maxHeap.isEmpty()) {
            Point p = maxHeap.poll();
            ans[index][0] = p.x;
            ans[index][1] = p.y;
            index++;
        }

        return ans;
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        max_heap = []
        for (x, y) in points:
            dist = -(x**2 + y**2)  # Use negative for max-heap
            if len(max_heap) < k:
                heapq.heappush(max_heap, (dist, x, y))
            else:
                heapq.heappushpop(max_heap, (dist, x, y))
        
        ans = [[x, y] for (dist, x, y) in max_heap]
        return ans

Complexity

  • ⏰ Time complexity: O(n log k) where n is the number of points and k is the number of closest points to find.
  • 🧺 Space complexity: O(k) to store the k closest points.

On the Follow up - K closest distance from central point

We just need to update the distance calculation -

 private long distance(int[] point, int[] centralPoint) {
	 return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
			 (point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
 }

Here is the udpate code:

Java
class Solution {
    public int[][] kClosest(int[][] points, int[] centralPoint, int k) {
        // Max heap to keep the k closest points
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(
                distance(o2, centralPoint), distance(o1, centralPoint)
        ));

        for (int[] point : points) {
            maxHeap.offer(point);
            if (maxHeap.size() > k) {
                maxHeap.poll();
            }
        }

        int[][] ans = new int[k][2];
        int index = 0;
        while (!maxHeap.isEmpty()) {
            ans[index++] = maxHeap.poll();
        }

        return ans;
    }

    private long distance(int[] point, int[] centralPoint) {
        return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
                (point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], central_point: List[int], k: int) -> List[List[int]]:
        def dist(point: List[int]) -> int:
            return (point[0] - central_point[0]) ** 2 + (point[1] - central_point[1]) ** 2

        max_heap = []
        for point in points:
            distance = dist(point)
            if len(max_heap) < k:
                heapq.heappush(max_heap, (-distance, point))
            else:
                heapq.heappushpop(max_heap, (-distance, point))
        
        return [point for (_, point) in max_heap]

Method 2 -Quickselect

Using Quickselect, we can select the kth minimum distance from the list of distances in O(n) time. Then, we need to go through the array of points and output the first k elements with a distance less than or equal to the kth minimum distance. This is an O(n) time in-place algorithm with constant space.

Approach

  1. Distance Calculation: For each point, calculate the squared Euclidean distance to the origin. The squared distance avoids dealing with floating-point precision issues.
  2. Quickselect: Use the Quickselect algorithm to partition the points such that the k closest points are in the correct positions.
  3. Partition: After applying Quickselect, the first k points in the array are the closest points.
  4. Output: Return these k closest points.

Code

Java
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        quickSelect(points, 0, points.length - 1, k);
        int[][] ans = new int[k][2];
        System.arraycopy(points, 0, ans, 0, k);
        return ans;
    }

    private void quickSelect(int[][] points, int left, int right, int k) {
        if (left < right) {
            int pivotIndex = partition(points, left, right);
            if (pivotIndex == k - 1) {
                return;
            } else if (pivotIndex < k - 1) {
                quickSelect(points, pivotIndex + 1, right, k);
            } else {
                quickSelect(points, left, pivotIndex - 1, k);
            }
        }
    }

    private int partition(int[][] points, int left, int right) {
        int[] pivot = points[right];
        int pivotDist = distance(pivot);
        int swapIndex = left;

        for (int i = left; i < right; i++) {
            if (distance(points[i]) <= pivotDist) {
                swap(points, i, swapIndex);
                swapIndex++;
            }
        }

        swap(points, swapIndex, right);
        return swapIndex;
    }

    private void swap(int[][] points, int i, int j) {
        int[] tmp = points[i];
        points[i] = points[j];
        points[j] = tmp;
    }

    private int distance(int[] point) {
        return point[0] * point[0] + point[1] * point[1];
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        def dist(point: List[int]) -> int:
            return point[0] ** 2 + point[1] ** 2

        def partition(left: int, right: int) -> int:
            pivot = points[right]
            pivot_dist = dist(pivot)
            i = left
            for j in range(left, right):
                if dist(points[j]) <= pivot_dist:
                    points[i], points[j] = points[j], points[i]
                    i += 1
            points[i], points[right] = points[right], points[i]
            return i

        def quickSelect(left: int, right: int, k: int) -> None:
            if left < right:
                pivot_index = partition(left, right)
                if pivot_index == k - 1:
                    return
                elif pivot_index < k - 1:
                    quickSelect(pivot_index + 1, right, k)
                else:
                    quickSelect(left, pivot_index - 1, k)

        quickSelect(0, len(points) - 1, k)
        return points[:k]

Complexity

  • ⏰ Time complexity: O(n) where n is the number of points (average case for Quickselect)
  • 🧺 Space complexity: O(1) for the in-place partitioning.