Skip to main content

Algorithms

Union Find

Union Find AKA Disjoint Set Union (DSU) is used to manage and combine disjoint sets. It is named after its two main operations. Starting with elements x1,x2,,xnx_1, x_2, \ldots, x_n each in its own set {x1},{x2},,{xn}\{x_1\}, \{x_2\}, \ldots, \{x_n\}, you can combine any two sets and determine which set an element is in:

  1. Union: given element a and element b, union the sets that a and b belong to
  2. Find: given element v find the set it is in

In pseudo code, it is usually composed:

  1. make_set(v) - creates a new set consisting of element v
  2. union_sets(a, b) - merges the two specified sets that a and b are contained in
  3. find_set(v) - return the representative ("leader") of the set that contains v. The representative is an element within the set. a and b are in the same set if find_set(a) == find_set(b).
Key Idea

Each set is delineated by a representative element. Elements share the same representative iff they belong to the same set.

A tree data structure is used for organizing these sets so that we can find the shared representative quickly. Each node or element in the tree has a parent. A crucial part of the implementation, is that the root node's parent is itself. Hence, a representative is found simply by checking parent[v] = v.

  • Each element starts of as the vertex of its own tree
  • Maintain an array parent with parent[x] = y meaning the parent node of x is y

unionfind

Naive implementation:

  • make_set(v) - this is simple, just create tree with root as v
def make_set(v):
parent[v] = v
  • find_set(v) - traverse ancestors of v until you reach the root (parent is itself)
def find_set(v):
if parent[v] == v:
return v
return find_set(parent[v])
  • union_sets(a, b) - find the representative of a and the representative of b. If they are identical do nothing, otherwise, make the parent of one representative the other
def union_sets(a, b):
rep_a = find_set(a)
rep_b = find_set(b)
if rep_a != rep_b:
parent[rep_b] = rep_a

The problem here is you can create degenerate trees that are just a long chain, in which find_set(a) takes O(n)\mathcal{O}(n) and so union_sets(a, b) also takes O(n)\mathcal{O}(n)

Optimization 1: Compress Trees

To speed up find_set(a), on the traversal from a leaf node up to the representative, set the parent of each to the representative. Thus, we flatten the tree and avoid degenerate long trees.

def find_set(v):
if parent[v] == v:
return v

parent[v] = find_set(parent[v])
return parent[v]

This results in find_set(a) time complexity O(logn)\mathcal{O}(\log n).

Optimization 2: Union by size / rank

Another problem is that in the naive implementation we always attach the second tree to the first tree. This might not be optimal as attaching a large tree to a small one can lead to degenerate trees.

There are two approaches to ensure we attach a smaller tree (lower rank) to a bigger tree (higher rank):

  1. Use size of trees as rank

We will maintain a array size that stores the size of the tree for a root vertex v:

def make_set(v):
parent[v] = v
size[v] = 1

def union_sets(a, b):
rep_a = find_set(a)
rep_b = find_set(b)

if rep_a != rep_b:
if size[rep_a] >= size[rep_b]:
parent[rep_b] = rep_a
size[rep_a] += size[rep_b]
else:
parent[rep_a] = rep_b
size[rep_b] += size[rep_a]

  1. Use depth of trees as rank

Similarly maintain an array rank that stores the depth of the tree with root vertex v:

def make_set(v):
parent[v] = v
rank[v] = 0

def union_sets(a, b):
rep_a = find_set(a)
rep_b = find_set(b)

if rep_a != rep_b:
# NOTE: You do not sum depths here because the root of one tree is made the parent of the root of another
# Also if a tree with smaller depth is added to a larger depth tree this way, the depth does not change
# However if the two trees are the same depth, only increment the depth by 1!
if (rank[rep_a] >= rank[rep_b]):
parent[rep_b] = rep_a
if rank[rep_a] == rank[rep_b]:
rank[rep_a] += 1
else:
parent[rep_a] = rep_b

Time Complexity:

With both optimizations, worst case is O(logn)\mathcal{O}(\log n), but on average, you have constant time queries with O(α(n))\mathcal{O}(\alpha(n)) where α(n)<4\alpha(n) < 4 for approximately all n<10600n < 10^{600}.

Resources:

Connected Components

A specific implementation for merging overlapping sets found here.

Task: Given a number of sets, union any overlapping sets until you have only disjoint or non-overlapping sets left.

For example, given the sets {7, 6}, {5, 4, 3, 2, 1, 0}, {8, 1, 0}, {2, 9}, {8, 9}, the resulting disjoint sets should be {6, 7}, {0, 1, 2, 3, 4, 5, 8, 9}.

We will call an initial set a component and the resulting connected components groups. We maintain the following lists for book keeping:

  • component - list of length n_items to indicate the component each item belongs to, so component[7] = 0 in the above example.
  • group - a DAG represented by list of length n_components to indicate the component each component shares a group with. group[1] = 2 indicates that component 1 is in the same group as component 2. If group[2] = 2 then it points to itself and component 2 is in its own group. Initialize with the assumption that each component is disjoint so group[component_i] = component_i.
Key Idea

To get the group of component i, if group[i] != i, then you must recursively jump through the pointers group[group[i]] until you get a component that points to itself. This is the "representative" of the actual group (see union find).

For example, group = [0, 2, 3, 3] means that component 1 shares a group with group 2 which shares a group with group 3, and the group is represented by 3.

The algorithm consists of the following steps:

  • Assign components integers from 0 to n and walk over them sequentially.
  • First walk over all components and their items, and mark them in components. That is on component j, assign component[i] = j for item i in component j.
  • When an items component has already been marked so component[cur_item] is equal to some previous k such that k < j, this means you have identified an overlap. Mark the two components as sharing the same group. This is done by making the representative of the group of k point to j.
  • Once all components have been walked, go from the end of the group list to the beginning, propagating the value of the representative back, so instead of indirect pointers, each group[i] is now equal to the representative.
Key Idea

By enforcing the rule that we jump to the representative component of each group, we create a DAG so that previous information and connectivity of groups is protected as we continue to add new components to pre-existing groups.

Let's walk through the tricky part of updating the groups with DAGs:

  • Say you are on component 2, and halfway through it's items are unmarked in the component list so you simply do component[i] = 2, but then you encounter component[i] = 1 so the item already belongs to component 1. Now we want to indicate that component 1 and 2 share the same group. We do this by setting group[1] = 2. Say later on in component 5 we come across component[i] = 1, if we do group[1] = 5 then the information of component 1 and 2 sharing a group is lost. Hence we need to traverse the DAG and get to group[group[1]] = 5. This way component 1 points to component 2 which points to component 5, preserving prior connectivity information. We illustrate this in Example 1 below.

connectedc

Optimization: Breaking Up Chains in group

Supposing that the item is already marked by a previous component, we currently do the following recursive hopping:

j_component = component[item]

# suppose j_component != current i_component
# jump through DAG to representative and set its group to current component
while True:
k_group = group[j_component]

if k_group == j_component:
group[j_component] = i_component
break
else:
j_component = k_group

We can break up the possibly long chains of recursion in the group array from this naive implementation by each time reassigning the group pointers in the chain during the traversal to point to the newest (largest) representative aka the current component.

j_component = component[item]
while True:
k_group = group[j_component]
group[j_component] = i_component

if k_group == j_component:
break

j_component = k_group

Finally with everything put together our code looks like:

def connected_components(components, n_items):
"""
Parameters
----------
components : Iterable[Iterable[int]]
n_items : int
"""
n_components = len(components)

# Algorithm here
component = [-1 for _ in range(n_items)]
group = list(range(n_components))

for i_component, component_items in enumerate(components):
for item in component_items:
j_component = component[item]

if j_component < 0:
# item has not been marked
component[item] = i_component
else:
# already marked item!
if group[j_component] == i_component:
# the representative of the j_component is already the current component
# due to a previous item - so no need to do anything
continue

# jump through DAG to representative and set its group to current component
while True:
k_group = group[j_component]
group[j_component] = i_component

if k_group == j_component:
break

j_component = k_group

# Backward pass through group to propagate direct pointers
for i in reversed(range(n_components)):
if (k := group[i]) != i:
group[i] = group[k]

# Finish up by aggregating components into disjoint sets
groups = defaultdict(set)

for i_comp, i_grp in enumerate(group):
groups[i_grp].update(components[i_comp])

return list(groups.values())

Time complexity:

The algorithm requires one pass forward through the items with each component needing some recursive jumping through the DAG per item, and one pass backward through the group array.

In the naive degenerate case where each component shares an item with the first component, you will need to hop through the entire group array each time, leading to O(n2)\mathcal{O}(n^2). In the best case, you do not need any recursive jumps, so it is O(n)\mathcal{O}(n).

With the optimization this is less problematic, though it does not always prevent long chains from forming.