Problem
Given an array of points
where points[i] = [xi, yi]
represents a point on the X-Y plane and an integer k
, return the k
closest points to the origin (0, 0)
.
The distance between two points on the X-Y plane is the Euclidean distance (i.e., √(x1 - x2)2 + (y1 - y2)2
).
You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).
Examples
Example 1:
Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].
Example 2:
Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.
Notice the key requirement here: “K is much smaller than N. N is very large”. Definitely the brute-force solution by finding distance of all element and then sorting them in O(nlgn). But its costlier as we don’t need to sort all the n points as we are only concerned for first k points in the sorted list.
Follow up
Instead of origin, we are given some central point as well to the input? How will you handle that? -
Solution
Idea is to calculate the distance of the point from the origin and then use this problem - Kth Largest Element in an Array.
Method 1 - Using Max Heap
An efficient solution could be to use a max-heap of size k
for maintaining k
minimum distances.
A max heap has its largest element at the root. Whenever we add a point to the heap and its size exceeds K
, we remove the root to discard the farthest point and retain the closest ones.
- Use a max-heap to keep track of the
k
closest points. - Insert points into the heap based on their distance to the origin.
- If the heap size exceeds
k
, remove the point with the maximum distance. - At the end, the heap contains the
k
closest points.
Code
Java
class Solution {
public int[][] kClosest(int[][] points, int k) {
PriorityQueue<int[]> heap = new PriorityQueue<>((a, b) -> (b[0] * b[0] + b[1] * b[1]) - (a[0] * a[0] + a[1] * a[1]));
for (int[] point : points) {
heap.offer(point);
if (heap.size() > k) {
heap.poll();
}
}
int[][] ans = new int[k][2];
for (int i = 0; i < k; i++) {
ans[i] = heap.poll();
}
return ans;
}
}
Using Point Class
class Solution {
static class Point {
public int x;
public int y;
public Point(final int x, final int y) {
this.x = x;
this.y = y;
}
// distance from origin
public long getDist() {
return (long) x * x + y * y;
}
}
public int[][] kClosest(int[][] points, final int k) {
// Max heap
final PriorityQueue<Point> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(o2.getDist(), o1.getDist()));
for (int[] point : points) {
Point p = new Point(point[0], point[1]);
maxHeap.offer(p);
if (maxHeap.size() > k) {
maxHeap.poll();
}
}
int[][] ans = new int[k][2];
int index = 0;
while (!maxHeap.isEmpty()) {
Point p = maxHeap.poll();
ans[index][0] = p.x;
ans[index][1] = p.y;
index++;
}
return ans;
}
}
Python
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
max_heap = []
for (x, y) in points:
dist = -(x**2 + y**2) # Use negative for max-heap
if len(max_heap) < k:
heapq.heappush(max_heap, (dist, x, y))
else:
heapq.heappushpop(max_heap, (dist, x, y))
ans = [[x, y] for (dist, x, y) in max_heap]
return ans
Complexity
- ⏰ Time complexity:
O(n log k)
wheren
is the number of points andk
is the number of closest points to find. - 🧺 Space complexity:
O(k)
to store thek
closest points.
On the Follow up - K closest distance from central point
We just need to update the distance calculation -
private long distance(int[] point, int[] centralPoint) {
return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
(point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
}
Here is the udpate code:
Java
class Solution {
public int[][] kClosest(int[][] points, int[] centralPoint, int k) {
// Max heap to keep the k closest points
PriorityQueue<int[]> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(
distance(o2, centralPoint), distance(o1, centralPoint)
));
for (int[] point : points) {
maxHeap.offer(point);
if (maxHeap.size() > k) {
maxHeap.poll();
}
}
int[][] ans = new int[k][2];
int index = 0;
while (!maxHeap.isEmpty()) {
ans[index++] = maxHeap.poll();
}
return ans;
}
private long distance(int[] point, int[] centralPoint) {
return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
(point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
}
}
Python
class Solution:
def kClosest(self, points: List[List[int]], central_point: List[int], k: int) -> List[List[int]]:
def dist(point: List[int]) -> int:
return (point[0] - central_point[0]) ** 2 + (point[1] - central_point[1]) ** 2
max_heap = []
for point in points:
distance = dist(point)
if len(max_heap) < k:
heapq.heappush(max_heap, (-distance, point))
else:
heapq.heappushpop(max_heap, (-distance, point))
return [point for (_, point) in max_heap]
Method 2 -Quickselect
Using Quickselect, we can select the k
th minimum distance from the list of distances in O(n)
time. Then, we need to go through the array of points and output the first k
elements with a distance less than or equal to the k
th minimum distance. This is an O(n)
time in-place algorithm with constant space.
Approach
- Distance Calculation: For each point, calculate the squared Euclidean distance to the origin. The squared distance avoids dealing with floating-point precision issues.
- Quickselect: Use the Quickselect algorithm to partition the points such that the
k
closest points are in the correct positions. - Partition: After applying Quickselect, the first
k
points in the array are the closest points. - Output: Return these
k
closest points.
Code
Java
class Solution {
public int[][] kClosest(int[][] points, int k) {
quickSelect(points, 0, points.length - 1, k);
int[][] ans = new int[k][2];
System.arraycopy(points, 0, ans, 0, k);
return ans;
}
private void quickSelect(int[][] points, int left, int right, int k) {
if (left < right) {
int pivotIndex = partition(points, left, right);
if (pivotIndex == k - 1) {
return;
} else if (pivotIndex < k - 1) {
quickSelect(points, pivotIndex + 1, right, k);
} else {
quickSelect(points, left, pivotIndex - 1, k);
}
}
}
private int partition(int[][] points, int left, int right) {
int[] pivot = points[right];
int pivotDist = distance(pivot);
int swapIndex = left;
for (int i = left; i < right; i++) {
if (distance(points[i]) <= pivotDist) {
swap(points, i, swapIndex);
swapIndex++;
}
}
swap(points, swapIndex, right);
return swapIndex;
}
private void swap(int[][] points, int i, int j) {
int[] tmp = points[i];
points[i] = points[j];
points[j] = tmp;
}
private int distance(int[] point) {
return point[0] * point[0] + point[1] * point[1];
}
}
Python
class Solution:
def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
def dist(point: List[int]) -> int:
return point[0] ** 2 + point[1] ** 2
def partition(left: int, right: int) -> int:
pivot = points[right]
pivot_dist = dist(pivot)
i = left
for j in range(left, right):
if dist(points[j]) <= pivot_dist:
points[i], points[j] = points[j], points[i]
i += 1
points[i], points[right] = points[right], points[i]
return i
def quickSelect(left: int, right: int, k: int) -> None:
if left < right:
pivot_index = partition(left, right)
if pivot_index == k - 1:
return
elif pivot_index < k - 1:
quickSelect(pivot_index + 1, right, k)
else:
quickSelect(left, pivot_index - 1, k)
quickSelect(0, len(points) - 1, k)
return points[:k]
Complexity
- ⏰ Time complexity:
O(n)
wheren
is the number of points (average case for Quickselect) - 🧺 Space complexity:
O(1)
for the in-place partitioning.