0%

2023-强网杯-wp-crypto

期末进入中场休息时间,来补一补这段时间这些比赛的wp。*代表赛后复现的题目。

not only rsa

题目描述:

1
这个模数好像很不安全,那你能解密出flag吗

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from Crypto.Util.number import bytes_to_long
from secret import flag
import os

n = 6249734963373034215610144758924910630356277447014258270888329547267471837899275103421406467763122499270790512099702898939814547982931674247240623063334781529511973585977522269522704997379194673181703247780179146749499072297334876619475914747479522310651303344623434565831770309615574478274456549054332451773452773119453059618433160299319070430295124113199473337940505806777950838270849
e = 641747
m = bytes_to_long(flag)

flag = flag + os.urandom(n.bit_length() // 8 - len(flag) - 1)
m = bytes_to_long(flag)

c = pow(m, e, n)

with open('out.txt', 'w') as f:
print(f"{n = }", file=f)
print(f"{e = }", file=f)
print(f"{c = }", file=f)

按题目提示把n放到factordb上看看,可以发现:

然后e果然也是p-1的因子,所以这个题估计是想要在模p下开e次方后用hensel找到模p^5下的根,但是直接用集成的nth_root函数就可以。

exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from Crypto.Util.number import *

p = 91027438112295439314606669837102361953591324472804851543344131406676387779969
n = 6249734963373034215610144758924910630356277447014258270888329547267471837899275103421406467763122499270790512099702898939814547982931674247240623063334781529511973585977522269522704997379194673181703247780179146749499072297334876619475914747479522310651303344623434565831770309615574478274456549054332451773452773119453059618433160299319070430295124113199473337940505806777950838270849
e = 641747
c = 730024611795626517480532940587152891926416120514706825368440230330259913837764632826884065065554839415540061752397144140563698277864414584568812699048873820551131185796851863064509294123861487954267708318027370912496252338232193619491860340395824180108335802813022066531232025997349683725357024257420090981323217296019482516072036780365510855555146547481407283231721904830868033930943

res = Zmod(p^5)(c).nth_root(e, all=True)
for i in res:
temp = long_to_bytes(int(i))
if(b"flag" in temp):
print(temp)

#flag{c19c3ec0-d489-4bbb-83fc-bc0419a6822a}



guess game

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from Crypto.Util.number import bytes_to_long, long_to_bytes
from os import urandom
from binascii import unhexlify

class cipher:
def __init__(self, key, rounds=4):
self.key = key
self.rounds = rounds
self.sbox = [0xc, 0x5, 0x6, 0xb, 0x9, 0x0, 0xa, 0xd, 0x3, 0xe, 0xf, 0x8, 0x4, 0x7, 0x1, 0x2]
self.pbox = [0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63]
self.rk = self.genrk(self.key)

def substitution(self, state):
output = 0
for i in range(16):
output += self.sbox[state >> (i*4) & 0xF] << (i*4)
return output

def permutation(self, state):
output = 0
for i in range(64):
output += ((state >> i) & 0x1) << self.pbox[i]
return output

def genrk(self, key):
rk = []
for i in range(1, self.rounds+1):
rk.append(key >> 16)
key = ((key & (2**19-1)) << 61) + (key >> 19)
key = (self.sbox[key >> 76] << 76)+(key & (2**76-1))
key ^= i << 15
return rk

def addrk(self, state, rk):
return state ^ rk

def encrypt(self, pt):
ct = b""
state = pt
for i in range(self.rounds-1):
state = self.addrk(state, self.rk[i])
state = self.substitution(state)
state = self.permutation(state)
state = self.addrk(state, self.rk[-1])
ct += long_to_bytes(state)
return ct

def hint(self, pt):
return self.encrypt(pt)

with open("flag.txt", "r") as f:
flag = f.read()

op = '''1.get hint\n2.start game\n'''
count = 0
success = 0
key = int.from_bytes(urandom(10), "big")
guess = list(map(int, list(bin(key)[2:].zfill(80))))
game = cipher(key)
while True:
print(op)
user_input = int(input(">").strip())
if user_input == 1:
if count < 80:
pt = int(input("pt in hex:"), 16)
hint = game.hint(pt)
count += 1
print(hint.hex())
else:
print("Sorry~")
elif user_input == 2:
for i in range(len(guess)):
user_guess = int(input(f"Round {i + 1} > ").strip())
if user_guess == guess[i]:
print("Right!")
success += 1
else:
print("Wrong!")
if success > 0.7 * len(guess):
print(flag)
else:
print("Failed!")
exit(-1)
else:
exit(-1)

题目实现了一个三轮的分组加密类,然后生成一个10字节的key,并用它初始化一个加密类对象。

在这之后,题目给了我们两种交互:

  • 输入”1”,我们可以选择至多80组明文,并获得他的加密结果
  • 输入”2”,我们可以对key的80个比特进行猜测,猜测成功的次数>56次的话就能得到flag

看完整个题目的流程,其实不管用什么方式,80个bit猜对56次以上就能拿到flag了。而每一次都随机猜,猜对的概率是0.5,那么对于完全乱猜来说这其实就是一个二项分布,可以求得80次猜对56次以上的概率大约在0.0001,也就是万分之一左右。那么根据生日攻击来说,我们硬交互一万次就有大概率可以拿到flag。

而这个题又恰恰好不设置pow,所以交互上一万次不是难事,时间也不久,运气好点应该甚至不需要一个小时。

exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from Crypto.Util.number import *
from pwn import *

#context.log_level = 'debug'

while(1):
r = remote("47.97.69.130",22333)
r.sendline(token)
r.sendline(b"2")

count = 0
for i in range(80):
r.recvuntil(b">")
r.sendline(b"0")
temp = r.recvline()
if(b"Right" in temp):
count += 1
print(count)
if(count > 56):
print(r.recvline())
exit()
r.close()

#flag{be050d3fe312654d40d4ebb60d667c22}

而预期解是基于一篇论文的积分攻击,关于这一点可以看nepnep战队师傅的wp。不过这种针对分组加密本身漏洞的攻击方法我一直以来都很懵,未来也许有机会会尝试一下吧。



*discrete_log

题目描述:

1
离散对数分离

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from Crypto.Util.number import *
from Crypto.Util.Padding import pad
flag = 'flag{d3b07b0d416ebb}'

assert len(flag) <= 45
assert flag.startswith('flag{')
assert flag.endswith('}')

m = bytes_to_long(pad(flag.encode(), 128))

p = 0xf6e82946a9e7657cebcd14018a314a33c48b80552169f3069923d49c301f8dbfc6a1ca82902fc99a9e8aff92cef927e8695baeba694ad79b309af3b6a190514cb6bfa98bbda651f9dc8f80d8490a47e8b7b22ba32dd5f24fd7ee058b4f6659726b9ac50c8a7f97c3c4a48f830bc2767a15c16fe28a9b9f4ca3949ab6eb2e53c3
g = 5

assert m < (p - 1)

c = pow(g, m, p)

with open('out.txt', 'w') as f:
print(f"{p = }", file=f)
print(f"{g = }", file=f)
print(f"{c = }", file=f)

这个题目没有什么东西要说,赛中因为这句:

1
assert len(flag) <= 45

就一直以为flag虽然短,但怎么说也应该有三四十个字符,所以完全没往mitm上想。并且赛中的前一段时间解出这题的队伍一度很少,所以也很自然地就往cado、论文这种错误的方向上靠了。

exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from Crypto.Util.number import *
from gmpy2 import powmod,invert
from itertools import *
from tqdm import *

def mitm(length,p,g,c):
padding_len = 128 - len(b"flag{}") - length
padding = padding_len*long_to_bytes(padding_len)
suffix = bytes_to_long(b"}" + padding)
prefix = bytes_to_long(b"flag{") * 256**(128-5)
c1 = c * invert(powmod(g,suffix,p),p) % p
c2 = c1 * invert(powmod(g,prefix,p),p) % p

#mitm
pre_len = length // 2
suf_len = length - pre_len
base_h_num = powmod(g,256**(padding_len+1+suf_len),p)
base_l_num = powmod(g,256**(padding_len+1),p)
table = "0123456789abcdef"

dic = {}
for i in product(table, repeat=pre_len):
i = "".join(i)
temp = bytes_to_long(i.encode())
tempc = c2 * invert(powmod(base_h_num,temp,p),p) % p
dic[tempc] = i

for j in product(table, repeat=suf_len):
j = "".join(j)
temp = bytes_to_long(j.encode())
tempc = powmod(base_l_num,temp,p)
if(tempc in dic.keys()):
print(dic[tempc])
print(j)
exit()

p = 173383907346370188246634353442514171630882212643019826706575120637048836061602034776136960080336351252616860522273644431927909101923807914940397420063587913080793842100264484222211278105783220210128152062330954876427406484701993115395306434064667136148361558851998019806319799444970703714594938822660931343299
g = 5
c = 105956730578629949992232286714779776923846577007389446302378719229216496867835280661431342821159505656015790792811649783966417989318584221840008436316642333656736724414761508478750342102083967959048112859470526771487533503436337125728018422740023680376681927932966058904269005466550073181194896860353202252854

for i in trange(1,46):
mitm(i,p,g,c)

#flag{61e8007dd65f}



*babyrsa

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from Crypto.Util.number import isPrime, inverse, bytes_to_long
from random import getrandbits, randrange
from collections import namedtuple


Complex = namedtuple("Complex", ["re", "im"])

def complex_mult(c1, c2, modulus):
return Complex(
(c1.re * c2.re - c1.im * c2.im) % modulus,
(c1.re * c2.im + c1.im * c2.re) % modulus,
)

def complex_pow(c, exp, modulus):
result = Complex(1, 0)
while exp > 0:
if exp & 1:
result = complex_mult(result, c, modulus)
c = complex_mult(c, c, modulus)
exp >>= 1
return result

def rand_prime(nbits, kbits, share):
while True:
p = (getrandbits(nbits) << kbits) + share
if p % 4 == 3 and isPrime(p):
return p

def gen():
while True:
k = getrandbits(100)
pp = getrandbits(400) << 100
qq = getrandbits(400) << 100
p = pp + k
q = qq + k
if isPrime(p) and isPrime(q):
break
if q > p:
p, q = q, p

n = p * q
lb = int(n ** 0.675)
ub = int(n ** 0.70)
d = randrange(lb, ub)
e = inverse(d, (p * p - 1) * (q * q - 1))
sk = (p, q, d)
pk = (n, e)
return pk, sk


pk, sk = gen()
n, e = pk
with open("flag.txt", "rb")as f:
flag = f.read()

m = Complex(bytes_to_long(flag[:len(flag)//2]), bytes_to_long(flag[len(flag)//2:]))
c = complex_pow(m, e, n)
print(n)
print(e)
print(c)

题目是基于复数域上的RSA,其他参数的生成都没有什么特殊,有以下几个部分值得注意:

  • p,q低100位相等

  • d在一个特殊的区间内(n^0.675,n^0.7)

  • e和d满足的关系是:

然后照常的,给出我们公钥n,e以及密文c,要求解出明文m得到flag。

可以看出,这个题目如果能求出私钥d,那么只要用复数域里的模幂运算操作去计算pow(c,d,n)就能得到明文m了。所以说这个题目中的复数其实就是套了层壳,我们真正需要做的事情,是想办法根据上面的两个特殊条件解出私钥d,然后就能顺理成章得到m。

不过这里可以补充一下,之所以使用复数运算,其实还可能有一个原因,就是ed满足的这个关系刚好符合复数域下的欧拉函数。也就是在复数域下:

这个道理通俗一点的解释是这样,首先n的欧拉函数依然满足:

然后对于其中任意一个(这里就取p做例子),因为复数域中的元素是a+bi的形式,那么a和b分别取尽0到p-1的值,就共有p^2种组合,自然也就有p^2-1个元素互素。

然而这并不总对任意p满足,这要求p是4k+3形式的素数,正好也对应题目给的条件。更细节的解释可以看:

rsa - Why is $\phi(N) = (p^2 -1) (q^2 - 1)$ here? - Cryptography Stack Exchange

以上是对一些背景知识的简单介绍,接下来进入到解题环节。我们的目标就是,根据如下式子:

以及p,q低100位相等、d比较小的特性来想办法解出d。

而d比较小这一点,很容易联想到Boneh&Duree做低解密指数攻击所做的构造多项式来进行copper的方法。首先把式子展开:

这里需要注意到一个事情,既然e是和(p^2-1)(q^2-1)差不多数量级的值,那么k应该就和d数量级接近,所以在之后可以看成是一个”小根”。

接下来把右边写作乘起来的形式:

转到模e下分析:

这个时候其实已经可以尝试造多项式了,把k和p^2+q^2分别看作x和y,就有:

然后尝试二元copper,发现这样要求解出来的话,d的上界远远达不到题目的0.7,因此需要做一定优化。发现p,q低位相等的条件还没有用上,而又因为:

于是可以把上面的式子改成:

也就是:

这样又哪里做了优化呢?这是因为p,q既然低100位相等,那么p-q的低100bit就全是0,所以p-q可以写为:

也就有:

然后依然把k看作x,但把u^2看作y,多项式就变成:

这样一来y的数量级降下去了200bit,根更小了。但是实际测试发现,即使这样构造多项式,d的上界也最多取到0.61才能解出根,因此还要进行优化。

而进一步的优化就肯定需要论文才能做到了,这一篇论文其实相比起来还是蛮好找的:

A new attack on some RSA variants - ScienceDirect

按照论文里面的内容逐步实现,就能造出一个根更小的多项式,然而做到这一步,用寻常板子的二元copper依然是出不了结果的(甚至论文样例也不行)。这是因为论文里的多项式构造也和一般使用的不同,需要对着改一遍才行。

对着改之后测一遍样例,能出就可以上靶机了。

exp:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from Pwn4Sage.pwn import *
import itertools
from gmpy2 import iroot
from Crypto.Util.number import *
from collections import namedtuple
from tqdm import *


Complex = namedtuple("Complex", ["re", "im"])

def complex_mult(c1, c2, modulus):
return Complex(
(c1.re * c2.re - c1.im * c2.im) % modulus,
(c1.re * c2.im + c1.im * c2.re) % modulus,
)

def complex_pow(c, exp, modulus):
result = Complex(1, 0)
while exp > 0:
if exp & 1:
result = complex_mult(result, c, modulus)
c = complex_mult(c, c, modulus)
exp >>= 1
return result

def small_roots(e, m, t, X, Y,a1,a2,a3):
PR = PolynomialRing(QQ, 'x,y', 2, order='lex')
x, y = PR.gens()
f = x*y^2 + a1*x*y + a2*x + a3

G_polys = []
for k in range(m + 1):
for i_1 in range(k, m+1):
for i_2 in [2*k, 2*k + 1]:
G_polys.append(x**(i_1-k) * y**(i_2-2*k) * f**k * e**(m-k))

H_polys = []
for k in range(m + 1):
for i_2 in range(2*k+2, 2*k+t+1):
H_polys.append(y**(i_2-2*k) * f**k * e**(m-k))

polys = G_polys + H_polys
monomials = []
for poly in polys:
monomials.append(poly.lm())

dims1 = len(polys)
dims2 = len(monomials)
MM = matrix(QQ, dims1, dims2)
for idx, poly in enumerate(polys):
for idx_, monomial in enumerate(monomials):
if monomial in poly.monomials():
MM[idx, idx_] = poly.monomial_coefficient(monomial) * monomial(X, Y)
B = MM.LLL()

found_polynomials = False

for pol1_idx in range(B.nrows()):
for pol2_idx in range(pol1_idx + 1, B.nrows()):
P = PolynomialRing(QQ, 'a,b', 2)
a, b = P.gens()
pol1 = pol2 = 0
for idx_, monomial in enumerate(monomials):
pol1 += monomial(a,b) * B[pol1_idx, idx_] / monomial(X, Y)
pol2 += monomial(a,b) * B[pol2_idx, idx_] / monomial(X, Y)

# resultant
rr = pol1.resultant(pol2)
# are these good polynomials?
if rr.is_zero() or rr.monomials() == [1]:
continue
else:
print(f"found them, using vectors {pol1_idx}, {pol2_idx}")
found_polynomials = True
break
if found_polynomials:
break

if not found_polynomials:
print("no independant vectors could be found. This should very rarely happen...")


PRq = PolynomialRing(QQ, 'z')
z = PRq.gen()
rr = rr(z, z)
soly = rr.roots()[0][0]

ppol = pol1(z, soly)
solx = ppol.roots()[0][0]
return solx, soly


#get data
r = remote("node4.anna.nssctf.cn",28334)
n = int(r.recvline().strip().decode())
e = int(r.recvline().strip().decode())
c = eval(r.recvline().strip().decode())


#theorem
alpha = 1.997
beta = 0.1
delta = 0.7
bounds = (int(2*n^(alpha+delta-2)),int(3*n^(0.5-2*beta)))
r= 100
u0 = var('u0')
mod = 2^r
solve_u0 = solve_mod([u0^2-n == 0], mod)

for u0 in tqdm(solve_u0):
u0 = int(u0[0])
v0 = (2*u0 + (n-u0^2)*inverse(u0,mod^2)) % (mod^2)
a1 = v0*inverse((mod^2)//2,e) % e
a2 = -((n+1)^2-v0^2) * inverse(mod^4,e) % e
a3 = -inverse(mod^4,e) % e

PR = PolynomialRing(Zmod(e), names='x,y')
x, y = PR.gens()
solves = small_roots(e,4,4,bounds[0],bounds[1],a1,a2,a3)
print(solves)

if(solves != []):
try:
v = int(solves[1])
pplusq = mod^2*v+v0
pminusq = iroot(pplusq^2-4*n,2)[0]
p = (pplusq + pminusq) // 2
q = n // p
phi = (p^2-1)*(q^2-1)
d = inverse(e,phi)
flag = complex_pow(c, d, n)
print(long_to_bytes(int(flag.re)))
print(long_to_bytes(int(flag.im)))
except:
pass

#NSSCTF{b2f4e9c8-fb45-452c-81ab-82f9f1150f48}



*recovery

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from Crypto.Util.number import *
from Crypto.Cipher import AES
from secrets import flag
from math import ceil
from binascii import unhexlify
from os import urandom

BLOCKSIZE = 16

def pad(s):
length = BLOCKSIZE - len(s) % BLOCKSIZE
return s + b"\x00" * length

def rsa_pad(m, size):
return b"\x00" * 2 + m + urandom(size - 2 - len(m))

def genkey(bits):
e = 65537
p = getPrime(bits)
q = getPrime(bits)
if p > q:
p, q = q, p
d = inverse(e, (p - 1) * (q - 1))
n = p * q
u = inverse(q, p)
dp = d % (p - 1)
dq = d % (q - 1)
sk = (n, e, d, p, q, dp, dq, u)
pk = (n, e)
return sk, pk

def rsa_encrypt(m, pk):
n, e = pk
m_padded = rsa_pad(m, ceil(n.bit_length() / 8))
c = pow(bytes_to_long(m_padded), e, n)
return long_to_bytes(c)

def rsa_decrypt(c, sk):
c = bytes_to_long(c)
n, e, d, p, q, dp, dq, u = sk
mp = pow(c, dp, p)
mq = pow(c, dq, q)
t = (mp - mq) % p
h = (u * t) % p
m = (h * q + mq) % n
m = long_to_bytes(m)
m = m.rjust(ceil(n.bit_length() / 8), b"\x00")
return m[2:]

def encrypt_sk(sk, kM):
n, e, d, p, q, dp, dq, u = sk
lp = long_to_bytes(ceil(p.bit_length() / 8) * 8)
lq = long_to_bytes(ceil(q.bit_length() / 8) * 8)
ld = long_to_bytes(ceil(d.bit_length() / 8) * 8)
lu = long_to_bytes(ceil(u.bit_length() / 8) * 8)
s = lq + long_to_bytes(q) + lp + long_to_bytes(p) + ld + long_to_bytes(d) + lu + long_to_bytes(u)
s = pad(s)
return aes_encrypt(s, kM)

def aes_encrypt(pt, key):
aes = AES.new(key, AES.MODE_ECB)
return aes.encrypt(pt)

def aes_decrypt(ct, key):
aes = AES.new(key, AES.MODE_ECB)
return aes.decrypt(ct)

def decrypt_sk(sk_enc, kM):
sk = aes_decrypt(sk_enc, kM)

lq = sk[:2]
sk = sk[2:]
q = bytes_to_long(sk[:bytes_to_long(lq) // 8])
sk = sk[bytes_to_long(lq) // 8:]

lp = sk[:2]
sk = sk[2:]
p = bytes_to_long(sk[:bytes_to_long(lp) // 8])
sk = sk[bytes_to_long(lp) // 8:]

ld = sk[:2]
sk = sk[2:]
d = bytes_to_long(sk[:bytes_to_long(ld) // 8])
sk = sk[bytes_to_long(ld) // 8:]

lu = sk[:2]
sk = sk[2:]
u = bytes_to_long(sk[:bytes_to_long(lu) // 8])
return p, q, d, u

def query(sk_enc, c):
p, q, d, u = decrypt_sk(sk_enc, master_key)
dp = d % (p - 1)
dq = d % (q - 1)
phi = (p - 1) * (q - 1)
n = p * q
e = inverse(d, phi)
sk = (n, e, d, p, q, dp, dq, u)
return rsa_decrypt(c, sk)[:43]

query_times = 0
sk, pk = genkey(1024)
n, e, d, p, q, dp, dq, u = sk
master_key = urandom(16)
wrapped_key = encrypt_sk(sk, master_key)

print(f"my pk: {pk}")
print(f"wrapped_key: {wrapped_key.hex()}")
while query_times < 16:
ct = unhexlify(input("ct:").strip().encode())
wkey = unhexlify(input("wkey:").strip().encode())
result = query(wkey, ct)
print(f"result: {result.hex()}")
query_times += 1

user_input = int(input("p = ").strip())
if user_input == p:
print(flag)

简单说一下题目流程:题目首先照常生成RSA的密钥,密钥包含有私钥sk和公钥pk。

私钥sk:

其中u为q关于p的逆元,主要用在一会儿的crt计算明文中。

公钥pk:

在这之后,题目生成随机16字节masterkey,用于之后的AES。

然后就是关键环节,题目把私钥的p,q,d,u这几个信息按如下方式拼接,并用AES按ECB方式加密,得到一个结果wrapped_key:

1
| |q| |    q    | |p| |    p    | |d| |        d        | |u| |    u    |padding|

在这之后,我们可以从靶机得到pk和wrapped_key两个量,然后就进入如下交互:

  • 总共进行16次
  • 每次输入密文ct及w_key
  • 靶机会将w_key用AES解密,解密后按照上述拼接方法拆出q、p、d、u四个量,并且用这四个值重新计算dp,dq,phi,n,e,然后进行如下方式的RSA解密(这个解密方式有什么问题之后会说):
1
2
3
4
5
6
7
8
9
10
11
def rsa_decrypt(c, sk):
c = bytes_to_long(c)
n, e, d, p, q, dp, dq, u = sk
mp = pow(c, dp, p)
mq = pow(c, dq, q)
t = (mp - mq) % p
h = (u * t) % p
m = (h * q + mq) % n
m = long_to_bytes(m)
m = m.rjust(ceil(n.bit_length() / 8), b"\x00")
return m[2:]
  • 然后,靶机会用上述量计算ct的明文,并且返回结果的前43字节

我们需要利用这16次交互得到p的值,输入给靶机得到flag。接下来首先讲讲我的思路,可能对解出本题并没有什么意义,有兴趣的师傅可以看看:

我的思路

首先要注意的是我们该如何选择传给靶机的w_key,因为这关系到靶机解出来的q、p、d、u四个量究竟是多少。而由于AES的密钥masterkey我们并不知道,所以我们任改其中某个比特的话,他所在分组的全部128个比特都会变得混乱。

但是这正好可以被我们所利用,这是因为刚才提到的本题通过crt计算明文的方式:

1
2
3
4
5
6
7
8
9
10
11
def rsa_decrypt(c, sk):
c = bytes_to_long(c)
n, e, d, p, q, dp, dq, u = sk
mp = pow(c, dp, p)
mq = pow(c, dq, q)
t = (mp - mq) % p
h = (u * t) % p
m = (h * q + mq) % n
m = long_to_bytes(m)
m = m.rjust(ceil(n.bit_length() / 8), b"\x00")
return m[2:]

可以实验得知,这种计算方式,在所有参数均正确的时候,计算结果是和我们常知的crt计算方式相同的,但如果有参数出现故障就会使结果出现差异,比如可以用如下脚本简单测试一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from Crypto.Util.number import *

def genkey(bits):
e = 65537
p = getPrime(bits)
q = getPrime(bits)
if p > q:
p, q = q, p
d = inverse(e, (p - 1) * (q - 1))
n = p * q
u = inverse(q, p)
dp = d % (p - 1)
dq = d % (q - 1)
sk = (n, e, d, p, q, dp, dq, u)
pk = (n, e)
return sk, pk

def rsa_decrypt(c, sk):
n, e, d, p, q, dp, dq, u = sk
mp = pow(c, dp, p)
mq = pow(c, dq, q)
t = (mp - mq) % p
h = (u * t) % p
m = (h * q + mq) % n
return m

sk, pk = genkey(1024)
n, e, d, p, q, dp, dq, u = sk

p = getPrime(1024)
dp = d %(p-1)
sk = (n, e, d, p, q, dp, dq, u)

m = 5
c = pow(m,e,n)
print(rsa_decrypt(c, sk))

from sympy.ntheory.modular import crt
nlist = [p,q]
clist = [pow(c, dp, p),pow(c, dq, q)]
print(crt(nlist,clist)[0])

而在仅有p出错的时候,本题中的计算方式就可以给我们提供一些信息,具体来说这个过程是这样:

  • 我们改变wrapped_key中的p中对应的某一分组,AES解密后该分组就会出错,于是我们得到了一个错误的p和正确的q、d、u
  • 然后,我们进一步会得到错误的dp、n与正确的dq

  • 然后进入rsa_decrypt函数,我们首先会得到正确的mq和错误的mp,进一步也就得到了错误的t和h

在这之后就到关键的一步了,接下来m是这样计算:

1
m = (h * q + mq) % n

其中h是小于p的,因此当mq较小的时候,整个式子的模可以脱掉,变成:

1
m = h * q + mq

而如何使mq较小也很简单,我们令ct为:

那么mq就是2(因为dq,q都是正确量)。所以说,如果题目返回的m不是前43字节,而是完整的话,我们就可以用两个不同的错误p,得到:

然后求GCD(m1-2,m2-2)就能得到q,自然也就有p。

而如果题目给的是[:-43]而不是[:43]的话,我们依然有办法,这时候我们能用16个不同的错误pi,靶机会计算出16组如下mi:

但是我们只能得到这些mi的高位Mi,然而我们把Mi左移43*8位的话,其实Mi可以写作:

ri就是被隐藏的低43字节,再代入mi的表达式就有:

可以发现,这样的式子里,Mi已知,2-ri较小,所以这其实就可以变成一个AGCD问题来求解了。

以上就是我对于这个题目的一些思考。而这个思路之所以没办法最终用于解题,就是因为题目给的是[:43],这导致ri甚至大于p的数量级,整个式子甚至不能写成带余除法,所以没法用AGCD做。

正解

在询问了其他师傅后才明白这其实也是个论文题,论文如下(题目对应第二种攻击方式):

959.pdf (iacr.org)

然后也有对应板子:

keeganryan/attacks-poc: PoC for our attacks on MEGA. (github.com)

题目返回的43这个字节数甚至都是完全一样的。



*1515

题目:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from para import m, n, q, v, flag
import json

Zq = IntegerModRing(q)

A0 = random_matrix(Zq, m-n, n)
A1515 = matrix.block(Zq, [[identity_matrix(ZZ, n)], [A0]])
u = random_vector(Zq, n)

print(A0)
print(u)

x = vector(ZZ, json.loads(input("Give me a list of numbers: ")))

if x * A1515 != u:
print("Failed!", x * A1515)
exit(0)
if x.norm().n() >= v:
print("Failed!", v)
exit(0)

print(flag)

这个题赛中我一直在乱搞,最后也没搞出个所以然。维数低的时候还能用LLL结合左核去找解空间里的短向量,维度大了的时候效果很差不说,用时还很长。

赛后知道这是个论文题并且也有现成攻击代码。然而我很多环境都还没配,就暂时不复现了,有兴趣的师傅可以参照下面这篇,有论文链接和代码链接:

2023强网杯密码部分题解 | Tover’ Blog