Data Structures & Algorithms Refresher
Comprehensive quick-reference for DSA patterns used in coding interviews — arrays, trees, graphs, dynamic programming, and common problem-solving templates
Table of Contents
1. Complexity Analysis
Big-O describes how runtime or space scales with input size as n grows toward infinity. It captures the worst-case growth rate, ignoring constants and lower-order terms.
Big-O Reference Table
| Notation | Name | Example | n=10 | n=1,000 | n=1,000,000 |
|---|---|---|---|---|---|
O(1) | Constant | Hash lookup, array index | 1 | 1 | 1 |
O(log n) | Logarithmic | Binary search, BST op | 3 | 10 | 20 |
O(n) | Linear | Linear scan, single loop | 10 | 1K | 1M |
O(n log n) | Linearithmic | Merge sort, heap sort | 33 | 10K | 20M |
| \(O(n^2)\) | Quadratic | Nested loops, bubble sort | 100 | 1M | \(10^{12}\) |
| \(O(2^n)\) | Exponential | Subset enumeration, naive recursion | 1,024 | \(2^{1000}\) | \(2^{10^6}\) |
O(n!) | Factorial | Permutations, TSP brute force | 3.6M | unusable | unusable |
Space Complexity
Auxiliary space is the extra space used beyond the input. Total space includes input storage. Interview problems usually ask about auxiliary space.
| Situation | Space | Example |
|---|---|---|
| A few variables | O(1) | Two-pointer, in-place sort |
| Recursion depth | O(depth) | DFS on balanced tree = O(log n); skewed = O(n) |
| Output / copy of input | O(n) | Prefix sum array, result list |
| 2D grid / DP table | \(O(n^2)\) | Edit distance DP (optimizable to O(n)) |
| Recursion + data per frame | O(n × k) | DFS with path copy at each level |
Analyzing Loops and Recursion
Nested loops
# O(n^2) — each loop 0..n
for i in range(n):
for j in range(n):
process(i, j)
# O(n^2) — inner loop still O(n) amortized
for i in range(n):
for j in range(i, n): # triangle: n*(n-1)/2 iterations
process(i, j)
# O(n log n) — inner loop halves each time
for i in range(n):
j = 1
while j < n:
process(i, j)
j *= 2
Recurrence relations (Master Theorem shortcut)
# T(n) = a * T(n/b) + O(n^d)
#
# Compare d vs log_b(a):
# d > log_b(a) => O(n^d)
# d == log_b(a) => O(n^d * log n)
# d < log_b(a) => O(n^log_b(a))
#
# Merge sort: a=2, b=2, d=1 => log_2(2)=1 == d => O(n log n)
# Binary search: a=1, b=2, d=0 => log_2(1)=0 == d => O(log n)
# Strassen: a=7, b=2, d=2 => log_2(7)~2.81 > d => O(n^2.81)
- Divide & conquer halving: \(T(n) = 2T(n/2) + O(n) \rightarrow O(n \log n)\)
- Single recursive call halving: \(T(n) = T(n/2) + O(1) \rightarrow O(\log n)\)
- Linear recursion: \(T(n) = T(n{-}1) + O(1) \rightarrow O(n)\)
- Two branches, no memoization: \(T(n) = 2T(n{-}1) + O(1) \rightarrow O(2^n)\)
Identify the Time Complexity
What does this function print? Think about how many times the loop runs as n doubles.
def mystery(n):
count = 0
i = 1
while i < n:
count += 1
i *= 2
return count
print(mystery(64))
public class Solution {
public static int mystery(int n) {
int count = 0;
int i = 1;
while (i < n) {
count++;
i *= 2;
}
return count;
}
public static void main(String[] args) {
System.out.println(mystery(64));
}
}
i: 1 → 2 → 4 → 8 → 16 → 32 → 64. The loop runs log₂(n) times. This is O(log n).def mystery(n):
count = 0
i = 1
while i < n:
count += 1
i *= 2
return count
print(mystery(64)) # Output: 6
public class Solution {
public static int mystery(int n) {
int count = 0;
int i = 1;
while (i < n) {
count++;
i *= 2;
}
return count;
}
public static void main(String[] args) {
System.out.println(mystery(64)); // Output: 6
}
}
Nested Loop Complexity
What does this function print? Count how many times count is incremented across all iterations of both loops.
def mystery(n):
count = 0
for i in range(n):
for j in range(i):
count += 1
return count
print(mystery(5))
public class Solution {
public static int mystery(int n) {
int count = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++) {
count++;
}
}
return count;
}
public static void main(String[] args) {
System.out.println(mystery(5));
}
}
def mystery(n):
count = 0
for i in range(n):
for j in range(i):
count += 1
return count
print(mystery(5)) # Output: 10
public class Solution {
public static int mystery(int n) {
int count = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < i; j++) {
count++;
}
}
return count;
}
public static void main(String[] args) {
System.out.println(mystery(5)); // Output: 10
}
}
Recursion Complexity
What does this function print? Trace the recursive calls — each call spawns two more.
def mystery(n):
if n <= 1:
return 1
return mystery(n - 1) + mystery(n - 1)
print(mystery(5))
public class Solution {
public static int mystery(int n) {
if (n <= 1) {
return 1;
}
return mystery(n - 1) + mystery(n - 1);
}
public static void main(String[] args) {
System.out.println(mystery(5));
}
}
def mystery(n):
if n <= 1:
return 1
return mystery(n - 1) + mystery(n - 1)
print(mystery(5)) # Output: 16
public class Solution {
public static int mystery(int n) {
if (n <= 1) {
return 1;
}
return mystery(n - 1) + mystery(n - 1);
}
public static void main(String[] args) {
System.out.println(mystery(5)); // Output: 16
}
}
2. Arrays & Hashing
Two Pointers
Two pointers eliminates an inner loop by maintaining a structural invariant as pointers move. Choose the variant based on whether the answer involves a pair (opposite direction) or a subarray property (same direction).
Opposite-direction — sorted array pair problems
def two_sum_sorted(nums: list[int], target: int) -> tuple[int, int]:
"""Find indices of two numbers that sum to target in a sorted array.
Time: O(n) Space: O(1)
"""
left, right = 0, len(nums) - 1
while left < right:
s = nums[left] + nums[right]
if s == target:
return left, right
elif s < target:
left += 1 # need bigger sum
else:
right -= 1 # need smaller sum
return -1, -1 # not found
Same-direction — remove duplicates / partition
def remove_duplicates(nums: list[int]) -> int:
"""Remove duplicates in-place from sorted array; return new length.
Time: O(n) Space: O(1)
Pattern: slow pointer marks the 'write' position.
"""
if not nums:
return 0
slow = 0
for fast in range(1, len(nums)):
if nums[fast] != nums[slow]:
slow += 1
nums[slow] = nums[fast]
return slow + 1
def partition(nums: list[int], pivot: int) -> list[int]:
"""Dutch National Flag — 3-way partition around pivot.
Time: O(n) Space: O(1)
"""
low, mid, high = 0, 0, len(nums) - 1
while mid <= high:
if nums[mid] < pivot:
nums[low], nums[mid] = nums[mid], nums[low]
low += 1
mid += 1
elif nums[mid] == pivot:
mid += 1
else:
nums[mid], nums[high] = nums[high], nums[mid]
high -= 1
return nums
Sliding Window
Sliding window avoids recomputing aggregates from scratch on every subarray. The window expands at the right and contracts at the left based on a validity condition.
Fixed-size window
def max_sum_subarray(nums: list[int], k: int) -> int:
"""Maximum sum of any subarray of length exactly k.
Time: O(n) Space: O(1)
"""
if len(nums) < k:
return 0
# Build first window
window_sum = sum(nums[:k])
best = window_sum
for i in range(k, len(nums)):
window_sum += nums[i] - nums[i - k] # slide: add right, drop left
best = max(best, window_sum)
return best
Variable-size window — find shortest/longest subarray satisfying condition
def min_subarray_len(target: int, nums: list[int]) -> int:
"""Minimum length subarray with sum >= target.
Time: O(n) Space: O(1)
"""
left = 0
current_sum = 0
min_len = float('inf')
for right in range(len(nums)):
current_sum += nums[right] # expand window
while current_sum >= target: # window valid: try to shrink
min_len = min(min_len, right - left + 1)
current_sum -= nums[left]
left += 1
return min_len if min_len != float('inf') else 0
def longest_substring_k_distinct(s: str, k: int) -> int:
"""Longest substring with at most k distinct characters.
Time: O(n) Space: O(k)
"""
from collections import defaultdict
freq: dict[str, int] = defaultdict(int)
left = 0
best = 0
for right, ch in enumerate(s):
freq[ch] += 1
while len(freq) > k: # window invalid: shrink
left_ch = s[left]
freq[left_ch] -= 1
if freq[left_ch] == 0:
del freq[left_ch]
left += 1
best = max(best, right - left + 1)
return best
Prefix Sums
def build_prefix(nums: list[int]) -> list[int]:
"""prefix[i] = sum of nums[0..i-1]. prefix[0] = 0 (sentinel).
Time: O(n) Space: O(n)
"""
prefix = [0] * (len(nums) + 1)
for i, x in enumerate(nums):
prefix[i + 1] = prefix[i] + x
return prefix
def range_sum(prefix: list[int], left: int, right: int) -> int:
"""Sum of nums[left..right] inclusive. O(1) per query."""
return prefix[right + 1] - prefix[left]
def subarray_sum_equals_k(nums: list[int], k: int) -> int:
"""Count subarrays with sum exactly k.
Key insight: sum(i..j) = prefix[j+1] - prefix[i] = k
=> prefix[i] = prefix[j+1] - k
Time: O(n) Space: O(n)
"""
from collections import defaultdict
count = 0
running = 0
seen: dict[int, int] = defaultdict(int)
seen[0] = 1 # empty prefix
for x in nums:
running += x
count += seen[running - k]
seen[running] += 1
return count
Hash Map Patterns
Two Sum (unsorted)
def two_sum(nums: list[int], target: int) -> tuple[int, int]:
"""Time: O(n) Space: O(n)"""
seen: dict[int, int] = {}
for i, x in enumerate(nums):
complement = target - x
if complement in seen:
return seen[complement], i
seen[x] = i
return -1, -1
Frequency counting / grouping
from collections import defaultdict, Counter
def group_anagrams(words: list[str]) -> list[list[str]]:
"""Group words that are anagrams of each other.
Time: O(n * m log m) where m = avg word length. Space: O(n*m)
"""
groups: dict[tuple, list[str]] = defaultdict(list)
for word in words:
key = tuple(sorted(word)) # canonical form
groups[key].append(word)
return list(groups.values())
def top_k_frequent(nums: list[int], k: int) -> list[int]:
"""Return k most frequent elements.
Time: O(n log k) Space: O(n)
"""
import heapq
freq = Counter(nums)
return heapq.nlargest(k, freq, key=freq.get)
Kadane's Algorithm — Maximum Subarray
def max_subarray(nums: list[int]) -> int:
"""Maximum sum of any contiguous subarray (Kadane's algorithm).
Time: O(n) Space: O(1)
Invariant: current is the max sum of any subarray ending at index i.
If current goes negative, it can only hurt future sums — reset to 0.
"""
best = nums[0]
current = nums[0]
for x in nums[1:]:
current = max(x, current + x) # extend or restart
best = max(best, current)
return best
def max_subarray_with_indices(nums: list[int]) -> tuple[int, int, int]:
"""Returns (max_sum, start, end) — inclusive indices."""
best_sum = nums[0]
best_start = best_end = 0
current_sum = nums[0]
current_start = 0
for i in range(1, len(nums)):
if nums[i] > current_sum + nums[i]:
current_sum = nums[i]
current_start = i
else:
current_sum += nums[i]
if current_sum > best_sum:
best_sum = current_sum
best_start = current_start
best_end = i
return best_sum, best_start, best_end
Merge Intervals
def merge_intervals(intervals: list[list[int]]) -> list[list[int]]:
"""Merge all overlapping intervals.
Time: O(n log n) Space: O(n)
"""
if not intervals:
return []
intervals.sort(key=lambda x: x[0]) # sort by start
merged = [intervals[0]]
for start, end in intervals[1:]:
if start <= merged[-1][1]: # overlaps with last merged
merged[-1][1] = max(merged[-1][1], end)
else:
merged.append([start, end])
return merged
def insert_interval(intervals: list[list[int]], new: list[int]) -> list[list[int]]:
"""Insert a new interval, merging as needed. Assumes sorted input.
Time: O(n) Space: O(n)
"""
result = []
i, n = 0, len(intervals)
# All intervals that end before new starts
while i < n and intervals[i][1] < new[0]:
result.append(intervals[i])
i += 1
# Merge overlapping intervals into new
while i < n and intervals[i][0] <= new[1]:
new[0] = min(new[0], intervals[i][0])
new[1] = max(new[1], intervals[i][1])
i += 1
result.append(new)
result.extend(intervals[i:])
return result
Two Sum
Given an array of integers and a target, return the indices of the two numbers that add up to the target. Use a hash map for O(n) time.
def two_sum(nums, target):
seen = {}
for i, num in enumerate(nums):
complement = target - num
# YOUR CODE HERE
seen[num] = i
return []
print(two_sum([2, 7, 11, 15], 9)) # [0, 1]
print(two_sum([3, 2, 4], 6)) # [1, 2]
import java.util.*;
public class Solution {
public static int[] twoSum(int[] nums, int target) {
Map<Integer, Integer> seen = new HashMap<>();
for (int i = 0; i < nums.length; i++) {
int complement = target - nums[i];
// YOUR CODE HERE
seen.put(nums[i], i);
}
return new int[]{};
}
public static void main(String[] args) {
System.out.println(Arrays.toString(twoSum(new int[]{2, 7, 11, 15}, 9))); // [0, 1]
System.out.println(Arrays.toString(twoSum(new int[]{3, 2, 4}, 6))); // [1, 2]
}
}
complement = target - num. Check if complement is already in seen. If yes, return [seen[complement], i]. Otherwise store num → i in the map.def two_sum(nums, target):
seen = {}
for i, num in enumerate(nums):
complement = target - num
if complement in seen:
return [seen[complement], i]
seen[num] = i
return []
print(two_sum([2, 7, 11, 15], 9)) # [0, 1]
print(two_sum([3, 2, 4], 6)) # [1, 2]
import java.util.*;
public class Solution {
public static int[] twoSum(int[] nums, int target) {
Map<Integer, Integer> seen = new HashMap<>();
for (int i = 0; i < nums.length; i++) {
int complement = target - nums[i];
if (seen.containsKey(complement)) {
return new int[]{seen.get(complement), i};
}
seen.put(nums[i], i);
}
return new int[]{};
}
public static void main(String[] args) {
System.out.println(Arrays.toString(twoSum(new int[]{2, 7, 11, 15}, 9))); // [0, 1]
System.out.println(Arrays.toString(twoSum(new int[]{3, 2, 4}, 6))); // [1, 2]
}
}
Contains Duplicate
Return True if any value appears at least twice in the array, False if all elements are distinct. Aim for O(n) time using a set.
def contains_duplicate(nums):
seen = set()
for num in nums:
# YOUR CODE HERE
seen.add(num)
return False
print(contains_duplicate([1, 2, 3, 1])) # True
print(contains_duplicate([1, 2, 3, 4])) # False
import java.util.*;
public class Solution {
public static boolean containsDuplicate(int[] nums) {
Set<Integer> seen = new HashSet<>();
for (int num : nums) {
// YOUR CODE HERE
seen.add(num);
}
return false;
}
public static void main(String[] args) {
System.out.println(containsDuplicate(new int[]{1, 2, 3, 1})); // true
System.out.println(containsDuplicate(new int[]{1, 2, 3, 4})); // false
}
}
if num in seen. If so, return True immediately.def contains_duplicate(nums):
seen = set()
for num in nums:
if num in seen:
return True
seen.add(num)
return False
print(contains_duplicate([1, 2, 3, 1])) # True
print(contains_duplicate([1, 2, 3, 4])) # False
import java.util.*;
public class Solution {
public static boolean containsDuplicate(int[] nums) {
Set<Integer> seen = new HashSet<>();
for (int num : nums) {
if (seen.contains(num)) {
return true;
}
seen.add(num);
}
return false;
}
public static void main(String[] args) {
System.out.println(containsDuplicate(new int[]{1, 2, 3, 1})); // true
System.out.println(containsDuplicate(new int[]{1, 2, 3, 4})); // false
}
}
Maximum Subarray Sum
Find the contiguous subarray with the largest sum (Kadane's algorithm). The array may contain negative numbers.
def max_subarray(nums):
current = nums[0]
best = nums[0]
for num in nums[1:]:
# YOUR CODE HERE: update current and best
pass
return best
print(max_subarray([-2, 1, -3, 4, -1, 2, 1, -5, 4])) # 6
print(max_subarray([1])) # 1
print(max_subarray([-1, -2, -3])) # -1
public class Solution {
public static int maxSubarray(int[] nums) {
int current = nums[0];
int best = nums[0];
for (int i = 1; i < nums.length; i++) {
// YOUR CODE HERE: update current and best
}
return best;
}
public static void main(String[] args) {
System.out.println(maxSubarray(new int[]{-2, 1, -3, 4, -1, 2, 1, -5, 4})); // 6
System.out.println(maxSubarray(new int[]{1})); // 1
System.out.println(maxSubarray(new int[]{-1, -2, -3})); // -1
}
}
current = max(num, current + num). Then update best = max(best, current).def max_subarray(nums):
current = nums[0]
best = nums[0]
for num in nums[1:]:
current = max(num, current + num)
best = max(best, current)
return best
print(max_subarray([-2, 1, -3, 4, -1, 2, 1, -5, 4])) # 6
print(max_subarray([1])) # 1
print(max_subarray([-1, -2, -3])) # -1
public class Solution {
public static int maxSubarray(int[] nums) {
int current = nums[0];
int best = nums[0];
for (int i = 1; i < nums.length; i++) {
current = Math.max(nums[i], current + nums[i]);
best = Math.max(best, current);
}
return best;
}
public static void main(String[] args) {
System.out.println(maxSubarray(new int[]{-2, 1, -3, 4, -1, 2, 1, -5, 4})); // 6
System.out.println(maxSubarray(new int[]{1})); // 1
System.out.println(maxSubarray(new int[]{-1, -2, -3})); // -1
}
}
3. Strings
Palindrome Detection
def is_palindrome(s: str) -> bool:
"""O(n) time, O(1) space — two pointers."""
left, right = 0, len(s) - 1
while left < right:
if s[left] != s[right]:
return False
left += 1
right -= 1
return True
def longest_palindrome(s: str) -> str:
"""Expand from center — handles both odd and even length palindromes.
Time: O(n^2) Space: O(1)
"""
def expand(left: int, right: int) -> str:
while left >= 0 and right < len(s) and s[left] == s[right]:
left -= 1
right += 1
return s[left + 1:right] # slice after over-expansion
best = ""
for i in range(len(s)):
odd = expand(i, i) # center at single character
even = expand(i, i + 1) # center between two characters
if len(odd) > len(best):
best = odd
if len(even) > len(best):
best = even
return best
Anagram Patterns
from collections import Counter
def is_anagram(s: str, t: str) -> bool:
"""Time: O(n) Space: O(1) — at most 26 characters."""
return Counter(s) == Counter(t)
def find_all_anagrams(s: str, pattern: str) -> list[int]:
"""Return start indices of all anagrams of pattern in s.
Sliding window with frequency map.
Time: O(n) Space: O(1)
"""
result = []
p_count = Counter(pattern)
w_count: Counter = Counter()
k = len(pattern)
for right in range(len(s)):
w_count[s[right]] += 1
# Drop character leaving the window
if right >= k:
left_ch = s[right - k]
w_count[left_ch] -= 1
if w_count[left_ch] == 0:
del w_count[left_ch]
if w_count == p_count:
result.append(right - k + 1)
return result
KMP Algorithm
Knuth-Morris-Pratt finds all occurrences of a pattern in a text in O(n + m) time by precomputing a failure function (also called LPS — longest proper prefix that is also a suffix). This lets the algorithm skip re-examining characters after a mismatch.
def build_lps(pattern: str) -> list[int]:
"""Build LPS (failure function) array.
lps[i] = length of longest proper prefix of pattern[0..i]
that is also a suffix.
Time: O(m) Space: O(m)
"""
m = len(pattern)
lps = [0] * m
length = 0 # length of previous longest prefix-suffix
i = 1
while i < m:
if pattern[i] == pattern[length]:
length += 1
lps[i] = length
i += 1
else:
if length != 0:
length = lps[length - 1] # fall back (don't increment i)
else:
lps[i] = 0
i += 1
return lps
def kmp_search(text: str, pattern: str) -> list[int]:
"""Return all start indices where pattern occurs in text.
Time: O(n + m) Space: O(m)
"""
n, m = len(text), len(pattern)
if m == 0:
return []
lps = build_lps(pattern)
result = []
i = j = 0 # i = text index, j = pattern index
while i < n:
if text[i] == pattern[j]:
i += 1
j += 1
if j == m:
result.append(i - j)
j = lps[j - 1] # shift pattern using LPS
elif i < n and text[i] != pattern[j]:
if j != 0:
j = lps[j - 1] # skip already-matched prefix
else:
i += 1
return result
Common String Tricks
# 1. Reverse words in a sentence
def reverse_words(s: str) -> str:
return ' '.join(s.split()[::-1])
# 2. Check if string is a rotation of another
# "abcde" is a rotation of "cdeab"
def is_rotation(s1: str, s2: str) -> bool:
if len(s1) != len(s2):
return False
return s2 in (s1 + s1) # O(n) with KMP; O(n^2) with naive 'in'
# 3. Longest common prefix
def longest_common_prefix(strs: list[str]) -> str:
if not strs:
return ""
prefix = strs[0]
for s in strs[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""
return prefix
# 4. Encode / decode strings (interview favorite)
def encode(strs: list[str]) -> str:
# Format: "length#string" per word — handles all characters
return ''.join(f"{len(s)}#{s}" for s in strs)
def decode(data: str) -> list[str]:
result, i = [], 0
while i < len(data):
j = data.index('#', i)
length = int(data[i:j])
result.append(data[j + 1: j + 1 + length])
i = j + 1 + length
return result
+= is \(O(n^2)\). Always accumulate in a list and ''.join() at the end: \(O(n)\).
Valid Palindrome
Return True if a string is a palindrome, considering only alphanumeric characters and ignoring case. Use two pointers.
def is_palindrome(s):
left, right = 0, len(s) - 1
while left < right:
while left < right and not s[left].isalnum():
left += 1
while left < right and not s[right].isalnum():
right -= 1
# YOUR CODE HERE: compare and move pointers
return True
print(is_palindrome("A man, a plan, a canal: Panama")) # True
print(is_palindrome("race a car")) # False
public class Solution {
public static boolean isPalindrome(String s) {
int left = 0, right = s.length() - 1;
while (left < right) {
while (left < right && !Character.isLetterOrDigit(s.charAt(left))) left++;
while (left < right && !Character.isLetterOrDigit(s.charAt(right))) right--;
// YOUR CODE HERE: compare and move pointers
}
return true;
}
public static void main(String[] args) {
System.out.println(isPalindrome("A man, a plan, a canal: Panama")); // true
System.out.println(isPalindrome("race a car")); // false
}
}
s[left].lower() != s[right].lower(). If they differ, return False. Otherwise move both pointers inward: left += 1; right -= 1.def is_palindrome(s):
left, right = 0, len(s) - 1
while left < right:
while left < right and not s[left].isalnum():
left += 1
while left < right and not s[right].isalnum():
right -= 1
if s[left].lower() != s[right].lower():
return False
left += 1
right -= 1
return True
print(is_palindrome("A man, a plan, a canal: Panama")) # True
print(is_palindrome("race a car")) # False
public class Solution {
public static boolean isPalindrome(String s) {
int left = 0, right = s.length() - 1;
while (left < right) {
while (left < right && !Character.isLetterOrDigit(s.charAt(left))) left++;
while (left < right && !Character.isLetterOrDigit(s.charAt(right))) right--;
if (Character.toLowerCase(s.charAt(left)) != Character.toLowerCase(s.charAt(right))) {
return false;
}
left++;
right--;
}
return true;
}
public static void main(String[] args) {
System.out.println(isPalindrome("A man, a plan, a canal: Panama")); // true
System.out.println(isPalindrome("race a car")); // false
}
}
Reverse String In-Place
Reverse a list of characters in-place using two pointers. Do not allocate extra space for another array.
def reverse_string(s):
left, right = 0, len(s) - 1
while left < right:
# YOUR CODE HERE: swap s[left] and s[right], then move pointers
pass
chars = ['h', 'e', 'l', 'l', 'o']
reverse_string(chars)
print(chars) # ['o', 'l', 'l', 'e', 'h']
import java.util.*;
public class Solution {
public static void reverseString(char[] s) {
int left = 0, right = s.length - 1;
while (left < right) {
// YOUR CODE HERE: swap s[left] and s[right], then move pointers
}
}
public static void main(String[] args) {
char[] chars = {'h', 'e', 'l', 'l', 'o'};
reverseString(chars);
System.out.println(Arrays.toString(chars)); // [o, l, l, e, h]
}
}
s[left] and s[right] using a temporary variable (or Python's tuple swap: s[left], s[right] = s[right], s[left]). Then increment left and decrement right.def reverse_string(s):
left, right = 0, len(s) - 1
while left < right:
s[left], s[right] = s[right], s[left]
left += 1
right -= 1
chars = ['h', 'e', 'l', 'l', 'o']
reverse_string(chars)
print(chars) # ['o', 'l', 'l', 'e', 'h']
import java.util.*;
public class Solution {
public static void reverseString(char[] s) {
int left = 0, right = s.length - 1;
while (left < right) {
char temp = s[left];
s[left] = s[right];
s[right] = temp;
left++;
right--;
}
}
public static void main(String[] args) {
char[] chars = {'h', 'e', 'l', 'l', 'o'};
reverseString(chars);
System.out.println(Arrays.toString(chars)); // [o, l, l, e, h]
}
}
First Unique Character
Given a string, return the index of the first non-repeating character. Return -1 if none exists. Aim for O(n) using a frequency map.
def first_uniq_char(s):
freq = {}
for ch in s:
freq[ch] = freq.get(ch, 0) + 1
# YOUR CODE HERE: find index of first char with freq == 1
return -1
print(first_uniq_char("leetcode")) # 0
print(first_uniq_char("loveleet")) # 1
print(first_uniq_char("aabb")) # -1
import java.util.*;
public class Solution {
public static int firstUniqChar(String s) {
Map<Character, Integer> freq = new HashMap<>();
for (char ch : s.toCharArray()) {
freq.put(ch, freq.getOrDefault(ch, 0) + 1);
}
// YOUR CODE HERE: find index of first char with freq == 1
return -1;
}
public static void main(String[] args) {
System.out.println(firstUniqChar("leetcode")); // 0
System.out.println(firstUniqChar("loveleet")); // 1
System.out.println(firstUniqChar("aabb")); // -1
}
}
enumerate(s). Return the index i as soon as freq[s[i]] == 1.def first_uniq_char(s):
freq = {}
for ch in s:
freq[ch] = freq.get(ch, 0) + 1
for i, ch in enumerate(s):
if freq[ch] == 1:
return i
return -1
print(first_uniq_char("leetcode")) # 0
print(first_uniq_char("loveleet")) # 1
print(first_uniq_char("aabb")) # -1
import java.util.*;
public class Solution {
public static int firstUniqChar(String s) {
Map<Character, Integer> freq = new HashMap<>();
for (char ch : s.toCharArray()) {
freq.put(ch, freq.getOrDefault(ch, 0) + 1);
}
for (int i = 0; i < s.length(); i++) {
if (freq.get(s.charAt(i)) == 1) {
return i;
}
}
return -1;
}
public static void main(String[] args) {
System.out.println(firstUniqChar("leetcode")); // 0
System.out.println(firstUniqChar("loveleet")); // 1
System.out.println(firstUniqChar("aabb")); // -1
}
}
4. Linked Lists
class ListNode:
def __init__(self, val: int = 0, next: 'ListNode | None' = None):
self.val = val
self.next = next
Fast / Slow Pointer (Floyd's)
def has_cycle(head: ListNode | None) -> bool:
"""Detect cycle. Time: O(n) Space: O(1)"""
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
if slow is fast:
return True
return False
def find_cycle_start(head: ListNode | None) -> ListNode | None:
"""Find the node where a cycle begins.
Time: O(n) Space: O(1)
Insight: after detecting meeting point, reset one pointer to head.
Move both one step at a time — they meet at cycle start.
"""
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
if slow is fast:
# Found meeting point
slow = head
while slow is not fast:
slow = slow.next
fast = fast.next
return slow # cycle start
return None
def find_middle(head: ListNode | None) -> ListNode | None:
"""Find middle node (for even length: returns second middle).
Time: O(n) Space: O(1)
"""
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
return slow
Reverse Linked List
def reverse_iterative(head: ListNode | None) -> ListNode | None:
"""Time: O(n) Space: O(1)"""
prev = None
curr = head
while curr:
next_node = curr.next # save next before overwriting
curr.next = prev # reverse the link
prev = curr # advance prev
curr = next_node # advance curr
return prev # prev is the new head
def reverse_recursive(head: ListNode | None) -> ListNode | None:
"""Time: O(n) Space: O(n) — call stack depth."""
if head is None or head.next is None:
return head
new_head = reverse_recursive(head.next)
head.next.next = head # reverse the link
head.next = None # don't forget to nullify
return new_head
def reverse_between(head: ListNode | None, left: int, right: int) -> ListNode | None:
"""Reverse sublist from position left to right (1-indexed).
Time: O(n) Space: O(1)
"""
dummy = ListNode(0, head)
prev = dummy
# Move prev to node just before 'left'
for _ in range(left - 1):
prev = prev.next
curr = prev.next
for _ in range(right - left):
next_node = curr.next
curr.next = next_node.next
next_node.next = prev.next
prev.next = next_node
return dummy.next
Merge Two Sorted Lists
def merge_two_sorted(l1: ListNode | None, l2: ListNode | None) -> ListNode | None:
"""Time: O(n + m) Space: O(1)
Dummy head avoids edge cases when the result head is uncertain.
"""
dummy = ListNode(0)
curr = dummy
while l1 and l2:
if l1.val <= l2.val:
curr.next = l1
l1 = l1.next
else:
curr.next = l2
l2 = l2.next
curr = curr.next
curr.next = l1 or l2 # attach remaining
return dummy.next
Dummy Head Trick & Other Patterns
def remove_nth_from_end(head: ListNode | None, n: int) -> ListNode | None:
"""Remove the nth node from the end in one pass.
Time: O(L) Space: O(1)
"""
dummy = ListNode(0, head)
fast = slow = dummy
# Advance fast by n+1 steps so the gap is n
for _ in range(n + 1):
fast = fast.next
# Move both until fast reaches end
while fast:
fast = fast.next
slow = slow.next
slow.next = slow.next.next # unlink nth from end
return dummy.next
def reorder_list(head: ListNode | None) -> None:
"""Reorder L0 -> L1 -> ... -> Ln into L0 -> Ln -> L1 -> Ln-1 -> ...
In-place. Time: O(n) Space: O(1)
Three steps: find middle, reverse second half, merge.
"""
if not head:
return
# Step 1: find middle
slow = fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
# Step 2: reverse second half
second = slow.next
slow.next = None # cut the list
prev = None
while second:
tmp = second.next
second.next = prev
prev = second
second = tmp
# Step 3: interleave
first, second = head, prev
while second:
tmp1, tmp2 = first.next, second.next
first.next = second
second.next = tmp1
first, second = tmp1, tmp2
- Always use a dummy head when the result head might change
- Draw the pointer diagram before coding — one missed
.nextbreaks the whole list - Check: empty list, single node, two nodes, cycle at tail vs. mid
- Fast/slow pointer: fast moves 2 steps, slow moves 1 — they meet in O(n)
Reverse Linked List
Reverse a singly linked list iteratively. Fill in the body of reverse_list.
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def to_list(head):
result = []
while head:
result.append(head.val)
head = head.next
return result
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def reverse_list(head):
prev = None
curr = head
while curr:
# YOUR CODE HERE: save next, point curr.next to prev, advance
pass
return prev
print(to_list(reverse_list(from_list([1, 2, 3, 4, 5])))) # [5, 4, 3, 2, 1]
print(to_list(reverse_list(from_list([1, 2])))) # [2, 1]
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
static String toList(ListNode head) {
StringBuilder sb = new StringBuilder("[");
while (head != null) {
sb.append(head.val);
if (head.next != null) sb.append(", ");
head = head.next;
}
return sb.append("]").toString();
}
public static ListNode reverseList(ListNode head) {
ListNode prev = null;
ListNode curr = head;
while (curr != null) {
// YOUR CODE HERE: save next, point curr.next to prev, advance
}
return prev;
}
public static void main(String[] args) {
System.out.println(toList(reverseList(fromList(new int[]{1, 2, 3, 4, 5})))); // [5, 4, 3, 2, 1]
System.out.println(toList(reverseList(fromList(new int[]{1, 2})))); // [2, 1]
}
}
next_node = curr.next, then set curr.next = prev to reverse the pointer. Advance: prev = curr, curr = next_node. Repeat until curr is None.class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def to_list(head):
result = []
while head:
result.append(head.val)
head = head.next
return result
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def reverse_list(head):
prev = None
curr = head
while curr:
next_node = curr.next
curr.next = prev
prev = curr
curr = next_node
return prev
print(to_list(reverse_list(from_list([1, 2, 3, 4, 5])))) # [5, 4, 3, 2, 1]
print(to_list(reverse_list(from_list([1, 2])))) # [2, 1]
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
static String toList(ListNode head) {
StringBuilder sb = new StringBuilder("[");
while (head != null) {
sb.append(head.val);
if (head.next != null) sb.append(", ");
head = head.next;
}
return sb.append("]").toString();
}
public static ListNode reverseList(ListNode head) {
ListNode prev = null;
ListNode curr = head;
while (curr != null) {
ListNode nextNode = curr.next;
curr.next = prev;
prev = curr;
curr = nextNode;
}
return prev;
}
public static void main(String[] args) {
System.out.println(toList(reverseList(fromList(new int[]{1, 2, 3, 4, 5})))); // [5, 4, 3, 2, 1]
System.out.println(toList(reverseList(fromList(new int[]{1, 2})))); // [2, 1]
}
}
Middle of Linked List
Return the middle node of a linked list. For even-length lists, return the second middle node. Use the fast/slow pointer technique.
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def middle(head):
slow = head
fast = head
while fast and fast.next:
# YOUR CODE HERE: advance slow by 1, fast by 2
pass
return slow
print(middle(from_list([1, 2, 3, 4, 5])).val) # 3
print(middle(from_list([1, 2, 3, 4])).val) # 3 (second middle)
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
public static ListNode middle(ListNode head) {
ListNode slow = head;
ListNode fast = head;
while (fast != null && fast.next != null) {
// YOUR CODE HERE: advance slow by 1, fast by 2
}
return slow;
}
public static void main(String[] args) {
System.out.println(middle(fromList(new int[]{1, 2, 3, 4, 5})).val); // 3
System.out.println(middle(fromList(new int[]{1, 2, 3, 4})).val); // 3
}
}
slow = slow.next and fast = fast.next.next each iteration. When fast reaches the end, slow is exactly at the middle.class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def middle(head):
slow = head
fast = head
while fast and fast.next:
slow = slow.next
fast = fast.next.next
return slow
print(middle(from_list([1, 2, 3, 4, 5])).val) # 3
print(middle(from_list([1, 2, 3, 4])).val) # 3 (second middle)
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
public static ListNode middle(ListNode head) {
ListNode slow = head;
ListNode fast = head;
while (fast != null && fast.next != null) {
slow = slow.next;
fast = fast.next.next;
}
return slow;
}
public static void main(String[] args) {
System.out.println(middle(fromList(new int[]{1, 2, 3, 4, 5})).val); // 3
System.out.println(middle(fromList(new int[]{1, 2, 3, 4})).val); // 3
}
}
Merge Two Sorted Lists
Merge two sorted linked lists and return the sorted result. Use a dummy head node to simplify the logic.
class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def to_list(head):
result = []
while head:
result.append(head.val)
head = head.next
return result
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def merge(l1, l2):
dummy = ListNode(0)
curr = dummy
while l1 and l2:
# YOUR CODE HERE: attach smaller node, advance that pointer
pass
# YOUR CODE HERE: attach remaining nodes
return dummy.next
print(to_list(merge(from_list([1, 3, 5]), from_list([2, 4, 6])))) # [1, 2, 3, 4, 5, 6]
print(to_list(merge(from_list([1, 2]), from_list([3, 4])))) # [1, 2, 3, 4]
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
static String toList(ListNode head) {
StringBuilder sb = new StringBuilder("[");
while (head != null) {
sb.append(head.val);
if (head.next != null) sb.append(", ");
head = head.next;
}
return sb.append("]").toString();
}
public static ListNode merge(ListNode l1, ListNode l2) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
while (l1 != null && l2 != null) {
// YOUR CODE HERE: attach smaller node, advance that pointer
}
// YOUR CODE HERE: attach remaining nodes
return dummy.next;
}
public static void main(String[] args) {
System.out.println(toList(merge(fromList(new int[]{1, 3, 5}), fromList(new int[]{2, 4, 6})))); // [1, 2, 3, 4, 5, 6]
System.out.println(toList(merge(fromList(new int[]{1, 2}), fromList(new int[]{3, 4})))); // [1, 2, 3, 4]
}
}
l1.val and l2.val. Attach the smaller one to curr.next and advance that list's pointer. After the loop, attach whichever list still has nodes with curr.next = l1 or l2.class ListNode:
def __init__(self, val=0, next=None):
self.val = val
self.next = next
def to_list(head):
result = []
while head:
result.append(head.val)
head = head.next
return result
def from_list(arr):
dummy = ListNode(0)
curr = dummy
for v in arr:
curr.next = ListNode(v)
curr = curr.next
return dummy.next
def merge(l1, l2):
dummy = ListNode(0)
curr = dummy
while l1 and l2:
if l1.val <= l2.val:
curr.next = l1
l1 = l1.next
else:
curr.next = l2
l2 = l2.next
curr = curr.next
curr.next = l1 if l1 else l2
return dummy.next
print(to_list(merge(from_list([1, 3, 5]), from_list([2, 4, 6])))) # [1, 2, 3, 4, 5, 6]
print(to_list(merge(from_list([1, 2]), from_list([3, 4])))) # [1, 2, 3, 4]
public class Solution {
static class ListNode {
int val;
ListNode next;
ListNode(int val) { this.val = val; }
}
static ListNode fromList(int[] arr) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
return dummy.next;
}
static String toList(ListNode head) {
StringBuilder sb = new StringBuilder("[");
while (head != null) {
sb.append(head.val);
if (head.next != null) sb.append(", ");
head = head.next;
}
return sb.append("]").toString();
}
public static ListNode merge(ListNode l1, ListNode l2) {
ListNode dummy = new ListNode(0);
ListNode curr = dummy;
while (l1 != null && l2 != null) {
if (l1.val <= l2.val) {
curr.next = l1;
l1 = l1.next;
} else {
curr.next = l2;
l2 = l2.next;
}
curr = curr.next;
}
curr.next = (l1 != null) ? l1 : l2;
return dummy.next;
}
public static void main(String[] args) {
System.out.println(toList(merge(fromList(new int[]{1, 3, 5}), fromList(new int[]{2, 4, 6})))); // [1, 2, 3, 4, 5, 6]
System.out.println(toList(merge(fromList(new int[]{1, 2}), fromList(new int[]{3, 4})))); // [1, 2, 3, 4]
}
}
5. Stacks & Queues
Monotonic Stack
A monotonic stack maintains elements in increasing (or decreasing) order by popping elements that violate the invariant. The classic application is Next Greater Element: for each element, find the first element to its right that is strictly larger.
def next_greater_element(nums: list[int]) -> list[int]:
"""For each nums[i], find the next greater element to its right.
Returns -1 if none exists.
Time: O(n) Space: O(n)
Invariant: the stack holds indices of elements whose next-greater
has NOT yet been found, in DECREASING order of value.
"""
n = len(nums)
result = [-1] * n
stack: list[int] = [] # indices, monotonically decreasing values
for i in range(n):
# Current element is the "next greater" for everything it beats
while stack and nums[i] > nums[stack[-1]]:
idx = stack.pop()
result[idx] = nums[i]
stack.append(i)
return result
def largest_rectangle_histogram(heights: list[int]) -> int:
"""Largest rectangle in histogram.
Time: O(n) Space: O(n)
Stack holds indices of bars in increasing height order.
When we find a shorter bar, we pop and compute rectangle width.
"""
stack: list[int] = []
max_area = 0
heights = heights + [0] # sentinel forces flush at end
for i, h in enumerate(heights):
start = i
while stack and heights[stack[-1]] > h:
idx = stack.pop()
width = i - (stack[-1] + 1 if stack else 0)
max_area = max(max_area, heights[idx] * width)
start = idx # rectangle can extend back to this index
stack.append(i)
return max_area
def daily_temperatures(temps: list[int]) -> list[int]:
"""Days until a warmer temperature; 0 if none.
Classic monotonic stack application.
Time: O(n) Space: O(n)
"""
result = [0] * len(temps)
stack: list[int] = [] # indices, decreasing temperatures
for i, t in enumerate(temps):
while stack and t > temps[stack[-1]]:
idx = stack.pop()
result[idx] = i - idx
stack.append(i)
return result
Valid Parentheses
def is_valid(s: str) -> bool:
"""Check balanced brackets: ()[]{}
Time: O(n) Space: O(n)
"""
stack: list[str] = []
matching = {')': '(', ']': '[', '}': '{'}
for ch in s:
if ch in matching: # closing bracket
if not stack or stack[-1] != matching[ch]:
return False
stack.pop()
else: # opening bracket
stack.append(ch)
return len(stack) == 0
def min_remove_for_valid(s: str) -> str:
"""Remove minimum parentheses to make string valid.
Time: O(n) Space: O(n)
"""
stack: list[int] = [] # indices of unmatched '('
to_remove: set[int] = set()
for i, ch in enumerate(s):
if ch == '(':
stack.append(i)
elif ch == ')':
if stack:
stack.pop()
else:
to_remove.add(i) # unmatched ')'
to_remove.update(stack) # unmatched '('
return ''.join(ch for i, ch in enumerate(s) if i not in to_remove)
Sliding Window Maximum (Monotonic Deque)
from collections import deque
def sliding_window_maximum(nums: list[int], k: int) -> list[int]:
"""Maximum of each window of size k.
Time: O(n) Space: O(k)
Deque stores indices in DECREASING order of value.
Front is always the index of the current window maximum.
"""
result: list[int] = []
dq: deque[int] = deque() # indices, values decreasing front->back
for i, x in enumerate(nums):
# Remove indices outside the window
while dq and dq[0] < i - k + 1:
dq.popleft()
# Maintain decreasing invariant — pop smaller elements from back
while dq and nums[dq[-1]] < x:
dq.pop()
dq.append(i)
if i >= k - 1: # window is fully formed
result.append(nums[dq[0]])
return result
Min Stack
class MinStack:
"""Stack that supports O(1) getMin().
Strategy: pair each pushed value with the minimum at that point.
Time: O(1) all ops Space: O(n)
"""
def __init__(self):
self._stack: list[tuple[int, int]] = [] # (value, current_min)
def push(self, val: int) -> None:
current_min = min(val, self._stack[-1][1]) if self._stack else val
self._stack.append((val, current_min))
def pop(self) -> None:
self._stack.pop()
def top(self) -> int:
return self._stack[-1][0]
def getMin(self) -> int:
return self._stack[-1][1]
Valid Parentheses
Given a string containing only ()[]{}, return True if the brackets are correctly matched and nested.
def is_valid(s):
stack = []
matching = {')': '(', ']': '[', '}': '{'}
for ch in s:
if ch in '([{':
stack.append(ch)
else:
# YOUR CODE HERE: check stack and matching closer
pass
return len(stack) == 0
print(is_valid("({[]})")) # True
print(is_valid("()[]{}")) # True
print(is_valid("({[)")) # False
import java.util.*;
public class Solution {
public static boolean isValid(String s) {
Deque<Character> stack = new ArrayDeque<>();
for (char ch : s.toCharArray()) {
if (ch == '(' || ch == '[' || ch == '{') {
stack.push(ch);
} else {
// YOUR CODE HERE: check stack and matching closer
}
}
return stack.isEmpty();
}
public static void main(String[] args) {
System.out.println(isValid("({[]})")); // true
System.out.println(isValid("()[]{}")); // true
System.out.println(isValid("({[)")); // false
}
}
stack[-1] != matching[ch], return False. Otherwise stack.pop(). At the end, the stack must be empty.def is_valid(s):
stack = []
matching = {')': '(', ']': '[', '}': '{'}
for ch in s:
if ch in '([{':
stack.append(ch)
else:
if not stack or stack[-1] != matching[ch]:
return False
stack.pop()
return len(stack) == 0
print(is_valid("({[]})")) # True
print(is_valid("()[]{}")) # True
print(is_valid("({[)")) # False
import java.util.*;
public class Solution {
public static boolean isValid(String s) {
Deque<Character> stack = new ArrayDeque<>();
for (char ch : s.toCharArray()) {
if (ch == '(' || ch == '[' || ch == '{') {
stack.push(ch);
} else {
if (stack.isEmpty()) return false;
char top = stack.pop();
if (ch == ')' && top != '(') return false;
if (ch == ']' && top != '[') return false;
if (ch == '}' && top != '{') return false;
}
}
return stack.isEmpty();
}
public static void main(String[] args) {
System.out.println(isValid("({[]})")); // true
System.out.println(isValid("()[]{}")); // true
System.out.println(isValid("({[)")); // false
}
}
Min Stack
Implement a stack that supports push, pop, and get_min all in O(1) time. Track the running minimum with a second stack.
class MinStack:
def __init__(self):
self.stack = []
self.min_stack = []
def push(self, val):
self.stack.append(val)
# YOUR CODE HERE: push to min_stack
def pop(self):
self.stack.pop()
# YOUR CODE HERE: pop from min_stack
def get_min(self):
# YOUR CODE HERE: return current minimum
pass
ms = MinStack()
ms.push(5)
ms.push(2)
ms.push(7)
print(ms.get_min()) # 2
ms.pop()
print(ms.get_min()) # 2
ms.pop()
print(ms.get_min()) # 5
import java.util.*;
public class Solution {
static class MinStack {
private Deque<Integer> stack = new ArrayDeque<>();
private Deque<Integer> minStack = new ArrayDeque<>();
public void push(int val) {
stack.push(val);
// YOUR CODE HERE: push to minStack
}
public void pop() {
stack.pop();
// YOUR CODE HERE: pop from minStack
}
public int getMin() {
// YOUR CODE HERE: return current minimum
return 0;
}
}
public static void main(String[] args) {
MinStack ms = new MinStack();
ms.push(5);
ms.push(2);
ms.push(7);
System.out.println(ms.getMin()); // 2
ms.pop();
System.out.println(ms.getMin()); // 2
ms.pop();
System.out.println(ms.getMin()); // 5
}
}
min(val, min_stack[-1]) to min_stack (or just val if min_stack is empty). When popping, pop from both stacks. get_min returns min_stack[-1].class MinStack:
def __init__(self):
self.stack = []
self.min_stack = []
def push(self, val):
self.stack.append(val)
current_min = val if not self.min_stack else min(val, self.min_stack[-1])
self.min_stack.append(current_min)
def pop(self):
self.stack.pop()
self.min_stack.pop()
def get_min(self):
return self.min_stack[-1]
ms = MinStack()
ms.push(5)
ms.push(2)
ms.push(7)
print(ms.get_min()) # 2
ms.pop()
print(ms.get_min()) # 2
ms.pop()
print(ms.get_min()) # 5
import java.util.*;
public class Solution {
static class MinStack {
private Deque<Integer> stack = new ArrayDeque<>();
private Deque<Integer> minStack = new ArrayDeque<>();
public void push(int val) {
stack.push(val);
int currentMin = minStack.isEmpty() ? val : Math.min(val, minStack.peek());
minStack.push(currentMin);
}
public void pop() {
stack.pop();
minStack.pop();
}
public int getMin() {
return minStack.peek();
}
}
public static void main(String[] args) {
MinStack ms = new MinStack();
ms.push(5);
ms.push(2);
ms.push(7);
System.out.println(ms.getMin()); // 2
ms.pop();
System.out.println(ms.getMin()); // 2
ms.pop();
System.out.println(ms.getMin()); // 5
}
}
Next Greater Element
For each element in the array, find the next greater element to its right. If none exists, use -1. Use a monotonic stack for O(n) time.
def next_greater(nums):
result = [-1] * len(nums)
stack = [] # stores indices
for i, num in enumerate(nums):
# YOUR CODE HERE: while stack and nums[stack[-1]] < num, resolve stack top
stack.append(i)
return result
print(next_greater([4, 1, 2, 7])) # [7, 2, 7, -1]
print(next_greater([1, 3, 2, 4])) # [3, 4, 4, -1]
import java.util.*;
public class Solution {
public static int[] nextGreater(int[] nums) {
int[] result = new int[nums.length];
Arrays.fill(result, -1);
Deque<Integer> stack = new ArrayDeque<>(); // stores indices
for (int i = 0; i < nums.length; i++) {
// YOUR CODE HERE: while stack not empty and nums[stack.peek()] < nums[i], resolve
stack.push(i);
}
return result;
}
public static void main(String[] args) {
System.out.println(Arrays.toString(nextGreater(new int[]{4, 1, 2, 7}))); // [7, 2, 7, -1]
System.out.println(Arrays.toString(nextGreater(new int[]{1, 3, 2, 4}))); // [3, 4, 4, -1]
}
}
i, pop all stack indices where nums[index] < nums[i] — the answer for those positions is nums[i]. Then push i.def next_greater(nums):
result = [-1] * len(nums)
stack = [] # stores indices
for i, num in enumerate(nums):
while stack and nums[stack[-1]] < num:
idx = stack.pop()
result[idx] = num
stack.append(i)
return result
print(next_greater([4, 1, 2, 7])) # [7, 2, 7, -1]
print(next_greater([1, 3, 2, 4])) # [3, 4, 4, -1]
import java.util.*;
public class Solution {
public static int[] nextGreater(int[] nums) {
int[] result = new int[nums.length];
Arrays.fill(result, -1);
Deque<Integer> stack = new ArrayDeque<>(); // stores indices
for (int i = 0; i < nums.length; i++) {
while (!stack.isEmpty() && nums[stack.peek()] < nums[i]) {
result[stack.pop()] = nums[i];
}
stack.push(i);
}
return result;
}
public static void main(String[] args) {
System.out.println(Arrays.toString(nextGreater(new int[]{4, 1, 2, 7}))); // [7, 2, 7, -1]
System.out.println(Arrays.toString(nextGreater(new int[]{1, 3, 2, 4}))); // [3, 4, 4, -1]
}
}
6. Binary Trees
class TreeNode:
def __init__(self, val: int = 0,
left: 'TreeNode | None' = None,
right: 'TreeNode | None' = None):
self.val = val
self.left = left
self.right = right
Traversals — Iterative & Recursive
from collections import deque
# ── Recursive (clean but uses O(h) call stack) ──────────────────────────────
def inorder_recursive(root: TreeNode | None) -> list[int]:
"""Left → Root → Right. Gives sorted order for BST."""
result: list[int] = []
def dfs(node: TreeNode | None) -> None:
if not node:
return
dfs(node.left)
result.append(node.val)
dfs(node.right)
dfs(root)
return result
def preorder_recursive(root: TreeNode | None) -> list[int]:
"""Root → Left → Right. Useful for tree copy / serialization."""
result: list[int] = []
def dfs(node: TreeNode | None) -> None:
if not node:
return
result.append(node.val)
dfs(node.left)
dfs(node.right)
dfs(root)
return result
def postorder_recursive(root: TreeNode | None) -> list[int]:
"""Left → Right → Root. Useful for deletion / bottom-up aggregation."""
result: list[int] = []
def dfs(node: TreeNode | None) -> None:
if not node:
return
dfs(node.left)
dfs(node.right)
result.append(node.val)
dfs(root)
return result
# ── Iterative (avoids stack overflow on deep trees) ─────────────────────────
def inorder_iterative(root: TreeNode | None) -> list[int]:
result: list[int] = []
stack: list[TreeNode] = []
curr = root
while curr or stack:
while curr: # go as far left as possible
stack.append(curr)
curr = curr.left
curr = stack.pop() # process node
result.append(curr.val)
curr = curr.right # move to right subtree
return result
def preorder_iterative(root: TreeNode | None) -> list[int]:
if not root:
return []
result: list[int] = []
stack = [root]
while stack:
node = stack.pop()
result.append(node.val)
if node.right: # push right first so left is processed first
stack.append(node.right)
if node.left:
stack.append(node.left)
return result
def postorder_iterative(root: TreeNode | None) -> list[int]:
"""Trick: reverse of modified preorder (Root→Right→Left reversed = Left→Right→Root)."""
if not root:
return []
result: list[int] = []
stack = [root]
while stack:
node = stack.pop()
result.append(node.val)
if node.left:
stack.append(node.left)
if node.right:
stack.append(node.right)
return result[::-1]
# ── BFS level-order ──────────────────────────────────────────────────────────
def level_order(root: TreeNode | None) -> list[list[int]]:
"""Returns list of levels. Time: O(n) Space: O(w) where w = max width."""
if not root:
return []
result: list[list[int]] = []
queue: deque[TreeNode] = deque([root])
while queue:
level_size = len(queue)
level: list[int] = []
for _ in range(level_size):
node = queue.popleft()
level.append(node.val)
if node.left:
queue.append(node.left)
if node.right:
queue.append(node.right)
result.append(level)
return result
DFS Patterns: Top-Down vs. Bottom-Up
# ── Top-Down: pass information from parent to children ──────────────────────
def max_depth_top_down(root: TreeNode | None) -> int:
"""Classic top-down: carry current depth, update global max."""
result = [0]
def dfs(node: TreeNode | None, depth: int) -> None:
if not node:
return
result[0] = max(result[0], depth)
dfs(node.left, depth + 1)
dfs(node.right, depth + 1)
dfs(root, 1)
return result[0]
def has_path_sum(root: TreeNode | None, target: int) -> bool:
"""Does any root-to-leaf path sum to target?"""
def dfs(node: TreeNode | None, remaining: int) -> bool:
if not node:
return False
remaining -= node.val
if not node.left and not node.right: # leaf
return remaining == 0
return dfs(node.left, remaining) or dfs(node.right, remaining)
return dfs(root, target)
# ── Bottom-Up: aggregate from children, return to parent ────────────────────
def max_depth_bottom_up(root: TreeNode | None) -> int:
"""Bottom-up is usually cleaner for aggregation problems."""
if not root:
return 0
return 1 + max(max_depth_bottom_up(root.left),
max_depth_bottom_up(root.right))
def diameter(root: TreeNode | None) -> int:
"""Longest path between any two nodes (may not pass through root).
Time: O(n) Space: O(h)
Key insight: at each node, diameter through it = left_depth + right_depth.
"""
best = [0]
def depth(node: TreeNode | None) -> int:
if not node:
return 0
left = depth(node.left)
right = depth(node.right)
best[0] = max(best[0], left + right) # update diameter
return 1 + max(left, right) # return depth to parent
depth(root)
return best[0]
def max_path_sum(root: TreeNode | None) -> int:
"""Maximum path sum between any two nodes. Values can be negative.
Time: O(n) Space: O(h)
"""
best = [float('-inf')]
def gain(node: TreeNode | None) -> int:
if not node:
return 0
left = max(gain(node.left), 0) # ignore negative branches
right = max(gain(node.right), 0)
best[0] = max(best[0], node.val + left + right)
return node.val + max(left, right) # can only extend in one direction
gain(root)
return best[0]
Lowest Common Ancestor (LCA)
def lca(root: TreeNode | None,
p: TreeNode, q: TreeNode) -> TreeNode | None:
"""LCA for general binary tree (not necessarily BST).
Time: O(n) Space: O(h)
Logic: if the current node is p or q, it must be the LCA of itself.
Otherwise the LCA is wherever both left and right return non-None.
"""
if not root or root is p or root is q:
return root
left = lca(root.left, p, q)
right = lca(root.right, p, q)
if left and right:
return root # p and q are in different subtrees
return left or right
Maximum Depth of Binary Tree
Return the maximum depth of a binary tree. Depth is the number of nodes along the longest path from root to leaf. Fill in the recursive body.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def max_depth(root):
if root is None:
return ___
left = max_depth(root.left)
right = max_depth(root.right)
return ___ + max(left, right)
root = TreeNode(3,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)))
print(max_depth(root)) # 3
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static int maxDepth(TreeNode root) {
if (root == null) return ___;
int left = maxDepth(root.left);
int right = maxDepth(root.right);
return ___ + Math.max(left, right);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(3,
new TreeNode(9),
new TreeNode(20, new TreeNode(15), new TreeNode(7)));
System.out.println(maxDepth(root)); // 3
}
}
None (or null) returns 0. For any other node return 1 + max(leftDepth, rightDepth).class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def max_depth(root):
if root is None:
return 0
left = max_depth(root.left)
right = max_depth(root.right)
return 1 + max(left, right)
root = TreeNode(3,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)))
print(max_depth(root)) # 3
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static int maxDepth(TreeNode root) {
if (root == null) return 0;
int left = maxDepth(root.left);
int right = maxDepth(root.right);
return 1 + Math.max(left, right);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(3,
new TreeNode(9),
new TreeNode(20, new TreeNode(15), new TreeNode(7)));
System.out.println(maxDepth(root)); // 3
}
}
Invert Binary Tree
Mirror a binary tree by swapping left and right children at every node. Fill in the swap and the recursive calls.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def invert(root):
if root is None:
return None
root.left, root.right = ___, ___
return root
def inorder(node, result=None):
if result is None:
result = []
if node:
inorder(node.left, result)
result.append(node.val)
inorder(node.right, result)
return result
root = TreeNode(3,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)))
invert(root)
print(inorder(root)) # [7, 20, 15, 3, 9]
import java.util.*;
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static TreeNode invert(TreeNode root) {
if (root == null) return null;
TreeNode tmp = root.left;
root.left = ___;
root.right = ___;
invert(root.left);
invert(root.right);
return root;
}
static List<Integer> inorder(TreeNode node) {
List<Integer> res = new ArrayList<>();
inorderHelper(node, res);
return res;
}
static void inorderHelper(TreeNode node, List<Integer> res) {
if (node == null) return;
inorderHelper(node.left, res);
res.add(node.val);
inorderHelper(node.right, res);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(3,
new TreeNode(9),
new TreeNode(20, new TreeNode(15), new TreeNode(7)));
invert(root);
System.out.println(inorder(root)); // [7, 20, 15, 3, 9]
}
}
root.left, root.right = invert(root.right), invert(root.left).class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def invert(root):
if root is None:
return None
root.left, root.right = invert(root.right), invert(root.left)
return root
def inorder(node, result=None):
if result is None:
result = []
if node:
inorder(node.left, result)
result.append(node.val)
inorder(node.right, result)
return result
root = TreeNode(3,
TreeNode(9),
TreeNode(20, TreeNode(15), TreeNode(7)))
invert(root)
print(inorder(root)) # [7, 20, 15, 3, 9]
import java.util.*;
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static TreeNode invert(TreeNode root) {
if (root == null) return null;
TreeNode tmp = root.left;
root.left = invert(root.right);
root.right = invert(tmp);
return root;
}
static List<Integer> inorder(TreeNode node) {
List<Integer> res = new ArrayList<>();
inorderHelper(node, res);
return res;
}
static void inorderHelper(TreeNode node, List<Integer> res) {
if (node == null) return;
inorderHelper(node.left, res);
res.add(node.val);
inorderHelper(node.right, res);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(3,
new TreeNode(9),
new TreeNode(20, new TreeNode(15), new TreeNode(7)));
invert(root);
System.out.println(inorder(root)); // [7, 20, 15, 3, 9]
}
}
Same Tree
Check whether two binary trees are structurally identical with the same node values. Fill in the three conditions.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def is_same(p, q):
if p is None and q is None:
return ___
if p is None or q is None:
return ___
if p.val != q.val:
return ___
return is_same(p.left, q.left) and is_same(p.right, q.right)
t1 = TreeNode(1, TreeNode(2), TreeNode(3))
t2 = TreeNode(1, TreeNode(2), TreeNode(3))
t3 = TreeNode(1, TreeNode(2), TreeNode(4))
print(is_same(t1, t2)) # True
print(is_same(t1, t3)) # False
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static boolean isSame(TreeNode p, TreeNode q) {
if (p == null && q == null) return ___;
if (p == null || q == null) return ___;
if (p.val != q.val) return ___;
return isSame(p.left, q.left) && isSame(p.right, q.right);
}
public static void main(String[] args) {
TreeNode t1 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
TreeNode t2 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
TreeNode t3 = new TreeNode(1, new TreeNode(2), new TreeNode(4));
System.out.println(isSame(t1, t2)); // true
System.out.println(isSame(t1, t3)); // false
}
}
True. Exactly one null → False. Values differ → False. Otherwise recurse both subtrees.class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def is_same(p, q):
if p is None and q is None:
return True
if p is None or q is None:
return False
if p.val != q.val:
return False
return is_same(p.left, q.left) and is_same(p.right, q.right)
t1 = TreeNode(1, TreeNode(2), TreeNode(3))
t2 = TreeNode(1, TreeNode(2), TreeNode(3))
t3 = TreeNode(1, TreeNode(2), TreeNode(4))
print(is_same(t1, t2)) # True
print(is_same(t1, t3)) # False
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static boolean isSame(TreeNode p, TreeNode q) {
if (p == null && q == null) return true;
if (p == null || q == null) return false;
if (p.val != q.val) return false;
return isSame(p.left, q.left) && isSame(p.right, q.right);
}
public static void main(String[] args) {
TreeNode t1 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
TreeNode t2 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
TreeNode t3 = new TreeNode(1, new TreeNode(2), new TreeNode(4));
System.out.println(isSame(t1, t2)); // true
System.out.println(isSame(t1, t3)); // false
}
}
7. Binary Search Trees
A BST satisfies: for every node, all left descendants < node.val < all right descendants. This gives O(h) search/insert/delete — O(log n) when balanced, O(n) when degenerate.
BST Operations
def search_bst(root: TreeNode | None, val: int) -> TreeNode | None:
"""Time: O(h) Space: O(1) iterative, O(h) recursive."""
while root:
if val == root.val:
return root
root = root.left if val < root.val else root.right
return None
def insert_bst(root: TreeNode | None, val: int) -> TreeNode:
"""Time: O(h) Space: O(h) recursive."""
if not root:
return TreeNode(val)
if val < root.val:
root.left = insert_bst(root.left, val)
elif val > root.val:
root.right = insert_bst(root.right, val)
return root # val == root.val: duplicate, no insert
def delete_bst(root: TreeNode | None, val: int) -> TreeNode | None:
"""Time: O(h) Space: O(h)
Three cases:
1. Node has no children: just remove it.
2. Node has one child: replace with child.
3. Node has two children: replace val with inorder successor
(smallest in right subtree), then delete successor.
"""
if not root:
return None
if val < root.val:
root.left = delete_bst(root.left, val)
elif val > root.val:
root.right = delete_bst(root.right, val)
else:
if not root.left:
return root.right
if not root.right:
return root.left
# Find inorder successor (min of right subtree)
successor = root.right
while successor.left:
successor = successor.left
root.val = successor.val
root.right = delete_bst(root.right, successor.val)
return root
BST Validation
def is_valid_bst(root: TreeNode | None) -> bool:
"""Validate using min/max bounds passed down the tree.
Time: O(n) Space: O(h)
Common mistake: only checking parent-child, missing grandparent constraint.
"""
def validate(node: TreeNode | None,
low: float, high: float) -> bool:
if not node:
return True
if node.val <= low or node.val >= high:
return False
return (validate(node.left, low, node.val) and
validate(node.right, node.val, high))
return validate(root, float('-inf'), float('inf'))
def kth_smallest_bst(root: TreeNode | None, k: int) -> int:
"""Inorder traversal gives sorted order; return kth element.
Time: O(h + k) Space: O(h)
"""
stack: list[TreeNode] = []
curr = root
count = 0
while curr or stack:
while curr:
stack.append(curr)
curr = curr.left
curr = stack.pop()
count += 1
if count == k:
return curr.val
curr = curr.right
return -1 # k out of range
Balanced BST from Sorted Array
def sorted_array_to_bst(nums: list[int]) -> TreeNode | None:
"""Build height-balanced BST from sorted array.
Choose middle element as root at each step.
Time: O(n) Space: O(log n) recursion stack.
"""
if not nums:
return None
mid = len(nums) // 2
root = TreeNode(nums[mid])
root.left = sorted_array_to_bst(nums[:mid])
root.right = sorted_array_to_bst(nums[mid + 1:])
return root
| Operation | Balanced BST | Unbalanced BST | Sorted Array |
|---|---|---|---|
| Search | O(log n) | O(n) | O(log n) |
| Insert | O(log n) | O(n) | O(n) |
| Delete | O(log n) | O(n) | O(n) |
| Min/Max | O(log n) | O(n) | O(1) |
| In-order | O(n) | O(n) | O(n) |
Validate BST
Return True if the tree is a valid BST. Every node must satisfy: all values in its left subtree are strictly less, all values in its right subtree are strictly greater. Fill in the bounds checks.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def is_valid_bst(root, lo=float('-inf'), hi=float('inf')):
if root is None:
return True
if root.val <= lo or root.val >= hi:
return ___
return (is_valid_bst(root.left, ___, root.val) and
is_valid_bst(root.right, root.val, ___))
valid = TreeNode(2, TreeNode(1), TreeNode(3))
invalid = TreeNode(1, TreeNode(2), TreeNode(3))
print(is_valid_bst(valid)) # True
print(is_valid_bst(invalid)) # False
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static boolean validate(TreeNode node, long lo, long hi) {
if (node == null) return true;
if (node.val <= lo || node.val >= hi) return ___;
return validate(node.left, ___, node.val) &&
validate(node.right, node.val, ___);
}
static boolean isValidBST(TreeNode root) {
return validate(root, Long.MIN_VALUE, Long.MAX_VALUE);
}
public static void main(String[] args) {
TreeNode valid = new TreeNode(2, new TreeNode(1), new TreeNode(3));
TreeNode invalid = new TreeNode(1, new TreeNode(2), new TreeNode(3));
System.out.println(isValidBST(valid)); // true
System.out.println(isValidBST(invalid)); // false
}
}
lo and hi bound down the recursion. When going left, the current node's value becomes the new upper bound (hi). When going right, it becomes the lower bound (lo).class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def is_valid_bst(root, lo=float('-inf'), hi=float('inf')):
if root is None:
return True
if root.val <= lo or root.val >= hi:
return False
return (is_valid_bst(root.left, lo, root.val) and
is_valid_bst(root.right, root.val, hi))
valid = TreeNode(2, TreeNode(1), TreeNode(3))
invalid = TreeNode(1, TreeNode(2), TreeNode(3))
print(is_valid_bst(valid)) # True
print(is_valid_bst(invalid)) # False
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static boolean validate(TreeNode node, long lo, long hi) {
if (node == null) return true;
if (node.val <= lo || node.val >= hi) return false;
return validate(node.left, lo, node.val) &&
validate(node.right, node.val, hi);
}
static boolean isValidBST(TreeNode root) {
return validate(root, Long.MIN_VALUE, Long.MAX_VALUE);
}
public static void main(String[] args) {
TreeNode valid = new TreeNode(2, new TreeNode(1), new TreeNode(3));
TreeNode invalid = new TreeNode(1, new TreeNode(2), new TreeNode(3));
System.out.println(isValidBST(valid)); // true
System.out.println(isValidBST(invalid)); // false
}
}
Search in BST
Search the BST for a target value. Return the node's value if found, or -1 if not. Fill in the directional recursive calls.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def search_bst(root, target):
if root is None:
return -1
if target == root.val:
return root.val
if target < root.val:
return search_bst(___, target)
return search_bst(___, target)
root = TreeNode(4,
TreeNode(2, TreeNode(1), TreeNode(3)),
TreeNode(7))
print(search_bst(root, 2)) # 2
print(search_bst(root, 5)) # -1
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static int searchBST(TreeNode root, int target) {
if (root == null) return -1;
if (target == root.val) return root.val;
if (target < root.val) return searchBST(___, target);
return searchBST(___, target);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(4,
new TreeNode(2, new TreeNode(1), new TreeNode(3)),
new TreeNode(7));
System.out.println(searchBST(root, 2)); // 2
System.out.println(searchBST(root, 5)); // -1
}
}
root.left). If greater, go right (root.right). The BST property guarantees only one path to check.class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def search_bst(root, target):
if root is None:
return -1
if target == root.val:
return root.val
if target < root.val:
return search_bst(root.left, target)
return search_bst(root.right, target)
root = TreeNode(4,
TreeNode(2, TreeNode(1), TreeNode(3)),
TreeNode(7))
print(search_bst(root, 2)) # 2
print(search_bst(root, 5)) # -1
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
}
static int searchBST(TreeNode root, int target) {
if (root == null) return -1;
if (target == root.val) return root.val;
if (target < root.val) return searchBST(root.left, target);
return searchBST(root.right, target);
}
public static void main(String[] args) {
TreeNode root = new TreeNode(4,
new TreeNode(2, new TreeNode(1), new TreeNode(3)),
new TreeNode(7));
System.out.println(searchBST(root, 2)); // 2
System.out.println(searchBST(root, 5)); // -1
}
}
Sorted Array to Height-Balanced BST
Convert a sorted array into a height-balanced BST. An inorder traversal of the result should reproduce the original array. Fill in the mid-point selection and recursive calls.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def sorted_array_to_bst(nums):
if not nums:
return None
mid = len(nums) ___
node = TreeNode(nums[mid])
node.left = sorted_array_to_bst(nums[___])
node.right = sorted_array_to_bst(nums[___])
return node
def inorder(node, result=None):
if result is None:
result = []
if node:
inorder(node.left, result)
result.append(node.val)
inorder(node.right, result)
return result
root = sorted_array_to_bst([-10, -3, 0, 5, 9])
print(inorder(root)) # [-10, -3, 0, 5, 9]
import java.util.*;
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
}
static TreeNode sortedArrayToBST(int[] nums, int lo, int hi) {
if (lo > hi) return null;
int mid = lo + (hi - lo) ___;
TreeNode node = new TreeNode(nums[mid]);
node.left = sortedArrayToBST(nums, ___, mid - 1);
node.right = sortedArrayToBST(nums, mid + 1, ___);
return node;
}
static List<Integer> inorder(TreeNode node) {
List<Integer> res = new ArrayList<>();
inorderHelper(node, res);
return res;
}
static void inorderHelper(TreeNode node, List<Integer> res) {
if (node == null) return;
inorderHelper(node.left, res);
res.add(node.val);
inorderHelper(node.right, res);
}
public static void main(String[] args) {
int[] nums = {-10, -3, 0, 5, 9};
TreeNode root = sortedArrayToBST(nums, 0, nums.length - 1);
System.out.println(inorder(root)); // [-10, -3, 0, 5, 9]
}
}
len(nums) // 2) as the root. Recurse on nums[:mid] for the left subtree and nums[mid+1:] for the right. This keeps the tree balanced.class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def sorted_array_to_bst(nums):
if not nums:
return None
mid = len(nums) // 2
node = TreeNode(nums[mid])
node.left = sorted_array_to_bst(nums[:mid])
node.right = sorted_array_to_bst(nums[mid+1:])
return node
def inorder(node, result=None):
if result is None:
result = []
if node:
inorder(node.left, result)
result.append(node.val)
inorder(node.right, result)
return result
root = sorted_array_to_bst([-10, -3, 0, 5, 9])
print(inorder(root)) # [-10, -3, 0, 5, 9]
import java.util.*;
public class Solution {
static class TreeNode {
int val;
TreeNode left, right;
TreeNode(int v) { val = v; }
}
static TreeNode sortedArrayToBST(int[] nums, int lo, int hi) {
if (lo > hi) return null;
int mid = lo + (hi - lo) / 2;
TreeNode node = new TreeNode(nums[mid]);
node.left = sortedArrayToBST(nums, lo, mid - 1);
node.right = sortedArrayToBST(nums, mid + 1, hi);
return node;
}
static List<Integer> inorder(TreeNode node) {
List<Integer> res = new ArrayList<>();
inorderHelper(node, res);
return res;
}
static void inorderHelper(TreeNode node, List<Integer> res) {
if (node == null) return;
inorderHelper(node.left, res);
res.add(node.val);
inorderHelper(node.right, res);
}
public static void main(String[] args) {
int[] nums = {-10, -3, 0, 5, 9};
TreeNode root = sortedArrayToBST(nums, 0, nums.length - 1);
System.out.println(inorder(root)); // [-10, -3, 0, 5, 9]
}
}
8. Heaps / Priority Queues
A min-heap guarantees the smallest element is always at the top (index 0). Python's heapq is a min-heap. For a max-heap, negate values on push/pop.
import heapq
# Min-heap basics
heap: list[int] = []
heapq.heappush(heap, 5)
heapq.heappush(heap, 1)
heapq.heappush(heap, 3)
print(heap[0]) # 1 — peek min, O(1)
smallest = heapq.heappop(heap) # 1, O(log n)
# Build from list in O(n)
nums = [3, 1, 4, 1, 5, 9]
heapq.heapify(nums) # transforms in-place
# Max-heap: negate values
max_heap: list[int] = []
heapq.heappush(max_heap, -5)
heapq.heappush(max_heap, -1)
largest = -heapq.heappop(max_heap) # 5
# Heap with tuples (priority, value)
tasks: list[tuple[int, str]] = []
heapq.heappush(tasks, (1, "low priority"))
heapq.heappush(tasks, (0, "urgent"))
priority, task = heapq.heappop(tasks) # (0, "urgent")
Top-K Pattern
def kth_largest(nums: list[int], k: int) -> int:
"""Kth largest element in unsorted array.
Min-heap of size k: the top is always the kth largest so far.
Time: O(n log k) Space: O(k)
Alternative: quickselect O(n) average — see binary search section.
"""
heap: list[int] = []
for x in nums:
heapq.heappush(heap, x)
if len(heap) > k:
heapq.heappop(heap) # evict smaller elements
return heap[0] # smallest of the k largest = kth largest
def top_k_elements(nums: list[int], k: int) -> list[int]:
"""Return k largest elements. Time: O(n log k) Space: O(k)"""
return heapq.nlargest(k, nums) # built-in uses heap internally
Merge K Sorted Lists
def merge_k_sorted(lists: list[ListNode | None]) -> ListNode | None:
"""Merge k sorted linked lists into one sorted list.
Time: O(n log k) where n = total nodes, k = number of lists.
Space: O(k) for the heap.
"""
dummy = ListNode(0)
curr = dummy
heap: list[tuple[int, int, ListNode]] = []
# Initialize heap with head of each list
# Use list index as tiebreaker to avoid comparing ListNodes
for i, node in enumerate(lists):
if node:
heapq.heappush(heap, (node.val, i, node))
while heap:
val, i, node = heapq.heappop(heap)
curr.next = node
curr = curr.next
if node.next:
heapq.heappush(heap, (node.next.val, i, node.next))
return dummy.next
Two Heaps — Median of a Stream
class MedianFinder:
"""Maintain running median using two heaps.
- max_heap: lower half of numbers (negate for max-heap in Python)
- min_heap: upper half of numbers
Invariant: len(max_heap) == len(min_heap) or
len(max_heap) == len(min_heap) + 1
Time: O(log n) add, O(1) find_median. Space: O(n).
"""
def __init__(self):
self.max_heap: list[int] = [] # lower half, negate for max
self.min_heap: list[int] = [] # upper half
def add_num(self, num: int) -> None:
# Always push to max_heap first
heapq.heappush(self.max_heap, -num)
# Balance: ensure max_heap top <= min_heap top
if (self.min_heap and
-self.max_heap[0] > self.min_heap[0]):
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
# Balance sizes: max_heap can have at most 1 more element
if len(self.max_heap) > len(self.min_heap) + 1:
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
elif len(self.min_heap) > len(self.max_heap):
val = heapq.heappop(self.min_heap)
heapq.heappush(self.max_heap, -val)
def find_median(self) -> float:
if len(self.max_heap) > len(self.min_heap):
return float(-self.max_heap[0])
return (-self.max_heap[0] + self.min_heap[0]) / 2.0
Kth Largest Element
Find the kth largest element in an unsorted array using a min-heap of size k. Fill in the heap push, the overflow pop condition, and the return value.
import heapq
def kth_largest(nums, k):
heap = []
for n in nums:
heapq.heappush(heap, ___)
if len(heap) > k:
heapq.heappop(___)
return heap[___]
print(kth_largest([3, 2, 1, 5, 6, 4], 2)) # 5
print(kth_largest([3, 2, 3, 1, 2, 4, 5, 5, 6], 4)) # 4
import java.util.*;
public class Solution {
static int kthLargest(int[] nums, int k) {
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (int n : nums) {
minHeap.offer(___);
if (minHeap.size() > k) {
minHeap.poll();
}
}
return minHeap.peek();
}
public static void main(String[] args) {
System.out.println(kthLargest(new int[]{3, 2, 1, 5, 6, 4}, 2)); // 5
System.out.println(kthLargest(new int[]{3, 2, 3, 1, 2, 4, 5, 5, 6}, 4)); // 4
}
}
heap[0]) is the kth largest.import heapq
def kth_largest(nums, k):
heap = []
for n in nums:
heapq.heappush(heap, n)
if len(heap) > k:
heapq.heappop(heap)
return heap[0]
print(kth_largest([3, 2, 1, 5, 6, 4], 2)) # 5
print(kth_largest([3, 2, 3, 1, 2, 4, 5, 5, 6], 4)) # 4
import java.util.*;
public class Solution {
static int kthLargest(int[] nums, int k) {
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (int n : nums) {
minHeap.offer(n);
if (minHeap.size() > k) {
minHeap.poll();
}
}
return minHeap.peek();
}
public static void main(String[] args) {
System.out.println(kthLargest(new int[]{3, 2, 1, 5, 6, 4}, 2)); // 5
System.out.println(kthLargest(new int[]{3, 2, 3, 1, 2, 4, 5, 5, 6}, 4)); // 4
}
}
Sort Using Heap
Sort an array in ascending order using heapq. Fill in the heapify call and the pop loop.
import heapq
def heap_sort(nums):
heapq.___(nums)
result = []
while nums:
result.append(heapq.___(nums))
return result
print(heap_sort([5, 3, 8, 1, 2])) # [1, 2, 3, 5, 8]
print(heap_sort([9, 0, -3, 7])) # [-3, 0, 7, 9]
import java.util.*;
public class Solution {
static int[] heapSort(int[] nums) {
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (int n : nums) minHeap.offer(n);
int[] result = new int[nums.length];
for (int i = 0; i < result.length; i++) {
result[i] = ___;
}
return result;
}
public static void main(String[] args) {
System.out.println(Arrays.toString(heapSort(new int[]{5, 3, 8, 1, 2}))); // [1, 2, 3, 5, 8]
System.out.println(Arrays.toString(heapSort(new int[]{9, 0, -3, 7}))); // [-3, 0, 7, 9]
}
}
heapq.heapify(nums) to turn the list into a min-heap in-place (O(n)). Then repeatedly call heapq.heappop(nums) to extract the smallest element each time.import heapq
def heap_sort(nums):
heapq.heapify(nums)
result = []
while nums:
result.append(heapq.heappop(nums))
return result
print(heap_sort([5, 3, 8, 1, 2])) # [1, 2, 3, 5, 8]
print(heap_sort([9, 0, -3, 7])) # [-3, 0, 7, 9]
import java.util.*;
public class Solution {
static int[] heapSort(int[] nums) {
PriorityQueue<Integer> minHeap = new PriorityQueue<>();
for (int n : nums) minHeap.offer(n);
int[] result = new int[nums.length];
for (int i = 0; i < result.length; i++) {
result[i] = minHeap.poll();
}
return result;
}
public static void main(String[] args) {
System.out.println(Arrays.toString(heapSort(new int[]{5, 3, 8, 1, 2}))); // [1, 2, 3, 5, 8]
System.out.println(Arrays.toString(heapSort(new int[]{9, 0, -3, 7}))); // [-3, 0, 7, 9]
}
}
Last Stone Weight
Smash the two heaviest stones together each round. If they differ, push the difference back. Repeat until at most one stone remains. Return its weight, or 0 if none left. Fill in the max-heap simulation using negated values.
import heapq
def last_stone(stones):
heap = [___ for s in stones] # negate for max-heap
heapq.heapify(heap)
while len(heap) > 1:
a = -heapq.heappop(heap) # largest
b = -heapq.heappop(heap) # second largest
if a != b:
heapq.heappush(heap, ___)
return -heap[0] if heap else ___
print(last_stone([2, 7, 4, 1, 8, 1])) # 1
print(last_stone([1, 1])) # 0
import java.util.*;
public class Solution {
static int lastStone(int[] stones) {
// Java PriorityQueue is min-heap; use Collections.reverseOrder() for max-heap
PriorityQueue<Integer> maxHeap = new PriorityQueue<>(Collections.reverseOrder());
for (int s : stones) maxHeap.offer(s);
while (maxHeap.size() > 1) {
int a = ___; // largest
int b = ___; // second largest
if (a != b) maxHeap.offer(a - b);
}
return maxHeap.isEmpty() ? 0 : maxHeap.peek();
}
public static void main(String[] args) {
System.out.println(lastStone(new int[]{2, 7, 4, 1, 8, 1})); // 1
System.out.println(lastStone(new int[]{1, 1})); // 0
}
}
heapq is a min-heap. Store negated weights so the most negative (= largest weight) is always at the top. Push -(a - b) when the stones differ. Return -heap[0] if one stone remains, else 0.import heapq
def last_stone(stones):
heap = [-s for s in stones]
heapq.heapify(heap)
while len(heap) > 1:
a = -heapq.heappop(heap)
b = -heapq.heappop(heap)
if a != b:
heapq.heappush(heap, -(a - b))
return -heap[0] if heap else 0
print(last_stone([2, 7, 4, 1, 8, 1])) # 1
print(last_stone([1, 1])) # 0
import java.util.*;
public class Solution {
static int lastStone(int[] stones) {
PriorityQueue<Integer> maxHeap = new PriorityQueue<>(Collections.reverseOrder());
for (int s : stones) maxHeap.offer(s);
while (maxHeap.size() > 1) {
int a = maxHeap.poll();
int b = maxHeap.poll();
if (a != b) maxHeap.offer(a - b);
}
return maxHeap.isEmpty() ? 0 : maxHeap.peek();
}
public static void main(String[] args) {
System.out.println(lastStone(new int[]{2, 7, 4, 1, 8, 1})); // 1
System.out.println(lastStone(new int[]{1, 1})); // 0
}
}
9. Binary Search
Binary search works on any monotone predicate — not just sorted arrays. The key is identifying the invariant: at the end, left and right have converged to the answer boundary.
Standard Template
def binary_search(nums: list[int], target: int) -> int:
"""Return index of target, or -1 if not found.
Time: O(log n) Space: O(1)
Invariant: target is in nums[left..right] if it exists.
"""
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) // 2 # avoids integer overflow vs (l+r)//2
if nums[mid] == target:
return mid
elif nums[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
Leftmost / Rightmost Occurrence
def search_leftmost(nums: list[int], target: int) -> int:
"""First (leftmost) index of target. Returns -1 if not found.
Time: O(log n) Space: O(1)
"""
left, right = 0, len(nums) - 1
result = -1
while left <= right:
mid = (left + right) // 2
if nums[mid] == target:
result = mid
right = mid - 1 # keep searching left
elif nums[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
def search_rightmost(nums: list[int], target: int) -> int:
"""Last (rightmost) index of target. Returns -1 if not found."""
left, right = 0, len(nums) - 1
result = -1
while left <= right:
mid = (left + right) // 2
if nums[mid] == target:
result = mid
left = mid + 1 # keep searching right
elif nums[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
def count_occurrences(nums: list[int], target: int) -> int:
left = search_leftmost(nums, target)
right = search_rightmost(nums, target)
if left == -1:
return 0
return right - left + 1
Binary Search on Answer
When the problem asks for the minimum/maximum value satisfying a condition and the feasibility function is monotone, binary search on the answer space directly.
def min_days_to_make_bouquets(bloom_day: list[int],
m: int, k: int) -> int:
"""Minimum days to make m bouquets, each needing k adjacent flowers.
Binary search on the answer: day range [min(bloom_day), max(bloom_day)].
Time: O(n log(max_day)) Space: O(1)
"""
def can_make(day: int) -> bool:
bouquets = consecutive = 0
for bd in bloom_day:
if bd <= day:
consecutive += 1
if consecutive == k:
bouquets += 1
consecutive = 0
else:
consecutive = 0
return bouquets >= m
if len(bloom_day) < m * k:
return -1
left, right = min(bloom_day), max(bloom_day)
while left < right:
mid = (left + right) // 2
if can_make(mid):
right = mid # valid, try smaller
else:
left = mid + 1 # invalid, need more days
return left
def split_array_largest_sum(nums: list[int], k: int) -> int:
"""Minimize the largest sum when splitting array into k subarrays.
Binary search on answer: sum range [max(nums), sum(nums)].
Time: O(n log(sum)) Space: O(1)
"""
def can_split(limit: int) -> bool:
parts = 1
running = 0
for x in nums:
if running + x > limit:
parts += 1
running = 0
running += x
return parts <= k
left, right = max(nums), sum(nums)
while left < right:
mid = (left + right) // 2
if can_split(mid):
right = mid
else:
left = mid + 1
return left
Rotated Sorted Array
def search_rotated(nums: list[int], target: int) -> int:
"""Search in rotated sorted array (no duplicates).
Time: O(log n) Space: O(1)
Key insight: one half of the array is always sorted.
Determine which half, check if target falls in it.
"""
left, right = 0, len(nums) - 1
while left <= right:
mid = (left + right) // 2
if nums[mid] == target:
return mid
if nums[left] <= nums[mid]: # left half is sorted
if nums[left] <= target < nums[mid]:
right = mid - 1
else:
left = mid + 1
else: # right half is sorted
if nums[mid] < target <= nums[right]:
left = mid + 1
else:
right = mid - 1
return -1
def find_min_rotated(nums: list[int]) -> int:
"""Find minimum element in rotated sorted array.
Time: O(log n) Space: O(1)
"""
left, right = 0, len(nums) - 1
while left < right:
mid = (left + right) // 2
if nums[mid] > nums[right]:
left = mid + 1 # min is in right half
else:
right = mid # mid could be the min
return nums[left]
Classic Binary Search
Find a target in a sorted array and return its index, or -1 if not present. Fill in the mid calculation and the narrowing conditions.
def binary_search(nums, target):
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) ___
if nums[mid] == target:
return ___
elif nums[mid] < target:
left = ___
else:
right = ___
return -1
print(binary_search([-1, 0, 3, 5, 9, 12], 9)) # 4
print(binary_search([-1, 0, 3, 5, 9, 12], 2)) # -1
print(binary_search([-1, 0, 3, 5, 9, 12], -1)) # 0
public class Solution {
static int binarySearch(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) ___;
if (nums[mid] == target) return ___;
else if (nums[mid] < target) left = ___;
else right = ___;
}
return -1;
}
public static void main(String[] args) {
int[] nums = {-1, 0, 3, 5, 9, 12};
System.out.println(binarySearch(nums, 9)); // 4
System.out.println(binarySearch(nums, 2)); // -1
System.out.println(binarySearch(nums, -1)); // 0
}
}
left + (right - left) // 2 to avoid overflow. If nums[mid] < target, the answer is to the right so set left = mid + 1. Otherwise set right = mid - 1.def binary_search(nums, target):
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) // 2
if nums[mid] == target:
return mid
elif nums[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
print(binary_search([-1, 0, 3, 5, 9, 12], 9)) # 4
print(binary_search([-1, 0, 3, 5, 9, 12], 2)) # -1
print(binary_search([-1, 0, 3, 5, 9, 12], -1)) # 0
public class Solution {
static int binarySearch(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
public static void main(String[] args) {
int[] nums = {-1, 0, 3, 5, 9, 12};
System.out.println(binarySearch(nums, 9)); // 4
System.out.println(binarySearch(nums, 2)); // -1
System.out.println(binarySearch(nums, -1)); // 0
}
}
Search Insert Position
Given a sorted array and a target, return the index if found, or the index where it would be inserted to keep the array sorted. Fill in the loop body and return value.
def search_insert(nums, target):
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) // 2
if nums[mid] == target:
return ___
elif nums[mid] < target:
left = ___
else:
right = ___
return ___ # insertion point when not found
print(search_insert([1, 3, 5, 6], 5)) # 2
print(search_insert([1, 3, 5, 6], 2)) # 1
print(search_insert([1, 3, 5, 6], 7)) # 4
print(search_insert([1, 3, 5, 6], 0)) # 0
public class Solution {
static int searchInsert(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] == target) return ___;
else if (nums[mid] < target) left = ___;
else right = ___;
}
return ___; // insertion point
}
public static void main(String[] args) {
int[] nums = {1, 3, 5, 6};
System.out.println(searchInsert(nums, 5)); // 2
System.out.println(searchInsert(nums, 2)); // 1
System.out.println(searchInsert(nums, 7)); // 4
System.out.println(searchInsert(nums, 0)); // 0
}
}
left. After the loop, left sits exactly at the position where the target would be inserted.def search_insert(nums, target):
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) // 2
if nums[mid] == target:
return mid
elif nums[mid] < target:
left = mid + 1
else:
right = mid - 1
return left
print(search_insert([1, 3, 5, 6], 5)) # 2
print(search_insert([1, 3, 5, 6], 2)) # 1
print(search_insert([1, 3, 5, 6], 7)) # 4
print(search_insert([1, 3, 5, 6], 0)) # 0
public class Solution {
static int searchInsert(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + (right - left) / 2;
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return left;
}
public static void main(String[] args) {
int[] nums = {1, 3, 5, 6};
System.out.println(searchInsert(nums, 5)); // 2
System.out.println(searchInsert(nums, 2)); // 1
System.out.println(searchInsert(nums, 7)); // 4
System.out.println(searchInsert(nums, 0)); // 0
}
}
First Bad Version
Versions 1 through n are released. Starting from some version, all subsequent versions are bad. Given is_bad(version), find the first bad version using as few calls as possible. Fill in the narrowing logic.
is_bad = lambda v: v >= 4 # first bad version is 4
def first_bad_version(n):
left, right = 1, n
while left < right:
mid = left + (right - left) // 2
if is_bad(mid):
right = ___ # mid could be first bad, keep it
else:
left = ___ # mid is good, first bad is after
return ___
print(first_bad_version(7)) # 4
print(first_bad_version(1)) # 1 (with is_bad always True)
public class Solution {
static boolean isBad(int version) {
return version >= 4; // first bad version is 4
}
static int firstBadVersion(int n) {
int left = 1, right = n;
while (left < right) {
int mid = left + (right - left) / 2;
if (isBad(mid)) {
right = ___; // could be first bad; keep it
} else {
left = ___; // mid is good; look right
}
}
return ___;
}
public static void main(String[] args) {
System.out.println(firstBadVersion(7)); // 4
}
}
is_bad(mid) is true, set right = mid (not mid - 1) because mid itself might be the first bad version. When false, set left = mid + 1. The loop terminates when left == right, which is the answer.is_bad = lambda v: v >= 4
def first_bad_version(n):
left, right = 1, n
while left < right:
mid = left + (right - left) // 2
if is_bad(mid):
right = mid
else:
left = mid + 1
return left
print(first_bad_version(7)) # 4
public class Solution {
static boolean isBad(int version) {
return version >= 4;
}
static int firstBadVersion(int n) {
int left = 1, right = n;
while (left < right) {
int mid = left + (right - left) / 2;
if (isBad(mid)) {
right = mid;
} else {
left = mid + 1;
}
}
return left;
}
public static void main(String[] args) {
System.out.println(firstBadVersion(7)); // 4
}
}
10. Graphs — BFS & DFS
Representations
from collections import defaultdict, deque
# Adjacency list (most common in interviews)
# Undirected graph
graph: dict[int, list[int]] = defaultdict(list)
edges = [(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)]
for u, v in edges:
graph[u].append(v)
graph[v].append(u)
# Weighted graph: store (neighbor, weight) pairs
weighted: dict[int, list[tuple[int, int]]] = defaultdict(list)
for u, v, w in [(0,1,4), (0,2,1), (1,3,2)]:
weighted[u].append((v, w))
weighted[v].append((u, w))
# Grid as implicit graph — 4-directional neighbors
def neighbors_4(r: int, c: int,
rows: int, cols: int) -> list[tuple[int, int]]:
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < rows and 0 <= nc < cols:
yield nr, nc
BFS Template
def bfs(graph: dict[int, list[int]], start: int) -> list[int]:
"""BFS: explores nodes level by level (shortest path in unweighted graph).
Time: O(V + E) Space: O(V)
"""
visited: set[int] = {start}
queue: deque[int] = deque([start])
order: list[int] = []
while queue:
node = queue.popleft()
order.append(node)
for neighbor in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return order
def bfs_shortest_path(graph: dict[int, list[int]],
start: int, end: int) -> int:
"""Shortest path length in unweighted graph. Returns -1 if unreachable.
Time: O(V + E) Space: O(V)
"""
if start == end:
return 0
visited: set[int] = {start}
queue: deque[tuple[int, int]] = deque([(start, 0)])
while queue:
node, dist = queue.popleft()
for neighbor in graph[node]:
if neighbor == end:
return dist + 1
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, dist + 1))
return -1
def bfs_grid(grid: list[list[int]],
start: tuple[int, int]) -> list[list[int]]:
"""BFS on a grid, returns distance from start to every reachable cell.
0 = open, 1 = wall.
Time: O(rows * cols) Space: O(rows * cols)
"""
rows, cols = len(grid), len(grid[0])
dist = [[-1] * cols for _ in range(rows)]
sr, sc = start
dist[sr][sc] = 0
queue: deque[tuple[int, int]] = deque([(sr, sc)])
while queue:
r, c = queue.popleft()
for nr, nc in neighbors_4(r, c, rows, cols):
if grid[nr][nc] == 0 and dist[nr][nc] == -1:
dist[nr][nc] = dist[r][c] + 1
queue.append((nr, nc))
return dist
DFS Template
def dfs_recursive(graph: dict[int, list[int]],
node: int,
visited: set[int] | None = None) -> list[int]:
"""DFS recursive. Time: O(V + E) Space: O(V) call stack."""
if visited is None:
visited = set()
visited.add(node)
order = [node]
for neighbor in graph[node]:
if neighbor not in visited:
order.extend(dfs_recursive(graph, neighbor, visited))
return order
def dfs_iterative(graph: dict[int, list[int]], start: int) -> list[int]:
"""DFS iterative using explicit stack.
Note: visits in slightly different order than recursive.
Time: O(V + E) Space: O(V)
"""
visited: set[int] = set()
stack: list[int] = [start]
order: list[int] = []
while stack:
node = stack.pop()
if node in visited:
continue
visited.add(node)
order.append(node)
# Push neighbors in reverse so original order is preserved
for neighbor in reversed(graph[node]):
if neighbor not in visited:
stack.append(neighbor)
return order
Topological Sort (Kahn's Algorithm)
def topological_sort(num_nodes: int,
edges: list[tuple[int, int]]) -> list[int]:
"""Kahn's algorithm — BFS-based topological sort.
Time: O(V + E) Space: O(V + E)
Returns empty list if graph has a cycle (not a DAG).
"""
graph: dict[int, list[int]] = defaultdict(list)
in_degree: list[int] = [0] * num_nodes
for u, v in edges:
graph[u].append(v)
in_degree[v] += 1
# Start with all nodes that have no incoming edges
queue: deque[int] = deque(
node for node in range(num_nodes) if in_degree[node] == 0
)
order: list[int] = []
while queue:
node = queue.popleft()
order.append(node)
for neighbor in graph[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
# If not all nodes processed, cycle exists
return order if len(order) == num_nodes else []
def course_schedule(num_courses: int,
prerequisites: list[list[int]]) -> bool:
"""Can all courses be finished? True if no cycle exists.
Classic topological sort application.
Time: O(V + E) Space: O(V + E)
"""
order = topological_sort(num_courses,
[(b, a) for a, b in prerequisites])
return len(order) == num_courses
Cycle Detection
def has_cycle_directed(num_nodes: int,
edges: list[tuple[int, int]]) -> bool:
"""Detect cycle in directed graph using DFS coloring.
Colors: 0=unvisited, 1=in current path, 2=fully processed.
Time: O(V + E) Space: O(V)
"""
graph: dict[int, list[int]] = defaultdict(list)
for u, v in edges:
graph[u].append(v)
color = [0] * num_nodes
def dfs(node: int) -> bool:
color[node] = 1 # mark as in-progress
for neighbor in graph[node]:
if color[neighbor] == 1:
return True # back edge = cycle
if color[neighbor] == 0 and dfs(neighbor):
return True
color[node] = 2 # fully processed
return False
return any(color[n] == 0 and dfs(n) for n in range(num_nodes))
def has_cycle_undirected(graph: dict[int, list[int]]) -> bool:
"""Detect cycle in undirected graph.
Track parent to avoid treating the edge back to parent as a cycle.
Time: O(V + E) Space: O(V)
"""
visited: set[int] = set()
def dfs(node: int, parent: int) -> bool:
visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
if dfs(neighbor, node):
return True
elif neighbor != parent:
return True # back edge to non-parent = cycle
return False
return any(dfs(n, -1) for n in graph if n not in visited)
Island Problems / Flood Fill
def num_islands(grid: list[list[str]]) -> int:
"""Count connected components of '1's.
Time: O(m * n) Space: O(m * n) worst case DFS stack.
"""
if not grid:
return 0
rows, cols = len(grid), len(grid[0])
count = 0
def dfs(r: int, c: int) -> None:
if r < 0 or r >= rows or c < 0 or c >= cols:
return
if grid[r][c] != '1':
return
grid[r][c] = '#' # mark visited in-place (restore if needed)
dfs(r + 1, c); dfs(r - 1, c)
dfs(r, c + 1); dfs(r, c - 1)
for r in range(rows):
for c in range(cols):
if grid[r][c] == '1':
dfs(r, c)
count += 1
return count
def max_area_of_island(grid: list[list[int]]) -> int:
"""Maximum area of an island. 1=land, 0=water.
Time: O(m * n) Space: O(m * n)
"""
rows, cols = len(grid), len(grid[0])
def dfs(r: int, c: int) -> int:
if r < 0 or r >= rows or c < 0 or c >= cols or grid[r][c] == 0:
return 0
grid[r][c] = 0 # mark visited
return 1 + dfs(r+1,c) + dfs(r-1,c) + dfs(r,c+1) + dfs(r,c-1)
return max(dfs(r, c)
for r in range(rows)
for c in range(cols)
if grid[r][c] == 1) or 0
def flood_fill(image: list[list[int]],
sr: int, sc: int, color: int) -> list[list[int]]:
"""Flood fill: repaint connected region starting from (sr, sc).
Time: O(m * n) Space: O(m * n)
"""
original = image[sr][sc]
if original == color:
return image
rows, cols = len(image), len(image[0])
def dfs(r: int, c: int) -> None:
if r < 0 or r >= rows or c < 0 or c >= cols:
return
if image[r][c] != original:
return
image[r][c] = color
dfs(r+1,c); dfs(r-1,c); dfs(r,c+1); dfs(r,c-1)
dfs(sr, sc)
return image
| Use BFS when | Use DFS when |
|---|---|
| Shortest path in unweighted graph | Detecting cycles |
| Level-order processing | Topological sort (DFS version) |
| Finding closest/nearest node | Exhaustive search / backtracking |
| Word ladder, 0-1 BFS | Connected components, flood fill |
| The graph is wide (many branches) | The graph is deep (long paths) |
BFS Shortest Path
Find the shortest path length (number of edges) between two nodes in an unweighted undirected graph. Fill in the BFS queue operations and distance tracking.
from collections import deque
def shortest_path(graph, start, end):
if start == end:
return 0
visited = {start}
queue = deque([(start, 0)]) # (node, distance)
while queue:
node, dist = queue.___( )
for neighbor in graph.get(node, []):
if neighbor == end:
return ___
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, ___))
return -1 # unreachable
graph = {0: [1, 2], 1: [0, 3], 2: [0, 3], 3: [1, 2, 4], 4: [3]}
print(shortest_path(graph, 0, 4)) # 3
print(shortest_path(graph, 0, 3)) # 2
print(shortest_path(graph, 0, 0)) # 0
import java.util.*;
public class Solution {
static int shortestPath(Map<Integer, List<Integer>> graph, int start, int end) {
if (start == end) return 0;
Set<Integer> visited = new HashSet<>();
Queue<int[]> queue = new LinkedList<>(); // {node, dist}
queue.offer(new int[]{start, 0});
visited.add(start);
while (!queue.isEmpty()) {
int[] curr = ___;
int node = curr[0], dist = curr[1];
for (int neighbor : graph.getOrDefault(node, Collections.emptyList())) {
if (neighbor == end) return ___;
if (!visited.contains(neighbor)) {
visited.add(neighbor);
queue.offer(new int[]{neighbor, ___});
}
}
}
return -1;
}
public static void main(String[] args) {
Map<Integer, List<Integer>> graph = new HashMap<>();
graph.put(0, Arrays.asList(1, 2));
graph.put(1, Arrays.asList(0, 3));
graph.put(2, Arrays.asList(0, 3));
graph.put(3, Arrays.asList(1, 2, 4));
graph.put(4, Arrays.asList(3));
System.out.println(shortestPath(graph, 0, 4)); // 3
System.out.println(shortestPath(graph, 0, 3)); // 2
System.out.println(shortestPath(graph, 0, 0)); // 0
}
}
queue.popleft() in Python (or queue.poll() in Java). When you find the destination, return dist + 1 (one more edge). Push neighbors with dist + 1.from collections import deque
def shortest_path(graph, start, end):
if start == end:
return 0
visited = {start}
queue = deque([(start, 0)])
while queue:
node, dist = queue.popleft()
for neighbor in graph.get(node, []):
if neighbor == end:
return dist + 1
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, dist + 1))
return -1
graph = {0: [1, 2], 1: [0, 3], 2: [0, 3], 3: [1, 2, 4], 4: [3]}
print(shortest_path(graph, 0, 4)) # 3
print(shortest_path(graph, 0, 3)) # 2
print(shortest_path(graph, 0, 0)) # 0
import java.util.*;
public class Solution {
static int shortestPath(Map<Integer, List<Integer>> graph, int start, int end) {
if (start == end) return 0;
Set<Integer> visited = new HashSet<>();
Queue<int[]> queue = new LinkedList<>();
queue.offer(new int[]{start, 0});
visited.add(start);
while (!queue.isEmpty()) {
int[] curr = queue.poll();
int node = curr[0], dist = curr[1];
for (int neighbor : graph.getOrDefault(node, Collections.emptyList())) {
if (neighbor == end) return dist + 1;
if (!visited.contains(neighbor)) {
visited.add(neighbor);
queue.offer(new int[]{neighbor, dist + 1});
}
}
}
return -1;
}
public static void main(String[] args) {
Map<Integer, List<Integer>> graph = new HashMap<>();
graph.put(0, Arrays.asList(1, 2));
graph.put(1, Arrays.asList(0, 3));
graph.put(2, Arrays.asList(0, 3));
graph.put(3, Arrays.asList(1, 2, 4));
graph.put(4, Arrays.asList(3));
System.out.println(shortestPath(graph, 0, 4)); // 3
System.out.println(shortestPath(graph, 0, 3)); // 2
System.out.println(shortestPath(graph, 0, 0)); // 0
}
}
Number of Connected Components
Given n nodes (labeled 0 to n-1) and a list of undirected edges, count the number of connected components. Fill in the BFS/DFS traversal and component counter.
from collections import deque, defaultdict
def count_components(n, edges):
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
graph[v].append(u)
visited = set()
components = ___
for node in range(n):
if node not in visited:
components ___
queue = deque([node])
visited.add(node)
while queue:
curr = queue.popleft()
for neighbor in graph[curr]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(___)
return components
print(count_components(5, [[0,1],[1,2],[3,4]])) # 2
print(count_components(5, [[0,1],[1,2],[2,3],[3,4]])) # 1
print(count_components(4, [])) # 4
import java.util.*;
public class Solution {
static int countComponents(int n, int[][] edges) {
List<List<Integer>> graph = new ArrayList<>();
for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
for (int[] e : edges) {
graph.get(e[0]).add(e[1]);
graph.get(e[1]).add(e[0]);
}
boolean[] visited = new boolean[n];
int components = ___;
for (int i = 0; i < n; i++) {
if (!visited[i]) {
components___;
Queue<Integer> queue = new LinkedList<>();
queue.offer(i);
visited[i] = true;
while (!queue.isEmpty()) {
int curr = queue.poll();
for (int neighbor : graph.get(curr)) {
if (!visited[neighbor]) {
visited[neighbor] = true;
queue.offer(___);
}
}
}
}
}
return components;
}
public static void main(String[] args) {
System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{3,4}})); // 2
System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{2,3},{3,4}})); // 1
System.out.println(countComponents(4, new int[][]{})); // 4
}
}
components = 0. Each time you start a BFS/DFS from an unvisited node, increment it by 1. The traversal marks all nodes in that component as visited so they are skipped on future iterations.from collections import deque, defaultdict
def count_components(n, edges):
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
graph[v].append(u)
visited = set()
components = 0
for node in range(n):
if node not in visited:
components += 1
queue = deque([node])
visited.add(node)
while queue:
curr = queue.popleft()
for neighbor in graph[curr]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return components
print(count_components(5, [[0,1],[1,2],[3,4]])) # 2
print(count_components(5, [[0,1],[1,2],[2,3],[3,4]])) # 1
print(count_components(4, [])) # 4
import java.util.*;
public class Solution {
static int countComponents(int n, int[][] edges) {
List<List<Integer>> graph = new ArrayList<>();
for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
for (int[] e : edges) {
graph.get(e[0]).add(e[1]);
graph.get(e[1]).add(e[0]);
}
boolean[] visited = new boolean[n];
int components = 0;
for (int i = 0; i < n; i++) {
if (!visited[i]) {
components++;
Queue<Integer> queue = new LinkedList<>();
queue.offer(i);
visited[i] = true;
while (!queue.isEmpty()) {
int curr = queue.poll();
for (int neighbor : graph.get(curr)) {
if (!visited[neighbor]) {
visited[neighbor] = true;
queue.offer(neighbor);
}
}
}
}
}
return components;
}
public static void main(String[] args) {
System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{3,4}})); // 2
System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{2,3},{3,4}})); // 1
System.out.println(countComponents(4, new int[][]{})); // 4
}
}
Detect Cycle in Directed Graph
Return True if the directed graph contains a cycle. Use DFS with three node states: unvisited (0), currently in the DFS stack (1), and fully processed (2). A back edge to a state-1 node indicates a cycle. Fill in the state transitions and cycle detection condition.
from collections import defaultdict
def has_cycle(num_nodes, edges):
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
# 0 = unvisited, 1 = in-progress, 2 = done
state = [0] * num_nodes
def dfs(node):
state[node] = ___ # mark in-progress
for neighbor in graph[node]:
if state[neighbor] == ___: # back edge = cycle
return True
if state[neighbor] == 0 and dfs(neighbor):
return True
state[node] = ___ # mark done
return False
for node in range(num_nodes):
if state[node] == 0:
if dfs(node):
return True
return False
print(has_cycle(4, [[0,1],[1,2],[2,0],[2,3]])) # True
print(has_cycle(4, [[0,1],[1,2],[2,3]])) # False
import java.util.*;
public class Solution {
static List<List<Integer>> graph;
static int[] state; // 0=unvisited, 1=in-progress, 2=done
static boolean dfs(int node) {
state[node] = ___; // in-progress
for (int neighbor : graph.get(node)) {
if (state[neighbor] == ___) return true; // back edge
if (state[neighbor] == 0 && dfs(neighbor)) return true;
}
state[node] = ___; // done
return false;
}
static boolean hasCycle(int numNodes, int[][] edges) {
graph = new ArrayList<>();
state = new int[numNodes];
for (int i = 0; i < numNodes; i++) graph.add(new ArrayList<>());
for (int[] e : edges) graph.get(e[0]).add(e[1]);
for (int i = 0; i < numNodes; i++) {
if (state[i] == 0 && dfs(i)) return true;
}
return false;
}
public static void main(String[] args) {
System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,0},{2,3}})); // true
System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,3}})); // false
}
}
state[node] = 1 when entering DFS. If you reach a neighbor with state 1, you have found a back edge — return True. After exploring all neighbors, set state[node] = 2 (done). Nodes with state 2 are safe and can be skipped.from collections import defaultdict
def has_cycle(num_nodes, edges):
graph = defaultdict(list)
for u, v in edges:
graph[u].append(v)
state = [0] * num_nodes # 0=unvisited, 1=in-progress, 2=done
def dfs(node):
state[node] = 1
for neighbor in graph[node]:
if state[neighbor] == 1:
return True
if state[neighbor] == 0 and dfs(neighbor):
return True
state[node] = 2
return False
for node in range(num_nodes):
if state[node] == 0:
if dfs(node):
return True
return False
print(has_cycle(4, [[0,1],[1,2],[2,0],[2,3]])) # True
print(has_cycle(4, [[0,1],[1,2],[2,3]])) # False
import java.util.*;
public class Solution {
static List<List<Integer>> graph;
static int[] state;
static boolean dfs(int node) {
state[node] = 1;
for (int neighbor : graph.get(node)) {
if (state[neighbor] == 1) return true;
if (state[neighbor] == 0 && dfs(neighbor)) return true;
}
state[node] = 2;
return false;
}
static boolean hasCycle(int numNodes, int[][] edges) {
graph = new ArrayList<>();
state = new int[numNodes];
for (int i = 0; i < numNodes; i++) graph.add(new ArrayList<>());
for (int[] e : edges) graph.get(e[0]).add(e[1]);
for (int i = 0; i < numNodes; i++) {
if (state[i] == 0 && dfs(i)) return true;
}
return false;
}
public static void main(String[] args) {
System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,0},{2,3}})); // true
System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,3}})); // false
}
}
11. Shortest Path Algorithms
Weighted graphs require algorithms that account for edge costs. Choose based on constraints: negative weights, number of sources, and graph density all matter.
Dijkstra's Algorithm
Single-source shortest path for graphs with non-negative edge weights. Uses a min-heap to always relax the cheapest reachable node next.
import heapq
from collections import defaultdict
def dijkstra(n: int,
edges: list[tuple[int, int, int]],
src: int) -> list[float]:
"""Single-source shortest paths from src.
edges: list of (u, v, weight) — undirected here.
Returns dist[] where dist[i] = shortest distance from src to i,
or float('inf') if unreachable.
Time: O(E log V) Space: O(V + E)
"""
graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w)) # omit for directed graph
dist = [float('inf')] * n
dist[src] = 0
heap: list[tuple[float, int]] = [(0, src)] # (cost, node)
while heap:
cost, u = heapq.heappop(heap)
if cost > dist[u]: # stale entry — skip
continue
for v, w in graph[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
heapq.heappush(heap, (dist[v], v))
return dist
def dijkstra_with_path(n: int,
edges: list[tuple[int, int, int]],
src: int,
dst: int) -> tuple[float, list[int]]:
"""Returns (distance, path). Path is empty if dst unreachable."""
graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w))
dist = [float('inf')] * n
prev = [-1] * n
dist[src] = 0
heap = [(0, src)]
while heap:
cost, u = heapq.heappop(heap)
if cost > dist[u]:
continue
for v, w in graph[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
prev[v] = u
heapq.heappush(heap, (dist[v], v))
if dist[dst] == float('inf'):
return float('inf'), []
path: list[int] = []
cur = dst
while cur != -1:
path.append(cur)
cur = prev[cur]
return dist[dst], path[::-1]
Bellman-Ford
Handles negative edge weights. Relaxes all edges V-1 times. A V-th relaxation that still reduces distance means a negative cycle exists.
def bellman_ford(n: int,
edges: list[tuple[int, int, int]],
src: int) -> tuple[list[float], bool]:
"""Single-source shortest path supporting negative weights.
edges: list of (u, v, weight) — directed.
Returns (dist[], has_negative_cycle).
Time: O(V * E) Space: O(V)
"""
dist = [float('inf')] * n
dist[src] = 0
for _ in range(n - 1): # relax all edges V-1 times
updated = False
for u, v, w in edges:
if dist[u] != float('inf') and dist[u] + w < dist[v]:
dist[v] = dist[u] + w
updated = True
if not updated: # early exit: already converged
break
# Check for negative cycle: one more pass still reduces?
for u, v, w in edges:
if dist[u] != float('inf') and dist[u] + w < dist[v]:
return dist, True # negative cycle reachable from src
return dist, False
Floyd-Warshall
All-pairs shortest paths. Works with negative weights (but not negative cycles). DP over intermediate nodes.
def floyd_warshall(n: int,
edges: list[tuple[int, int, int]]
) -> list[list[float]]:
"""All-pairs shortest paths.
Time: O(V^3) Space: O(V^2)
Use when n is small (n <= 500 typically).
"""
INF = float('inf')
dist = [[INF] * n for _ in range(n)]
for i in range(n):
dist[i][i] = 0
for u, v, w in edges:
dist[u][v] = min(dist[u][v], w) # directed; duplicate edges: take min
for k in range(n): # k = intermediate node
for i in range(n):
for j in range(n):
if dist[i][k] + dist[k][j] < dist[i][j]:
dist[i][j] = dist[i][k] + dist[k][j]
# Negative cycle exists if any dist[i][i] < 0
return dist
0-1 BFS
When edges have weight 0 or 1 only, use a deque instead of a heap: push 0-weight neighbors to front, 1-weight to back. O(V + E) — faster than Dijkstra for this special case.
from collections import deque
def bfs_01(grid: list[list[int]],
src: tuple[int, int],
dst: tuple[int, int]) -> int:
"""Minimum cost path in grid where moving to 0-cell costs 0,
moving to 1-cell costs 1.
Time: O(rows * cols) Space: O(rows * cols)
"""
rows, cols = len(grid), len(grid[0])
dist = [[float('inf')] * cols for _ in range(rows)]
sr, sc = src
dist[sr][sc] = 0
dq: deque[tuple[int, int]] = deque([(sr, sc)])
while dq:
r, c = dq.popleft()
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < rows and 0 <= nc < cols:
w = grid[nr][nc] # edge weight: 0 or 1
new_dist = dist[r][c] + w
if new_dist < dist[nr][nc]:
dist[nr][nc] = new_dist
if w == 0:
dq.appendleft((nr, nc)) # cheaper: goes to front
else:
dq.append((nr, nc)) # costlier: goes to back
dr, dc = dst
return dist[dr][dc] if dist[dr][dc] != float('inf') else -1
| Situation | Algorithm | Complexity |
|---|---|---|
| Non-negative weights, single source | Dijkstra (min-heap) | O(E log V) |
| Negative weights, single source | Bellman-Ford | O(VE) |
| All pairs, small n | Floyd-Warshall | \(O(V^3)\) |
| Unweighted graph | BFS | O(V + E) |
| 0/1 edge weights | 0-1 BFS (deque) | O(V + E) |
| DAG, any weights | Topo sort + relax | O(V + E) |
Dijkstra Simplified
Given a weighted undirected graph as a list of edges (u, v, w), find the
shortest distance from node 0 to every other node. Fill in the Dijkstra implementation.
Expected output: [0, 3, 1, 4]
import heapq
def dijkstra(edges, n):
graph = [[] for _ in range(n)]
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w))
dist = [float('inf')] * n
dist[0] = 0
heap = [(0, 0)] # (cost, node)
while heap:
cost, node = heapq.heappop(heap)
# __ Skip if we already found a better path
if cost > dist[node]:
continue
for neighbor, weight in graph[node]:
new_cost = cost + weight
# __ Relax the edge if cheaper
if new_cost < dist[neighbor]:
dist[neighbor] = new_cost
heapq.heappush(heap, (new_cost, neighbor))
return dist
edges = [(0, 1, 4), (0, 2, 1), (2, 1, 2), (1, 3, 1), (2, 3, 5)]
print(dijkstra(edges, 4)) # [0, 3, 1, 4]
import java.util.*;
public class Solution {
public static int[] dijkstra(int[][] edges, int n) {
List<int[]>[] graph = new List[n];
for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
for (int[] e : edges) {
graph[e[0]].add(new int[]{e[1], e[2]});
graph[e[1]].add(new int[]{e[0], e[2]});
}
int[] dist = new int[n];
Arrays.fill(dist, Integer.MAX_VALUE);
dist[0] = 0;
PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
pq.offer(new int[]{0, 0}); // {cost, node}
while (!pq.isEmpty()) {
int[] curr = pq.poll();
int cost = curr[0], node = curr[1];
// Skip if we already found a better path
if (cost > dist[node]) continue;
for (int[] next : graph[node]) {
int neighbor = next[0], weight = next[1];
int newCost = cost + weight;
// Relax the edge if cheaper
if (newCost < dist[neighbor]) {
dist[neighbor] = newCost;
pq.offer(new int[]{newCost, neighbor});
}
}
}
return dist;
}
public static void main(String[] args) {
int[][] edges = {{0,1,4},{0,2,1},{2,1,2},{1,3,1},{2,3,5}};
System.out.println(Arrays.toString(dijkstra(edges, 4)));
// [0, 3, 1, 4]
}
}
(cost, node). Pop the cheapest node, skip if its cost exceeds the recorded best. Relax edges to neighbors and push improved distances.import heapq
def dijkstra(edges, n):
graph = [[] for _ in range(n)]
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w))
dist = [float('inf')] * n
dist[0] = 0
heap = [(0, 0)]
while heap:
cost, node = heapq.heappop(heap)
if cost > dist[node]:
continue
for neighbor, weight in graph[node]:
new_cost = cost + weight
if new_cost < dist[neighbor]:
dist[neighbor] = new_cost
heapq.heappush(heap, (new_cost, neighbor))
return dist
edges = [(0, 1, 4), (0, 2, 1), (2, 1, 2), (1, 3, 1), (2, 3, 5)]
print(dijkstra(edges, 4)) # [0, 3, 1, 4]
import java.util.*;
public class Solution {
public static int[] dijkstra(int[][] edges, int n) {
List<int[]>[] graph = new List[n];
for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
for (int[] e : edges) {
graph[e[0]].add(new int[]{e[1], e[2]});
graph[e[1]].add(new int[]{e[0], e[2]});
}
int[] dist = new int[n];
Arrays.fill(dist, Integer.MAX_VALUE);
dist[0] = 0;
PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
pq.offer(new int[]{0, 0});
while (!pq.isEmpty()) {
int[] curr = pq.poll();
int cost = curr[0], node = curr[1];
if (cost > dist[node]) continue;
for (int[] next : graph[node]) {
int neighbor = next[0], weight = next[1];
int newCost = cost + weight;
if (newCost < dist[neighbor]) {
dist[neighbor] = newCost;
pq.offer(new int[]{newCost, neighbor});
}
}
}
return dist;
}
public static void main(String[] args) {
int[][] edges = {{0,1,4},{0,2,1},{2,1,2},{1,3,1},{2,3,5}};
System.out.println(Arrays.toString(dijkstra(edges, 4)));
// [0, 3, 1, 4]
}
}
Network Delay Time
A signal is sent from node k. Given weighted directed edges, find how long
until all nodes receive the signal (max of all shortest paths). Return -1
if any node is unreachable. Expected output: 3
import heapq
def network_delay(times, n, k):
graph = [[] for _ in range(n + 1)]
for u, v, w in times:
graph[u].append((v, w))
dist = [float('inf')] * (n + 1)
dist[k] = 0
heap = [(0, k)]
while heap:
cost, node = heapq.heappop(heap)
if cost > dist[node]:
continue
for neighbor, weight in graph[node]:
new_cost = cost + weight
# __ Relax if cheaper
if new_cost < dist[neighbor]:
dist[neighbor] = new_cost
heapq.heappush(heap, (new_cost, neighbor))
# __ Check reachability and return answer
max_dist = max(dist[1:])
return max_dist if max_dist < float('inf') else -1
edges = [[1, 2, 1], [2, 3, 2], [1, 3, 4]]
print(network_delay(edges, 3, 1)) # 3
import java.util.*;
public class Solution {
public static int networkDelay(int[][] times, int n, int k) {
List<int[]>[] graph = new List[n + 1];
for (int i = 0; i <= n; i++) graph[i] = new ArrayList<>();
for (int[] t : times) graph[t[0]].add(new int[]{t[1], t[2]});
int[] dist = new int[n + 1];
Arrays.fill(dist, Integer.MAX_VALUE);
dist[k] = 0;
PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
pq.offer(new int[]{0, k});
while (!pq.isEmpty()) {
int[] curr = pq.poll();
int cost = curr[0], node = curr[1];
if (cost > dist[node]) continue;
for (int[] next : graph[node]) {
int newCost = cost + next[1];
// Relax if cheaper
if (newCost < dist[next[0]]) {
dist[next[0]] = newCost;
pq.offer(new int[]{newCost, next[0]});
}
}
}
// Check reachability and return answer
int maxDist = 0;
for (int i = 1; i <= n; i++) {
if (dist[i] == Integer.MAX_VALUE) return -1;
maxDist = Math.max(maxDist, dist[i]);
}
return maxDist;
}
public static void main(String[] args) {
int[][] edges = {{1,2,1},{2,3,2},{1,3,4}};
System.out.println(networkDelay(edges, 3, 1)); // 3
}
}
k. The answer is max(distances). Return -1 if any distance is still infinity after the algorithm completes.import heapq
def network_delay(times, n, k):
graph = [[] for _ in range(n + 1)]
for u, v, w in times:
graph[u].append((v, w))
dist = [float('inf')] * (n + 1)
dist[k] = 0
heap = [(0, k)]
while heap:
cost, node = heapq.heappop(heap)
if cost > dist[node]:
continue
for neighbor, weight in graph[node]:
new_cost = cost + weight
if new_cost < dist[neighbor]:
dist[neighbor] = new_cost
heapq.heappush(heap, (new_cost, neighbor))
max_dist = max(dist[1:])
return max_dist if max_dist < float('inf') else -1
edges = [[1, 2, 1], [2, 3, 2], [1, 3, 4]]
print(network_delay(edges, 3, 1)) # 3
import java.util.*;
public class Solution {
public static int networkDelay(int[][] times, int n, int k) {
List<int[]>[] graph = new List[n + 1];
for (int i = 0; i <= n; i++) graph[i] = new ArrayList<>();
for (int[] t : times) graph[t[0]].add(new int[]{t[1], t[2]});
int[] dist = new int[n + 1];
Arrays.fill(dist, Integer.MAX_VALUE);
dist[k] = 0;
PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
pq.offer(new int[]{0, k});
while (!pq.isEmpty()) {
int[] curr = pq.poll();
int cost = curr[0], node = curr[1];
if (cost > dist[node]) continue;
for (int[] next : graph[node]) {
int newCost = cost + next[1];
if (newCost < dist[next[0]]) {
dist[next[0]] = newCost;
pq.offer(new int[]{newCost, next[0]});
}
}
}
int maxDist = 0;
for (int i = 1; i <= n; i++) {
if (dist[i] == Integer.MAX_VALUE) return -1;
maxDist = Math.max(maxDist, dist[i]);
}
return maxDist;
}
public static void main(String[] args) {
int[][] edges = {{1,2,1},{2,3,2},{1,3,4}};
System.out.println(networkDelay(edges, 3, 1)); // 3
}
}
Cheapest Flights with K Stops
Find the cheapest price from src to dst using at most
k stops. Use the Bellman-Ford approach: relax edges k+1 times.
Expected output: 200
def cheapest(n, flights, src, dst, k):
prices = [float('inf')] * n
prices[src] = 0
for i in range(k + 1):
# __ Work on a copy so we don't reuse this round's updates
temp = prices[:]
for u, v, w in flights:
if prices[u] != float('inf'):
new_price = prices[u] + w
# __ Relax if cheaper
if new_price < temp[v]:
temp[v] = new_price
prices = temp
return prices[dst] if prices[dst] != float('inf') else -1
flights = [[0, 1, 100], [1, 2, 100], [0, 2, 500]]
print(cheapest(3, flights, 0, 2, 1)) # 200
import java.util.*;
public class Solution {
public static int cheapest(int n, int[][] flights, int src, int dst, int k) {
int[] prices = new int[n];
Arrays.fill(prices, Integer.MAX_VALUE);
prices[src] = 0;
for (int i = 0; i <= k; i++) {
// Work on a copy so we don't reuse this round's updates
int[] temp = Arrays.copyOf(prices, n);
for (int[] flight : flights) {
int u = flight[0], v = flight[1], w = flight[2];
if (prices[u] != Integer.MAX_VALUE) {
int newPrice = prices[u] + w;
// Relax if cheaper
if (newPrice < temp[v]) temp[v] = newPrice;
}
}
prices = temp;
}
return prices[dst] == Integer.MAX_VALUE ? -1 : prices[dst];
}
public static void main(String[] args) {
int[][] flights = {{0,1,100},{1,2,100},{0,2,500}};
System.out.println(cheapest(3, flights, 0, 2, 1)); // 200
}
}
k+1 relaxation rounds. Each round represents one more edge (hop) used. Copy the prices array before each round so updates from the current round don't bleed into same-round relaxations.def cheapest(n, flights, src, dst, k):
prices = [float('inf')] * n
prices[src] = 0
for i in range(k + 1):
temp = prices[:]
for u, v, w in flights:
if prices[u] != float('inf'):
new_price = prices[u] + w
if new_price < temp[v]:
temp[v] = new_price
prices = temp
return prices[dst] if prices[dst] != float('inf') else -1
flights = [[0, 1, 100], [1, 2, 100], [0, 2, 500]]
print(cheapest(3, flights, 0, 2, 1)) # 200
import java.util.*;
public class Solution {
public static int cheapest(int n, int[][] flights, int src, int dst, int k) {
int[] prices = new int[n];
Arrays.fill(prices, Integer.MAX_VALUE);
prices[src] = 0;
for (int i = 0; i <= k; i++) {
int[] temp = Arrays.copyOf(prices, n);
for (int[] flight : flights) {
int u = flight[0], v = flight[1], w = flight[2];
if (prices[u] != Integer.MAX_VALUE) {
int newPrice = prices[u] + w;
if (newPrice < temp[v]) temp[v] = newPrice;
}
}
prices = temp;
}
return prices[dst] == Integer.MAX_VALUE ? -1 : prices[dst];
}
public static void main(String[] args) {
int[][] flights = {{0,1,100},{1,2,100},{0,2,500}};
System.out.println(cheapest(3, flights, 0, 2, 1)); // 200
}
}
12. Advanced Graphs
Union-Find (Disjoint Set Union)
Efficient data structure for tracking connected components. Two key optimizations: path compression (flatten tree on find) and union by rank (attach smaller tree under larger). Both together give near O(1) amortized per operation.
class UnionFind:
"""Union-Find with path compression and union by rank.
Time: O(alpha(n)) per operation (effectively O(1)).
Space: O(n)
"""
def __init__(self, n: int) -> None:
self.parent = list(range(n))
self.rank = [0] * n
self.components = n # number of distinct components
def find(self, x: int) -> int:
"""Find root with path compression."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compress
return self.parent[x]
def union(self, x: int, y: int) -> bool:
"""Unite the sets containing x and y.
Returns True if they were in different sets (new connection made).
"""
rx, ry = self.find(x), self.find(y)
if rx == ry:
return False # already connected
# Attach smaller rank tree under larger rank root
if self.rank[rx] < self.rank[ry]:
rx, ry = ry, rx
self.parent[ry] = rx
if self.rank[rx] == self.rank[ry]:
self.rank[rx] += 1
self.components -= 1
return True
def connected(self, x: int, y: int) -> bool:
return self.find(x) == self.find(y)
# Example: count connected components in undirected graph
def count_components(n: int, edges: list[list[int]]) -> int:
uf = UnionFind(n)
for u, v in edges:
uf.union(u, v)
return uf.components
Minimum Spanning Tree
A spanning tree that connects all V nodes with exactly V-1 edges and minimal total weight.
Kruskal's — sort edges, add if no cycle (Union-Find)
def kruskal_mst(n: int,
edges: list[tuple[int, int, int]]) -> tuple[int, list]:
"""Kruskal's MST algorithm.
edges: list of (weight, u, v).
Returns (total_cost, mst_edges).
Time: O(E log E) Space: O(V)
"""
uf = UnionFind(n)
mst_cost = 0
mst_edges: list[tuple[int, int, int]] = []
for w, u, v in sorted(edges): # sort by weight ascending
if uf.union(u, v): # no cycle: add edge
mst_cost += w
mst_edges.append((w, u, v))
if len(mst_edges) == n - 1:
break # MST complete
return mst_cost, mst_edges
Prim's — grow MST greedily from a start node (min-heap)
def prim_mst(n: int,
edges: list[tuple[int, int, int]]) -> int:
"""Prim's MST algorithm.
edges: list of (u, v, weight) — undirected.
Returns total MST cost, or -1 if graph is disconnected.
Time: O(E log V) Space: O(V + E)
"""
graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
graph[v].append((u, w))
visited: set[int] = set()
heap = [(0, 0)] # (cost, node), start at node 0
total = 0
while heap and len(visited) < n:
cost, u = heapq.heappop(heap)
if u in visited:
continue
visited.add(u)
total += cost
for v, w in graph[u]:
if v not in visited:
heapq.heappush(heap, (w, v))
return total if len(visited) == n else -1
Bipartite Checking (2-Color BFS)
def is_bipartite(graph: list[list[int]]) -> bool:
"""Check if graph is 2-colorable (no odd-length cycles).
graph: adjacency list indexed 0..n-1.
Time: O(V + E) Space: O(V)
"""
n = len(graph)
color = [-1] * n
for start in range(n):
if color[start] != -1:
continue
color[start] = 0
queue: deque[int] = deque([start])
while queue:
node = queue.popleft()
for neighbor in graph[node]:
if color[neighbor] == -1:
color[neighbor] = 1 - color[node]
queue.append(neighbor)
elif color[neighbor] == color[node]:
return False # same color: odd cycle found
return True
Strongly Connected Components (Kosaraju's Concept)
A SCC is a maximal set of nodes where every node is reachable from every other node. Kosaraju's runs two DFS passes: one on original graph (finish order), one on transposed graph in reverse finish order.
def kosaraju_scc(n: int,
edges: list[tuple[int, int]]) -> list[list[int]]:
"""Find all SCCs using Kosaraju's algorithm.
Time: O(V + E) Space: O(V + E)
"""
graph: dict[int, list[int]] = defaultdict(list)
rev_graph: dict[int, list[int]] = defaultdict(list)
for u, v in edges:
graph[u].append(v)
rev_graph[v].append(u)
# Pass 1: DFS on original graph, record finish order
visited: set[int] = set()
finish_order: list[int] = []
def dfs1(node: int) -> None:
visited.add(node)
for nb in graph[node]:
if nb not in visited:
dfs1(nb)
finish_order.append(node) # append on finish
for v in range(n):
if v not in visited:
dfs1(v)
# Pass 2: DFS on reversed graph in reverse finish order
visited.clear()
sccs: list[list[int]] = []
def dfs2(node: int, component: list[int]) -> None:
visited.add(node)
component.append(node)
for nb in rev_graph[node]:
if nb not in visited:
dfs2(nb, component)
for v in reversed(finish_order):
if v not in visited:
component: list[int] = []
dfs2(v, component)
sccs.append(component)
return sccs
Union-Find: Connected Components
Count the number of connected components in an undirected graph using Union-Find.
Expected output: 2
def count_components(n, edges):
parent = list(range(n))
def find(x):
# __ Path compression: point directly to root
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a, b):
ra, rb = find(a), find(b)
if ra == rb:
return
# __ Merge: point one root to the other
parent[ra] = rb
for u, v in edges:
union(u, v)
# __ Count distinct roots
return sum(1 for i in range(n) if find(i) == i)
print(count_components(5, [[0, 1], [1, 2], [3, 4]])) # 2
public class Solution {
static int[] parent;
static int find(int x) {
// Path compression: point directly to root
while (parent[x] != x) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
static void union(int a, int b) {
int ra = find(a), rb = find(b);
if (ra == rb) return;
// Merge: point one root to the other
parent[ra] = rb;
}
public static int countComponents(int n, int[][] edges) {
parent = new int[n];
for (int i = 0; i < n; i++) parent[i] = i;
for (int[] e : edges) union(e[0], e[1]);
// Count distinct roots
int count = 0;
for (int i = 0; i < n; i++) if (find(i) == i) count++;
return count;
}
public static void main(String[] args) {
int[][] edges = {{0,1},{1,2},{3,4}};
System.out.println(countComponents(5, edges)); // 2
}
}
parent[i] = i. find() follows parent pointers to the root (with path compression). union() merges two sets by connecting their roots. At the end, count nodes where find(i) == i (they are roots of distinct components).def count_components(n, edges):
parent = list(range(n))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a, b):
ra, rb = find(a), find(b)
if ra == rb:
return
parent[ra] = rb
for u, v in edges:
union(u, v)
return sum(1 for i in range(n) if find(i) == i)
print(count_components(5, [[0, 1], [1, 2], [3, 4]])) # 2
public class Solution {
static int[] parent;
static int find(int x) {
while (parent[x] != x) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
static void union(int a, int b) {
int ra = find(a), rb = find(b);
if (ra == rb) return;
parent[ra] = rb;
}
public static int countComponents(int n, int[][] edges) {
parent = new int[n];
for (int i = 0; i < n; i++) parent[i] = i;
for (int[] e : edges) union(e[0], e[1]);
int count = 0;
for (int i = 0; i < n; i++) if (find(i) == i) count++;
return count;
}
public static void main(String[] args) {
int[][] edges = {{0,1},{1,2},{3,4}};
System.out.println(countComponents(5, edges)); // 2
}
}
Topological Sort
Return a topological ordering of a directed acyclic graph using Kahn's algorithm.
Return an empty list if a cycle exists. Test checks len(result) == n.
Expected output: 4 (length of valid ordering)
from collections import deque
def topo_sort(n, edges):
graph = [[] for _ in range(n)]
in_degree = [0] * n
for u, v in edges:
graph[u].append(v)
# __ Increment in-degree of destination
in_degree[v] += 1
# __ Start with all nodes that have no incoming edges
queue = deque(i for i in range(n) if in_degree[i] == 0)
result = []
while queue:
node = queue.popleft()
result.append(node)
for neighbor in graph[node]:
in_degree[neighbor] -= 1
# __ Add neighbor if it now has no dependencies
if in_degree[neighbor] == 0:
queue.append(neighbor)
# __ Cycle check: valid only if all nodes are in result
return result if len(result) == n else []
result = topo_sort(4, [[0, 1], [0, 2], [1, 3], [2, 3]])
print(len(result)) # 4 (valid ordering exists)
print(topo_sort(3, [[0, 1], [1, 2], [2, 0]])) # [] (cycle)
import java.util.*;
public class Solution {
public static List<Integer> topoSort(int n, int[][] edges) {
List<Integer>[] graph = new List[n];
int[] inDegree = new int[n];
for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
for (int[] e : edges) {
graph[e[0]].add(e[1]);
// Increment in-degree of destination
inDegree[e[1]]++;
}
// Start with all nodes that have no incoming edges
Queue<Integer> queue = new LinkedList<>();
for (int i = 0; i < n; i++) if (inDegree[i] == 0) queue.offer(i);
List<Integer> result = new ArrayList<>();
while (!queue.isEmpty()) {
int node = queue.poll();
result.add(node);
for (int neighbor : graph[node]) {
inDegree[neighbor]--;
// Add neighbor if it now has no dependencies
if (inDegree[neighbor] == 0) queue.offer(neighbor);
}
}
// Cycle check
return result.size() == n ? result : new ArrayList<>();
}
public static void main(String[] args) {
int[][] edges = {{0,1},{0,2},{1,3},{2,3}};
List<Integer> result = topoSort(4, edges);
System.out.println(result.size()); // 4
int[][] cycle = {{0,1},{1,2},{2,0}};
System.out.println(topoSort(3, cycle)); // []
}
}
n nodes are processed — otherwise there's a cycle.from collections import deque
def topo_sort(n, edges):
graph = [[] for _ in range(n)]
in_degree = [0] * n
for u, v in edges:
graph[u].append(v)
in_degree[v] += 1
queue = deque(i for i in range(n) if in_degree[i] == 0)
result = []
while queue:
node = queue.popleft()
result.append(node)
for neighbor in graph[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
return result if len(result) == n else []
result = topo_sort(4, [[0, 1], [0, 2], [1, 3], [2, 3]])
print(len(result)) # 4
print(topo_sort(3, [[0, 1], [1, 2], [2, 0]])) # []
import java.util.*;
public class Solution {
public static List<Integer> topoSort(int n, int[][] edges) {
List<Integer>[] graph = new List[n];
int[] inDegree = new int[n];
for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
for (int[] e : edges) {
graph[e[0]].add(e[1]);
inDegree[e[1]]++;
}
Queue<Integer> queue = new LinkedList<>();
for (int i = 0; i < n; i++) if (inDegree[i] == 0) queue.offer(i);
List<Integer> result = new ArrayList<>();
while (!queue.isEmpty()) {
int node = queue.poll();
result.add(node);
for (int neighbor : graph[node]) {
inDegree[neighbor]--;
if (inDegree[neighbor] == 0) queue.offer(neighbor);
}
}
return result.size() == n ? result : new ArrayList<>();
}
public static void main(String[] args) {
int[][] edges = {{0,1},{0,2},{1,3},{2,3}};
List<Integer> result = topoSort(4, edges);
System.out.println(result.size()); // 4
int[][] cycle = {{0,1},{1,2},{2,0}};
System.out.println(topoSort(3, cycle)); // []
}
}
Redundant Connection
Given edges of an undirected graph added one at a time, find the first edge that
creates a cycle. Return that edge. Expected output: [2, 3]
def find_redundant(edges):
n = len(edges)
parent = list(range(n + 1))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
for u, v in edges:
ru, rv = find(u), find(v)
# __ If same root, adding this edge creates a cycle
if ru == rv:
return [u, v]
# __ Otherwise, union them
parent[ru] = rv
return []
print(find_redundant([[1, 2], [1, 3], [2, 3]])) # [2, 3]
public class Solution {
static int[] parent;
static int find(int x) {
while (parent[x] != x) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
public static int[] findRedundant(int[][] edges) {
int n = edges.length;
parent = new int[n + 1];
for (int i = 0; i <= n; i++) parent[i] = i;
for (int[] e : edges) {
int ru = find(e[0]), rv = find(e[1]);
// If same root, adding this edge creates a cycle
if (ru == rv) return e;
// Otherwise, union them
parent[ru] = rv;
}
return new int[]{};
}
public static void main(String[] args) {
int[][] edges = {{1,2},{1,3},{2,3}};
int[] result = findRedundant(edges);
System.out.println("[" + result[0] + ", " + result[1] + "]"); // [2, 3]
}
}
find() on both endpoints. If they already share the same root, this edge would form a cycle — return it immediately. Otherwise, union the two sets and continue.def find_redundant(edges):
n = len(edges)
parent = list(range(n + 1))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
for u, v in edges:
ru, rv = find(u), find(v)
if ru == rv:
return [u, v]
parent[ru] = rv
return []
print(find_redundant([[1, 2], [1, 3], [2, 3]])) # [2, 3]
public class Solution {
static int[] parent;
static int find(int x) {
while (parent[x] != x) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
public static int[] findRedundant(int[][] edges) {
int n = edges.length;
parent = new int[n + 1];
for (int i = 0; i <= n; i++) parent[i] = i;
for (int[] e : edges) {
int ru = find(e[0]), rv = find(e[1]);
if (ru == rv) return e;
parent[ru] = rv;
}
return new int[]{};
}
public static void main(String[] args) {
int[][] edges = {{1,2},{1,3},{2,3}};
int[] result = findRedundant(edges);
System.out.println("[" + result[0] + ", " + result[1] + "]"); // [2, 3]
}
}
13. Dynamic Programming Patterns
DP applies when a problem has optimal substructure (optimal solution uses optimal solutions to subproblems) and overlapping subproblems (same subproblems solved repeatedly).
2. Recurrence: How does dp[state] relate to smaller states?
3. Base case: What are the smallest/trivially solved states?
4. Order: Ensure each state is computed before it is needed.
5. Answer: Which state holds the final answer?
6. Optimize: Can you reduce space with a rolling array?
1D DP Classics
def climb_stairs(n: int) -> int:
"""Count ways to climb n stairs (1 or 2 steps at a time).
State: dp[i] = number of ways to reach step i.
Recurrence: dp[i] = dp[i-1] + dp[i-2]
Time: O(n) Space: O(1) after rolling optimization
"""
if n <= 2:
return n
a, b = 1, 2
for _ in range(3, n + 1):
a, b = b, a + b
return b
def house_robber(nums: list[int]) -> int:
"""Max sum with no two adjacent elements selected.
State: dp[i] = max profit using houses 0..i.
Recurrence: dp[i] = max(dp[i-1], dp[i-2] + nums[i])
Time: O(n) Space: O(1)
"""
prev2, prev1 = 0, 0
for x in nums:
prev2, prev1 = prev1, max(prev1, prev2 + x)
return prev1
def coin_change(coins: list[int], amount: int) -> int:
"""Minimum coins to make amount. -1 if impossible.
State: dp[i] = fewest coins to make amount i.
Recurrence: dp[i] = min(dp[i - c] + 1) for c in coins if i >= c
Time: O(amount * len(coins)) Space: O(amount)
"""
dp = [float('inf')] * (amount + 1)
dp[0] = 0
for i in range(1, amount + 1):
for c in coins:
if i >= c:
dp[i] = min(dp[i], dp[i - c] + 1)
return dp[amount] if dp[amount] != float('inf') else -1
def word_break(s: str, word_dict: list[str]) -> bool:
"""Can s be segmented into words from word_dict?
State: dp[i] = True if s[:i] can be segmented.
Time: O(n^2) Space: O(n)
"""
words = set(word_dict)
n = len(s)
dp = [False] * (n + 1)
dp[0] = True # empty string always valid
for i in range(1, n + 1):
for j in range(i):
if dp[j] and s[j:i] in words:
dp[i] = True
break
return dp[n]
2D DP Classics
def unique_paths(m: int, n: int) -> int:
"""Count paths from top-left to bottom-right (only right/down).
Time: O(m*n) Space: O(n) rolling row
"""
dp = [1] * n
for _ in range(1, m):
for j in range(1, n):
dp[j] += dp[j - 1]
return dp[n - 1]
def edit_distance(word1: str, word2: str) -> int:
"""Levenshtein distance: minimum insert/delete/replace to convert word1 -> word2.
State: dp[i][j] = edit distance between word1[:i] and word2[:j].
Time: O(m*n) Space: O(m*n) — reducible to O(n)
"""
m, n = len(word1), len(word2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
dp[i][0] = i # delete all of word1
for j in range(n + 1):
dp[0][j] = j # insert all of word2
for i in range(1, m + 1):
for j in range(1, n + 1):
if word1[i-1] == word2[j-1]:
dp[i][j] = dp[i-1][j-1]
else:
dp[i][j] = 1 + min(
dp[i-1][j], # delete from word1
dp[i][j-1], # insert into word1
dp[i-1][j-1] # replace
)
return dp[m][n]
def longest_common_subsequence(text1: str, text2: str) -> int:
"""LCS length. Not necessarily contiguous.
Time: O(m*n) Space: O(m*n)
"""
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])
return dp[m][n]
Knapsack Patterns
def knapsack_01(weights: list[int],
values: list[int],
capacity: int) -> int:
"""0/1 Knapsack: each item used at most once.
State: dp[w] = max value with weight budget w.
Iterate weights in reverse to avoid using item twice.
Time: O(n * capacity) Space: O(capacity)
"""
dp = [0] * (capacity + 1)
for w, v in zip(weights, values):
for c in range(capacity, w - 1, -1): # reverse!
dp[c] = max(dp[c], dp[c - w] + v)
return dp[capacity]
def knapsack_unbounded(weights: list[int],
values: list[int],
capacity: int) -> int:
"""Unbounded Knapsack: each item usable unlimited times.
Iterate weights forward so items can be reused.
"""
dp = [0] * (capacity + 1)
for c in range(1, capacity + 1):
for w, v in zip(weights, values):
if w <= c:
dp[c] = max(dp[c], dp[c - w] + v)
return dp[capacity]
def subset_sum(nums: list[int], target: int) -> bool:
"""Can any subset sum to target? (0/1 knapsack variant)
Time: O(n * target) Space: O(target)
"""
dp = [False] * (target + 1)
dp[0] = True
for x in nums:
for j in range(target, x - 1, -1):
dp[j] = dp[j] or dp[j - x]
return dp[target]
Longest Increasing Subsequence
import bisect
def lis_n2(nums: list[int]) -> int:
"""O(n^2) LIS — easier to understand and extend."""
if not nums:
return 0
n = len(nums)
dp = [1] * n # dp[i] = LIS ending at index i
for i in range(1, n):
for j in range(i):
if nums[j] < nums[i]:
dp[i] = max(dp[i], dp[j] + 1)
return max(dp)
def lis_nlogn(nums: list[int]) -> int:
"""O(n log n) LIS using patience sort (binary search).
'tails' stores the smallest tail of all increasing subsequences
of each length. tails is always sorted.
"""
tails: list[int] = []
for x in nums:
pos = bisect.bisect_left(tails, x)
if pos == len(tails):
tails.append(x) # extend longest subsequence
else:
tails[pos] = x # replace: keeps tails minimal
return len(tails)
def lis_with_reconstruction(nums: list[int]) -> list[int]:
"""Return the actual LIS sequence (O(n^2) for simplicity)."""
if not nums:
return []
n = len(nums)
dp = [1] * n
parent = [-1] * n
best_end = 0
for i in range(1, n):
for j in range(i):
if nums[j] < nums[i] and dp[j] + 1 > dp[i]:
dp[i] = dp[j] + 1
parent[i] = j
if dp[i] > dp[best_end]:
best_end = i
path: list[int] = []
idx = best_end
while idx != -1:
path.append(nums[idx])
idx = parent[idx]
return path[::-1]
DP on Intervals
def burst_balloons(nums: list[int]) -> int:
"""Classic interval DP. Add sentinels 1 at both ends.
dp[i][j] = max coins from bursting all balloons between i and j (exclusive).
Key insight: think of 'k' as the LAST balloon to burst in interval (i,j).
Time: O(n^3) Space: O(n^2)
"""
balloons = [1] + nums + [1]
n = len(balloons)
dp = [[0] * n for _ in range(n)]
for length in range(2, n): # interval length
for left in range(0, n - length):
right = left + length
for k in range(left + 1, right): # last to burst
coins = (balloons[left] * balloons[k] * balloons[right]
+ dp[left][k] + dp[k][right])
dp[left][right] = max(dp[left][right], coins)
return dp[0][n - 1]
Space optimization: rolling array technique
# Many 2D DP problems only look at the previous row.
# Replace the m x n table with two 1D arrays.
def edit_distance_space_optimized(word1: str, word2: str) -> int:
"""O(n) space instead of O(m*n)."""
m, n = len(word1), len(word2)
prev = list(range(n + 1)) # base case: dp[0][j] = j
for i in range(1, m + 1):
curr = [i] + [0] * n # base case: dp[i][0] = i
for j in range(1, n + 1):
if word1[i-1] == word2[j-1]:
curr[j] = prev[j-1]
else:
curr[j] = 1 + min(prev[j], curr[j-1], prev[j-1])
prev = curr
return prev[n]
Climbing Stairs
Count the number of distinct ways to climb n stairs, taking 1 or 2 steps
at a time. Expected output: 8
def climb_stairs(n):
if n <= 2:
return n
# __ Use two variables instead of full dp array
prev2, prev1 = 1, 2
for i in range(3, n + 1):
# __ Current = sum of previous two (like Fibonacci)
curr = prev1 + prev2
prev2 = prev1
prev1 = curr
return prev1
print(climb_stairs(5)) # 8
public class Solution {
public static int climbStairs(int n) {
if (n <= 2) return n;
// Use two variables instead of full dp array
int prev2 = 1, prev1 = 2;
for (int i = 3; i <= n; i++) {
// Current = sum of previous two (like Fibonacci)
int curr = prev1 + prev2;
prev2 = prev1;
prev1 = curr;
}
return prev1;
}
public static void main(String[] args) {
System.out.println(climbStairs(5)); // 8
}
}
dp[i] = dp[i-1] + dp[i-2] — from stair i you could have come from i-1 (one step) or i-2 (two steps). Base cases: dp[1]=1, dp[2]=2. Optimize to O(1) space with two rolling variables.def climb_stairs(n):
if n <= 2:
return n
prev2, prev1 = 1, 2
for i in range(3, n + 1):
curr = prev1 + prev2
prev2 = prev1
prev1 = curr
return prev1
print(climb_stairs(5)) # 8
public class Solution {
public static int climbStairs(int n) {
if (n <= 2) return n;
int prev2 = 1, prev1 = 2;
for (int i = 3; i <= n; i++) {
int curr = prev1 + prev2;
prev2 = prev1;
prev1 = curr;
}
return prev1;
}
public static void main(String[] args) {
System.out.println(climbStairs(5)); // 8
}
}
Coin Change
Given coin denominations and a target amount, find the minimum number of coins needed.
Return -1 if not possible. Expected output: 2
def coin_change(coins, amount):
# __ dp[i] = min coins to make amount i
dp = [float('inf')] * (amount + 1)
dp[0] = 0
for i in range(1, amount + 1):
for coin in coins:
if coin <= i:
# __ Use this coin + whatever was needed for the remainder
dp[i] = min(dp[i], dp[i - coin] + 1)
return dp[amount] if dp[amount] != float('inf') else -1
print(coin_change([1, 5, 10, 25], 30)) # 2 (25 + 5)
import java.util.*;
public class Solution {
public static int coinChange(int[] coins, int amount) {
// dp[i] = min coins to make amount i
int[] dp = new int[amount + 1];
Arrays.fill(dp, amount + 1); // sentinel: larger than any valid answer
dp[0] = 0;
for (int i = 1; i <= amount; i++) {
for (int coin : coins) {
if (coin <= i) {
// Use this coin + whatever was needed for the remainder
dp[i] = Math.min(dp[i], dp[i - coin] + 1);
}
}
}
return dp[amount] > amount ? -1 : dp[amount];
}
public static void main(String[] args) {
System.out.println(coinChange(new int[]{1, 5, 10, 25}, 30)); // 2
}
}
dp[0..amount] bottom-up. dp[0] = 0. For each amount i, try every coin: dp[i] = min(dp[i], dp[i - coin] + 1). Initialize all other entries to infinity so you only improve on feasible sub-problems.def coin_change(coins, amount):
dp = [float('inf')] * (amount + 1)
dp[0] = 0
for i in range(1, amount + 1):
for coin in coins:
if coin <= i:
dp[i] = min(dp[i], dp[i - coin] + 1)
return dp[amount] if dp[amount] != float('inf') else -1
print(coin_change([1, 5, 10, 25], 30)) # 2
import java.util.*;
public class Solution {
public static int coinChange(int[] coins, int amount) {
int[] dp = new int[amount + 1];
Arrays.fill(dp, amount + 1);
dp[0] = 0;
for (int i = 1; i <= amount; i++) {
for (int coin : coins) {
if (coin <= i) {
dp[i] = Math.min(dp[i], dp[i - coin] + 1);
}
}
}
return dp[amount] > amount ? -1 : dp[amount];
}
public static void main(String[] args) {
System.out.println(coinChange(new int[]{1, 5, 10, 25}, 30)); // 2
}
}
House Robber
Find the maximum money you can rob from houses without robbing two adjacent ones.
Expected output: 12
def rob(nums):
if not nums:
return 0
if len(nums) == 1:
return nums[0]
# __ Two rolling variables: rob up to i-2 vs i-1
prev2, prev1 = 0, 0
for num in nums:
# __ Rob current house or skip it
curr = max(prev1, prev2 + num)
prev2 = prev1
prev1 = curr
return prev1
print(rob([2, 7, 9, 3, 1])) # 12 (2 + 9 + 1)
public class Solution {
public static int rob(int[] nums) {
if (nums.length == 1) return nums[0];
// Two rolling variables: rob up to i-2 vs i-1
int prev2 = 0, prev1 = 0;
for (int num : nums) {
// Rob current house or skip it
int curr = Math.max(prev1, prev2 + num);
prev2 = prev1;
prev1 = curr;
}
return prev1;
}
public static void main(String[] args) {
System.out.println(rob(new int[]{2, 7, 9, 3, 1})); // 12
}
}
dp[i] = max(dp[i-1], dp[i-2] + nums[i]) — either skip the current house (take dp[i-1]) or rob it (take dp[i-2] + nums[i]). Optimize to O(1) space using two variables that track the last two states.def rob(nums):
if not nums:
return 0
prev2, prev1 = 0, 0
for num in nums:
curr = max(prev1, prev2 + num)
prev2 = prev1
prev1 = curr
return prev1
print(rob([2, 7, 9, 3, 1])) # 12
public class Solution {
public static int rob(int[] nums) {
if (nums.length == 1) return nums[0];
int prev2 = 0, prev1 = 0;
for (int num : nums) {
int curr = Math.max(prev1, prev2 + num);
prev2 = prev1;
prev1 = curr;
}
return prev1;
}
public static void main(String[] args) {
System.out.println(rob(new int[]{2, 7, 9, 3, 1})); // 12
}
}
14. Backtracking
Backtracking is DFS on a decision tree: choose an option, explore recursively, then unchoose (undo) when returning. Always ask: what is the state at each node? what are the choices? what is the base case?
Universal Template
def backtrack(state, choices, result):
# Base case: state is a complete solution
if is_complete(state):
result.append(list(state)) # copy! don't store reference
return
for choice in choices:
if is_valid(choice, state):
state.append(choice) # CHOOSE
backtrack(state, next_choices(choice, state), result)
state.pop() # UNCHOOSE (undo)
Permutations
def permutations(nums: list[int]) -> list[list[int]]:
"""All permutations of a list with no duplicates.
Time: O(n! * n) Space: O(n)
"""
result: list[list[int]] = []
def backtrack(path: list[int], used: list[bool]) -> None:
if len(path) == len(nums):
result.append(path[:])
return
for i, x in enumerate(nums):
if not used[i]:
used[i] = True
path.append(x)
backtrack(path, used)
path.pop()
used[i] = False
backtrack([], [False] * len(nums))
return result
def permutations_with_duplicates(nums: list[int]) -> list[list[int]]:
"""Permutations where nums may contain duplicates. Prune duplicates.
Key: sort first; skip if nums[i] == nums[i-1] and nums[i-1] not used.
"""
nums.sort()
result: list[list[int]] = []
def backtrack(path: list[int], used: list[bool]) -> None:
if len(path) == len(nums):
result.append(path[:])
return
for i in range(len(nums)):
if used[i]:
continue
# Skip duplicate: same value as previous AND previous was not used in this branch
if i > 0 and nums[i] == nums[i-1] and not used[i-1]:
continue
used[i] = True
path.append(nums[i])
backtrack(path, used)
path.pop()
used[i] = False
backtrack([], [False] * len(nums))
return result
Combinations & Subsets
def combinations(n: int, k: int) -> list[list[int]]:
"""All combinations of k numbers from 1..n.
Key: start index prevents duplicates (only pick forward).
Time: O(C(n,k) * k) Space: O(k)
"""
result: list[list[int]] = []
def backtrack(start: int, path: list[int]) -> None:
if len(path) == k:
result.append(path[:])
return
# Pruning: need (k - len(path)) more elements, so stop early
for i in range(start, n + 1 - (k - len(path)) + 1):
path.append(i)
backtrack(i + 1, path)
path.pop()
backtrack(1, [])
return result
def combination_sum(candidates: list[int], target: int) -> list[list[int]]:
"""Combinations that sum to target; each number usable unlimited times."""
candidates.sort()
result: list[list[int]] = []
def backtrack(start: int, path: list[int], remaining: int) -> None:
if remaining == 0:
result.append(path[:])
return
for i in range(start, len(candidates)):
if candidates[i] > remaining:
break # pruning: sorted, so rest are bigger
path.append(candidates[i])
backtrack(i, path, remaining - candidates[i]) # i, not i+1: reuse allowed
path.pop()
backtrack(0, [], target)
return result
def subsets(nums: list[int]) -> list[list[int]]:
"""Power set: all 2^n subsets.
Time: O(2^n * n) Space: O(n)
"""
result: list[list[int]] = []
def backtrack(start: int, path: list[int]) -> None:
result.append(path[:]) # every state is a valid subset
for i in range(start, len(nums)):
path.append(nums[i])
backtrack(i + 1, path)
path.pop()
backtrack(0, [])
return result
N-Queens
def solve_n_queens(n: int) -> list[list[str]]:
"""Place n queens on n×n board with no two queens attacking each other.
Track columns and both diagonals to check attacks in O(1).
Time: O(n!) Space: O(n)
"""
result: list[list[str]] = []
queens: list[int] = [] # queens[row] = col
cols: set[int] = set()
diag1: set[int] = set() # row - col (NW-SE diagonal)
diag2: set[int] = set() # row + col (NE-SW diagonal)
def backtrack(row: int) -> None:
if row == n:
board = ['.' * c + 'Q' + '.' * (n - c - 1)
for c in queens]
result.append(board)
return
for col in range(n):
if col in cols or (row - col) in diag1 or (row + col) in diag2:
continue
cols.add(col); diag1.add(row - col); diag2.add(row + col)
queens.append(col)
backtrack(row + 1)
queens.pop()
cols.remove(col); diag1.remove(row - col); diag2.remove(row + col)
backtrack(0)
return result
Word Search in Grid
def exist(board: list[list[str]], word: str) -> bool:
"""Check if word exists in grid (any path, no reuse of cells).
Time: O(rows * cols * 4^len(word)) Space: O(len(word))
"""
rows, cols = len(board), len(board[0])
def dfs(r: int, c: int, idx: int) -> bool:
if idx == len(word):
return True
if (r < 0 or r >= rows or c < 0 or c >= cols
or board[r][c] != word[idx]):
return False
temp = board[r][c]
board[r][c] = '#' # mark visited
found = any(
dfs(r + dr, c + dc, idx + 1)
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]
)
board[r][c] = temp # restore
return found
return any(
dfs(r, c, 0)
for r in range(rows)
for c in range(cols)
)
2. Sort first: enables break statements when candidates exceed target.
3. Skip duplicates: sort + skip consecutive equal values at same decision level.
4. Symmetry breaking: encode constraints that eliminate mirror solutions (e.g. first queen in left half only).
Generate All Subsets
Return all subsets (the power set) of a list of distinct integers. Order within
the output list does not matter. Expected output: 8 subsets total.
def subsets(nums):
result = []
def backtrack(start, current):
# __ Record a copy of the current subset
result.append(current[:])
for i in range(start, len(nums)):
# __ Include nums[i]
current.append(nums[i])
backtrack(i + 1, current)
# __ Undo the choice (backtrack)
current.pop()
backtrack(0, [])
return result
res = subsets([1, 2, 3])
print(len(res)) # 8
print(sorted(res)) # [[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] nums, int start, List<Integer> current) {
// Record a copy of the current subset
result.add(new ArrayList<>(current));
for (int i = start; i < nums.length; i++) {
// Include nums[i]
current.add(nums[i]);
backtrack(nums, i + 1, current);
// Undo the choice (backtrack)
current.remove(current.size() - 1);
}
}
public static List<List<Integer>> subsets(int[] nums) {
result.clear();
backtrack(nums, 0, new ArrayList<>());
return result;
}
public static void main(String[] args) {
int[] nums = {1, 2, 3};
List<List<Integer>> res = subsets(nums);
System.out.println(res.size()); // 8
}
}
start onwards, include that element, recurse with start = i + 1, then remove it. This naturally generates all 2^n subsets.def subsets(nums):
result = []
def backtrack(start, current):
result.append(current[:])
for i in range(start, len(nums)):
current.append(nums[i])
backtrack(i + 1, current)
current.pop()
backtrack(0, [])
return result
res = subsets([1, 2, 3])
print(len(res)) # 8
print(sorted(res))
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] nums, int start, List<Integer> current) {
result.add(new ArrayList<>(current));
for (int i = start; i < nums.length; i++) {
current.add(nums[i]);
backtrack(nums, i + 1, current);
current.remove(current.size() - 1);
}
}
public static List<List<Integer>> subsets(int[] nums) {
result.clear();
backtrack(nums, 0, new ArrayList<>());
return result;
}
public static void main(String[] args) {
int[] nums = {1, 2, 3};
List<List<Integer>> res = subsets(nums);
System.out.println(res.size()); // 8
}
}
Generate Permutations
Return all permutations of distinct integers. We check the count, not the order.
Expected output: 6
def permutations(nums):
result = []
used = [False] * len(nums)
def backtrack(path):
# __ Base case: full-length path is a complete permutation
if len(path) == len(nums):
result.append(path[:])
return
for i in range(len(nums)):
if used[i]:
continue
# __ Choose nums[i]
used[i] = True
path.append(nums[i])
backtrack(path)
# __ Unchoose (backtrack)
path.pop()
used[i] = False
backtrack([])
return result
print(len(permutations([1, 2, 3]))) # 6
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] nums, boolean[] used, List<Integer> path) {
// Base case: full-length path is a complete permutation
if (path.size() == nums.length) {
result.add(new ArrayList<>(path));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) continue;
// Choose nums[i]
used[i] = true;
path.add(nums[i]);
backtrack(nums, used, path);
// Unchoose (backtrack)
path.remove(path.size() - 1);
used[i] = false;
}
}
public static List<List<Integer>> permutations(int[] nums) {
result.clear();
backtrack(nums, new boolean[nums.length], new ArrayList<>());
return result;
}
public static void main(String[] args) {
int[] nums = {1, 2, 3};
System.out.println(permutations(nums).size()); // 6
}
}
used[] boolean array. At each recursion level, try every unused element. Mark it used, append to path, recurse. When path length equals n, record a copy. Then undo: pop from path, mark unused. This gives n! permutations.def permutations(nums):
result = []
used = [False] * len(nums)
def backtrack(path):
if len(path) == len(nums):
result.append(path[:])
return
for i in range(len(nums)):
if used[i]:
continue
used[i] = True
path.append(nums[i])
backtrack(path)
path.pop()
used[i] = False
backtrack([])
return result
print(len(permutations([1, 2, 3]))) # 6
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] nums, boolean[] used, List<Integer> path) {
if (path.size() == nums.length) {
result.add(new ArrayList<>(path));
return;
}
for (int i = 0; i < nums.length; i++) {
if (used[i]) continue;
used[i] = true;
path.add(nums[i]);
backtrack(nums, used, path);
path.remove(path.size() - 1);
used[i] = false;
}
}
public static List<List<Integer>> permutations(int[] nums) {
result.clear();
backtrack(nums, new boolean[nums.length], new ArrayList<>());
return result;
}
public static void main(String[] args) {
int[] nums = {1, 2, 3};
System.out.println(permutations(nums).size()); // 6
}
}
Combination Sum
Find all unique combinations that sum to target. Each candidate can be
used multiple times. Expected output: [[2, 2, 3], [7]]
def combination_sum(candidates, target):
candidates.sort()
result = []
def backtrack(start, current, remaining):
if remaining == 0:
result.append(current[:])
return
for i in range(start, len(candidates)):
c = candidates[i]
# __ Pruning: no point trying larger candidates
if c > remaining:
break
current.append(c)
# __ Use same index i (not i+1) to allow reuse
backtrack(i, current, remaining - c)
current.pop()
backtrack(0, [], target)
return result
print(combination_sum([2, 3, 6, 7], 7)) # [[2, 2, 3], [7]]
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] candidates, int start,
List<Integer> current, int remaining) {
if (remaining == 0) {
result.add(new ArrayList<>(current));
return;
}
for (int i = start; i < candidates.length; i++) {
int c = candidates[i];
// Pruning: no point trying larger candidates
if (c > remaining) break;
current.add(c);
// Use same index i (not i+1) to allow reuse
backtrack(candidates, i, current, remaining - c);
current.remove(current.size() - 1);
}
}
public static List<List<Integer>> combinationSum(int[] candidates, int target) {
result.clear();
Arrays.sort(candidates);
backtrack(candidates, 0, new ArrayList<>(), target);
return result;
}
public static void main(String[] args) {
int[] candidates = {2, 3, 6, 7};
System.out.println(combinationSum(candidates, 7));
// [[2, 2, 3], [7]]
}
}
start index to avoid revisiting earlier candidates (preventing duplicates). Pass the same i (not i+1) when recursing to allow reusing the same number. Break early when candidate > remaining.def combination_sum(candidates, target):
candidates.sort()
result = []
def backtrack(start, current, remaining):
if remaining == 0:
result.append(current[:])
return
for i in range(start, len(candidates)):
c = candidates[i]
if c > remaining:
break
current.append(c)
backtrack(i, current, remaining - c)
current.pop()
backtrack(0, [], target)
return result
print(combination_sum([2, 3, 6, 7], 7)) # [[2, 2, 3], [7]]
import java.util.*;
public class Solution {
static List<List<Integer>> result = new ArrayList<>();
static void backtrack(int[] candidates, int start,
List<Integer> current, int remaining) {
if (remaining == 0) {
result.add(new ArrayList<>(current));
return;
}
for (int i = start; i < candidates.length; i++) {
int c = candidates[i];
if (c > remaining) break;
current.add(c);
backtrack(candidates, i, current, remaining - c);
current.remove(current.size() - 1);
}
}
public static List<List<Integer>> combinationSum(int[] candidates, int target) {
result.clear();
Arrays.sort(candidates);
backtrack(candidates, 0, new ArrayList<>(), target);
return result;
}
public static void main(String[] args) {
int[] candidates = {2, 3, 6, 7};
System.out.println(combinationSum(candidates, 7));
// [[2, 2, 3], [7]]
}
}
15. Greedy Algorithms
A greedy algorithm makes the locally optimal choice at each step. It works when the problem has the greedy choice property (a global optimum can be built from local optima) and optimal substructure. If you cannot prove these, use DP instead.
| Greedy | DP |
|---|---|
| One pass, irrevocable choices | Explore multiple options per step |
| Greedy choice property provable | Optimal substructure only |
| O(n log n) or O(n) typical | O(n^2) or O(n * capacity) typical |
| Interval scheduling, MST, Huffman | Knapsack, LCS, Edit Distance |
Interval Scheduling
def merge_intervals(intervals: list[list[int]]) -> list[list[int]]:
"""Merge all overlapping intervals.
Time: O(n log n) Space: O(n)
"""
if not intervals:
return []
intervals.sort(key=lambda x: x[0]) # sort by start
merged = [intervals[0]]
for start, end in intervals[1:]:
if start <= merged[-1][1]: # overlaps with last merged
merged[-1][1] = max(merged[-1][1], end)
else:
merged.append([start, end])
return merged
def meeting_rooms_ii(intervals: list[list[int]]) -> int:
"""Minimum number of rooms to schedule all meetings (no overlap per room).
Key insight: track when the earliest-ending meeting finishes.
Time: O(n log n) Space: O(n)
"""
if not intervals:
return 0
intervals.sort(key=lambda x: x[0])
heap: list[int] = [] # stores end times of active meetings
for start, end in intervals:
if heap and heap[0] <= start: # a room is free
heapq.heapreplace(heap, end)
else:
heapq.heappush(heap, end) # need a new room
return len(heap)
def non_overlapping_intervals(intervals: list[list[int]]) -> int:
"""Minimum intervals to remove so none overlap.
Greedy: always keep the interval with the earliest end time.
Time: O(n log n) Space: O(1)
"""
intervals.sort(key=lambda x: x[1]) # sort by end time
removed = 0
last_end = float('-inf')
for start, end in intervals:
if start >= last_end:
last_end = end # keep this interval
else:
removed += 1 # remove the later-ending one
return removed
Jump Game Variants
def can_jump(nums: list[int]) -> bool:
"""Can you reach the last index? Each element = max jump from that index.
Greedy: track the maximum reachable index.
Time: O(n) Space: O(1)
"""
max_reach = 0
for i, jump in enumerate(nums):
if i > max_reach:
return False # current index is unreachable
max_reach = max(max_reach, i + jump)
return True
def jump_game_ii(nums: list[int]) -> int:
"""Minimum jumps to reach last index.
Greedy: at each step, choose the jump that extends reach the farthest.
Time: O(n) Space: O(1)
"""
jumps = 0
current_end = 0 # end of current jump range
farthest = 0 # farthest we can reach from current range
for i in range(len(nums) - 1):
farthest = max(farthest, i + nums[i])
if i == current_end: # must jump now
jumps += 1
current_end = farthest
return jumps
Task Scheduler
from collections import Counter
def least_interval(tasks: list[str], n: int) -> int:
"""Minimum time to execute all tasks with cooldown n between same tasks.
Key insight: the most frequent task determines the lower bound.
Time: O(len(tasks)) Space: O(1) — at most 26 task types
"""
freq = Counter(tasks)
max_freq = max(freq.values())
# Number of tasks with maximum frequency
max_count = sum(1 for f in freq.values() if f == max_freq)
# Formula: (max_freq - 1) * (n + 1) + max_count
# The last chunk may not be full — take max with len(tasks)
return max(len(tasks), (max_freq - 1) * (n + 1) + max_count)
Best Time to Buy and Sell Stock
Given daily prices, find the maximum profit from one buy and one sell transaction.
You must buy before you sell. Expected output: 5
def max_profit(prices):
# __ Track the lowest price seen so far
min_price = float('inf')
max_profit = 0
for price in prices:
if price < min_price:
# __ Found a new minimum — update buying price
min_price = price
else:
# __ Check if selling today beats our best profit
max_profit = max(max_profit, price - min_price)
return max_profit
print(max_profit([7, 1, 5, 3, 6, 4])) # 5 (buy at 1, sell at 6)
public class Solution {
public static int maxProfit(int[] prices) {
// Track the lowest price seen so far
int minPrice = Integer.MAX_VALUE;
int maxProfit = 0;
for (int price : prices) {
if (price < minPrice) {
// Found a new minimum — update buying price
minPrice = price;
} else {
// Check if selling today beats our best profit
maxProfit = Math.max(maxProfit, price - minPrice);
}
}
return maxProfit;
}
public static void main(String[] args) {
System.out.println(maxProfit(new int[]{7, 1, 5, 3, 6, 4})); // 5
}
}
price - min_price and update the running maximum. This works because the minimum before any given day is always the best day to have bought.def max_profit(prices):
min_price = float('inf')
max_profit = 0
for price in prices:
if price < min_price:
min_price = price
else:
max_profit = max(max_profit, price - min_price)
return max_profit
print(max_profit([7, 1, 5, 3, 6, 4])) # 5
public class Solution {
public static int maxProfit(int[] prices) {
int minPrice = Integer.MAX_VALUE;
int maxProfit = 0;
for (int price : prices) {
if (price < minPrice) {
minPrice = price;
} else {
maxProfit = Math.max(maxProfit, price - minPrice);
}
}
return maxProfit;
}
public static void main(String[] args) {
System.out.println(maxProfit(new int[]{7, 1, 5, 3, 6, 4})); // 5
}
}
Jump Game
Each element in nums is the maximum jump length from that position.
Determine if you can reach the last index. Expected output: True then False
def can_jump(nums):
# __ Farthest index reachable so far
max_reach = 0
for i, jump in enumerate(nums):
# __ If current index is beyond what we can reach, we're stuck
if i > max_reach:
return False
# __ Extend max reach from current position
max_reach = max(max_reach, i + jump)
return True
print(can_jump([2, 3, 1, 1, 4])) # True
print(can_jump([3, 2, 1, 0, 4])) # False
public class Solution {
public static boolean canJump(int[] nums) {
// Farthest index reachable so far
int maxReach = 0;
for (int i = 0; i < nums.length; i++) {
// If current index is beyond what we can reach, we're stuck
if (i > maxReach) return false;
// Extend max reach from current position
maxReach = Math.max(maxReach, i + nums[i]);
}
return true;
}
public static void main(String[] args) {
System.out.println(canJump(new int[]{2, 3, 1, 1, 4})); // true
System.out.println(canJump(new int[]{3, 2, 1, 0, 4})); // false
}
}
max_reach: the farthest index reachable so far. For each index i, if i > max_reach you cannot get here — return false. Otherwise update max_reach = max(max_reach, i + nums[i]). If you never get stuck, return true.def can_jump(nums):
max_reach = 0
for i, jump in enumerate(nums):
if i > max_reach:
return False
max_reach = max(max_reach, i + jump)
return True
print(can_jump([2, 3, 1, 1, 4])) # True
print(can_jump([3, 2, 1, 0, 4])) # False
public class Solution {
public static boolean canJump(int[] nums) {
int maxReach = 0;
for (int i = 0; i < nums.length; i++) {
if (i > maxReach) return false;
maxReach = Math.max(maxReach, i + nums[i]);
}
return true;
}
public static void main(String[] args) {
System.out.println(canJump(new int[]{2, 3, 1, 1, 4})); // true
System.out.println(canJump(new int[]{3, 2, 1, 0, 4})); // false
}
}
Non-overlapping Intervals
Find the minimum number of intervals to remove so the rest do not overlap.
Expected output: 1
def erase_overlap(intervals):
# __ Sort by end time (greedy: keep intervals that end earliest)
intervals.sort(key=lambda x: x[1])
removals = 0
# __ End time of the last interval we kept
prev_end = float('-inf')
for start, end in intervals:
if start >= prev_end:
# __ No overlap: keep this interval
prev_end = end
else:
# __ Overlap: remove this interval (it ends later since we sorted by end)
removals += 1
return removals
print(erase_overlap([[1, 2], [2, 3], [3, 4], [1, 3]])) # 1
import java.util.*;
public class Solution {
public static int eraseOverlapIntervals(int[][] intervals) {
// Sort by end time (greedy: keep intervals that end earliest)
Arrays.sort(intervals, Comparator.comparingInt(a -> a[1]));
int removals = 0;
// End time of the last interval we kept
int prevEnd = Integer.MIN_VALUE;
for (int[] interval : intervals) {
int start = interval[0], end = interval[1];
if (start >= prevEnd) {
// No overlap: keep this interval
prevEnd = end;
} else {
// Overlap: remove this interval
removals++;
}
}
return removals;
}
public static void main(String[] args) {
int[][] intervals = {{1,2},{2,3},{3,4},{1,3}};
System.out.println(eraseOverlapIntervals(intervals)); // 1
}
}
prev_end. If start >= prev_end, no overlap — keep it and update prev_end. Otherwise, count a removal (we discard the current interval since the kept one ends sooner).def erase_overlap(intervals):
intervals.sort(key=lambda x: x[1])
removals = 0
prev_end = float('-inf')
for start, end in intervals:
if start >= prev_end:
prev_end = end
else:
removals += 1
return removals
print(erase_overlap([[1, 2], [2, 3], [3, 4], [1, 3]])) # 1
import java.util.*;
public class Solution {
public static int eraseOverlapIntervals(int[][] intervals) {
Arrays.sort(intervals, Comparator.comparingInt(a -> a[1]));
int removals = 0;
int prevEnd = Integer.MIN_VALUE;
for (int[] interval : intervals) {
int start = interval[0], end = interval[1];
if (start >= prevEnd) {
prevEnd = end;
} else {
removals++;
}
}
return removals;
}
public static void main(String[] args) {
int[][] intervals = {{1,2},{2,3},{3,4},{1,3}};
System.out.println(eraseOverlapIntervals(intervals)); // 1
}
}
16. Tries (Prefix Trees)
A trie stores strings character by character, enabling O(L) insert/search where L is word length. Ideal for prefix lookups, autocomplete, and spell checking.
Trie Node — Insert / Search / StartsWith
class TrieNode:
def __init__(self) -> None:
self.children: dict[str, 'TrieNode'] = {}
self.is_end: bool = False
class Trie:
"""Prefix tree supporting insert, search, and startsWith.
Time: O(L) per operation Space: O(total characters inserted)
"""
def __init__(self) -> None:
self.root = TrieNode()
def insert(self, word: str) -> None:
node = self.root
for ch in word:
if ch not in node.children:
node.children[ch] = TrieNode()
node = node.children[ch]
node.is_end = True
def search(self, word: str) -> bool:
"""True if word is in the trie."""
node = self._traverse(word)
return node is not None and node.is_end
def starts_with(self, prefix: str) -> bool:
"""True if any word in trie starts with prefix."""
return self._traverse(prefix) is not None
def _traverse(self, s: str) -> 'TrieNode | None':
node = self.root
for ch in s:
if ch not in node.children:
return None
node = node.children[ch]
return node
Word Dictionary with Wildcards
class WordDictionary:
"""Supports addWord and search where '.' matches any character.
search uses DFS when '.' is encountered.
Time: O(L) insert; O(26^L) worst case search with many '.'s.
"""
def __init__(self) -> None:
self.root = TrieNode()
def add_word(self, word: str) -> None:
node = self.root
for ch in word:
if ch not in node.children:
node.children[ch] = TrieNode()
node = node.children[ch]
node.is_end = True
def search(self, word: str) -> bool:
return self._dfs(self.root, word, 0)
def _dfs(self, node: TrieNode, word: str, idx: int) -> bool:
if idx == len(word):
return node.is_end
ch = word[idx]
if ch == '.':
return any(
self._dfs(child, word, idx + 1)
for child in node.children.values()
)
if ch not in node.children:
return False
return self._dfs(node.children[ch], word, idx + 1)
Design Autocomplete System
import heapq
from collections import defaultdict
def find_words_with_prefix(words: list[str],
prefix: str,
top_k: int = 3) -> list[str]:
"""Return top-k lexicographically smallest words matching prefix.
Uses a trie to navigate to the prefix node, then DFS to collect all words.
Time: O(L + W) where W = total characters in matching words.
"""
trie = Trie()
for w in words:
trie.insert(w)
node = trie._traverse(prefix)
if node is None:
return []
results: list[str] = []
def dfs(cur: TrieNode, path: str) -> None:
if len(results) >= top_k:
return
if cur.is_end:
results.append(path)
for ch in sorted(cur.children): # sorted for lexicographic order
dfs(cur.children[ch], path + ch)
dfs(node, prefix)
return results
XOR Trie for Maximum XOR
class XorTrie:
"""Binary trie for maximum XOR of two numbers in an array.
Each number stored as 32 binary bits, MSB first.
Time: O(n * 32) = O(n)
"""
def __init__(self) -> None:
self.children: list['XorTrie | None'] = [None, None]
def insert(self, num: int) -> None:
node = self
for bit in range(31, -1, -1):
b = (num >> bit) & 1
if node.children[b] is None:
node.children[b] = XorTrie()
node = node.children[b]
def max_xor_with(self, num: int) -> int:
"""Find value in trie that maximizes XOR with num."""
node = self
xor = 0
for bit in range(31, -1, -1):
b = (num >> bit) & 1
want = 1 - b # prefer opposite bit
if node.children[want] is not None:
xor |= (1 << bit)
node = node.children[want]
else:
node = node.children[b]
return xor
def find_maximum_xor(nums: list[int]) -> int:
"""LeetCode 421: Maximum XOR of Two Numbers in an Array.
Time: O(n) Space: O(n)
"""
trie = XorTrie()
for x in nums:
trie.insert(x)
return max(trie.max_xor_with(x) for x in nums)
Implement Trie
Implement insert and search for a basic trie. After inserting "apple" and "app", searching "apple" returns True, "app" returns True, and "appl" returns False.
class TrieNode: def __init__(self): self.children = {} self.is_end = False class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): # TODO: walk char by char, create nodes as needed # mark is_end = True at the last node pass def search(self, word): # TODO: walk char by char, return False if node missing # return is_end at the last node pass t = Trie() t.insert("apple") t.insert("app") print(t.search("apple")) # True print(t.search("app")) # True print(t.search("appl")) # Falseimport java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); boolean isEnd = false; } static TrieNode root = new TrieNode(); static void insert(String word) { // TODO: walk char by char, create nodes as needed // mark isEnd = true at the last node } static boolean search(String word) { // TODO: walk char by char, return false if node missing // return isEnd at the last node return false; } public static void main(String[] args) { insert("apple"); insert("app"); System.out.println(search("apple")); // true System.out.println(search("app")); // true System.out.println(search("appl")); // false } }
is_end flag. For insert, walk each character, creating a new node if missing, then set is_end = True on the last node. For search, walk each character, returning False if a node is missing, then return is_end.class TrieNode: def __init__(self): self.children = {} self.is_end = False class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): node = self.root for ch in word: if ch not in node.children: node.children[ch] = TrieNode() node = node.children[ch] node.is_end = True def search(self, word): node = self.root for ch in word: if ch not in node.children: return False node = node.children[ch] return node.is_end t = Trie() t.insert("apple") t.insert("app") print(t.search("apple")) # True print(t.search("app")) # True print(t.search("appl")) # Falseimport java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); boolean isEnd = false; } static TrieNode root = new TrieNode(); static void insert(String word) { TrieNode node = root; for (char ch : word.toCharArray()) { node.children.putIfAbsent(ch, new TrieNode()); node = node.children.get(ch); } node.isEnd = true; } static boolean search(String word) { TrieNode node = root; for (char ch : word.toCharArray()) { if (!node.children.containsKey(ch)) return false; node = node.children.get(ch); } return node.isEnd; } public static void main(String[] args) { insert("apple"); insert("app"); System.out.println(search("apple")); // true System.out.println(search("app")); // true System.out.println(search("appl")); // false } }
Prefix Count
Count how many inserted words start with a given prefix. After inserting ["apple", "app", "apricot", "banana"], prefix "ap" → 3, prefix "ban" → 1, prefix "c" → 0.
class TrieNode: def __init__(self): self.children = {} self.count = 0 # words passing through this node class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): # TODO: walk each char, increment count on every node visited pass def prefix_count(self, prefix): # TODO: walk to the end of prefix, return count # return 0 if prefix not found pass t = Trie() for w in ["apple", "app", "apricot", "banana"]: t.insert(w) print(t.prefix_count("ap")) # 3 print(t.prefix_count("ban")) # 1 print(t.prefix_count("c")) # 0import java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); int count = 0; // words passing through this node } static TrieNode root = new TrieNode(); static void insert(String word) { // TODO: walk each char, increment count on every node visited } static int prefixCount(String prefix) { // TODO: walk to the end of prefix, return count // return 0 if prefix not found return 0; } public static void main(String[] args) { for (String w : new String[]{"apple", "app", "apricot", "banana"}) { insert(w); } System.out.println(prefixCount("ap")); // 3 System.out.println(prefixCount("ban")); // 1 System.out.println(prefixCount("c")); // 0 } }
count field to each node. During insert, increment count on every node you visit (not just the last). For prefix_count, traverse to the end of the prefix and return that node's count. Return 0 if the prefix path doesn't exist.class TrieNode: def __init__(self): self.children = {} self.count = 0 class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): node = self.root for ch in word: if ch not in node.children: node.children[ch] = TrieNode() node = node.children[ch] node.count += 1 def prefix_count(self, prefix): node = self.root for ch in prefix: if ch not in node.children: return 0 node = node.children[ch] return node.count t = Trie() for w in ["apple", "app", "apricot", "banana"]: t.insert(w) print(t.prefix_count("ap")) # 3 print(t.prefix_count("ban")) # 1 print(t.prefix_count("c")) # 0import java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); int count = 0; } static TrieNode root = new TrieNode(); static void insert(String word) { TrieNode node = root; for (char ch : word.toCharArray()) { node.children.putIfAbsent(ch, new TrieNode()); node = node.children.get(ch); node.count++; } } static int prefixCount(String prefix) { TrieNode node = root; for (char ch : prefix.toCharArray()) { if (!node.children.containsKey(ch)) return 0; node = node.children.get(ch); } return node.count; } public static void main(String[] args) { for (String w : new String[]{"apple", "app", "apricot", "banana"}) { insert(w); } System.out.println(prefixCount("ap")); // 3 System.out.println(prefixCount("ban")); // 1 System.out.println(prefixCount("c")); // 0 } }
Search with Wildcards
Support '.' matching any single character in search. After inserting ["bad", "dad", "mad"]: ".ad" → True, "b.." → True, ".a" → False (wrong length).
class TrieNode: def __init__(self): self.children = {} self.is_end = False class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): node = self.root for ch in word: if ch not in node.children: node.children[ch] = TrieNode() node = node.children[ch] node.is_end = True def search(self, word, node=None): if node is None: node = self.root # TODO: walk each char; if '.', recurse into ALL children # if regular char, go to that specific child # at end of word, return is_end pass t = Trie() for w in ["bad", "dad", "mad"]: t.insert(w) print(t.search(".ad")) # True print(t.search("b..")) # True print(t.search("b.d")) # True print(t.search("...")) # True print(t.search(".a")) # Falseimport java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); boolean isEnd = false; } static TrieNode root = new TrieNode(); static void insert(String word) { TrieNode node = root; for (char ch : word.toCharArray()) { node.children.putIfAbsent(ch, new TrieNode()); node = node.children.get(ch); } node.isEnd = true; } static boolean search(String word, int idx, TrieNode node) { // TODO: if idx == word.length(), return node.isEnd // if word.charAt(idx) == '.', recurse into ALL children // otherwise go to the specific child return false; } public static void main(String[] args) { for (String w : new String[]{"bad", "dad", "mad"}) { insert(w); } System.out.println(search(".ad", 0, root)); // true System.out.println(search("b..", 0, root)); // true System.out.println(search("b.d", 0, root)); // true System.out.println(search("...", 0, root)); // true System.out.println(search(".a", 0, root)); // false } }
'.', recurse into all children with the remaining suffix — return True if any branch matches. When it's a regular character, follow only that child (or return False if missing). At the end of the word string, check is_end.class TrieNode: def __init__(self): self.children = {} self.is_end = False class Trie: def __init__(self): self.root = TrieNode() def insert(self, word): node = self.root for ch in word: if ch not in node.children: node.children[ch] = TrieNode() node = node.children[ch] node.is_end = True def search(self, word, node=None): if node is None: node = self.root for i, ch in enumerate(word): if ch == '.': return any( self.search(word[i + 1:], child) for child in node.children.values() ) if ch not in node.children: return False node = node.children[ch] return node.is_end t = Trie() for w in ["bad", "dad", "mad"]: t.insert(w) print(t.search(".ad")) # True print(t.search("b..")) # True print(t.search("b.d")) # True print(t.search("...")) # True print(t.search(".a")) # Falseimport java.util.HashMap; public class Solution { static class TrieNode { HashMap<Character, TrieNode> children = new HashMap<>(); boolean isEnd = false; } static TrieNode root = new TrieNode(); static void insert(String word) { TrieNode node = root; for (char ch : word.toCharArray()) { node.children.putIfAbsent(ch, new TrieNode()); node = node.children.get(ch); } node.isEnd = true; } static boolean search(String word, int idx, TrieNode node) { if (idx == word.length()) return node.isEnd; char ch = word.charAt(idx); if (ch == '.') { for (TrieNode child : node.children.values()) { if (search(word, idx + 1, child)) return true; } return false; } if (!node.children.containsKey(ch)) return false; return search(word, idx + 1, node.children.get(ch)); } public static void main(String[] args) { for (String w : new String[]{"bad", "dad", "mad"}) { insert(w); } System.out.println(search(".ad", 0, root)); // true System.out.println(search("b..", 0, root)); // true System.out.println(search("b.d", 0, root)); // true System.out.println(search("...", 0, root)); // true System.out.println(search(".a", 0, root)); // false } }
17. Bit Manipulation
Essential Operations
| Operation | Expression | Notes |
|---|---|---|
| Check bit k | (n >> k) & 1 | 1 if bit k is set |
| Set bit k | n | (1 << k) | Turn on bit k |
| Clear bit k | n & ~(1 << k) | Turn off bit k |
| Toggle bit k | n ^ (1 << k) | Flip bit k |
| Clear lowest set bit | n & (n - 1) | Brian Kernighan trick |
| Isolate lowest set bit | n & (-n) | Also n & (~n + 1) |
| Check power of 2 | n > 0 and (n & (n-1)) == 0 | |
| XOR self | n ^ n == 0 | Cancels out pairs |
| XOR with 0 | n ^ 0 == n | Identity |
| Swap without temp | a ^= b; b ^= a; a ^= b | Careful with a == b |
Common Bit Tricks
def count_set_bits_kernighan(n: int) -> int:
"""Count 1-bits using Brian Kernighan's algorithm.
Each iteration clears the lowest set bit.
Time: O(number of set bits) — faster than O(32) when bits are sparse.
"""
count = 0
while n:
n &= n - 1 # clear lowest set bit
count += 1
return count
def count_set_bits_builtin(n: int) -> int:
"""Python built-in: bin(n).count('1') or int.bit_count() in 3.10+"""
return bin(n).count('1')
def single_number(nums: list[int]) -> int:
"""Find the one element that appears once; all others appear twice.
XOR: a ^ a = 0, a ^ 0 = a. XOR all elements; pairs cancel out.
Time: O(n) Space: O(1)
"""
result = 0
for x in nums:
result ^= x
return result
def single_number_iii(nums: list[int]) -> list[int]:
"""Two elements appear once; all others appear twice. Return the two.
Step 1: XOR all -> xor = a ^ b (where a, b are the two singles).
Step 2: Find any set bit in xor (a and b differ here).
Step 3: Partition nums by that bit — each group XORs down to one of a,b.
Time: O(n) Space: O(1)
"""
xor = 0
for x in nums:
xor ^= x
diff_bit = xor & (-xor) # isolate any differing bit
a = 0
for x in nums:
if x & diff_bit:
a ^= x
return [a, xor ^ a]
def reverse_bits(n: int) -> int:
"""Reverse the 32 bits of a 32-bit unsigned integer."""
result = 0
for _ in range(32):
result = (result << 1) | (n & 1)
n >>= 1
return result
Subsets via Bitmask Enumeration
def all_subsets_bitmask(nums: list[int]) -> list[list[int]]:
"""Enumerate all 2^n subsets using bitmask.
Bit i of mask indicates whether nums[i] is included.
Time: O(2^n * n) Space: O(n)
"""
n = len(nums)
result: list[list[int]] = []
for mask in range(1 << n): # 0 to 2^n - 1
subset = [nums[i] for i in range(n) if mask & (1 << i)]
result.append(subset)
return result
def count_bits(n: int) -> list[int]:
"""For every i in 0..n, count number of 1-bits. LeetCode 338.
DP: ans[i] = ans[i >> 1] + (i & 1)
Time: O(n) Space: O(n)
"""
dp = [0] * (n + 1)
for i in range(1, n + 1):
dp[i] = dp[i >> 1] + (i & 1)
return dp
Single Number
Every element appears exactly twice except one. Find the element that appears only once. [4, 1, 2, 1, 2] → 4.
def single_number(nums): # TODO: XOR all elements together # pairs cancel out (a ^ a = 0), leaving the single one pass print(single_number([4, 1, 2, 1, 2])) # 4 print(single_number([2, 2, 7])) # 7public class Solution { static int singleNumber(int[] nums) { // TODO: XOR all elements together // pairs cancel out (a ^ a = 0), leaving the single one return 0; } public static void main(String[] args) { System.out.println(singleNumber(new int[]{4, 1, 2, 1, 2})); // 4 System.out.println(singleNumber(new int[]{2, 2, 7})); // 7 } }
a ^ a = 0 and a ^ 0 = a, every duplicated pair cancels out. What remains is the single element. One line: return functools.reduce(lambda a, b: a ^ b, nums), or a simple loop with result ^= n.def single_number(nums): result = 0 for n in nums: result ^= n return result print(single_number([4, 1, 2, 1, 2])) # 4 print(single_number([2, 2, 7])) # 7public class Solution { static int singleNumber(int[] nums) { int result = 0; for (int n : nums) { result ^= n; } return result; } public static void main(String[] args) { System.out.println(singleNumber(new int[]{4, 1, 2, 1, 2})); // 4 System.out.println(singleNumber(new int[]{2, 2, 7})); // 7 } }
Count Set Bits
Count the number of 1-bits in an integer. count_bits(11) → 3 (binary 1011 has three 1s).
def count_bits(n): # TODO: use Brian Kernighan's trick # n &= (n - 1) clears the lowest set bit each iteration count = 0 # your loop here return count print(count_bits(11)) # 3 (1011) print(count_bits(7)) # 3 (0111) print(count_bits(0)) # 0public class Solution { static int countBits(int n) { // TODO: use Brian Kernighan's trick // n &= (n - 1) clears the lowest set bit each iteration int count = 0; // your loop here return count; } public static void main(String[] args) { System.out.println(countBits(11)); // 3 (1011) System.out.println(countBits(7)); // 3 (0111) System.out.println(countBits(0)); // 0 } }
n != 0, do n &= (n - 1) (clears the lowest set bit) and increment a counter. The loop runs exactly as many times as there are set bits — much faster than checking every bit position.def count_bits(n): count = 0 while n: n &= (n - 1) count += 1 return count print(count_bits(11)) # 3 (1011) print(count_bits(7)) # 3 (0111) print(count_bits(0)) # 0public class Solution { static int countBits(int n) { int count = 0; while (n != 0) { n &= (n - 1); count++; } return count; } public static void main(String[] args) { System.out.println(countBits(11)); // 3 (1011) System.out.println(countBits(7)); // 3 (0111) System.out.println(countBits(0)); // 0 } }
Power of Two
Check if n is a power of two. is_power_of_two(16) → True, is_power_of_two(18) → False.
def is_power_of_two(n): # TODO: a power of 2 has exactly one set bit # use: n > 0 and (n & (n - 1)) == 0 pass print(is_power_of_two(16)) # True (10000) print(is_power_of_two(18)) # False (10010) print(is_power_of_two(1)) # True (2^0) print(is_power_of_two(0)) # Falsepublic class Solution { static boolean isPowerOfTwo(int n) { // TODO: a power of 2 has exactly one set bit // use: n > 0 && (n & (n - 1)) == 0 return false; } public static void main(String[] args) { System.out.println(isPowerOfTwo(16)); // true (10000) System.out.println(isPowerOfTwo(18)); // false (10010) System.out.println(isPowerOfTwo(1)); // true (2^0) System.out.println(isPowerOfTwo(0)); // false } }
16 = 10000). The trick n & (n - 1) clears the lowest set bit — if the result is 0, there was only one set bit. Don't forget to check n > 0 since 0 is not a power of two.def is_power_of_two(n): return n > 0 and (n & (n - 1)) == 0 print(is_power_of_two(16)) # True (10000) print(is_power_of_two(18)) # False (10010) print(is_power_of_two(1)) # True (2^0) print(is_power_of_two(0)) # Falsepublic class Solution { static boolean isPowerOfTwo(int n) { return n > 0 && (n & (n - 1)) == 0; } public static void main(String[] args) { System.out.println(isPowerOfTwo(16)); // true (10000) System.out.println(isPowerOfTwo(18)); // false (10010) System.out.println(isPowerOfTwo(1)); // true (2^0) System.out.println(isPowerOfTwo(0)); // false } }
18. Math & Number Theory
GCD and LCM
import math
def gcd(a: int, b: int) -> int:
"""Euclidean algorithm. Time: O(log min(a,b))"""
while b:
a, b = b, a % b
return a
# Python 3.5+: math.gcd(a, b)
# Python 3.9+: math.gcd(*nums) for multiple values
def lcm(a: int, b: int) -> int:
"""LCM via GCD: lcm(a,b) = a * b // gcd(a,b)"""
return a * b // gcd(a, b)
# Python 3.9+: math.lcm(a, b)
Sieve of Eratosthenes
def sieve_of_eratosthenes(limit: int) -> list[int]:
"""All primes up to limit.
Time: O(n log log n) Space: O(n)
"""
is_prime = [True] * (limit + 1)
is_prime[0] = is_prime[1] = False
p = 2
while p * p <= limit:
if is_prime[p]:
for multiple in range(p * p, limit + 1, p):
is_prime[multiple] = False
p += 1
return [x for x in range(2, limit + 1) if is_prime[x]]
def count_primes(n: int) -> int:
"""Count primes strictly less than n. LeetCode 204."""
if n < 2:
return 0
is_prime = [True] * n
is_prime[0] = is_prime[1] = False
for p in range(2, int(n**0.5) + 1):
if is_prime[p]:
is_prime[p*p::p] = [False] * len(is_prime[p*p::p])
return sum(is_prime)
Modular Arithmetic
MOD = 10**9 + 7 # standard mod in competitive programming
# Key properties (all under MOD):
# (a + b) % MOD == ((a % MOD) + (b % MOD)) % MOD
# (a * b) % MOD == ((a % MOD) * (b % MOD)) % MOD
# (a - b) % MOD == ((a % MOD) - (b % MOD) + MOD) % MOD # add MOD to keep positive
# Division requires modular inverse: (a / b) % MOD = a * pow(b, MOD-2, MOD) % MOD
def fast_power(base: int, exp: int, mod: int) -> int:
"""Modular exponentiation: base^exp % mod.
Python built-in pow(base, exp, mod) does this optimally.
Time: O(log exp) Space: O(1)
"""
result = 1
base %= mod
while exp > 0:
if exp & 1:
result = result * base % mod
base = base * base % mod
exp >>= 1
return result
# Equivalent: return pow(base, exp, mod)
def mod_inverse(a: int, mod: int) -> int:
"""Modular multiplicative inverse using Fermat's little theorem.
Valid only when mod is prime.
a^(mod-2) % mod
"""
return pow(a, mod - 2, mod)
Combinatorics
def combinations_mod(n: int, r: int, mod: int = 10**9 + 7) -> int:
"""n choose r, modulo mod. Uses precomputed factorials.
Time: O(n) Space: O(n)
"""
if r > n or r < 0:
return 0
fact = [1] * (n + 1)
for i in range(1, n + 1):
fact[i] = fact[i-1] * i % mod
inv_fact = [1] * (n + 1)
inv_fact[n] = pow(fact[n], mod - 2, mod)
for i in range(n - 1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % mod
return fact[n] * inv_fact[r] % mod * inv_fact[n-r] % mod
def pascal_triangle(num_rows: int) -> list[list[int]]:
"""Pascal's triangle: row i contains C(i,0)..C(i,i).
Each entry = sum of two entries above it.
Time: O(n^2) Space: O(n^2)
"""
triangle: list[list[int]] = []
for i in range(num_rows):
row = [1] * (i + 1)
for j in range(1, i):
row[j] = triangle[i-1][j-1] + triangle[i-1][j]
triangle.append(row)
return triangle
Random Sampling
import random
def reservoir_sample(stream, k: int) -> list:
"""Select k items uniformly at random from a stream of unknown length.
Algorithm R (Vitter): each element i (0-indexed) replaces a reservoir slot
with probability k/(i+1).
Time: O(n) Space: O(k)
"""
reservoir = []
for i, item in enumerate(stream):
if i < k:
reservoir.append(item)
else:
j = random.randint(0, i)
if j < k:
reservoir[j] = item
return reservoir
def fisher_yates_shuffle(nums: list) -> list:
"""In-place uniform random shuffle.
Time: O(n) Space: O(1)
"""
n = len(nums)
for i in range(n - 1, 0, -1):
j = random.randint(0, i)
nums[i], nums[j] = nums[j], nums[i]
return nums
GCD
Compute the greatest common divisor using the Euclidean algorithm. gcd(48, 18) → 6, gcd(100, 75) → 25.
def gcd(a, b): # TODO: Euclidean algorithm # while b != 0: a, b = b, a % b # return a pass print(gcd(48, 18)) # 6 print(gcd(100, 75)) # 25 print(gcd(7, 3)) # 1public class Solution { static int gcd(int a, int b) { // TODO: Euclidean algorithm // while (b != 0): int tmp = b; b = a % b; a = tmp; // return a return 0; } public static void main(String[] args) { System.out.println(gcd(48, 18)); // 6 System.out.println(gcd(100, 75)); // 25 System.out.println(gcd(7, 3)); // 1 } }
b != 0, replace (a, b) with (b, a % b). When b reaches 0, a is the GCD. It works because gcd(a, b) == gcd(b, a % b).def gcd(a, b): while b: a, b = b, a % b return a print(gcd(48, 18)) # 6 print(gcd(100, 75)) # 25 print(gcd(7, 3)) # 1public class Solution { static int gcd(int a, int b) { while (b != 0) { int tmp = b; b = a % b; a = tmp; } return a; } public static void main(String[] args) { System.out.println(gcd(48, 18)); // 6 System.out.println(gcd(100, 75)); // 25 System.out.println(gcd(7, 3)); // 1 } }
Sieve of Eratosthenes
Find all prime numbers up to n (inclusive). sieve(30) → [2, 3, 5, 7, 11, 13, 17, 19, 23, 29].
def sieve(n): # TODO: create a boolean array is_prime[0..n], init True # mark 0 and 1 as False # for each i from 2 to sqrt(n): if is_prime[i], mark multiples False # return list of indices where is_prime is True pass print(sieve(30)) # [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]import java.util.*; public class Solution { static List<Integer> sieve(int n) { // TODO: boolean array is_prime[0..n], init true // mark 0 and 1 false // for each i from 2 to sqrt(n): if prime, mark multiples false // collect indices where is_prime is true return new ArrayList<>(); } public static void main(String[] args) { System.out.println(sieve(30)); // [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] } }
n+1 initialized to True. Set indices 0 and 1 to False. For each i from 2 to int(n**0.5) + 1, if is_prime[i] is True, mark i*i, i*i+i, ... up to n as False (start at i*i because smaller multiples were already marked).def sieve(n): is_prime = [True] * (n + 1) is_prime[0] = is_prime[1] = False for i in range(2, int(n ** 0.5) + 1): if is_prime[i]: for j in range(i * i, n + 1, i): is_prime[j] = False return [i for i in range(n + 1) if is_prime[i]] print(sieve(30)) # [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]import java.util.*; public class Solution { static List<Integer> sieve(int n) { boolean[] isPrime = new boolean[n + 1]; Arrays.fill(isPrime, true); isPrime[0] = isPrime[1] = false; for (int i = 2; (long) i * i <= n; i++) { if (isPrime[i]) { for (int j = i * i; j <= n; j += i) { isPrime[j] = false; } } } List<Integer> primes = new ArrayList<>(); for (int i = 2; i <= n; i++) { if (isPrime[i]) primes.add(i); } return primes; } public static void main(String[] args) { System.out.println(sieve(30)); // [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] } }
Reverse Integer
Reverse the digits of an integer. Return 0 if the result would overflow a 32-bit signed integer (-2^31 to 2^31 - 1). reverse(123) → 321, reverse(-456) → -654.
def reverse(x): sign = -1 if x < 0 else 1 x = abs(x) result = 0 # TODO: build result digit by digit # digit = x % 10, x //= 10, result = result * 10 + digit # at the end apply sign, check 32-bit bounds, return 0 if overflow return 0 print(reverse(123)) # 321 print(reverse(-456)) # -654 print(reverse(1534236469)) # 0 (overflow)public class Solution { static int reverse(int x) { int sign = x < 0 ? -1 : 1; long val = Math.abs((long) x); long result = 0; // TODO: build result digit by digit // digit = val % 10, val /= 10, result = result * 10 + digit // at the end apply sign, check 32-bit bounds, return 0 if overflow return 0; } public static void main(String[] args) { System.out.println(reverse(123)); // 321 System.out.println(reverse(-456)); // -654 System.out.println(reverse(1534236469)); // 0 (overflow) } }
digit = x % 10, x //= 10, result = result * 10 + digit. After the loop, reapply the sign. Check bounds: if result > 2**31 - 1 or result < -2**31, return 0.def reverse(x): sign = -1 if x < 0 else 1 x = abs(x) result = 0 while x: result = result * 10 + x % 10 x //= 10 result *= sign if result < -(2 ** 31) or result > 2 ** 31 - 1: return 0 return result print(reverse(123)) # 321 print(reverse(-456)) # -654 print(reverse(1534236469)) # 0 (overflow)public class Solution { static int reverse(int x) { int sign = x < 0 ? -1 : 1; long val = Math.abs((long) x); long result = 0; while (val != 0) { result = result * 10 + val % 10; val /= 10; } result *= sign; if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) return 0; return (int) result; } public static void main(String[] args) { System.out.println(reverse(123)); // 321 System.out.println(reverse(-456)); // -654 System.out.println(reverse(1534236469)); // 0 (overflow) } }
19. Design Problems
LRU Cache
Least Recently Used cache evicts the entry unused for the longest time. Requires O(1) get and O(1) put. Classic solution: HashMap + doubly linked list. Python's OrderedDict provides a concise built-in alternative.
from collections import OrderedDict
class LRUCache:
"""O(1) get and put using OrderedDict (move_to_end + popitem).
Time: O(1) amortized Space: O(capacity)
"""
def __init__(self, capacity: int) -> None:
self.cap = capacity
self.cache: OrderedDict[int, int] = OrderedDict()
def get(self, key: int) -> int:
if key not in self.cache:
return -1
self.cache.move_to_end(key) # mark as most recently used
return self.cache[key]
def put(self, key: int, value: int) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.cap:
self.cache.popitem(last=False) # evict LRU (front)
LRU Cache with explicit doubly linked list (interview-expected)
class DLLNode:
def __init__(self, key: int = 0, val: int = 0) -> None:
self.key = key
self.val = val
self.prev: 'DLLNode | None' = None
self.next: 'DLLNode | None' = None
class LRUCacheManual:
"""Doubly linked list + HashMap. Dummy head/tail simplify edge cases."""
def __init__(self, capacity: int) -> None:
self.cap = capacity
self.map: dict[int, DLLNode] = {}
self.head = DLLNode() # dummy head (LRU end)
self.tail = DLLNode() # dummy tail (MRU end)
self.head.next = self.tail
self.tail.prev = self.head
def _remove(self, node: DLLNode) -> None:
node.prev.next = node.next
node.next.prev = node.prev
def _add_to_tail(self, node: DLLNode) -> None:
node.prev = self.tail.prev
node.next = self.tail
self.tail.prev.next = node
self.tail.prev = node
def get(self, key: int) -> int:
if key not in self.map:
return -1
node = self.map[key]
self._remove(node)
self._add_to_tail(node) # promote to MRU
return node.val
def put(self, key: int, value: int) -> None:
if key in self.map:
self._remove(self.map[key])
node = DLLNode(key, value)
self._add_to_tail(node)
self.map[key] = node
if len(self.map) > self.cap:
lru = self.head.next # evict from LRU end
self._remove(lru)
del self.map[lru.key]
Min Stack
class MinStack:
"""Stack with O(1) getMin. Use a parallel stack tracking minimums.
Push min only when new value <= current min (saves space slightly).
"""
def __init__(self) -> None:
self.stack: list[int] = []
self.min_stack: list[int] = [] # parallel min tracking
def push(self, val: int) -> None:
self.stack.append(val)
if not self.min_stack or val <= self.min_stack[-1]:
self.min_stack.append(val)
def pop(self) -> None:
val = self.stack.pop()
if val == self.min_stack[-1]:
self.min_stack.pop()
def top(self) -> int:
return self.stack[-1]
def get_min(self) -> int:
return self.min_stack[-1]
Time-Based Key-Value Store
import bisect
class TimeMap:
"""Key-value store where each key maps to timestamped values.
get(key, timestamp) returns the value with largest timestamp <= given timestamp.
Uses binary search on stored timestamps.
Time: O(log n) get, O(1) set (timestamps always increase per problem spec)
"""
def __init__(self) -> None:
# key -> list of (timestamp, value), sorted by timestamp
self.store: dict[str, list[tuple[int, str]]] = {}
def set(self, key: str, value: str, timestamp: int) -> None:
if key not in self.store:
self.store[key] = []
self.store[key].append((timestamp, value))
def get(self, key: str, timestamp: int) -> str:
if key not in self.store:
return ""
entries = self.store[key]
# Find rightmost timestamp <= given timestamp
lo, hi = 0, len(entries) - 1
result = ""
while lo <= hi:
mid = (lo + hi) // 2
if entries[mid][0] <= timestamp:
result = entries[mid][1]
lo = mid + 1
else:
hi = mid - 1
return result
Flatten Nested Iterator
class NestedIterator:
"""Flatten a nested list of integers lazily.
LeetCode 341: each element is an integer OR a nested list.
Use a stack storing (list_ref, current_index) pairs.
Time: O(1) amortized per next()/hasNext() Space: O(depth)
"""
def __init__(self, nested_list: list) -> None:
# Stack of (list, index) tuples; start with the outermost list
self.stack = [(nested_list, 0)]
def next(self) -> int:
self.has_next() # ensure top of stack is an int
lst, idx = self.stack.pop()
self.stack.append((lst, idx + 1))
return lst[idx].get_integer()
def has_next(self) -> bool:
"""Advance through nested structure until an integer is at the top."""
while self.stack:
lst, idx = self.stack[-1]
if idx == len(lst):
self.stack.pop() # exhausted this list
continue
item = lst[idx]
if item.is_integer():
return True
# item is a nested list: push it and advance parent index
self.stack[-1] = (lst, idx + 1)
self.stack.append((item.get_list(), 0))
return False
LRU Cache Simplified
Implement get(key) and put(key, value) with a capacity limit. The least recently used entry is evicted when full. put(1,1), put(2,2), get(1)→1, put(3,3) evicts key 2, so get(2)→-1.
from collections import OrderedDict class LRUCache: def __init__(self, capacity): self.capacity = capacity self.cache = OrderedDict() def get(self, key): # TODO: if key missing return -1 # move key to end (most recently used), return value pass def put(self, key, value): # TODO: if key exists, update and move to end # if at capacity, remove oldest (first) item # insert key at end pass c = LRUCache(2) c.put(1, 1) c.put(2, 2) print(c.get(1)) # 1 (key 1 is now most recent) c.put(3, 3) # evicts key 2 print(c.get(2)) # -1 (evicted) print(c.get(3)) # 3import java.util.*; public class Solution { static class LRUCache extends LinkedHashMap<Integer, Integer> { private final int capacity; LRUCache(int capacity) { super(capacity, 0.75f, true); // accessOrder = true this.capacity = capacity; } int get(int key) { // TODO: return getOrDefault(key, -1) return -1; } void put(int key, int value) { // TODO: call super.put(key, value) } @Override protected boolean removeEldestEntry(Map.Entry<Integer, Integer> eldest) { // TODO: return true when size exceeds capacity return false; } } public static void main(String[] args) { LRUCache c = new LRUCache(2); c.put(1, 1); c.put(2, 2); System.out.println(c.get(1)); // 1 c.put(3, 3); // evicts key 2 System.out.println(c.get(2)); // -1 System.out.println(c.get(3)); // 3 } }
OrderedDict preserves insertion order. Use move_to_end(key) on every access (get and put). When full, call popitem(last=False) to evict the oldest (leftmost) entry. Java: extend LinkedHashMap with accessOrder=true and override removeEldestEntry to return size() > capacity.from collections import OrderedDict class LRUCache: def __init__(self, capacity): self.capacity = capacity self.cache = OrderedDict() def get(self, key): if key not in self.cache: return -1 self.cache.move_to_end(key) return self.cache[key] def put(self, key, value): if key in self.cache: self.cache.move_to_end(key) self.cache[key] = value if len(self.cache) > self.capacity: self.cache.popitem(last=False) c = LRUCache(2) c.put(1, 1) c.put(2, 2) print(c.get(1)) # 1 c.put(3, 3) # evicts key 2 print(c.get(2)) # -1 print(c.get(3)) # 3import java.util.*; public class Solution { static class LRUCache extends LinkedHashMap<Integer, Integer> { private final int capacity; LRUCache(int capacity) { super(capacity, 0.75f, true); this.capacity = capacity; } int get(int key) { return getOrDefault(key, -1); } void put(int key, int value) { super.put(key, value); } @Override protected boolean removeEldestEntry(Map.Entry<Integer, Integer> eldest) { return size() > capacity; } } public static void main(String[] args) { LRUCache c = new LRUCache(2); c.put(1, 1); c.put(2, 2); System.out.println(c.get(1)); // 1 c.put(3, 3); // evicts key 2 System.out.println(c.get(2)); // -1 System.out.println(c.get(3)); // 3 } }
Min Stack
Implement a stack with push, pop, top, and getMin all in O(1). After push(3), push(5): getMin()→3. After push(1): getMin()→1. After pop(): getMin()→3.
class MinStack: def __init__(self): self.stack = [] self.min_stack = [] # tracks current min at each level def push(self, val): # TODO: push val to stack # push current min to min_stack # (min is val if min_stack empty, else min(val, min_stack[-1])) pass def pop(self): # TODO: pop from both stacks pass def top(self): return self.stack[-1] def getMin(self): return self.min_stack[-1] ms = MinStack() ms.push(3) ms.push(5) print(ms.getMin()) # 3 ms.push(1) print(ms.getMin()) # 1 ms.pop() print(ms.getMin()) # 3import java.util.*; public class Solution { static Deque<Integer> stack = new ArrayDeque<>(); static Deque<Integer> minStack = new ArrayDeque<>(); static void push(int val) { // TODO: push val to stack // push current min to minStack // (min is val if minStack empty, else Math.min(val, minStack.peek())) } static void pop() { // TODO: pop from both stacks } static int top() { return stack.peek(); } static int getMin() { return minStack.peek(); } public static void main(String[] args) { push(3); push(5); System.out.println(getMin()); // 3 push(1); System.out.println(getMin()); // 1 pop(); System.out.println(getMin()); // 3 } }
min_stack. On each push(val), also push min(val, min_stack[-1]) (or just val if min_stack is empty) onto min_stack. On pop(), pop from both stacks. getMin() just peeks at min_stack[-1].class MinStack: def __init__(self): self.stack = [] self.min_stack = [] def push(self, val): self.stack.append(val) current_min = val if not self.min_stack else min(val, self.min_stack[-1]) self.min_stack.append(current_min) def pop(self): self.stack.pop() self.min_stack.pop() def top(self): return self.stack[-1] def getMin(self): return self.min_stack[-1] ms = MinStack() ms.push(3) ms.push(5) print(ms.getMin()) # 3 ms.push(1) print(ms.getMin()) # 1 ms.pop() print(ms.getMin()) # 3import java.util.*; public class Solution { static Deque<Integer> stack = new ArrayDeque<>(); static Deque<Integer> minStack = new ArrayDeque<>(); static void push(int val) { stack.push(val); int currentMin = minStack.isEmpty() ? val : Math.min(val, minStack.peek()); minStack.push(currentMin); } static void pop() { stack.pop(); minStack.pop(); } static int top() { return stack.peek(); } static int getMin() { return minStack.peek(); } public static void main(String[] args) { push(3); push(5); System.out.println(getMin()); // 3 push(1); System.out.println(getMin()); // 1 pop(); System.out.println(getMin()); // 3 } }
Hit Counter
Count hits in the past 5 minutes (300 seconds). hit(1), hit(2), hit(3), getHits(4)→3. After hit(300), getHits(300)→4. getHits(301)→3 (the hit at t=1 has expired).
from collections import deque class HitCounter: def __init__(self): self.hits = deque() # stores timestamps def hit(self, timestamp): self.hits.append(timestamp) def getHits(self, timestamp): # TODO: remove all timestamps where timestamp - ts >= 300 # (they are older than 300 seconds) # return remaining count pass hc = HitCounter() hc.hit(1) hc.hit(2) hc.hit(3) print(hc.getHits(4)) # 3 hc.hit(300) print(hc.getHits(300)) # 4 print(hc.getHits(301)) # 3import java.util.*; public class Solution { static Deque<Integer> hits = new ArrayDeque<>(); static void hit(int timestamp) { hits.addLast(timestamp); } static int getHits(int timestamp) { // TODO: remove from front while timestamp - hits.peekFirst() >= 300 // return hits.size() return 0; } public static void main(String[] args) { hit(1); hit(2); hit(3); System.out.println(getHits(4)); // 3 hit(300); System.out.println(getHits(300)); // 4 System.out.println(getHits(301)); // 3 } }
getHits(t), pop from the left (oldest) while t - hits[0] >= 300 — those hits are outside the 300-second window. Return the remaining queue length. The condition is >= 300, not > 300, because a hit at t=1 should expire by t=301.from collections import deque class HitCounter: def __init__(self): self.hits = deque() def hit(self, timestamp): self.hits.append(timestamp) def getHits(self, timestamp): while self.hits and timestamp - self.hits[0] >= 300: self.hits.popleft() return len(self.hits) hc = HitCounter() hc.hit(1) hc.hit(2) hc.hit(3) print(hc.getHits(4)) # 3 hc.hit(300) print(hc.getHits(300)) # 4 print(hc.getHits(301)) # 3import java.util.*; public class Solution { static Deque<Integer> hits = new ArrayDeque<>(); static void hit(int timestamp) { hits.addLast(timestamp); } static int getHits(int timestamp) { while (!hits.isEmpty() && timestamp - hits.peekFirst() >= 300) { hits.pollFirst(); } return hits.size(); } public static void main(String[] args) { hit(1); hit(2); hit(3); System.out.println(getHits(4)); // 3 hit(300); System.out.println(getHits(300)); // 4 System.out.println(getHits(301)); // 3 } }
20. Interview Strategy Cheat Sheet
Problem-Solving Framework
| Step | What to do | Time |
|---|---|---|
| 1. Understand | Restate the problem. Clarify input/output types, constraints, edge cases. | 2 min |
| 2. Examples | Work through 2-3 examples including edge cases (empty, single element, negatives). | 2 min |
| 3. Brute force | State naive \(O(n^2)\) or \(O(2^n)\) solution. Briefly explain why it works. | 1 min |
| 4. Optimize | Identify bottleneck. Apply pattern (sort, two pointers, DP, BFS, etc.). | 5 min |
| 5. Code | Write clean code top-down. Name variables clearly. No premature optimization. | 15 min |
| 6. Test | Trace through examples manually. Check edge cases. Analyze complexity. | 5 min |
Pattern Recognition Table
| Input / Clue | Likely Technique |
|---|---|
| Sorted array, find target | Binary search |
| Find pair/triplet with target sum | Two pointers (sorted) or hash map |
| Subarray / substring sum or length | Sliding window or prefix sums |
| Top K / K-th largest | Heap (min-heap of size k) or Quickselect |
| All subsets / permutations / combinations | Backtracking |
| Optimal choice at each step | Greedy (prove or disprove first) |
| Overlapping subproblems, optimal solution | Dynamic programming |
| Shortest path, unweighted | BFS |
| Shortest path, weighted (non-negative) | Dijkstra |
| Connected components / union operations | Union-Find or BFS/DFS |
| Tree: level-by-level processing | BFS with queue |
| Tree: path / ancestor / subtree | DFS (often recursive) |
| Prefix lookups, autocomplete | Trie |
| Monotone property (next greater, range min) | Monotonic stack or deque |
| String with duplicate characters | Hash map + sliding window |
| Find duplicate / missing number | XOR, index marking, or math |
| Matrix connectivity / regions | DFS/BFS or Union-Find on grid |
| Palindrome | Two pointers from center or Manacher's |
Time Complexity Targets by Input Size
| Input n | Acceptable complexity | Typical algorithm |
|---|---|---|
| \(n \le 15\) | \(O(2^n)\) or \(O(n!)\) | Backtracking, bitmask DP |
| \(n \le 100\) | \(O(n^3)\) or \(O(n^2 \log n)\) | Floyd-Warshall, interval DP |
| \(n \le 1{,}000\) | \(O(n^2)\) | DP, brute-force with pruning |
| \(n \le 10{,}000\) | \(O(n^2)\) tight, prefer \(O(n \log n)\) | \(O(n^2)\) LIS, sorting-based |
| \(n \le 100{,}000\) | \(O(n \log n)\) | Sort, heap, binary search, BFS/DFS |
| \(n \le 10^6\) | \(O(n)\) or \(O(n \log n)\) | Two pointers, sliding window, prefix |
| \(n \le 10^9\) | \(O(\log n)\) or \(O(1)\) | Binary search, math formula |
Common Edge Cases Checklist
- Empty input: empty list, empty string, n=0
- Single element: list of length 1, single node tree
- All same elements: [5, 5, 5, 5]
- Already sorted / reverse sorted
- Negative numbers: does your sum/product logic handle them?
- Integer overflow: Python is safe, but flag it in Java/C++
- Cyclic graph: does your graph algorithm handle cycles?
- Disconnected graph: loop over all unvisited nodes
- Duplicate values: sorted with duplicates, set semantics
- Target not found: return -1, [], or appropriate sentinel
- Very large n: stack overflow from deep recursion? use iterative
Communication Tips
# Narrate your thought process as you code.
# Say these phrases:
#
# "Let me start with a brute-force approach to make sure I understand the problem..."
# "I notice this has overlapping subproblems, so DP might work here."
# "My key insight is: [explain the non-obvious observation]."
# "I'll use a [data structure] here because [reason]."
# "The time complexity is O(...) because [explain each term]."
# "Let me trace through this example: input=[...], expected=[...]"
# "An edge case I want to check: what if the input is empty?"
# "I could optimize space further to O(n) by using a rolling array, but
# I'll keep this version for clarity since we're within the time budget."
#
# During coding:
# - Write function signature first with type hints.
# - Add a one-line docstring stating the approach.
# - Use descriptive names: 'left_ptr' not 'l', 'remaining_sum' not 'rs'.
# - If stuck: "Let me think about what information I need at each step."
2. Have I handled the empty input case?
3. Does the code return the correct type (list vs tuple, int vs float)?
4. Off-by-one: are my loop bounds correct (
range(n) vs range(n+1))?5. Am I mutating input the caller might not expect me to mutate?
6. Is the space complexity within the constraints?
7. Did I state the time and space complexity out loud?
Complexity Target
Given n = 100,000 and a 1-second time limit, what is the maximum acceptable time complexity? Run the code to confirm.
# Given n = 100,000 and a 1-second time limit, # what is the maximum acceptable time complexity? # Options: O(n^2), O(n log n), O(n), O(log n) n = 100000 # O(n^2) = 10,000,000,000 ops — far too slow (~10,000s) # O(n log n) ~ 1,700,000 ops — fast enough (~0.002s) # O(n) = 100,000 ops — also fine # O(log n) ~ 17 ops — trivially fast print("O(n log n)")public class Solution { public static void main(String[] args) { // Given n = 100,000 and a 1-second time limit, // what is the maximum acceptable time complexity? // Options: O(n^2), O(n log n), O(n), O(log n) // O(n^2) = 10,000,000,000 ops -- far too slow // O(n log n) ~ 1,700,000 ops -- fast enough // O(n) = 100,000 ops -- also fine // O(log n) ~ 17 ops -- trivially fast System.out.println("O(n log n)"); } }
# Given n = 100,000 and a 1-second time limit, # what is the maximum acceptable time complexity? # Options: O(n^2), O(n log n), O(n), O(log n) n = 100000 # O(n^2) = 10,000,000,000 ops — far too slow (~10,000s) # O(n log n) ~ 1,700,000 ops — fast enough (~0.002s) # O(n) = 100,000 ops — also fine # O(log n) ~ 17 ops — trivially fast print("O(n log n)")public class Solution { public static void main(String[] args) { // Given n = 100,000 and a 1-second time limit, // what is the maximum acceptable time complexity? // Options: O(n^2), O(n log n), O(n), O(log n) // O(n^2) = 10,000,000,000 ops -- far too slow // O(n log n) ~ 1,700,000 ops -- fast enough // O(n) = 100,000 ops -- also fine // O(log n) ~ 17 ops -- trivially fast System.out.println("O(n log n)"); } }
Pattern Recognition
You are given a sorted array and need to find if any pair sums to a target. Which technique is most efficient? Run the code to confirm.
# You're given a sorted array and need to find if a pair sums to target. # Which technique is most efficient? # Options: brute force, binary search, two pointers, hash map # Brute force: O(n^2) time, O(1) space — too slow # Binary search: O(n log n) time, O(1) space — decent # Two pointers: O(n) time, O(1) space — best for sorted input # Hash map: O(n) time, O(n) space — works but wastes space # Sorted array + pair sum = two pointers # Left pointer at start, right at end. # If sum too small, advance left. If too large, retreat right. print("two pointers")public class Solution { public static void main(String[] args) { // Sorted array + pair sum = two pointers // Left pointer at start, right at end. // If sum too small, advance left. If too large, retreat right. // O(n) time, O(1) space -- best for sorted input. System.out.println("two pointers"); } }
# You're given a sorted array and need to find if a pair sums to target. # Which technique is most efficient? # Options: brute force, binary search, two pointers, hash map # Brute force: O(n^2) time, O(1) space — too slow # Binary search: O(n log n) time, O(1) space — decent # Two pointers: O(n) time, O(1) space — best for sorted input # Hash map: O(n) time, O(n) space — works but wastes space # Sorted array + pair sum = two pointers # Left pointer at start, right at end. # If sum too small, advance left. If too large, retreat right. print("two pointers")public class Solution { public static void main(String[] args) { // Sorted array + pair sum = two pointers // Left pointer at start, right at end. // If sum too small, advance left. If too large, retreat right. // O(n) time, O(1) space -- best for sorted input. System.out.println("two pointers"); } }
Choose Data Structure
You need O(1) lookup, O(1) insert, AND O(1) delete-any-element. Which combination of data structures achieves all three? Run the code to confirm.
# You need O(1) lookup, O(1) insert, AND O(1) delete-random. # Which combination of data structures achieves this? # Array alone: O(1) lookup/insert-end, but O(n) delete (must shift) # HashMap alone: O(1) lookup/insert/delete, but no O(1) random access by index # LinkedList: O(1) insert/delete (with pointer), but O(n) lookup # Solution: HashMap + ArrayList with the "swap trick" # HashMap: value -> index in ArrayList # Delete(val): swap with last element, pop last, update map entry # This makes delete O(1) amortized! print("HashMap + ArrayList")public class Solution { public static void main(String[] args) { // You need O(1) lookup, O(1) insert, AND O(1) delete-random. // Array alone: O(1) lookup/insert-end, O(n) delete (shifting) // HashMap alone: O(1) all ops, but no random index access // LinkedList: O(1) insert/delete (with ptr), O(n) lookup // Solution: HashMap + ArrayList with the "swap trick" // HashMap: value -> index in ArrayList // Delete(val): swap with last, pop last, update map // Result: all three ops are O(1) amortized System.out.println("HashMap + ArrayList"); } }
# You need O(1) lookup, O(1) insert, AND O(1) delete-random. # Which combination of data structures achieves this? # Array alone: O(1) lookup/insert-end, but O(n) delete (must shift) # HashMap alone: O(1) lookup/insert/delete, but no O(1) random access by index # LinkedList: O(1) insert/delete (with pointer), but O(n) lookup # Solution: HashMap + ArrayList with the "swap trick" # HashMap: value -> index in ArrayList # Delete(val): swap with last element, pop last, update map entry # This makes delete O(1) amortized! print("HashMap + ArrayList")public class Solution { public static void main(String[] args) { // You need O(1) lookup, O(1) insert, AND O(1) delete-random. // Array alone: O(1) lookup/insert-end, O(n) delete (shifting) // HashMap alone: O(1) all ops, but no random index access // LinkedList: O(1) insert/delete (with ptr), O(n) lookup // Solution: HashMap + ArrayList with the "swap trick" // HashMap: value -> index in ArrayList // Delete(val): swap with last, pop last, update map // Result: all three ops are O(1) amortized System.out.println("HashMap + ArrayList"); } }