Table of Contents

1. Complexity Analysis

Big-O describes how runtime or space scales with input size as n grows toward infinity. It captures the worst-case growth rate, ignoring constants and lower-order terms.

Big-O Reference Table

NotationNameExamplen=10n=1,000n=1,000,000
O(1)ConstantHash lookup, array index 111
O(log n)LogarithmicBinary search, BST op 31020
O(n)LinearLinear scan, single loop 101K1M
O(n log n)LinearithmicMerge sort, heap sort 3310K20M
\(O(n^2)\)QuadraticNested loops, bubble sort 1001M\(10^{12}\)
\(O(2^n)\)ExponentialSubset enumeration, naive recursion 1,024\(2^{1000}\)\(2^{10^6}\)
O(n!)FactorialPermutations, TSP brute force 3.6Munusableunusable
Interview rule of thumb
\(n \le 20\): \(O(2^n)\) or \(O(n!)\) may be acceptable. \(n \le 500\): \(O(n^2)\) is fine. \(n \le 10^5\): aim for \(O(n \log n)\) or better. \(n \le 10^7\): \(O(n)\) only. \(n \le 10^9\): \(O(\log n)\) or \(O(1)\).

Space Complexity

Auxiliary space is the extra space used beyond the input. Total space includes input storage. Interview problems usually ask about auxiliary space.

SituationSpaceExample
A few variablesO(1)Two-pointer, in-place sort
Recursion depthO(depth)DFS on balanced tree = O(log n); skewed = O(n)
Output / copy of inputO(n)Prefix sum array, result list
2D grid / DP table\(O(n^2)\)Edit distance DP (optimizable to O(n))
Recursion + data per frameO(n × k)DFS with path copy at each level

Analyzing Loops and Recursion

Nested loops

# O(n^2) — each loop 0..n
for i in range(n):
    for j in range(n):
        process(i, j)

# O(n^2) — inner loop still O(n) amortized
for i in range(n):
    for j in range(i, n):   # triangle: n*(n-1)/2 iterations
        process(i, j)

# O(n log n) — inner loop halves each time
for i in range(n):
    j = 1
    while j < n:
        process(i, j)
        j *= 2

Recurrence relations (Master Theorem shortcut)

# T(n) = a * T(n/b) + O(n^d)
#
# Compare d vs log_b(a):
#   d > log_b(a)  => O(n^d)
#   d == log_b(a) => O(n^d * log n)
#   d < log_b(a)  => O(n^log_b(a))
#
# Merge sort:  a=2, b=2, d=1  => log_2(2)=1 == d  => O(n log n)
# Binary search: a=1, b=2, d=0 => log_2(1)=0 == d  => O(log n)
# Strassen:    a=7, b=2, d=2  => log_2(7)~2.81 > d => O(n^2.81)
Common recursion patterns
  • Divide & conquer halving: \(T(n) = 2T(n/2) + O(n) \rightarrow O(n \log n)\)
  • Single recursive call halving: \(T(n) = T(n/2) + O(1) \rightarrow O(\log n)\)
  • Linear recursion: \(T(n) = T(n{-}1) + O(1) \rightarrow O(n)\)
  • Two branches, no memoization: \(T(n) = 2T(n{-}1) + O(1) \rightarrow O(2^n)\)
Practice
Predict

Identify the Time Complexity

What does this function print? Think about how many times the loop runs as n doubles.

def mystery(n):
    count = 0
    i = 1
    while i < n:
        count += 1
        i *= 2
    return count

print(mystery(64))
public class Solution {
    public static int mystery(int n) {
        int count = 0;
        int i = 1;
        while (i < n) {
            count++;
            i *= 2;
        }
        return count;
    }

    public static void main(String[] args) {
        System.out.println(mystery(64));
    }
}
Each iteration doubles i: 1 → 2 → 4 → 8 → 16 → 32 → 64. The loop runs log₂(n) times. This is O(log n).
def mystery(n):
    count = 0
    i = 1
    while i < n:
        count += 1
        i *= 2
    return count

print(mystery(64))  # Output: 6
public class Solution {
    public static int mystery(int n) {
        int count = 0;
        int i = 1;
        while (i < n) {
            count++;
            i *= 2;
        }
        return count;
    }

    public static void main(String[] args) {
        System.out.println(mystery(64));  // Output: 6
    }
}
Predict

Nested Loop Complexity

What does this function print? Count how many times count is incremented across all iterations of both loops.

def mystery(n):
    count = 0
    for i in range(n):
        for j in range(i):
            count += 1
    return count

print(mystery(5))
public class Solution {
    public static int mystery(int n) {
        int count = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < i; j++) {
                count++;
            }
        }
        return count;
    }

    public static void main(String[] args) {
        System.out.println(mystery(5));
    }
}
The inner loop runs 0 + 1 + 2 + 3 + 4 = 10 times total. In general it runs n*(n-1)/2 times, which is O(n²).
def mystery(n):
    count = 0
    for i in range(n):
        for j in range(i):
            count += 1
    return count

print(mystery(5))  # Output: 10
public class Solution {
    public static int mystery(int n) {
        int count = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < i; j++) {
                count++;
            }
        }
        return count;
    }

    public static void main(String[] args) {
        System.out.println(mystery(5));  // Output: 10
    }
}
Predict

Recursion Complexity

What does this function print? Trace the recursive calls — each call spawns two more.

def mystery(n):
    if n <= 1:
        return 1
    return mystery(n - 1) + mystery(n - 1)

print(mystery(5))
public class Solution {
    public static int mystery(int n) {
        if (n <= 1) {
            return 1;
        }
        return mystery(n - 1) + mystery(n - 1);
    }

    public static void main(String[] args) {
        System.out.println(mystery(5));
    }
}
mystery(1)=1, mystery(2)=1+1=2, mystery(3)=2+2=4, mystery(4)=4+4=8, mystery(5)=8+8=16. Each level doubles, so result is 2^(n-1). This is O(2^n) time complexity.
def mystery(n):
    if n <= 1:
        return 1
    return mystery(n - 1) + mystery(n - 1)

print(mystery(5))  # Output: 16
public class Solution {
    public static int mystery(int n) {
        if (n <= 1) {
            return 1;
        }
        return mystery(n - 1) + mystery(n - 1);
    }

    public static void main(String[] args) {
        System.out.println(mystery(5));  // Output: 16
    }
}

2. Arrays & Hashing

Two Pointers

Two pointers eliminates an inner loop by maintaining a structural invariant as pointers move. Choose the variant based on whether the answer involves a pair (opposite direction) or a subarray property (same direction).

Opposite-direction — sorted array pair problems

def two_sum_sorted(nums: list[int], target: int) -> tuple[int, int]:
    """Find indices of two numbers that sum to target in a sorted array.
    Time: O(n)  Space: O(1)
    """
    left, right = 0, len(nums) - 1
    while left < right:
        s = nums[left] + nums[right]
        if s == target:
            return left, right
        elif s < target:
            left += 1   # need bigger sum
        else:
            right -= 1  # need smaller sum
    return -1, -1  # not found

Same-direction — remove duplicates / partition

def remove_duplicates(nums: list[int]) -> int:
    """Remove duplicates in-place from sorted array; return new length.
    Time: O(n)  Space: O(1)
    Pattern: slow pointer marks the 'write' position.
    """
    if not nums:
        return 0
    slow = 0
    for fast in range(1, len(nums)):
        if nums[fast] != nums[slow]:
            slow += 1
            nums[slow] = nums[fast]
    return slow + 1


def partition(nums: list[int], pivot: int) -> list[int]:
    """Dutch National Flag — 3-way partition around pivot.
    Time: O(n)  Space: O(1)
    """
    low, mid, high = 0, 0, len(nums) - 1
    while mid <= high:
        if nums[mid] < pivot:
            nums[low], nums[mid] = nums[mid], nums[low]
            low += 1
            mid += 1
        elif nums[mid] == pivot:
            mid += 1
        else:
            nums[mid], nums[high] = nums[high], nums[mid]
            high -= 1
    return nums

Sliding Window

Sliding window avoids recomputing aggregates from scratch on every subarray. The window expands at the right and contracts at the left based on a validity condition.

Fixed-size window

def max_sum_subarray(nums: list[int], k: int) -> int:
    """Maximum sum of any subarray of length exactly k.
    Time: O(n)  Space: O(1)
    """
    if len(nums) < k:
        return 0

    # Build first window
    window_sum = sum(nums[:k])
    best = window_sum

    for i in range(k, len(nums)):
        window_sum += nums[i] - nums[i - k]  # slide: add right, drop left
        best = max(best, window_sum)

    return best

Variable-size window — find shortest/longest subarray satisfying condition

def min_subarray_len(target: int, nums: list[int]) -> int:
    """Minimum length subarray with sum >= target.
    Time: O(n)  Space: O(1)
    """
    left = 0
    current_sum = 0
    min_len = float('inf')

    for right in range(len(nums)):
        current_sum += nums[right]                   # expand window
        while current_sum >= target:                 # window valid: try to shrink
            min_len = min(min_len, right - left + 1)
            current_sum -= nums[left]
            left += 1

    return min_len if min_len != float('inf') else 0


def longest_substring_k_distinct(s: str, k: int) -> int:
    """Longest substring with at most k distinct characters.
    Time: O(n)  Space: O(k)
    """
    from collections import defaultdict
    freq: dict[str, int] = defaultdict(int)
    left = 0
    best = 0

    for right, ch in enumerate(s):
        freq[ch] += 1
        while len(freq) > k:                         # window invalid: shrink
            left_ch = s[left]
            freq[left_ch] -= 1
            if freq[left_ch] == 0:
                del freq[left_ch]
            left += 1
        best = max(best, right - left + 1)

    return best
Window invariant checklist
Before coding, state the invariant explicitly: "the window [left, right] always satisfies X." If shrinking breaks the invariant, stop shrinking. If expanding breaks it, shrink until it holds again.

Prefix Sums

def build_prefix(nums: list[int]) -> list[int]:
    """prefix[i] = sum of nums[0..i-1].  prefix[0] = 0 (sentinel).
    Time: O(n)  Space: O(n)
    """
    prefix = [0] * (len(nums) + 1)
    for i, x in enumerate(nums):
        prefix[i + 1] = prefix[i] + x
    return prefix


def range_sum(prefix: list[int], left: int, right: int) -> int:
    """Sum of nums[left..right] inclusive. O(1) per query."""
    return prefix[right + 1] - prefix[left]


def subarray_sum_equals_k(nums: list[int], k: int) -> int:
    """Count subarrays with sum exactly k.
    Key insight: sum(i..j) = prefix[j+1] - prefix[i] = k
                 => prefix[i] = prefix[j+1] - k
    Time: O(n)  Space: O(n)
    """
    from collections import defaultdict
    count = 0
    running = 0
    seen: dict[int, int] = defaultdict(int)
    seen[0] = 1  # empty prefix

    for x in nums:
        running += x
        count += seen[running - k]
        seen[running] += 1

    return count

Hash Map Patterns

Two Sum (unsorted)

def two_sum(nums: list[int], target: int) -> tuple[int, int]:
    """Time: O(n)  Space: O(n)"""
    seen: dict[int, int] = {}
    for i, x in enumerate(nums):
        complement = target - x
        if complement in seen:
            return seen[complement], i
        seen[x] = i
    return -1, -1

Frequency counting / grouping

from collections import defaultdict, Counter

def group_anagrams(words: list[str]) -> list[list[str]]:
    """Group words that are anagrams of each other.
    Time: O(n * m log m) where m = avg word length.  Space: O(n*m)
    """
    groups: dict[tuple, list[str]] = defaultdict(list)
    for word in words:
        key = tuple(sorted(word))   # canonical form
        groups[key].append(word)
    return list(groups.values())


def top_k_frequent(nums: list[int], k: int) -> list[int]:
    """Return k most frequent elements.
    Time: O(n log k)  Space: O(n)
    """
    import heapq
    freq = Counter(nums)
    return heapq.nlargest(k, freq, key=freq.get)

Kadane's Algorithm — Maximum Subarray

def max_subarray(nums: list[int]) -> int:
    """Maximum sum of any contiguous subarray (Kadane's algorithm).
    Time: O(n)  Space: O(1)

    Invariant: current is the max sum of any subarray ending at index i.
    If current goes negative, it can only hurt future sums — reset to 0.
    """
    best = nums[0]
    current = nums[0]
    for x in nums[1:]:
        current = max(x, current + x)   # extend or restart
        best = max(best, current)
    return best


def max_subarray_with_indices(nums: list[int]) -> tuple[int, int, int]:
    """Returns (max_sum, start, end) — inclusive indices."""
    best_sum = nums[0]
    best_start = best_end = 0
    current_sum = nums[0]
    current_start = 0

    for i in range(1, len(nums)):
        if nums[i] > current_sum + nums[i]:
            current_sum = nums[i]
            current_start = i
        else:
            current_sum += nums[i]
        if current_sum > best_sum:
            best_sum = current_sum
            best_start = current_start
            best_end = i

    return best_sum, best_start, best_end

Merge Intervals

def merge_intervals(intervals: list[list[int]]) -> list[list[int]]:
    """Merge all overlapping intervals.
    Time: O(n log n)  Space: O(n)
    """
    if not intervals:
        return []

    intervals.sort(key=lambda x: x[0])   # sort by start
    merged = [intervals[0]]

    for start, end in intervals[1:]:
        if start <= merged[-1][1]:        # overlaps with last merged
            merged[-1][1] = max(merged[-1][1], end)
        else:
            merged.append([start, end])

    return merged


def insert_interval(intervals: list[list[int]], new: list[int]) -> list[list[int]]:
    """Insert a new interval, merging as needed.  Assumes sorted input.
    Time: O(n)  Space: O(n)
    """
    result = []
    i, n = 0, len(intervals)

    # All intervals that end before new starts
    while i < n and intervals[i][1] < new[0]:
        result.append(intervals[i])
        i += 1

    # Merge overlapping intervals into new
    while i < n and intervals[i][0] <= new[1]:
        new[0] = min(new[0], intervals[i][0])
        new[1] = max(new[1], intervals[i][1])
        i += 1

    result.append(new)
    result.extend(intervals[i:])
    return result
Practice
Challenge

Two Sum

Given an array of integers and a target, return the indices of the two numbers that add up to the target. Use a hash map for O(n) time.

def two_sum(nums, target):
    seen = {}
    for i, num in enumerate(nums):
        complement = target - num
        # YOUR CODE HERE
        seen[num] = i
    return []

print(two_sum([2, 7, 11, 15], 9))   # [0, 1]
print(two_sum([3, 2, 4], 6))        # [1, 2]
import java.util.*;

public class Solution {
    public static int[] twoSum(int[] nums, int target) {
        Map<Integer, Integer> seen = new HashMap<>();
        for (int i = 0; i < nums.length; i++) {
            int complement = target - nums[i];
            // YOUR CODE HERE
            seen.put(nums[i], i);
        }
        return new int[]{};
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(twoSum(new int[]{2, 7, 11, 15}, 9)));  // [0, 1]
        System.out.println(Arrays.toString(twoSum(new int[]{3, 2, 4}, 6)));       // [1, 2]
    }
}
For each number, compute complement = target - num. Check if complement is already in seen. If yes, return [seen[complement], i]. Otherwise store num → i in the map.
def two_sum(nums, target):
    seen = {}
    for i, num in enumerate(nums):
        complement = target - num
        if complement in seen:
            return [seen[complement], i]
        seen[num] = i
    return []

print(two_sum([2, 7, 11, 15], 9))   # [0, 1]
print(two_sum([3, 2, 4], 6))        # [1, 2]
import java.util.*;

public class Solution {
    public static int[] twoSum(int[] nums, int target) {
        Map<Integer, Integer> seen = new HashMap<>();
        for (int i = 0; i < nums.length; i++) {
            int complement = target - nums[i];
            if (seen.containsKey(complement)) {
                return new int[]{seen.get(complement), i};
            }
            seen.put(nums[i], i);
        }
        return new int[]{};
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(twoSum(new int[]{2, 7, 11, 15}, 9)));  // [0, 1]
        System.out.println(Arrays.toString(twoSum(new int[]{3, 2, 4}, 6)));       // [1, 2]
    }
}
Challenge

Contains Duplicate

Return True if any value appears at least twice in the array, False if all elements are distinct. Aim for O(n) time using a set.

def contains_duplicate(nums):
    seen = set()
    for num in nums:
        # YOUR CODE HERE
        seen.add(num)
    return False

print(contains_duplicate([1, 2, 3, 1]))   # True
print(contains_duplicate([1, 2, 3, 4]))   # False
import java.util.*;

public class Solution {
    public static boolean containsDuplicate(int[] nums) {
        Set<Integer> seen = new HashSet<>();
        for (int num : nums) {
            // YOUR CODE HERE
            seen.add(num);
        }
        return false;
    }

    public static void main(String[] args) {
        System.out.println(containsDuplicate(new int[]{1, 2, 3, 1}));  // true
        System.out.println(containsDuplicate(new int[]{1, 2, 3, 4}));  // false
    }
}
Before adding each number to the set, check if it is already present with if num in seen. If so, return True immediately.
def contains_duplicate(nums):
    seen = set()
    for num in nums:
        if num in seen:
            return True
        seen.add(num)
    return False

print(contains_duplicate([1, 2, 3, 1]))   # True
print(contains_duplicate([1, 2, 3, 4]))   # False
import java.util.*;

public class Solution {
    public static boolean containsDuplicate(int[] nums) {
        Set<Integer> seen = new HashSet<>();
        for (int num : nums) {
            if (seen.contains(num)) {
                return true;
            }
            seen.add(num);
        }
        return false;
    }

    public static void main(String[] args) {
        System.out.println(containsDuplicate(new int[]{1, 2, 3, 1}));  // true
        System.out.println(containsDuplicate(new int[]{1, 2, 3, 4}));  // false
    }
}
Challenge

Maximum Subarray Sum

Find the contiguous subarray with the largest sum (Kadane's algorithm). The array may contain negative numbers.

def max_subarray(nums):
    current = nums[0]
    best = nums[0]
    for num in nums[1:]:
        # YOUR CODE HERE: update current and best
        pass
    return best

print(max_subarray([-2, 1, -3, 4, -1, 2, 1, -5, 4]))  # 6
print(max_subarray([1]))                                 # 1
print(max_subarray([-1, -2, -3]))                        # -1
public class Solution {
    public static int maxSubarray(int[] nums) {
        int current = nums[0];
        int best = nums[0];
        for (int i = 1; i < nums.length; i++) {
            // YOUR CODE HERE: update current and best
        }
        return best;
    }

    public static void main(String[] args) {
        System.out.println(maxSubarray(new int[]{-2, 1, -3, 4, -1, 2, 1, -5, 4}));  // 6
        System.out.println(maxSubarray(new int[]{1}));                                // 1
        System.out.println(maxSubarray(new int[]{-1, -2, -3}));                      // -1
    }
}
At each step, decide: is it better to extend the current subarray, or start fresh from this element? Use current = max(num, current + num). Then update best = max(best, current).
def max_subarray(nums):
    current = nums[0]
    best = nums[0]
    for num in nums[1:]:
        current = max(num, current + num)
        best = max(best, current)
    return best

print(max_subarray([-2, 1, -3, 4, -1, 2, 1, -5, 4]))  # 6
print(max_subarray([1]))                                 # 1
print(max_subarray([-1, -2, -3]))                        # -1
public class Solution {
    public static int maxSubarray(int[] nums) {
        int current = nums[0];
        int best = nums[0];
        for (int i = 1; i < nums.length; i++) {
            current = Math.max(nums[i], current + nums[i]);
            best = Math.max(best, current);
        }
        return best;
    }

    public static void main(String[] args) {
        System.out.println(maxSubarray(new int[]{-2, 1, -3, 4, -1, 2, 1, -5, 4}));  // 6
        System.out.println(maxSubarray(new int[]{1}));                                // 1
        System.out.println(maxSubarray(new int[]{-1, -2, -3}));                      // -1
    }
}

3. Strings

Palindrome Detection

def is_palindrome(s: str) -> bool:
    """O(n) time, O(1) space — two pointers."""
    left, right = 0, len(s) - 1
    while left < right:
        if s[left] != s[right]:
            return False
        left += 1
        right -= 1
    return True


def longest_palindrome(s: str) -> str:
    """Expand from center — handles both odd and even length palindromes.
    Time: O(n^2)  Space: O(1)
    """
    def expand(left: int, right: int) -> str:
        while left >= 0 and right < len(s) and s[left] == s[right]:
            left -= 1
            right += 1
        return s[left + 1:right]   # slice after over-expansion

    best = ""
    for i in range(len(s)):
        odd  = expand(i, i)        # center at single character
        even = expand(i, i + 1)    # center between two characters
        if len(odd) > len(best):
            best = odd
        if len(even) > len(best):
            best = even
    return best

Anagram Patterns

from collections import Counter

def is_anagram(s: str, t: str) -> bool:
    """Time: O(n)  Space: O(1) — at most 26 characters."""
    return Counter(s) == Counter(t)


def find_all_anagrams(s: str, pattern: str) -> list[int]:
    """Return start indices of all anagrams of pattern in s.
    Sliding window with frequency map.
    Time: O(n)  Space: O(1)
    """
    result = []
    p_count = Counter(pattern)
    w_count: Counter = Counter()
    k = len(pattern)

    for right in range(len(s)):
        w_count[s[right]] += 1

        # Drop character leaving the window
        if right >= k:
            left_ch = s[right - k]
            w_count[left_ch] -= 1
            if w_count[left_ch] == 0:
                del w_count[left_ch]

        if w_count == p_count:
            result.append(right - k + 1)

    return result

KMP Algorithm

Knuth-Morris-Pratt finds all occurrences of a pattern in a text in O(n + m) time by precomputing a failure function (also called LPS — longest proper prefix that is also a suffix). This lets the algorithm skip re-examining characters after a mismatch.

def build_lps(pattern: str) -> list[int]:
    """Build LPS (failure function) array.
    lps[i] = length of longest proper prefix of pattern[0..i]
             that is also a suffix.
    Time: O(m)  Space: O(m)
    """
    m = len(pattern)
    lps = [0] * m
    length = 0   # length of previous longest prefix-suffix
    i = 1

    while i < m:
        if pattern[i] == pattern[length]:
            length += 1
            lps[i] = length
            i += 1
        else:
            if length != 0:
                length = lps[length - 1]  # fall back (don't increment i)
            else:
                lps[i] = 0
                i += 1
    return lps


def kmp_search(text: str, pattern: str) -> list[int]:
    """Return all start indices where pattern occurs in text.
    Time: O(n + m)  Space: O(m)
    """
    n, m = len(text), len(pattern)
    if m == 0:
        return []

    lps = build_lps(pattern)
    result = []
    i = j = 0   # i = text index, j = pattern index

    while i < n:
        if text[i] == pattern[j]:
            i += 1
            j += 1
        if j == m:
            result.append(i - j)
            j = lps[j - 1]          # shift pattern using LPS
        elif i < n and text[i] != pattern[j]:
            if j != 0:
                j = lps[j - 1]      # skip already-matched prefix
            else:
                i += 1

    return result

Common String Tricks

# 1. Reverse words in a sentence
def reverse_words(s: str) -> str:
    return ' '.join(s.split()[::-1])


# 2. Check if string is a rotation of another
#    "abcde" is a rotation of "cdeab"
def is_rotation(s1: str, s2: str) -> bool:
    if len(s1) != len(s2):
        return False
    return s2 in (s1 + s1)   # O(n) with KMP; O(n^2) with naive 'in'


# 3. Longest common prefix
def longest_common_prefix(strs: list[str]) -> str:
    if not strs:
        return ""
    prefix = strs[0]
    for s in strs[1:]:
        while not s.startswith(prefix):
            prefix = prefix[:-1]
            if not prefix:
                return ""
    return prefix


# 4. Encode / decode strings (interview favorite)
def encode(strs: list[str]) -> str:
    # Format: "length#string" per word — handles all characters
    return ''.join(f"{len(s)}#{s}" for s in strs)

def decode(data: str) -> list[str]:
    result, i = [], 0
    while i < len(data):
        j = data.index('#', i)
        length = int(data[i:j])
        result.append(data[j + 1: j + 1 + length])
        i = j + 1 + length
    return result
Python string immutability
Python strings are immutable. Building a string character by character with += is \(O(n^2)\). Always accumulate in a list and ''.join() at the end: \(O(n)\).
Practice
Challenge

Valid Palindrome

Return True if a string is a palindrome, considering only alphanumeric characters and ignoring case. Use two pointers.

def is_palindrome(s):
    left, right = 0, len(s) - 1
    while left < right:
        while left < right and not s[left].isalnum():
            left += 1
        while left < right and not s[right].isalnum():
            right -= 1
        # YOUR CODE HERE: compare and move pointers
    return True

print(is_palindrome("A man, a plan, a canal: Panama"))  # True
print(is_palindrome("race a car"))                       # False
public class Solution {
    public static boolean isPalindrome(String s) {
        int left = 0, right = s.length() - 1;
        while (left < right) {
            while (left < right && !Character.isLetterOrDigit(s.charAt(left))) left++;
            while (left < right && !Character.isLetterOrDigit(s.charAt(right))) right--;
            // YOUR CODE HERE: compare and move pointers
        }
        return true;
    }

    public static void main(String[] args) {
        System.out.println(isPalindrome("A man, a plan, a canal: Panama"));  // true
        System.out.println(isPalindrome("race a car"));                       // false
    }
}
After skipping non-alphanumeric chars, compare s[left].lower() != s[right].lower(). If they differ, return False. Otherwise move both pointers inward: left += 1; right -= 1.
def is_palindrome(s):
    left, right = 0, len(s) - 1
    while left < right:
        while left < right and not s[left].isalnum():
            left += 1
        while left < right and not s[right].isalnum():
            right -= 1
        if s[left].lower() != s[right].lower():
            return False
        left += 1
        right -= 1
    return True

print(is_palindrome("A man, a plan, a canal: Panama"))  # True
print(is_palindrome("race a car"))                       # False
public class Solution {
    public static boolean isPalindrome(String s) {
        int left = 0, right = s.length() - 1;
        while (left < right) {
            while (left < right && !Character.isLetterOrDigit(s.charAt(left))) left++;
            while (left < right && !Character.isLetterOrDigit(s.charAt(right))) right--;
            if (Character.toLowerCase(s.charAt(left)) != Character.toLowerCase(s.charAt(right))) {
                return false;
            }
            left++;
            right--;
        }
        return true;
    }

    public static void main(String[] args) {
        System.out.println(isPalindrome("A man, a plan, a canal: Panama"));  // true
        System.out.println(isPalindrome("race a car"));                       // false
    }
}
Challenge

Reverse String In-Place

Reverse a list of characters in-place using two pointers. Do not allocate extra space for another array.

def reverse_string(s):
    left, right = 0, len(s) - 1
    while left < right:
        # YOUR CODE HERE: swap s[left] and s[right], then move pointers
        pass

chars = ['h', 'e', 'l', 'l', 'o']
reverse_string(chars)
print(chars)  # ['o', 'l', 'l', 'e', 'h']
import java.util.*;

public class Solution {
    public static void reverseString(char[] s) {
        int left = 0, right = s.length - 1;
        while (left < right) {
            // YOUR CODE HERE: swap s[left] and s[right], then move pointers
        }
    }

    public static void main(String[] args) {
        char[] chars = {'h', 'e', 'l', 'l', 'o'};
        reverseString(chars);
        System.out.println(Arrays.toString(chars));  // [o, l, l, e, h]
    }
}
Swap s[left] and s[right] using a temporary variable (or Python's tuple swap: s[left], s[right] = s[right], s[left]). Then increment left and decrement right.
def reverse_string(s):
    left, right = 0, len(s) - 1
    while left < right:
        s[left], s[right] = s[right], s[left]
        left += 1
        right -= 1

chars = ['h', 'e', 'l', 'l', 'o']
reverse_string(chars)
print(chars)  # ['o', 'l', 'l', 'e', 'h']
import java.util.*;

public class Solution {
    public static void reverseString(char[] s) {
        int left = 0, right = s.length - 1;
        while (left < right) {
            char temp = s[left];
            s[left] = s[right];
            s[right] = temp;
            left++;
            right--;
        }
    }

    public static void main(String[] args) {
        char[] chars = {'h', 'e', 'l', 'l', 'o'};
        reverseString(chars);
        System.out.println(Arrays.toString(chars));  // [o, l, l, e, h]
    }
}
Challenge

First Unique Character

Given a string, return the index of the first non-repeating character. Return -1 if none exists. Aim for O(n) using a frequency map.

def first_uniq_char(s):
    freq = {}
    for ch in s:
        freq[ch] = freq.get(ch, 0) + 1
    # YOUR CODE HERE: find index of first char with freq == 1
    return -1

print(first_uniq_char("leetcode"))  # 0
print(first_uniq_char("loveleet")) # 1
print(first_uniq_char("aabb"))     # -1
import java.util.*;

public class Solution {
    public static int firstUniqChar(String s) {
        Map<Character, Integer> freq = new HashMap<>();
        for (char ch : s.toCharArray()) {
            freq.put(ch, freq.getOrDefault(ch, 0) + 1);
        }
        // YOUR CODE HERE: find index of first char with freq == 1
        return -1;
    }

    public static void main(String[] args) {
        System.out.println(firstUniqChar("leetcode"));  // 0
        System.out.println(firstUniqChar("loveleet"));  // 1
        System.out.println(firstUniqChar("aabb"));      // -1
    }
}
After building the frequency map, do a second pass over the string with enumerate(s). Return the index i as soon as freq[s[i]] == 1.
def first_uniq_char(s):
    freq = {}
    for ch in s:
        freq[ch] = freq.get(ch, 0) + 1
    for i, ch in enumerate(s):
        if freq[ch] == 1:
            return i
    return -1

print(first_uniq_char("leetcode"))  # 0
print(first_uniq_char("loveleet")) # 1
print(first_uniq_char("aabb"))     # -1
import java.util.*;

public class Solution {
    public static int firstUniqChar(String s) {
        Map<Character, Integer> freq = new HashMap<>();
        for (char ch : s.toCharArray()) {
            freq.put(ch, freq.getOrDefault(ch, 0) + 1);
        }
        for (int i = 0; i < s.length(); i++) {
            if (freq.get(s.charAt(i)) == 1) {
                return i;
            }
        }
        return -1;
    }

    public static void main(String[] args) {
        System.out.println(firstUniqChar("leetcode"));  // 0
        System.out.println(firstUniqChar("loveleet"));  // 1
        System.out.println(firstUniqChar("aabb"));      // -1
    }
}

4. Linked Lists

class ListNode:
    def __init__(self, val: int = 0, next: 'ListNode | None' = None):
        self.val = val
        self.next = next

Fast / Slow Pointer (Floyd's)

def has_cycle(head: ListNode | None) -> bool:
    """Detect cycle. Time: O(n)  Space: O(1)"""
    slow = fast = head
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
        if slow is fast:
            return True
    return False


def find_cycle_start(head: ListNode | None) -> ListNode | None:
    """Find the node where a cycle begins.
    Time: O(n)  Space: O(1)

    Insight: after detecting meeting point, reset one pointer to head.
    Move both one step at a time — they meet at cycle start.
    """
    slow = fast = head
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
        if slow is fast:
            # Found meeting point
            slow = head
            while slow is not fast:
                slow = slow.next
                fast = fast.next
            return slow   # cycle start
    return None


def find_middle(head: ListNode | None) -> ListNode | None:
    """Find middle node (for even length: returns second middle).
    Time: O(n)  Space: O(1)
    """
    slow = fast = head
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
    return slow

Reverse Linked List

def reverse_iterative(head: ListNode | None) -> ListNode | None:
    """Time: O(n)  Space: O(1)"""
    prev = None
    curr = head
    while curr:
        next_node = curr.next   # save next before overwriting
        curr.next = prev        # reverse the link
        prev = curr             # advance prev
        curr = next_node        # advance curr
    return prev                 # prev is the new head


def reverse_recursive(head: ListNode | None) -> ListNode | None:
    """Time: O(n)  Space: O(n) — call stack depth."""
    if head is None or head.next is None:
        return head
    new_head = reverse_recursive(head.next)
    head.next.next = head   # reverse the link
    head.next = None        # don't forget to nullify
    return new_head


def reverse_between(head: ListNode | None, left: int, right: int) -> ListNode | None:
    """Reverse sublist from position left to right (1-indexed).
    Time: O(n)  Space: O(1)
    """
    dummy = ListNode(0, head)
    prev = dummy

    # Move prev to node just before 'left'
    for _ in range(left - 1):
        prev = prev.next

    curr = prev.next
    for _ in range(right - left):
        next_node = curr.next
        curr.next = next_node.next
        next_node.next = prev.next
        prev.next = next_node

    return dummy.next

Merge Two Sorted Lists

def merge_two_sorted(l1: ListNode | None, l2: ListNode | None) -> ListNode | None:
    """Time: O(n + m)  Space: O(1)
    Dummy head avoids edge cases when the result head is uncertain.
    """
    dummy = ListNode(0)
    curr = dummy

    while l1 and l2:
        if l1.val <= l2.val:
            curr.next = l1
            l1 = l1.next
        else:
            curr.next = l2
            l2 = l2.next
        curr = curr.next

    curr.next = l1 or l2   # attach remaining
    return dummy.next

Dummy Head Trick & Other Patterns

def remove_nth_from_end(head: ListNode | None, n: int) -> ListNode | None:
    """Remove the nth node from the end in one pass.
    Time: O(L)  Space: O(1)
    """
    dummy = ListNode(0, head)
    fast = slow = dummy

    # Advance fast by n+1 steps so the gap is n
    for _ in range(n + 1):
        fast = fast.next

    # Move both until fast reaches end
    while fast:
        fast = fast.next
        slow = slow.next

    slow.next = slow.next.next   # unlink nth from end
    return dummy.next


def reorder_list(head: ListNode | None) -> None:
    """Reorder L0 -> L1 -> ... -> Ln into L0 -> Ln -> L1 -> Ln-1 -> ...
    In-place. Time: O(n)  Space: O(1)
    Three steps: find middle, reverse second half, merge.
    """
    if not head:
        return

    # Step 1: find middle
    slow = fast = head
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next

    # Step 2: reverse second half
    second = slow.next
    slow.next = None    # cut the list
    prev = None
    while second:
        tmp = second.next
        second.next = prev
        prev = second
        second = tmp

    # Step 3: interleave
    first, second = head, prev
    while second:
        tmp1, tmp2 = first.next, second.next
        first.next = second
        second.next = tmp1
        first, second = tmp1, tmp2
Linked list interview checklist
  • Always use a dummy head when the result head might change
  • Draw the pointer diagram before coding — one missed .next breaks the whole list
  • Check: empty list, single node, two nodes, cycle at tail vs. mid
  • Fast/slow pointer: fast moves 2 steps, slow moves 1 — they meet in O(n)
Practice
Challenge

Reverse Linked List

Reverse a singly linked list iteratively. Fill in the body of reverse_list.

class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def to_list(head):
    result = []
    while head:
        result.append(head.val)
        head = head.next
    return result

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def reverse_list(head):
    prev = None
    curr = head
    while curr:
        # YOUR CODE HERE: save next, point curr.next to prev, advance
        pass
    return prev

print(to_list(reverse_list(from_list([1, 2, 3, 4, 5]))))  # [5, 4, 3, 2, 1]
print(to_list(reverse_list(from_list([1, 2]))))            # [2, 1]
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    static String toList(ListNode head) {
        StringBuilder sb = new StringBuilder("[");
        while (head != null) {
            sb.append(head.val);
            if (head.next != null) sb.append(", ");
            head = head.next;
        }
        return sb.append("]").toString();
    }

    public static ListNode reverseList(ListNode head) {
        ListNode prev = null;
        ListNode curr = head;
        while (curr != null) {
            // YOUR CODE HERE: save next, point curr.next to prev, advance
        }
        return prev;
    }

    public static void main(String[] args) {
        System.out.println(toList(reverseList(fromList(new int[]{1, 2, 3, 4, 5}))));  // [5, 4, 3, 2, 1]
        System.out.println(toList(reverseList(fromList(new int[]{1, 2}))));            // [2, 1]
    }
}
Save next_node = curr.next, then set curr.next = prev to reverse the pointer. Advance: prev = curr, curr = next_node. Repeat until curr is None.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def to_list(head):
    result = []
    while head:
        result.append(head.val)
        head = head.next
    return result

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def reverse_list(head):
    prev = None
    curr = head
    while curr:
        next_node = curr.next
        curr.next = prev
        prev = curr
        curr = next_node
    return prev

print(to_list(reverse_list(from_list([1, 2, 3, 4, 5]))))  # [5, 4, 3, 2, 1]
print(to_list(reverse_list(from_list([1, 2]))))            # [2, 1]
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    static String toList(ListNode head) {
        StringBuilder sb = new StringBuilder("[");
        while (head != null) {
            sb.append(head.val);
            if (head.next != null) sb.append(", ");
            head = head.next;
        }
        return sb.append("]").toString();
    }

    public static ListNode reverseList(ListNode head) {
        ListNode prev = null;
        ListNode curr = head;
        while (curr != null) {
            ListNode nextNode = curr.next;
            curr.next = prev;
            prev = curr;
            curr = nextNode;
        }
        return prev;
    }

    public static void main(String[] args) {
        System.out.println(toList(reverseList(fromList(new int[]{1, 2, 3, 4, 5}))));  // [5, 4, 3, 2, 1]
        System.out.println(toList(reverseList(fromList(new int[]{1, 2}))));            // [2, 1]
    }
}
Challenge

Middle of Linked List

Return the middle node of a linked list. For even-length lists, return the second middle node. Use the fast/slow pointer technique.

class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def middle(head):
    slow = head
    fast = head
    while fast and fast.next:
        # YOUR CODE HERE: advance slow by 1, fast by 2
        pass
    return slow

print(middle(from_list([1, 2, 3, 4, 5])).val)  # 3
print(middle(from_list([1, 2, 3, 4])).val)      # 3 (second middle)
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    public static ListNode middle(ListNode head) {
        ListNode slow = head;
        ListNode fast = head;
        while (fast != null && fast.next != null) {
            // YOUR CODE HERE: advance slow by 1, fast by 2
        }
        return slow;
    }

    public static void main(String[] args) {
        System.out.println(middle(fromList(new int[]{1, 2, 3, 4, 5})).val);  // 3
        System.out.println(middle(fromList(new int[]{1, 2, 3, 4})).val);     // 3
    }
}
Move slow = slow.next and fast = fast.next.next each iteration. When fast reaches the end, slow is exactly at the middle.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def middle(head):
    slow = head
    fast = head
    while fast and fast.next:
        slow = slow.next
        fast = fast.next.next
    return slow

print(middle(from_list([1, 2, 3, 4, 5])).val)  # 3
print(middle(from_list([1, 2, 3, 4])).val)      # 3 (second middle)
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    public static ListNode middle(ListNode head) {
        ListNode slow = head;
        ListNode fast = head;
        while (fast != null && fast.next != null) {
            slow = slow.next;
            fast = fast.next.next;
        }
        return slow;
    }

    public static void main(String[] args) {
        System.out.println(middle(fromList(new int[]{1, 2, 3, 4, 5})).val);  // 3
        System.out.println(middle(fromList(new int[]{1, 2, 3, 4})).val);     // 3
    }
}
Challenge

Merge Two Sorted Lists

Merge two sorted linked lists and return the sorted result. Use a dummy head node to simplify the logic.

class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def to_list(head):
    result = []
    while head:
        result.append(head.val)
        head = head.next
    return result

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def merge(l1, l2):
    dummy = ListNode(0)
    curr = dummy
    while l1 and l2:
        # YOUR CODE HERE: attach smaller node, advance that pointer
        pass
    # YOUR CODE HERE: attach remaining nodes
    return dummy.next

print(to_list(merge(from_list([1, 3, 5]), from_list([2, 4, 6]))))  # [1, 2, 3, 4, 5, 6]
print(to_list(merge(from_list([1, 2]), from_list([3, 4]))))         # [1, 2, 3, 4]
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    static String toList(ListNode head) {
        StringBuilder sb = new StringBuilder("[");
        while (head != null) {
            sb.append(head.val);
            if (head.next != null) sb.append(", ");
            head = head.next;
        }
        return sb.append("]").toString();
    }

    public static ListNode merge(ListNode l1, ListNode l2) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        while (l1 != null && l2 != null) {
            // YOUR CODE HERE: attach smaller node, advance that pointer
        }
        // YOUR CODE HERE: attach remaining nodes
        return dummy.next;
    }

    public static void main(String[] args) {
        System.out.println(toList(merge(fromList(new int[]{1, 3, 5}), fromList(new int[]{2, 4, 6}))));  // [1, 2, 3, 4, 5, 6]
        System.out.println(toList(merge(fromList(new int[]{1, 2}), fromList(new int[]{3, 4}))));         // [1, 2, 3, 4]
    }
}
Compare l1.val and l2.val. Attach the smaller one to curr.next and advance that list's pointer. After the loop, attach whichever list still has nodes with curr.next = l1 or l2.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

def to_list(head):
    result = []
    while head:
        result.append(head.val)
        head = head.next
    return result

def from_list(arr):
    dummy = ListNode(0)
    curr = dummy
    for v in arr:
        curr.next = ListNode(v)
        curr = curr.next
    return dummy.next

def merge(l1, l2):
    dummy = ListNode(0)
    curr = dummy
    while l1 and l2:
        if l1.val <= l2.val:
            curr.next = l1
            l1 = l1.next
        else:
            curr.next = l2
            l2 = l2.next
        curr = curr.next
    curr.next = l1 if l1 else l2
    return dummy.next

print(to_list(merge(from_list([1, 3, 5]), from_list([2, 4, 6]))))  # [1, 2, 3, 4, 5, 6]
print(to_list(merge(from_list([1, 2]), from_list([3, 4]))))         # [1, 2, 3, 4]
public class Solution {
    static class ListNode {
        int val;
        ListNode next;
        ListNode(int val) { this.val = val; }
    }

    static ListNode fromList(int[] arr) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        for (int v : arr) { curr.next = new ListNode(v); curr = curr.next; }
        return dummy.next;
    }

    static String toList(ListNode head) {
        StringBuilder sb = new StringBuilder("[");
        while (head != null) {
            sb.append(head.val);
            if (head.next != null) sb.append(", ");
            head = head.next;
        }
        return sb.append("]").toString();
    }

    public static ListNode merge(ListNode l1, ListNode l2) {
        ListNode dummy = new ListNode(0);
        ListNode curr = dummy;
        while (l1 != null && l2 != null) {
            if (l1.val <= l2.val) {
                curr.next = l1;
                l1 = l1.next;
            } else {
                curr.next = l2;
                l2 = l2.next;
            }
            curr = curr.next;
        }
        curr.next = (l1 != null) ? l1 : l2;
        return dummy.next;
    }

    public static void main(String[] args) {
        System.out.println(toList(merge(fromList(new int[]{1, 3, 5}), fromList(new int[]{2, 4, 6}))));  // [1, 2, 3, 4, 5, 6]
        System.out.println(toList(merge(fromList(new int[]{1, 2}), fromList(new int[]{3, 4}))));         // [1, 2, 3, 4]
    }
}

5. Stacks & Queues

Monotonic Stack

A monotonic stack maintains elements in increasing (or decreasing) order by popping elements that violate the invariant. The classic application is Next Greater Element: for each element, find the first element to its right that is strictly larger.

def next_greater_element(nums: list[int]) -> list[int]:
    """For each nums[i], find the next greater element to its right.
    Returns -1 if none exists.
    Time: O(n)  Space: O(n)

    Invariant: the stack holds indices of elements whose next-greater
    has NOT yet been found, in DECREASING order of value.
    """
    n = len(nums)
    result = [-1] * n
    stack: list[int] = []   # indices, monotonically decreasing values

    for i in range(n):
        # Current element is the "next greater" for everything it beats
        while stack and nums[i] > nums[stack[-1]]:
            idx = stack.pop()
            result[idx] = nums[i]
        stack.append(i)

    return result


def largest_rectangle_histogram(heights: list[int]) -> int:
    """Largest rectangle in histogram.
    Time: O(n)  Space: O(n)

    Stack holds indices of bars in increasing height order.
    When we find a shorter bar, we pop and compute rectangle width.
    """
    stack: list[int] = []
    max_area = 0
    heights = heights + [0]   # sentinel forces flush at end

    for i, h in enumerate(heights):
        start = i
        while stack and heights[stack[-1]] > h:
            idx = stack.pop()
            width = i - (stack[-1] + 1 if stack else 0)
            max_area = max(max_area, heights[idx] * width)
            start = idx           # rectangle can extend back to this index
        stack.append(i)

    return max_area


def daily_temperatures(temps: list[int]) -> list[int]:
    """Days until a warmer temperature; 0 if none.
    Classic monotonic stack application.
    Time: O(n)  Space: O(n)
    """
    result = [0] * len(temps)
    stack: list[int] = []   # indices, decreasing temperatures

    for i, t in enumerate(temps):
        while stack and t > temps[stack[-1]]:
            idx = stack.pop()
            result[idx] = i - idx
        stack.append(i)

    return result

Valid Parentheses

def is_valid(s: str) -> bool:
    """Check balanced brackets: ()[]{}
    Time: O(n)  Space: O(n)
    """
    stack: list[str] = []
    matching = {')': '(', ']': '[', '}': '{'}

    for ch in s:
        if ch in matching:                      # closing bracket
            if not stack or stack[-1] != matching[ch]:
                return False
            stack.pop()
        else:                                   # opening bracket
            stack.append(ch)

    return len(stack) == 0


def min_remove_for_valid(s: str) -> str:
    """Remove minimum parentheses to make string valid.
    Time: O(n)  Space: O(n)
    """
    stack: list[int] = []   # indices of unmatched '('
    to_remove: set[int] = set()

    for i, ch in enumerate(s):
        if ch == '(':
            stack.append(i)
        elif ch == ')':
            if stack:
                stack.pop()
            else:
                to_remove.add(i)   # unmatched ')'

    to_remove.update(stack)        # unmatched '('
    return ''.join(ch for i, ch in enumerate(s) if i not in to_remove)

Sliding Window Maximum (Monotonic Deque)

from collections import deque

def sliding_window_maximum(nums: list[int], k: int) -> list[int]:
    """Maximum of each window of size k.
    Time: O(n)  Space: O(k)

    Deque stores indices in DECREASING order of value.
    Front is always the index of the current window maximum.
    """
    result: list[int] = []
    dq: deque[int] = deque()   # indices, values decreasing front->back

    for i, x in enumerate(nums):
        # Remove indices outside the window
        while dq and dq[0] < i - k + 1:
            dq.popleft()

        # Maintain decreasing invariant — pop smaller elements from back
        while dq and nums[dq[-1]] < x:
            dq.pop()

        dq.append(i)

        if i >= k - 1:   # window is fully formed
            result.append(nums[dq[0]])

    return result

Min Stack

class MinStack:
    """Stack that supports O(1) getMin().
    Strategy: pair each pushed value with the minimum at that point.
    Time: O(1) all ops  Space: O(n)
    """
    def __init__(self):
        self._stack: list[tuple[int, int]] = []  # (value, current_min)

    def push(self, val: int) -> None:
        current_min = min(val, self._stack[-1][1]) if self._stack else val
        self._stack.append((val, current_min))

    def pop(self) -> None:
        self._stack.pop()

    def top(self) -> int:
        return self._stack[-1][0]

    def getMin(self) -> int:
        return self._stack[-1][1]
Practice
Challenge

Valid Parentheses

Given a string containing only ()[]{}, return True if the brackets are correctly matched and nested.

def is_valid(s):
    stack = []
    matching = {')': '(', ']': '[', '}': '{'}
    for ch in s:
        if ch in '([{':
            stack.append(ch)
        else:
            # YOUR CODE HERE: check stack and matching closer
            pass
    return len(stack) == 0

print(is_valid("({[]})"))  # True
print(is_valid("()[]{}"))  # True
print(is_valid("({[)"))    # False
import java.util.*;

public class Solution {
    public static boolean isValid(String s) {
        Deque<Character> stack = new ArrayDeque<>();
        for (char ch : s.toCharArray()) {
            if (ch == '(' || ch == '[' || ch == '{') {
                stack.push(ch);
            } else {
                // YOUR CODE HERE: check stack and matching closer
            }
        }
        return stack.isEmpty();
    }

    public static void main(String[] args) {
        System.out.println(isValid("({[]})"));  // true
        System.out.println(isValid("()[]{}"));  // true
        System.out.println(isValid("({[)"));    // false
    }
}
For a closing bracket: if the stack is empty or stack[-1] != matching[ch], return False. Otherwise stack.pop(). At the end, the stack must be empty.
def is_valid(s):
    stack = []
    matching = {')': '(', ']': '[', '}': '{'}
    for ch in s:
        if ch in '([{':
            stack.append(ch)
        else:
            if not stack or stack[-1] != matching[ch]:
                return False
            stack.pop()
    return len(stack) == 0

print(is_valid("({[]})"))  # True
print(is_valid("()[]{}"))  # True
print(is_valid("({[)"))    # False
import java.util.*;

public class Solution {
    public static boolean isValid(String s) {
        Deque<Character> stack = new ArrayDeque<>();
        for (char ch : s.toCharArray()) {
            if (ch == '(' || ch == '[' || ch == '{') {
                stack.push(ch);
            } else {
                if (stack.isEmpty()) return false;
                char top = stack.pop();
                if (ch == ')' && top != '(') return false;
                if (ch == ']' && top != '[') return false;
                if (ch == '}' && top != '{') return false;
            }
        }
        return stack.isEmpty();
    }

    public static void main(String[] args) {
        System.out.println(isValid("({[]})"));  // true
        System.out.println(isValid("()[]{}"));  // true
        System.out.println(isValid("({[)"));    // false
    }
}
Challenge

Min Stack

Implement a stack that supports push, pop, and get_min all in O(1) time. Track the running minimum with a second stack.

class MinStack:
    def __init__(self):
        self.stack = []
        self.min_stack = []

    def push(self, val):
        self.stack.append(val)
        # YOUR CODE HERE: push to min_stack

    def pop(self):
        self.stack.pop()
        # YOUR CODE HERE: pop from min_stack

    def get_min(self):
        # YOUR CODE HERE: return current minimum
        pass

ms = MinStack()
ms.push(5)
ms.push(2)
ms.push(7)
print(ms.get_min())  # 2
ms.pop()
print(ms.get_min())  # 2
ms.pop()
print(ms.get_min())  # 5
import java.util.*;

public class Solution {
    static class MinStack {
        private Deque<Integer> stack = new ArrayDeque<>();
        private Deque<Integer> minStack = new ArrayDeque<>();

        public void push(int val) {
            stack.push(val);
            // YOUR CODE HERE: push to minStack
        }

        public void pop() {
            stack.pop();
            // YOUR CODE HERE: pop from minStack
        }

        public int getMin() {
            // YOUR CODE HERE: return current minimum
            return 0;
        }
    }

    public static void main(String[] args) {
        MinStack ms = new MinStack();
        ms.push(5);
        ms.push(2);
        ms.push(7);
        System.out.println(ms.getMin());  // 2
        ms.pop();
        System.out.println(ms.getMin());  // 2
        ms.pop();
        System.out.println(ms.getMin());  // 5
    }
}
When pushing, append min(val, min_stack[-1]) to min_stack (or just val if min_stack is empty). When popping, pop from both stacks. get_min returns min_stack[-1].
class MinStack:
    def __init__(self):
        self.stack = []
        self.min_stack = []

    def push(self, val):
        self.stack.append(val)
        current_min = val if not self.min_stack else min(val, self.min_stack[-1])
        self.min_stack.append(current_min)

    def pop(self):
        self.stack.pop()
        self.min_stack.pop()

    def get_min(self):
        return self.min_stack[-1]

ms = MinStack()
ms.push(5)
ms.push(2)
ms.push(7)
print(ms.get_min())  # 2
ms.pop()
print(ms.get_min())  # 2
ms.pop()
print(ms.get_min())  # 5
import java.util.*;

public class Solution {
    static class MinStack {
        private Deque<Integer> stack = new ArrayDeque<>();
        private Deque<Integer> minStack = new ArrayDeque<>();

        public void push(int val) {
            stack.push(val);
            int currentMin = minStack.isEmpty() ? val : Math.min(val, minStack.peek());
            minStack.push(currentMin);
        }

        public void pop() {
            stack.pop();
            minStack.pop();
        }

        public int getMin() {
            return minStack.peek();
        }
    }

    public static void main(String[] args) {
        MinStack ms = new MinStack();
        ms.push(5);
        ms.push(2);
        ms.push(7);
        System.out.println(ms.getMin());  // 2
        ms.pop();
        System.out.println(ms.getMin());  // 2
        ms.pop();
        System.out.println(ms.getMin());  // 5
    }
}
Challenge

Next Greater Element

For each element in the array, find the next greater element to its right. If none exists, use -1. Use a monotonic stack for O(n) time.

def next_greater(nums):
    result = [-1] * len(nums)
    stack = []  # stores indices
    for i, num in enumerate(nums):
        # YOUR CODE HERE: while stack and nums[stack[-1]] < num, resolve stack top
        stack.append(i)
    return result

print(next_greater([4, 1, 2, 7]))   # [7, 2, 7, -1]
print(next_greater([1, 3, 2, 4]))   # [3, 4, 4, -1]
import java.util.*;

public class Solution {
    public static int[] nextGreater(int[] nums) {
        int[] result = new int[nums.length];
        Arrays.fill(result, -1);
        Deque<Integer> stack = new ArrayDeque<>();  // stores indices
        for (int i = 0; i < nums.length; i++) {
            // YOUR CODE HERE: while stack not empty and nums[stack.peek()] < nums[i], resolve
            stack.push(i);
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(nextGreater(new int[]{4, 1, 2, 7})));  // [7, 2, 7, -1]
        System.out.println(Arrays.toString(nextGreater(new int[]{1, 3, 2, 4})));  // [3, 4, 4, -1]
    }
}
The stack holds indices of elements waiting for their next greater. When you reach element i, pop all stack indices where nums[index] < nums[i] — the answer for those positions is nums[i]. Then push i.
def next_greater(nums):
    result = [-1] * len(nums)
    stack = []  # stores indices
    for i, num in enumerate(nums):
        while stack and nums[stack[-1]] < num:
            idx = stack.pop()
            result[idx] = num
        stack.append(i)
    return result

print(next_greater([4, 1, 2, 7]))   # [7, 2, 7, -1]
print(next_greater([1, 3, 2, 4]))   # [3, 4, 4, -1]
import java.util.*;

public class Solution {
    public static int[] nextGreater(int[] nums) {
        int[] result = new int[nums.length];
        Arrays.fill(result, -1);
        Deque<Integer> stack = new ArrayDeque<>();  // stores indices
        for (int i = 0; i < nums.length; i++) {
            while (!stack.isEmpty() && nums[stack.peek()] < nums[i]) {
                result[stack.pop()] = nums[i];
            }
            stack.push(i);
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(nextGreater(new int[]{4, 1, 2, 7})));  // [7, 2, 7, -1]
        System.out.println(Arrays.toString(nextGreater(new int[]{1, 3, 2, 4})));  // [3, 4, 4, -1]
    }
}

6. Binary Trees

class TreeNode:
    def __init__(self, val: int = 0,
                 left: 'TreeNode | None' = None,
                 right: 'TreeNode | None' = None):
        self.val = val
        self.left = left
        self.right = right

Traversals — Iterative & Recursive

from collections import deque

# ── Recursive (clean but uses O(h) call stack) ──────────────────────────────

def inorder_recursive(root: TreeNode | None) -> list[int]:
    """Left → Root → Right. Gives sorted order for BST."""
    result: list[int] = []
    def dfs(node: TreeNode | None) -> None:
        if not node:
            return
        dfs(node.left)
        result.append(node.val)
        dfs(node.right)
    dfs(root)
    return result


def preorder_recursive(root: TreeNode | None) -> list[int]:
    """Root → Left → Right. Useful for tree copy / serialization."""
    result: list[int] = []
    def dfs(node: TreeNode | None) -> None:
        if not node:
            return
        result.append(node.val)
        dfs(node.left)
        dfs(node.right)
    dfs(root)
    return result


def postorder_recursive(root: TreeNode | None) -> list[int]:
    """Left → Right → Root. Useful for deletion / bottom-up aggregation."""
    result: list[int] = []
    def dfs(node: TreeNode | None) -> None:
        if not node:
            return
        dfs(node.left)
        dfs(node.right)
        result.append(node.val)
    dfs(root)
    return result


# ── Iterative (avoids stack overflow on deep trees) ─────────────────────────

def inorder_iterative(root: TreeNode | None) -> list[int]:
    result: list[int] = []
    stack: list[TreeNode] = []
    curr = root
    while curr or stack:
        while curr:               # go as far left as possible
            stack.append(curr)
            curr = curr.left
        curr = stack.pop()        # process node
        result.append(curr.val)
        curr = curr.right         # move to right subtree
    return result


def preorder_iterative(root: TreeNode | None) -> list[int]:
    if not root:
        return []
    result: list[int] = []
    stack = [root]
    while stack:
        node = stack.pop()
        result.append(node.val)
        if node.right:            # push right first so left is processed first
            stack.append(node.right)
        if node.left:
            stack.append(node.left)
    return result


def postorder_iterative(root: TreeNode | None) -> list[int]:
    """Trick: reverse of modified preorder (Root→Right→Left reversed = Left→Right→Root)."""
    if not root:
        return []
    result: list[int] = []
    stack = [root]
    while stack:
        node = stack.pop()
        result.append(node.val)
        if node.left:
            stack.append(node.left)
        if node.right:
            stack.append(node.right)
    return result[::-1]


# ── BFS level-order ──────────────────────────────────────────────────────────

def level_order(root: TreeNode | None) -> list[list[int]]:
    """Returns list of levels. Time: O(n)  Space: O(w) where w = max width."""
    if not root:
        return []
    result: list[list[int]] = []
    queue: deque[TreeNode] = deque([root])
    while queue:
        level_size = len(queue)
        level: list[int] = []
        for _ in range(level_size):
            node = queue.popleft()
            level.append(node.val)
            if node.left:
                queue.append(node.left)
            if node.right:
                queue.append(node.right)
        result.append(level)
    return result

DFS Patterns: Top-Down vs. Bottom-Up

# ── Top-Down: pass information from parent to children ──────────────────────

def max_depth_top_down(root: TreeNode | None) -> int:
    """Classic top-down: carry current depth, update global max."""
    result = [0]
    def dfs(node: TreeNode | None, depth: int) -> None:
        if not node:
            return
        result[0] = max(result[0], depth)
        dfs(node.left, depth + 1)
        dfs(node.right, depth + 1)
    dfs(root, 1)
    return result[0]


def has_path_sum(root: TreeNode | None, target: int) -> bool:
    """Does any root-to-leaf path sum to target?"""
    def dfs(node: TreeNode | None, remaining: int) -> bool:
        if not node:
            return False
        remaining -= node.val
        if not node.left and not node.right:   # leaf
            return remaining == 0
        return dfs(node.left, remaining) or dfs(node.right, remaining)
    return dfs(root, target)


# ── Bottom-Up: aggregate from children, return to parent ────────────────────

def max_depth_bottom_up(root: TreeNode | None) -> int:
    """Bottom-up is usually cleaner for aggregation problems."""
    if not root:
        return 0
    return 1 + max(max_depth_bottom_up(root.left),
                   max_depth_bottom_up(root.right))


def diameter(root: TreeNode | None) -> int:
    """Longest path between any two nodes (may not pass through root).
    Time: O(n)  Space: O(h)

    Key insight: at each node, diameter through it = left_depth + right_depth.
    """
    best = [0]
    def depth(node: TreeNode | None) -> int:
        if not node:
            return 0
        left  = depth(node.left)
        right = depth(node.right)
        best[0] = max(best[0], left + right)   # update diameter
        return 1 + max(left, right)            # return depth to parent
    depth(root)
    return best[0]


def max_path_sum(root: TreeNode | None) -> int:
    """Maximum path sum between any two nodes.  Values can be negative.
    Time: O(n)  Space: O(h)
    """
    best = [float('-inf')]
    def gain(node: TreeNode | None) -> int:
        if not node:
            return 0
        left  = max(gain(node.left),  0)   # ignore negative branches
        right = max(gain(node.right), 0)
        best[0] = max(best[0], node.val + left + right)
        return node.val + max(left, right)  # can only extend in one direction
    gain(root)
    return best[0]

Lowest Common Ancestor (LCA)

def lca(root: TreeNode | None,
        p: TreeNode, q: TreeNode) -> TreeNode | None:
    """LCA for general binary tree (not necessarily BST).
    Time: O(n)  Space: O(h)

    Logic: if the current node is p or q, it must be the LCA of itself.
    Otherwise the LCA is wherever both left and right return non-None.
    """
    if not root or root is p or root is q:
        return root
    left  = lca(root.left,  p, q)
    right = lca(root.right, p, q)
    if left and right:
        return root    # p and q are in different subtrees
    return left or right
Practice
Challenge

Maximum Depth of Binary Tree

Return the maximum depth of a binary tree. Depth is the number of nodes along the longest path from root to leaf. Fill in the recursive body.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def max_depth(root):
    if root is None:
        return ___
    left  = max_depth(root.left)
    right = max_depth(root.right)
    return ___ + max(left, right)

root = TreeNode(3,
    TreeNode(9),
    TreeNode(20, TreeNode(15), TreeNode(7)))
print(max_depth(root))  # 3
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static int maxDepth(TreeNode root) {
        if (root == null) return ___;
        int left  = maxDepth(root.left);
        int right = maxDepth(root.right);
        return ___ + Math.max(left, right);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(3,
            new TreeNode(9),
            new TreeNode(20, new TreeNode(15), new TreeNode(7)));
        System.out.println(maxDepth(root));  // 3
    }
}
Base case: None (or null) returns 0. For any other node return 1 + max(leftDepth, rightDepth).
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def max_depth(root):
    if root is None:
        return 0
    left  = max_depth(root.left)
    right = max_depth(root.right)
    return 1 + max(left, right)

root = TreeNode(3,
    TreeNode(9),
    TreeNode(20, TreeNode(15), TreeNode(7)))
print(max_depth(root))  # 3
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static int maxDepth(TreeNode root) {
        if (root == null) return 0;
        int left  = maxDepth(root.left);
        int right = maxDepth(root.right);
        return 1 + Math.max(left, right);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(3,
            new TreeNode(9),
            new TreeNode(20, new TreeNode(15), new TreeNode(7)));
        System.out.println(maxDepth(root));  // 3
    }
}
Challenge

Invert Binary Tree

Mirror a binary tree by swapping left and right children at every node. Fill in the swap and the recursive calls.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def invert(root):
    if root is None:
        return None
    root.left, root.right = ___, ___
    return root

def inorder(node, result=None):
    if result is None:
        result = []
    if node:
        inorder(node.left, result)
        result.append(node.val)
        inorder(node.right, result)
    return result

root = TreeNode(3,
    TreeNode(9),
    TreeNode(20, TreeNode(15), TreeNode(7)))
invert(root)
print(inorder(root))  # [7, 20, 15, 3, 9]
import java.util.*;

public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static TreeNode invert(TreeNode root) {
        if (root == null) return null;
        TreeNode tmp = root.left;
        root.left  = ___;
        root.right = ___;
        invert(root.left);
        invert(root.right);
        return root;
    }

    static List<Integer> inorder(TreeNode node) {
        List<Integer> res = new ArrayList<>();
        inorderHelper(node, res);
        return res;
    }
    static void inorderHelper(TreeNode node, List<Integer> res) {
        if (node == null) return;
        inorderHelper(node.left, res);
        res.add(node.val);
        inorderHelper(node.right, res);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(3,
            new TreeNode(9),
            new TreeNode(20, new TreeNode(15), new TreeNode(7)));
        invert(root);
        System.out.println(inorder(root));  // [7, 20, 15, 3, 9]
    }
}
Swap left and right first, then recurse: root.left, root.right = invert(root.right), invert(root.left).
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def invert(root):
    if root is None:
        return None
    root.left, root.right = invert(root.right), invert(root.left)
    return root

def inorder(node, result=None):
    if result is None:
        result = []
    if node:
        inorder(node.left, result)
        result.append(node.val)
        inorder(node.right, result)
    return result

root = TreeNode(3,
    TreeNode(9),
    TreeNode(20, TreeNode(15), TreeNode(7)))
invert(root)
print(inorder(root))  # [7, 20, 15, 3, 9]
import java.util.*;

public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static TreeNode invert(TreeNode root) {
        if (root == null) return null;
        TreeNode tmp  = root.left;
        root.left     = invert(root.right);
        root.right    = invert(tmp);
        return root;
    }

    static List<Integer> inorder(TreeNode node) {
        List<Integer> res = new ArrayList<>();
        inorderHelper(node, res);
        return res;
    }
    static void inorderHelper(TreeNode node, List<Integer> res) {
        if (node == null) return;
        inorderHelper(node.left, res);
        res.add(node.val);
        inorderHelper(node.right, res);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(3,
            new TreeNode(9),
            new TreeNode(20, new TreeNode(15), new TreeNode(7)));
        invert(root);
        System.out.println(inorder(root));  // [7, 20, 15, 3, 9]
    }
}
Challenge

Same Tree

Check whether two binary trees are structurally identical with the same node values. Fill in the three conditions.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_same(p, q):
    if p is None and q is None:
        return ___
    if p is None or q is None:
        return ___
    if p.val != q.val:
        return ___
    return is_same(p.left, q.left) and is_same(p.right, q.right)

t1 = TreeNode(1, TreeNode(2), TreeNode(3))
t2 = TreeNode(1, TreeNode(2), TreeNode(3))
t3 = TreeNode(1, TreeNode(2), TreeNode(4))
print(is_same(t1, t2))  # True
print(is_same(t1, t3))  # False
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static boolean isSame(TreeNode p, TreeNode q) {
        if (p == null && q == null) return ___;
        if (p == null || q == null) return ___;
        if (p.val != q.val)         return ___;
        return isSame(p.left, q.left) && isSame(p.right, q.right);
    }

    public static void main(String[] args) {
        TreeNode t1 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        TreeNode t2 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        TreeNode t3 = new TreeNode(1, new TreeNode(2), new TreeNode(4));
        System.out.println(isSame(t1, t2));  // true
        System.out.println(isSame(t1, t3));  // false
    }
}
Both null → True. Exactly one null → False. Values differ → False. Otherwise recurse both subtrees.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_same(p, q):
    if p is None and q is None:
        return True
    if p is None or q is None:
        return False
    if p.val != q.val:
        return False
    return is_same(p.left, q.left) and is_same(p.right, q.right)

t1 = TreeNode(1, TreeNode(2), TreeNode(3))
t2 = TreeNode(1, TreeNode(2), TreeNode(3))
t3 = TreeNode(1, TreeNode(2), TreeNode(4))
print(is_same(t1, t2))  # True
print(is_same(t1, t3))  # False
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static boolean isSame(TreeNode p, TreeNode q) {
        if (p == null && q == null) return true;
        if (p == null || q == null) return false;
        if (p.val != q.val)         return false;
        return isSame(p.left, q.left) && isSame(p.right, q.right);
    }

    public static void main(String[] args) {
        TreeNode t1 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        TreeNode t2 = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        TreeNode t3 = new TreeNode(1, new TreeNode(2), new TreeNode(4));
        System.out.println(isSame(t1, t2));  // true
        System.out.println(isSame(t1, t3));  // false
    }
}

7. Binary Search Trees

A BST satisfies: for every node, all left descendants < node.val < all right descendants. This gives O(h) search/insert/delete — O(log n) when balanced, O(n) when degenerate.

BST Operations

def search_bst(root: TreeNode | None, val: int) -> TreeNode | None:
    """Time: O(h)  Space: O(1) iterative, O(h) recursive."""
    while root:
        if val == root.val:
            return root
        root = root.left if val < root.val else root.right
    return None


def insert_bst(root: TreeNode | None, val: int) -> TreeNode:
    """Time: O(h)  Space: O(h) recursive."""
    if not root:
        return TreeNode(val)
    if val < root.val:
        root.left = insert_bst(root.left, val)
    elif val > root.val:
        root.right = insert_bst(root.right, val)
    return root   # val == root.val: duplicate, no insert


def delete_bst(root: TreeNode | None, val: int) -> TreeNode | None:
    """Time: O(h)  Space: O(h)
    Three cases:
      1. Node has no children: just remove it.
      2. Node has one child: replace with child.
      3. Node has two children: replace val with inorder successor
         (smallest in right subtree), then delete successor.
    """
    if not root:
        return None
    if val < root.val:
        root.left = delete_bst(root.left, val)
    elif val > root.val:
        root.right = delete_bst(root.right, val)
    else:
        if not root.left:
            return root.right
        if not root.right:
            return root.left
        # Find inorder successor (min of right subtree)
        successor = root.right
        while successor.left:
            successor = successor.left
        root.val = successor.val
        root.right = delete_bst(root.right, successor.val)
    return root

BST Validation

def is_valid_bst(root: TreeNode | None) -> bool:
    """Validate using min/max bounds passed down the tree.
    Time: O(n)  Space: O(h)

    Common mistake: only checking parent-child, missing grandparent constraint.
    """
    def validate(node: TreeNode | None,
                 low: float, high: float) -> bool:
        if not node:
            return True
        if node.val <= low or node.val >= high:
            return False
        return (validate(node.left,  low,      node.val) and
                validate(node.right, node.val, high))
    return validate(root, float('-inf'), float('inf'))


def kth_smallest_bst(root: TreeNode | None, k: int) -> int:
    """Inorder traversal gives sorted order; return kth element.
    Time: O(h + k)  Space: O(h)
    """
    stack: list[TreeNode] = []
    curr = root
    count = 0
    while curr or stack:
        while curr:
            stack.append(curr)
            curr = curr.left
        curr = stack.pop()
        count += 1
        if count == k:
            return curr.val
        curr = curr.right
    return -1   # k out of range

Balanced BST from Sorted Array

def sorted_array_to_bst(nums: list[int]) -> TreeNode | None:
    """Build height-balanced BST from sorted array.
    Choose middle element as root at each step.
    Time: O(n)  Space: O(log n) recursion stack.
    """
    if not nums:
        return None
    mid = len(nums) // 2
    root = TreeNode(nums[mid])
    root.left  = sorted_array_to_bst(nums[:mid])
    root.right = sorted_array_to_bst(nums[mid + 1:])
    return root
OperationBalanced BSTUnbalanced BSTSorted Array
SearchO(log n)O(n)O(log n)
InsertO(log n)O(n)O(n)
DeleteO(log n)O(n)O(n)
Min/MaxO(log n)O(n)O(1)
In-orderO(n)O(n)O(n)
Practice
Challenge

Validate BST

Return True if the tree is a valid BST. Every node must satisfy: all values in its left subtree are strictly less, all values in its right subtree are strictly greater. Fill in the bounds checks.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_valid_bst(root, lo=float('-inf'), hi=float('inf')):
    if root is None:
        return True
    if root.val <= lo or root.val >= hi:
        return ___
    return (is_valid_bst(root.left,  ___, root.val) and
            is_valid_bst(root.right, root.val, ___))

valid   = TreeNode(2, TreeNode(1), TreeNode(3))
invalid = TreeNode(1, TreeNode(2), TreeNode(3))
print(is_valid_bst(valid))    # True
print(is_valid_bst(invalid))  # False
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static boolean validate(TreeNode node, long lo, long hi) {
        if (node == null) return true;
        if (node.val <= lo || node.val >= hi) return ___;
        return validate(node.left,  ___, node.val) &&
               validate(node.right, node.val, ___);
    }

    static boolean isValidBST(TreeNode root) {
        return validate(root, Long.MIN_VALUE, Long.MAX_VALUE);
    }

    public static void main(String[] args) {
        TreeNode valid   = new TreeNode(2, new TreeNode(1), new TreeNode(3));
        TreeNode invalid = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        System.out.println(isValidBST(valid));    // true
        System.out.println(isValidBST(invalid));  // false
    }
}
Thread a lo and hi bound down the recursion. When going left, the current node's value becomes the new upper bound (hi). When going right, it becomes the lower bound (lo).
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def is_valid_bst(root, lo=float('-inf'), hi=float('inf')):
    if root is None:
        return True
    if root.val <= lo or root.val >= hi:
        return False
    return (is_valid_bst(root.left,  lo,       root.val) and
            is_valid_bst(root.right, root.val, hi))

valid   = TreeNode(2, TreeNode(1), TreeNode(3))
invalid = TreeNode(1, TreeNode(2), TreeNode(3))
print(is_valid_bst(valid))    # True
print(is_valid_bst(invalid))  # False
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static boolean validate(TreeNode node, long lo, long hi) {
        if (node == null) return true;
        if (node.val <= lo || node.val >= hi) return false;
        return validate(node.left,  lo,       node.val) &&
               validate(node.right, node.val, hi);
    }

    static boolean isValidBST(TreeNode root) {
        return validate(root, Long.MIN_VALUE, Long.MAX_VALUE);
    }

    public static void main(String[] args) {
        TreeNode valid   = new TreeNode(2, new TreeNode(1), new TreeNode(3));
        TreeNode invalid = new TreeNode(1, new TreeNode(2), new TreeNode(3));
        System.out.println(isValidBST(valid));    // true
        System.out.println(isValidBST(invalid));  // false
    }
}
Challenge

Search in BST

Search the BST for a target value. Return the node's value if found, or -1 if not. Fill in the directional recursive calls.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def search_bst(root, target):
    if root is None:
        return -1
    if target == root.val:
        return root.val
    if target < root.val:
        return search_bst(___, target)
    return search_bst(___, target)

root = TreeNode(4,
    TreeNode(2, TreeNode(1), TreeNode(3)),
    TreeNode(7))
print(search_bst(root, 2))  # 2
print(search_bst(root, 5))  # -1
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static int searchBST(TreeNode root, int target) {
        if (root == null) return -1;
        if (target == root.val) return root.val;
        if (target < root.val) return searchBST(___, target);
        return searchBST(___, target);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(4,
            new TreeNode(2, new TreeNode(1), new TreeNode(3)),
            new TreeNode(7));
        System.out.println(searchBST(root, 2));  // 2
        System.out.println(searchBST(root, 5));  // -1
    }
}
If target is less than the current node's value, go left (root.left). If greater, go right (root.right). The BST property guarantees only one path to check.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def search_bst(root, target):
    if root is None:
        return -1
    if target == root.val:
        return root.val
    if target < root.val:
        return search_bst(root.left, target)
    return search_bst(root.right, target)

root = TreeNode(4,
    TreeNode(2, TreeNode(1), TreeNode(3)),
    TreeNode(7))
print(search_bst(root, 2))  # 2
print(search_bst(root, 5))  # -1
public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
        TreeNode(int v, TreeNode l, TreeNode r) { val = v; left = l; right = r; }
    }

    static int searchBST(TreeNode root, int target) {
        if (root == null) return -1;
        if (target == root.val) return root.val;
        if (target < root.val) return searchBST(root.left, target);
        return searchBST(root.right, target);
    }

    public static void main(String[] args) {
        TreeNode root = new TreeNode(4,
            new TreeNode(2, new TreeNode(1), new TreeNode(3)),
            new TreeNode(7));
        System.out.println(searchBST(root, 2));  // 2
        System.out.println(searchBST(root, 5));  // -1
    }
}
Challenge

Sorted Array to Height-Balanced BST

Convert a sorted array into a height-balanced BST. An inorder traversal of the result should reproduce the original array. Fill in the mid-point selection and recursive calls.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def sorted_array_to_bst(nums):
    if not nums:
        return None
    mid = len(nums) ___
    node = TreeNode(nums[mid])
    node.left  = sorted_array_to_bst(nums[___])
    node.right = sorted_array_to_bst(nums[___])
    return node

def inorder(node, result=None):
    if result is None:
        result = []
    if node:
        inorder(node.left, result)
        result.append(node.val)
        inorder(node.right, result)
    return result

root = sorted_array_to_bst([-10, -3, 0, 5, 9])
print(inorder(root))  # [-10, -3, 0, 5, 9]
import java.util.*;

public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
    }

    static TreeNode sortedArrayToBST(int[] nums, int lo, int hi) {
        if (lo > hi) return null;
        int mid = lo + (hi - lo) ___;
        TreeNode node = new TreeNode(nums[mid]);
        node.left  = sortedArrayToBST(nums, ___, mid - 1);
        node.right = sortedArrayToBST(nums, mid + 1, ___);
        return node;
    }

    static List<Integer> inorder(TreeNode node) {
        List<Integer> res = new ArrayList<>();
        inorderHelper(node, res);
        return res;
    }
    static void inorderHelper(TreeNode node, List<Integer> res) {
        if (node == null) return;
        inorderHelper(node.left, res);
        res.add(node.val);
        inorderHelper(node.right, res);
    }

    public static void main(String[] args) {
        int[] nums = {-10, -3, 0, 5, 9};
        TreeNode root = sortedArrayToBST(nums, 0, nums.length - 1);
        System.out.println(inorder(root));  // [-10, -3, 0, 5, 9]
    }
}
Pick the middle index (len(nums) // 2) as the root. Recurse on nums[:mid] for the left subtree and nums[mid+1:] for the right. This keeps the tree balanced.
class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

def sorted_array_to_bst(nums):
    if not nums:
        return None
    mid = len(nums) // 2
    node = TreeNode(nums[mid])
    node.left  = sorted_array_to_bst(nums[:mid])
    node.right = sorted_array_to_bst(nums[mid+1:])
    return node

def inorder(node, result=None):
    if result is None:
        result = []
    if node:
        inorder(node.left, result)
        result.append(node.val)
        inorder(node.right, result)
    return result

root = sorted_array_to_bst([-10, -3, 0, 5, 9])
print(inorder(root))  # [-10, -3, 0, 5, 9]
import java.util.*;

public class Solution {
    static class TreeNode {
        int val;
        TreeNode left, right;
        TreeNode(int v) { val = v; }
    }

    static TreeNode sortedArrayToBST(int[] nums, int lo, int hi) {
        if (lo > hi) return null;
        int mid = lo + (hi - lo) / 2;
        TreeNode node = new TreeNode(nums[mid]);
        node.left  = sortedArrayToBST(nums, lo,      mid - 1);
        node.right = sortedArrayToBST(nums, mid + 1, hi);
        return node;
    }

    static List<Integer> inorder(TreeNode node) {
        List<Integer> res = new ArrayList<>();
        inorderHelper(node, res);
        return res;
    }
    static void inorderHelper(TreeNode node, List<Integer> res) {
        if (node == null) return;
        inorderHelper(node.left, res);
        res.add(node.val);
        inorderHelper(node.right, res);
    }

    public static void main(String[] args) {
        int[] nums = {-10, -3, 0, 5, 9};
        TreeNode root = sortedArrayToBST(nums, 0, nums.length - 1);
        System.out.println(inorder(root));  // [-10, -3, 0, 5, 9]
    }
}

8. Heaps / Priority Queues

A min-heap guarantees the smallest element is always at the top (index 0). Python's heapq is a min-heap. For a max-heap, negate values on push/pop.

import heapq

# Min-heap basics
heap: list[int] = []
heapq.heappush(heap, 5)
heapq.heappush(heap, 1)
heapq.heappush(heap, 3)
print(heap[0])            # 1 — peek min, O(1)
smallest = heapq.heappop(heap)   # 1, O(log n)

# Build from list in O(n)
nums = [3, 1, 4, 1, 5, 9]
heapq.heapify(nums)       # transforms in-place

# Max-heap: negate values
max_heap: list[int] = []
heapq.heappush(max_heap, -5)
heapq.heappush(max_heap, -1)
largest = -heapq.heappop(max_heap)   # 5

# Heap with tuples (priority, value)
tasks: list[tuple[int, str]] = []
heapq.heappush(tasks, (1, "low priority"))
heapq.heappush(tasks, (0, "urgent"))
priority, task = heapq.heappop(tasks)   # (0, "urgent")

Top-K Pattern

def kth_largest(nums: list[int], k: int) -> int:
    """Kth largest element in unsorted array.
    Min-heap of size k: the top is always the kth largest so far.
    Time: O(n log k)  Space: O(k)

    Alternative: quickselect O(n) average — see binary search section.
    """
    heap: list[int] = []
    for x in nums:
        heapq.heappush(heap, x)
        if len(heap) > k:
            heapq.heappop(heap)   # evict smaller elements
    return heap[0]   # smallest of the k largest = kth largest


def top_k_elements(nums: list[int], k: int) -> list[int]:
    """Return k largest elements. Time: O(n log k)  Space: O(k)"""
    return heapq.nlargest(k, nums)   # built-in uses heap internally

Merge K Sorted Lists

def merge_k_sorted(lists: list[ListNode | None]) -> ListNode | None:
    """Merge k sorted linked lists into one sorted list.
    Time: O(n log k) where n = total nodes, k = number of lists.
    Space: O(k) for the heap.
    """
    dummy = ListNode(0)
    curr = dummy
    heap: list[tuple[int, int, ListNode]] = []

    # Initialize heap with head of each list
    # Use list index as tiebreaker to avoid comparing ListNodes
    for i, node in enumerate(lists):
        if node:
            heapq.heappush(heap, (node.val, i, node))

    while heap:
        val, i, node = heapq.heappop(heap)
        curr.next = node
        curr = curr.next
        if node.next:
            heapq.heappush(heap, (node.next.val, i, node.next))

    return dummy.next

Two Heaps — Median of a Stream

class MedianFinder:
    """Maintain running median using two heaps.
    - max_heap: lower half of numbers (negate for max-heap in Python)
    - min_heap: upper half of numbers
    Invariant: len(max_heap) == len(min_heap)  or
               len(max_heap) == len(min_heap) + 1
    Time: O(log n) add, O(1) find_median.  Space: O(n).
    """
    def __init__(self):
        self.max_heap: list[int] = []  # lower half, negate for max
        self.min_heap: list[int] = []  # upper half

    def add_num(self, num: int) -> None:
        # Always push to max_heap first
        heapq.heappush(self.max_heap, -num)

        # Balance: ensure max_heap top <= min_heap top
        if (self.min_heap and
                -self.max_heap[0] > self.min_heap[0]):
            val = -heapq.heappop(self.max_heap)
            heapq.heappush(self.min_heap, val)

        # Balance sizes: max_heap can have at most 1 more element
        if len(self.max_heap) > len(self.min_heap) + 1:
            val = -heapq.heappop(self.max_heap)
            heapq.heappush(self.min_heap, val)
        elif len(self.min_heap) > len(self.max_heap):
            val = heapq.heappop(self.min_heap)
            heapq.heappush(self.max_heap, -val)

    def find_median(self) -> float:
        if len(self.max_heap) > len(self.min_heap):
            return float(-self.max_heap[0])
        return (-self.max_heap[0] + self.min_heap[0]) / 2.0
Practice
Challenge

Kth Largest Element

Find the kth largest element in an unsorted array using a min-heap of size k. Fill in the heap push, the overflow pop condition, and the return value.

import heapq

def kth_largest(nums, k):
    heap = []
    for n in nums:
        heapq.heappush(heap, ___)
        if len(heap) > k:
            heapq.heappop(___)
    return heap[___]

print(kth_largest([3, 2, 1, 5, 6, 4], 2))  # 5
print(kth_largest([3, 2, 3, 1, 2, 4, 5, 5, 6], 4))  # 4
import java.util.*;

public class Solution {
    static int kthLargest(int[] nums, int k) {
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (int n : nums) {
            minHeap.offer(___);
            if (minHeap.size() > k) {
                minHeap.poll();
            }
        }
        return minHeap.peek();
    }

    public static void main(String[] args) {
        System.out.println(kthLargest(new int[]{3, 2, 1, 5, 6, 4}, 2));          // 5
        System.out.println(kthLargest(new int[]{3, 2, 3, 1, 2, 4, 5, 5, 6}, 4)); // 4
    }
}
Push every element onto the min-heap. When the heap grows beyond size k, pop the smallest. After processing all elements the heap top (heap[0]) is the kth largest.
import heapq

def kth_largest(nums, k):
    heap = []
    for n in nums:
        heapq.heappush(heap, n)
        if len(heap) > k:
            heapq.heappop(heap)
    return heap[0]

print(kth_largest([3, 2, 1, 5, 6, 4], 2))  # 5
print(kth_largest([3, 2, 3, 1, 2, 4, 5, 5, 6], 4))  # 4
import java.util.*;

public class Solution {
    static int kthLargest(int[] nums, int k) {
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (int n : nums) {
            minHeap.offer(n);
            if (minHeap.size() > k) {
                minHeap.poll();
            }
        }
        return minHeap.peek();
    }

    public static void main(String[] args) {
        System.out.println(kthLargest(new int[]{3, 2, 1, 5, 6, 4}, 2));          // 5
        System.out.println(kthLargest(new int[]{3, 2, 3, 1, 2, 4, 5, 5, 6}, 4)); // 4
    }
}
Challenge

Sort Using Heap

Sort an array in ascending order using heapq. Fill in the heapify call and the pop loop.

import heapq

def heap_sort(nums):
    heapq.___(nums)
    result = []
    while nums:
        result.append(heapq.___(nums))
    return result

print(heap_sort([5, 3, 8, 1, 2]))  # [1, 2, 3, 5, 8]
print(heap_sort([9, 0, -3, 7]))    # [-3, 0, 7, 9]
import java.util.*;

public class Solution {
    static int[] heapSort(int[] nums) {
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (int n : nums) minHeap.offer(n);
        int[] result = new int[nums.length];
        for (int i = 0; i < result.length; i++) {
            result[i] = ___;
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(heapSort(new int[]{5, 3, 8, 1, 2})));  // [1, 2, 3, 5, 8]
        System.out.println(Arrays.toString(heapSort(new int[]{9, 0, -3, 7})));    // [-3, 0, 7, 9]
    }
}
Use heapq.heapify(nums) to turn the list into a min-heap in-place (O(n)). Then repeatedly call heapq.heappop(nums) to extract the smallest element each time.
import heapq

def heap_sort(nums):
    heapq.heapify(nums)
    result = []
    while nums:
        result.append(heapq.heappop(nums))
    return result

print(heap_sort([5, 3, 8, 1, 2]))  # [1, 2, 3, 5, 8]
print(heap_sort([9, 0, -3, 7]))    # [-3, 0, 7, 9]
import java.util.*;

public class Solution {
    static int[] heapSort(int[] nums) {
        PriorityQueue<Integer> minHeap = new PriorityQueue<>();
        for (int n : nums) minHeap.offer(n);
        int[] result = new int[nums.length];
        for (int i = 0; i < result.length; i++) {
            result[i] = minHeap.poll();
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(Arrays.toString(heapSort(new int[]{5, 3, 8, 1, 2})));  // [1, 2, 3, 5, 8]
        System.out.println(Arrays.toString(heapSort(new int[]{9, 0, -3, 7})));    // [-3, 0, 7, 9]
    }
}
Challenge

Last Stone Weight

Smash the two heaviest stones together each round. If they differ, push the difference back. Repeat until at most one stone remains. Return its weight, or 0 if none left. Fill in the max-heap simulation using negated values.

import heapq

def last_stone(stones):
    heap = [___ for s in stones]   # negate for max-heap
    heapq.heapify(heap)
    while len(heap) > 1:
        a = -heapq.heappop(heap)   # largest
        b = -heapq.heappop(heap)   # second largest
        if a != b:
            heapq.heappush(heap, ___)
    return -heap[0] if heap else ___

print(last_stone([2, 7, 4, 1, 8, 1]))  # 1
print(last_stone([1, 1]))              # 0
import java.util.*;

public class Solution {
    static int lastStone(int[] stones) {
        // Java PriorityQueue is min-heap; use Collections.reverseOrder() for max-heap
        PriorityQueue<Integer> maxHeap = new PriorityQueue<>(Collections.reverseOrder());
        for (int s : stones) maxHeap.offer(s);
        while (maxHeap.size() > 1) {
            int a = ___;  // largest
            int b = ___;  // second largest
            if (a != b) maxHeap.offer(a - b);
        }
        return maxHeap.isEmpty() ? 0 : maxHeap.peek();
    }

    public static void main(String[] args) {
        System.out.println(lastStone(new int[]{2, 7, 4, 1, 8, 1}));  // 1
        System.out.println(lastStone(new int[]{1, 1}));               // 0
    }
}
Python's heapq is a min-heap. Store negated weights so the most negative (= largest weight) is always at the top. Push -(a - b) when the stones differ. Return -heap[0] if one stone remains, else 0.
import heapq

def last_stone(stones):
    heap = [-s for s in stones]
    heapq.heapify(heap)
    while len(heap) > 1:
        a = -heapq.heappop(heap)
        b = -heapq.heappop(heap)
        if a != b:
            heapq.heappush(heap, -(a - b))
    return -heap[0] if heap else 0

print(last_stone([2, 7, 4, 1, 8, 1]))  # 1
print(last_stone([1, 1]))              # 0
import java.util.*;

public class Solution {
    static int lastStone(int[] stones) {
        PriorityQueue<Integer> maxHeap = new PriorityQueue<>(Collections.reverseOrder());
        for (int s : stones) maxHeap.offer(s);
        while (maxHeap.size() > 1) {
            int a = maxHeap.poll();
            int b = maxHeap.poll();
            if (a != b) maxHeap.offer(a - b);
        }
        return maxHeap.isEmpty() ? 0 : maxHeap.peek();
    }

    public static void main(String[] args) {
        System.out.println(lastStone(new int[]{2, 7, 4, 1, 8, 1}));  // 1
        System.out.println(lastStone(new int[]{1, 1}));               // 0
    }
}

Binary search works on any monotone predicate — not just sorted arrays. The key is identifying the invariant: at the end, left and right have converged to the answer boundary.

Standard Template

def binary_search(nums: list[int], target: int) -> int:
    """Return index of target, or -1 if not found.
    Time: O(log n)  Space: O(1)

    Invariant: target is in nums[left..right] if it exists.
    """
    left, right = 0, len(nums) - 1
    while left <= right:
        mid = left + (right - left) // 2   # avoids integer overflow vs (l+r)//2
        if nums[mid] == target:
            return mid
        elif nums[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1

Leftmost / Rightmost Occurrence

def search_leftmost(nums: list[int], target: int) -> int:
    """First (leftmost) index of target. Returns -1 if not found.
    Time: O(log n)  Space: O(1)
    """
    left, right = 0, len(nums) - 1
    result = -1
    while left <= right:
        mid = (left + right) // 2
        if nums[mid] == target:
            result = mid
            right = mid - 1   # keep searching left
        elif nums[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return result


def search_rightmost(nums: list[int], target: int) -> int:
    """Last (rightmost) index of target. Returns -1 if not found."""
    left, right = 0, len(nums) - 1
    result = -1
    while left <= right:
        mid = (left + right) // 2
        if nums[mid] == target:
            result = mid
            left = mid + 1    # keep searching right
        elif nums[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return result


def count_occurrences(nums: list[int], target: int) -> int:
    left  = search_leftmost(nums, target)
    right = search_rightmost(nums, target)
    if left == -1:
        return 0
    return right - left + 1

Binary Search on Answer

When the problem asks for the minimum/maximum value satisfying a condition and the feasibility function is monotone, binary search on the answer space directly.

def min_days_to_make_bouquets(bloom_day: list[int],
                               m: int, k: int) -> int:
    """Minimum days to make m bouquets, each needing k adjacent flowers.
    Binary search on the answer: day range [min(bloom_day), max(bloom_day)].
    Time: O(n log(max_day))  Space: O(1)
    """
    def can_make(day: int) -> bool:
        bouquets = consecutive = 0
        for bd in bloom_day:
            if bd <= day:
                consecutive += 1
                if consecutive == k:
                    bouquets += 1
                    consecutive = 0
            else:
                consecutive = 0
        return bouquets >= m

    if len(bloom_day) < m * k:
        return -1

    left, right = min(bloom_day), max(bloom_day)
    while left < right:
        mid = (left + right) // 2
        if can_make(mid):
            right = mid       # valid, try smaller
        else:
            left = mid + 1    # invalid, need more days
    return left


def split_array_largest_sum(nums: list[int], k: int) -> int:
    """Minimize the largest sum when splitting array into k subarrays.
    Binary search on answer: sum range [max(nums), sum(nums)].
    Time: O(n log(sum))  Space: O(1)
    """
    def can_split(limit: int) -> bool:
        parts = 1
        running = 0
        for x in nums:
            if running + x > limit:
                parts += 1
                running = 0
            running += x
        return parts <= k

    left, right = max(nums), sum(nums)
    while left < right:
        mid = (left + right) // 2
        if can_split(mid):
            right = mid
        else:
            left = mid + 1
    return left

Rotated Sorted Array

def search_rotated(nums: list[int], target: int) -> int:
    """Search in rotated sorted array (no duplicates).
    Time: O(log n)  Space: O(1)

    Key insight: one half of the array is always sorted.
    Determine which half, check if target falls in it.
    """
    left, right = 0, len(nums) - 1
    while left <= right:
        mid = (left + right) // 2
        if nums[mid] == target:
            return mid

        if nums[left] <= nums[mid]:       # left half is sorted
            if nums[left] <= target < nums[mid]:
                right = mid - 1
            else:
                left = mid + 1
        else:                             # right half is sorted
            if nums[mid] < target <= nums[right]:
                left = mid + 1
            else:
                right = mid - 1
    return -1


def find_min_rotated(nums: list[int]) -> int:
    """Find minimum element in rotated sorted array.
    Time: O(log n)  Space: O(1)
    """
    left, right = 0, len(nums) - 1
    while left < right:
        mid = (left + right) // 2
        if nums[mid] > nums[right]:
            left = mid + 1    # min is in right half
        else:
            right = mid       # mid could be the min
    return nums[left]

10. Graphs — BFS & DFS

Representations

from collections import defaultdict, deque

# Adjacency list (most common in interviews)
# Undirected graph
graph: dict[int, list[int]] = defaultdict(list)
edges = [(0, 1), (0, 2), (1, 3), (2, 3), (3, 4)]
for u, v in edges:
    graph[u].append(v)
    graph[v].append(u)

# Weighted graph: store (neighbor, weight) pairs
weighted: dict[int, list[tuple[int, int]]] = defaultdict(list)
for u, v, w in [(0,1,4), (0,2,1), (1,3,2)]:
    weighted[u].append((v, w))
    weighted[v].append((u, w))

# Grid as implicit graph — 4-directional neighbors
def neighbors_4(r: int, c: int,
                rows: int, cols: int) -> list[tuple[int, int]]:
    for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
        nr, nc = r + dr, c + dc
        if 0 <= nr < rows and 0 <= nc < cols:
            yield nr, nc

BFS Template

def bfs(graph: dict[int, list[int]], start: int) -> list[int]:
    """BFS: explores nodes level by level (shortest path in unweighted graph).
    Time: O(V + E)  Space: O(V)
    """
    visited: set[int] = {start}
    queue: deque[int] = deque([start])
    order: list[int] = []

    while queue:
        node = queue.popleft()
        order.append(node)
        for neighbor in graph[node]:
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append(neighbor)
    return order


def bfs_shortest_path(graph: dict[int, list[int]],
                      start: int, end: int) -> int:
    """Shortest path length in unweighted graph. Returns -1 if unreachable.
    Time: O(V + E)  Space: O(V)
    """
    if start == end:
        return 0
    visited: set[int] = {start}
    queue: deque[tuple[int, int]] = deque([(start, 0)])

    while queue:
        node, dist = queue.popleft()
        for neighbor in graph[node]:
            if neighbor == end:
                return dist + 1
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, dist + 1))
    return -1


def bfs_grid(grid: list[list[int]],
             start: tuple[int, int]) -> list[list[int]]:
    """BFS on a grid, returns distance from start to every reachable cell.
    0 = open, 1 = wall.
    Time: O(rows * cols)  Space: O(rows * cols)
    """
    rows, cols = len(grid), len(grid[0])
    dist = [[-1] * cols for _ in range(rows)]
    sr, sc = start
    dist[sr][sc] = 0
    queue: deque[tuple[int, int]] = deque([(sr, sc)])

    while queue:
        r, c = queue.popleft()
        for nr, nc in neighbors_4(r, c, rows, cols):
            if grid[nr][nc] == 0 and dist[nr][nc] == -1:
                dist[nr][nc] = dist[r][c] + 1
                queue.append((nr, nc))
    return dist

DFS Template

def dfs_recursive(graph: dict[int, list[int]],
                  node: int,
                  visited: set[int] | None = None) -> list[int]:
    """DFS recursive. Time: O(V + E)  Space: O(V) call stack."""
    if visited is None:
        visited = set()
    visited.add(node)
    order = [node]
    for neighbor in graph[node]:
        if neighbor not in visited:
            order.extend(dfs_recursive(graph, neighbor, visited))
    return order


def dfs_iterative(graph: dict[int, list[int]], start: int) -> list[int]:
    """DFS iterative using explicit stack.
    Note: visits in slightly different order than recursive.
    Time: O(V + E)  Space: O(V)
    """
    visited: set[int] = set()
    stack: list[int] = [start]
    order: list[int] = []

    while stack:
        node = stack.pop()
        if node in visited:
            continue
        visited.add(node)
        order.append(node)
        # Push neighbors in reverse so original order is preserved
        for neighbor in reversed(graph[node]):
            if neighbor not in visited:
                stack.append(neighbor)
    return order

Topological Sort (Kahn's Algorithm)

def topological_sort(num_nodes: int,
                     edges: list[tuple[int, int]]) -> list[int]:
    """Kahn's algorithm — BFS-based topological sort.
    Time: O(V + E)  Space: O(V + E)

    Returns empty list if graph has a cycle (not a DAG).
    """
    graph: dict[int, list[int]] = defaultdict(list)
    in_degree: list[int] = [0] * num_nodes

    for u, v in edges:
        graph[u].append(v)
        in_degree[v] += 1

    # Start with all nodes that have no incoming edges
    queue: deque[int] = deque(
        node for node in range(num_nodes) if in_degree[node] == 0
    )
    order: list[int] = []

    while queue:
        node = queue.popleft()
        order.append(node)
        for neighbor in graph[node]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    # If not all nodes processed, cycle exists
    return order if len(order) == num_nodes else []


def course_schedule(num_courses: int,
                    prerequisites: list[list[int]]) -> bool:
    """Can all courses be finished? True if no cycle exists.
    Classic topological sort application.
    Time: O(V + E)  Space: O(V + E)
    """
    order = topological_sort(num_courses,
                             [(b, a) for a, b in prerequisites])
    return len(order) == num_courses

Cycle Detection

def has_cycle_directed(num_nodes: int,
                       edges: list[tuple[int, int]]) -> bool:
    """Detect cycle in directed graph using DFS coloring.
    Colors: 0=unvisited, 1=in current path, 2=fully processed.
    Time: O(V + E)  Space: O(V)
    """
    graph: dict[int, list[int]] = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)
    color = [0] * num_nodes

    def dfs(node: int) -> bool:
        color[node] = 1   # mark as in-progress
        for neighbor in graph[node]:
            if color[neighbor] == 1:
                return True   # back edge = cycle
            if color[neighbor] == 0 and dfs(neighbor):
                return True
        color[node] = 2   # fully processed
        return False

    return any(color[n] == 0 and dfs(n) for n in range(num_nodes))


def has_cycle_undirected(graph: dict[int, list[int]]) -> bool:
    """Detect cycle in undirected graph.
    Track parent to avoid treating the edge back to parent as a cycle.
    Time: O(V + E)  Space: O(V)
    """
    visited: set[int] = set()

    def dfs(node: int, parent: int) -> bool:
        visited.add(node)
        for neighbor in graph[node]:
            if neighbor not in visited:
                if dfs(neighbor, node):
                    return True
            elif neighbor != parent:
                return True   # back edge to non-parent = cycle
        return False

    return any(dfs(n, -1) for n in graph if n not in visited)

Island Problems / Flood Fill

def num_islands(grid: list[list[str]]) -> int:
    """Count connected components of '1's.
    Time: O(m * n)  Space: O(m * n) worst case DFS stack.
    """
    if not grid:
        return 0
    rows, cols = len(grid), len(grid[0])
    count = 0

    def dfs(r: int, c: int) -> None:
        if r < 0 or r >= rows or c < 0 or c >= cols:
            return
        if grid[r][c] != '1':
            return
        grid[r][c] = '#'    # mark visited in-place (restore if needed)
        dfs(r + 1, c); dfs(r - 1, c)
        dfs(r, c + 1); dfs(r, c - 1)

    for r in range(rows):
        for c in range(cols):
            if grid[r][c] == '1':
                dfs(r, c)
                count += 1
    return count


def max_area_of_island(grid: list[list[int]]) -> int:
    """Maximum area of an island. 1=land, 0=water.
    Time: O(m * n)  Space: O(m * n)
    """
    rows, cols = len(grid), len(grid[0])

    def dfs(r: int, c: int) -> int:
        if r < 0 or r >= rows or c < 0 or c >= cols or grid[r][c] == 0:
            return 0
        grid[r][c] = 0    # mark visited
        return 1 + dfs(r+1,c) + dfs(r-1,c) + dfs(r,c+1) + dfs(r,c-1)

    return max(dfs(r, c)
               for r in range(rows)
               for c in range(cols)
               if grid[r][c] == 1) or 0


def flood_fill(image: list[list[int]],
               sr: int, sc: int, color: int) -> list[list[int]]:
    """Flood fill: repaint connected region starting from (sr, sc).
    Time: O(m * n)  Space: O(m * n)
    """
    original = image[sr][sc]
    if original == color:
        return image

    rows, cols = len(image), len(image[0])

    def dfs(r: int, c: int) -> None:
        if r < 0 or r >= rows or c < 0 or c >= cols:
            return
        if image[r][c] != original:
            return
        image[r][c] = color
        dfs(r+1,c); dfs(r-1,c); dfs(r,c+1); dfs(r,c-1)

    dfs(sr, sc)
    return image
BFS vs DFS decision guide
Use BFS whenUse DFS when
Shortest path in unweighted graphDetecting cycles
Level-order processingTopological sort (DFS version)
Finding closest/nearest nodeExhaustive search / backtracking
Word ladder, 0-1 BFSConnected components, flood fill
The graph is wide (many branches)The graph is deep (long paths)
Practice
Challenge

BFS Shortest Path

Find the shortest path length (number of edges) between two nodes in an unweighted undirected graph. Fill in the BFS queue operations and distance tracking.

from collections import deque

def shortest_path(graph, start, end):
    if start == end:
        return 0
    visited = {start}
    queue = deque([(start, 0)])   # (node, distance)
    while queue:
        node, dist = queue.___(   )
        for neighbor in graph.get(node, []):
            if neighbor == end:
                return ___
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, ___))
    return -1   # unreachable

graph = {0: [1, 2], 1: [0, 3], 2: [0, 3], 3: [1, 2, 4], 4: [3]}
print(shortest_path(graph, 0, 4))  # 3
print(shortest_path(graph, 0, 3))  # 2
print(shortest_path(graph, 0, 0))  # 0
import java.util.*;

public class Solution {
    static int shortestPath(Map<Integer, List<Integer>> graph, int start, int end) {
        if (start == end) return 0;
        Set<Integer> visited = new HashSet<>();
        Queue<int[]> queue = new LinkedList<>();  // {node, dist}
        queue.offer(new int[]{start, 0});
        visited.add(start);
        while (!queue.isEmpty()) {
            int[] curr = ___;
            int node = curr[0], dist = curr[1];
            for (int neighbor : graph.getOrDefault(node, Collections.emptyList())) {
                if (neighbor == end) return ___;
                if (!visited.contains(neighbor)) {
                    visited.add(neighbor);
                    queue.offer(new int[]{neighbor, ___});
                }
            }
        }
        return -1;
    }

    public static void main(String[] args) {
        Map<Integer, List<Integer>> graph = new HashMap<>();
        graph.put(0, Arrays.asList(1, 2));
        graph.put(1, Arrays.asList(0, 3));
        graph.put(2, Arrays.asList(0, 3));
        graph.put(3, Arrays.asList(1, 2, 4));
        graph.put(4, Arrays.asList(3));
        System.out.println(shortestPath(graph, 0, 4));  // 3
        System.out.println(shortestPath(graph, 0, 3));  // 2
        System.out.println(shortestPath(graph, 0, 0));  // 0
    }
}
Use queue.popleft() in Python (or queue.poll() in Java). When you find the destination, return dist + 1 (one more edge). Push neighbors with dist + 1.
from collections import deque

def shortest_path(graph, start, end):
    if start == end:
        return 0
    visited = {start}
    queue = deque([(start, 0)])
    while queue:
        node, dist = queue.popleft()
        for neighbor in graph.get(node, []):
            if neighbor == end:
                return dist + 1
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, dist + 1))
    return -1

graph = {0: [1, 2], 1: [0, 3], 2: [0, 3], 3: [1, 2, 4], 4: [3]}
print(shortest_path(graph, 0, 4))  # 3
print(shortest_path(graph, 0, 3))  # 2
print(shortest_path(graph, 0, 0))  # 0
import java.util.*;

public class Solution {
    static int shortestPath(Map<Integer, List<Integer>> graph, int start, int end) {
        if (start == end) return 0;
        Set<Integer> visited = new HashSet<>();
        Queue<int[]> queue = new LinkedList<>();
        queue.offer(new int[]{start, 0});
        visited.add(start);
        while (!queue.isEmpty()) {
            int[] curr = queue.poll();
            int node = curr[0], dist = curr[1];
            for (int neighbor : graph.getOrDefault(node, Collections.emptyList())) {
                if (neighbor == end) return dist + 1;
                if (!visited.contains(neighbor)) {
                    visited.add(neighbor);
                    queue.offer(new int[]{neighbor, dist + 1});
                }
            }
        }
        return -1;
    }

    public static void main(String[] args) {
        Map<Integer, List<Integer>> graph = new HashMap<>();
        graph.put(0, Arrays.asList(1, 2));
        graph.put(1, Arrays.asList(0, 3));
        graph.put(2, Arrays.asList(0, 3));
        graph.put(3, Arrays.asList(1, 2, 4));
        graph.put(4, Arrays.asList(3));
        System.out.println(shortestPath(graph, 0, 4));  // 3
        System.out.println(shortestPath(graph, 0, 3));  // 2
        System.out.println(shortestPath(graph, 0, 0));  // 0
    }
}
Challenge

Number of Connected Components

Given n nodes (labeled 0 to n-1) and a list of undirected edges, count the number of connected components. Fill in the BFS/DFS traversal and component counter.

from collections import deque, defaultdict

def count_components(n, edges):
    graph = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)
        graph[v].append(u)

    visited = set()
    components = ___

    for node in range(n):
        if node not in visited:
            components ___
            queue = deque([node])
            visited.add(node)
            while queue:
                curr = queue.popleft()
                for neighbor in graph[curr]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(___)
    return components

print(count_components(5, [[0,1],[1,2],[3,4]]))        # 2
print(count_components(5, [[0,1],[1,2],[2,3],[3,4]]))  # 1
print(count_components(4, []))                         # 4
import java.util.*;

public class Solution {
    static int countComponents(int n, int[][] edges) {
        List<List<Integer>> graph = new ArrayList<>();
        for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
        for (int[] e : edges) {
            graph.get(e[0]).add(e[1]);
            graph.get(e[1]).add(e[0]);
        }

        boolean[] visited = new boolean[n];
        int components = ___;

        for (int i = 0; i < n; i++) {
            if (!visited[i]) {
                components___;
                Queue<Integer> queue = new LinkedList<>();
                queue.offer(i);
                visited[i] = true;
                while (!queue.isEmpty()) {
                    int curr = queue.poll();
                    for (int neighbor : graph.get(curr)) {
                        if (!visited[neighbor]) {
                            visited[neighbor] = true;
                            queue.offer(___);
                        }
                    }
                }
            }
        }
        return components;
    }

    public static void main(String[] args) {
        System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{3,4}}));       // 2
        System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{2,3},{3,4}})); // 1
        System.out.println(countComponents(4, new int[][]{}));                        // 4
    }
}
Initialise components = 0. Each time you start a BFS/DFS from an unvisited node, increment it by 1. The traversal marks all nodes in that component as visited so they are skipped on future iterations.
from collections import deque, defaultdict

def count_components(n, edges):
    graph = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)
        graph[v].append(u)

    visited = set()
    components = 0

    for node in range(n):
        if node not in visited:
            components += 1
            queue = deque([node])
            visited.add(node)
            while queue:
                curr = queue.popleft()
                for neighbor in graph[curr]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(neighbor)
    return components

print(count_components(5, [[0,1],[1,2],[3,4]]))        # 2
print(count_components(5, [[0,1],[1,2],[2,3],[3,4]]))  # 1
print(count_components(4, []))                         # 4
import java.util.*;

public class Solution {
    static int countComponents(int n, int[][] edges) {
        List<List<Integer>> graph = new ArrayList<>();
        for (int i = 0; i < n; i++) graph.add(new ArrayList<>());
        for (int[] e : edges) {
            graph.get(e[0]).add(e[1]);
            graph.get(e[1]).add(e[0]);
        }

        boolean[] visited = new boolean[n];
        int components = 0;

        for (int i = 0; i < n; i++) {
            if (!visited[i]) {
                components++;
                Queue<Integer> queue = new LinkedList<>();
                queue.offer(i);
                visited[i] = true;
                while (!queue.isEmpty()) {
                    int curr = queue.poll();
                    for (int neighbor : graph.get(curr)) {
                        if (!visited[neighbor]) {
                            visited[neighbor] = true;
                            queue.offer(neighbor);
                        }
                    }
                }
            }
        }
        return components;
    }

    public static void main(String[] args) {
        System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{3,4}}));       // 2
        System.out.println(countComponents(5, new int[][]{{0,1},{1,2},{2,3},{3,4}})); // 1
        System.out.println(countComponents(4, new int[][]{}));                        // 4
    }
}
Challenge

Detect Cycle in Directed Graph

Return True if the directed graph contains a cycle. Use DFS with three node states: unvisited (0), currently in the DFS stack (1), and fully processed (2). A back edge to a state-1 node indicates a cycle. Fill in the state transitions and cycle detection condition.

from collections import defaultdict

def has_cycle(num_nodes, edges):
    graph = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)

    # 0 = unvisited, 1 = in-progress, 2 = done
    state = [0] * num_nodes

    def dfs(node):
        state[node] = ___     # mark in-progress
        for neighbor in graph[node]:
            if state[neighbor] == ___:   # back edge = cycle
                return True
            if state[neighbor] == 0 and dfs(neighbor):
                return True
        state[node] = ___     # mark done
        return False

    for node in range(num_nodes):
        if state[node] == 0:
            if dfs(node):
                return True
    return False

print(has_cycle(4, [[0,1],[1,2],[2,0],[2,3]]))  # True
print(has_cycle(4, [[0,1],[1,2],[2,3]]))        # False
import java.util.*;

public class Solution {
    static List<List<Integer>> graph;
    static int[] state;  // 0=unvisited, 1=in-progress, 2=done

    static boolean dfs(int node) {
        state[node] = ___;         // in-progress
        for (int neighbor : graph.get(node)) {
            if (state[neighbor] == ___) return true;  // back edge
            if (state[neighbor] == 0 && dfs(neighbor)) return true;
        }
        state[node] = ___;         // done
        return false;
    }

    static boolean hasCycle(int numNodes, int[][] edges) {
        graph = new ArrayList<>();
        state = new int[numNodes];
        for (int i = 0; i < numNodes; i++) graph.add(new ArrayList<>());
        for (int[] e : edges) graph.get(e[0]).add(e[1]);

        for (int i = 0; i < numNodes; i++) {
            if (state[i] == 0 && dfs(i)) return true;
        }
        return false;
    }

    public static void main(String[] args) {
        System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,0},{2,3}}));  // true
        System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,3}}));        // false
    }
}
Set state[node] = 1 when entering DFS. If you reach a neighbor with state 1, you have found a back edge — return True. After exploring all neighbors, set state[node] = 2 (done). Nodes with state 2 are safe and can be skipped.
from collections import defaultdict

def has_cycle(num_nodes, edges):
    graph = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)

    state = [0] * num_nodes  # 0=unvisited, 1=in-progress, 2=done

    def dfs(node):
        state[node] = 1
        for neighbor in graph[node]:
            if state[neighbor] == 1:
                return True
            if state[neighbor] == 0 and dfs(neighbor):
                return True
        state[node] = 2
        return False

    for node in range(num_nodes):
        if state[node] == 0:
            if dfs(node):
                return True
    return False

print(has_cycle(4, [[0,1],[1,2],[2,0],[2,3]]))  # True
print(has_cycle(4, [[0,1],[1,2],[2,3]]))        # False
import java.util.*;

public class Solution {
    static List<List<Integer>> graph;
    static int[] state;

    static boolean dfs(int node) {
        state[node] = 1;
        for (int neighbor : graph.get(node)) {
            if (state[neighbor] == 1) return true;
            if (state[neighbor] == 0 && dfs(neighbor)) return true;
        }
        state[node] = 2;
        return false;
    }

    static boolean hasCycle(int numNodes, int[][] edges) {
        graph = new ArrayList<>();
        state = new int[numNodes];
        for (int i = 0; i < numNodes; i++) graph.add(new ArrayList<>());
        for (int[] e : edges) graph.get(e[0]).add(e[1]);

        for (int i = 0; i < numNodes; i++) {
            if (state[i] == 0 && dfs(i)) return true;
        }
        return false;
    }

    public static void main(String[] args) {
        System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,0},{2,3}}));  // true
        System.out.println(hasCycle(4, new int[][]{{0,1},{1,2},{2,3}}));        // false
    }
}

11. Shortest Path Algorithms

Weighted graphs require algorithms that account for edge costs. Choose based on constraints: negative weights, number of sources, and graph density all matter.

Dijkstra's Algorithm

Single-source shortest path for graphs with non-negative edge weights. Uses a min-heap to always relax the cheapest reachable node next.

import heapq
from collections import defaultdict

def dijkstra(n: int,
             edges: list[tuple[int, int, int]],
             src: int) -> list[float]:
    """Single-source shortest paths from src.
    edges: list of (u, v, weight) — undirected here.
    Returns dist[] where dist[i] = shortest distance from src to i,
    or float('inf') if unreachable.
    Time: O(E log V)  Space: O(V + E)
    """
    graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((v, w))
        graph[v].append((u, w))          # omit for directed graph

    dist = [float('inf')] * n
    dist[src] = 0
    heap: list[tuple[float, int]] = [(0, src)]  # (cost, node)

    while heap:
        cost, u = heapq.heappop(heap)
        if cost > dist[u]:               # stale entry — skip
            continue
        for v, w in graph[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                heapq.heappush(heap, (dist[v], v))

    return dist


def dijkstra_with_path(n: int,
                       edges: list[tuple[int, int, int]],
                       src: int,
                       dst: int) -> tuple[float, list[int]]:
    """Returns (distance, path). Path is empty if dst unreachable."""
    graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((v, w))
        graph[v].append((u, w))

    dist = [float('inf')] * n
    prev = [-1] * n
    dist[src] = 0
    heap = [(0, src)]

    while heap:
        cost, u = heapq.heappop(heap)
        if cost > dist[u]:
            continue
        for v, w in graph[u]:
            if dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                prev[v] = u
                heapq.heappush(heap, (dist[v], v))

    if dist[dst] == float('inf'):
        return float('inf'), []

    path: list[int] = []
    cur = dst
    while cur != -1:
        path.append(cur)
        cur = prev[cur]
    return dist[dst], path[::-1]
LeetCode problems
743 Network Delay Time · 1631 Path With Minimum Effort · 778 Swim in Rising Water · 1514 Path with Maximum Probability

Bellman-Ford

Handles negative edge weights. Relaxes all edges V-1 times. A V-th relaxation that still reduces distance means a negative cycle exists.

def bellman_ford(n: int,
                 edges: list[tuple[int, int, int]],
                 src: int) -> tuple[list[float], bool]:
    """Single-source shortest path supporting negative weights.
    edges: list of (u, v, weight) — directed.
    Returns (dist[], has_negative_cycle).
    Time: O(V * E)  Space: O(V)
    """
    dist = [float('inf')] * n
    dist[src] = 0

    for _ in range(n - 1):              # relax all edges V-1 times
        updated = False
        for u, v, w in edges:
            if dist[u] != float('inf') and dist[u] + w < dist[v]:
                dist[v] = dist[u] + w
                updated = True
        if not updated:                  # early exit: already converged
            break

    # Check for negative cycle: one more pass still reduces?
    for u, v, w in edges:
        if dist[u] != float('inf') and dist[u] + w < dist[v]:
            return dist, True           # negative cycle reachable from src

    return dist, False
LeetCode problems
787 Cheapest Flights Within K Stops · 743 Network Delay Time (also solvable with Bellman-Ford) · 1334 Find the City With the Smallest Number of Neighbors

Floyd-Warshall

All-pairs shortest paths. Works with negative weights (but not negative cycles). DP over intermediate nodes.

def floyd_warshall(n: int,
                   edges: list[tuple[int, int, int]]
                   ) -> list[list[float]]:
    """All-pairs shortest paths.
    Time: O(V^3)  Space: O(V^2)
    Use when n is small (n <= 500 typically).
    """
    INF = float('inf')
    dist = [[INF] * n for _ in range(n)]
    for i in range(n):
        dist[i][i] = 0
    for u, v, w in edges:
        dist[u][v] = min(dist[u][v], w)  # directed; duplicate edges: take min

    for k in range(n):                   # k = intermediate node
        for i in range(n):
            for j in range(n):
                if dist[i][k] + dist[k][j] < dist[i][j]:
                    dist[i][j] = dist[i][k] + dist[k][j]

    # Negative cycle exists if any dist[i][i] < 0
    return dist

0-1 BFS

When edges have weight 0 or 1 only, use a deque instead of a heap: push 0-weight neighbors to front, 1-weight to back. O(V + E) — faster than Dijkstra for this special case.

from collections import deque

def bfs_01(grid: list[list[int]],
           src: tuple[int, int],
           dst: tuple[int, int]) -> int:
    """Minimum cost path in grid where moving to 0-cell costs 0,
    moving to 1-cell costs 1.
    Time: O(rows * cols)  Space: O(rows * cols)
    """
    rows, cols = len(grid), len(grid[0])
    dist = [[float('inf')] * cols for _ in range(rows)]
    sr, sc = src
    dist[sr][sc] = 0
    dq: deque[tuple[int, int]] = deque([(sr, sc)])

    while dq:
        r, c = dq.popleft()
        for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
            nr, nc = r + dr, c + dc
            if 0 <= nr < rows and 0 <= nc < cols:
                w = grid[nr][nc]          # edge weight: 0 or 1
                new_dist = dist[r][c] + w
                if new_dist < dist[nr][nc]:
                    dist[nr][nc] = new_dist
                    if w == 0:
                        dq.appendleft((nr, nc))   # cheaper: goes to front
                    else:
                        dq.append((nr, nc))       # costlier: goes to back

    dr, dc = dst
    return dist[dr][dc] if dist[dr][dc] != float('inf') else -1
Which shortest-path algorithm to use?
SituationAlgorithmComplexity
Non-negative weights, single sourceDijkstra (min-heap)O(E log V)
Negative weights, single sourceBellman-FordO(VE)
All pairs, small nFloyd-Warshall\(O(V^3)\)
Unweighted graphBFSO(V + E)
0/1 edge weights0-1 BFS (deque)O(V + E)
DAG, any weightsTopo sort + relaxO(V + E)
Practice
Challenge

Dijkstra Simplified

Given a weighted undirected graph as a list of edges (u, v, w), find the shortest distance from node 0 to every other node. Fill in the Dijkstra implementation. Expected output: [0, 3, 1, 4]

import heapq

def dijkstra(edges, n):
    graph = [[] for _ in range(n)]
    for u, v, w in edges:
        graph[u].append((v, w))
        graph[v].append((u, w))

    dist = [float('inf')] * n
    dist[0] = 0
    heap = [(0, 0)]  # (cost, node)

    while heap:
        cost, node = heapq.heappop(heap)
        # __ Skip if we already found a better path
        if cost > dist[node]:
            continue
        for neighbor, weight in graph[node]:
            new_cost = cost + weight
            # __ Relax the edge if cheaper
            if new_cost < dist[neighbor]:
                dist[neighbor] = new_cost
                heapq.heappush(heap, (new_cost, neighbor))

    return dist

edges = [(0, 1, 4), (0, 2, 1), (2, 1, 2), (1, 3, 1), (2, 3, 5)]
print(dijkstra(edges, 4))  # [0, 3, 1, 4]
import java.util.*;

public class Solution {
    public static int[] dijkstra(int[][] edges, int n) {
        List<int[]>[] graph = new List[n];
        for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
        for (int[] e : edges) {
            graph[e[0]].add(new int[]{e[1], e[2]});
            graph[e[1]].add(new int[]{e[0], e[2]});
        }

        int[] dist = new int[n];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[0] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        pq.offer(new int[]{0, 0}); // {cost, node}

        while (!pq.isEmpty()) {
            int[] curr = pq.poll();
            int cost = curr[0], node = curr[1];
            // Skip if we already found a better path
            if (cost > dist[node]) continue;
            for (int[] next : graph[node]) {
                int neighbor = next[0], weight = next[1];
                int newCost = cost + weight;
                // Relax the edge if cheaper
                if (newCost < dist[neighbor]) {
                    dist[neighbor] = newCost;
                    pq.offer(new int[]{newCost, neighbor});
                }
            }
        }
        return dist;
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1,4},{0,2,1},{2,1,2},{1,3,1},{2,3,5}};
        System.out.println(Arrays.toString(dijkstra(edges, 4)));
        // [0, 3, 1, 4]
    }
}
Use a min-heap storing (cost, node). Pop the cheapest node, skip if its cost exceeds the recorded best. Relax edges to neighbors and push improved distances.
import heapq

def dijkstra(edges, n):
    graph = [[] for _ in range(n)]
    for u, v, w in edges:
        graph[u].append((v, w))
        graph[v].append((u, w))

    dist = [float('inf')] * n
    dist[0] = 0
    heap = [(0, 0)]

    while heap:
        cost, node = heapq.heappop(heap)
        if cost > dist[node]:
            continue
        for neighbor, weight in graph[node]:
            new_cost = cost + weight
            if new_cost < dist[neighbor]:
                dist[neighbor] = new_cost
                heapq.heappush(heap, (new_cost, neighbor))

    return dist

edges = [(0, 1, 4), (0, 2, 1), (2, 1, 2), (1, 3, 1), (2, 3, 5)]
print(dijkstra(edges, 4))  # [0, 3, 1, 4]
import java.util.*;

public class Solution {
    public static int[] dijkstra(int[][] edges, int n) {
        List<int[]>[] graph = new List[n];
        for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();
        for (int[] e : edges) {
            graph[e[0]].add(new int[]{e[1], e[2]});
            graph[e[1]].add(new int[]{e[0], e[2]});
        }

        int[] dist = new int[n];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[0] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        pq.offer(new int[]{0, 0});

        while (!pq.isEmpty()) {
            int[] curr = pq.poll();
            int cost = curr[0], node = curr[1];
            if (cost > dist[node]) continue;
            for (int[] next : graph[node]) {
                int neighbor = next[0], weight = next[1];
                int newCost = cost + weight;
                if (newCost < dist[neighbor]) {
                    dist[neighbor] = newCost;
                    pq.offer(new int[]{newCost, neighbor});
                }
            }
        }
        return dist;
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1,4},{0,2,1},{2,1,2},{1,3,1},{2,3,5}};
        System.out.println(Arrays.toString(dijkstra(edges, 4)));
        // [0, 3, 1, 4]
    }
}
Challenge

Network Delay Time

A signal is sent from node k. Given weighted directed edges, find how long until all nodes receive the signal (max of all shortest paths). Return -1 if any node is unreachable. Expected output: 3

import heapq

def network_delay(times, n, k):
    graph = [[] for _ in range(n + 1)]
    for u, v, w in times:
        graph[u].append((v, w))

    dist = [float('inf')] * (n + 1)
    dist[k] = 0
    heap = [(0, k)]

    while heap:
        cost, node = heapq.heappop(heap)
        if cost > dist[node]:
            continue
        for neighbor, weight in graph[node]:
            new_cost = cost + weight
            # __ Relax if cheaper
            if new_cost < dist[neighbor]:
                dist[neighbor] = new_cost
                heapq.heappush(heap, (new_cost, neighbor))

    # __ Check reachability and return answer
    max_dist = max(dist[1:])
    return max_dist if max_dist < float('inf') else -1

edges = [[1, 2, 1], [2, 3, 2], [1, 3, 4]]
print(network_delay(edges, 3, 1))  # 3
import java.util.*;

public class Solution {
    public static int networkDelay(int[][] times, int n, int k) {
        List<int[]>[] graph = new List[n + 1];
        for (int i = 0; i <= n; i++) graph[i] = new ArrayList<>();
        for (int[] t : times) graph[t[0]].add(new int[]{t[1], t[2]});

        int[] dist = new int[n + 1];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[k] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        pq.offer(new int[]{0, k});

        while (!pq.isEmpty()) {
            int[] curr = pq.poll();
            int cost = curr[0], node = curr[1];
            if (cost > dist[node]) continue;
            for (int[] next : graph[node]) {
                int newCost = cost + next[1];
                // Relax if cheaper
                if (newCost < dist[next[0]]) {
                    dist[next[0]] = newCost;
                    pq.offer(new int[]{newCost, next[0]});
                }
            }
        }

        // Check reachability and return answer
        int maxDist = 0;
        for (int i = 1; i <= n; i++) {
            if (dist[i] == Integer.MAX_VALUE) return -1;
            maxDist = Math.max(maxDist, dist[i]);
        }
        return maxDist;
    }

    public static void main(String[] args) {
        int[][] edges = {{1,2,1},{2,3,2},{1,3,4}};
        System.out.println(networkDelay(edges, 3, 1));  // 3
    }
}
Run Dijkstra from node k. The answer is max(distances). Return -1 if any distance is still infinity after the algorithm completes.
import heapq

def network_delay(times, n, k):
    graph = [[] for _ in range(n + 1)]
    for u, v, w in times:
        graph[u].append((v, w))

    dist = [float('inf')] * (n + 1)
    dist[k] = 0
    heap = [(0, k)]

    while heap:
        cost, node = heapq.heappop(heap)
        if cost > dist[node]:
            continue
        for neighbor, weight in graph[node]:
            new_cost = cost + weight
            if new_cost < dist[neighbor]:
                dist[neighbor] = new_cost
                heapq.heappush(heap, (new_cost, neighbor))

    max_dist = max(dist[1:])
    return max_dist if max_dist < float('inf') else -1

edges = [[1, 2, 1], [2, 3, 2], [1, 3, 4]]
print(network_delay(edges, 3, 1))  # 3
import java.util.*;

public class Solution {
    public static int networkDelay(int[][] times, int n, int k) {
        List<int[]>[] graph = new List[n + 1];
        for (int i = 0; i <= n; i++) graph[i] = new ArrayList<>();
        for (int[] t : times) graph[t[0]].add(new int[]{t[1], t[2]});

        int[] dist = new int[n + 1];
        Arrays.fill(dist, Integer.MAX_VALUE);
        dist[k] = 0;
        PriorityQueue<int[]> pq = new PriorityQueue<>(Comparator.comparingInt(a -> a[0]));
        pq.offer(new int[]{0, k});

        while (!pq.isEmpty()) {
            int[] curr = pq.poll();
            int cost = curr[0], node = curr[1];
            if (cost > dist[node]) continue;
            for (int[] next : graph[node]) {
                int newCost = cost + next[1];
                if (newCost < dist[next[0]]) {
                    dist[next[0]] = newCost;
                    pq.offer(new int[]{newCost, next[0]});
                }
            }
        }

        int maxDist = 0;
        for (int i = 1; i <= n; i++) {
            if (dist[i] == Integer.MAX_VALUE) return -1;
            maxDist = Math.max(maxDist, dist[i]);
        }
        return maxDist;
    }

    public static void main(String[] args) {
        int[][] edges = {{1,2,1},{2,3,2},{1,3,4}};
        System.out.println(networkDelay(edges, 3, 1));  // 3
    }
}
Challenge

Cheapest Flights with K Stops

Find the cheapest price from src to dst using at most k stops. Use the Bellman-Ford approach: relax edges k+1 times. Expected output: 200

def cheapest(n, flights, src, dst, k):
    prices = [float('inf')] * n
    prices[src] = 0

    for i in range(k + 1):
        # __ Work on a copy so we don't reuse this round's updates
        temp = prices[:]
        for u, v, w in flights:
            if prices[u] != float('inf'):
                new_price = prices[u] + w
                # __ Relax if cheaper
                if new_price < temp[v]:
                    temp[v] = new_price
        prices = temp

    return prices[dst] if prices[dst] != float('inf') else -1

flights = [[0, 1, 100], [1, 2, 100], [0, 2, 500]]
print(cheapest(3, flights, 0, 2, 1))  # 200
import java.util.*;

public class Solution {
    public static int cheapest(int n, int[][] flights, int src, int dst, int k) {
        int[] prices = new int[n];
        Arrays.fill(prices, Integer.MAX_VALUE);
        prices[src] = 0;

        for (int i = 0; i <= k; i++) {
            // Work on a copy so we don't reuse this round's updates
            int[] temp = Arrays.copyOf(prices, n);
            for (int[] flight : flights) {
                int u = flight[0], v = flight[1], w = flight[2];
                if (prices[u] != Integer.MAX_VALUE) {
                    int newPrice = prices[u] + w;
                    // Relax if cheaper
                    if (newPrice < temp[v]) temp[v] = newPrice;
                }
            }
            prices = temp;
        }

        return prices[dst] == Integer.MAX_VALUE ? -1 : prices[dst];
    }

    public static void main(String[] args) {
        int[][] flights = {{0,1,100},{1,2,100},{0,2,500}};
        System.out.println(cheapest(3, flights, 0, 2, 1));  // 200
    }
}
Bellman-Ford with a twist: run exactly k+1 relaxation rounds. Each round represents one more edge (hop) used. Copy the prices array before each round so updates from the current round don't bleed into same-round relaxations.
def cheapest(n, flights, src, dst, k):
    prices = [float('inf')] * n
    prices[src] = 0

    for i in range(k + 1):
        temp = prices[:]
        for u, v, w in flights:
            if prices[u] != float('inf'):
                new_price = prices[u] + w
                if new_price < temp[v]:
                    temp[v] = new_price
        prices = temp

    return prices[dst] if prices[dst] != float('inf') else -1

flights = [[0, 1, 100], [1, 2, 100], [0, 2, 500]]
print(cheapest(3, flights, 0, 2, 1))  # 200
import java.util.*;

public class Solution {
    public static int cheapest(int n, int[][] flights, int src, int dst, int k) {
        int[] prices = new int[n];
        Arrays.fill(prices, Integer.MAX_VALUE);
        prices[src] = 0;

        for (int i = 0; i <= k; i++) {
            int[] temp = Arrays.copyOf(prices, n);
            for (int[] flight : flights) {
                int u = flight[0], v = flight[1], w = flight[2];
                if (prices[u] != Integer.MAX_VALUE) {
                    int newPrice = prices[u] + w;
                    if (newPrice < temp[v]) temp[v] = newPrice;
                }
            }
            prices = temp;
        }

        return prices[dst] == Integer.MAX_VALUE ? -1 : prices[dst];
    }

    public static void main(String[] args) {
        int[][] flights = {{0,1,100},{1,2,100},{0,2,500}};
        System.out.println(cheapest(3, flights, 0, 2, 1));  // 200
    }
}

12. Advanced Graphs

Union-Find (Disjoint Set Union)

Efficient data structure for tracking connected components. Two key optimizations: path compression (flatten tree on find) and union by rank (attach smaller tree under larger). Both together give near O(1) amortized per operation.

class UnionFind:
    """Union-Find with path compression and union by rank.
    Time: O(alpha(n)) per operation (effectively O(1)).
    Space: O(n)
    """
    def __init__(self, n: int) -> None:
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n           # number of distinct components

    def find(self, x: int) -> int:
        """Find root with path compression."""
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])  # path compress
        return self.parent[x]

    def union(self, x: int, y: int) -> bool:
        """Unite the sets containing x and y.
        Returns True if they were in different sets (new connection made).
        """
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False              # already connected
        # Attach smaller rank tree under larger rank root
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        self.components -= 1
        return True

    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)


# Example: count connected components in undirected graph
def count_components(n: int, edges: list[list[int]]) -> int:
    uf = UnionFind(n)
    for u, v in edges:
        uf.union(u, v)
    return uf.components
LeetCode problems
200 Number of Islands · 684 Redundant Connection · 323 Number of Connected Components · 1584 Min Cost to Connect All Points · 721 Accounts Merge

Minimum Spanning Tree

A spanning tree that connects all V nodes with exactly V-1 edges and minimal total weight.

Kruskal's — sort edges, add if no cycle (Union-Find)

def kruskal_mst(n: int,
                edges: list[tuple[int, int, int]]) -> tuple[int, list]:
    """Kruskal's MST algorithm.
    edges: list of (weight, u, v).
    Returns (total_cost, mst_edges).
    Time: O(E log E)  Space: O(V)
    """
    uf = UnionFind(n)
    mst_cost = 0
    mst_edges: list[tuple[int, int, int]] = []
    for w, u, v in sorted(edges):       # sort by weight ascending
        if uf.union(u, v):              # no cycle: add edge
            mst_cost += w
            mst_edges.append((w, u, v))
            if len(mst_edges) == n - 1:
                break                   # MST complete
    return mst_cost, mst_edges

Prim's — grow MST greedily from a start node (min-heap)

def prim_mst(n: int,
             edges: list[tuple[int, int, int]]) -> int:
    """Prim's MST algorithm.
    edges: list of (u, v, weight) — undirected.
    Returns total MST cost, or -1 if graph is disconnected.
    Time: O(E log V)  Space: O(V + E)
    """
    graph: dict[int, list[tuple[int, int]]] = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((v, w))
        graph[v].append((u, w))

    visited: set[int] = set()
    heap = [(0, 0)]                     # (cost, node), start at node 0
    total = 0

    while heap and len(visited) < n:
        cost, u = heapq.heappop(heap)
        if u in visited:
            continue
        visited.add(u)
        total += cost
        for v, w in graph[u]:
            if v not in visited:
                heapq.heappush(heap, (w, v))

    return total if len(visited) == n else -1

Bipartite Checking (2-Color BFS)

def is_bipartite(graph: list[list[int]]) -> bool:
    """Check if graph is 2-colorable (no odd-length cycles).
    graph: adjacency list indexed 0..n-1.
    Time: O(V + E)  Space: O(V)
    """
    n = len(graph)
    color = [-1] * n

    for start in range(n):
        if color[start] != -1:
            continue
        color[start] = 0
        queue: deque[int] = deque([start])
        while queue:
            node = queue.popleft()
            for neighbor in graph[node]:
                if color[neighbor] == -1:
                    color[neighbor] = 1 - color[node]
                    queue.append(neighbor)
                elif color[neighbor] == color[node]:
                    return False        # same color: odd cycle found
    return True

Strongly Connected Components (Kosaraju's Concept)

A SCC is a maximal set of nodes where every node is reachable from every other node. Kosaraju's runs two DFS passes: one on original graph (finish order), one on transposed graph in reverse finish order.

def kosaraju_scc(n: int,
                 edges: list[tuple[int, int]]) -> list[list[int]]:
    """Find all SCCs using Kosaraju's algorithm.
    Time: O(V + E)  Space: O(V + E)
    """
    graph: dict[int, list[int]] = defaultdict(list)
    rev_graph: dict[int, list[int]] = defaultdict(list)
    for u, v in edges:
        graph[u].append(v)
        rev_graph[v].append(u)

    # Pass 1: DFS on original graph, record finish order
    visited: set[int] = set()
    finish_order: list[int] = []

    def dfs1(node: int) -> None:
        visited.add(node)
        for nb in graph[node]:
            if nb not in visited:
                dfs1(nb)
        finish_order.append(node)       # append on finish

    for v in range(n):
        if v not in visited:
            dfs1(v)

    # Pass 2: DFS on reversed graph in reverse finish order
    visited.clear()
    sccs: list[list[int]] = []

    def dfs2(node: int, component: list[int]) -> None:
        visited.add(node)
        component.append(node)
        for nb in rev_graph[node]:
            if nb not in visited:
                dfs2(nb, component)

    for v in reversed(finish_order):
        if v not in visited:
            component: list[int] = []
            dfs2(v, component)
            sccs.append(component)

    return sccs
Network Flow (concept only)
Max flow / min cut: given a directed network with edge capacities, find the maximum flow from source to sink. Ford-Fulkerson repeatedly finds augmenting paths (BFS version = Edmonds-Karp, \(O(VE^2)\)). Min-cut theorem: max flow = min cut capacity. Rarely tested in FAANG interviews directly, but the concept appears in matching problems (bipartite matching = max flow).
Practice
Challenge

Union-Find: Connected Components

Count the number of connected components in an undirected graph using Union-Find. Expected output: 2

def count_components(n, edges):
    parent = list(range(n))

    def find(x):
        # __ Path compression: point directly to root
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a, b):
        ra, rb = find(a), find(b)
        if ra == rb:
            return
        # __ Merge: point one root to the other
        parent[ra] = rb

    for u, v in edges:
        union(u, v)

    # __ Count distinct roots
    return sum(1 for i in range(n) if find(i) == i)

print(count_components(5, [[0, 1], [1, 2], [3, 4]]))  # 2
public class Solution {
    static int[] parent;

    static int find(int x) {
        // Path compression: point directly to root
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    static void union(int a, int b) {
        int ra = find(a), rb = find(b);
        if (ra == rb) return;
        // Merge: point one root to the other
        parent[ra] = rb;
    }

    public static int countComponents(int n, int[][] edges) {
        parent = new int[n];
        for (int i = 0; i < n; i++) parent[i] = i;
        for (int[] e : edges) union(e[0], e[1]);

        // Count distinct roots
        int count = 0;
        for (int i = 0; i < n; i++) if (find(i) == i) count++;
        return count;
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1},{1,2},{3,4}};
        System.out.println(countComponents(5, edges));  // 2
    }
}
Initialize parent[i] = i. find() follows parent pointers to the root (with path compression). union() merges two sets by connecting their roots. At the end, count nodes where find(i) == i (they are roots of distinct components).
def count_components(n, edges):
    parent = list(range(n))

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a, b):
        ra, rb = find(a), find(b)
        if ra == rb:
            return
        parent[ra] = rb

    for u, v in edges:
        union(u, v)

    return sum(1 for i in range(n) if find(i) == i)

print(count_components(5, [[0, 1], [1, 2], [3, 4]]))  # 2
public class Solution {
    static int[] parent;

    static int find(int x) {
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    static void union(int a, int b) {
        int ra = find(a), rb = find(b);
        if (ra == rb) return;
        parent[ra] = rb;
    }

    public static int countComponents(int n, int[][] edges) {
        parent = new int[n];
        for (int i = 0; i < n; i++) parent[i] = i;
        for (int[] e : edges) union(e[0], e[1]);

        int count = 0;
        for (int i = 0; i < n; i++) if (find(i) == i) count++;
        return count;
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1},{1,2},{3,4}};
        System.out.println(countComponents(5, edges));  // 2
    }
}
Challenge

Topological Sort

Return a topological ordering of a directed acyclic graph using Kahn's algorithm. Return an empty list if a cycle exists. Test checks len(result) == n. Expected output: 4 (length of valid ordering)

from collections import deque

def topo_sort(n, edges):
    graph = [[] for _ in range(n)]
    in_degree = [0] * n

    for u, v in edges:
        graph[u].append(v)
        # __ Increment in-degree of destination
        in_degree[v] += 1

    # __ Start with all nodes that have no incoming edges
    queue = deque(i for i in range(n) if in_degree[i] == 0)
    result = []

    while queue:
        node = queue.popleft()
        result.append(node)
        for neighbor in graph[node]:
            in_degree[neighbor] -= 1
            # __ Add neighbor if it now has no dependencies
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    # __ Cycle check: valid only if all nodes are in result
    return result if len(result) == n else []

result = topo_sort(4, [[0, 1], [0, 2], [1, 3], [2, 3]])
print(len(result))  # 4 (valid ordering exists)
print(topo_sort(3, [[0, 1], [1, 2], [2, 0]]))  # [] (cycle)
import java.util.*;

public class Solution {
    public static List<Integer> topoSort(int n, int[][] edges) {
        List<Integer>[] graph = new List[n];
        int[] inDegree = new int[n];
        for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();

        for (int[] e : edges) {
            graph[e[0]].add(e[1]);
            // Increment in-degree of destination
            inDegree[e[1]]++;
        }

        // Start with all nodes that have no incoming edges
        Queue<Integer> queue = new LinkedList<>();
        for (int i = 0; i < n; i++) if (inDegree[i] == 0) queue.offer(i);

        List<Integer> result = new ArrayList<>();
        while (!queue.isEmpty()) {
            int node = queue.poll();
            result.add(node);
            for (int neighbor : graph[node]) {
                inDegree[neighbor]--;
                // Add neighbor if it now has no dependencies
                if (inDegree[neighbor] == 0) queue.offer(neighbor);
            }
        }

        // Cycle check
        return result.size() == n ? result : new ArrayList<>();
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1},{0,2},{1,3},{2,3}};
        List<Integer> result = topoSort(4, edges);
        System.out.println(result.size());  // 4

        int[][] cycle = {{0,1},{1,2},{2,0}};
        System.out.println(topoSort(3, cycle));  // []
    }
}
Kahn's algorithm: compute in-degrees, seed the queue with zero-in-degree nodes, process node-by-node decrementing neighbor in-degrees. If a neighbor reaches zero, enqueue it. A valid topological order exists only if all n nodes are processed — otherwise there's a cycle.
from collections import deque

def topo_sort(n, edges):
    graph = [[] for _ in range(n)]
    in_degree = [0] * n

    for u, v in edges:
        graph[u].append(v)
        in_degree[v] += 1

    queue = deque(i for i in range(n) if in_degree[i] == 0)
    result = []

    while queue:
        node = queue.popleft()
        result.append(node)
        for neighbor in graph[node]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    return result if len(result) == n else []

result = topo_sort(4, [[0, 1], [0, 2], [1, 3], [2, 3]])
print(len(result))  # 4
print(topo_sort(3, [[0, 1], [1, 2], [2, 0]]))  # []
import java.util.*;

public class Solution {
    public static List<Integer> topoSort(int n, int[][] edges) {
        List<Integer>[] graph = new List[n];
        int[] inDegree = new int[n];
        for (int i = 0; i < n; i++) graph[i] = new ArrayList<>();

        for (int[] e : edges) {
            graph[e[0]].add(e[1]);
            inDegree[e[1]]++;
        }

        Queue<Integer> queue = new LinkedList<>();
        for (int i = 0; i < n; i++) if (inDegree[i] == 0) queue.offer(i);

        List<Integer> result = new ArrayList<>();
        while (!queue.isEmpty()) {
            int node = queue.poll();
            result.add(node);
            for (int neighbor : graph[node]) {
                inDegree[neighbor]--;
                if (inDegree[neighbor] == 0) queue.offer(neighbor);
            }
        }

        return result.size() == n ? result : new ArrayList<>();
    }

    public static void main(String[] args) {
        int[][] edges = {{0,1},{0,2},{1,3},{2,3}};
        List<Integer> result = topoSort(4, edges);
        System.out.println(result.size());  // 4

        int[][] cycle = {{0,1},{1,2},{2,0}};
        System.out.println(topoSort(3, cycle));  // []
    }
}
Challenge

Redundant Connection

Given edges of an undirected graph added one at a time, find the first edge that creates a cycle. Return that edge. Expected output: [2, 3]

def find_redundant(edges):
    n = len(edges)
    parent = list(range(n + 1))

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    for u, v in edges:
        ru, rv = find(u), find(v)
        # __ If same root, adding this edge creates a cycle
        if ru == rv:
            return [u, v]
        # __ Otherwise, union them
        parent[ru] = rv

    return []

print(find_redundant([[1, 2], [1, 3], [2, 3]]))  # [2, 3]
public class Solution {
    static int[] parent;

    static int find(int x) {
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    public static int[] findRedundant(int[][] edges) {
        int n = edges.length;
        parent = new int[n + 1];
        for (int i = 0; i <= n; i++) parent[i] = i;

        for (int[] e : edges) {
            int ru = find(e[0]), rv = find(e[1]);
            // If same root, adding this edge creates a cycle
            if (ru == rv) return e;
            // Otherwise, union them
            parent[ru] = rv;
        }
        return new int[]{};
    }

    public static void main(String[] args) {
        int[][] edges = {{1,2},{1,3},{2,3}};
        int[] result = findRedundant(edges);
        System.out.println("[" + result[0] + ", " + result[1] + "]");  // [2, 3]
    }
}
Process edges in order with Union-Find. Before each union, call find() on both endpoints. If they already share the same root, this edge would form a cycle — return it immediately. Otherwise, union the two sets and continue.
def find_redundant(edges):
    n = len(edges)
    parent = list(range(n + 1))

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    for u, v in edges:
        ru, rv = find(u), find(v)
        if ru == rv:
            return [u, v]
        parent[ru] = rv

    return []

print(find_redundant([[1, 2], [1, 3], [2, 3]]))  # [2, 3]
public class Solution {
    static int[] parent;

    static int find(int x) {
        while (parent[x] != x) {
            parent[x] = parent[parent[x]];
            x = parent[x];
        }
        return x;
    }

    public static int[] findRedundant(int[][] edges) {
        int n = edges.length;
        parent = new int[n + 1];
        for (int i = 0; i <= n; i++) parent[i] = i;

        for (int[] e : edges) {
            int ru = find(e[0]), rv = find(e[1]);
            if (ru == rv) return e;
            parent[ru] = rv;
        }
        return new int[]{};
    }

    public static void main(String[] args) {
        int[][] edges = {{1,2},{1,3},{2,3}};
        int[] result = findRedundant(edges);
        System.out.println("[" + result[0] + ", " + result[1] + "]");  // [2, 3]
    }
}

13. Dynamic Programming Patterns

DP applies when a problem has optimal substructure (optimal solution uses optimal solutions to subproblems) and overlapping subproblems (same subproblems solved repeatedly).

DP design framework
1. State: What information uniquely identifies a subproblem?
2. Recurrence: How does dp[state] relate to smaller states?
3. Base case: What are the smallest/trivially solved states?
4. Order: Ensure each state is computed before it is needed.
5. Answer: Which state holds the final answer?
6. Optimize: Can you reduce space with a rolling array?

1D DP Classics

def climb_stairs(n: int) -> int:
    """Count ways to climb n stairs (1 or 2 steps at a time).
    State: dp[i] = number of ways to reach step i.
    Recurrence: dp[i] = dp[i-1] + dp[i-2]
    Time: O(n)  Space: O(1) after rolling optimization
    """
    if n <= 2:
        return n
    a, b = 1, 2
    for _ in range(3, n + 1):
        a, b = b, a + b
    return b


def house_robber(nums: list[int]) -> int:
    """Max sum with no two adjacent elements selected.
    State: dp[i] = max profit using houses 0..i.
    Recurrence: dp[i] = max(dp[i-1], dp[i-2] + nums[i])
    Time: O(n)  Space: O(1)
    """
    prev2, prev1 = 0, 0
    for x in nums:
        prev2, prev1 = prev1, max(prev1, prev2 + x)
    return prev1


def coin_change(coins: list[int], amount: int) -> int:
    """Minimum coins to make amount. -1 if impossible.
    State: dp[i] = fewest coins to make amount i.
    Recurrence: dp[i] = min(dp[i - c] + 1) for c in coins if i >= c
    Time: O(amount * len(coins))  Space: O(amount)
    """
    dp = [float('inf')] * (amount + 1)
    dp[0] = 0
    for i in range(1, amount + 1):
        for c in coins:
            if i >= c:
                dp[i] = min(dp[i], dp[i - c] + 1)
    return dp[amount] if dp[amount] != float('inf') else -1


def word_break(s: str, word_dict: list[str]) -> bool:
    """Can s be segmented into words from word_dict?
    State: dp[i] = True if s[:i] can be segmented.
    Time: O(n^2)  Space: O(n)
    """
    words = set(word_dict)
    n = len(s)
    dp = [False] * (n + 1)
    dp[0] = True                        # empty string always valid
    for i in range(1, n + 1):
        for j in range(i):
            if dp[j] and s[j:i] in words:
                dp[i] = True
                break
    return dp[n]

2D DP Classics

def unique_paths(m: int, n: int) -> int:
    """Count paths from top-left to bottom-right (only right/down).
    Time: O(m*n)  Space: O(n) rolling row
    """
    dp = [1] * n
    for _ in range(1, m):
        for j in range(1, n):
            dp[j] += dp[j - 1]
    return dp[n - 1]


def edit_distance(word1: str, word2: str) -> int:
    """Levenshtein distance: minimum insert/delete/replace to convert word1 -> word2.
    State: dp[i][j] = edit distance between word1[:i] and word2[:j].
    Time: O(m*n)  Space: O(m*n) — reducible to O(n)
    """
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i                    # delete all of word1
    for j in range(n + 1):
        dp[0][j] = j                    # insert all of word2

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if word1[i-1] == word2[j-1]:
                dp[i][j] = dp[i-1][j-1]
            else:
                dp[i][j] = 1 + min(
                    dp[i-1][j],         # delete from word1
                    dp[i][j-1],         # insert into word1
                    dp[i-1][j-1]        # replace
                )
    return dp[m][n]


def longest_common_subsequence(text1: str, text2: str) -> int:
    """LCS length. Not necessarily contiguous.
    Time: O(m*n)  Space: O(m*n)
    """
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i-1] == text2[j-1]:
                dp[i][j] = dp[i-1][j-1] + 1
            else:
                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
    return dp[m][n]

Knapsack Patterns

def knapsack_01(weights: list[int],
                values: list[int],
                capacity: int) -> int:
    """0/1 Knapsack: each item used at most once.
    State: dp[w] = max value with weight budget w.
    Iterate weights in reverse to avoid using item twice.
    Time: O(n * capacity)  Space: O(capacity)
    """
    dp = [0] * (capacity + 1)
    for w, v in zip(weights, values):
        for c in range(capacity, w - 1, -1):  # reverse!
            dp[c] = max(dp[c], dp[c - w] + v)
    return dp[capacity]


def knapsack_unbounded(weights: list[int],
                       values: list[int],
                       capacity: int) -> int:
    """Unbounded Knapsack: each item usable unlimited times.
    Iterate weights forward so items can be reused.
    """
    dp = [0] * (capacity + 1)
    for c in range(1, capacity + 1):
        for w, v in zip(weights, values):
            if w <= c:
                dp[c] = max(dp[c], dp[c - w] + v)
    return dp[capacity]


def subset_sum(nums: list[int], target: int) -> bool:
    """Can any subset sum to target? (0/1 knapsack variant)
    Time: O(n * target)  Space: O(target)
    """
    dp = [False] * (target + 1)
    dp[0] = True
    for x in nums:
        for j in range(target, x - 1, -1):
            dp[j] = dp[j] or dp[j - x]
    return dp[target]

Longest Increasing Subsequence

import bisect

def lis_n2(nums: list[int]) -> int:
    """O(n^2) LIS — easier to understand and extend."""
    if not nums:
        return 0
    n = len(nums)
    dp = [1] * n                        # dp[i] = LIS ending at index i
    for i in range(1, n):
        for j in range(i):
            if nums[j] < nums[i]:
                dp[i] = max(dp[i], dp[j] + 1)
    return max(dp)


def lis_nlogn(nums: list[int]) -> int:
    """O(n log n) LIS using patience sort (binary search).
    'tails' stores the smallest tail of all increasing subsequences
    of each length. tails is always sorted.
    """
    tails: list[int] = []
    for x in nums:
        pos = bisect.bisect_left(tails, x)
        if pos == len(tails):
            tails.append(x)             # extend longest subsequence
        else:
            tails[pos] = x              # replace: keeps tails minimal
    return len(tails)


def lis_with_reconstruction(nums: list[int]) -> list[int]:
    """Return the actual LIS sequence (O(n^2) for simplicity)."""
    if not nums:
        return []
    n = len(nums)
    dp = [1] * n
    parent = [-1] * n
    best_end = 0

    for i in range(1, n):
        for j in range(i):
            if nums[j] < nums[i] and dp[j] + 1 > dp[i]:
                dp[i] = dp[j] + 1
                parent[i] = j
        if dp[i] > dp[best_end]:
            best_end = i

    path: list[int] = []
    idx = best_end
    while idx != -1:
        path.append(nums[idx])
        idx = parent[idx]
    return path[::-1]
LeetCode problems
322 Coin Change · 300 Longest Increasing Subsequence · 1143 Longest Common Subsequence · 72 Edit Distance · 416 Partition Equal Subset Sum · 139 Word Break · 198 House Robber · 152 Maximum Product Subarray

DP on Intervals

def burst_balloons(nums: list[int]) -> int:
    """Classic interval DP. Add sentinels 1 at both ends.
    dp[i][j] = max coins from bursting all balloons between i and j (exclusive).
    Key insight: think of 'k' as the LAST balloon to burst in interval (i,j).
    Time: O(n^3)  Space: O(n^2)
    """
    balloons = [1] + nums + [1]
    n = len(balloons)
    dp = [[0] * n for _ in range(n)]

    for length in range(2, n):          # interval length
        for left in range(0, n - length):
            right = left + length
            for k in range(left + 1, right):   # last to burst
                coins = (balloons[left] * balloons[k] * balloons[right]
                         + dp[left][k] + dp[k][right])
                dp[left][right] = max(dp[left][right], coins)

    return dp[0][n - 1]
Space optimization: rolling array technique
# Many 2D DP problems only look at the previous row.
# Replace the m x n table with two 1D arrays.

def edit_distance_space_optimized(word1: str, word2: str) -> int:
    """O(n) space instead of O(m*n)."""
    m, n = len(word1), len(word2)
    prev = list(range(n + 1))           # base case: dp[0][j] = j

    for i in range(1, m + 1):
        curr = [i] + [0] * n            # base case: dp[i][0] = i
        for j in range(1, n + 1):
            if word1[i-1] == word2[j-1]:
                curr[j] = prev[j-1]
            else:
                curr[j] = 1 + min(prev[j], curr[j-1], prev[j-1])
        prev = curr

    return prev[n]
Practice
Challenge

Climbing Stairs

Count the number of distinct ways to climb n stairs, taking 1 or 2 steps at a time. Expected output: 8

def climb_stairs(n):
    if n <= 2:
        return n
    # __ Use two variables instead of full dp array
    prev2, prev1 = 1, 2
    for i in range(3, n + 1):
        # __ Current = sum of previous two (like Fibonacci)
        curr = prev1 + prev2
        prev2 = prev1
        prev1 = curr
    return prev1

print(climb_stairs(5))  # 8
public class Solution {
    public static int climbStairs(int n) {
        if (n <= 2) return n;
        // Use two variables instead of full dp array
        int prev2 = 1, prev1 = 2;
        for (int i = 3; i <= n; i++) {
            // Current = sum of previous two (like Fibonacci)
            int curr = prev1 + prev2;
            prev2 = prev1;
            prev1 = curr;
        }
        return prev1;
    }

    public static void main(String[] args) {
        System.out.println(climbStairs(5));  // 8
    }
}
dp[i] = dp[i-1] + dp[i-2] — from stair i you could have come from i-1 (one step) or i-2 (two steps). Base cases: dp[1]=1, dp[2]=2. Optimize to O(1) space with two rolling variables.
def climb_stairs(n):
    if n <= 2:
        return n
    prev2, prev1 = 1, 2
    for i in range(3, n + 1):
        curr = prev1 + prev2
        prev2 = prev1
        prev1 = curr
    return prev1

print(climb_stairs(5))  # 8
public class Solution {
    public static int climbStairs(int n) {
        if (n <= 2) return n;
        int prev2 = 1, prev1 = 2;
        for (int i = 3; i <= n; i++) {
            int curr = prev1 + prev2;
            prev2 = prev1;
            prev1 = curr;
        }
        return prev1;
    }

    public static void main(String[] args) {
        System.out.println(climbStairs(5));  // 8
    }
}
Challenge

Coin Change

Given coin denominations and a target amount, find the minimum number of coins needed. Return -1 if not possible. Expected output: 2

def coin_change(coins, amount):
    # __ dp[i] = min coins to make amount i
    dp = [float('inf')] * (amount + 1)
    dp[0] = 0

    for i in range(1, amount + 1):
        for coin in coins:
            if coin <= i:
                # __ Use this coin + whatever was needed for the remainder
                dp[i] = min(dp[i], dp[i - coin] + 1)

    return dp[amount] if dp[amount] != float('inf') else -1

print(coin_change([1, 5, 10, 25], 30))  # 2 (25 + 5)
import java.util.*;

public class Solution {
    public static int coinChange(int[] coins, int amount) {
        // dp[i] = min coins to make amount i
        int[] dp = new int[amount + 1];
        Arrays.fill(dp, amount + 1);  // sentinel: larger than any valid answer
        dp[0] = 0;

        for (int i = 1; i <= amount; i++) {
            for (int coin : coins) {
                if (coin <= i) {
                    // Use this coin + whatever was needed for the remainder
                    dp[i] = Math.min(dp[i], dp[i - coin] + 1);
                }
            }
        }

        return dp[amount] > amount ? -1 : dp[amount];
    }

    public static void main(String[] args) {
        System.out.println(coinChange(new int[]{1, 5, 10, 25}, 30));  // 2
    }
}
Build dp[0..amount] bottom-up. dp[0] = 0. For each amount i, try every coin: dp[i] = min(dp[i], dp[i - coin] + 1). Initialize all other entries to infinity so you only improve on feasible sub-problems.
def coin_change(coins, amount):
    dp = [float('inf')] * (amount + 1)
    dp[0] = 0

    for i in range(1, amount + 1):
        for coin in coins:
            if coin <= i:
                dp[i] = min(dp[i], dp[i - coin] + 1)

    return dp[amount] if dp[amount] != float('inf') else -1

print(coin_change([1, 5, 10, 25], 30))  # 2
import java.util.*;

public class Solution {
    public static int coinChange(int[] coins, int amount) {
        int[] dp = new int[amount + 1];
        Arrays.fill(dp, amount + 1);
        dp[0] = 0;

        for (int i = 1; i <= amount; i++) {
            for (int coin : coins) {
                if (coin <= i) {
                    dp[i] = Math.min(dp[i], dp[i - coin] + 1);
                }
            }
        }

        return dp[amount] > amount ? -1 : dp[amount];
    }

    public static void main(String[] args) {
        System.out.println(coinChange(new int[]{1, 5, 10, 25}, 30));  // 2
    }
}
Challenge

House Robber

Find the maximum money you can rob from houses without robbing two adjacent ones. Expected output: 12

def rob(nums):
    if not nums:
        return 0
    if len(nums) == 1:
        return nums[0]

    # __ Two rolling variables: rob up to i-2 vs i-1
    prev2, prev1 = 0, 0
    for num in nums:
        # __ Rob current house or skip it
        curr = max(prev1, prev2 + num)
        prev2 = prev1
        prev1 = curr

    return prev1

print(rob([2, 7, 9, 3, 1]))  # 12 (2 + 9 + 1)
public class Solution {
    public static int rob(int[] nums) {
        if (nums.length == 1) return nums[0];

        // Two rolling variables: rob up to i-2 vs i-1
        int prev2 = 0, prev1 = 0;
        for (int num : nums) {
            // Rob current house or skip it
            int curr = Math.max(prev1, prev2 + num);
            prev2 = prev1;
            prev1 = curr;
        }
        return prev1;
    }

    public static void main(String[] args) {
        System.out.println(rob(new int[]{2, 7, 9, 3, 1}));  // 12
    }
}
dp[i] = max(dp[i-1], dp[i-2] + nums[i]) — either skip the current house (take dp[i-1]) or rob it (take dp[i-2] + nums[i]). Optimize to O(1) space using two variables that track the last two states.
def rob(nums):
    if not nums:
        return 0
    prev2, prev1 = 0, 0
    for num in nums:
        curr = max(prev1, prev2 + num)
        prev2 = prev1
        prev1 = curr
    return prev1

print(rob([2, 7, 9, 3, 1]))  # 12
public class Solution {
    public static int rob(int[] nums) {
        if (nums.length == 1) return nums[0];
        int prev2 = 0, prev1 = 0;
        for (int num : nums) {
            int curr = Math.max(prev1, prev2 + num);
            prev2 = prev1;
            prev1 = curr;
        }
        return prev1;
    }

    public static void main(String[] args) {
        System.out.println(rob(new int[]{2, 7, 9, 3, 1}));  // 12
    }
}

14. Backtracking

Backtracking is DFS on a decision tree: choose an option, explore recursively, then unchoose (undo) when returning. Always ask: what is the state at each node? what are the choices? what is the base case?

Universal Template

def backtrack(state, choices, result):
    # Base case: state is a complete solution
    if is_complete(state):
        result.append(list(state))      # copy! don't store reference
        return

    for choice in choices:
        if is_valid(choice, state):
            state.append(choice)        # CHOOSE
            backtrack(state, next_choices(choice, state), result)
            state.pop()                 # UNCHOOSE (undo)

Permutations

def permutations(nums: list[int]) -> list[list[int]]:
    """All permutations of a list with no duplicates.
    Time: O(n! * n)  Space: O(n)
    """
    result: list[list[int]] = []

    def backtrack(path: list[int], used: list[bool]) -> None:
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i, x in enumerate(nums):
            if not used[i]:
                used[i] = True
                path.append(x)
                backtrack(path, used)
                path.pop()
                used[i] = False

    backtrack([], [False] * len(nums))
    return result


def permutations_with_duplicates(nums: list[int]) -> list[list[int]]:
    """Permutations where nums may contain duplicates. Prune duplicates.
    Key: sort first; skip if nums[i] == nums[i-1] and nums[i-1] not used.
    """
    nums.sort()
    result: list[list[int]] = []

    def backtrack(path: list[int], used: list[bool]) -> None:
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i in range(len(nums)):
            if used[i]:
                continue
            # Skip duplicate: same value as previous AND previous was not used in this branch
            if i > 0 and nums[i] == nums[i-1] and not used[i-1]:
                continue
            used[i] = True
            path.append(nums[i])
            backtrack(path, used)
            path.pop()
            used[i] = False

    backtrack([], [False] * len(nums))
    return result

Combinations & Subsets

def combinations(n: int, k: int) -> list[list[int]]:
    """All combinations of k numbers from 1..n.
    Key: start index prevents duplicates (only pick forward).
    Time: O(C(n,k) * k)  Space: O(k)
    """
    result: list[list[int]] = []

    def backtrack(start: int, path: list[int]) -> None:
        if len(path) == k:
            result.append(path[:])
            return
        # Pruning: need (k - len(path)) more elements, so stop early
        for i in range(start, n + 1 - (k - len(path)) + 1):
            path.append(i)
            backtrack(i + 1, path)
            path.pop()

    backtrack(1, [])
    return result


def combination_sum(candidates: list[int], target: int) -> list[list[int]]:
    """Combinations that sum to target; each number usable unlimited times."""
    candidates.sort()
    result: list[list[int]] = []

    def backtrack(start: int, path: list[int], remaining: int) -> None:
        if remaining == 0:
            result.append(path[:])
            return
        for i in range(start, len(candidates)):
            if candidates[i] > remaining:
                break                   # pruning: sorted, so rest are bigger
            path.append(candidates[i])
            backtrack(i, path, remaining - candidates[i])  # i, not i+1: reuse allowed
            path.pop()

    backtrack(0, [], target)
    return result


def subsets(nums: list[int]) -> list[list[int]]:
    """Power set: all 2^n subsets.
    Time: O(2^n * n)  Space: O(n)
    """
    result: list[list[int]] = []

    def backtrack(start: int, path: list[int]) -> None:
        result.append(path[:])          # every state is a valid subset
        for i in range(start, len(nums)):
            path.append(nums[i])
            backtrack(i + 1, path)
            path.pop()

    backtrack(0, [])
    return result

N-Queens

def solve_n_queens(n: int) -> list[list[str]]:
    """Place n queens on n×n board with no two queens attacking each other.
    Track columns and both diagonals to check attacks in O(1).
    Time: O(n!)  Space: O(n)
    """
    result: list[list[str]] = []
    queens: list[int] = []              # queens[row] = col
    cols: set[int] = set()
    diag1: set[int] = set()             # row - col (NW-SE diagonal)
    diag2: set[int] = set()             # row + col (NE-SW diagonal)

    def backtrack(row: int) -> None:
        if row == n:
            board = ['.' * c + 'Q' + '.' * (n - c - 1)
                     for c in queens]
            result.append(board)
            return
        for col in range(n):
            if col in cols or (row - col) in diag1 or (row + col) in diag2:
                continue
            cols.add(col); diag1.add(row - col); diag2.add(row + col)
            queens.append(col)
            backtrack(row + 1)
            queens.pop()
            cols.remove(col); diag1.remove(row - col); diag2.remove(row + col)

    backtrack(0)
    return result
def exist(board: list[list[str]], word: str) -> bool:
    """Check if word exists in grid (any path, no reuse of cells).
    Time: O(rows * cols * 4^len(word))  Space: O(len(word))
    """
    rows, cols = len(board), len(board[0])

    def dfs(r: int, c: int, idx: int) -> bool:
        if idx == len(word):
            return True
        if (r < 0 or r >= rows or c < 0 or c >= cols
                or board[r][c] != word[idx]):
            return False
        temp = board[r][c]
        board[r][c] = '#'               # mark visited
        found = any(
            dfs(r + dr, c + dc, idx + 1)
            for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]
        )
        board[r][c] = temp              # restore
        return found

    return any(
        dfs(r, c, 0)
        for r in range(rows)
        for c in range(cols)
    )
Pruning techniques
1. Early termination: if remaining elements cannot satisfy the constraint, return immediately.
2. Sort first: enables break statements when candidates exceed target.
3. Skip duplicates: sort + skip consecutive equal values at same decision level.
4. Symmetry breaking: encode constraints that eliminate mirror solutions (e.g. first queen in left half only).
LeetCode problems
46 Permutations · 47 Permutations II · 39 Combination Sum · 40 Combination Sum II · 78 Subsets · 90 Subsets II · 51 N-Queens · 79 Word Search · 37 Sudoku Solver
Practice
Challenge

Generate All Subsets

Return all subsets (the power set) of a list of distinct integers. Order within the output list does not matter. Expected output: 8 subsets total.

def subsets(nums):
    result = []

    def backtrack(start, current):
        # __ Record a copy of the current subset
        result.append(current[:])
        for i in range(start, len(nums)):
            # __ Include nums[i]
            current.append(nums[i])
            backtrack(i + 1, current)
            # __ Undo the choice (backtrack)
            current.pop()

    backtrack(0, [])
    return result

res = subsets([1, 2, 3])
print(len(res))  # 8
print(sorted(res))  # [[], [1], [1, 2], [1, 2, 3], [1, 3], [2], [2, 3], [3]]
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] nums, int start, List<Integer> current) {
        // Record a copy of the current subset
        result.add(new ArrayList<>(current));
        for (int i = start; i < nums.length; i++) {
            // Include nums[i]
            current.add(nums[i]);
            backtrack(nums, i + 1, current);
            // Undo the choice (backtrack)
            current.remove(current.size() - 1);
        }
    }

    public static List<List<Integer>> subsets(int[] nums) {
        result.clear();
        backtrack(nums, 0, new ArrayList<>());
        return result;
    }

    public static void main(String[] args) {
        int[] nums = {1, 2, 3};
        List<List<Integer>> res = subsets(nums);
        System.out.println(res.size());  // 8
    }
}
At each recursive call, add the current subset to results (the empty list is a valid subset). Then for each index from start onwards, include that element, recurse with start = i + 1, then remove it. This naturally generates all 2^n subsets.
def subsets(nums):
    result = []

    def backtrack(start, current):
        result.append(current[:])
        for i in range(start, len(nums)):
            current.append(nums[i])
            backtrack(i + 1, current)
            current.pop()

    backtrack(0, [])
    return result

res = subsets([1, 2, 3])
print(len(res))  # 8
print(sorted(res))
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] nums, int start, List<Integer> current) {
        result.add(new ArrayList<>(current));
        for (int i = start; i < nums.length; i++) {
            current.add(nums[i]);
            backtrack(nums, i + 1, current);
            current.remove(current.size() - 1);
        }
    }

    public static List<List<Integer>> subsets(int[] nums) {
        result.clear();
        backtrack(nums, 0, new ArrayList<>());
        return result;
    }

    public static void main(String[] args) {
        int[] nums = {1, 2, 3};
        List<List<Integer>> res = subsets(nums);
        System.out.println(res.size());  // 8
    }
}
Challenge

Generate Permutations

Return all permutations of distinct integers. We check the count, not the order. Expected output: 6

def permutations(nums):
    result = []
    used = [False] * len(nums)

    def backtrack(path):
        # __ Base case: full-length path is a complete permutation
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i in range(len(nums)):
            if used[i]:
                continue
            # __ Choose nums[i]
            used[i] = True
            path.append(nums[i])
            backtrack(path)
            # __ Unchoose (backtrack)
            path.pop()
            used[i] = False

    backtrack([])
    return result

print(len(permutations([1, 2, 3])))  # 6
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] nums, boolean[] used, List<Integer> path) {
        // Base case: full-length path is a complete permutation
        if (path.size() == nums.length) {
            result.add(new ArrayList<>(path));
            return;
        }
        for (int i = 0; i < nums.length; i++) {
            if (used[i]) continue;
            // Choose nums[i]
            used[i] = true;
            path.add(nums[i]);
            backtrack(nums, used, path);
            // Unchoose (backtrack)
            path.remove(path.size() - 1);
            used[i] = false;
        }
    }

    public static List<List<Integer>> permutations(int[] nums) {
        result.clear();
        backtrack(nums, new boolean[nums.length], new ArrayList<>());
        return result;
    }

    public static void main(String[] args) {
        int[] nums = {1, 2, 3};
        System.out.println(permutations(nums).size());  // 6
    }
}
Track a used[] boolean array. At each recursion level, try every unused element. Mark it used, append to path, recurse. When path length equals n, record a copy. Then undo: pop from path, mark unused. This gives n! permutations.
def permutations(nums):
    result = []
    used = [False] * len(nums)

    def backtrack(path):
        if len(path) == len(nums):
            result.append(path[:])
            return
        for i in range(len(nums)):
            if used[i]:
                continue
            used[i] = True
            path.append(nums[i])
            backtrack(path)
            path.pop()
            used[i] = False

    backtrack([])
    return result

print(len(permutations([1, 2, 3])))  # 6
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] nums, boolean[] used, List<Integer> path) {
        if (path.size() == nums.length) {
            result.add(new ArrayList<>(path));
            return;
        }
        for (int i = 0; i < nums.length; i++) {
            if (used[i]) continue;
            used[i] = true;
            path.add(nums[i]);
            backtrack(nums, used, path);
            path.remove(path.size() - 1);
            used[i] = false;
        }
    }

    public static List<List<Integer>> permutations(int[] nums) {
        result.clear();
        backtrack(nums, new boolean[nums.length], new ArrayList<>());
        return result;
    }

    public static void main(String[] args) {
        int[] nums = {1, 2, 3};
        System.out.println(permutations(nums).size());  // 6
    }
}
Challenge

Combination Sum

Find all unique combinations that sum to target. Each candidate can be used multiple times. Expected output: [[2, 2, 3], [7]]

def combination_sum(candidates, target):
    candidates.sort()
    result = []

    def backtrack(start, current, remaining):
        if remaining == 0:
            result.append(current[:])
            return
        for i in range(start, len(candidates)):
            c = candidates[i]
            # __ Pruning: no point trying larger candidates
            if c > remaining:
                break
            current.append(c)
            # __ Use same index i (not i+1) to allow reuse
            backtrack(i, current, remaining - c)
            current.pop()

    backtrack(0, [], target)
    return result

print(combination_sum([2, 3, 6, 7], 7))  # [[2, 2, 3], [7]]
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] candidates, int start,
                          List<Integer> current, int remaining) {
        if (remaining == 0) {
            result.add(new ArrayList<>(current));
            return;
        }
        for (int i = start; i < candidates.length; i++) {
            int c = candidates[i];
            // Pruning: no point trying larger candidates
            if (c > remaining) break;
            current.add(c);
            // Use same index i (not i+1) to allow reuse
            backtrack(candidates, i, current, remaining - c);
            current.remove(current.size() - 1);
        }
    }

    public static List<List<Integer>> combinationSum(int[] candidates, int target) {
        result.clear();
        Arrays.sort(candidates);
        backtrack(candidates, 0, new ArrayList<>(), target);
        return result;
    }

    public static void main(String[] args) {
        int[] candidates = {2, 3, 6, 7};
        System.out.println(combinationSum(candidates, 7));
        // [[2, 2, 3], [7]]
    }
}
Sort candidates first to enable pruning. Pass start index to avoid revisiting earlier candidates (preventing duplicates). Pass the same i (not i+1) when recursing to allow reusing the same number. Break early when candidate > remaining.
def combination_sum(candidates, target):
    candidates.sort()
    result = []

    def backtrack(start, current, remaining):
        if remaining == 0:
            result.append(current[:])
            return
        for i in range(start, len(candidates)):
            c = candidates[i]
            if c > remaining:
                break
            current.append(c)
            backtrack(i, current, remaining - c)
            current.pop()

    backtrack(0, [], target)
    return result

print(combination_sum([2, 3, 6, 7], 7))  # [[2, 2, 3], [7]]
import java.util.*;

public class Solution {
    static List<List<Integer>> result = new ArrayList<>();

    static void backtrack(int[] candidates, int start,
                          List<Integer> current, int remaining) {
        if (remaining == 0) {
            result.add(new ArrayList<>(current));
            return;
        }
        for (int i = start; i < candidates.length; i++) {
            int c = candidates[i];
            if (c > remaining) break;
            current.add(c);
            backtrack(candidates, i, current, remaining - c);
            current.remove(current.size() - 1);
        }
    }

    public static List<List<Integer>> combinationSum(int[] candidates, int target) {
        result.clear();
        Arrays.sort(candidates);
        backtrack(candidates, 0, new ArrayList<>(), target);
        return result;
    }

    public static void main(String[] args) {
        int[] candidates = {2, 3, 6, 7};
        System.out.println(combinationSum(candidates, 7));
        // [[2, 2, 3], [7]]
    }
}

15. Greedy Algorithms

A greedy algorithm makes the locally optimal choice at each step. It works when the problem has the greedy choice property (a global optimum can be built from local optima) and optimal substructure. If you cannot prove these, use DP instead.

Greedy vs DP decision guide
GreedyDP
One pass, irrevocable choicesExplore multiple options per step
Greedy choice property provableOptimal substructure only
O(n log n) or O(n) typicalO(n^2) or O(n * capacity) typical
Interval scheduling, MST, HuffmanKnapsack, LCS, Edit Distance

Interval Scheduling

def merge_intervals(intervals: list[list[int]]) -> list[list[int]]:
    """Merge all overlapping intervals.
    Time: O(n log n)  Space: O(n)
    """
    if not intervals:
        return []
    intervals.sort(key=lambda x: x[0])  # sort by start
    merged = [intervals[0]]

    for start, end in intervals[1:]:
        if start <= merged[-1][1]:       # overlaps with last merged
            merged[-1][1] = max(merged[-1][1], end)
        else:
            merged.append([start, end])

    return merged


def meeting_rooms_ii(intervals: list[list[int]]) -> int:
    """Minimum number of rooms to schedule all meetings (no overlap per room).
    Key insight: track when the earliest-ending meeting finishes.
    Time: O(n log n)  Space: O(n)
    """
    if not intervals:
        return 0
    intervals.sort(key=lambda x: x[0])
    heap: list[int] = []                # stores end times of active meetings

    for start, end in intervals:
        if heap and heap[0] <= start:   # a room is free
            heapq.heapreplace(heap, end)
        else:
            heapq.heappush(heap, end)   # need a new room

    return len(heap)


def non_overlapping_intervals(intervals: list[list[int]]) -> int:
    """Minimum intervals to remove so none overlap.
    Greedy: always keep the interval with the earliest end time.
    Time: O(n log n)  Space: O(1)
    """
    intervals.sort(key=lambda x: x[1])  # sort by end time
    removed = 0
    last_end = float('-inf')

    for start, end in intervals:
        if start >= last_end:
            last_end = end              # keep this interval
        else:
            removed += 1               # remove the later-ending one

    return removed

Jump Game Variants

def can_jump(nums: list[int]) -> bool:
    """Can you reach the last index? Each element = max jump from that index.
    Greedy: track the maximum reachable index.
    Time: O(n)  Space: O(1)
    """
    max_reach = 0
    for i, jump in enumerate(nums):
        if i > max_reach:
            return False                # current index is unreachable
        max_reach = max(max_reach, i + jump)
    return True


def jump_game_ii(nums: list[int]) -> int:
    """Minimum jumps to reach last index.
    Greedy: at each step, choose the jump that extends reach the farthest.
    Time: O(n)  Space: O(1)
    """
    jumps = 0
    current_end = 0                     # end of current jump range
    farthest = 0                        # farthest we can reach from current range

    for i in range(len(nums) - 1):
        farthest = max(farthest, i + nums[i])
        if i == current_end:            # must jump now
            jumps += 1
            current_end = farthest

    return jumps

Task Scheduler

from collections import Counter

def least_interval(tasks: list[str], n: int) -> int:
    """Minimum time to execute all tasks with cooldown n between same tasks.
    Key insight: the most frequent task determines the lower bound.
    Time: O(len(tasks))  Space: O(1) — at most 26 task types
    """
    freq = Counter(tasks)
    max_freq = max(freq.values())
    # Number of tasks with maximum frequency
    max_count = sum(1 for f in freq.values() if f == max_freq)

    # Formula: (max_freq - 1) * (n + 1) + max_count
    # The last chunk may not be full — take max with len(tasks)
    return max(len(tasks), (max_freq - 1) * (n + 1) + max_count)
LeetCode problems
56 Merge Intervals · 57 Insert Interval · 252 Meeting Rooms · 253 Meeting Rooms II · 435 Non-Overlapping Intervals · 55 Jump Game · 45 Jump Game II · 621 Task Scheduler · 763 Partition Labels
Practice
Challenge

Best Time to Buy and Sell Stock

Given daily prices, find the maximum profit from one buy and one sell transaction. You must buy before you sell. Expected output: 5

def max_profit(prices):
    # __ Track the lowest price seen so far
    min_price = float('inf')
    max_profit = 0

    for price in prices:
        if price < min_price:
            # __ Found a new minimum — update buying price
            min_price = price
        else:
            # __ Check if selling today beats our best profit
            max_profit = max(max_profit, price - min_price)

    return max_profit

print(max_profit([7, 1, 5, 3, 6, 4]))  # 5 (buy at 1, sell at 6)
public class Solution {
    public static int maxProfit(int[] prices) {
        // Track the lowest price seen so far
        int minPrice = Integer.MAX_VALUE;
        int maxProfit = 0;

        for (int price : prices) {
            if (price < minPrice) {
                // Found a new minimum — update buying price
                minPrice = price;
            } else {
                // Check if selling today beats our best profit
                maxProfit = Math.max(maxProfit, price - minPrice);
            }
        }

        return maxProfit;
    }

    public static void main(String[] args) {
        System.out.println(maxProfit(new int[]{7, 1, 5, 3, 6, 4}));  // 5
    }
}
Greedily track the minimum price seen so far. At each step, compute potential profit as price - min_price and update the running maximum. This works because the minimum before any given day is always the best day to have bought.
def max_profit(prices):
    min_price = float('inf')
    max_profit = 0

    for price in prices:
        if price < min_price:
            min_price = price
        else:
            max_profit = max(max_profit, price - min_price)

    return max_profit

print(max_profit([7, 1, 5, 3, 6, 4]))  # 5
public class Solution {
    public static int maxProfit(int[] prices) {
        int minPrice = Integer.MAX_VALUE;
        int maxProfit = 0;

        for (int price : prices) {
            if (price < minPrice) {
                minPrice = price;
            } else {
                maxProfit = Math.max(maxProfit, price - minPrice);
            }
        }

        return maxProfit;
    }

    public static void main(String[] args) {
        System.out.println(maxProfit(new int[]{7, 1, 5, 3, 6, 4}));  // 5
    }
}
Challenge

Jump Game

Each element in nums is the maximum jump length from that position. Determine if you can reach the last index. Expected output: True then False

def can_jump(nums):
    # __ Farthest index reachable so far
    max_reach = 0

    for i, jump in enumerate(nums):
        # __ If current index is beyond what we can reach, we're stuck
        if i > max_reach:
            return False
        # __ Extend max reach from current position
        max_reach = max(max_reach, i + jump)

    return True

print(can_jump([2, 3, 1, 1, 4]))  # True
print(can_jump([3, 2, 1, 0, 4]))  # False
public class Solution {
    public static boolean canJump(int[] nums) {
        // Farthest index reachable so far
        int maxReach = 0;

        for (int i = 0; i < nums.length; i++) {
            // If current index is beyond what we can reach, we're stuck
            if (i > maxReach) return false;
            // Extend max reach from current position
            maxReach = Math.max(maxReach, i + nums[i]);
        }

        return true;
    }

    public static void main(String[] args) {
        System.out.println(canJump(new int[]{2, 3, 1, 1, 4}));  // true
        System.out.println(canJump(new int[]{3, 2, 1, 0, 4}));  // false
    }
}
Track max_reach: the farthest index reachable so far. For each index i, if i > max_reach you cannot get here — return false. Otherwise update max_reach = max(max_reach, i + nums[i]). If you never get stuck, return true.
def can_jump(nums):
    max_reach = 0

    for i, jump in enumerate(nums):
        if i > max_reach:
            return False
        max_reach = max(max_reach, i + jump)

    return True

print(can_jump([2, 3, 1, 1, 4]))  # True
print(can_jump([3, 2, 1, 0, 4]))  # False
public class Solution {
    public static boolean canJump(int[] nums) {
        int maxReach = 0;

        for (int i = 0; i < nums.length; i++) {
            if (i > maxReach) return false;
            maxReach = Math.max(maxReach, i + nums[i]);
        }

        return true;
    }

    public static void main(String[] args) {
        System.out.println(canJump(new int[]{2, 3, 1, 1, 4}));  // true
        System.out.println(canJump(new int[]{3, 2, 1, 0, 4}));  // false
    }
}
Challenge

Non-overlapping Intervals

Find the minimum number of intervals to remove so the rest do not overlap. Expected output: 1

def erase_overlap(intervals):
    # __ Sort by end time (greedy: keep intervals that end earliest)
    intervals.sort(key=lambda x: x[1])

    removals = 0
    # __ End time of the last interval we kept
    prev_end = float('-inf')

    for start, end in intervals:
        if start >= prev_end:
            # __ No overlap: keep this interval
            prev_end = end
        else:
            # __ Overlap: remove this interval (it ends later since we sorted by end)
            removals += 1

    return removals

print(erase_overlap([[1, 2], [2, 3], [3, 4], [1, 3]]))  # 1
import java.util.*;

public class Solution {
    public static int eraseOverlapIntervals(int[][] intervals) {
        // Sort by end time (greedy: keep intervals that end earliest)
        Arrays.sort(intervals, Comparator.comparingInt(a -> a[1]));

        int removals = 0;
        // End time of the last interval we kept
        int prevEnd = Integer.MIN_VALUE;

        for (int[] interval : intervals) {
            int start = interval[0], end = interval[1];
            if (start >= prevEnd) {
                // No overlap: keep this interval
                prevEnd = end;
            } else {
                // Overlap: remove this interval
                removals++;
            }
        }

        return removals;
    }

    public static void main(String[] args) {
        int[][] intervals = {{1,2},{2,3},{3,4},{1,3}};
        System.out.println(eraseOverlapIntervals(intervals));  // 1
    }
}
Sort intervals by end time. Greedily keep intervals that end earliest (they leave the most room for future intervals). Track prev_end. If start >= prev_end, no overlap — keep it and update prev_end. Otherwise, count a removal (we discard the current interval since the kept one ends sooner).
def erase_overlap(intervals):
    intervals.sort(key=lambda x: x[1])

    removals = 0
    prev_end = float('-inf')

    for start, end in intervals:
        if start >= prev_end:
            prev_end = end
        else:
            removals += 1

    return removals

print(erase_overlap([[1, 2], [2, 3], [3, 4], [1, 3]]))  # 1
import java.util.*;

public class Solution {
    public static int eraseOverlapIntervals(int[][] intervals) {
        Arrays.sort(intervals, Comparator.comparingInt(a -> a[1]));

        int removals = 0;
        int prevEnd = Integer.MIN_VALUE;

        for (int[] interval : intervals) {
            int start = interval[0], end = interval[1];
            if (start >= prevEnd) {
                prevEnd = end;
            } else {
                removals++;
            }
        }

        return removals;
    }

    public static void main(String[] args) {
        int[][] intervals = {{1,2},{2,3},{3,4},{1,3}};
        System.out.println(eraseOverlapIntervals(intervals));  // 1
    }
}

16. Tries (Prefix Trees)

A trie stores strings character by character, enabling O(L) insert/search where L is word length. Ideal for prefix lookups, autocomplete, and spell checking.

Trie Node — Insert / Search / StartsWith

class TrieNode:
    def __init__(self) -> None:
        self.children: dict[str, 'TrieNode'] = {}
        self.is_end: bool = False


class Trie:
    """Prefix tree supporting insert, search, and startsWith.
    Time: O(L) per operation  Space: O(total characters inserted)
    """
    def __init__(self) -> None:
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word: str) -> bool:
        """True if word is in the trie."""
        node = self._traverse(word)
        return node is not None and node.is_end

    def starts_with(self, prefix: str) -> bool:
        """True if any word in trie starts with prefix."""
        return self._traverse(prefix) is not None

    def _traverse(self, s: str) -> 'TrieNode | None':
        node = self.root
        for ch in s:
            if ch not in node.children:
                return None
            node = node.children[ch]
        return node

Word Dictionary with Wildcards

class WordDictionary:
    """Supports addWord and search where '.' matches any character.
    search uses DFS when '.' is encountered.
    Time: O(L) insert; O(26^L) worst case search with many '.'s.
    """
    def __init__(self) -> None:
        self.root = TrieNode()

    def add_word(self, word: str) -> None:
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word: str) -> bool:
        return self._dfs(self.root, word, 0)

    def _dfs(self, node: TrieNode, word: str, idx: int) -> bool:
        if idx == len(word):
            return node.is_end
        ch = word[idx]
        if ch == '.':
            return any(
                self._dfs(child, word, idx + 1)
                for child in node.children.values()
            )
        if ch not in node.children:
            return False
        return self._dfs(node.children[ch], word, idx + 1)

Design Autocomplete System

import heapq
from collections import defaultdict

def find_words_with_prefix(words: list[str],
                           prefix: str,
                           top_k: int = 3) -> list[str]:
    """Return top-k lexicographically smallest words matching prefix.
    Uses a trie to navigate to the prefix node, then DFS to collect all words.
    Time: O(L + W) where W = total characters in matching words.
    """
    trie = Trie()
    for w in words:
        trie.insert(w)

    node = trie._traverse(prefix)
    if node is None:
        return []

    results: list[str] = []

    def dfs(cur: TrieNode, path: str) -> None:
        if len(results) >= top_k:
            return
        if cur.is_end:
            results.append(path)
        for ch in sorted(cur.children):    # sorted for lexicographic order
            dfs(cur.children[ch], path + ch)

    dfs(node, prefix)
    return results

XOR Trie for Maximum XOR

class XorTrie:
    """Binary trie for maximum XOR of two numbers in an array.
    Each number stored as 32 binary bits, MSB first.
    Time: O(n * 32) = O(n)
    """
    def __init__(self) -> None:
        self.children: list['XorTrie | None'] = [None, None]

    def insert(self, num: int) -> None:
        node = self
        for bit in range(31, -1, -1):
            b = (num >> bit) & 1
            if node.children[b] is None:
                node.children[b] = XorTrie()
            node = node.children[b]

    def max_xor_with(self, num: int) -> int:
        """Find value in trie that maximizes XOR with num."""
        node = self
        xor = 0
        for bit in range(31, -1, -1):
            b = (num >> bit) & 1
            want = 1 - b                # prefer opposite bit
            if node.children[want] is not None:
                xor |= (1 << bit)
                node = node.children[want]
            else:
                node = node.children[b]
        return xor


def find_maximum_xor(nums: list[int]) -> int:
    """LeetCode 421: Maximum XOR of Two Numbers in an Array.
    Time: O(n)  Space: O(n)
    """
    trie = XorTrie()
    for x in nums:
        trie.insert(x)
    return max(trie.max_xor_with(x) for x in nums)
LeetCode problems
208 Implement Trie · 211 Word Dictionary · 212 Word Search II · 421 Maximum XOR · 1268 Search Suggestions System · 720 Longest Word in Dictionary
Practice
Challenge

Implement Trie

Implement insert and search for a basic trie. After inserting "apple" and "app", searching "apple" returns True, "app" returns True, and "appl" returns False.

class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        # TODO: walk char by char, create nodes as needed
        # mark is_end = True at the last node
        pass

    def search(self, word):
        # TODO: walk char by char, return False if node missing
        # return is_end at the last node
        pass

t = Trie()
t.insert("apple")
t.insert("app")
print(t.search("apple"))  # True
print(t.search("app"))    # True
print(t.search("appl"))   # False
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        boolean isEnd = false;
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        // TODO: walk char by char, create nodes as needed
        // mark isEnd = true at the last node
    }

    static boolean search(String word) {
        // TODO: walk char by char, return false if node missing
        // return isEnd at the last node
        return false;
    }

    public static void main(String[] args) {
        insert("apple");
        insert("app");
        System.out.println(search("apple"));  // true
        System.out.println(search("app"));    // true
        System.out.println(search("appl"));   // false
    }
}
Each node has a dict of children keyed by character, plus an is_end flag. For insert, walk each character, creating a new node if missing, then set is_end = True on the last node. For search, walk each character, returning False if a node is missing, then return is_end.
class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                return False
            node = node.children[ch]
        return node.is_end

t = Trie()
t.insert("apple")
t.insert("app")
print(t.search("apple"))  # True
print(t.search("app"))    # True
print(t.search("appl"))   # False
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        boolean isEnd = false;
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        TrieNode node = root;
        for (char ch : word.toCharArray()) {
            node.children.putIfAbsent(ch, new TrieNode());
            node = node.children.get(ch);
        }
        node.isEnd = true;
    }

    static boolean search(String word) {
        TrieNode node = root;
        for (char ch : word.toCharArray()) {
            if (!node.children.containsKey(ch)) return false;
            node = node.children.get(ch);
        }
        return node.isEnd;
    }

    public static void main(String[] args) {
        insert("apple");
        insert("app");
        System.out.println(search("apple"));  // true
        System.out.println(search("app"));    // true
        System.out.println(search("appl"));   // false
    }
}
Challenge

Prefix Count

Count how many inserted words start with a given prefix. After inserting ["apple", "app", "apricot", "banana"], prefix "ap"3, prefix "ban"1, prefix "c"0.

class TrieNode:
    def __init__(self):
        self.children = {}
        self.count = 0  # words passing through this node

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        # TODO: walk each char, increment count on every node visited
        pass

    def prefix_count(self, prefix):
        # TODO: walk to the end of prefix, return count
        # return 0 if prefix not found
        pass

t = Trie()
for w in ["apple", "app", "apricot", "banana"]:
    t.insert(w)
print(t.prefix_count("ap"))   # 3
print(t.prefix_count("ban"))  # 1
print(t.prefix_count("c"))    # 0
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        int count = 0;  // words passing through this node
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        // TODO: walk each char, increment count on every node visited
    }

    static int prefixCount(String prefix) {
        // TODO: walk to the end of prefix, return count
        // return 0 if prefix not found
        return 0;
    }

    public static void main(String[] args) {
        for (String w : new String[]{"apple", "app", "apricot", "banana"}) {
            insert(w);
        }
        System.out.println(prefixCount("ap"));   // 3
        System.out.println(prefixCount("ban"));  // 1
        System.out.println(prefixCount("c"));    // 0
    }
}
Add a count field to each node. During insert, increment count on every node you visit (not just the last). For prefix_count, traverse to the end of the prefix and return that node's count. Return 0 if the prefix path doesn't exist.
class TrieNode:
    def __init__(self):
        self.children = {}
        self.count = 0

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
            node.count += 1

    def prefix_count(self, prefix):
        node = self.root
        for ch in prefix:
            if ch not in node.children:
                return 0
            node = node.children[ch]
        return node.count

t = Trie()
for w in ["apple", "app", "apricot", "banana"]:
    t.insert(w)
print(t.prefix_count("ap"))   # 3
print(t.prefix_count("ban"))  # 1
print(t.prefix_count("c"))    # 0
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        int count = 0;
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        TrieNode node = root;
        for (char ch : word.toCharArray()) {
            node.children.putIfAbsent(ch, new TrieNode());
            node = node.children.get(ch);
            node.count++;
        }
    }

    static int prefixCount(String prefix) {
        TrieNode node = root;
        for (char ch : prefix.toCharArray()) {
            if (!node.children.containsKey(ch)) return 0;
            node = node.children.get(ch);
        }
        return node.count;
    }

    public static void main(String[] args) {
        for (String w : new String[]{"apple", "app", "apricot", "banana"}) {
            insert(w);
        }
        System.out.println(prefixCount("ap"));   // 3
        System.out.println(prefixCount("ban"));  // 1
        System.out.println(prefixCount("c"));    // 0
    }
}
Challenge

Search with Wildcards

Support '.' matching any single character in search. After inserting ["bad", "dad", "mad"]: ".ad"True, "b.."True, ".a"False (wrong length).

class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word, node=None):
        if node is None:
            node = self.root
        # TODO: walk each char; if '.', recurse into ALL children
        # if regular char, go to that specific child
        # at end of word, return is_end
        pass

t = Trie()
for w in ["bad", "dad", "mad"]:
    t.insert(w)
print(t.search(".ad"))  # True
print(t.search("b.."))  # True
print(t.search("b.d"))  # True
print(t.search("..."))  # True
print(t.search(".a"))   # False
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        boolean isEnd = false;
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        TrieNode node = root;
        for (char ch : word.toCharArray()) {
            node.children.putIfAbsent(ch, new TrieNode());
            node = node.children.get(ch);
        }
        node.isEnd = true;
    }

    static boolean search(String word, int idx, TrieNode node) {
        // TODO: if idx == word.length(), return node.isEnd
        // if word.charAt(idx) == '.', recurse into ALL children
        // otherwise go to the specific child
        return false;
    }

    public static void main(String[] args) {
        for (String w : new String[]{"bad", "dad", "mad"}) {
            insert(w);
        }
        System.out.println(search(".ad", 0, root));  // true
        System.out.println(search("b..", 0, root));  // true
        System.out.println(search("b.d", 0, root));  // true
        System.out.println(search("...", 0, root));  // true
        System.out.println(search(".a",  0, root));  // false
    }
}
When the current character is '.', recurse into all children with the remaining suffix — return True if any branch matches. When it's a regular character, follow only that child (or return False if missing). At the end of the word string, check is_end.
class TrieNode:
    def __init__(self):
        self.children = {}
        self.is_end = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word):
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word, node=None):
        if node is None:
            node = self.root
        for i, ch in enumerate(word):
            if ch == '.':
                return any(
                    self.search(word[i + 1:], child)
                    for child in node.children.values()
                )
            if ch not in node.children:
                return False
            node = node.children[ch]
        return node.is_end

t = Trie()
for w in ["bad", "dad", "mad"]:
    t.insert(w)
print(t.search(".ad"))  # True
print(t.search("b.."))  # True
print(t.search("b.d"))  # True
print(t.search("..."))  # True
print(t.search(".a"))   # False
import java.util.HashMap;

public class Solution {
    static class TrieNode {
        HashMap<Character, TrieNode> children = new HashMap<>();
        boolean isEnd = false;
    }

    static TrieNode root = new TrieNode();

    static void insert(String word) {
        TrieNode node = root;
        for (char ch : word.toCharArray()) {
            node.children.putIfAbsent(ch, new TrieNode());
            node = node.children.get(ch);
        }
        node.isEnd = true;
    }

    static boolean search(String word, int idx, TrieNode node) {
        if (idx == word.length()) return node.isEnd;
        char ch = word.charAt(idx);
        if (ch == '.') {
            for (TrieNode child : node.children.values()) {
                if (search(word, idx + 1, child)) return true;
            }
            return false;
        }
        if (!node.children.containsKey(ch)) return false;
        return search(word, idx + 1, node.children.get(ch));
    }

    public static void main(String[] args) {
        for (String w : new String[]{"bad", "dad", "mad"}) {
            insert(w);
        }
        System.out.println(search(".ad", 0, root));  // true
        System.out.println(search("b..", 0, root));  // true
        System.out.println(search("b.d", 0, root));  // true
        System.out.println(search("...", 0, root));  // true
        System.out.println(search(".a",  0, root));  // false
    }
}

17. Bit Manipulation

Essential Operations

OperationExpressionNotes
Check bit k(n >> k) & 11 if bit k is set
Set bit kn | (1 << k)Turn on bit k
Clear bit kn & ~(1 << k)Turn off bit k
Toggle bit kn ^ (1 << k)Flip bit k
Clear lowest set bitn & (n - 1)Brian Kernighan trick
Isolate lowest set bitn & (-n)Also n & (~n + 1)
Check power of 2n > 0 and (n & (n-1)) == 0
XOR selfn ^ n == 0Cancels out pairs
XOR with 0n ^ 0 == nIdentity
Swap without tempa ^= b; b ^= a; a ^= bCareful with a == b

Common Bit Tricks

def count_set_bits_kernighan(n: int) -> int:
    """Count 1-bits using Brian Kernighan's algorithm.
    Each iteration clears the lowest set bit.
    Time: O(number of set bits)  — faster than O(32) when bits are sparse.
    """
    count = 0
    while n:
        n &= n - 1                      # clear lowest set bit
        count += 1
    return count


def count_set_bits_builtin(n: int) -> int:
    """Python built-in: bin(n).count('1') or int.bit_count() in 3.10+"""
    return bin(n).count('1')


def single_number(nums: list[int]) -> int:
    """Find the one element that appears once; all others appear twice.
    XOR: a ^ a = 0, a ^ 0 = a. XOR all elements; pairs cancel out.
    Time: O(n)  Space: O(1)
    """
    result = 0
    for x in nums:
        result ^= x
    return result


def single_number_iii(nums: list[int]) -> list[int]:
    """Two elements appear once; all others appear twice. Return the two.
    Step 1: XOR all -> xor = a ^ b (where a, b are the two singles).
    Step 2: Find any set bit in xor (a and b differ here).
    Step 3: Partition nums by that bit — each group XORs down to one of a,b.
    Time: O(n)  Space: O(1)
    """
    xor = 0
    for x in nums:
        xor ^= x

    diff_bit = xor & (-xor)             # isolate any differing bit

    a = 0
    for x in nums:
        if x & diff_bit:
            a ^= x

    return [a, xor ^ a]


def reverse_bits(n: int) -> int:
    """Reverse the 32 bits of a 32-bit unsigned integer."""
    result = 0
    for _ in range(32):
        result = (result << 1) | (n & 1)
        n >>= 1
    return result

Subsets via Bitmask Enumeration

def all_subsets_bitmask(nums: list[int]) -> list[list[int]]:
    """Enumerate all 2^n subsets using bitmask.
    Bit i of mask indicates whether nums[i] is included.
    Time: O(2^n * n)  Space: O(n)
    """
    n = len(nums)
    result: list[list[int]] = []
    for mask in range(1 << n):          # 0 to 2^n - 1
        subset = [nums[i] for i in range(n) if mask & (1 << i)]
        result.append(subset)
    return result


def count_bits(n: int) -> list[int]:
    """For every i in 0..n, count number of 1-bits. LeetCode 338.
    DP: ans[i] = ans[i >> 1] + (i & 1)
    Time: O(n)  Space: O(n)
    """
    dp = [0] * (n + 1)
    for i in range(1, n + 1):
        dp[i] = dp[i >> 1] + (i & 1)
    return dp
LeetCode problems
136 Single Number · 260 Single Number III · 191 Number of 1 Bits · 338 Counting Bits · 190 Reverse Bits · 268 Missing Number · 371 Sum of Two Integers (no +/-) · 1318 Minimum Flips to Make a OR b Equal to c
Practice
Challenge

Single Number

Every element appears exactly twice except one. Find the element that appears only once. [4, 1, 2, 1, 2]4.

def single_number(nums):
    # TODO: XOR all elements together
    # pairs cancel out (a ^ a = 0), leaving the single one
    pass

print(single_number([4, 1, 2, 1, 2]))  # 4
print(single_number([2, 2, 7]))        # 7
public class Solution {
    static int singleNumber(int[] nums) {
        // TODO: XOR all elements together
        // pairs cancel out (a ^ a = 0), leaving the single one
        return 0;
    }

    public static void main(String[] args) {
        System.out.println(singleNumber(new int[]{4, 1, 2, 1, 2}));  // 4
        System.out.println(singleNumber(new int[]{2, 2, 7}));        // 7
    }
}
XOR all elements together. Because a ^ a = 0 and a ^ 0 = a, every duplicated pair cancels out. What remains is the single element. One line: return functools.reduce(lambda a, b: a ^ b, nums), or a simple loop with result ^= n.
def single_number(nums):
    result = 0
    for n in nums:
        result ^= n
    return result

print(single_number([4, 1, 2, 1, 2]))  # 4
print(single_number([2, 2, 7]))        # 7
public class Solution {
    static int singleNumber(int[] nums) {
        int result = 0;
        for (int n : nums) {
            result ^= n;
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(singleNumber(new int[]{4, 1, 2, 1, 2}));  // 4
        System.out.println(singleNumber(new int[]{2, 2, 7}));        // 7
    }
}
Challenge

Count Set Bits

Count the number of 1-bits in an integer. count_bits(11)3 (binary 1011 has three 1s).

def count_bits(n):
    # TODO: use Brian Kernighan's trick
    # n &= (n - 1) clears the lowest set bit each iteration
    count = 0
    # your loop here
    return count

print(count_bits(11))  # 3  (1011)
print(count_bits(7))   # 3  (0111)
print(count_bits(0))   # 0
public class Solution {
    static int countBits(int n) {
        // TODO: use Brian Kernighan's trick
        // n &= (n - 1) clears the lowest set bit each iteration
        int count = 0;
        // your loop here
        return count;
    }

    public static void main(String[] args) {
        System.out.println(countBits(11));  // 3  (1011)
        System.out.println(countBits(7));   // 3  (0111)
        System.out.println(countBits(0));   // 0
    }
}
Brian Kernighan's algorithm: while n != 0, do n &= (n - 1) (clears the lowest set bit) and increment a counter. The loop runs exactly as many times as there are set bits — much faster than checking every bit position.
def count_bits(n):
    count = 0
    while n:
        n &= (n - 1)
        count += 1
    return count

print(count_bits(11))  # 3  (1011)
print(count_bits(7))   # 3  (0111)
print(count_bits(0))   # 0
public class Solution {
    static int countBits(int n) {
        int count = 0;
        while (n != 0) {
            n &= (n - 1);
            count++;
        }
        return count;
    }

    public static void main(String[] args) {
        System.out.println(countBits(11));  // 3  (1011)
        System.out.println(countBits(7));   // 3  (0111)
        System.out.println(countBits(0));   // 0
    }
}
Challenge

Power of Two

Check if n is a power of two. is_power_of_two(16)True, is_power_of_two(18)False.

def is_power_of_two(n):
    # TODO: a power of 2 has exactly one set bit
    # use: n > 0 and (n & (n - 1)) == 0
    pass

print(is_power_of_two(16))  # True  (10000)
print(is_power_of_two(18))  # False (10010)
print(is_power_of_two(1))   # True  (2^0)
print(is_power_of_two(0))   # False
public class Solution {
    static boolean isPowerOfTwo(int n) {
        // TODO: a power of 2 has exactly one set bit
        // use: n > 0 && (n & (n - 1)) == 0
        return false;
    }

    public static void main(String[] args) {
        System.out.println(isPowerOfTwo(16));  // true  (10000)
        System.out.println(isPowerOfTwo(18));  // false (10010)
        System.out.println(isPowerOfTwo(1));   // true  (2^0)
        System.out.println(isPowerOfTwo(0));   // false
    }
}
Powers of two have exactly one set bit (e.g. 16 = 10000). The trick n & (n - 1) clears the lowest set bit — if the result is 0, there was only one set bit. Don't forget to check n > 0 since 0 is not a power of two.
def is_power_of_two(n):
    return n > 0 and (n & (n - 1)) == 0

print(is_power_of_two(16))  # True  (10000)
print(is_power_of_two(18))  # False (10010)
print(is_power_of_two(1))   # True  (2^0)
print(is_power_of_two(0))   # False
public class Solution {
    static boolean isPowerOfTwo(int n) {
        return n > 0 && (n & (n - 1)) == 0;
    }

    public static void main(String[] args) {
        System.out.println(isPowerOfTwo(16));  // true  (10000)
        System.out.println(isPowerOfTwo(18));  // false (10010)
        System.out.println(isPowerOfTwo(1));   // true  (2^0)
        System.out.println(isPowerOfTwo(0));   // false
    }
}

18. Math & Number Theory

GCD and LCM

import math

def gcd(a: int, b: int) -> int:
    """Euclidean algorithm. Time: O(log min(a,b))"""
    while b:
        a, b = b, a % b
    return a

# Python 3.5+: math.gcd(a, b)
# Python 3.9+: math.gcd(*nums) for multiple values

def lcm(a: int, b: int) -> int:
    """LCM via GCD: lcm(a,b) = a * b // gcd(a,b)"""
    return a * b // gcd(a, b)

# Python 3.9+: math.lcm(a, b)

Sieve of Eratosthenes

def sieve_of_eratosthenes(limit: int) -> list[int]:
    """All primes up to limit.
    Time: O(n log log n)  Space: O(n)
    """
    is_prime = [True] * (limit + 1)
    is_prime[0] = is_prime[1] = False

    p = 2
    while p * p <= limit:
        if is_prime[p]:
            for multiple in range(p * p, limit + 1, p):
                is_prime[multiple] = False
        p += 1

    return [x for x in range(2, limit + 1) if is_prime[x]]


def count_primes(n: int) -> int:
    """Count primes strictly less than n. LeetCode 204."""
    if n < 2:
        return 0
    is_prime = [True] * n
    is_prime[0] = is_prime[1] = False
    for p in range(2, int(n**0.5) + 1):
        if is_prime[p]:
            is_prime[p*p::p] = [False] * len(is_prime[p*p::p])
    return sum(is_prime)

Modular Arithmetic

MOD = 10**9 + 7   # standard mod in competitive programming

# Key properties (all under MOD):
# (a + b) % MOD == ((a % MOD) + (b % MOD)) % MOD
# (a * b) % MOD == ((a % MOD) * (b % MOD)) % MOD
# (a - b) % MOD == ((a % MOD) - (b % MOD) + MOD) % MOD  # add MOD to keep positive
# Division requires modular inverse: (a / b) % MOD = a * pow(b, MOD-2, MOD) % MOD

def fast_power(base: int, exp: int, mod: int) -> int:
    """Modular exponentiation: base^exp % mod.
    Python built-in pow(base, exp, mod) does this optimally.
    Time: O(log exp)  Space: O(1)
    """
    result = 1
    base %= mod
    while exp > 0:
        if exp & 1:
            result = result * base % mod
        base = base * base % mod
        exp >>= 1
    return result
    # Equivalent: return pow(base, exp, mod)


def mod_inverse(a: int, mod: int) -> int:
    """Modular multiplicative inverse using Fermat's little theorem.
    Valid only when mod is prime.
    a^(mod-2) % mod
    """
    return pow(a, mod - 2, mod)

Combinatorics

def combinations_mod(n: int, r: int, mod: int = 10**9 + 7) -> int:
    """n choose r, modulo mod. Uses precomputed factorials.
    Time: O(n)  Space: O(n)
    """
    if r > n or r < 0:
        return 0
    fact = [1] * (n + 1)
    for i in range(1, n + 1):
        fact[i] = fact[i-1] * i % mod
    inv_fact = [1] * (n + 1)
    inv_fact[n] = pow(fact[n], mod - 2, mod)
    for i in range(n - 1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % mod
    return fact[n] * inv_fact[r] % mod * inv_fact[n-r] % mod


def pascal_triangle(num_rows: int) -> list[list[int]]:
    """Pascal's triangle: row i contains C(i,0)..C(i,i).
    Each entry = sum of two entries above it.
    Time: O(n^2)  Space: O(n^2)
    """
    triangle: list[list[int]] = []
    for i in range(num_rows):
        row = [1] * (i + 1)
        for j in range(1, i):
            row[j] = triangle[i-1][j-1] + triangle[i-1][j]
        triangle.append(row)
    return triangle

Random Sampling

import random

def reservoir_sample(stream, k: int) -> list:
    """Select k items uniformly at random from a stream of unknown length.
    Algorithm R (Vitter): each element i (0-indexed) replaces a reservoir slot
    with probability k/(i+1).
    Time: O(n)  Space: O(k)
    """
    reservoir = []
    for i, item in enumerate(stream):
        if i < k:
            reservoir.append(item)
        else:
            j = random.randint(0, i)
            if j < k:
                reservoir[j] = item
    return reservoir


def fisher_yates_shuffle(nums: list) -> list:
    """In-place uniform random shuffle.
    Time: O(n)  Space: O(1)
    """
    n = len(nums)
    for i in range(n - 1, 0, -1):
        j = random.randint(0, i)
        nums[i], nums[j] = nums[j], nums[i]
    return nums
LeetCode problems
204 Count Primes · 50 Pow(x, n) · 372 Super Pow · 118 Pascal's Triangle · 382 Linked List Random Node (reservoir sampling) · 398 Random Pick Index · 470 Implement rand10 Using rand7
Practice
Challenge

GCD

Compute the greatest common divisor using the Euclidean algorithm. gcd(48, 18)6, gcd(100, 75)25.

def gcd(a, b):
    # TODO: Euclidean algorithm
    # while b != 0: a, b = b, a % b
    # return a
    pass

print(gcd(48, 18))   # 6
print(gcd(100, 75))  # 25
print(gcd(7, 3))     # 1
public class Solution {
    static int gcd(int a, int b) {
        // TODO: Euclidean algorithm
        // while (b != 0): int tmp = b; b = a % b; a = tmp;
        // return a
        return 0;
    }

    public static void main(String[] args) {
        System.out.println(gcd(48, 18));   // 6
        System.out.println(gcd(100, 75));  // 25
        System.out.println(gcd(7, 3));     // 1
    }
}
The Euclidean algorithm: while b != 0, replace (a, b) with (b, a % b). When b reaches 0, a is the GCD. It works because gcd(a, b) == gcd(b, a % b).
def gcd(a, b):
    while b:
        a, b = b, a % b
    return a

print(gcd(48, 18))   # 6
print(gcd(100, 75))  # 25
print(gcd(7, 3))     # 1
public class Solution {
    static int gcd(int a, int b) {
        while (b != 0) {
            int tmp = b;
            b = a % b;
            a = tmp;
        }
        return a;
    }

    public static void main(String[] args) {
        System.out.println(gcd(48, 18));   // 6
        System.out.println(gcd(100, 75));  // 25
        System.out.println(gcd(7, 3));     // 1
    }
}
Challenge

Sieve of Eratosthenes

Find all prime numbers up to n (inclusive). sieve(30)[2, 3, 5, 7, 11, 13, 17, 19, 23, 29].

def sieve(n):
    # TODO: create a boolean array is_prime[0..n], init True
    # mark 0 and 1 as False
    # for each i from 2 to sqrt(n): if is_prime[i], mark multiples False
    # return list of indices where is_prime is True
    pass

print(sieve(30))
# [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
import java.util.*;

public class Solution {
    static List<Integer> sieve(int n) {
        // TODO: boolean array is_prime[0..n], init true
        // mark 0 and 1 false
        // for each i from 2 to sqrt(n): if prime, mark multiples false
        // collect indices where is_prime is true
        return new ArrayList<>();
    }

    public static void main(String[] args) {
        System.out.println(sieve(30));
        // [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
    }
}
Create a boolean list of size n+1 initialized to True. Set indices 0 and 1 to False. For each i from 2 to int(n**0.5) + 1, if is_prime[i] is True, mark i*i, i*i+i, ... up to n as False (start at i*i because smaller multiples were already marked).
def sieve(n):
    is_prime = [True] * (n + 1)
    is_prime[0] = is_prime[1] = False
    for i in range(2, int(n ** 0.5) + 1):
        if is_prime[i]:
            for j in range(i * i, n + 1, i):
                is_prime[j] = False
    return [i for i in range(n + 1) if is_prime[i]]

print(sieve(30))
# [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
import java.util.*;

public class Solution {
    static List<Integer> sieve(int n) {
        boolean[] isPrime = new boolean[n + 1];
        Arrays.fill(isPrime, true);
        isPrime[0] = isPrime[1] = false;
        for (int i = 2; (long) i * i <= n; i++) {
            if (isPrime[i]) {
                for (int j = i * i; j <= n; j += i) {
                    isPrime[j] = false;
                }
            }
        }
        List<Integer> primes = new ArrayList<>();
        for (int i = 2; i <= n; i++) {
            if (isPrime[i]) primes.add(i);
        }
        return primes;
    }

    public static void main(String[] args) {
        System.out.println(sieve(30));
        // [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
    }
}
Challenge

Reverse Integer

Reverse the digits of an integer. Return 0 if the result would overflow a 32-bit signed integer (-2^31 to 2^31 - 1). reverse(123)321, reverse(-456)-654.

def reverse(x):
    sign = -1 if x < 0 else 1
    x = abs(x)
    result = 0
    # TODO: build result digit by digit
    # digit = x % 10, x //= 10, result = result * 10 + digit
    # at the end apply sign, check 32-bit bounds, return 0 if overflow
    return 0

print(reverse(123))   # 321
print(reverse(-456))  # -654
print(reverse(1534236469))  # 0  (overflow)
public class Solution {
    static int reverse(int x) {
        int sign = x < 0 ? -1 : 1;
        long val = Math.abs((long) x);
        long result = 0;
        // TODO: build result digit by digit
        // digit = val % 10, val /= 10, result = result * 10 + digit
        // at the end apply sign, check 32-bit bounds, return 0 if overflow
        return 0;
    }

    public static void main(String[] args) {
        System.out.println(reverse(123));          // 321
        System.out.println(reverse(-456));         // -654
        System.out.println(reverse(1534236469));   // 0  (overflow)
    }
}
Strip the sign, then loop: digit = x % 10, x //= 10, result = result * 10 + digit. After the loop, reapply the sign. Check bounds: if result > 2**31 - 1 or result < -2**31, return 0.
def reverse(x):
    sign = -1 if x < 0 else 1
    x = abs(x)
    result = 0
    while x:
        result = result * 10 + x % 10
        x //= 10
    result *= sign
    if result < -(2 ** 31) or result > 2 ** 31 - 1:
        return 0
    return result

print(reverse(123))          # 321
print(reverse(-456))         # -654
print(reverse(1534236469))   # 0  (overflow)
public class Solution {
    static int reverse(int x) {
        int sign = x < 0 ? -1 : 1;
        long val = Math.abs((long) x);
        long result = 0;
        while (val != 0) {
            result = result * 10 + val % 10;
            val /= 10;
        }
        result *= sign;
        if (result < Integer.MIN_VALUE || result > Integer.MAX_VALUE) return 0;
        return (int) result;
    }

    public static void main(String[] args) {
        System.out.println(reverse(123));          // 321
        System.out.println(reverse(-456));         // -654
        System.out.println(reverse(1534236469));   // 0  (overflow)
    }
}

19. Design Problems

LRU Cache

Least Recently Used cache evicts the entry unused for the longest time. Requires O(1) get and O(1) put. Classic solution: HashMap + doubly linked list. Python's OrderedDict provides a concise built-in alternative.

from collections import OrderedDict

class LRUCache:
    """O(1) get and put using OrderedDict (move_to_end + popitem).
    Time: O(1) amortized  Space: O(capacity)
    """
    def __init__(self, capacity: int) -> None:
        self.cap = capacity
        self.cache: OrderedDict[int, int] = OrderedDict()

    def get(self, key: int) -> int:
        if key not in self.cache:
            return -1
        self.cache.move_to_end(key)     # mark as most recently used
        return self.cache[key]

    def put(self, key: int, value: int) -> None:
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = value
        if len(self.cache) > self.cap:
            self.cache.popitem(last=False)  # evict LRU (front)
LRU Cache with explicit doubly linked list (interview-expected)
class DLLNode:
    def __init__(self, key: int = 0, val: int = 0) -> None:
        self.key = key
        self.val = val
        self.prev: 'DLLNode | None' = None
        self.next: 'DLLNode | None' = None


class LRUCacheManual:
    """Doubly linked list + HashMap. Dummy head/tail simplify edge cases."""
    def __init__(self, capacity: int) -> None:
        self.cap = capacity
        self.map: dict[int, DLLNode] = {}
        self.head = DLLNode()           # dummy head (LRU end)
        self.tail = DLLNode()           # dummy tail (MRU end)
        self.head.next = self.tail
        self.tail.prev = self.head

    def _remove(self, node: DLLNode) -> None:
        node.prev.next = node.next
        node.next.prev = node.prev

    def _add_to_tail(self, node: DLLNode) -> None:
        node.prev = self.tail.prev
        node.next = self.tail
        self.tail.prev.next = node
        self.tail.prev = node

    def get(self, key: int) -> int:
        if key not in self.map:
            return -1
        node = self.map[key]
        self._remove(node)
        self._add_to_tail(node)         # promote to MRU
        return node.val

    def put(self, key: int, value: int) -> None:
        if key in self.map:
            self._remove(self.map[key])
        node = DLLNode(key, value)
        self._add_to_tail(node)
        self.map[key] = node
        if len(self.map) > self.cap:
            lru = self.head.next        # evict from LRU end
            self._remove(lru)
            del self.map[lru.key]

Min Stack

class MinStack:
    """Stack with O(1) getMin. Use a parallel stack tracking minimums.
    Push min only when new value <= current min (saves space slightly).
    """
    def __init__(self) -> None:
        self.stack: list[int] = []
        self.min_stack: list[int] = []  # parallel min tracking

    def push(self, val: int) -> None:
        self.stack.append(val)
        if not self.min_stack or val <= self.min_stack[-1]:
            self.min_stack.append(val)

    def pop(self) -> None:
        val = self.stack.pop()
        if val == self.min_stack[-1]:
            self.min_stack.pop()

    def top(self) -> int:
        return self.stack[-1]

    def get_min(self) -> int:
        return self.min_stack[-1]

Time-Based Key-Value Store

import bisect

class TimeMap:
    """Key-value store where each key maps to timestamped values.
    get(key, timestamp) returns the value with largest timestamp <= given timestamp.
    Uses binary search on stored timestamps.
    Time: O(log n) get, O(1) set (timestamps always increase per problem spec)
    """
    def __init__(self) -> None:
        # key -> list of (timestamp, value), sorted by timestamp
        self.store: dict[str, list[tuple[int, str]]] = {}

    def set(self, key: str, value: str, timestamp: int) -> None:
        if key not in self.store:
            self.store[key] = []
        self.store[key].append((timestamp, value))

    def get(self, key: str, timestamp: int) -> str:
        if key not in self.store:
            return ""
        entries = self.store[key]
        # Find rightmost timestamp <= given timestamp
        lo, hi = 0, len(entries) - 1
        result = ""
        while lo <= hi:
            mid = (lo + hi) // 2
            if entries[mid][0] <= timestamp:
                result = entries[mid][1]
                lo = mid + 1
            else:
                hi = mid - 1
        return result

Flatten Nested Iterator

class NestedIterator:
    """Flatten a nested list of integers lazily.
    LeetCode 341: each element is an integer OR a nested list.
    Use a stack storing (list_ref, current_index) pairs.
    Time: O(1) amortized per next()/hasNext()  Space: O(depth)
    """
    def __init__(self, nested_list: list) -> None:
        # Stack of (list, index) tuples; start with the outermost list
        self.stack = [(nested_list, 0)]

    def next(self) -> int:
        self.has_next()                 # ensure top of stack is an int
        lst, idx = self.stack.pop()
        self.stack.append((lst, idx + 1))
        return lst[idx].get_integer()

    def has_next(self) -> bool:
        """Advance through nested structure until an integer is at the top."""
        while self.stack:
            lst, idx = self.stack[-1]
            if idx == len(lst):
                self.stack.pop()        # exhausted this list
                continue
            item = lst[idx]
            if item.is_integer():
                return True
            # item is a nested list: push it and advance parent index
            self.stack[-1] = (lst, idx + 1)
            self.stack.append((item.get_list(), 0))
        return False
LeetCode problems
146 LRU Cache · 460 LFU Cache · 155 Min Stack · 981 Time Based Key-Value Store · 341 Flatten Nested List Iterator · 208 Implement Trie · 380 Insert Delete GetRandom O(1) · 295 Find Median from Data Stream
Practice
Challenge

LRU Cache Simplified

Implement get(key) and put(key, value) with a capacity limit. The least recently used entry is evicted when full. put(1,1), put(2,2), get(1)→1, put(3,3) evicts key 2, so get(2)→-1.

from collections import OrderedDict

class LRUCache:
    def __init__(self, capacity):
        self.capacity = capacity
        self.cache = OrderedDict()

    def get(self, key):
        # TODO: if key missing return -1
        # move key to end (most recently used), return value
        pass

    def put(self, key, value):
        # TODO: if key exists, update and move to end
        # if at capacity, remove oldest (first) item
        # insert key at end
        pass

c = LRUCache(2)
c.put(1, 1)
c.put(2, 2)
print(c.get(1))   # 1   (key 1 is now most recent)
c.put(3, 3)       # evicts key 2
print(c.get(2))   # -1  (evicted)
print(c.get(3))   # 3
import java.util.*;

public class Solution {
    static class LRUCache extends LinkedHashMap<Integer, Integer> {
        private final int capacity;

        LRUCache(int capacity) {
            super(capacity, 0.75f, true);  // accessOrder = true
            this.capacity = capacity;
        }

        int get(int key) {
            // TODO: return getOrDefault(key, -1)
            return -1;
        }

        void put(int key, int value) {
            // TODO: call super.put(key, value)
        }

        @Override
        protected boolean removeEldestEntry(Map.Entry<Integer, Integer> eldest) {
            // TODO: return true when size exceeds capacity
            return false;
        }
    }

    public static void main(String[] args) {
        LRUCache c = new LRUCache(2);
        c.put(1, 1);
        c.put(2, 2);
        System.out.println(c.get(1));   // 1
        c.put(3, 3);                    // evicts key 2
        System.out.println(c.get(2));   // -1
        System.out.println(c.get(3));   // 3
    }
}
Python: OrderedDict preserves insertion order. Use move_to_end(key) on every access (get and put). When full, call popitem(last=False) to evict the oldest (leftmost) entry. Java: extend LinkedHashMap with accessOrder=true and override removeEldestEntry to return size() > capacity.
from collections import OrderedDict

class LRUCache:
    def __init__(self, capacity):
        self.capacity = capacity
        self.cache = OrderedDict()

    def get(self, key):
        if key not in self.cache:
            return -1
        self.cache.move_to_end(key)
        return self.cache[key]

    def put(self, key, value):
        if key in self.cache:
            self.cache.move_to_end(key)
        self.cache[key] = value
        if len(self.cache) > self.capacity:
            self.cache.popitem(last=False)

c = LRUCache(2)
c.put(1, 1)
c.put(2, 2)
print(c.get(1))   # 1
c.put(3, 3)       # evicts key 2
print(c.get(2))   # -1
print(c.get(3))   # 3
import java.util.*;

public class Solution {
    static class LRUCache extends LinkedHashMap<Integer, Integer> {
        private final int capacity;

        LRUCache(int capacity) {
            super(capacity, 0.75f, true);
            this.capacity = capacity;
        }

        int get(int key) {
            return getOrDefault(key, -1);
        }

        void put(int key, int value) {
            super.put(key, value);
        }

        @Override
        protected boolean removeEldestEntry(Map.Entry<Integer, Integer> eldest) {
            return size() > capacity;
        }
    }

    public static void main(String[] args) {
        LRUCache c = new LRUCache(2);
        c.put(1, 1);
        c.put(2, 2);
        System.out.println(c.get(1));   // 1
        c.put(3, 3);                    // evicts key 2
        System.out.println(c.get(2));   // -1
        System.out.println(c.get(3));   // 3
    }
}
Challenge

Min Stack

Implement a stack with push, pop, top, and getMin all in O(1). After push(3), push(5): getMin()→3. After push(1): getMin()→1. After pop(): getMin()→3.

class MinStack:
    def __init__(self):
        self.stack = []
        self.min_stack = []  # tracks current min at each level

    def push(self, val):
        # TODO: push val to stack
        # push current min to min_stack
        # (min is val if min_stack empty, else min(val, min_stack[-1]))
        pass

    def pop(self):
        # TODO: pop from both stacks
        pass

    def top(self):
        return self.stack[-1]

    def getMin(self):
        return self.min_stack[-1]

ms = MinStack()
ms.push(3)
ms.push(5)
print(ms.getMin())  # 3
ms.push(1)
print(ms.getMin())  # 1
ms.pop()
print(ms.getMin())  # 3
import java.util.*;

public class Solution {
    static Deque<Integer> stack = new ArrayDeque<>();
    static Deque<Integer> minStack = new ArrayDeque<>();

    static void push(int val) {
        // TODO: push val to stack
        // push current min to minStack
        // (min is val if minStack empty, else Math.min(val, minStack.peek()))
    }

    static void pop() {
        // TODO: pop from both stacks
    }

    static int top() {
        return stack.peek();
    }

    static int getMin() {
        return minStack.peek();
    }

    public static void main(String[] args) {
        push(3);
        push(5);
        System.out.println(getMin());  // 3
        push(1);
        System.out.println(getMin());  // 1
        pop();
        System.out.println(getMin());  // 3
    }
}
Maintain a parallel min_stack. On each push(val), also push min(val, min_stack[-1]) (or just val if min_stack is empty) onto min_stack. On pop(), pop from both stacks. getMin() just peeks at min_stack[-1].
class MinStack:
    def __init__(self):
        self.stack = []
        self.min_stack = []

    def push(self, val):
        self.stack.append(val)
        current_min = val if not self.min_stack else min(val, self.min_stack[-1])
        self.min_stack.append(current_min)

    def pop(self):
        self.stack.pop()
        self.min_stack.pop()

    def top(self):
        return self.stack[-1]

    def getMin(self):
        return self.min_stack[-1]

ms = MinStack()
ms.push(3)
ms.push(5)
print(ms.getMin())  # 3
ms.push(1)
print(ms.getMin())  # 1
ms.pop()
print(ms.getMin())  # 3
import java.util.*;

public class Solution {
    static Deque<Integer> stack = new ArrayDeque<>();
    static Deque<Integer> minStack = new ArrayDeque<>();

    static void push(int val) {
        stack.push(val);
        int currentMin = minStack.isEmpty() ? val : Math.min(val, minStack.peek());
        minStack.push(currentMin);
    }

    static void pop() {
        stack.pop();
        minStack.pop();
    }

    static int top() {
        return stack.peek();
    }

    static int getMin() {
        return minStack.peek();
    }

    public static void main(String[] args) {
        push(3);
        push(5);
        System.out.println(getMin());  // 3
        push(1);
        System.out.println(getMin());  // 1
        pop();
        System.out.println(getMin());  // 3
    }
}
Challenge

Hit Counter

Count hits in the past 5 minutes (300 seconds). hit(1), hit(2), hit(3), getHits(4)→3. After hit(300), getHits(300)→4. getHits(301)→3 (the hit at t=1 has expired).

from collections import deque

class HitCounter:
    def __init__(self):
        self.hits = deque()  # stores timestamps

    def hit(self, timestamp):
        self.hits.append(timestamp)

    def getHits(self, timestamp):
        # TODO: remove all timestamps where timestamp - ts >= 300
        # (they are older than 300 seconds)
        # return remaining count
        pass

hc = HitCounter()
hc.hit(1)
hc.hit(2)
hc.hit(3)
print(hc.getHits(4))    # 3
hc.hit(300)
print(hc.getHits(300))  # 4
print(hc.getHits(301))  # 3
import java.util.*;

public class Solution {
    static Deque<Integer> hits = new ArrayDeque<>();

    static void hit(int timestamp) {
        hits.addLast(timestamp);
    }

    static int getHits(int timestamp) {
        // TODO: remove from front while timestamp - hits.peekFirst() >= 300
        // return hits.size()
        return 0;
    }

    public static void main(String[] args) {
        hit(1); hit(2); hit(3);
        System.out.println(getHits(4));    // 3
        hit(300);
        System.out.println(getHits(300));  // 4
        System.out.println(getHits(301));  // 3
    }
}
Store hit timestamps in a queue (deque). On getHits(t), pop from the left (oldest) while t - hits[0] >= 300 — those hits are outside the 300-second window. Return the remaining queue length. The condition is >= 300, not > 300, because a hit at t=1 should expire by t=301.
from collections import deque

class HitCounter:
    def __init__(self):
        self.hits = deque()

    def hit(self, timestamp):
        self.hits.append(timestamp)

    def getHits(self, timestamp):
        while self.hits and timestamp - self.hits[0] >= 300:
            self.hits.popleft()
        return len(self.hits)

hc = HitCounter()
hc.hit(1)
hc.hit(2)
hc.hit(3)
print(hc.getHits(4))    # 3
hc.hit(300)
print(hc.getHits(300))  # 4
print(hc.getHits(301))  # 3
import java.util.*;

public class Solution {
    static Deque<Integer> hits = new ArrayDeque<>();

    static void hit(int timestamp) {
        hits.addLast(timestamp);
    }

    static int getHits(int timestamp) {
        while (!hits.isEmpty() && timestamp - hits.peekFirst() >= 300) {
            hits.pollFirst();
        }
        return hits.size();
    }

    public static void main(String[] args) {
        hit(1); hit(2); hit(3);
        System.out.println(getHits(4));    // 3
        hit(300);
        System.out.println(getHits(300));  // 4
        System.out.println(getHits(301));  // 3
    }
}

20. Interview Strategy Cheat Sheet

Problem-Solving Framework

StepWhat to doTime
1. UnderstandRestate the problem. Clarify input/output types, constraints, edge cases.2 min
2. ExamplesWork through 2-3 examples including edge cases (empty, single element, negatives).2 min
3. Brute forceState naive \(O(n^2)\) or \(O(2^n)\) solution. Briefly explain why it works.1 min
4. OptimizeIdentify bottleneck. Apply pattern (sort, two pointers, DP, BFS, etc.).5 min
5. CodeWrite clean code top-down. Name variables clearly. No premature optimization.15 min
6. TestTrace through examples manually. Check edge cases. Analyze complexity.5 min

Pattern Recognition Table

Input / ClueLikely Technique
Sorted array, find targetBinary search
Find pair/triplet with target sumTwo pointers (sorted) or hash map
Subarray / substring sum or lengthSliding window or prefix sums
Top K / K-th largestHeap (min-heap of size k) or Quickselect
All subsets / permutations / combinationsBacktracking
Optimal choice at each stepGreedy (prove or disprove first)
Overlapping subproblems, optimal solutionDynamic programming
Shortest path, unweightedBFS
Shortest path, weighted (non-negative)Dijkstra
Connected components / union operationsUnion-Find or BFS/DFS
Tree: level-by-level processingBFS with queue
Tree: path / ancestor / subtreeDFS (often recursive)
Prefix lookups, autocompleteTrie
Monotone property (next greater, range min)Monotonic stack or deque
String with duplicate charactersHash map + sliding window
Find duplicate / missing numberXOR, index marking, or math
Matrix connectivity / regionsDFS/BFS or Union-Find on grid
PalindromeTwo pointers from center or Manacher's

Time Complexity Targets by Input Size

Input nAcceptable complexityTypical algorithm
\(n \le 15\)\(O(2^n)\) or \(O(n!)\)Backtracking, bitmask DP
\(n \le 100\)\(O(n^3)\) or \(O(n^2 \log n)\)Floyd-Warshall, interval DP
\(n \le 1{,}000\)\(O(n^2)\)DP, brute-force with pruning
\(n \le 10{,}000\)\(O(n^2)\) tight, prefer \(O(n \log n)\)\(O(n^2)\) LIS, sorting-based
\(n \le 100{,}000\)\(O(n \log n)\)Sort, heap, binary search, BFS/DFS
\(n \le 10^6\)\(O(n)\) or \(O(n \log n)\)Two pointers, sliding window, prefix
\(n \le 10^9\)\(O(\log n)\) or \(O(1)\)Binary search, math formula

Common Edge Cases Checklist

Always check these before finalizing your solution
  • Empty input: empty list, empty string, n=0
  • Single element: list of length 1, single node tree
  • All same elements: [5, 5, 5, 5]
  • Already sorted / reverse sorted
  • Negative numbers: does your sum/product logic handle them?
  • Integer overflow: Python is safe, but flag it in Java/C++
  • Cyclic graph: does your graph algorithm handle cycles?
  • Disconnected graph: loop over all unvisited nodes
  • Duplicate values: sorted with duplicates, set semantics
  • Target not found: return -1, [], or appropriate sentinel
  • Very large n: stack overflow from deep recursion? use iterative

Communication Tips

# Narrate your thought process as you code.
# Say these phrases:
#
# "Let me start with a brute-force approach to make sure I understand the problem..."
# "I notice this has overlapping subproblems, so DP might work here."
# "My key insight is: [explain the non-obvious observation]."
# "I'll use a [data structure] here because [reason]."
# "The time complexity is O(...) because [explain each term]."
# "Let me trace through this example: input=[...], expected=[...]"
# "An edge case I want to check: what if the input is empty?"
# "I could optimize space further to O(n) by using a rolling array, but
#  I'll keep this version for clarity since we're within the time budget."
#
# During coding:
# - Write function signature first with type hints.
# - Add a one-line docstring stating the approach.
# - Use descriptive names: 'left_ptr' not 'l', 'remaining_sum' not 'rs'.
# - If stuck: "Let me think about what information I need at each step."
Final 5-minute checklist before submitting
1. Does the code compile / run without syntax errors?
2. Have I handled the empty input case?
3. Does the code return the correct type (list vs tuple, int vs float)?
4. Off-by-one: are my loop bounds correct (range(n) vs range(n+1))?
5. Am I mutating input the caller might not expect me to mutate?
6. Is the space complexity within the constraints?
7. Did I state the time and space complexity out loud?
Practice
Predict

Complexity Target

Given n = 100,000 and a 1-second time limit, what is the maximum acceptable time complexity? Run the code to confirm.

# Given n = 100,000 and a 1-second time limit,
# what is the maximum acceptable time complexity?
# Options: O(n^2), O(n log n), O(n), O(log n)

n = 100000
# O(n^2)    = 10,000,000,000 ops — far too slow (~10,000s)
# O(n log n) ~   1,700,000 ops — fast enough (~0.002s)
# O(n)       =     100,000 ops — also fine
# O(log n)   ~          17 ops — trivially fast

print("O(n log n)")
public class Solution {
    public static void main(String[] args) {
        // Given n = 100,000 and a 1-second time limit,
        // what is the maximum acceptable time complexity?
        // Options: O(n^2), O(n log n), O(n), O(log n)

        // O(n^2)    = 10,000,000,000 ops -- far too slow
        // O(n log n) ~   1,700,000 ops -- fast enough
        // O(n)       =     100,000 ops -- also fine
        // O(log n)   ~          17 ops -- trivially fast

        System.out.println("O(n log n)");
    }
}
A modern CPU executes roughly 10^8–10^9 simple operations per second. With n = 100,000: O(n^2) = 10^10 ops (too slow), O(n log n) ≈ 1.7 × 10^6 ops (comfortably within budget). Both O(n) and O(n log n) are acceptable — O(n^2) is the threshold to avoid.
# Given n = 100,000 and a 1-second time limit,
# what is the maximum acceptable time complexity?
# Options: O(n^2), O(n log n), O(n), O(log n)

n = 100000
# O(n^2)    = 10,000,000,000 ops — far too slow (~10,000s)
# O(n log n) ~   1,700,000 ops — fast enough (~0.002s)
# O(n)       =     100,000 ops — also fine
# O(log n)   ~          17 ops — trivially fast

print("O(n log n)")
public class Solution {
    public static void main(String[] args) {
        // Given n = 100,000 and a 1-second time limit,
        // what is the maximum acceptable time complexity?
        // Options: O(n^2), O(n log n), O(n), O(log n)

        // O(n^2)    = 10,000,000,000 ops -- far too slow
        // O(n log n) ~   1,700,000 ops -- fast enough
        // O(n)       =     100,000 ops -- also fine
        // O(log n)   ~          17 ops -- trivially fast

        System.out.println("O(n log n)");
    }
}
Predict

Pattern Recognition

You are given a sorted array and need to find if any pair sums to a target. Which technique is most efficient? Run the code to confirm.

# You're given a sorted array and need to find if a pair sums to target.
# Which technique is most efficient?
# Options: brute force, binary search, two pointers, hash map

# Brute force:   O(n^2) time, O(1) space — too slow
# Binary search: O(n log n) time, O(1) space — decent
# Two pointers:  O(n) time, O(1) space — best for sorted input
# Hash map:      O(n) time, O(n) space — works but wastes space

# Sorted array + pair sum = two pointers
# Left pointer at start, right at end.
# If sum too small, advance left. If too large, retreat right.
print("two pointers")
public class Solution {
    public static void main(String[] args) {
        // Sorted array + pair sum = two pointers
        // Left pointer at start, right at end.
        // If sum too small, advance left. If too large, retreat right.
        // O(n) time, O(1) space -- best for sorted input.
        System.out.println("two pointers");
    }
}
When the array is already sorted, two pointers achieves O(n) time with O(1) space — no extra memory needed. A hash map also gives O(n) time but uses O(n) space unnecessarily. The sorted property is the key signal: it tells you two pointers will work.
# You're given a sorted array and need to find if a pair sums to target.
# Which technique is most efficient?
# Options: brute force, binary search, two pointers, hash map

# Brute force:   O(n^2) time, O(1) space — too slow
# Binary search: O(n log n) time, O(1) space — decent
# Two pointers:  O(n) time, O(1) space — best for sorted input
# Hash map:      O(n) time, O(n) space — works but wastes space

# Sorted array + pair sum = two pointers
# Left pointer at start, right at end.
# If sum too small, advance left. If too large, retreat right.
print("two pointers")
public class Solution {
    public static void main(String[] args) {
        // Sorted array + pair sum = two pointers
        // Left pointer at start, right at end.
        // If sum too small, advance left. If too large, retreat right.
        // O(n) time, O(1) space -- best for sorted input.
        System.out.println("two pointers");
    }
}
Predict

Choose Data Structure

You need O(1) lookup, O(1) insert, AND O(1) delete-any-element. Which combination of data structures achieves all three? Run the code to confirm.

# You need O(1) lookup, O(1) insert, AND O(1) delete-random.
# Which combination of data structures achieves this?

# Array alone:    O(1) lookup/insert-end, but O(n) delete (must shift)
# HashMap alone:  O(1) lookup/insert/delete, but no O(1) random access by index
# LinkedList:     O(1) insert/delete (with pointer), but O(n) lookup

# Solution: HashMap + ArrayList with the "swap trick"
#   HashMap: value -> index in ArrayList
#   Delete(val): swap with last element, pop last, update map entry
#   This makes delete O(1) amortized!
print("HashMap + ArrayList")
public class Solution {
    public static void main(String[] args) {
        // You need O(1) lookup, O(1) insert, AND O(1) delete-random.

        // Array alone:   O(1) lookup/insert-end, O(n) delete (shifting)
        // HashMap alone: O(1) all ops, but no random index access
        // LinkedList:    O(1) insert/delete (with ptr), O(n) lookup

        // Solution: HashMap + ArrayList with the "swap trick"
        //   HashMap: value -> index in ArrayList
        //   Delete(val): swap with last, pop last, update map
        //   Result: all three ops are O(1) amortized
        System.out.println("HashMap + ArrayList");
    }
}
The classic "RandomizedSet" trick: store elements in an ArrayList for O(1) random access, and keep a HashMap from value to its index in the list. To delete: look up the index in the map, swap that element with the last element in the list (updating the map for the swapped element), then pop the last element and remove from the map. No shifting needed — O(1).
# You need O(1) lookup, O(1) insert, AND O(1) delete-random.
# Which combination of data structures achieves this?

# Array alone:    O(1) lookup/insert-end, but O(n) delete (must shift)
# HashMap alone:  O(1) lookup/insert/delete, but no O(1) random access by index
# LinkedList:     O(1) insert/delete (with pointer), but O(n) lookup

# Solution: HashMap + ArrayList with the "swap trick"
#   HashMap: value -> index in ArrayList
#   Delete(val): swap with last element, pop last, update map entry
#   This makes delete O(1) amortized!
print("HashMap + ArrayList")
public class Solution {
    public static void main(String[] args) {
        // You need O(1) lookup, O(1) insert, AND O(1) delete-random.

        // Array alone:   O(1) lookup/insert-end, O(n) delete (shifting)
        // HashMap alone: O(1) all ops, but no random index access
        // LinkedList:    O(1) insert/delete (with ptr), O(n) lookup

        // Solution: HashMap + ArrayList with the "swap trick"
        //   HashMap: value -> index in ArrayList
        //   Delete(val): swap with last, pop last, update map
        //   Result: all three ops are O(1) amortized
        System.out.println("HashMap + ArrayList");
    }
}