diff --git a/src/pyj/aes.pyj b/src/pyj/aes.pyj index 086592ecc4..b9ef430bb1 100644 --- a/src/pyj/aes.pyj +++ b/src/pyj/aes.pyj @@ -244,13 +244,19 @@ class CBC: def __init__(self, key): self.aes = AES(key or generate_key(32)) - def encrypt_bytes(self, bytes): + def encrypt_bytes(self, bytes, tag_bytes): iv = first_iv = random_bytes(16) - padsz = 32 - (bytes.length % 16) + mlen = bytes.length + tag_bytes.length + 1 + padsz = 16 - (mlen % 16) + inputbytes = Uint8Array(mlen + padsz) + inputbytes[0] = padsz + if tag_bytes.length: + inputbytes.set(tag_bytes, 1) + inputbytes.set(bytes, 1 + tag_bytes.length) + if padsz: + inputbytes.set(random_bytes(padsz), 1 + tag_bytes.length + bytes.length) + offset = 0 - tag = random_bytes(padsz) - inputbytes = Uint8Array(bytes.length + padsz) - inputbytes.set(tag), inputbytes.set(bytes, tag.length) outputbytes = Uint8Array(inputbytes.length) for v'var block = 0; block < inputbytes.length; block += 16': if block > 0: @@ -258,13 +264,14 @@ class CBC: for v'var i = 0; i < 16; i++': inputbytes[block + i] ^= iv[offset + i] self.aes.encrypt(inputbytes, outputbytes, block) - return {'iv':first_iv, 'tag':tag, 'cipherbytes':outputbytes} + return {'iv':first_iv, 'cipherbytes':outputbytes} - def encrypt(self, plaintext): - return self.encrypt_bytes(string_to_bytes(plaintext)) + def encrypt(self, plaintext, tag): + return self.encrypt_bytes(string_to_bytes(plaintext), string_to_bytes(tag) if tag else Uint8Array(0)) - def decrypt(self, output_from_encrypt): - iv, tag, inputbytes = output_from_encrypt.iv, output_from_encrypt.tag, output_from_encrypt.cipherbytes + def decrypt(self, output_from_encrypt, tag): + tag_bytes = string_to_bytes(tag) if tag else Uint8Array(0) + iv, inputbytes = output_from_encrypt.iv, output_from_encrypt.cipherbytes offset = 0 outputbytes = Uint8Array(inputbytes.length) for v'var block = 0; block < inputbytes.length; block += 16': @@ -273,10 +280,14 @@ class CBC: iv, offset = inputbytes, block - 16 for v'var i = 0; i < 16; i++': outputbytes[block + i] ^= iv[offset + i] - for v'var i = 0; i < tag.length; i++': - if tag[i] != outputbytes[i]: + padsz = outputbytes[0] + for v'var i = 0; i < tag_bytes.length; i++': + if tag_bytes[i] != outputbytes[i+1]: raise ValueError('Corrupt message') - return bytes_to_string(outputbytes, tag.length) + mlen = outputbytes.length - 1 - padsz - tag_bytes.length + mstart = 1 + tag_bytes.length + outputbytes = outputbytes.slice(mstart, mstart + mlen) + return bytes_to_string(outputbytes) if __name__ == '__main__': cbc = CBC() @@ -284,3 +295,7 @@ if __name__ == '__main__': crypted = cbc.encrypt(text) decrypted = cbc.decrypt(crypted) print('Roundtrip:', 'OK' if text is decrypted else 'FAILED') + crypted = cbc.encrypt(text, 'secret') + decrypted = cbc.decrypt(crypted, 'secret') + print('Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED') +