Problem

Given a list of lists containing elements, write a function that prints out the permutations of of the elements such that, each of the permutation set contains only 1 element from each list and there are no duplicates in the list of permutation sets.

Examples

Example 1:

Input: lists = [ [a1, b1, c1], [a2, b2] ]
Output: [ [a1, a2], [a1, b2], [b1, a2], [b1, b2], [c1, a2], [c1, b2] ]
Explanation: Note that [a1, a2] is same as [a2, a1] in terms of combination, though they are separate permutation.

Solution

Method 1 - Backtracking

This is a combination problem and we can solve it using backtracking (and dfs).

Read more: Permutation vs Combination. We can start from 0th list in the lists, in listCombinations method.

In the helper method iterates through each element in the current list (indexed by depth), adds it to the current combination, and recursively calls itself to fill the next position.

So, for each string in first list we do recurse to all other lists to find the combinations. We add the current combination to the result ans whenever we reach the last list i.e. depth = n and output contains one elements from each of the n lists. Below is the implementation of this idea.

The method then backtracks by removing the last added element, allowing for the next element to be considered.

This implementation does not check for duplicates in the lists or results, as the generated combinations will inherently be unique given the constraints of the problem (selecting one element from each list).

Code

Java
public class Solution {

	public List<List<Character>> listCombinations(List <List<Character>> lists) {
		List <List<Character>> ans = new ArrayList<>();

		if (lists == null || lists.size() == 0) {
			return ans;
		}

		backtrack(lists, ans, new ArrayList<>(), 0);
		return ans;
	}

	private static void backtrack(List<List<Character>> lists, List<List<Character>> ans, List<Character> combination, int depth) {
		if (depth == lists.size()) {
			result.add(new ArrayList<>(combination));
			return;
		}

		for (int i = 0; i < lists.get(depth).size(); i++) {
			combination.add(lists.get(depth).get(i));
			backtrack(lists, result, combination, depth + 1);
			combination.remove(combination.size() - 1);
		}
	}
}

Here is the runner code:

public static void main(String[] args) {
	List <List<Character>> lists = new ArrayList<>();
	lists.add(List.of('a', 'b'));
	lists.add(List.of('c', 'd'));
	lists.add(List.of('e', 'f'));

	List < List<Character>> result = listCombinations(lists);

	for (List<Character> combination: result) {
		System.out.println(combination);
	}
}