Problem

Design a map that allows you to do the following:

  • Maps a string key to a given value.
  • Returns the sum of the values that have a key with a prefix equal to a given string.

Implement the MapSum class:

  • MapSum() Initializes the MapSum object.
  • void insert(String key, int val) Inserts the key-val pair into the map. If the key already existed, the original key-value pair will be overridden to the new one.
  • int sum(string prefix) Returns the sum of all the pairs’ value whose key starts with the prefix.

Examples

Example 1:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
**Input**
["MapSum", "insert", "sum", "insert", "sum"]
[[], ["apple", 3], ["ap"], ["app", 2], ["ap"]]
**Output**
[null, null, 3, null, 5]

**Explanation**
MapSum mapSum = new MapSum();
mapSum.insert("apple", 3);  
mapSum.sum("ap");           // return 3 (_ap_ ple = 3)
mapSum.insert("app", 2);    
mapSum.sum("ap");           // return 5 (_ap_ ple + _ap_ p = 3 + 2 = 5)

Constraints:

  • 1 <= key.length, prefix.length <= 50
  • key and prefix consist of only lowercase English letters.
  • 1 <= val <= 1000
  • At most 50 calls will be made to insert and sum.

Solution

Method 1 - Using map and prefix sum

To solve the problem, we need to design a class that supports two primary operations: inserting a key-value pair and calculating the sum of all values associated with keys that share a given prefix.

  1. Data Structure:
    • Use a HashMap to store the mapping of keys to their respective values.
    • Use a Trie (prefix tree) to efficiently compute the sum of values based on prefixes.
  2. Insert Operation:
    • Update the key-value pair in the HashMap.
    • If the key existed before and is being updated, adjust the values stored in the Trie to account for the difference.
    • Insert the key into the Trie, updating the cumulative sums along the way.
  3. Sum Operation:
    • Traverse the Trie from its root to evaluate the cumulative sum for the given prefix.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class Solution {

    static class MapSum {

        private final Map<String, Integer> map;
        private final TrieNode trie;

        public MapSum() {
            map = new HashMap<>();
            trie = new TrieNode();
        }

        public void insert(String key, int val) {
            int diff = val - map.getOrDefault(key, 0);
            map.put(key, val);

            TrieNode node = trie;
            for (char ch : key.toCharArray()) {
                node.children.putIfAbsent(ch, new TrieNode());
                node = node.children.get(ch);
                node.sum += diff;
            }
        }

        public int sum(String prefix) {
            TrieNode node = trie;
            for (char ch : prefix.toCharArray()) {
                if (!node.children.containsKey(ch)) {
                    return 0;
                }
                node = node.children.get(ch);
            }
            return node.sum;
        }

        static class TrieNode {

            Map<Character, TrieNode> children;
            int sum;

            TrieNode() {
                children = new HashMap<>();
                sum = 0;
            }
        }
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
    class MapSum:
        def __init__(self) -> None:
            self.map: Dict[str, int] = {}
            self.trie: Dict = {}

        def insert(self, key: str, val: int) -> None:
            diff = val - self.map.get(key, 0)
            self.map[key] = val
            node = self.trie
            for ch in key:
                if ch not in node:
                    node[ch] = {"sum": 0, "next": {}}
                node[ch]["sum"] += diff
                node = node[ch]["next"]

        def sum(self, prefix: str) -> int:
            node = self.trie
            for ch in prefix:
                if ch not in node:
                    return 0
                node = node[ch]["next"]
            ans = node.get("sum", 0)
            return ans

Complexity

  • ⏰ Time complexity:
    • insert: O(k) where k is the length of the key being inserted.
    • sum: O(p) where p is the length of the prefix provided.
  • 🧺 Space complexity:
    • O(n * k) for the Trie, where n is the number of keys and k is the average key length, plus the space for the HashMap.