Implement AES-GCM

This commit is contained in:
Kovid Goyal 2016-03-22 17:25:31 +05:30
parent 84f304b7ca
commit c3ec8fc5d8

View File

@ -56,7 +56,30 @@ def increment_counter(c):
c[i] += 1 c[i] += 1
break break
# Lookup tables {{{ def convert_to_int32(bytes, output, offset, length):
offset = offset or 0
length = length or bytes.length
for v'var i = offset, j = 0; i < offset + length; i += 4, j++':
output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3]
def convert_to_int32_pad(bytes):
extra = bytes.length % 4
if extra:
t = Uint8Array(bytes.length + 4 - extra)
t.set(bytes)
bytes = t
ans = Uint32Array(bytes.length / 4)
convert_to_int32(bytes, ans)
return ans
def from_64_to_32(num):
# convert 64-bit number to two BE Int32s
ans = Uint32Array(2)
ans[0] = (num / 0x100000000) | 0
ans[1] = num & 0xFFFFFFFF
return ans
# Lookup tables for AES {{{
# Number of rounds by keysize # Number of rounds by keysize
number_of_rounds = {16: 10, 24: 12, 32: 14} number_of_rounds = {16: 10, 24: 12, 32: 14}
# Round constant words # Round constant words
@ -86,13 +109,7 @@ U4 = v'new Uint32Array([0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c
# }}} # }}}
def convert_to_int32(bytes, output, offset, length): class AES: # {{{
offset = offset or 0
length = length or bytes.length
for v'var i = offset, j = 0; i < offset + length; i += 4, j++':
output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3]
class AES:
def __init__(self, key): def __init__(self, key):
self.working_mem = [Uint32Array(4), Uint32Array(4)] self.working_mem = [Uint32Array(4), Uint32Array(4)]
@ -242,7 +259,7 @@ random_bytes = random_bytes_secure if type(crypto) is not 'undefined' and type(c
if random_bytes is random_bytes_insecure: if random_bytes is random_bytes_insecure:
print('WARNING: Using insecure RNG for AES') print('WARNING: Using insecure RNG for AES')
class ModeOfOperation: class ModeOfOperation: # {{{
def __init__(self, key): def __init__(self, key):
self.key = key or generate_key(32) self.key = key or generate_key(32)
@ -260,6 +277,100 @@ class ModeOfOperation:
if type(tag) is 'string': if type(tag) is 'string':
return string_to_bytes(tag) return string_to_bytes(tag)
raise TypeError('Invalid tag, must be a string or a Uint8Array') raise TypeError('Invalid tag, must be a string or a Uint8Array')
# }}}
class GaloisField: # {{{
def __init__(self, sub_key):
k32 = Uint32Array(4)
convert_to_int32(sub_key, k32, 0)
self.m = self.generate_hash_table(k32)
self.wmem = Uint32Array(4)
def power(self, x, out):
lsb = x[3] & 1
for v'var i = 3; i > 0; --i':
out[i] = (x[i] >>> 1) | ((x[i - 1] & 1) << 31)
out[0] = x[0] >>> 1
if lsb:
out[0] ^= 0xE1000000
def multiply(self, x, y):
z_i = Uint32Array(4)
v_i = y.slice(0)
for v'var i = 0; i < 128; ++i':
x_i = x[(i / 32) | 0] & (1 << (31 - i % 32))
if x_i:
z_i[0] ^= v_i[0]
z_i[1] ^= v_i[1]
z_i[2] ^= v_i[2]
z_i[3] ^= v_i[3]
self.power(v_i, v_i)
return z_i
def generate_sub_hash_table(self, mid):
bits = mid.length
size = 1 << bits
half = size >>> 1
m = Array(size)
m[half] = mid.slice(0)
i = half >>> 1
while i > 0:
m[i] = Uint32Array(4)
self.power(m[2 * i], m[i])
i >>= 1
i = 2
while i < half:
for v'var j = 1; j < i; ++j':
m_i = m[i]
m_j = m[j]
m[i + j] = x = Uint32Array(4)
for v'var c = 0; c < 4; c++':
x[c] = m_i[c] ^ m_j[c]
i *= 2
m[0] = Uint32Array(4)
for v'i = half + 1; i < size; ++i':
x = m[i ^ half]
m[i] = y = Uint32Array(4)
for v'var c = 0; c < 4; c++':
y[c] = mid[c] ^ x[c]
return m
def generate_hash_table(self, key_as_int32_array):
bits = key_as_int32_array.length
multiplier = 8 / bits
per_int = 4 * multiplier
size = 16 * multiplier
ans = Array(size)
for v'var i =0; i < size; ++i':
tmp = Uint32Array(4)
idx = (i/ per_int) | 0
shft = ((per_int - 1 - (i % per_int)) * bits)
tmp[idx] = (1 << (bits - 1)) << shft
ans[i] = self.generate_sub_hash_table(self.multiply(tmp, key_as_int32_array))
return ans
def table_multiply(self, x):
z = Uint32Array(4)
for v'var i = 0; i < 32; ++i':
idx = (i / 8) | 0
x_i = (x[idx] >>> ((7 - (i % 8)) * 4)) & 0xF
ah = self.m[i][x_i]
z[0] ^= ah[0]
z[1] ^= ah[1]
z[2] ^= ah[2]
z[3] ^= ah[3]
return z
def ghash(self, x, y):
z = self.wmem
z[0] = y[0] ^ x[0]
z[1] = y[1] ^ x[1]
z[2] = y[2] ^ x[2]
z[3] = y[3] ^ x[3]
return self.table_multiply(z)
# }}}
# }}} # }}}
@ -388,6 +499,96 @@ class CTR(ModeOfOperation): # {{{
return bytes_to_string(b, offset) return bytes_to_string(b, offset)
# }}} # }}}
class GCM(ModeOfOperation):
# See http://web.cs.ucdavis.edu/~rogaway/ocb/gcm.pdf
def __init__(self, key):
ModeOfOperation.__init__(self, key)
# Generate the hash subkey
H = Uint8Array(16)
self.aes.encrypt(Uint8Array(16), H, 0)
self.galois = GaloisField(H)
# Working memory
self.J0 = Uint32Array(4)
self.wmem = Uint32Array(4)
self.out_block = Uint8Array(16)
def _create_j0(self, iv):
J0 = self.J0
if iv.length is 12:
convert_to_int32(iv, J0)
J0[3] = 1
else:
J0.fill(0)
tmp = convert_to_int32_pad(iv)
while tmp.length:
J0 = self.galois.ghash(J0, tmp)
tmp = tmp.subarray(4)
tmp = Uint32Array(4)
tmp.set(from_64_to_32(iv.length * 8), 2)
J0 = self.galois.ghash(J0, tmp)
return J0
def _crypt(self, iv, bytes, additional_data, decrypt):
J0 = self._create_j0(iv)
# Generate initial counter block
in_block = J0.slice(0)
in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter
outbytes = Uint8Array(bytes.length)
ghash = self.galois.ghash.bind(self.galois)
# Process additional_data
overflow = additional_data.length % 16
if overflow:
t = Uint8Array(additional_data.length + 16 - overflow)
t.set(additional_data)
additional_data = t
S = Uint32Array(4)
while additional_data.length:
additional_data = additional_data.subarray(4)
convert_to_int32(additional_data, self.wmem, 0, 16)
S = ghash(S, self.wmem)
# Create the ciphertext, encrypting block by block
for v'var pos = 0; pos < bytes.length; pos += 16':
self.aes.encrypt32(in_block, self.out_block, 0)
num = min(16, bytes.length - pos) # noqa: unused-local
for v'var i = 0; i < num; i++':
outbytes[pos + i] = bytes[pos+i] ^ self.out_block[i]
convert_to_int32(self.out_block, self.wmem)
S = ghash(S, self.wmem)
in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter
# Mix the lengths into S
lengths = Uint32Array(4)
lengths.set(from_64_to_32(additional_data.length * 8))
lengths.set(from_64_to_32(bytes.length * 8))
S = ghash(S, lengths)
# Create the tag
self.aes.encrypt32(J0, self.out_block, 0)
convert_to_int32(self.out_block, self.wmem)
tag = Uint32Array(4)
for v'var i = 0; i < S.length; i++':
tag[i] = S[i] ^ self.wmem[i]
return {'iv':iv, 'cipherbytes':outbytes, 'tag':tag}
def encrypt(self, plaintext, tag):
iv = random_bytes(12)
return self._crypt(iv, string_to_bytes(plaintext), self.tag_as_bytes(tag))
def decrypt(self, output_from_encrypt, tag):
if output_from_encrypt.tag.length != 4:
raise ValueError('Corrupted message')
ans = self._crypt(output_from_encrypt.iv, output_from_encrypt.cipherbytes, self.tag_as_bytes(tag), True)
if ans.tag != output_from_encrypt.tag:
raise ValueError('Corrupted message')
return bytes_to_string(ans.cipherbytes)
if __name__ == '__main__': if __name__ == '__main__':
text = 'testing a basic roundtrip ø̄ū' text = 'testing a basic roundtrip ø̄ū'
@ -404,7 +605,14 @@ if __name__ == '__main__':
crypted = ctr.encrypt(text) crypted = ctr.encrypt(text)
decrypted = ctr.decrypt(crypted) decrypted = ctr.decrypt(crypted)
print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED') print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED')
crypted = ctr.encrypt(text, secret_tag) crypted = ctr.encrypt(text, secret_tag)
decrypted = ctr.decrypt(crypted, secret_tag) decrypted = ctr.decrypt(crypted, secret_tag)
print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED') print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')
gcm = GCM()
crypted = gcm.encrypt(text)
decrypted = gcm.decrypt(crypted)
print('GCM Roundtrip:', 'OK' if text is decrypted else 'FAILED')
crypted = gcm.encrypt(text, secret_tag)
decrypted = gcm.decrypt(crypted, secret_tag)
print('GCM Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')