mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Implement AES-GCM
This commit is contained in:
parent
84f304b7ca
commit
c3ec8fc5d8
228
src/pyj/aes.pyj
228
src/pyj/aes.pyj
@ -56,7 +56,30 @@ def increment_counter(c):
|
|||||||
c[i] += 1
|
c[i] += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
# Lookup tables {{{
|
def convert_to_int32(bytes, output, offset, length):
|
||||||
|
offset = offset or 0
|
||||||
|
length = length or bytes.length
|
||||||
|
for v'var i = offset, j = 0; i < offset + length; i += 4, j++':
|
||||||
|
output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3]
|
||||||
|
|
||||||
|
def convert_to_int32_pad(bytes):
|
||||||
|
extra = bytes.length % 4
|
||||||
|
if extra:
|
||||||
|
t = Uint8Array(bytes.length + 4 - extra)
|
||||||
|
t.set(bytes)
|
||||||
|
bytes = t
|
||||||
|
ans = Uint32Array(bytes.length / 4)
|
||||||
|
convert_to_int32(bytes, ans)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def from_64_to_32(num):
|
||||||
|
# convert 64-bit number to two BE Int32s
|
||||||
|
ans = Uint32Array(2)
|
||||||
|
ans[0] = (num / 0x100000000) | 0
|
||||||
|
ans[1] = num & 0xFFFFFFFF
|
||||||
|
return ans
|
||||||
|
|
||||||
|
# Lookup tables for AES {{{
|
||||||
# Number of rounds by keysize
|
# Number of rounds by keysize
|
||||||
number_of_rounds = {16: 10, 24: 12, 32: 14}
|
number_of_rounds = {16: 10, 24: 12, 32: 14}
|
||||||
# Round constant words
|
# Round constant words
|
||||||
@ -86,13 +109,7 @@ U4 = v'new Uint32Array([0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c
|
|||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
def convert_to_int32(bytes, output, offset, length):
|
class AES: # {{{
|
||||||
offset = offset or 0
|
|
||||||
length = length or bytes.length
|
|
||||||
for v'var i = offset, j = 0; i < offset + length; i += 4, j++':
|
|
||||||
output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3]
|
|
||||||
|
|
||||||
class AES:
|
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
self.working_mem = [Uint32Array(4), Uint32Array(4)]
|
self.working_mem = [Uint32Array(4), Uint32Array(4)]
|
||||||
@ -242,7 +259,7 @@ random_bytes = random_bytes_secure if type(crypto) is not 'undefined' and type(c
|
|||||||
if random_bytes is random_bytes_insecure:
|
if random_bytes is random_bytes_insecure:
|
||||||
print('WARNING: Using insecure RNG for AES')
|
print('WARNING: Using insecure RNG for AES')
|
||||||
|
|
||||||
class ModeOfOperation:
|
class ModeOfOperation: # {{{
|
||||||
|
|
||||||
def __init__(self, key):
|
def __init__(self, key):
|
||||||
self.key = key or generate_key(32)
|
self.key = key or generate_key(32)
|
||||||
@ -260,6 +277,100 @@ class ModeOfOperation:
|
|||||||
if type(tag) is 'string':
|
if type(tag) is 'string':
|
||||||
return string_to_bytes(tag)
|
return string_to_bytes(tag)
|
||||||
raise TypeError('Invalid tag, must be a string or a Uint8Array')
|
raise TypeError('Invalid tag, must be a string or a Uint8Array')
|
||||||
|
# }}}
|
||||||
|
|
||||||
|
class GaloisField: # {{{
|
||||||
|
|
||||||
|
def __init__(self, sub_key):
|
||||||
|
k32 = Uint32Array(4)
|
||||||
|
convert_to_int32(sub_key, k32, 0)
|
||||||
|
self.m = self.generate_hash_table(k32)
|
||||||
|
self.wmem = Uint32Array(4)
|
||||||
|
|
||||||
|
def power(self, x, out):
|
||||||
|
lsb = x[3] & 1
|
||||||
|
for v'var i = 3; i > 0; --i':
|
||||||
|
out[i] = (x[i] >>> 1) | ((x[i - 1] & 1) << 31)
|
||||||
|
out[0] = x[0] >>> 1
|
||||||
|
if lsb:
|
||||||
|
out[0] ^= 0xE1000000
|
||||||
|
|
||||||
|
def multiply(self, x, y):
|
||||||
|
z_i = Uint32Array(4)
|
||||||
|
v_i = y.slice(0)
|
||||||
|
for v'var i = 0; i < 128; ++i':
|
||||||
|
x_i = x[(i / 32) | 0] & (1 << (31 - i % 32))
|
||||||
|
if x_i:
|
||||||
|
z_i[0] ^= v_i[0]
|
||||||
|
z_i[1] ^= v_i[1]
|
||||||
|
z_i[2] ^= v_i[2]
|
||||||
|
z_i[3] ^= v_i[3]
|
||||||
|
self.power(v_i, v_i)
|
||||||
|
return z_i
|
||||||
|
|
||||||
|
def generate_sub_hash_table(self, mid):
|
||||||
|
bits = mid.length
|
||||||
|
size = 1 << bits
|
||||||
|
half = size >>> 1
|
||||||
|
m = Array(size)
|
||||||
|
m[half] = mid.slice(0)
|
||||||
|
i = half >>> 1
|
||||||
|
while i > 0:
|
||||||
|
m[i] = Uint32Array(4)
|
||||||
|
self.power(m[2 * i], m[i])
|
||||||
|
i >>= 1
|
||||||
|
i = 2
|
||||||
|
while i < half:
|
||||||
|
for v'var j = 1; j < i; ++j':
|
||||||
|
m_i = m[i]
|
||||||
|
m_j = m[j]
|
||||||
|
m[i + j] = x = Uint32Array(4)
|
||||||
|
for v'var c = 0; c < 4; c++':
|
||||||
|
x[c] = m_i[c] ^ m_j[c]
|
||||||
|
i *= 2
|
||||||
|
m[0] = Uint32Array(4)
|
||||||
|
for v'i = half + 1; i < size; ++i':
|
||||||
|
x = m[i ^ half]
|
||||||
|
m[i] = y = Uint32Array(4)
|
||||||
|
for v'var c = 0; c < 4; c++':
|
||||||
|
y[c] = mid[c] ^ x[c]
|
||||||
|
return m
|
||||||
|
|
||||||
|
def generate_hash_table(self, key_as_int32_array):
|
||||||
|
bits = key_as_int32_array.length
|
||||||
|
multiplier = 8 / bits
|
||||||
|
per_int = 4 * multiplier
|
||||||
|
size = 16 * multiplier
|
||||||
|
ans = Array(size)
|
||||||
|
for v'var i =0; i < size; ++i':
|
||||||
|
tmp = Uint32Array(4)
|
||||||
|
idx = (i/ per_int) | 0
|
||||||
|
shft = ((per_int - 1 - (i % per_int)) * bits)
|
||||||
|
tmp[idx] = (1 << (bits - 1)) << shft
|
||||||
|
ans[i] = self.generate_sub_hash_table(self.multiply(tmp, key_as_int32_array))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def table_multiply(self, x):
|
||||||
|
z = Uint32Array(4)
|
||||||
|
for v'var i = 0; i < 32; ++i':
|
||||||
|
idx = (i / 8) | 0
|
||||||
|
x_i = (x[idx] >>> ((7 - (i % 8)) * 4)) & 0xF
|
||||||
|
ah = self.m[i][x_i]
|
||||||
|
z[0] ^= ah[0]
|
||||||
|
z[1] ^= ah[1]
|
||||||
|
z[2] ^= ah[2]
|
||||||
|
z[3] ^= ah[3]
|
||||||
|
return z
|
||||||
|
|
||||||
|
def ghash(self, x, y):
|
||||||
|
z = self.wmem
|
||||||
|
z[0] = y[0] ^ x[0]
|
||||||
|
z[1] = y[1] ^ x[1]
|
||||||
|
z[2] = y[2] ^ x[2]
|
||||||
|
z[3] = y[3] ^ x[3]
|
||||||
|
return self.table_multiply(z)
|
||||||
|
|
||||||
|
# }}}
|
||||||
|
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
@ -388,6 +499,96 @@ class CTR(ModeOfOperation): # {{{
|
|||||||
return bytes_to_string(b, offset)
|
return bytes_to_string(b, offset)
|
||||||
# }}}
|
# }}}
|
||||||
|
|
||||||
|
class GCM(ModeOfOperation):
|
||||||
|
|
||||||
|
# See http://web.cs.ucdavis.edu/~rogaway/ocb/gcm.pdf
|
||||||
|
|
||||||
|
def __init__(self, key):
|
||||||
|
ModeOfOperation.__init__(self, key)
|
||||||
|
|
||||||
|
# Generate the hash subkey
|
||||||
|
H = Uint8Array(16)
|
||||||
|
self.aes.encrypt(Uint8Array(16), H, 0)
|
||||||
|
self.galois = GaloisField(H)
|
||||||
|
|
||||||
|
# Working memory
|
||||||
|
self.J0 = Uint32Array(4)
|
||||||
|
self.wmem = Uint32Array(4)
|
||||||
|
self.out_block = Uint8Array(16)
|
||||||
|
|
||||||
|
def _create_j0(self, iv):
|
||||||
|
J0 = self.J0
|
||||||
|
if iv.length is 12:
|
||||||
|
convert_to_int32(iv, J0)
|
||||||
|
J0[3] = 1
|
||||||
|
else:
|
||||||
|
J0.fill(0)
|
||||||
|
tmp = convert_to_int32_pad(iv)
|
||||||
|
while tmp.length:
|
||||||
|
J0 = self.galois.ghash(J0, tmp)
|
||||||
|
tmp = tmp.subarray(4)
|
||||||
|
tmp = Uint32Array(4)
|
||||||
|
tmp.set(from_64_to_32(iv.length * 8), 2)
|
||||||
|
J0 = self.galois.ghash(J0, tmp)
|
||||||
|
return J0
|
||||||
|
|
||||||
|
def _crypt(self, iv, bytes, additional_data, decrypt):
|
||||||
|
J0 = self._create_j0(iv)
|
||||||
|
# Generate initial counter block
|
||||||
|
in_block = J0.slice(0)
|
||||||
|
in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter
|
||||||
|
outbytes = Uint8Array(bytes.length)
|
||||||
|
ghash = self.galois.ghash.bind(self.galois)
|
||||||
|
|
||||||
|
# Process additional_data
|
||||||
|
overflow = additional_data.length % 16
|
||||||
|
if overflow:
|
||||||
|
t = Uint8Array(additional_data.length + 16 - overflow)
|
||||||
|
t.set(additional_data)
|
||||||
|
additional_data = t
|
||||||
|
S = Uint32Array(4)
|
||||||
|
while additional_data.length:
|
||||||
|
additional_data = additional_data.subarray(4)
|
||||||
|
convert_to_int32(additional_data, self.wmem, 0, 16)
|
||||||
|
S = ghash(S, self.wmem)
|
||||||
|
|
||||||
|
# Create the ciphertext, encrypting block by block
|
||||||
|
for v'var pos = 0; pos < bytes.length; pos += 16':
|
||||||
|
self.aes.encrypt32(in_block, self.out_block, 0)
|
||||||
|
num = min(16, bytes.length - pos) # noqa: unused-local
|
||||||
|
for v'var i = 0; i < num; i++':
|
||||||
|
outbytes[pos + i] = bytes[pos+i] ^ self.out_block[i]
|
||||||
|
convert_to_int32(self.out_block, self.wmem)
|
||||||
|
S = ghash(S, self.wmem)
|
||||||
|
in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter
|
||||||
|
|
||||||
|
# Mix the lengths into S
|
||||||
|
lengths = Uint32Array(4)
|
||||||
|
lengths.set(from_64_to_32(additional_data.length * 8))
|
||||||
|
lengths.set(from_64_to_32(bytes.length * 8))
|
||||||
|
S = ghash(S, lengths)
|
||||||
|
|
||||||
|
# Create the tag
|
||||||
|
self.aes.encrypt32(J0, self.out_block, 0)
|
||||||
|
convert_to_int32(self.out_block, self.wmem)
|
||||||
|
tag = Uint32Array(4)
|
||||||
|
for v'var i = 0; i < S.length; i++':
|
||||||
|
tag[i] = S[i] ^ self.wmem[i]
|
||||||
|
return {'iv':iv, 'cipherbytes':outbytes, 'tag':tag}
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt(self, plaintext, tag):
|
||||||
|
iv = random_bytes(12)
|
||||||
|
return self._crypt(iv, string_to_bytes(plaintext), self.tag_as_bytes(tag))
|
||||||
|
|
||||||
|
def decrypt(self, output_from_encrypt, tag):
|
||||||
|
if output_from_encrypt.tag.length != 4:
|
||||||
|
raise ValueError('Corrupted message')
|
||||||
|
ans = self._crypt(output_from_encrypt.iv, output_from_encrypt.cipherbytes, self.tag_as_bytes(tag), True)
|
||||||
|
if ans.tag != output_from_encrypt.tag:
|
||||||
|
raise ValueError('Corrupted message')
|
||||||
|
return bytes_to_string(ans.cipherbytes)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
text = 'testing a basic roundtrip ø̄ū'
|
text = 'testing a basic roundtrip ø̄ū'
|
||||||
|
|
||||||
@ -404,7 +605,14 @@ if __name__ == '__main__':
|
|||||||
crypted = ctr.encrypt(text)
|
crypted = ctr.encrypt(text)
|
||||||
decrypted = ctr.decrypt(crypted)
|
decrypted = ctr.decrypt(crypted)
|
||||||
print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED')
|
print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED')
|
||||||
|
|
||||||
crypted = ctr.encrypt(text, secret_tag)
|
crypted = ctr.encrypt(text, secret_tag)
|
||||||
decrypted = ctr.decrypt(crypted, secret_tag)
|
decrypted = ctr.decrypt(crypted, secret_tag)
|
||||||
print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')
|
print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')
|
||||||
|
|
||||||
|
gcm = GCM()
|
||||||
|
crypted = gcm.encrypt(text)
|
||||||
|
decrypted = gcm.decrypt(crypted)
|
||||||
|
print('GCM Roundtrip:', 'OK' if text is decrypted else 'FAILED')
|
||||||
|
crypted = gcm.encrypt(text, secret_tag)
|
||||||
|
decrypted = gcm.decrypt(crypted, secret_tag)
|
||||||
|
print('GCM Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user