Problem

Given a text string and a pattern string, replace all occurrences of the pattern in the text with another given replacement string.

Examples

Example 1:

Input: text = "aabaabab", pattern = "ab", replaceWith = "cd"
Output: "acdacdcd"

Example 2:

Input: text = "aabaabab", pattern = "ab", replaceWith = "abc"
Output: "aabcaabcabc"

Solution

Method 1 - Naive

The straightforward method involves performing a brute-force search to identify all instances of the pattern in the text. Then, we replace each occurrence of the pattern with the specified string. This approach has a time complexity of O(n^2).

Method 2 - KMP

To improve the solution, we can use the Knuth-Morris-Pratt (KMP) algorithm to efficiently find all occurrences of the pattern in the text. Unlike the brute-force method, KMP operates in O(n) time. I have discussed the KMP algorithm in detail in another post.

Here’s how we can improve the solution:

  1. Use the KMP algorithm to find all occurrences of the pattern in the text, which takes O(n) time.
  2. Store these positions in an array (let’s call it positionArray).
  3. Utilize a StringBuilder to construct the final string with replacements.
  4. Maintain a pointer (positionPointer) to track the index in positionArray.
  5. Traverse the text from left to right.
    • Append characters to the StringBuilder, except at positions indicated by positionPointer; here, append the replacement string instead and skip the pattern length in the text.
  6. Adjust positionPointer appropriately to manage overlapping patterns (e.g., in a text like “aaaaaa” with pattern “aa” replaced by “X”).

Pseudocode

positions = KMP_SEARCH(text, pattern)
StringBuilder result
positionPointer = 0
for each character in text:
    if positionPointer < length of positions and current index is positions[positionPointer]:
        append replacement to result
        skip pattern length in text
        positionPointer++
    else:
        append current character to result

Code

Java
public class Solution {
    public String replace(final String text, final String pattern, final String replaceWith) throws Exception {
        if (text.equals(pattern)) {
            return replaceWith;
        }
        if (pattern == null || pattern.isEmpty() || replaceWith == null || replaceWith.isEmpty()) {
            return text;
        }

        List<Integer> positions = KMPSearch(text, pattern);
        StringBuilder replacedString = new StringBuilder();
        int positionPointer = 0;
        int i = 0;

        while (i < text.length()) {
            if (positionPointer < positions.size() && i == positions.get(positionPointer)) {
                replacedString.append(replaceWith);
                i += pattern.length();
                positionPointer++;
            } else {
                replacedString.append(text.charAt(i));
                i++;
            }
        }

        return replacedString.toString();
    }

    private List<Integer> KMPSearch(String text, String pattern) {
        int[] lps = computeLPSArray(pattern);
        List<Integer> positions = new ArrayList<>();
        int i = 0;
        int j = 0;

        while (i < text.length()) {
            if (pattern.charAt(j) == text.charAt(i)) {
                i++;
                j++;
            }
            if (j == pattern.length()) {
                positions.add(i - j);
                j = lps[j - 1];
            } else if (i < text.length() && pattern.charAt(j) != text.charAt(i)) {
                if (j != 0) {
                    j = lps[j - 1];
                } else {
                    i++;
                }
            }
        }
        return positions;
    }

    private int[] computeLPSArray(String pattern) {
        int length = 0;
        int i = 1;
        int[] lps = new int[pattern.length()];
        lps[0] = 0;

        while (i < pattern.length()) {
            if (pattern.charAt(i) == pattern.charAt(length)) {
                length++;
                lps[i] = length;
                i++;
            } else {
                if (length != 0) {
                    length = lps[length - 1];
                } else {
                    lps[i] = 0;
                    i++;
                }
            }
        }
        return lps;
    }
}
Python
class Solution:
    def replace(self, text: str, pattern: str, replace_with: str) -> str:
        if text == pattern:
            return replace_with
        if not pattern or not replace_with:
            return text

        positions = self.kmp_search(text, pattern)
        replaced_string = []
        position_pointer = 0
        i = 0

        while i < len(text):
            if position_pointer < len(positions) and i == positions[position_pointer]:
                replaced_string.append(replace_with)
                i += len(pattern)
                position_pointer += 1
            else:
                replaced_string.append(text[i])
                i += 1

        return "".join(replaced_string)

    def kmp_search(self, text: str, pattern: str) -> List[int]:
        lps = self.compute_lps_array(pattern)
        positions = []
        i = 0
        j = 0

        while i < len(text):
            if pattern[j] == text[i]:
                i += 1
                j += 1

            if j == len(pattern):
                positions.append(i - j)
                j = lps[j - 1]
            elif i < len(text) and pattern[j] != text[i]:
                if j != 0:
                    j = lps[j - 1]
                else:
                    i += 1
        return positions

    def compute_lps_array(self, pattern: str) -> List[int]:
        length = 0
        i = 1
        lps = [0] * len(pattern)

        while i < len(pattern):
            if pattern[i] == pattern[length]:
                length += 1
                lps[i] = length
                i += 1
            else:
                if length != 0:
                    length = lps[length - 1]
                else:
                    lps[i] = 0
                    i += 1
        return lps

Complexity

  • ⏰ Time complexity: O(n)
  • 🧺 Space complexity: O(n)