problemmediumalgorithmsleetcode-973leetcode 973leetcode973daily-coding-problem-150daily coding problem 150dailycodingproblem150k-closest-points-to-the-central-pointk closest points to the central pointkclosestpointstothecentralpoint

K closest Points to Origin

MediumUpdated: Aug 2, 2025
Practice on:

Problem

Given an array of points where points[i] = [xi, yi] represents a point on the X-Y plane and an integer k, return the k closest points to the origin (0, 0).

The distance between two points on the X-Y plane is the Euclidean distance (i.e., √(x1 - x2)2 + (y1 - y2)2).

You may return the answer in any order. The answer is guaranteed to be unique (except for the order that it is in).

Examples

Example 1:

Input: points = [[1,3],[-2,2]], k = 1
Output: [[-2,2]]
Explanation:
The distance between (1, 3) and the origin is sqrt(10).
The distance between (-2, 2) and the origin is sqrt(8).
Since sqrt(8) < sqrt(10), (-2, 2) is closer to the origin.
We only want the closest k = 1 points from the origin, so the answer is just [[-2,2]].

Example 2:

Input: points = [[3,3],[5,-1],[-2,4]], k = 2
Output: [[3,3],[-2,4]]
Explanation: The answer [[-2,4],[3,3]] would also be accepted.

Notice the key requirement here: “K is much smaller than N. N is very large”. Definitely the brute-force solution by finding distance of all element and then sorting them in O(nlgn). But its costlier as we don’t need to sort all the n points as we are only concerned for first k points in the sorted list.

Follow up

Instead of origin, we are given some central point as well to the input? How will you handle that? -

Solution

Idea is to calculate the distance of the point from the origin and then use this problem - [Kth Largest Element in an Array](kth-largest-element-in-an-array).

Method 1 - Using Max Heap

An efficient solution could be to use a max-heap of size k for maintaining k minimum distances.

A max heap has its largest element at the root. Whenever we add a point to the heap and its size exceeds K, we remove the root to discard the farthest point and retain the closest ones.

  • Use a max-heap to keep track of the k closest points.
  • Insert points into the heap based on their distance to the origin.
  • If the heap size exceeds k, remove the point with the maximum distance.
  • At the end, the heap contains the k closest points.

Code

Java
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        PriorityQueue<int[]> heap = new PriorityQueue<>((a, b) -> (b[0] * b[0] + b[1] * b[1]) - (a[0] * a[0] + a[1] * a[1]));
        for (int[] point : points) {
            heap.offer(point);
            if (heap.size() > k) {
                heap.poll();
            }
        }
        int[][] ans = new int[k][2];
        for (int i = 0; i < k; i++) {
            ans[i] = heap.poll();
        }
        return ans;
    }
}
Using Point Class
class Solution {

    static class Point {
        public int x;
        public int y;

        public Point(final int x, final int y) {
            this.x = x;
            this.y = y;
        }

        // distance from origin
        public long getDist() {
            return (long) x * x + y * y;
        }
    }

    public int[][] kClosest(int[][] points, final int k) {
        // Max heap
        final PriorityQueue<Point> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(o2.getDist(), o1.getDist()));

        for (int[] point : points) {
            Point p = new Point(point[0], point[1]);
            maxHeap.offer(p);
            if (maxHeap.size() > k) {
                maxHeap.poll();
            }
        }

        int[][] ans = new int[k][2];
        int index = 0;
        while (!maxHeap.isEmpty()) {
            Point p = maxHeap.poll();
            ans[index][0] = p.x;
            ans[index][1] = p.y;
            index++;
        }

        return ans;
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        max_heap = []
        for (x, y) in points:
            dist = -(x**2 + y**2)  # Use negative for max-heap
            if len(max_heap) < k:
                heapq.heappush(max_heap, (dist, x, y))
            else:
                heapq.heappushpop(max_heap, (dist, x, y))
        
        ans = [[x, y] for (dist, x, y) in max_heap]
        return ans

Complexity

  • ⏰ Time complexity: O(n log k) where n is the number of points and k is the number of closest points to find.
  • 🧺 Space complexity: O(k) to store the k closest points.

Method 2 -Quickselect

Using [Quickselect](kth-largest-element-using-randomized-selection), we can select the kth minimum distance from the list of distances in O(n) time. Then, we need to go through the array of points and output the first k elements with a distance less than or equal to the kth minimum distance. This is an O(n) time in-place algorithm with constant space.

Approach

  1. Distance Calculation: For each point, calculate the squared Euclidean distance to the origin. The squared distance avoids dealing with floating-point precision issues.
  2. Quickselect: Use the Quickselect algorithm to partition the points such that the k closest points are in the correct positions.
  3. Partition: After applying Quickselect, the first k points in the array are the closest points.
  4. Output: Return these k closest points.

Code

Java
class Solution {
    public int[][] kClosest(int[][] points, int k) {
        quickSelect(points, 0, points.length - 1, k);
        int[][] ans = new int[k][2];
        System.arraycopy(points, 0, ans, 0, k);
        return ans;
    }

    private void quickSelect(int[][] points, int left, int right, int k) {
        if (left < right) {
            int pivotIndex = partition(points, left, right);
            if (pivotIndex == k - 1) {
                return;
            } else if (pivotIndex < k - 1) {
                quickSelect(points, pivotIndex + 1, right, k);
            } else {
                quickSelect(points, left, pivotIndex - 1, k);
            }
        }
    }

    private int partition(int[][] points, int left, int right) {
        int[] pivot = points[right];
        int pivotDist = distance(pivot);
        int swapIndex = left;

        for (int i = left; i < right; i++) {
            if (distance(points[i]) <= pivotDist) {
                swap(points, i, swapIndex);
                swapIndex++;
            }
        }

        swap(points, swapIndex, right);
        return swapIndex;
    }

    private void swap(int[][] points, int i, int j) {
        int[] tmp = points[i];
        points[i] = points[j];
        points[j] = tmp;
    }

    private int distance(int[] point) {
        return point[0] * point[0] + point[1] * point[1];
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        def dist(point: List[int]) -> int:
            return point[0] ** 2 + point[1] ** 2

        def partition(left: int, right: int) -> int:
            pivot = points[right]
            pivot_dist = dist(pivot)
            i = left
            for j in range(left, right):
                if dist(points[j]) <= pivot_dist:
                    points[i], points[j] = points[j], points[i]
                    i += 1
            points[i], points[right] = points[right], points[i]
            return i

        def quickSelect(left: int, right: int, k: int) -> None:
            if left < right:
                pivot_index = partition(left, right)
                if pivot_index == k - 1:
                    return
                elif pivot_index < k - 1:
                    quickSelect(pivot_index + 1, right, k)
                else:
                    quickSelect(left, pivot_index - 1, k)

        quickSelect(0, len(points) - 1, k)
        return points[:k]

Complexity

  • ⏰ Time complexity: O(n) where n is the number of points (average case for Quickselect)
  • 🧺 Space complexity: O(1) for the in-place partitioning.

Follow up - K closest distance from central point

We just need to update the distance calculation -

 private long distance(int[] point, int[] centralPoint) {
	 return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
			 (point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
 }

Code

Here is the updated code:

Java
class Solution {
    public int[][] kClosest(int[][] points, int[] centralPoint, int k) {
        // Max heap to keep the k closest points
        PriorityQueue<int[]> maxHeap = new PriorityQueue<>((o1, o2) -> Long.compare(
                distance(o2, centralPoint), distance(o1, centralPoint)
        ));

        for (int[] point : points) {
            maxHeap.offer(point);
            if (maxHeap.size() > k) {
                maxHeap.poll();
            }
        }

        int[][] ans = new int[k][2];
        int index = 0;
        while (!maxHeap.isEmpty()) {
            ans[index++] = maxHeap.poll();
        }

        return ans;
    }

    private long distance(int[] point, int[] centralPoint) {
        return (long) (point[0] - centralPoint[0]) * (point[0] - centralPoint[0]) +
                (point[1] - centralPoint[1]) * (point[1] - centralPoint[1]);
    }
}
Python
class Solution:
    def kClosest(self, points: List[List[int]], central_point: List[int], k: int) -> List[List[int]]:
        def dist(point: List[int]) -> int:
            return (point[0] - central_point[0]) ** 2 + (point[1] - central_point[1]) ** 2

        max_heap = []
        for point in points:
            distance = dist(point)
            if len(max_heap) < k:
                heapq.heappush(max_heap, (-distance, point))
            else:
                heapq.heappushpop(max_heap, (-distance, point))
        
        return [point for (_, point) in max_heap]

Comments