Problem

You are given a starting state start, a list of transition probabilities for a Markov chain, and a number of steps num_steps. Run the Markov chain starting from start for num_steps and compute the number of times we visited each state.

Examples

Example 1

Input: 
start = 'a', 
num_steps = 5000, 
	transitions = [
	('a', 'a', 0.9),
	('a', 'b', 0.075),
	('a', 'c', 0.025),
	('b', 'a', 0.15),
	('b', 'b', 0.8),
	('b', 'c', 0.05),
	('c', 'a', 0.25),
	('c', 'b', 0.25),
	('c', 'c', 0.5)
]
Output: { 'a': 3012, 'b': 1656, 'c': 332 }
Explanation: An instance of running the Markov chain for 5000 steps starting at 'a' gave these visit counts for each state.

Solution

Method 1 - Run the simulation

To solve this problem:

  1. Build a Transition Dictionary: Convert the list of transitions into a dictionary for easy lookup of transition probabilities based on the current state.
  2. Simulate the Markov Chain: Initialize the current state, then for each step, determine the next state based on the transition probabilities.
  3. Count Visits: Maintain a count of visits to each state as the simulation progresses.
  4. Return Results: After running the chain for the given number of steps, return the visit counts.

Code

Java
public class Solution {
    public Map<String, Integer> runMarkovChain(String start, int num_steps, List<Transition> transitions) {
        // Build transition map
        Map<String, List<Transition>> transitionMap = new HashMap<>();
        for (Transition transition : transitions) {
            transitionMap.computeIfAbsent(transition.from, k -> new ArrayList<>()).add(transition);
        }

        // Initialize state visit counts
        Map<String, Integer> visitCounts = new HashMap<>();
        for (Transition transition : transitions) {
            visitCounts.put(transition.from, 0);
            visitCounts.put(transition.to, 0);
        }

        // Simulate the Markov chain
        Random random = new Random();
        String currentState = start;
        for (int step = 0; step < num_steps; step++) {
            visitCounts.put(currentState, visitCounts.get(currentState) + 1);
            double rand = random.nextDouble();
            List<Transition> currentTransitions = transitionMap.get(currentState);
            double cumulativeProbability = 0.0;
            for (Transition transition : currentTransitions) {
                cumulativeProbability += transition.probability;
                if (rand < cumulativeProbability) {
                    currentState = transition.to;
                    break;
                }
            }
        }

        return visitCounts;
    }

    public static class Transition {
        String from;
        String to;
        double probability;

        public Transition(String from, String to, double probability) {
            this.from = from;
            this.to = to;
            this.probability = probability;
        }
    }
    
    public static void main(String[] args) {
        Solution sol = new Solution();
        List<Solution.Transition> transitions = Arrays.asList(
            new Solution.Transition("a", "a", 0.9),
            new Solution.Transition("a", "b", 0.075),
            new Solution.Transition("a", "c", 0.025),
            new Solution.Transition("b", "a", 0.15),
            new Solution.Transition("b", "b", 0.8),
            new Solution.Transition("b", "c", 0.05),
            new Solution.Transition("c", "a", 0.25),
            new Solution.Transition("c", "b", 0.25),
            new Solution.Transition("c", "c", 0.5)
        );
        
        Map<String, Integer> result = sol.runMarkovChain("a", 5000, transitions);
        System.out.println(result);
    }
}
Python
class Solution:
    def runMarkovChain(self, start: str, num_steps: int, transitions: List[Tuple[str, str, float]]) -> Dict[str, int]:
        # Build transition map
        transition_map: Dict[str, List[Tuple[str, float]]] = {}
        for from_state, to_state, prob in transitions:
            if from_state not in transition_map:
                transition_map[from_state] = []
            transition_map[from_state].append((to_state, prob))

        # Initialize state visit counts
        visit_counts: Dict[str, int] = {}
        for from_state, to_state, _ in transitions:
            visit_counts[from_state] = 0
            visit_counts[to_state] = 0

        # Simulate the Markov chain
        current_state = start
        for _ in range(num_steps):
            visit_counts[current_state] += 1
            rand = random.random()
            cumulative_probability = 0.0
            for to_state, prob in transition_map[current_state]:
                cumulative_probability += prob
                if rand < cumulative_probability:
                    current_state = to_state
                    break

        return visit_counts

# Example usage:
sol = Solution()
transitions = [
    ('a', 'a', 0.9),
    ('a', 'b', 0.075),
    ('a', 'c', 0.025),
    ('b', 'a', 0.15),
    ('b', 'b', 0.8),
    ('b', 'c', 0.05),
    ('c', 'a', 0.25),
    ('c', 'b', 0.25),
    ('c', 'c', 0.5)
]
result = sol.runMarkovChain('a', 5000, transitions)
print(result)

Complexity

  • ⏰ Time complexity: O(num_steps * T) where T is the expected number of transitions per state lookup. This is due to the nested loop for iterating over the steps and the transition probabilities.
  • 🧺 Space complexity: O(S) where S is the number of unique states, for storing the state counts and transition dictionary.