Word Squares Problem

Problem

Given a set of words without duplicates, find all (https://en.wikipedia.org/wiki/Word_square) you can build from them.

A sequence of words forms a valid word square if the kth row and column read the exact same string, where 0 ≤ k < max(numRows, numColumns).

For example, the word sequence ["ball","area","lead","lady"] forms a word square because each word reads the same both horizontally and vertically.

1
2
3
4
b a l l
a r e a
l e a d
l a d y

Examples

Example 1:

1
2
3
4
5
6
7
Input:
["area","lead","wall","lady","ball"]
Output:
[["wall","area","lead","lady"],["ball","area","lead","lady"]]

Explanation:
The output consists of two word squares. The order of output does not matter (just the order of words in each word square matters).

Example 2:

1
2
3
4
Input:
["abat","baba","atan","atal"]
Output:
 [["baba","abat","baba","atan"],["baba","abat","baba","atal"]]

Solution

Method 1 – Backtracking with Trie

Intuition

To efficiently find all word squares, we use backtracking to build the square row by row, and a Trie to quickly find all words matching a given prefix (the next column to fill).

Approach

  1. Build a Trie from the word list, where each node stores all words with the current prefix.
  2. For each word, start a backtracking search with that word as the first row.
  3. At each step, compute the prefix for the next row (using the current square’s columns), and use the Trie to find all candidate words.
  4. Continue until the square is complete.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
struct TrieNode {
    vector<string> words;
    TrieNode* children[26] = {};
};
class Trie {
public:
    TrieNode root;
    void insert(const string& word) {
        TrieNode* node = &root;
        for (char c : word) {
            if (!node->children[c-'a']) node->children[c-'a'] = new TrieNode();
            node = node->children[c-'a'];
            node->words.push_back(word);
        }
    }
    vector<string> find(const string& prefix) {
        TrieNode* node = &root;
        for (char c : prefix) {
            if (!node->children[c-'a']) return {};
            node = node->children[c-'a'];
        }
        return node->words;
    }
};
class Solution {
public:
    vector<vector<string>> wordSquares(vector<string>& words) {
        int n = words[0].size();
        Trie trie;
        for (auto& w : words) trie.insert(w);
        vector<vector<string>> res;
        vector<string> square;
        function<void()> backtrack = [&]() {
            if (square.size() == n) { res.push_back(square); return; }
            string prefix;
            for (auto& w : square) prefix += w[square.size()];
            for (auto& next : trie.find(prefix)) {
                square.push_back(next);
                backtrack();
                square.pop_back();
            }
        };
        for (auto& w : words) { square = {w}; backtrack(); }
        return res;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
type TrieNode struct {
    words []string
    children [26]*TrieNode
}
type Trie struct { root *TrieNode }
func NewTrie() *Trie { return &Trie{&TrieNode{}} }
func (t *Trie) Insert(word string) {
    node := t.root
    for _, c := range word {
        idx := c - 'a'
        if node.children[idx] == nil { node.children[idx] = &TrieNode{} }
        node = node.children[idx]
        node.words = append(node.words, word)
    }
}
func (t *Trie) Find(prefix string) []string {
    node := t.root
    for _, c := range prefix {
        idx := c - 'a'
        if node.children[idx] == nil { return nil }
        node = node.children[idx]
    }
    return node.words
}
func wordSquares(words []string) [][]string {
    n := len(words[0])
    trie := NewTrie()
    for _, w := range words { trie.Insert(w) }
    var res [][]string
    var square []string
    var backtrack func()
    backtrack = func() {
        if len(square) == n { tmp := make([]string, n); copy(tmp, square); res = append(res, tmp); return }
        prefix := ""
        for _, w := range square { prefix += string(w[len(square)]) }
        for _, next := range trie.Find(prefix) {
            square = append(square, next)
            backtrack()
            square = square[:len(square)-1]
        }
    }
    for _, w := range words {
        square = []string{w}
        backtrack()
    }
    return res
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Solution {
    static class TrieNode {
        List<String> words = new ArrayList<>();
        TrieNode[] children = new TrieNode[26];
    }
    static class Trie {
        TrieNode root = new TrieNode();
        void insert(String word) {
            TrieNode node = root;
            for (char c : word.toCharArray()) {
                if (node.children[c-'a'] == null) node.children[c-'a'] = new TrieNode();
                node = node.children[c-'a'];
                node.words.add(word);
            }
        }
        List<String> find(String prefix) {
            TrieNode node = root;
            for (char c : prefix.toCharArray()) {
                if (node.children[c-'a'] == null) return new ArrayList<>();
                node = node.children[c-'a'];
            }
            return node.words;
        }
    }
    public List<List<String>> wordSquares(String[] words) {
        int n = words[0].length();
        Trie trie = new Trie();
        for (String w : words) trie.insert(w);
        List<List<String>> res = new ArrayList<>();
        List<String> square = new ArrayList<>();
        backtrack(trie, n, res, square);
        return res;
    }
    private void backtrack(Trie trie, int n, List<List<String>> res, List<String> square) {
        if (square.size() == n) { res.add(new ArrayList<>(square)); return; }
        StringBuilder prefix = new StringBuilder();
        for (String w : square) prefix.append(w.charAt(square.size()));
        for (String next : trie.find(prefix.toString())) {
            square.add(next);
            backtrack(trie, n, res, square);
            square.remove(square.size()-1);
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {
    class TrieNode {
        val words = mutableListOf<String>()
        val children = Array<TrieNode?>(26) { null }
    }
    class Trie {
        val root = TrieNode()
        fun insert(word: String) {
            var node = root
            for (c in word) {
                val idx = c - 'a'
                if (node.children[idx] == null) node.children[idx] = TrieNode()
                node = node.children[idx]!!
                node.words.add(word)
            }
        }
        fun find(prefix: String): List<String> {
            var node = root
            for (c in prefix) {
                val idx = c - 'a'
                node = node.children[idx] ?: return emptyList()
            }
            return node.words
        }
    }
    fun wordSquares(words: Array<String>): List<List<String>> {
        val n = words[0].length
        val trie = Trie()
        for (w in words) trie.insert(w)
        val res = mutableListOf<List<String>>()
        val square = mutableListOf<String>()
        fun backtrack() {
            if (square.size == n) { res.add(square.toList()); return }
            val prefix = buildString { for (w in square) append(w[square.size]) }
            for (next in trie.find(prefix)) {
                square.add(next)
                backtrack()
                square.removeAt(square.size-1)
            }
        }
        for (w in words) {
            square.clear(); square.add(w)
            backtrack()
        }
        return res
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class TrieNode:
    def __init__(self):
        self.words = []
        self.children = {}
class Trie:
    def __init__(self):
        self.root = TrieNode()
    def insert(self, word):
        node = self.root
        for c in word:
            if c not in node.children:
                node.children[c] = TrieNode()
            node = node.children[c]
            node.words.append(word)
    def find(self, prefix):
        node = self.root
        for c in prefix:
            if c not in node.children:
                return []
            node = node.children[c]
        return node.words
class Solution:
    def wordSquares(self, words: list[str]) -> list[list[str]]:
        n = len(words[0])
        trie = Trie()
        for w in words:
            trie.insert(w)
        res = []
        square = []
        def backtrack():
            if len(square) == n:
                res.append(square[:])
                return
            prefix = ''.join(w[len(square)] for w in square)
            for next_word in trie.find(prefix):
                square.append(next_word)
                backtrack()
                square.pop()
        for w in words:
            square = [w]
            backtrack()
        return res
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use std::collections::HashMap;
struct TrieNode {
    words: Vec<String>,
    children: HashMap<char, Box<TrieNode>>,
}
impl TrieNode {
    fn new() -> Self { TrieNode { words: vec![], children: HashMap::new() } }
}
struct Trie { root: TrieNode }
impl Trie {
    fn new() -> Self { Trie { root: TrieNode::new() } }
    fn insert(&mut self, word: &str) {
        let mut node = &mut self.root;
        for c in word.chars() {
            node = node.children.entry(c).or_insert_with(|| Box::new(TrieNode::new()));
            node.words.push(word.to_string());
        }
    }
    fn find(&self, prefix: &str) -> Vec<String> {
        let mut node = &self.root;
        for c in prefix.chars() {
            if let Some(child) = node.children.get(&c) {
                node = child;
            } else { return vec![]; }
        }
        node.words.clone()
    }
}
impl Solution {
    pub fn word_squares(words: Vec<String>) -> Vec<Vec<String>> {
        let n = words[0].len();
        let mut trie = Trie::new();
        for w in &words { trie.insert(w); }
        let mut res = vec![];
        let mut square = vec![];
        fn backtrack(n: usize, trie: &Trie, res: &mut Vec<Vec<String>>, square: &mut Vec<String>) {
            if square.len() == n { res.push(square.clone()); return; }
            let prefix: String = square.iter().map(|w| w.chars().nth(square.len()).unwrap()).collect();
            for next in trie.find(&prefix) {
                square.push(next);
                backtrack(n, trie, res, square);
                square.pop();
            }
        }
        for w in &words {
            square.clear(); square.push(w.clone());
            backtrack(n, &trie, &mut res, &mut square);
        }
        res
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TrieNode {
    words: string[] = [];
    children: Map<string, TrieNode> = new Map();
}
class Trie {
    root = new TrieNode();
    insert(word: string) {
        let node = this.root;
        for (const c of word) {
            if (!node.children.has(c)) node.children.set(c, new TrieNode());
            node = node.children.get(c)!;
            node.words.push(word);
        }
    }
    find(prefix: string): string[] {
        let node = this.root;
        for (const c of prefix) {
            if (!node.children.has(c)) return [];
            node = node.children.get(c)!;
        }
        return node.words;
    }
}
class Solution {
    wordSquares(words: string[]): string[][] {
        const n = words[0].length;
        const trie = new Trie();
        for (const w of words) trie.insert(w);
        const res: string[][] = [];
        let square: string[] = [];
        function backtrack() {
            if (square.length === n) { res.push([...square]); return; }
            const prefix = square.map(w => w[square.length]).join("");
            for (const next of trie.find(prefix)) {
                square.push(next);
                backtrack();
                square.pop();
            }
        }
        for (const w of words) {
            square = [w];
            backtrack();
        }
        return res;
    }
}

Complexity

  • ⏰ Time complexity: O(N * W^2) where N is the number of words and W is the word length (for Trie build and search/backtracking).
  • 🧺 Space complexity: O(N * W) for the Trie and recursion stack.