AoC 2017 Day 7: Tree

Source: Recursive Circus

Part 1: A tree is defined as such:

  • node (weight) -> child1, child2, ...
  • node (weight)

Where a node always has a weight, but may or may not have child nodes.

What is the name of the root node of the tree (the node without a parent)?

First, we want to load the tree into a useful format. We’ll keep a map of parent to children along with children back to their parents, just in case we need to go either direction:

names = set()
weight_map = {}
child_map = {}
parent_map = {}

for line in lib.input():
    name, weight, children = re.match(r'(\w+) \((\d+)\)(?: -> (.*))?', line).groups()

    weight_map[name] = int(weight)

    if children:
        for child in children.split(', '):
            child_map.setdefault(name, set()).add(child)
            parent_map[child] = name

This is the sort of thing that could easily be stored in an object , but that seems overkill at this point.

As noted in the problem, the root node is the only one without a parent:

for name in names:
    if name not in parent_map:
        root = name

print('root: {}'.format(root), end = '; ')

Part 2: A node is balanced if all of its children have the same weight (including their children, recursively). One node in the tree is the wrong weight to be balanced. What would that node’s weight need to be to balance everything?

This is a pretty interesting problem. Finding the weight of a particular node plus its children is a nice example of recursive at work:

def total_weight(node):
    '''Return the weight of this node + the sum of all children.'''

    return weight_map[node] + sum(total_weight(child) for child in child_map.get(node, []))

Using this, we can write a recursive generator to iterate through the tree, find the node that isn’t balanced, and try to fix it:

def fix_balance(node):
    '''Fix the balance from this node (recursively).'''

    # If we have no children, no point in balancing
    if not child_map.get(node):

    # Collect a map of weight to set of children with that weight
    # We're unbalanced if this map has two keys, one with a single value
    weights = collections.defaultdict(set)
    for child in child_map[node]:
        yield from fix_balance(child)

    # If we only have a single weight, this node is not unbalanced
    if len(weights) == 1:

    # Otherwise, figure out which node is unbalanced (the single mismatched weight)
    for weight, children in weights.items():
        if len(children) == 1:
            unbalanced_node = list(children)[0]
            unbalanced_weight = weight
            balanced_weight = weight

    # Balance it
    weight_map[unbalanced_node] += balanced_weight - unbalanced_weight
    yield unbalanced_node, weight_map[unbalanced_node]

The comments mostly note what we’re doing at each state. One neat bit is that this function is more powerful than it needs to be. If there were many nodes out of balance, this would fix them all. Since there’s only one, we just need to print out the single value this function will yield:

for node, new_weight in fix_balance(root):
    print('{} -> {}'.format(node, new_weight))

It’s quick too:

$ python3 day-07

day-07  python3 input.txt       0.10468697547912598     root: azqje; rfkvap -> 646