diff --git a/src/pyj/aes.pyj b/src/pyj/aes.pyj index f10ee62b6a..98f81a5b7d 100644 --- a/src/pyj/aes.pyj +++ b/src/pyj/aes.pyj @@ -310,14 +310,13 @@ class CBC(ModeOfOperation): class Counter: def __init__(self, initial_value): + self.bytes = Uint8Array(16) if not initial_value: - self.bytes = Uint8Array(16) self.set_value(1) elif type(initial_value) is 'number': - self.bytes = Uint8Array(16) self.set_value(initial_value) else: - self.bytes = Uint8Array(initial_value) + self.bytes.set(initial_value) def set_value(self, value): c = self.bytes @@ -326,7 +325,7 @@ class Counter: value >>= 8 def set_bytes(self, bytes): - self.bytes = Uint8Array(bytes) + self.bytes.set(bytes) def increment(self): c = self.bytes @@ -342,6 +341,11 @@ class Counter: class CTR(ModeOfOperation): + # Note that this mode of operation requires the pair of (counterbytes, + # secret key) to always be unique, for every block. Therefore, if you are + # using it for bi-directional messaging it is best to use a different + # secret key for each direction + def __init__(self, key, counter): ModeOfOperation.__init__(self, key) self.cval = counter @@ -357,9 +361,12 @@ class CTR(ModeOfOperation): self.counter.increment() bytes[i] ^= self.wmem[self.counter_index] self.counter_index += 1 + self.counter_index = 16 + self.counter.increment() def encrypt(self, plaintext, tag): outbytes = string_to_bytes(plaintext) + counterbytes = Uint8Array(self.counter.bytes) if tag: tag_bytes = self.tag_as_bytes(tag) t = Uint8Array(outbytes.length + tag_bytes.length) @@ -367,12 +374,22 @@ class CTR(ModeOfOperation): t.set(outbytes, tag_bytes.length) outbytes = t self._crypt(outbytes) - return outbytes + return {'cipherbytes':outbytes, 'counterbytes':counterbytes} + + def __enter__(self): + self.before_index = self.counter_index + self.before_bytes = Uint8Array(self.counter.bytes) + + def __exit__(self): + self.counter_index = self.before_index + self.counter.set_bytes(self.before_bytes) def decrypt(self, output_from_encrypt, tag): - b = Uint8Array(output_from_encrypt.length) - b.set(output_from_encrypt) - self._crypt(b) + b = Uint8Array(output_from_encrypt.cipherbytes) + with self: + self.counter.set_bytes(output_from_encrypt.counterbytes) + self.counter_index = 16 + self._crypt(b) offset = 0 if tag: tag_bytes = self.tag_as_bytes(tag) @@ -394,12 +411,11 @@ if __name__ == '__main__': decrypted = cbc.decrypt(crypted, secret_tag) print('CBC Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED') - ctre = CTR() - ctrd = CTR(ctre.key) - crypted = ctre.encrypt(text) - decrypted = ctrd.decrypt(crypted) + ctr = CTR() + crypted = ctr.encrypt(text) + decrypted = ctr.decrypt(crypted) print('CTR Roundtrip:', 'OK' if text is decrypted else 'FAILED') - crypted = ctre.encrypt(text, secret_tag) - decrypted = ctrd.decrypt(crypted, secret_tag) + crypted = ctr.encrypt(text, secret_tag) + decrypted = ctr.decrypt(crypted, secret_tag) print('CTR Roundtrip with tag:', 'OK' if text is decrypted else 'FAILED')