Problem
Given a text string and a pattern string, replace all occurrences of the pattern in the text with another given replacement string.
Examples
Example 1:
Input: text = "aabaabab", pattern = "ab", replaceWith = "cd"
Output: "acdacdcd"
Example 2:
Input: text = "aabaabab", pattern = "ab", replaceWith = "abc"
Output: "aabcaabcabc"
Solution
Method 1 - Naive
The straightforward method involves performing a brute-force search to identify all instances of the pattern in the text. Then, we replace each occurrence of the pattern with the specified string. This approach has a time complexity of O(n^2)
.
Method 2 - KMP
To improve the solution, we can use the Knuth-Morris-Pratt (KMP) algorithm to efficiently find all occurrences of the pattern in the text. Unlike the brute-force method, KMP operates in O(n)
time. I have discussed the KMP algorithm in detail in another post.
Here’s how we can improve the solution:
- Use the KMP algorithm to find all occurrences of the pattern in the text, which takes
O(n)
time. - Store these positions in an array (let’s call it
positionArray
). - Utilize a
StringBuilder
to construct the final string with replacements. - Maintain a pointer (
positionPointer
) to track the index inpositionArray
. - Traverse the text from left to right.
- Append characters to the
StringBuilder
, except at positions indicated bypositionPointer
; here, append the replacement string instead and skip the pattern length in the text.
- Append characters to the
- Adjust
positionPointer
appropriately to manage overlapping patterns (e.g., in a text like “aaaaaa” with pattern “aa” replaced by “X”).
Pseudocode
positions = KMP_SEARCH(text, pattern)
StringBuilder result
positionPointer = 0
for each character in text:
if positionPointer < length of positions and current index is positions[positionPointer]:
append replacement to result
skip pattern length in text
positionPointer++
else:
append current character to result
Code
Java
public class Solution {
public String replace(final String text, final String pattern, final String replaceWith) throws Exception {
if (text.equals(pattern)) {
return replaceWith;
}
if (pattern == null || pattern.isEmpty() || replaceWith == null || replaceWith.isEmpty()) {
return text;
}
List<Integer> positions = KMPSearch(text, pattern);
StringBuilder replacedString = new StringBuilder();
int positionPointer = 0;
int i = 0;
while (i < text.length()) {
if (positionPointer < positions.size() && i == positions.get(positionPointer)) {
replacedString.append(replaceWith);
i += pattern.length();
positionPointer++;
} else {
replacedString.append(text.charAt(i));
i++;
}
}
return replacedString.toString();
}
private List<Integer> KMPSearch(String text, String pattern) {
int[] lps = computeLPSArray(pattern);
List<Integer> positions = new ArrayList<>();
int i = 0;
int j = 0;
while (i < text.length()) {
if (pattern.charAt(j) == text.charAt(i)) {
i++;
j++;
}
if (j == pattern.length()) {
positions.add(i - j);
j = lps[j - 1];
} else if (i < text.length() && pattern.charAt(j) != text.charAt(i)) {
if (j != 0) {
j = lps[j - 1];
} else {
i++;
}
}
}
return positions;
}
private int[] computeLPSArray(String pattern) {
int length = 0;
int i = 1;
int[] lps = new int[pattern.length()];
lps[0] = 0;
while (i < pattern.length()) {
if (pattern.charAt(i) == pattern.charAt(length)) {
length++;
lps[i] = length;
i++;
} else {
if (length != 0) {
length = lps[length - 1];
} else {
lps[i] = 0;
i++;
}
}
}
return lps;
}
}
Python
class Solution:
def replace(self, text: str, pattern: str, replace_with: str) -> str:
if text == pattern:
return replace_with
if not pattern or not replace_with:
return text
positions = self.kmp_search(text, pattern)
replaced_string = []
position_pointer = 0
i = 0
while i < len(text):
if position_pointer < len(positions) and i == positions[position_pointer]:
replaced_string.append(replace_with)
i += len(pattern)
position_pointer += 1
else:
replaced_string.append(text[i])
i += 1
return "".join(replaced_string)
def kmp_search(self, text: str, pattern: str) -> List[int]:
lps = self.compute_lps_array(pattern)
positions = []
i = 0
j = 0
while i < len(text):
if pattern[j] == text[i]:
i += 1
j += 1
if j == len(pattern):
positions.append(i - j)
j = lps[j - 1]
elif i < len(text) and pattern[j] != text[i]:
if j != 0:
j = lps[j - 1]
else:
i += 1
return positions
def compute_lps_array(self, pattern: str) -> List[int]:
length = 0
i = 1
lps = [0] * len(pattern)
while i < len(pattern):
if pattern[i] == pattern[length]:
length += 1
lps[i] = length
i += 1
else:
if length != 0:
length = lps[length - 1]
else:
lps[i] = 0
i += 1
return lps
Complexity
- ⏰ Time complexity:
O(n)
- 🧺 Space complexity:
O(n)