angstromCTF 2022 - RSA-AES

This challenge is a sequel to the challenge RSA-OTP from angstromCTF 2020, which was one of my first CTFs, so this was a nice sequel to a challenge I wasn’t capable of solving back then. Thanks to clam for writing it!

from Crypto.Util.number import bytes_to_long, long_to_bytes
from Crypto.Util.Padding import pad
from Crypto.Random import get_random_bytes
from Crypto.Cipher import AES
from secret import flag, d

assert len(flag) < 256

n = 0xbb7bbd6bb62e0cbbc776f9ceb974eca6f3d30295d31caf456d9bec9b98822de3cb941d3a40a0fba531212f338e7677eb2e3ac05ff28629f248d0bc9f98950ce7e5e637c9764bb7f0b53c2532f3ce47ecbe1205172f8644f28f039cae6f127ccf1137ac88d77605782abe4560ae3473d9fb93886625a6caa7f3a5180836f460c98bbc60df911637fa3f52556fa12a376e3f5f87b5956b705e4e42a30ca38c79e7cd94c9b53a7b4344f2e9de06057da350f3cd9bd84f9af28e137e5190cbe90f046f74ce22f4cd747a1cc9812a1e057b97de39f664ab045700c40c9ce16cf1742d992c99e3537663ede6673f53fbb2f3c28679fb747ab9db9753e692ed353e3551
e = 0x10001
assert pow(2,e*d,n)==2

enc = pow(bytes_to_long(flag),e,n)
print(enc)

k = get_random_bytes(32)
iv = get_random_bytes(16)
cipher = AES.new(k, AES.MODE_CBC, iv)

while 1:
	try:
		i = int(input("Enter message to sign: "))
		assert(0 < i < n)
		print("signed message (encrypted with military-grade aes-256-cbc encryption):")
		print(cipher.encrypt(pad(long_to_bytes(pow(i,d,n)),16)))
	except:
		print("bad input, exiting")

We get access to a server which encrypts the flag with RSA, and gives us an RSA decryption oracle, however all outputs are encrypted with AES-CBC mode, with random IV and key.

primitives

homomorphic RSA

Recall that RSA is homomorphic, i.e. $a^e * b^e = (ab)^e$. Since we have $c = f^e$ as the encrypted flag, and we have the public key, we can get a decryption of $f * k$ if we ask to decrypt $c * k^e$.

This is useful for a couple things, most notably if we set $k = 2^i$, as this allows us to bit shift the flag to the left, or increasing the bit length by $i$,

somewhat accurate bit length calculator

The idea for RSA-OTP was that the OTP leaked the precise bit length of the decrypted ciphertext. We have a similar primitive here, although we are going to have to use the first primitive to help us.

Notice that we can recover the bit length of a given plaintext accurately provided the bit length is under 2040 (for reasons I will explain shortly).

The idea here is that we can multiply the plaintext by 2 until we get a change in the number of blocks in the AES ciphertext. This allows us to determine when the number of bytes in the last block reaches 16, in which case the pad function will add another padding block, and so the number of blocks in the ciphertext will increase.

However, this will not work if the bit length is above 2041. This is because for a plaintext of bit length 2041, the last block will have 16 bytes already, and so the last block will be a padding block. If we try to multiply the plaintext by 2 until we reach the next block, our plaintext will be of the size $2^{2041} * 2^{120} = 2^{2161}$, however the issue is that this is significantly larger than $n$, and so we will have to reduce this $\bmod n$. This means that the plaintext is no longer “accurate”, and we are unable to use the method as described above.

Here’s a function to do that:

def recover_bitlen(ciphertext):
    blocklen = len(query(ciphertext)) // 16
    if blocklen == (2048//16): # this means it is > 2041 bits:
        return 2041
    lb = 0
    ub = 128
    while ub - lb != 1: # binary search to find when it changes
        k = (lb + ub)//2
        ct = ciphertext * pow(2, k*e, n)
        bl = len(query(ct % n)) // 16
        if bl == blocklen:
            lb = k
        else:
            ub = k
    return blocklen*128 - lb - 8

This also allows us to recover the bit length of the flag, which is 1759 bits = 220 bytes.

solution idea 1

The solution idea here is similar to a solution to RSA-OTP, which is detailed here. The general idea is that we find some value $k$ such that:

\[kf \approx 2^{1024}\]

which we can determine based on seeing when the bit length changes from 1023 to 1024, and then dividing $2^{1024}$ by $k$ to get the flag.

We can do a similar thing here, where we find $k$ such that:

\[kf \approx 2^{2041}\]

as we are only able to distinguish up to here. We can binary search $k$ based on the bit length, and then divide $2^{2041}$ by $k$ to hopefully recover the flag.

An implementation of this is below:

flagmin, flagmax = 2**1758, 2**1759
lb = (2**2041)//flagmax
ub = (2**2041)//flagmin

while ub - lb != 1:
    k = (lb + ub)//2
    ct = pow(k, e, n) * c
    bitlen = recover_bitlen(ct % n)
    if bitlen == 2041:
        ub = k
    else:
        lb = k

print(lb)
print(long_to_bytes(2**2041 // lb))
b'actf{the_letters_in_rsa_and_aes_foryis\x17B\x8a\xbb)\x1d\x9c\xeeX\xd1\xae\xc2\xa8\xe4\xda%\x14\xb9)\xebc\xbei@:\xc5\x0eVD3_\xf4D\xd1\xbf\xdc\x92\xbf\x00\x9ci\xb8\xd7\xed2M\x06{3\x96\n\xee\x10u\xe6\x10\x95\x03P\xc4\x07\x92\xb2\x10J\x11\xf9Lo\x06\x9d\xeek\xe9/\x17\xd3\xd9\xb5E\x19\x7f1\xc4"f\xe7&\x08&\x98=\x8cC\x8a\xaf\xf28\xdeO\x19!\xf7#\xd6SS\nH\x82\xd5\x1c\xbe\xb5\\\xac\xbb\x10*\xb4K<\xa9\xc2|\x94\xb0\xa5\xc4L\r\xae?pW\x02d\x8e\x7f5.t\xa5\xcc\xabd\xeb\xdeF\x92\x87\xfb\xb0\xd8\xb7g2\x0b\x9fR\x9f\xa42\x93\x12\xf3\xab\xb1\x7fI7"\x98"\x84|\xec`%\x8d2\x03'

Ouch, the flag is not accurate enough. We’ll need to find a better method, as we still have a very large portion of the flag to recover. Let’s think about some more primitives and also think about how this challenge differs from RSA-OTP.

more primitives

AES is deterministic

The main difference between RSA-OTP and this challenge is that the OTP was randomly generated, however this time, AES is used, and AES is notably deterministic. This means that, the same plaintext with the same IV and key will encrypt to the same thing every single time.

This allows us to compare encrypted blocks, and deduce whether they are the same or not, however we cannot “compare” blocks arithmetically.

However, there is a very minor issue here. Notice that the cipher object is defined at the start, meaning the IV for each encryption will be effectively different. This isn’t really a problem however, as recall that in CBC mode, for blocks after the first, the IV is effectively the last block of ciphertext, which we will know. We can XOR this with a plaintext block that we choose, meaning the outputted ciphertext will always be the same value.

An implementation of this is below:

def set_iv(): 
    b = query(1) # this is just a "fake" block so that we can recover the iv used for next round
    iv = b[-16:] # we need to xor our thing with this so what's feeded into the AES is our chosen block
    newblock = bytes_to_long(bxor(b"a" * 16, iv))
    ct = pow(newblock, e, n)
    query(ct) # this will always fix the iv

We will call this function every time before we query a value, as this allows us to compare blocks.

solution idea 2

Our idea will be similar to our initial idea, but since we need to be more accurate, we will instead find $k$ for some known $a$ such that:

\[kf \approx an\]

Now, to ensure we can actually do this, we need to choose $a$ in such a way that we can actually find $k$ without too much bruteforce.

Using our approximation of the flag, we can work out upper and lower bounds for $k$, and double check that the first block of the encryption of $lb * f$ is the same as the encryption of $n$, (if we choose a too large value of $a$ without a good enough approximation for $f $, then the upper and lower bounds for $k$ will be not be accurate enough for us to be able to compare the block).

We also need to check that the first block of the encryption of $ub * f$ is not the same as the encryption of $n$, as this indicates there is a $\bmod n$ being taken between these values at some point, which is the value we eventually want to find.

We can use binary search to find the best approximation for $k$. If our guess for $k$ is too large, then $kf = an + b$, and so taken $\bmod n$, the first 128 bits will be the first 128 bits of $b$, which is unlikely to be the first 128 bits of $n$, and therefore the AES ciphertext block will be different. If our guess is too small, then $kf = an - b$, and so taken $\bmod n$, the first 128 bits will be of $n - b$, which will be the same as the bits of $n$. So, after each guess, we adjust our guess for $k$ accordingly, until we find a value for $k$ where $kf < an$ and $(k+1)f > an$, in which case we know we have the best approximation for $k$.

After finding the best possible $k$, we can get an approximation for the flag by calculating $\frac{an}{k}$, and then after increasing $a$ with the new flag approximation, we repeat until we get the whole flag.

Full solve script (with added comments for clarity) below:

from pwn import *
from Crypto.Util.number import *
import os
from math import log, floor
import string

s = remote('challs.actf.co', 31500)

# pow stuff
s.recvuntil("work:")
powcmd = s.recvline().decode()[:-1]
os.system(powcmd + " > pow.txt")
s.sendline(open("pow.txt").read()[:-1])

n = 0xbb7bbd6bb62e0cbbc776f9ceb974eca6f3d30295d31caf456d9bec9b98822de3cb941d3a40a0fba531212f338e7677eb2e3ac05ff28629f248d0bc9f98950ce7e5e637c9764bb7f0b53c2532f3ce47ecbe1205172f8644f28f039cae6f127ccf1137ac88d77605782abe4560ae3473d9fb93886625a6caa7f3a5180836f460c98bbc60df911637fa3f52556fa12a376e3f5f87b5956b705e4e42a30ca38c79e7cd94c9b53a7b4344f2e9de06057da350f3cd9bd84f9af28e137e5190cbe90f046f74ce22f4cd747a1cc9812a1e057b97de39f664ab045700c40c9ce16cf1742d992c99e3537663ede6673f53fbb2f3c28679fb747ab9db9753e692ed353e3551 
e = 0x10001
c = 8702343735025266604493255023455944506448203943421139140860292426782509680048873569587596911548140388664137318807031840409098358526949080521742044811655775937519290500584015066858945832554481838588845652546365303337275177311841968984511864709414904338587040274646253896911291865516895328016639201704613064462056129248165695806001948541939983934752397658730589917722952341689278968790138018539025744727334202968601465562656795710345742329370762196560147624069504644216320910078454455520823319751612174752365432199529820974601209221975964155055716974539394252799602068585830290371624396504384131654854623710049511046664

def bxor(ba1, ba2):
    return bytes([x^y for x, y in zip(ba1, ba2)])

def query(num):
    s.sendlineafter("Enter message to sign: ", str(num))
    s.recvline()
    return eval(s.recvline()) # not safe but whatever

def recover_bitlen(ciphertext):
    blocklen = len(query(ciphertext)) // 16
    if blocklen == (2048//16): # this means it is > 2041 bits:
        return 2041
    lb = 0
    ub = 128
    while ub - lb != 1: # binary search to find when it changes
        k = (lb + ub)//2
        ct = ciphertext * pow(2, k*e, n)
        bl = len(query(ct % n)) // 16
        if bl == blocklen:
            lb = k
        else:
            ub = k
    return blocklen*128 - lb - 8


flagbitlen = recover_bitlen(c)
flagbytes = 1 + flagbitlen//8

# stage 1: recover first part of flag by finding k*flag = 2^2041
# this takes quite a while, feel free to comment out
flagmin, flagmax = 2**(flagbitlen-1), 2**flagbitlen
lb = (2**2041)//flagmax
ub = (2**2041)//flagmin

while ub - lb != 1:
    print((ub - lb).bit_length())
    k = (lb + ub)//2
    ct = pow(k, e, n) * c
    bitlen = recover_bitlen(ct % n)
    if bitlen == 2041: # k * flag > 2^2041, therefore we need to reduce the upper bound
        ub = k
    else: # k * flag < 2^2941, therefore we need to increase the lower bound
        lb = k

print("flag part 1: ", long_to_bytes(2**2041 // lb))

# generates minimum and maximum values for flag
def genminmax(flag):
    return bytes_to_long(flag + (b"\x20" * (flagbytes - len(flag)))), bytes_to_long(flag + (b"\x7f" * (flagbytes - len(flag))))

# sets iv before each query
def set_iv(): 
    b = query(1) # this is just a "fake" block so that we can recover the iv used for next round
    iv = b[-16:] # we need to xor our chosen block with this because of CBC
    newblock = bytes_to_long(bxor(b"a" * 16, iv))
    ct = pow(newblock, e, n)
    query(ct) # the iv is now fixed
    
# attempts to recover the flag part that is actually correct, there are probably better ways to do this (including manually), but this will do
def determine_flag(flags):
    charset = string.printable.encode()
    f = list(flags)
    f1, f2 = f[-2], f[-1]
    tflag = b""
    for x, y in zip(f1, f2):
        if x == y and x in charset:
            tflag += bytes([x])
        else:
            break
    return tflag[:-1] # take off a couple for accuracy

flag = b"actf{the_letters_in_rsa_and_aes_for"

def query2(num):
    set_iv()
    ct = pow(num, e, n) * c
    return query(ct % n)

# stage 2: recover full flag by finding k*flag = a*n for increasing a
# where a is chosen based on how much flag is known
# this takes even longer, a needs to reach about 2^1500 to recover the whole flag

a = 1
set_iv()
encn = query(n-1)[:16]

while len(flag) != flagbytes:
    print(f"a: 2^{a.bit_length()}")
    minflag, maxflag = genminmax(flag)
    an = a * n
    lowerbound, upperbound = an//maxflag, an//minflag
    assert query2(lowerbound)[:16] == encn # makes sure lowerbound * flag is close to a*n by checking first 128 bits
    assert query2(upperbound)[:16] != encn # makes sure upperbound * flag > a*n, so that there is a difference between these bounds
    power = 2**floor(log(upperbound-lowerbound, 2))
    base = 0
    flags = []
    # binary search
    while power:
        print(power.bit_length())
        changed = 0
        for i in range(3): # not quite binary search but close enough, just for reliability
            cip = query2(lowerbound + base + power*i)[:16]
            if cip != encn: # k*flag > a*n
                base += power * (i-1)
                power //= 2
                changed = 1
                break
        f = long_to_bytes((a * n)//(lowerbound + base))
        print(f)
        if f not in flags:
            flags.append(f)
        if not changed: # this is kinda messy and honestly its easier to do this manually
            # but for the sake of having a full solve script, we need to ensure 
            a //= 2**25 # just to even it out a bit
            break
    print(flags)
    flag = determine_flag(flags)
    a *= 2**50
    print(flag)

and after an hour or so of running this (with a few manual reruns due to timeouts), we get the flag.

Flag:

actf{the_letters_in_rsa_and_aes_form_aries_if_you_throw_in_the_letter_i_because_that_represents_yourself_or_something_anyway_aries_is_a_zodiac_sign_which_means_that_the_two_cryptosystems_are_mutually_compatble_i_think??}
HTB CyberApocalypse CTF 2022 HackPack CTF 2022