# TetCTF 2022 - fault

from secrets import randbits
from Crypto.Util.number import getPrime  # pycryptodome

NBITS = 1024
D_NBITS = 128  # small d makes decryption faster

class Cipher:
def __init__(self):
p = getPrime(NBITS // 2)
q = getPrime(NBITS // 2)
self.n = p * q
self.d = getPrime(D_NBITS)
self.e = pow(self.d, -1, (p - 1) * (q - 1))

def encrypt(self, m: int) -> int:
assert m < self.n
return pow(m, self.e, self.n)

def faultily_decrypt(self, c: int):
assert c < self.n
fault_vector = randbits(D_NBITS)
return fault_vector, pow(c, self.d ^ fault_vector, self.n)

def main():
from secret import FLAG
cipher = Cipher()
c = cipher.encrypt(int.from_bytes(FLAG.encode(), "big"))

for _ in range(2022):
line = input()
print(cipher.faultily_decrypt(c if line == 'c' else int(line)))

if __name__ == '__main__':
main()


The server generates a 1024 bit modulus $n$ and a 128 bit prime as the private exponent $d$. We are then given access to an RSA decryption oracle, which we can use 2022 times.

On each call to the oracle, we can either:

• decrypt the flag $c$
• decrypt any integer of our choice

However, the decryption function is broken, and instead of computing $c^d \mod n$, it generates a random 128 bit value $v$, and returns the value:

$c^{d\oplus v}\mod n$

for both functions of the oracle.

As well as this, neither $n$ or $e$ is given.

# solution

tl;dr: mitm + z3 to recover $d$, then common modulus attack to recover flag.

Firstly, since we are given nothing else apart from the oracle, we should first recover $n$. This is quite simple to do: asking for the decryption of -1. If the resulting $d_i$ is odd (which should happen with a probability 1/2), $-1^{d_i}$ will equal $-1 \mod n = n - 1$, and so $n$ has been recovered.

Then, we need to recover $d$.

## before

Before we delve into the actual challenge, consider the equation $a+ s = b$, where $a, b$ are unknown, however we know the value of $v = a\oplus b$, and we are trying to find the difference $s$ between $a$ and $b$.

For each of the bits $v_i$ in $v$, we have two cases:

Case 1: $v_i = 0$

In this case, we know that the respective bits of $a$ and $b$ are also equal, meaning that $a_i = b_i$, therefore $a_i - b_i = 0$, and so this means that the difference $s$ is not affected by this bit.

Case 2: $v_i = 1$

The respective bits of $a$ and $b$ are different, however we do not know which of the bits is 0 and which is 1. What we do know however regardless of which bit is which, the absolute value of $a_i - b_i = 1$, meaning that $s$ is affected by this bit, and we either add $2^i$ or subtract $2^i$ (since we don’t know if the difference is positive or negative, only that it exists for this bit)

So, we know that the value of $s$ is only affected by the one-bits in $v$, and we so can represent $s$ as a sum:

$\sum_{i=0}^{128} v_i * 2^i * k_i$

where $\forall i, k_i \in {-1, 1}$ (to get the positive or negative).

Notice then for a given $v$ with $l$ number of one-bits, there exist $2^{l}$ solutions for $s$.

## application

Now, in the challenge we are given $c^{d \oplus v_1}$ and $c^{d \oplus v_2}$ both taken $\mod n$. Notice then if we let $a = d \oplus v_1, b = d \oplus v_2, v = v_1 \oplus v_2$ (this is actually $d \oplus v_1 \oplus d \oplus v_2$, however the $d$’s cancel out, so we can use this value for $v$), the difference between $a$ and $b$ will be the value $s$ where:

$c^a * c^s = c^b$

since by the rules of indices ($c^{a} * c^{s} = c^{a + s}$), $a + s = b$. Then, since we know the value of $v$, we can work out all $2^l$ solutions for $s$ and see if $c^a * c^s = c^b$, and if so then we know that value for $s$ is the actual difference between $a$ and $b$. I wrote a branching algorithm to do this for me:

c = 2
n = # ...
p1 = # ...
v1 = # ...
p2 = # ...
v2 = # ...
v = v1 ^ v2
ca = pow(c, v1, n)
cb = pow(c, v2, n)
bits = []

for i in range(v.bit_length() + 1):
if (v >> i) & 1:
bits.append(2**i)

mul = [(pow(2, x, n), pow(pow(2, x, n), -1, n)) for x in bits]
maxdepth = len(mul) # number of one-bits

def tree(current, depth, diff):
if depth == maxdepth:
if (ca * pow(c, diff, n)) % n == cb: # check if the solution is correct
print("sice", diff)
else:
tree((current * mul[depth]) % n, depth + 1, diff + bits[depth]) # + 2^i
tree((current * mul[depth]) % n, depth + 1, diff - bits[depth]) # - 2^i

tree(ca, 0, 0)


Then, if we receive many $v_i$’s and find pairs where $l$ is small, we should be able to bruteforce the $2^l$ solutions for $s$.

Thus, the attack plan is as follows:

• recover $n$ by decrypting -1
• collect around 2000~ pairs of $(v_i, c^{d \oplus v_i})$
• find pairs of $v_i$ where $l$ is small
• recover the difference between $d \oplus v_1$ and $d \oplus v_2$
• use this to recover bits of $d$
• repeat until all of $d$ is recovered
• profit?

There’s only one small issue however, this is not feasible.

If we experiment with around 2000 $v_i’s$ locally, we see that usually the best $l$ values are around 36-40 bits.

import itertools
import random
from collections import Counter

v_is = [random.getrandbits(128) for _ in range(2000)]
ls = [bin(x^y).count("1") for x, y in itertools.combinations(v_is, r=2)]
print(sorted(Counter(ls).items()))
# [(36, 1), (37, 1), (38, 6), (39, 12), (40, 14), (41, 34), (42, 67)...


This is probably not feasible, even if we were to chuck large amounts of computing power at it. So, how can we make this feasible? The answer is simple: the meet in the middle attack.

## mitm time

The general idea for meet in the middle is that we save time by using space. This attack works when we have a known plaintext/ciphertext pair $(p, c)$ where $p$ has been encrypted several times by some “encryption” function $E(p, k_i)$ using an $i$ number of keys $k_1, k_2, \dots k_i$. So, we have:

$c_1 = E(p, k_1)\\ c_2 = E(c_1, k_2)\\ \vdots\\ c = E(c_{i-1}, k_i)$

Now, the function $E$ has an inverse/”decryption” function, $D$, which should satisfy $D(E(p, k), k) = p$. Notice then that if we pick a key number $m$, then we can encrypt $m$ number of keys to get $c_m$ (the middle point)

$c_1 = E(p, k_1)\\ c_2 = E(c_1, k_2)\\ \vdots\\ c_m = E(c_{m-1}, k_m)$

However, since we also have the ciphertext, we could also decrypt with the remaining $(i - m)$ keys to get the same $c_m$, as:

$c_{i-1} = D(c, k_i)\\ c_{i-2} = D(c_{i-1}, k_{i-1})\\ \vdots\\ c_m = D(c_{m+1}, k_{m+i})$

Then notice that since we know these two are equal, we can bruteforce the $m$ number of keys required to reach $c_m$ by encryption of $p$ and store them in a lookup table along with the keys used to get there, and then try all $(i - m)$ keys to attempt to reach a $c_m$ that we stored by decryption of $c$. If we find a $c_m$ we reached before, we know that is the actual value of $c_m$, and figure out what keys we used to get there. If we choose $m$ to be half of $i$, we can save a significant amount of time (example, if we have $i=2$ and each key being $20$ bits long, instead of bruteforcing $2^{20} * 2^{20} = 2^{40}$, we only brute $2^{20} + 2^{20} = 2^{21}$, a significant time save).

## application 2

So, we’ll apply this attack to our situation. Firstly, we choose $v_1$ and $v_2$ that results in $l$ being small, and also get the corresponding ciphertexts $c_1, c_2$. The encryption function $E$ is then multiplying the value by either $x^{1}$ or $x^{-1}$, which is determined by the key $k_i \in {1, -1}$. The decryption function $D$ is then therefore division by $x^1$ or $x^-1$. I’ll reuse code from the branching algorithm.

n = # ...
p1 = # ...
v1 = # ...
p2 = # ...
v2 = # ...

v = v1 ^ v2
bits = []
for i in range(v.bit_length() + 1):
if (v >> i) & 1:
bits.append(2**i)

bits1 = bits[:len(bits)//2]
bits2 = bits[len(bits)//2:]
n1 = len(bits1)
n2 = len(bits2)
mul1 = [(pow(2, x, n), pow(pow(2, x, n), -1, n)) for x in bits1]
mul2 = [(pow(pow(2, x, n), -1, n), pow(2, x, n)) for x in bits2]

# meet in the middle
su = 0
lookup_one = {}

def tree1(current, depth, diff):
if depth == n1:
lookup_one[current] = diff
else:
tree1((current * mul1[depth]) % n, depth + 1, diff + bits1[depth])
tree1((current * mul1[depth]) % n, depth + 1, diff - bits1[depth])

def tree2(current, depth, diff):
global su
if su: return
if depth == p2:
if current in lookup_one:
print("sice", lookup_one[current], diff, current)
su = lookup_one[current] + diff
else:
tree2((current * mul2[depth]) % n, depth + 1, diff + bits2[depth])
tree2((current * mul2[depth]) % n, depth + 1, diff - bits2[depth])

tree1(p1, 0, 0)
print("lookup created")
tree2(p2, 0, 0)


This should be fast enough to give us multiple equations of the form:

$(d \oplus v_1) + s = (d \oplus v_2)$

where $d$ is the only unknown. We can plug these into z3 (because I’m lazy) to solve it for us until we get a prime value for $d$ as shown in the challenge code.

Now that we have $d$, we need to recover the flag. However, we still only have the oracle which can faultily decrypt the flag. How can we get the flag without the actual ciphertext?

## flag recovery

Firstly, since we have the value of $d$, we can work out the actual private keys as we know $v_i$. Then, asking for the decryption of the flag will give us $c^{d_i}$. We can retrieve multiple encryptions to get the flag ciphertext encrypted with different exponents (but same modulus!), which might sound familiar to you, as we can use the common modulus attack. I’ll briefly describe it here.

Let the two exponents be $d_1, d_2$. Suppose $gcd(d_1, d_2) = 1$ (if this is not the case, we can just collect more $d_i$ until we find a pair which does have a gcd of 1).

By Bezout’s Identity, we know that there exist two integers $a, b$, where $ad_1 + bd_2 = 1$, which also holds $\mod n$. These integers are findable using the Extended Greatest Common Divisor algorithm. Then, knowing this and combining it with the rules of indices, we can manipulate the ciphertexts $c_1, c_2$ and their exponents in such a way that the resulting exponent is equal to 1, meaning we have the original flag plaintext.

Since $c^{a} * c^{b} = c^{a + b}$, we’ll ideally want to find values of $c^{ad_1}$ and $c^{bd_2}$, because remember, $ad_1 + bd_2 = 1$, so if we multiply them, we should just be left with $c^1$, or $c$. Getting these values is quite simple; we use the fact that $c^{a^b} = c^{ab}$, so we raise $c_1 = c^{d_1}$ to the power of $a$ to get $c^{ad_1}$, $c_2 = c^{d_2}$ to the power of $b$ to get $c^{bd_2}$.

After recovering the actual flag ciphertext, we have the private key, so we can just decrypt it.

# solve script

Putting this altogether, we get a script that runs in about 5 minutes. The main place where the script is kinda slow is the actual meet in the middle itself, as we have to do multiple brutes of around $2^{21}$.

# type: ignore
from pwn import *
from tqdm import tqdm
import itertools
import ast
from z3 import *
from Crypto.Util.number import *
from math import gcd

s = remote("139.162.61.222", 13373)

def query(x, rep=1):
s.sendline((f"{x}\n" * rep)[:-1])
dat = s.recvlines(rep)
dat = [ast.literal_eval(line.decode()) for line in dat]
if rep == 1:
return dat
return dat

dec = 1
while dec == 1:
fault, dec = query(-1)

n = dec + 1
flagdat = query("c", 10)
flags = dict(flagdat)

tupledat = query(2, 2000)
tuples = dict(tupledat)

s.close()

print("server data recovered")
rands = tuples.keys()
goodtuples = []

def egcd(a, b):
if a == 0 :
return b,0,1
gcd,x1,y1 = egcd(b%a, a)
x = y1 - (b//a) * x1
y = x1
return gcd,x,y

for x, y in tqdm(itertools.combinations(rands, r=2)):
l = bin(x^y).count("1")
if l < 42:
goodtuples.append((l, x, y))
print("found good tuple: ", x, y, l)

goodtuples.sort()

# meet in the middle
def tree1(current, depth, diff):
if depth == n1:
lookup_one[current] = diff
else:
tree1((current * mul1[depth]) % n, depth + 1, diff + bits1[depth])
tree1((current * mul1[depth]) % n, depth + 1, diff - bits1[depth])

def tree2(current, depth, diff):
global su
if su: return
if depth == n2:
if current in lookup_one:
print("sice", lookup_one[current], diff, current)
su = lookup_one[current] + diff
else:
tree2((current * mul2[depth]) % n, depth + 1, diff + bits2[depth])
tree2((current * mul2[depth]) % n, depth + 1, diff - bits2[depth])

sol = Solver()
d = BitVec('d', 129)

def recoverflag(_d):
es = {}
for i, x in flags.items():
es[(_d ^ i)] = x
_e1, _e2 = 0, 0
for i, e1 in enumerate(es):
for j, e2 in enumerate(es):
if i != j:
if gcd(e1, e2) == 1:
_e1, _e2 = e1, e2
if not _e1 and _e2:
return
x, a, b = egcd(_e1, _e2)
c1, c2 = es[_e1], es[_e2]
ct1 = pow(c1, a, n)
ct2 = pow(c2, b, n)
ct = ct1 * ct2 % n
flag = (long_to_bytes(pow(ct, _d, n)))
if b"TetCTF" in flag:
print("Flag: ", flag.decode())
exit(1)

for l, v1, v2 in goodtuples:
print("using tuple: ", l, v1, v2)

p1 = tuples[v1]
p2 = tuples[v2]
v = v1 ^ v2
bits = []
for i in range(v.bit_length() + 1):
if (v >> i) & 1:
bits.append(2**i)

bits1 = bits[:len(bits)//2]
bits2 = bits[len(bits)//2:]
n1 = len(bits1)
n2 = len(bits2)
mul1 = [(pow(2, x, n), pow(pow(2, x, n), -1, n)) for x in bits1]
mul2 = [(pow(pow(2, x, n), -1, n), pow(2, x, n)) for x in bits2]

su = 0
lookup_one = {}
tree1(p1, 0, 0)
print("lookup created")
tree2(p2, 0, 0)

assert (p1 * pow(2, su, n) % n == p2)
sol.add(((d ^ v2) - (d ^ v1)) == su)
assert sol.check() == sat
_d = sol.model()[d].as_long()
print("recovered: ", _d)
if isPrime(_d):
print("attempting to decrypt flag with: ", _d)
recoverflag(_d)
elif isPrime(_d - 2**128): # z3 weird af
print("attempting to decrypt flag with: ", _d - 2**128)
recoverflag(_d - 2**128)

# Flag: TetCTF{4n_unr34l1st1c_f4ult____1_th1nk}


#### Flag: TetCTF{4n_unr34l1st1c_f4ult____1_th1nk}

sidenote: i rewrote this on around the 10th of january because it didn’t read very nicely.