Problem

Implement a SnapshotArray that supports the following interface:

  • SnapshotArray(int length) initializes an array-like data structure with the given length. Initially, each element equals 0.
  • void set(index, val) sets the element at the given index to be equal to val.
  • int snap() takes a snapshot of the array and returns the snap_id: the total number of times we called snap() minus 1.
  • int get(index, snap_id) returns the value at the given index, at the time we took the snapshot with the given snap_id

Examples

Example 1:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Input:
["SnapshotArray","set","snap","set","get"]
[[3],[0,5],[],[0,6],[0,0]]
Output:
 [null,null,0,null,5]
Explanation: 
SnapshotArray snapshotArr = new SnapshotArray(3); // set the length to be 3
snapshotArr.set(0,5);  // Set array[0] = 5
snapshotArr.snap();  // Take a snapshot, return snap_id = 0
snapshotArr.set(0,6);
snapshotArr.get(0,0);  // Get the value of array[0] with snap_id = 0, return 5

Solution

Intuition

To efficiently support snapshots and queries, we store a history of value changes for each index, along with the snapshot id when the change occurred. This allows us to quickly retrieve the value at any index for any snapshot using binary search.

Approach

  1. For each index, maintain a list of (snap_id, value) pairs representing changes.
  2. On set, record the value and current snap_id for the index.
  3. On snap, increment the global snap_id and return the previous value.
  4. On get, use binary search to find the most recent value for the index at or before the requested snap_id.

Complexity

  • ⏰ Time complexity: O(log k) for get, where k is the number of changes at an index. Set and snap are O(1).
  • 🧺 Space complexity: O(n + m), where n is the array length and m is the total number of set operations.

Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <vector>
#include <algorithm>
using namespace std;
class SnapshotArray {
    vector<vector<pair<int, int>>> arr;
    int snap_id = 0;
public:
    SnapshotArray(int length) : arr(length) {}
    void set(int index, int val) {
        arr[index].emplace_back(snap_id, val);
    }
    int snap() {
        return snap_id++;
    }
    int get(int index, int snap_id) {
        auto& v = arr[index];
        int l = 0, r = v.size() - 1, ans = 0;
        while (l <= r) {
            int m = l + (r - l) / 2;
            if (v[m].first <= snap_id) {
                ans = v[m].second;
                l = m + 1;
            } else {
                r = m - 1;
            }
        }
        return ans;
    }
};
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
import java.util.*;
class SnapshotArray {
    private final List<TreeMap<Integer, Integer>> arr;
    private int snapId = 0;
    public SnapshotArray(int length) {
        arr = new ArrayList<>(length);
        for (int i = 0; i < length; i++) arr.add(new TreeMap<>());
    }
    public void set(int index, int val) {
        arr.get(index).put(snapId, val);
    }
    public int snap() {
        return snapId++;
    }
    public int get(int index, int snap_id) {
        Map.Entry<Integer, Integer> entry = arr.get(index).floorEntry(snap_id);
        return entry == null ? 0 : entry.getValue();
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import bisect
class SnapshotArray:
    def __init__(self, length: int) -> None:
        self.arr: list[list[tuple[int, int]]] = [[(0, 0)] for _ in range(length)]
        self.snap_id: int = 0
    def set(self, index: int, val: int) -> None:
        self.arr[index].append((self.snap_id, val))
    def snap(self) -> int:
        self.snap_id += 1
        return self.snap_id - 1
    def get(self, index: int, snap_id: int) -> int:
        v = self.arr[index]
        i = bisect.bisect_right(v, (snap_id, float('inf'))) - 1
        return v[i][1]