Problem

Given two sparse matrices mat1 of size m x k and mat2 of size k x n, return the result of mat1 x mat2. You may assume that multiplication is always possible.

Examples

Example 1:

$$ \begin{bmatrix} 1 & 0 & 0 \\ -1 & 0 & 3 \end{bmatrix}

\quad \mathbf{X} \quad

\begin{bmatrix} 7 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 1 \end{bmatrix}

\quad \mathbf{=} \quad

\begin{bmatrix} 7 & 0 & 0 \\ -7 & 0 & 3 \end{bmatrix} $$

Input: mat1 = [[1,0,0],[-1,0,3]], mat2 = [[7,0,0],[0,0,0],[0,0,1]]
Output: [[7,0,0],[-7,0,3]]

Solution

Method 1 - Use set to store non-zero elements

The problem requires multiplying two sparse matrices mat1 and mat2. Sparse matrices have a majority of their elements as zeros; hence, efficient storage and computation by focusing only on non-zero elements is essential to reduce time complexity.

Approach

  1. Create an output matrix: Initialize a matrix ans with dimensions m x n.
  2. Sparse Matrix Representation: Instead of storing the entire matrices, use dictionaries or lists to store only the non-zero elements.
  3. Matrix Multiplication:
    • Populate the ans matrix only by multiplying and adding the non-zero elements in the appropriate positions.

Code

Java
class Solution {
    public int[][] multiply(int[][] mat1, int[][] mat2) {
        int m = mat1.length, k = mat1[0].length, n = mat2[0].length;
        int[][] ans = new int[m][n];
        
        // Compute the product
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < k; j++) {
                if (mat1[i][j] != 0) {
                    for (int l = 0; l < n; l++) {
                        if (mat2[j][l] != 0) {
                            ans[i][l] += mat1[i][j] * mat2[j][l];
                        }
                    }
                }
            }
        }
        
        return ans;
    }
}
Python
class Solution:
    def multiply(self, mat1: List[List[int]], mat2: List[List[int]]) -> List[List[int]]:
        m, k, n = len(mat1), len(mat1[0]), len(mat2[0])
        ans: List[List[int]] = [[0] * n for _ in range(m)]
        
        # Compute the product
        for i in range(m):
            for j in range(k):
                if mat1[i][j] != 0:
                    for l in range(n):
                        if mat2[j][l] != 0:
                            ans[i][l] += mat1[i][j] * mat2[j][l]
        
        return ans

Complexity

  • ⏰ Time complexity: O(m * k * n)
  • 🧺 Space complexity: O(m * n)