Word Squares
HardUpdated: Aug 2, 2025
Practice on:
Word Squares Problem
Problem
Given a set of words without duplicates, find all (https://en.wikipedia.org/wiki/Word_square) you can build from them.
A sequence of words forms a valid word square if the kth row and column read the exact same string, where 0 ≤ k < max(numRows, numColumns).
For example, the word sequence ["ball","area","lead","lady"] forms a word square because each word reads the same both horizontally and vertically.
b a l l
a r e a
l e a d
l a d y
Examples
Example 1:
Input:
["area","lead","wall","lady","ball"]
Output:
[["wall","area","lead","lady"],["ball","area","lead","lady"]]
Explanation:
The output consists of two word squares. The order of output does not matter (just the order of words in each word square matters).
Example 2:
Input:
["abat","baba","atan","atal"]
Output:
[["baba","abat","baba","atan"],["baba","abat","baba","atal"]]
Solution
Method 1 – Backtracking with Trie
Intuition
To efficiently find all word squares, we use backtracking to build the square row by row, and a Trie to quickly find all words matching a given prefix (the next column to fill).
Approach
- Build a Trie from the word list, where each node stores all words with the current prefix.
- For each word, start a backtracking search with that word as the first row.
- At each step, compute the prefix for the next row (using the current square's columns), and use the Trie to find all candidate words.
- Continue until the square is complete.
Code
C++
struct TrieNode {
vector<string> words;
TrieNode* children[26] = {};
};
class Trie {
public:
TrieNode root;
void insert(const string& word) {
TrieNode* node = &root;
for (char c : word) {
if (!node->children[c-'a']) node->children[c-'a'] = new TrieNode();
node = node->children[c-'a'];
node->words.push_back(word);
}
}
vector<string> find(const string& prefix) {
TrieNode* node = &root;
for (char c : prefix) {
if (!node->children[c-'a']) return {};
node = node->children[c-'a'];
}
return node->words;
}
};
class Solution {
public:
vector<vector<string>> wordSquares(vector<string>& words) {
int n = words[0].size();
Trie trie;
for (auto& w : words) trie.insert(w);
vector<vector<string>> res;
vector<string> square;
function<void()> backtrack = [&]() {
if (square.size() == n) { res.push_back(square); return; }
string prefix;
for (auto& w : square) prefix += w[square.size()];
for (auto& next : trie.find(prefix)) {
square.push_back(next);
backtrack();
square.pop_back();
}
};
for (auto& w : words) { square = {w}; backtrack(); }
return res;
}
};
Go
type TrieNode struct {
words []string
children [26]*TrieNode
}
type Trie struct { root *TrieNode }
func NewTrie() *Trie { return &Trie{&TrieNode{}} }
func (t *Trie) Insert(word string) {
node := t.root
for _, c := range word {
idx := c - 'a'
if node.children[idx] == nil { node.children[idx] = &TrieNode{} }
node = node.children[idx]
node.words = append(node.words, word)
}
}
func (t *Trie) Find(prefix string) []string {
node := t.root
for _, c := range prefix {
idx := c - 'a'
if node.children[idx] == nil { return nil }
node = node.children[idx]
}
return node.words
}
func wordSquares(words []string) [][]string {
n := len(words[0])
trie := NewTrie()
for _, w := range words { trie.Insert(w) }
var res [][]string
var square []string
var backtrack func()
backtrack = func() {
if len(square) == n { tmp := make([]string, n); copy(tmp, square); res = append(res, tmp); return }
prefix := ""
for _, w := range square { prefix += string(w[len(square)]) }
for _, next := range trie.Find(prefix) {
square = append(square, next)
backtrack()
square = square[:len(square)-1]
}
}
for _, w := range words {
square = []string{w}
backtrack()
}
return res
}
Java
class Solution {
static class TrieNode {
List<String> words = new ArrayList<>();
TrieNode[] children = new TrieNode[26];
}
static class Trie {
TrieNode root = new TrieNode();
void insert(String word) {
TrieNode node = root;
for (char c : word.toCharArray()) {
if (node.children[c-'a'] == null) node.children[c-'a'] = new TrieNode();
node = node.children[c-'a'];
node.words.add(word);
}
}
List<String> find(String prefix) {
TrieNode node = root;
for (char c : prefix.toCharArray()) {
if (node.children[c-'a'] == null) return new ArrayList<>();
node = node.children[c-'a'];
}
return node.words;
}
}
public List<List<String>> wordSquares(String[] words) {
int n = words[0].length();
Trie trie = new Trie();
for (String w : words) trie.insert(w);
List<List<String>> res = new ArrayList<>();
List<String> square = new ArrayList<>();
backtrack(trie, n, res, square);
return res;
}
private void backtrack(Trie trie, int n, List<List<String>> res, List<String> square) {
if (square.size() == n) { res.add(new ArrayList<>(square)); return; }
StringBuilder prefix = new StringBuilder();
for (String w : square) prefix.append(w.charAt(square.size()));
for (String next : trie.find(prefix.toString())) {
square.add(next);
backtrack(trie, n, res, square);
square.remove(square.size()-1);
}
}
}
Kotlin
class Solution {
class TrieNode {
val words = mutableListOf<String>()
val children = Array<TrieNode?>(26) { null }
}
class Trie {
val root = TrieNode()
fun insert(word: String) {
var node = root
for (c in word) {
val idx = c - 'a'
if (node.children[idx] == null) node.children[idx] = TrieNode()
node = node.children[idx]!!
node.words.add(word)
}
}
fun find(prefix: String): List<String> {
var node = root
for (c in prefix) {
val idx = c - 'a'
node = node.children[idx] ?: return emptyList()
}
return node.words
}
}
fun wordSquares(words: Array<String>): List<List<String>> {
val n = words[0].length
val trie = Trie()
for (w in words) trie.insert(w)
val res = mutableListOf<List<String>>()
val square = mutableListOf<String>()
fun backtrack() {
if (square.size == n) { res.add(square.toList()); return }
val prefix = buildString { for (w in square) append(w[square.size]) }
for (next in trie.find(prefix)) {
square.add(next)
backtrack()
square.removeAt(square.size-1)
}
}
for (w in words) {
square.clear(); square.add(w)
backtrack()
}
return res
}
}
Python
class TrieNode:
def __init__(self):
self.words = []
self.children = {}
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, word):
node = self.root
for c in word:
if c not in node.children:
node.children[c] = TrieNode()
node = node.children[c]
node.words.append(word)
def find(self, prefix):
node = self.root
for c in prefix:
if c not in node.children:
return []
node = node.children[c]
return node.words
class Solution:
def wordSquares(self, words: list[str]) -> list[list[str]]:
n = len(words[0])
trie = Trie()
for w in words:
trie.insert(w)
res = []
square = []
def backtrack():
if len(square) == n:
res.append(square[:])
return
prefix = ''.join(w[len(square)] for w in square)
for next_word in trie.find(prefix):
square.append(next_word)
backtrack()
square.pop()
for w in words:
square = [w]
backtrack()
return res
Rust
use std::collections::HashMap;
struct TrieNode {
words: Vec<String>,
children: HashMap<char, Box<TrieNode>>,
}
impl TrieNode {
fn new() -> Self { TrieNode { words: vec![], children: HashMap::new() } }
}
struct Trie { root: TrieNode }
impl Trie {
fn new() -> Self { Trie { root: TrieNode::new() } }
fn insert(&mut self, word: &str) {
let mut node = &mut self.root;
for c in word.chars() {
node = node.children.entry(c).or_insert_with(|| Box::new(TrieNode::new()));
node.words.push(word.to_string());
}
}
fn find(&self, prefix: &str) -> Vec<String> {
let mut node = &self.root;
for c in prefix.chars() {
if let Some(child) = node.children.get(&c) {
node = child;
} else { return vec![]; }
}
node.words.clone()
}
}
impl Solution {
pub fn word_squares(words: Vec<String>) -> Vec<Vec<String>> {
let n = words[0].len();
let mut trie = Trie::new();
for w in &words { trie.insert(w); }
let mut res = vec![];
let mut square = vec![];
fn backtrack(n: usize, trie: &Trie, res: &mut Vec<Vec<String>>, square: &mut Vec<String>) {
if square.len() == n { res.push(square.clone()); return; }
let prefix: String = square.iter().map(|w| w.chars().nth(square.len()).unwrap()).collect();
for next in trie.find(&prefix) {
square.push(next);
backtrack(n, trie, res, square);
square.pop();
}
}
for w in &words {
square.clear(); square.push(w.clone());
backtrack(n, &trie, &mut res, &mut square);
}
res
}
}
TypeScript
class TrieNode {
words: string[] = [];
children: Map<string, TrieNode> = new Map();
}
class Trie {
root = new TrieNode();
insert(word: string) {
let node = this.root;
for (const c of word) {
if (!node.children.has(c)) node.children.set(c, new TrieNode());
node = node.children.get(c)!;
node.words.push(word);
}
}
find(prefix: string): string[] {
let node = this.root;
for (const c of prefix) {
if (!node.children.has(c)) return [];
node = node.children.get(c)!;
}
return node.words;
}
}
class Solution {
wordSquares(words: string[]): string[][] {
const n = words[0].length;
const trie = new Trie();
for (const w of words) trie.insert(w);
const res: string[][] = [];
let square: string[] = [];
function backtrack() {
if (square.length === n) { res.push([...square]); return; }
const prefix = square.map(w => w[square.length]).join("");
for (const next of trie.find(prefix)) {
square.push(next);
backtrack();
square.pop();
}
}
for (const w of words) {
square = [w];
backtrack();
}
return res;
}
}
Complexity
- ⏰ Time complexity:
O(N * W^2)whereNis the number of words andWis the word length (for Trie build and search/backtracking). - 🧺 Space complexity:
O(N * W)for the Trie and recursion stack.