AoC 2017 Day 23: Duetvmc

Source: Coprocessor Conflagration

Part 1: Create a variation of the previous DuetVM with only the following four instructions:

  • set X Y sets register X to Y
  • sub X Y set register X to X - Y
  • mul X Y sets register X to X * Y
  • jnz X Y jumps with an offset of the value of Y, iff X is not equal to zero

If you run the given program, how many times is mul invoked?

Interesting.

We could definitely use the same code as before, implementing the instructions as such:

@VM.register
def set(vm, x, y):
    vm.registers[x] = vm.value(y)

@VM.register
def sub(vm, x, y):
    vm.registers[x] -= vm.value(y)

@VM.register
def mul(vm, x, y):
    vm.mul_count = getattr(vm, 'mul_count', 0) + 1
    vm.registers[x] *= vm.value(y)

@VM.register
def jnz(vm, x, y):
    if vm.value(x) !- 0:
        vm.pc += vm.value(y) - 1

Just print out vm.mul_count at the end and we’re good. But where’s the fun in that?

Instead, let’s implement a compiler to turn DuetVM code into Python. 😇 It’s probably a bad idea, but we’re just going to use a series of regular expressions that will match the input program and rewrite sections one step at a time. For example:

@compile_step
def rewrite_simple_binops(code):
    '''Rewrite simple binary ops that have a direct python equivalent.'''

    for name, symbol in [('set', ''), ('sub', '-'), ('mul', '*')]:
        code = re.sub(
            r'{} ([a-h]) ([a-h]|-?\d+)'.format(name),
            r'\1 {}= \2'.format(symbol),
            code
        )
    return code

This will take any of the three basic commands and directly rewrite them as the Python equiavlent. set X Y becomes X = Y, sub X Y becomes X -= Y (and the same for mul). Easy enough.

compile_step is a decorator which essentially stores the step to be used later:

compilation_steps = []

def compile_step(f):
    functools.wraps(f)
    def new_f(code):
        return f(code)

    compilation_steps.append(new_f)
    return f

Next, we want to get rid of the relative jumps. First, Python doesn’t have relative jumps (well, it doesn’t have direct jumps at all, we’ll get to that), and second as we rewrite code, the relative jumps will change. So let’s make them absolute jumps instead, specified (for now) by comments:

@compile_step
def rewrite_jumps_as_absolute(code):
    '''Rewrite jumps in an absolute form with a label at the destination.'''

    labels = {}
    lines = code.split('\n')

    for index, line in enumerate(lines):
        m = re.match(r'jnz ([a-h]|-?\d+) (-?\d+)', line)
        if m:
            register, offset = m.groups()
            offset = int(offset)

            if 0 <= index + offset < len(lines):
                label = 'L{}'.format(len(labels))
                labels[label] = index + offset
                lines[index] = 'jnz {} {}'.format(register, label)
            else:
                # Out of bounds jumps will instead compile into a `halt` (stop execution)
                lines[index] = 'sys.exit(0)'

    for label, index in labels.items():
        lines[index] = '# {}\n{}'.format(label, lines[index])

    return '\n'.join(lines)

As an example:

>>> code = '''
... set a 2
... set b 3
... jnz a 2
... add a b
... mul a b
... '''

>>> print(rewrite_jumps_as_absolute(code))

set a 2
set b 3
jnz a L0
add a b
# L0
mul a b

It’s no longer either DuetVM code or Python, but rather something in between. In an optimal case, it would be best if every stage of the compiler were runnable so that we can test that the functionality of each step is exactly the same, but that’s just not something I’m going to do this time around.

Next, we he have a few different possible styles of jumps. The easiest two are jumps that don’t overlap any other jumps, going either forward or backwards. For the case of a jump going forward, it’s basically an if statement:

def rewrite_simple_if(code):
    '''Rewrite non-nested foward jumps as simple ifs.'''

    def make_if(m):
        indentation, register, label, body = m.groups()
        return '{}if {} == 0:\n{}'.format(
            indentation,
            register,
            '\n'.join('    ' + line for line in body.split('\n')),
        )

    return re.sub(
        r'(\s*)jnz ([a-h]|-?\d+) (L\d+)\n((?!jnz|# L).*)\n\1# \3',
        make_if,
        code,
        flags = re.DOTALL
    )

The regular expressions get a bit complicated here sadly, but they’re still more or or less readable:

PartGroupDescription
(\s*)\1 is indentationStart by matching whitespace, this is the current indentation (nested if statements)
jnz ([a-h]|-?\d+) (L\d+)\2 is the conditional, \3 is the labelMatch the forward jump, storing the conditional and label
\nStart the block
((?!jnz|# L).*)Match any content so long as you do not see either jnz instructions or a # L label comment; nested ifs are fine
\nFinish the block
\1# \3Match the closing label comment. It has the same indentation and label as the jnz instruction

When we rewrite that, we want to generate the condition if {register} == 0: and then intent each line in the block by an additional four spaces. So if we apply this to the code we had earlier:

>>> code = '''
... set a 2
... set b 3
... jnz a 2
... add a b
... mul a b
... '''

>>> print(rewrite_simple_if(rewrite_jumps_as_absolute(code)))

set a 2
set b 3
if a == 0:
    add a b
mul a b

Next up, let’s rewrite non-overlapping backward loops.

def rewrite_simple_while(code):
        '''Rewrite non-nested backward jumps as while loops with a flag.'''

        def make_while(m):
            indentation, label, body, register = m.groups()
            return '''\
{indentation}# {label}
{indentation}while True:
{body}
{indentation}    if {register} == 0: break\
'''.format(
    indentation = indentation,
    register = register,
    label = label,
    body = '\n'.join('    ' + line for line in body.split('\n')),
)

        return re.sub(
            r'(\s*)# (L\d+)\n((?!jnz|# L).*)\n\1jnz ([a-h]|-?\d+) \2',
            make_while,
            code,
            flags = re.DOTALL
        )

This one is a bit more complicated, since we want the code to run at least one time. If only we had a do while loop. But instead, all we have to do is make it while True and put the actual conditional and a break at the end of the loop.

Example:

>>> print(rewrite_jumps_as_absolute(code))

set a 2
set b 3
# L0
add a b
sub b 1
jnz b L0

>>> print(rewrite_jumps(rewrite_jumps_as_absolute(code)))

set a 2
set b 3
# L0
while True:
    add a b
    sub b 1
    if b == 0: break

Finally, let’s rewrite one simple case that we have in our actual example code, and if-not. Something like this:

jnz a 2
jnz 1 5
mul b 100
sub b -100000
set c b
sub c -17000
set f 1

How to read that is if a != 0, we want to skip the second jump and run the rest of the code. If a == 0, then we hit the second jump and always skip over five instructions (1 != 0 is always true). Logically, that means that if a != 0, we’ll run the instructions but if a == 2, we will not. So we can rewrite that as so:

def rewrite_if_not(code):
    '''Rewrite an overlapping pair of forward jumps as if/else.'''

    def make_if_not(m):
        indentation, register, label_1, label_2, body = m.groups()
        return '''\
{indentation}if {register} != 0:
{body}\
'''.format(
indentation = indentation,
register = register,
body = '\n'.join('    ' + line for line in body.split('\n')),
)

    return re.sub(
        r'(\s*)jnz ([a-h]) (L\d+)\n\1jnz (?!0)-?\d+ (L\d+)\n\1# \3\n((?!jnz|# L).*)\n\1# \4',
        make_if_not,
        code,
        flags = re.DOTALL
    )

It’s similar to the if statements, but the regular expression is a bit more complicated, since we’re matching two back to back jumps and have to have two closing labels. In practice:

>>> code = '''
... set a 2
... set b 3
... jnz a 2
... jnz 1 5
... mul b 100
... sub b -100000
... set c b
... sub c -17000
... set f 1
... '''
>>> print(rewrite_jumps(rewrite_jumps_as_absolute(code)))
set a 2
set b 3
if a != 0:
    mul b 100
    sub b -100000
    set c b
    sub c -17000
set f 1

That’s pretty neat.

One last thing we want to do is deal with potentially nested if statements. To do that, we just want to run the above three functions over and over again until the current code stops changing:

@compile_step
def rewrite_jumps(code):
    '''Using the previous step, rewrite jumps.'''
    # Keep running these functions until we reach a stable state
    functions = [
        rewrite_simple_if,
        rewrite_if_not,
        rewrite_simple_while,
    ]

    while True:
        new_code = code
        for f in functions:
            new_code = f(new_code)

        if code == new_code:
            break
        else:
            code = new_code

    return code

That’s actually enough to make runnable Python code:

# BASH$ cat input.txt
set b 65
set c b
jnz a 2
jnz 1 5
mul b 100
sub b -100000
set c b
sub c -17000
set f 1
set d 2
set e 2
set g d
mul g e
sub g b
jnz g 2
set f 0
sub e -1
set g e
sub g b
jnz g -8
sub d -1
set g d
sub g b
jnz g -13
jnz f 2
sub h -1
set g b
sub g c
jnz g 2
jnz 1 3
sub b -17
jnz 1 -23
# BASH$ python3 duetvmc.py input.txt
# Final code:
b = 65
c = b
if a != 0:
    mul_count += 1; b *= 100
    b -= -100000
    c = b
    c -= -17000
# L7
while True:
    f = 1
    d = 2
    # L4
    while True:
        e = 2
        # L3
        while True:
            g = d
            g *= e
            g -= b
            if g == 0:
                f = 0
            e -= -1
            g = e
            g -= b
            if g == 0: break
        d -= -1
        g = d
        g -= b
        if g == 0: break
    if f == 0:
        h -= -1
    g = b
    g -= c
    if g == 0:
        sys.exit(0)
    b -= -17
    if 1 == 0: break

Okay. It’s not quite done. We still need three more things:

  • Initialize the registers
  • Import sys so we can run sys.exit(0) towards the end
  • Count mul instructions
  • Print the final value of that count, either if we sys.exit(0) or reach the end of the program

To do that, we can add two more compile steps:

@compile_step
def add_debug_statements(code):
    code = re.sub(r'([a-h]) \*=', r'mul_count += 1; \1 *=', code)
    code = re.sub(r'sys.exit', 'print(mul_count); sys.exit', code)
    code = 'mul_count = 0\n' + code + '\nprint(mul_count)'
    return code

@compile_step
def add_initial_registers(code):
    return 'a = b = c = d = e = f = g = h = 0\n' + code

That leaves our actual final code as:

# Final code:
import sys
a = b = c = d = e = f = g = h = 0
mul_count = 0
b = 65
c = b
if a != 0:
    mul_count += 1; b *= 100
    b -= -100000
    c = b
    c -= -17000
# L7
while True:
    f = 1
    d = 2
    # L4
    while True:
        e = 2
        # L3
        while True:
            g = d
            mul_count += 1; g *= e
            g -= b
            if g == 0:
                f = 0
            e -= -1
            g = e
            g -= b
            if g == 0: break
        d -= -1
        g = d
        g -= b
        if g == 0: break
    if f == 0:
        h -= -1
    g = b
    g -= c
    if g == 0:
        print(mul_count); sys.exit(0)
    b -= -17
    if 1 == 0: break
print(mul_count)

Which we can directly run with the python executable:

$ python3 duetvmc.py input.txt | python3

3969

That’s just cool.

Part 2: Set a = 1. What is the final value of h?

Well. We could run the initial program again. That would work. The problem is that it would take a really long time to finish (I tried it).

We could do a bunch of additional optimizations to try to make the code run more quickly, but for this one case, I decided just to actually look at the code and try to figure it out myself. Specifically, everything after # L7 is something we need to optimize.

First, the innermost block:

# L3
while True:
    g = d
    mul_count += 1; g *= e
    g -= b
    if g == 0:
        f = 0
    e -= -1
    g = e
    g -= b
    if g == 0: break

If you start at that for a while, you have a divisibility check. If g is an exact multiple of e, f will be set to 0. Otherwise, f will be 1 (set in L7). So that loop is more or less equivalent to:

if g % e == 0:
    f = 0

Going out one layer, we have:

# L4
while True:
    e = 2
    if g % e == 0:
        f = 0
    d -= -1
    g = d
    g -= b
    if g == 0: break

This is a loop from 2 up to g:

for e in range(2, g):
    if g % e == 0:
        f = 1

And finally, one level more:

# L7
while True:
    f = 1
    d = 2

    for e in range(2, g):
        if g % e == 0:
            f = 1

    if f == 0:
        h -= -1
    g = b
    g -= c
    if g == 0:
        print(mul_count); sys.exit(0)
    b -= -17
    if 1 == 0: break

Two things to recognize here. First, this is another loop on g from b to c but with a step size of 17 (the b -= 17 line). Also, we’re reading from f here. If it’s set, add one to h. Since we can never unset f back to one once it’s zero, we can rewrite the inner loop as h += 1 with a break as another speed boost. That gives us:

for g in range(b, c + 1, 17):
    for e in range(2, g):
        if g % e == 0:
            h += 1
            break
print(h)

That looks familiar. Essentially, it’s doing the inverse of a primality test. For each g, if g is a composite number, add 1 to h. Adding this all as a hacky compiler step:

# HACK
@compile_step
def replace_with_composite_counter(code):
    def comment_out(m):
        return '\n'.join('# ' + line for line in m.group(1).split('\n')) + '\n'

    code = re.sub(r'(# L7\n.*)', comment_out, code, flags = re.DOTALL)
    code += '''
# Turns out the code is checking for composite numbers... very inefficiently
# a != 0 sets up the range on b and c (note the off by one for c...)
# The L7 loop is looping from b to c by 17 (the b -= -17 at the end)
# The L4 loop is looping from 2 to g, setting f = 0 if g is divisible by the given numbers
# - NOTE: The original loop doesn't bail out early, which helps speed up a fair bit
# The L3 loop is doing the trial division ()

for g in range(b, c + 1, 17):
for e in range(2, g):
    if g % e == 0:
        h += 1
        break
print(h)
'''
    return code

Which gives us a final result of:

$ python3 duetvmc.py input.txt --part 2

# Final code:
import sys
a = 1
b = c = d = e = f = g = h = 0
b = 65
c = b
if a != 0:
    b *= 100
    b -= -100000
    c = b
    c -= -17000
# # L7
# while True:
#     f = 1
#     d = 2
#     # L4
#     while True:
#         e = 2
#         # L3
#         while True:
#             g = d
#             g *= e
#             g -= b
#             if g == 0:
#                 f = 0
#             e -= -1
#             g = e
#             g -= b
#             if g == 0: break
#         d -= -1
#         g = d
#         g -= b
#         if g == 0: break
#     if f == 0:
#         h -= -1
#     g = b
#     g -= c
#     if g == 0:
#         sys.exit(0)
#     b -= -17
#     if 1 == 0: break

# Turns out the code is checking for composite numbers... very inefficiently
# a != 0 sets up the range on b and c (note the off by one for c...)
# The L7 loop is looping from b to c by 17 (the b -= -17 at the end)
# The L4 loop is looping from 2 to g, setting f = 0 if g is divisible by the given numbers
# - NOTE: The original loop doesn't bail out early, which helps speed up a fair bit
# The L3 loop is doing the trial division ()

for g in range(b, c + 1, 17):
    for e in range(2, g):
        if g % e == 0:
            h += 1
            break
print(h)

Which runs much more quickly (especially if we leverage PyPy):

$ python3 run-all.py day-23

day-23  python3 duetvmc.py input.txt --part 1 | python3 0.10982990264892578     3969
day-23  python3 duetvmc.py input.txt --part 2 | python3 1.5763399600982666      917
day-23  python3 duetvmc.py input.txt --part 2 | pypy    0.42662978172302246     917

That was fun. I want to write a much bigger ‘real’ compiler now.