Problem
You are given a starting state start
, a list of transition probabilities for a Markov chain, and a number of steps num_steps
. Run the Markov chain starting from start
for num_steps
and compute the number of times we visited each state.
Examples
Example 1
Input:
start = 'a',
num_steps = 5000,
transitions = [
('a', 'a', 0.9),
('a', 'b', 0.075),
('a', 'c', 0.025),
('b', 'a', 0.15),
('b', 'b', 0.8),
('b', 'c', 0.05),
('c', 'a', 0.25),
('c', 'b', 0.25),
('c', 'c', 0.5)
]
Output: { 'a': 3012, 'b': 1656, 'c': 332 }
Explanation: An instance of running the Markov chain for 5000 steps starting at 'a' gave these visit counts for each state.
Solution
Method 1 - Run the simulation
To solve this problem:
- Build a Transition Dictionary: Convert the list of transitions into a dictionary for easy lookup of transition probabilities based on the current state.
- Simulate the Markov Chain: Initialize the current state, then for each step, determine the next state based on the transition probabilities.
- Count Visits: Maintain a count of visits to each state as the simulation progresses.
- Return Results: After running the chain for the given number of steps, return the visit counts.
Code
Java
public class Solution {
public Map<String, Integer> runMarkovChain(String start, int num_steps, List<Transition> transitions) {
// Build transition map
Map<String, List<Transition>> transitionMap = new HashMap<>();
for (Transition transition : transitions) {
transitionMap.computeIfAbsent(transition.from, k -> new ArrayList<>()).add(transition);
}
// Initialize state visit counts
Map<String, Integer> visitCounts = new HashMap<>();
for (Transition transition : transitions) {
visitCounts.put(transition.from, 0);
visitCounts.put(transition.to, 0);
}
// Simulate the Markov chain
Random random = new Random();
String currentState = start;
for (int step = 0; step < num_steps; step++) {
visitCounts.put(currentState, visitCounts.get(currentState) + 1);
double rand = random.nextDouble();
List<Transition> currentTransitions = transitionMap.get(currentState);
double cumulativeProbability = 0.0;
for (Transition transition : currentTransitions) {
cumulativeProbability += transition.probability;
if (rand < cumulativeProbability) {
currentState = transition.to;
break;
}
}
}
return visitCounts;
}
public static class Transition {
String from;
String to;
double probability;
public Transition(String from, String to, double probability) {
this.from = from;
this.to = to;
this.probability = probability;
}
}
public static void main(String[] args) {
Solution sol = new Solution();
List<Solution.Transition> transitions = Arrays.asList(
new Solution.Transition("a", "a", 0.9),
new Solution.Transition("a", "b", 0.075),
new Solution.Transition("a", "c", 0.025),
new Solution.Transition("b", "a", 0.15),
new Solution.Transition("b", "b", 0.8),
new Solution.Transition("b", "c", 0.05),
new Solution.Transition("c", "a", 0.25),
new Solution.Transition("c", "b", 0.25),
new Solution.Transition("c", "c", 0.5)
);
Map<String, Integer> result = sol.runMarkovChain("a", 5000, transitions);
System.out.println(result);
}
}
Python
class Solution:
def runMarkovChain(self, start: str, num_steps: int, transitions: List[Tuple[str, str, float]]) -> Dict[str, int]:
# Build transition map
transition_map: Dict[str, List[Tuple[str, float]]] = {}
for from_state, to_state, prob in transitions:
if from_state not in transition_map:
transition_map[from_state] = []
transition_map[from_state].append((to_state, prob))
# Initialize state visit counts
visit_counts: Dict[str, int] = {}
for from_state, to_state, _ in transitions:
visit_counts[from_state] = 0
visit_counts[to_state] = 0
# Simulate the Markov chain
current_state = start
for _ in range(num_steps):
visit_counts[current_state] += 1
rand = random.random()
cumulative_probability = 0.0
for to_state, prob in transition_map[current_state]:
cumulative_probability += prob
if rand < cumulative_probability:
current_state = to_state
break
return visit_counts
# Example usage:
sol = Solution()
transitions = [
('a', 'a', 0.9),
('a', 'b', 0.075),
('a', 'c', 0.025),
('b', 'a', 0.15),
('b', 'b', 0.8),
('b', 'c', 0.05),
('c', 'a', 0.25),
('c', 'b', 0.25),
('c', 'c', 0.5)
]
result = sol.runMarkovChain('a', 5000, transitions)
print(result)
Complexity
- ⏰ Time complexity:
O(num_steps * T)
whereT
is the expected number of transitions per state lookup. This is due to the nested loop for iterating over the steps and the transition probabilities. - 🧺 Space complexity:
O(S)
whereS
is the number of unique states, for storing the state counts and transition dictionary.