Source: Coprocessor Conflagration
Part 1: Create a variation of the previous DuetVM with only the following four instructions:
set X Y
sets registerX
toY
sub X Y
set registerX
toX - Y
mul X Y
sets registerX
toX * Y
jnz X Y
jumps with an offset of the value ofY
, iffX
is not equal to zero
If you run the given program, how many times is
mul
invoked?
Interesting.
We could definitely use the same code as before, implementing the instructions as such:
@VM.register
def set(vm, x, y):
vm.registers[x] = vm.value(y)
@VM.register
def sub(vm, x, y):
vm.registers[x] -= vm.value(y)
@VM.register
def mul(vm, x, y):
vm.mul_count = getattr(vm, 'mul_count', 0) + 1
vm.registers[x] *= vm.value(y)
@VM.register
def jnz(vm, x, y):
if vm.value(x) !- 0:
vm.pc += vm.value(y) - 1
Just print out vm.mul_count
at the end and we’re good. But where’s the fun in that?
Instead, let’s implement a compiler to turn DuetVM code into Python. 😇 It’s probably a bad idea, but we’re just going to use a series of regular expressions that will match the input program and rewrite sections one step at a time. For example:
@compile_step
def rewrite_simple_binops(code):
'''Rewrite simple binary ops that have a direct python equivalent.'''
for name, symbol in [('set', ''), ('sub', '-'), ('mul', '*')]:
code = re.sub(
r'{} ([a-h]) ([a-h]|-?\d+)'.format(name),
r'\1 {}= \2'.format(symbol),
code
)
return code
This will take any of the three basic commands and directly rewrite them as the Python equiavlent. set X Y
becomes X = Y
, sub X Y
becomes X -= Y
(and the same for mul
). Easy enough.
compile_step
is a decorator which essentially stores the step to be used later:
compilation_steps = []
def compile_step(f):
functools.wraps(f)
def new_f(code):
return f(code)
compilation_steps.append(new_f)
return f
Next, we want to get rid of the relative jumps. First, Python doesn’t have relative jumps (well, it doesn’t have direct jumps at all, we’ll get to that), and second as we rewrite code, the relative jumps will change. So let’s make them absolute jumps instead, specified (for now) by comments:
@compile_step
def rewrite_jumps_as_absolute(code):
'''Rewrite jumps in an absolute form with a label at the destination.'''
labels = {}
lines = code.split('\n')
for index, line in enumerate(lines):
m = re.match(r'jnz ([a-h]|-?\d+) (-?\d+)', line)
if m:
register, offset = m.groups()
offset = int(offset)
if 0 <= index + offset < len(lines):
label = 'L{}'.format(len(labels))
labels[label] = index + offset
lines[index] = 'jnz {} {}'.format(register, label)
else:
# Out of bounds jumps will instead compile into a `halt` (stop execution)
lines[index] = 'sys.exit(0)'
for label, index in labels.items():
lines[index] = '# {}\n{}'.format(label, lines[index])
return '\n'.join(lines)
As an example:
>>> code = '''
... set a 2
... set b 3
... jnz a 2
... add a b
... mul a b
... '''
>>> print(rewrite_jumps_as_absolute(code))
set a 2
set b 3
jnz a L0
add a b
# L0
mul a b
It’s no longer either DuetVM code or Python, but rather something in between. In an optimal case, it would be best if every stage of the compiler were runnable so that we can test that the functionality of each step is exactly the same, but that’s just not something I’m going to do this time around.
Next, we he have a few different possible styles of jumps. The easiest two are jumps that don’t overlap any other jumps, going either forward or backwards. For the case of a jump going forward, it’s basically an if
statement:
def rewrite_simple_if(code):
'''Rewrite non-nested foward jumps as simple ifs.'''
def make_if(m):
indentation, register, label, body = m.groups()
return '{}if {} == 0:\n{}'.format(
indentation,
register,
'\n'.join(' ' + line for line in body.split('\n')),
)
return re.sub(
r'(\s*)jnz ([a-h]|-?\d+) (L\d+)\n((?!jnz|# L).*)\n\1# \3',
make_if,
code,
flags = re.DOTALL
)
The regular expressions get a bit complicated here sadly, but they’re still more or or less readable:
Part | Group | Description |
---|---|---|
(\s*) | \1 is indentation | Start by matching whitespace, this is the current indentation (nested if statements) |
jnz ([a-h]|-?\d+) (L\d+) | \2 is the conditional, \3 is the label | Match the forward jump, storing the conditional and label |
\n | Start the block | |
((?!jnz|# L).*) | Match any content so long as you do not see either jnz instructions or a # L label comment; nested if s are fine | |
\n | Finish the block | |
\1# \3 | Match the closing label comment. It has the same indentation and label as the jnz instruction |
When we rewrite that, we want to generate the condition if {register} == 0:
and then intent each line in the block by an additional four spaces. So if we apply this to the code we had earlier:
>>> code = '''
... set a 2
... set b 3
... jnz a 2
... add a b
... mul a b
... '''
>>> print(rewrite_simple_if(rewrite_jumps_as_absolute(code)))
set a 2
set b 3
if a == 0:
add a b
mul a b
Next up, let’s rewrite non-overlapping backward loops.
def rewrite_simple_while(code):
'''Rewrite non-nested backward jumps as while loops with a flag.'''
def make_while(m):
indentation, label, body, register = m.groups()
return '''\
{indentation}# {label}
{indentation}while True:
{body}
{indentation} if {register} == 0: break\
'''.format(
indentation = indentation,
register = register,
label = label,
body = '\n'.join(' ' + line for line in body.split('\n')),
)
return re.sub(
r'(\s*)# (L\d+)\n((?!jnz|# L).*)\n\1jnz ([a-h]|-?\d+) \2',
make_while,
code,
flags = re.DOTALL
)
This one is a bit more complicated, since we want the code to run at least one time. If only we had a do while
loop. But instead, all we have to do is make it while True
and put the actual conditional and a break
at the end of the loop.
Example:
>>> print(rewrite_jumps_as_absolute(code))
set a 2
set b 3
# L0
add a b
sub b 1
jnz b L0
>>> print(rewrite_jumps(rewrite_jumps_as_absolute(code)))
set a 2
set b 3
# L0
while True:
add a b
sub b 1
if b == 0: break
Finally, let’s rewrite one simple case that we have in our actual example code, and if-not
. Something like this:
jnz a 2
jnz 1 5
mul b 100
sub b -100000
set c b
sub c -17000
set f 1
How to read that is if a != 0
, we want to skip the second jump and run the rest of the code. If a == 0
, then we hit the second jump and always skip over five instructions (1 != 0
is always true). Logically, that means that if a != 0
, we’ll run the instructions but if a == 2
, we will not. So we can rewrite that as so:
def rewrite_if_not(code):
'''Rewrite an overlapping pair of forward jumps as if/else.'''
def make_if_not(m):
indentation, register, label_1, label_2, body = m.groups()
return '''\
{indentation}if {register} != 0:
{body}\
'''.format(
indentation = indentation,
register = register,
body = '\n'.join(' ' + line for line in body.split('\n')),
)
return re.sub(
r'(\s*)jnz ([a-h]) (L\d+)\n\1jnz (?!0)-?\d+ (L\d+)\n\1# \3\n((?!jnz|# L).*)\n\1# \4',
make_if_not,
code,
flags = re.DOTALL
)
It’s similar to the if
statements, but the regular expression is a bit more complicated, since we’re matching two back to back jumps and have to have two closing labels. In practice:
>>> code = '''
... set a 2
... set b 3
... jnz a 2
... jnz 1 5
... mul b 100
... sub b -100000
... set c b
... sub c -17000
... set f 1
... '''
>>> print(rewrite_jumps(rewrite_jumps_as_absolute(code)))
set a 2
set b 3
if a != 0:
mul b 100
sub b -100000
set c b
sub c -17000
set f 1
That’s pretty neat.
One last thing we want to do is deal with potentially nested if statements. To do that, we just want to run the above three functions over and over again until the current code stops changing:
@compile_step
def rewrite_jumps(code):
'''Using the previous step, rewrite jumps.'''
# Keep running these functions until we reach a stable state
functions = [
rewrite_simple_if,
rewrite_if_not,
rewrite_simple_while,
]
while True:
new_code = code
for f in functions:
new_code = f(new_code)
if code == new_code:
break
else:
code = new_code
return code
That’s actually enough to make runnable Python code:
# BASH$ cat input.txt
set b 65
set c b
jnz a 2
jnz 1 5
mul b 100
sub b -100000
set c b
sub c -17000
set f 1
set d 2
set e 2
set g d
mul g e
sub g b
jnz g 2
set f 0
sub e -1
set g e
sub g b
jnz g -8
sub d -1
set g d
sub g b
jnz g -13
jnz f 2
sub h -1
set g b
sub g c
jnz g 2
jnz 1 3
sub b -17
jnz 1 -23
# BASH$ python3 duetvmc.py input.txt
# Final code:
b = 65
c = b
if a != 0:
mul_count += 1; b *= 100
b -= -100000
c = b
c -= -17000
# L7
while True:
f = 1
d = 2
# L4
while True:
e = 2
# L3
while True:
g = d
g *= e
g -= b
if g == 0:
f = 0
e -= -1
g = e
g -= b
if g == 0: break
d -= -1
g = d
g -= b
if g == 0: break
if f == 0:
h -= -1
g = b
g -= c
if g == 0:
sys.exit(0)
b -= -17
if 1 == 0: break
Okay. It’s not quite done. We still need three more things:
- Initialize the registers
- Import
sys
so we can runsys.exit(0)
towards the end - Count
mul
instructions - Print the final value of that count, either if we
sys.exit(0)
or reach the end of the program
To do that, we can add two more compile steps:
@compile_step
def add_debug_statements(code):
code = re.sub(r'([a-h]) \*=', r'mul_count += 1; \1 *=', code)
code = re.sub(r'sys.exit', 'print(mul_count); sys.exit', code)
code = 'mul_count = 0\n' + code + '\nprint(mul_count)'
return code
@compile_step
def add_initial_registers(code):
return 'a = b = c = d = e = f = g = h = 0\n' + code
That leaves our actual final code as:
# Final code:
import sys
a = b = c = d = e = f = g = h = 0
mul_count = 0
b = 65
c = b
if a != 0:
mul_count += 1; b *= 100
b -= -100000
c = b
c -= -17000
# L7
while True:
f = 1
d = 2
# L4
while True:
e = 2
# L3
while True:
g = d
mul_count += 1; g *= e
g -= b
if g == 0:
f = 0
e -= -1
g = e
g -= b
if g == 0: break
d -= -1
g = d
g -= b
if g == 0: break
if f == 0:
h -= -1
g = b
g -= c
if g == 0:
print(mul_count); sys.exit(0)
b -= -17
if 1 == 0: break
print(mul_count)
Which we can directly run with the python
executable:
$ python3 duetvmc.py input.txt | python3
3969
That’s just cool.
Part 2: Set
a = 1
. What is the final value ofh
?
Well. We could run the initial program again. That would work. The problem is that it would take a really long time to finish (I tried it).
We could do a bunch of additional optimizations to try to make the code run more quickly, but for this one case, I decided just to actually look at the code and try to figure it out myself. Specifically, everything after # L7
is something we need to optimize.
First, the innermost block:
# L3
while True:
g = d
mul_count += 1; g *= e
g -= b
if g == 0:
f = 0
e -= -1
g = e
g -= b
if g == 0: break
If you start at that for a while, you have a divisibility check. If g
is an exact multiple of e
, f
will be set to 0
. Otherwise, f
will be 1
(set in L7
). So that loop is more or less equivalent to:
if g % e == 0:
f = 0
Going out one layer, we have:
# L4
while True:
e = 2
if g % e == 0:
f = 0
d -= -1
g = d
g -= b
if g == 0: break
This is a loop from 2
up to g
:
for e in range(2, g):
if g % e == 0:
f = 1
And finally, one level more:
# L7
while True:
f = 1
d = 2
for e in range(2, g):
if g % e == 0:
f = 1
if f == 0:
h -= -1
g = b
g -= c
if g == 0:
print(mul_count); sys.exit(0)
b -= -17
if 1 == 0: break
Two things to recognize here. First, this is another loop on g
from b
to c
but with a step size of 17
(the b -= 17
line). Also, we’re reading from f
here. If it’s set, add one to h
. Since we can never unset f
back to one once it’s zero, we can rewrite the inner loop as h += 1
with a break
as another speed boost. That gives us:
for g in range(b, c + 1, 17):
for e in range(2, g):
if g % e == 0:
h += 1
break
print(h)
That looks familiar. Essentially, it’s doing the inverse of a primality test. For each g
, if g
is a composite number, add 1 to h
. Adding this all as a hacky compiler step:
# HACK
@compile_step
def replace_with_composite_counter(code):
def comment_out(m):
return '\n'.join('# ' + line for line in m.group(1).split('\n')) + '\n'
code = re.sub(r'(# L7\n.*)', comment_out, code, flags = re.DOTALL)
code += '''
# Turns out the code is checking for composite numbers... very inefficiently
# a != 0 sets up the range on b and c (note the off by one for c...)
# The L7 loop is looping from b to c by 17 (the b -= -17 at the end)
# The L4 loop is looping from 2 to g, setting f = 0 if g is divisible by the given numbers
# - NOTE: The original loop doesn't bail out early, which helps speed up a fair bit
# The L3 loop is doing the trial division ()
for g in range(b, c + 1, 17):
for e in range(2, g):
if g % e == 0:
h += 1
break
print(h)
'''
return code
Which gives us a final result of:
$ python3 duetvmc.py input.txt --part 2
# Final code:
import sys
a = 1
b = c = d = e = f = g = h = 0
b = 65
c = b
if a != 0:
b *= 100
b -= -100000
c = b
c -= -17000
# # L7
# while True:
# f = 1
# d = 2
# # L4
# while True:
# e = 2
# # L3
# while True:
# g = d
# g *= e
# g -= b
# if g == 0:
# f = 0
# e -= -1
# g = e
# g -= b
# if g == 0: break
# d -= -1
# g = d
# g -= b
# if g == 0: break
# if f == 0:
# h -= -1
# g = b
# g -= c
# if g == 0:
# sys.exit(0)
# b -= -17
# if 1 == 0: break
# Turns out the code is checking for composite numbers... very inefficiently
# a != 0 sets up the range on b and c (note the off by one for c...)
# The L7 loop is looping from b to c by 17 (the b -= -17 at the end)
# The L4 loop is looping from 2 to g, setting f = 0 if g is divisible by the given numbers
# - NOTE: The original loop doesn't bail out early, which helps speed up a fair bit
# The L3 loop is doing the trial division ()
for g in range(b, c + 1, 17):
for e in range(2, g):
if g % e == 0:
h += 1
break
print(h)
Which runs much more quickly (especially if we leverage PyPy):
$ python3 run-all.py day-23
day-23 python3 duetvmc.py input.txt --part 1 | python3 0.10982990264892578 3969
day-23 python3 duetvmc.py input.txt --part 2 | python3 1.5763399600982666 917
day-23 python3 duetvmc.py input.txt --part 2 | pypy 0.42662978172302246 917
That was fun. I want to write a much bigger ‘real’ compiler now.