diff --git a/src/pyj/aes.pyj b/src/pyj/aes.pyj index 3de088b2dd..34739c7bd8 100644 --- a/src/pyj/aes.pyj +++ b/src/pyj/aes.pyj @@ -56,7 +56,30 @@ def increment_counter(c): c[i] += 1 break -# Lookup tables {{{ +def convert_to_int32(bytes, output, offset, length): + offset = offset or 0 + length = length or bytes.length + for v'var i = offset, j = 0; i < offset + length; i += 4, j++': + output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3] + +def convert_to_int32_pad(bytes): + extra = bytes.length % 4 + if extra: + t = Uint8Array(bytes.length + 4 - extra) + t.set(bytes) + bytes = t + ans = Uint32Array(bytes.length / 4) + convert_to_int32(bytes, ans) + return ans + +def from_64_to_32(num): + # convert 64-bit number to two BE Int32s + ans = Uint32Array(2) + ans[0] = (num / 0x100000000) | 0 + ans[1] = num & 0xFFFFFFFF + return ans + +# Lookup tables for AES {{{ # Number of rounds by keysize number_of_rounds = {16: 10, 24: 12, 32: 14} # Round constant words @@ -86,13 +109,7 @@ U4 = v'new Uint32Array([0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c # }}} -def convert_to_int32(bytes, output, offset, length): - offset = offset or 0 - length = length or bytes.length - for v'var i = offset, j = 0; i < offset + length; i += 4, j++': - output[j] = (bytes[i] << 24) | (bytes[i + 1] << 16) | (bytes[i + 2] << 8) | bytes[i + 3] - -class AES: +class AES: # {{{ def __init__(self, key): self.working_mem = [Uint32Array(4), Uint32Array(4)] @@ -242,7 +259,7 @@ random_bytes = random_bytes_secure if type(crypto) is not 'undefined' and type(c if random_bytes is random_bytes_insecure: print('WARNING: Using insecure RNG for AES') -class ModeOfOperation: +class ModeOfOperation: # {{{ def __init__(self, key): self.key = key or generate_key(32) @@ -260,6 +277,100 @@ class ModeOfOperation: if type(tag) is 'string': return string_to_bytes(tag) raise TypeError('Invalid tag, must be a string or a Uint8Array') +# }}} + +class GaloisField: # {{{ + + def __init__(self, sub_key): + k32 = Uint32Array(4) + convert_to_int32(sub_key, k32, 0) + self.m = self.generate_hash_table(k32) + self.wmem = Uint32Array(4) + + def power(self, x, out): + lsb = x[3] & 1 + for v'var i = 3; i > 0; --i': + out[i] = (x[i] >>> 1) | ((x[i - 1] & 1) << 31) + out[0] = x[0] >>> 1 + if lsb: + out[0] ^= 0xE1000000 + + def multiply(self, x, y): + z_i = Uint32Array(4) + v_i = y.slice(0) + for v'var i = 0; i < 128; ++i': + x_i = x[(i / 32) | 0] & (1 << (31 - i % 32)) + if x_i: + z_i[0] ^= v_i[0] + z_i[1] ^= v_i[1] + z_i[2] ^= v_i[2] + z_i[3] ^= v_i[3] + self.power(v_i, v_i) + return z_i + + def generate_sub_hash_table(self, mid): + bits = mid.length + size = 1 << bits + half = size >>> 1 + m = Array(size) + m[half] = mid.slice(0) + i = half >>> 1 + while i > 0: + m[i] = Uint32Array(4) + self.power(m[2 * i], m[i]) + i >>= 1 + i = 2 + while i < half: + for v'var j = 1; j < i; ++j': + m_i = m[i] + m_j = m[j] + m[i + j] = x = Uint32Array(4) + for v'var c = 0; c < 4; c++': + x[c] = m_i[c] ^ m_j[c] + i *= 2 + m[0] = Uint32Array(4) + for v'i = half + 1; i < size; ++i': + x = m[i ^ half] + m[i] = y = Uint32Array(4) + for v'var c = 0; c < 4; c++': + y[c] = mid[c] ^ x[c] + return m + + def generate_hash_table(self, key_as_int32_array): + bits = key_as_int32_array.length + multiplier = 8 / bits + per_int = 4 * multiplier + size = 16 * multiplier + ans = Array(size) + for v'var i =0; i < size; ++i': + tmp = Uint32Array(4) + idx = (i/ per_int) | 0 + shft = ((per_int - 1 - (i % per_int)) * bits) + tmp[idx] = (1 << (bits - 1)) << shft + ans[i] = self.generate_sub_hash_table(self.multiply(tmp, key_as_int32_array)) + return ans + + def table_multiply(self, x): + z = Uint32Array(4) + for v'var i = 0; i < 32; ++i': + idx = (i / 8) | 0 + x_i = (x[idx] >>> ((7 - (i % 8)) * 4)) & 0xF + ah = self.m[i][x_i] + z[0] ^= ah[0] + z[1] ^= ah[1] + z[2] ^= ah[2] + z[3] ^= ah[3] + return z + + def ghash(self, x, y): + z = self.wmem + z[0] = y[0] ^ x[0] + z[1] = y[1] ^ x[1] + z[2] = y[2] ^ x[2] + z[3] = y[3] ^ x[3] + return self.table_multiply(z) + +# }}} # }}} @@ -388,6 +499,96 @@ class CTR(ModeOfOperation): # {{{ return bytes_to_string(b, offset) # }}} +class GCM(ModeOfOperation): + + # See http://web.cs.ucdavis.edu/~rogaway/ocb/gcm.pdf + + def __init__(self, key): + ModeOfOperation.__init__(self, key) + + # Generate the hash subkey + H = Uint8Array(16) + self.aes.encrypt(Uint8Array(16), H, 0) + self.galois = GaloisField(H) + + # Working memory + self.J0 = Uint32Array(4) + self.wmem = Uint32Array(4) + self.out_block = Uint8Array(16) + + def _create_j0(self, iv): + J0 = self.J0 + if iv.length is 12: + convert_to_int32(iv, J0) + J0[3] = 1 + else: + J0.fill(0) + tmp = convert_to_int32_pad(iv) + while tmp.length: + J0 = self.galois.ghash(J0, tmp) + tmp = tmp.subarray(4) + tmp = Uint32Array(4) + tmp.set(from_64_to_32(iv.length * 8), 2) + J0 = self.galois.ghash(J0, tmp) + return J0 + + def _crypt(self, iv, bytes, additional_data, decrypt): + J0 = self._create_j0(iv) + # Generate initial counter block + in_block = J0.slice(0) + in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter + outbytes = Uint8Array(bytes.length) + ghash = self.galois.ghash.bind(self.galois) + + # Process additional_data + overflow = additional_data.length % 16 + if overflow: + t = Uint8Array(additional_data.length + 16 - overflow) + t.set(additional_data) + additional_data = t + S = Uint32Array(4) + while additional_data.length: + additional_data = additional_data.subarray(4) + convert_to_int32(additional_data, self.wmem, 0, 16) + S = ghash(S, self.wmem) + + # Create the ciphertext, encrypting block by block + for v'var pos = 0; pos < bytes.length; pos += 16': + self.aes.encrypt32(in_block, self.out_block, 0) + num = min(16, bytes.length - pos) # noqa: unused-local + for v'var i = 0; i < num; i++': + outbytes[pos + i] = bytes[pos+i] ^ self.out_block[i] + convert_to_int32(self.out_block, self.wmem) + S = ghash(S, self.wmem) + in_block[3] = (in_block[3] + 1) & 0xFFFFFFFF # increment counter + + # Mix the lengths into S + lengths = Uint32Array(4) + lengths.set(from_64_to_32(additional_data.length * 8)) + lengths.set(from_64_to_32(bytes.length * 8)) + S = ghash(S, lengths) + + # Create the tag + self.aes.encrypt32(J0, self.out_block, 0) + convert_to_int32(self.out_block, self.wmem) + tag = Uint32Array(4) + for v'var i = 0; i < S.length; i++': + tag[i] = S[i] ^ self.wmem[i] + return {'iv':iv, 'cipherbytes':outbytes, 'tag':tag} + + + def encrypt(self, plaintext, tag): + iv = random_bytes(12) + return self._crypt(iv, string_to_bytes(plaintext), self.tag_as_bytes(tag)) + + def decrypt(self, output_from_encrypt, tag): + if output_from_encrypt.tag.length != 4: + raise ValueError('Corrupted message') + ans = self._crypt(output_from_encrypt.iv, output_from_encrypt.cipherbytes, self.tag_as_bytes(tag), True) + if ans.tag != output_from_encrypt.tag: + raise ValueError('Corrupted message') + return bytes_to_string(ans.cipherbytes) + if __name__ == '__main__': text = 'testing a basic roundtrip ø̄ū' @@ -404,7 +605,14 @@ if __name__ == '__main__': crypted = ctr.encrypt(text) decrypted = ctr.decrypt(crypted) print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED') - crypted = ctr.encrypt(text, secret_tag) decrypted = ctr.decrypt(crypted, secret_tag) print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED') + + gcm = GCM() + crypted = gcm.encrypt(text) + decrypted = gcm.decrypt(crypted) + print('GCM Roundtrip:', 'OK' if text is decrypted else 'FAILED') + crypted = gcm.encrypt(text, secret_tag) + decrypted = gcm.decrypt(crypted, secret_tag) + print('GCM Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')