Sort Matrix by Diagonals
MediumUpdated: Sep 1, 2025
Practice on:
Problem
You are given an n x n square matrix of integers grid. Return the matrix such that:
- The diagonals in the bottom-left triangle (including the middle diagonal) are sorted in non-increasing order.
- The diagonals in the top-right triangle are sorted in non-decreasing order.
Examples
Example 1
\Huge
\begin{array}{|c|c|c|}
\hline
1 & \colorbox{Turquoise} 7 & 3 \\\
\hline
\colorbox{Turquoise} 9 & 8 & \colorbox{Turquoise} 2 \\\
\hline
4 & \colorbox{Turquoise} 5 & 6 \\\
\hline
\end{array}
\longrightarrow
\begin{array}{|c|c|c|}
\hline
8 & \colorbox{Turquoise} 2 & 3 \\\
\hline
\colorbox{Turquoise} 9 & 6 & \colorbox{Turquoise} 7 \\\
\hline
4 & \colorbox{Turquoise} 5 & 1 \\\
\hline
\end{array}
Input: grid = [[1,7,3],[9,8,2],[4,5,6]]
Output: [[8,2,3],[9,6,7],[4,5,1]]
Explanation:
The diagonals with a black arrow (bottom-left triangle) should be sorted in non-increasing order:
* `[1, 8, 6]` becomes `[8, 6, 1]`.
* `[9, 5]` and `[4]` remain unchanged.
The diagonals with a blue arrow (top-right triangle) should be sorted in non-
decreasing order:
* `[7, 2]` becomes `[2, 7]`.
* `[3]` remains unchanged.
Example 2
\Huge
\begin{array}{|c|c|}
\hline
\colorbox{Turquoise} 0 & 1 \\\
\hline
1 & \colorbox{Turquoise} 2 \\\
\hline
\end{array}
\longrightarrow
\begin{array}{|c|c|}
\hline
2 & \colorbox{Turquoise} 1 \\\
\hline
1 & \colorbox{Turquoise} 0 \\\
\hline
\end{array}
Input: grid = [[0,1],[1,2]]
Output: [[2,1],[1,0]]
Explanation:
The diagonals with a black arrow must be non-increasing, so `[0, 2]` is changed to `[2, 0]`. The other diagonals are already in the correct order.
Example 3
Input: grid = [[1]]
Output: [[1]]
Explanation:
Diagonals with exactly one element are already in order, so no changes are
needed.
Constraints
grid.length == grid[i].length == n1 <= n <= 10-10^5 <= grid[i][j] <= 10^5
Solution
Method 1 - Diagonal Identification and Sorting
Intuition
We need to identify diagonals and sort them based on their position. Diagonals can be identified by the difference i - j. For diagonals where i - j >= 0 (bottom-left triangle including main diagonal), sort in non-increasing order. For diagonals where i - j < 0 (top-right triangle), sort in non-decreasing order.
Approach
- Group matrix elements by their diagonal index (
i - j) - For each diagonal, extract all elements
- Sort elements based on diagonal position:
i - j >= 0: sort in descending order (non-increasing)i - j < 0: sort in ascending order (non-decreasing)
- Place sorted elements back into the matrix
Code
C++
#include <vector>
#include <algorithm>
#include <map>
using namespace std;
vector<vector<int>> sortMatrix(vector<vector<int>>& grid) {
int n = grid.size();
map<int, vector<int>> diagonals;
// Group elements by diagonal (i - j)
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
diagonals[i - j].push_back(grid[i][j]);
}
}
// Sort each diagonal
for (auto& [diag, elements] : diagonals) {
if (diag >= 0) {
// Bottom-left triangle (including main diagonal): non-increasing
sort(elements.begin(), elements.end(), greater<int>());
} else {
// Top-right triangle: non-decreasing
sort(elements.begin(), elements.end());
}
}
// Place sorted elements back
vector<vector<int>> result(n, vector<int>(n));
map<int, int> indices; // Track current index for each diagonal
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int diag = i - j;
result[i][j] = diagonals[diag][indices[diag]++];
}
}
return result;
}
Go
import "sort"
func sortMatrix(grid [][]int) [][]int {
n := len(grid)
diagonals := make(map[int][]int)
// Group elements by diagonal
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
diag := i - j
diagonals[diag] = append(diagonals[diag], grid[i][j])
}
}
// Sort each diagonal
for diag, elements := range diagonals {
if diag >= 0 {
// Non-increasing order
sort.Slice(elements, func(a, b int) bool {
return elements[a] > elements[b]
})
} else {
// Non-decreasing order
sort.Ints(elements)
}
diagonals[diag] = elements
}
// Place sorted elements back
result := make([][]int, n)
for i := range result {
result[i] = make([]int, n)
}
indices := make(map[int]int)
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
diag := i - j
result[i][j] = diagonals[diag][indices[diag]]
indices[diag]++
}
}
return result
}
Java
import java.util.*;
class Solution {
public int[][] sortMatrix(int[][] grid) {
int n = grid.length;
Map<Integer, List<Integer>> diagonals = new HashMap<>();
// Group elements by diagonal
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int diag = i - j;
diagonals.computeIfAbsent(diag, k -> new ArrayList<>()).add(grid[i][j]);
}
}
// Sort each diagonal
for (Map.Entry<Integer, List<Integer>> entry : diagonals.entrySet()) {
int diag = entry.getKey();
List<Integer> elements = entry.getValue();
if (diag >= 0) {
// Bottom-left triangle: non-increasing
elements.sort(Collections.reverseOrder());
} else {
// Top-right triangle: non-decreasing
Collections.sort(elements);
}
}
// Place sorted elements back
int[][] result = new int[n][n];
Map<Integer, Integer> indices = new HashMap<>();
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
int diag = i - j;
int idx = indices.getOrDefault(diag, 0);
result[i][j] = diagonals.get(diag).get(idx);
indices.put(diag, idx + 1);
}
}
return result;
}
}
Kotlin
class Solution {
fun sortMatrix(grid: Array<IntArray>): Array<IntArray> {
val n = grid.size
val diagonals = mutableMapOf<Int, MutableList<Int>>()
// Group elements by diagonal
for (i in 0 until n) {
for (j in 0 until n) {
val diag = i - j
diagonals.computeIfAbsent(diag) { mutableListOf() }.add(grid[i][j])
}
}
// Sort each diagonal
for ((diag, elements) in diagonals) {
if (diag >= 0) {
elements.sortDescending() // Non-increasing
} else {
elements.sort() // Non-decreasing
}
}
// Place sorted elements back
val result = Array(n) { IntArray(n) }
val indices = mutableMapOf<Int, Int>()
for (i in 0 until n) {
for (j in 0 until n) {
val diag = i - j
val idx = indices.getOrDefault(diag, 0)
result[i][j] = diagonals[diag]!![idx]
indices[diag] = idx + 1
}
}
return result
}
}
Python
from collections import defaultdict
def sortMatrix(grid: list[list[int]]) -> list[list[int]]:
n = len(grid)
diagonals = defaultdict(list)
# Group elements by diagonal
for i in range(n):
for j in range(n):
diag = i - j
diagonals[diag].append(grid[i][j])
# Sort each diagonal
for diag, elements in diagonals.items():
if diag >= 0:
# Bottom-left triangle: non-increasing
elements.sort(reverse=True)
else:
# Top-right triangle: non-decreasing
elements.sort()
# Place sorted elements back
result = [[0] * n for _ in range(n)]
indices = defaultdict(int)
for i in range(n):
for j in range(n):
diag = i - j
result[i][j] = diagonals[diag][indices[diag]]
indices[diag] += 1
return result
Rust
use std::collections::HashMap;
pub fn sort_matrix(grid: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
let n = grid.len();
let mut diagonals: HashMap<i32, Vec<i32>> = HashMap::new();
// Group elements by diagonal
for i in 0..n {
for j in 0..n {
let diag = i as i32 - j as i32;
diagonals.entry(diag).or_insert_with(Vec::new).push(grid[i][j]);
}
}
// Sort each diagonal
for (diag, elements) in diagonals.iter_mut() {
if *diag >= 0 {
// Bottom-left triangle: non-increasing
elements.sort_by(|a, b| b.cmp(a));
} else {
// Top-right triangle: non-decreasing
elements.sort();
}
}
// Place sorted elements back
let mut result = vec![vec![0; n]; n];
let mut indices: HashMap<i32, usize> = HashMap::new();
for i in 0..n {
for j in 0..n {
let diag = i as i32 - j as i32;
let idx = *indices.get(&diag).unwrap_or(&0);
result[i][j] = diagonals[&diag][idx];
indices.insert(diag, idx + 1);
}
}
result
}
TypeScript
function sortMatrix(grid: number[][]): number[][] {
const n = grid.length;
const diagonals = new Map<number, number[]>();
// Group elements by diagonal
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
const diag = i - j;
if (!diagonals.has(diag)) {
diagonals.set(diag, []);
}
diagonals.get(diag)!.push(grid[i][j]);
}
}
// Sort each diagonal
for (const [diag, elements] of diagonals) {
if (diag >= 0) {
// Bottom-left triangle: non-increasing
elements.sort((a, b) => b - a);
} else {
// Top-right triangle: non-decreasing
elements.sort((a, b) => a - b);
}
}
// Place sorted elements back
const result = Array(n).fill(null).map(() => Array(n).fill(0));
const indices = new Map<number, number>();
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
const diag = i - j;
const idx = indices.get(diag) || 0;
result[i][j] = diagonals.get(diag)![idx];
indices.set(diag, idx + 1);
}
}
return result;
}
Complexity
- ⏰ Time complexity:
O(n² log n)where n is the matrix dimension (sorting each diagonal) - 🧺 Space complexity:
O(n²)for storing diagonal elements and result matrix