Problem

You are given an n x n square matrix of integers grid. Return the matrix such that:

  • The diagonals in the bottom-left triangle (including the middle diagonal) are sorted in non-increasing order.
  • The diagonals in the top-right triangle are sorted in non-decreasing order.

Example 1

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
Input: grid = [[1,7,3],[9,8,2],[4,5,6]]
Output: [[8,2,3],[9,6,7],[4,5,1]]
Explanation:
![](https://assets.leetcode.com/uploads/2024/12/29/4052example1drawio.png)
The diagonals with a black arrow (bottom-left triangle) should be sorted in
non-increasing order:
* `[1, 8, 6]` becomes `[8, 6, 1]`.
* `[9, 5]` and `[4]` remain unchanged.
The diagonals with a blue arrow (top-right triangle) should be sorted in non-
decreasing order:
* `[7, 2]` becomes `[2, 7]`.
* `[3]` remains unchanged.

Example 2

1
2
3
4
5
6
Input: grid = [[0,1],[1,2]]
Output: [[2,1],[1,0]]
Explanation:
![](https://assets.leetcode.com/uploads/2024/12/29/4052example2adrawio.png)
The diagonals with a black arrow must be non-increasing, so `[0, 2]` is
changed to `[2, 0]`. The other diagonals are already in the correct order.

Example 3

1
2
3
4
5
Input: grid = [[1]]
Output: [[1]]
Explanation:
Diagonals with exactly one element are already in order, so no changes are
needed.

Constraints

  • grid.length == grid[i].length == n
  • 1 <= n <= 10
  • -10^5 <= grid[i][j] <= 10^5

Examples

Solution

Method 1 - Diagonal Identification and Sorting

Intuition: We need to identify diagonals and sort them based on their position. Diagonals can be identified by the difference i - j. For diagonals where i - j >= 0 (bottom-left triangle including main diagonal), sort in non-increasing order. For diagonals where i - j < 0 (top-right triangle), sort in non-decreasing order.

Approach:

  1. Group matrix elements by their diagonal index (i - j)
  2. For each diagonal, extract all elements
  3. Sort elements based on diagonal position:
    • i - j >= 0: sort in descending order (non-increasing)
    • i - j < 0: sort in ascending order (non-decreasing)
  4. Place sorted elements back into the matrix

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <vector>
#include <algorithm>
#include <map>
using namespace std;

vector<vector<int>> sortMatrixByDiagonals(vector<vector<int>>& grid) {
    int n = grid.size();
    map<int, vector<int>> diagonals;
    
    // Group elements by diagonal (i - j)
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            diagonals[i - j].push_back(grid[i][j]);
        }
    }
    
    // Sort each diagonal
    for (auto& [diag, elements] : diagonals) {
        if (diag >= 0) {
            // Bottom-left triangle (including main diagonal): non-increasing
            sort(elements.begin(), elements.end(), greater<int>());
        } else {
            // Top-right triangle: non-decreasing
            sort(elements.begin(), elements.end());
        }
    }
    
    // Place sorted elements back
    vector<vector<int>> result(n, vector<int>(n));
    map<int, int> indices; // Track current index for each diagonal
    
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            int diag = i - j;
            result[i][j] = diagonals[diag][indices[diag]++];
        }
    }
    
    return result;
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import "sort"

func sortMatrixByDiagonals(grid [][]int) [][]int {
    n := len(grid)
    diagonals := make(map[int][]int)
    
    // Group elements by diagonal
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            diag := i - j
            diagonals[diag] = append(diagonals[diag], grid[i][j])
        }
    }
    
    // Sort each diagonal
    for diag, elements := range diagonals {
        if diag >= 0 {
            // Non-increasing order
            sort.Slice(elements, func(a, b int) bool {
                return elements[a] > elements[b]
            })
        } else {
            // Non-decreasing order
            sort.Ints(elements)
        }
        diagonals[diag] = elements
    }
    
    // Place sorted elements back
    result := make([][]int, n)
    for i := range result {
        result[i] = make([]int, n)
    }
    
    indices := make(map[int]int)
    for i := 0; i < n; i++ {
        for j := 0; j < n; j++ {
            diag := i - j
            result[i][j] = diagonals[diag][indices[diag]]
            indices[diag]++
        }
    }
    
    return result
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import java.util.*;

class Solution {
    public int[][] sortMatrixByDiagonals(int[][] grid) {
        int n = grid.length;
        Map<Integer, List<Integer>> diagonals = new HashMap<>();
        
        // Group elements by diagonal
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int diag = i - j;
                diagonals.computeIfAbsent(diag, k -> new ArrayList<>()).add(grid[i][j]);
            }
        }
        
        // Sort each diagonal
        for (Map.Entry<Integer, List<Integer>> entry : diagonals.entrySet()) {
            int diag = entry.getKey();
            List<Integer> elements = entry.getValue();
            
            if (diag >= 0) {
                // Bottom-left triangle: non-increasing
                elements.sort(Collections.reverseOrder());
            } else {
                // Top-right triangle: non-decreasing
                Collections.sort(elements);
            }
        }
        
        // Place sorted elements back
        int[][] result = new int[n][n];
        Map<Integer, Integer> indices = new HashMap<>();
        
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                int diag = i - j;
                int idx = indices.getOrDefault(diag, 0);
                result[i][j] = diagonals.get(diag).get(idx);
                indices.put(diag, idx + 1);
            }
        }
        
        return result;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class Solution {
    fun sortMatrixByDiagonals(grid: Array<IntArray>): Array<IntArray> {
        val n = grid.size
        val diagonals = mutableMapOf<Int, MutableList<Int>>()
        
        // Group elements by diagonal
        for (i in 0 until n) {
            for (j in 0 until n) {
                val diag = i - j
                diagonals.computeIfAbsent(diag) { mutableListOf() }.add(grid[i][j])
            }
        }
        
        // Sort each diagonal
        for ((diag, elements) in diagonals) {
            if (diag >= 0) {
                elements.sortDescending() // Non-increasing
            } else {
                elements.sort() // Non-decreasing
            }
        }
        
        // Place sorted elements back
        val result = Array(n) { IntArray(n) }
        val indices = mutableMapOf<Int, Int>()
        
        for (i in 0 until n) {
            for (j in 0 until n) {
                val diag = i - j
                val idx = indices.getOrDefault(diag, 0)
                result[i][j] = diagonals[diag]!![idx]
                indices[diag] = idx + 1
            }
        }
        
        return result
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from collections import defaultdict

def sortMatrixByDiagonals(grid: list[list[int]]) -> list[list[int]]:
    n = len(grid)
    diagonals = defaultdict(list)
    
    # Group elements by diagonal
    for i in range(n):
        for j in range(n):
            diag = i - j
            diagonals[diag].append(grid[i][j])
    
    # Sort each diagonal
    for diag, elements in diagonals.items():
        if diag >= 0:
            # Bottom-left triangle: non-increasing
            elements.sort(reverse=True)
        else:
            # Top-right triangle: non-decreasing
            elements.sort()
    
    # Place sorted elements back
    result = [[0] * n for _ in range(n)]
    indices = defaultdict(int)
    
    for i in range(n):
        for j in range(n):
            diag = i - j
            result[i][j] = diagonals[diag][indices[diag]]
            indices[diag] += 1
    
    return result
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use std::collections::HashMap;

pub fn sort_matrix_by_diagonals(grid: Vec<Vec<i32>>) -> Vec<Vec<i32>> {
    let n = grid.len();
    let mut diagonals: HashMap<i32, Vec<i32>> = HashMap::new();
    
    // Group elements by diagonal
    for i in 0..n {
        for j in 0..n {
            let diag = i as i32 - j as i32;
            diagonals.entry(diag).or_insert_with(Vec::new).push(grid[i][j]);
        }
    }
    
    // Sort each diagonal
    for (diag, elements) in diagonals.iter_mut() {
        if *diag >= 0 {
            // Bottom-left triangle: non-increasing
            elements.sort_by(|a, b| b.cmp(a));
        } else {
            // Top-right triangle: non-decreasing
            elements.sort();
        }
    }
    
    // Place sorted elements back
    let mut result = vec![vec![0; n]; n];
    let mut indices: HashMap<i32, usize> = HashMap::new();
    
    for i in 0..n {
        for j in 0..n {
            let diag = i as i32 - j as i32;
            let idx = *indices.get(&diag).unwrap_or(&0);
            result[i][j] = diagonals[&diag][idx];
            indices.insert(diag, idx + 1);
        }
    }
    
    result
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
function sortMatrixByDiagonals(grid: number[][]): number[][] {
    const n = grid.length;
    const diagonals = new Map<number, number[]>();
    
    // Group elements by diagonal
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            const diag = i - j;
            if (!diagonals.has(diag)) {
                diagonals.set(diag, []);
            }
            diagonals.get(diag)!.push(grid[i][j]);
        }
    }
    
    // Sort each diagonal
    for (const [diag, elements] of diagonals) {
        if (diag >= 0) {
            // Bottom-left triangle: non-increasing
            elements.sort((a, b) => b - a);
        } else {
            // Top-right triangle: non-decreasing
            elements.sort((a, b) => a - b);
        }
    }
    
    // Place sorted elements back
    const result = Array(n).fill(null).map(() => Array(n).fill(0));
    const indices = new Map<number, number>();
    
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++) {
            const diag = i - j;
            const idx = indices.get(diag) || 0;
            result[i][j] = diagonals.get(diag)![idx];
            indices.set(diag, idx + 1);
        }
    }
    
    return result;
}

Complexity

  • ⏰ Time complexity: O(n² log n) where n is the matrix dimension (sorting each diagonal)
  • 🧺 Space complexity: O(n²) for storing diagonal elements and result matrix