Problem

Given a text string and a dictionary of M patterns, find all occurrences of any pattern in the text. The patterns may have different lengths. The goal is to efficiently find all matches of any pattern in the text, even when the number of patterns is large.

Examples

Example 1

1
2
3
4
5
6
7
Input:
patterns = ["he", "she", "his", "hers"]
text = "ahishers"
Output:
Matches at positions: [1, 2, 3, 4, 5]
Explanation:
The patterns "he", "she", "his", and "hers" all appear as substrings in the text at various positions.

Solution

Intuition

When searching for multiple patterns in a long text, running a single-pattern search for each pattern is too slow. The Aho-Corasick algorithm builds a trie of all patterns and augments it with suffix links (like KMP) to allow simultaneous search for all patterns in a single pass over the text. This enables efficient multi-pattern matching in linear time.

Approach

  1. Build a trie from all patterns. Each node represents a prefix of some pattern.
  2. For each node, compute a suffix link to the longest proper suffix that is also a prefix in the trie (using BFS).
  3. For each node, compute an output link to the next node that is a pattern ending (for reporting multiple matches).
  4. Scan the text character by character:
    • At each step, follow trie edges if possible; otherwise, follow suffix links until a match or the root is reached.
    • At each node, report all patterns ending at that position (using output links).

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class TrieNode {
public:
  unordered_map<char, TrieNode*> children;
  TrieNode* fail = nullptr;
  vector<int> wordIds;
};
class Solution {
public:
  vector<int> ahoCorasick(vector<string>& patterns, string text) {
    int m = patterns.size();
    TrieNode* root = new TrieNode();
    // Build trie
    for (int i = 0; i < m; ++i) {
      TrieNode* node = root;
      for (char c : patterns[i]) {
        if (!node->children.count(c)) node->children[c] = new TrieNode();
        node = node->children[c];
      }
      node->wordIds.push_back(i);
    }
    // Build fail links
    queue<TrieNode*> q;
    root->fail = root;
    for (auto& [c, child] : root->children) {
      child->fail = root; q.push(child);
    }
    while (!q.empty()) {
      TrieNode* node = q.front(); q.pop();
      for (auto& [c, child] : node->children) {
        TrieNode* f = node->fail;
        while (f != root && !f->children.count(c)) f = f->fail;
        if (f->children.count(c) && f->children[c] != child) child->fail = f->children[c];
        else child->fail = root;
        for (int id : child->fail->wordIds) child->wordIds.push_back(id);
        q.push(child);
      }
    }
    // Search
    vector<int> ans;
    TrieNode* node = root;
    for (int i = 0; i < text.size(); ++i) {
      char c = text[i];
      while (node != root && !node->children.count(c)) node = node->fail;
      if (node->children.count(c)) node = node->children[c];
      for (int id : node->wordIds) ans.push_back(i);
    }
    return ans;
  }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class TrieNode {
  Map<Character, TrieNode> children = new HashMap<>();
  TrieNode fail = null;
  List<Integer> wordIds = new ArrayList<>();
}
class Solution {
  public List<Integer> ahoCorasick(List<String> patterns, String text) {
    int m = patterns.size();
    TrieNode root = new TrieNode();
    for (int i = 0; i < m; ++i) {
      TrieNode node = root;
      for (char c : patterns.get(i).toCharArray()) {
        node = node.children.computeIfAbsent(c, k -> new TrieNode());
      }
      node.wordIds.add(i);
    }
    Queue<TrieNode> q = new LinkedList<>();
    root.fail = root;
    for (TrieNode child : root.children.values()) {
      child.fail = root; q.add(child);
    }
    while (!q.isEmpty()) {
      TrieNode node = q.poll();
      for (Map.Entry<Character, TrieNode> entry : node.children.entrySet()) {
        char c = entry.getKey(); TrieNode child = entry.getValue();
        TrieNode f = node.fail;
        while (f != root && !f.children.containsKey(c)) f = f.fail;
        if (f.children.containsKey(c) && f.children.get(c) != child) child.fail = f.children.get(c);
        else child.fail = root;
        child.wordIds.addAll(child.fail.wordIds);
        q.add(child);
      }
    }
    List<Integer> ans = new ArrayList<>();
    TrieNode node = root;
    for (int i = 0; i < text.length(); ++i) {
      char c = text.charAt(i);
      while (node != root && !node.children.containsKey(c)) node = node.fail;
      if (node.children.containsKey(c)) node = node.children.get(c);
      ans.addAll(node.wordIds.stream().map(id -> i).toList());
    }
    return ans;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class TrieNode:
    def __init__(self):
        self.children = {}
        self.fail = None
        self.word_ids = []

class Solution:
    def aho_corasick(self, patterns: list[str], text: str) -> list[int]:
        root = TrieNode()
        for idx, pat in enumerate(patterns):
            node = root
            for c in pat:
                if c not in node.children:
                    node.children[c] = TrieNode()
                node = node.children[c]
            node.word_ids.append(idx)
        # Build fail links
        from collections import deque
        root.fail = root
        q = deque()
        for child in root.children.values():
            child.fail = root
            q.append(child)
        while q:
            node = q.popleft()
            for c, child in node.children.items():
                f = node.fail
                while f != root and c not in f.children:
                    f = f.fail
                if c in f.children and f.children[c] != child:
                    child.fail = f.children[c]
                else:
                    child.fail = root
                child.word_ids += child.fail.word_ids
                q.append(child)
        # Search
        ans = []
        node = root
        for i, c in enumerate(text):
            while node != root and c not in node.children:
                node = node.fail
            if c in node.children:
                node = node.children[c]
            for _ in node.word_ids:
                ans.append(i)
        return ans

Complexity

  • ⏰ Time complexity: O(N + L + Z), where N is the text length, L is the total length of all patterns, and Z is the number of matches. Each character and trie edge is processed at most once.
  • 🧺 Space complexity: O(L + Q), where Q is the alphabet size times the number of trie nodes (for child pointers).