Source: Recursive Circus
Part 1: A tree is defined as such:
node (weight) -> child1, child2, ...
node (weight)
Where a
node
always has a weight, but may or may not have child nodes.
What is the name of the root
node
of the tree (the node without a parent)?
First, we want to load the tree into a useful format. We’ll keep a map of parent to children along with children back to their parents, just in case we need to go either direction:
names = set()
weight_map = {}
child_map = {}
parent_map = {}
for line in lib.input():
name, weight, children = re.match(r'(\w+) \((\d+)\)(?: -> (.*))?', line).groups()
names.add(name)
weight_map[name] = int(weight)
if children:
for child in children.split(', '):
child_map.setdefault(name, set()).add(child)
parent_map[child] = name
This is the sort of thing that could easily be stored in an object
, but that seems overkill at this point.
As noted in the problem, the root node is the only one without a parent:
for name in names:
if name not in parent_map:
root = name
print('root: {}'.format(root), end = '; ')
Part 2: A node is balanced if all of its children have the same weight (including their children, recursively). One node in the tree is the wrong weight to be balanced. What would that node’s weight need to be to balance everything?
This is a pretty interesting problem. Finding the weight of a particular node plus its children is a nice example of recursive at work:
def total_weight(node):
'''Return the weight of this node + the sum of all children.'''
return weight_map[node] + sum(total_weight(child) for child in child_map.get(node, []))
Using this, we can write a recursive generator to iterate through the tree, find the node that isn’t balanced, and try to fix it:
def fix_balance(node):
'''Fix the balance from this node (recursively).'''
# If we have no children, no point in balancing
if not child_map.get(node):
return
# Collect a map of weight to set of children with that weight
# We're unbalanced if this map has two keys, one with a single value
weights = collections.defaultdict(set)
for child in child_map[node]:
weights[total_weight(child)].add(child)
yield from fix_balance(child)
# If we only have a single weight, this node is not unbalanced
if len(weights) == 1:
return
# Otherwise, figure out which node is unbalanced (the single mismatched weight)
for weight, children in weights.items():
if len(children) == 1:
unbalanced_node = list(children)[0]
unbalanced_weight = weight
else:
balanced_weight = weight
# Balance it
weight_map[unbalanced_node] += balanced_weight - unbalanced_weight
yield unbalanced_node, weight_map[unbalanced_node]
The comments mostly note what we’re doing at each state. One neat bit is that this function is more powerful than it needs to be. If there were many nodes out of balance, this would fix them all. Since there’s only one, we just need to print out the single value this function will yield
:
for node, new_weight in fix_balance(root):
print('{} -> {}'.format(node, new_weight))
It’s quick too:
$ python3 run-all.py day-07
day-07 python3 tree.py input.txt 0.10468697547912598 root: azqje; rfkvap -> 646