mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Refactor to get rid of the Counter class
This commit is contained in:
parent
7fad5939b3
commit
05d5ad2f59
@ -47,6 +47,15 @@ def bytes_to_string_slow(bytes, offset):
|
|||||||
string_to_bytes = string_to_bytes_encoder if type(TextEncoder) is 'function' else string_to_bytes_slow
|
string_to_bytes = string_to_bytes_encoder if type(TextEncoder) is 'function' else string_to_bytes_slow
|
||||||
bytes_to_string = bytes_to_string_decoder if type(TextDecoder) is 'function' else bytes_to_string_slow
|
bytes_to_string = bytes_to_string_decoder if type(TextDecoder) is 'function' else bytes_to_string_slow
|
||||||
|
|
||||||
|
def increment_counter(c):
|
||||||
|
# c is a Uint8Array
|
||||||
|
for v'var i = c.length; i >= 0; i--':
|
||||||
|
if c[i] is 255:
|
||||||
|
c[i] = 0
|
||||||
|
else:
|
||||||
|
c[i] += 1
|
||||||
|
break
|
||||||
|
|
||||||
# Lookup tables {{{
|
# Lookup tables {{{
|
||||||
# Number of rounds by keysize
|
# Number of rounds by keysize
|
||||||
number_of_rounds = {16: 10, 24: 12, 32: 14}
|
number_of_rounds = {16: 10, 24: 12, 32: 14}
|
||||||
@ -222,19 +231,6 @@ def random_bytes_secure(sz):
|
|||||||
random_bytes = random_bytes_secure if type(crypto) is not 'undefined' and type(crypto.getRandomValues) is 'function' else random_bytes_insecure
|
random_bytes = random_bytes_secure if type(crypto) is not 'undefined' and type(crypto.getRandomValues) is 'function' else random_bytes_insecure
|
||||||
if random_bytes is random_bytes_insecure:
|
if random_bytes is random_bytes_insecure:
|
||||||
print('WARNING: Using insecure RNG for AES')
|
print('WARNING: Using insecure RNG for AES')
|
||||||
# }}}
|
|
||||||
|
|
||||||
def generate_key(sz):
|
|
||||||
if not number_of_rounds[sz]:
|
|
||||||
raise ValueError('Invalid key size, must be: 16, 24 or 32')
|
|
||||||
return random_bytes(sz)
|
|
||||||
|
|
||||||
def generate_tag(sz):
|
|
||||||
return random_bytes(sz or 32)
|
|
||||||
|
|
||||||
def typed_array_as_js(x):
|
|
||||||
name = x.constructor.name or 'Uint8Array'
|
|
||||||
return '(new ' + name + '(' + JSON.stringify(Array.prototype.slice.call(x)) + '))'
|
|
||||||
|
|
||||||
class ModeOfOperation:
|
class ModeOfOperation:
|
||||||
|
|
||||||
@ -255,7 +251,21 @@ class ModeOfOperation:
|
|||||||
return string_to_bytes(tag)
|
return string_to_bytes(tag)
|
||||||
raise TypeError('Invalid tag, must be a string or a Uint8Array')
|
raise TypeError('Invalid tag, must be a string or a Uint8Array')
|
||||||
|
|
||||||
class CBC(ModeOfOperation):
|
# }}}
|
||||||
|
|
||||||
|
def generate_key(sz):
|
||||||
|
if not number_of_rounds[sz]:
|
||||||
|
raise ValueError('Invalid key size, must be: 16, 24 or 32')
|
||||||
|
return random_bytes(sz)
|
||||||
|
|
||||||
|
def generate_tag(sz):
|
||||||
|
return random_bytes(sz or 32)
|
||||||
|
|
||||||
|
def typed_array_as_js(x):
|
||||||
|
name = x.constructor.name or 'Uint8Array'
|
||||||
|
return '(new ' + name + '(' + JSON.stringify(Array.prototype.slice.call(x)) + '))'
|
||||||
|
|
||||||
|
class CBC(ModeOfOperation): # {{{
|
||||||
|
|
||||||
def encrypt_bytes(self, bytes, tag_bytes):
|
def encrypt_bytes(self, bytes, tag_bytes):
|
||||||
iv = first_iv = random_bytes(16)
|
iv = first_iv = random_bytes(16)
|
||||||
@ -306,39 +316,9 @@ class CBC(ModeOfOperation):
|
|||||||
mstart = 1 + tag_bytes.length
|
mstart = 1 + tag_bytes.length
|
||||||
outputbytes = outputbytes.subarray(mstart, mstart + mlen)
|
outputbytes = outputbytes.subarray(mstart, mstart + mlen)
|
||||||
return bytes_to_string(outputbytes)
|
return bytes_to_string(outputbytes)
|
||||||
|
# }}}
|
||||||
|
|
||||||
class Counter:
|
class CTR(ModeOfOperation): # {{{
|
||||||
|
|
||||||
def __init__(self, initial_value, number_of_bytes=16):
|
|
||||||
initial_value = initial_value or 1
|
|
||||||
if type(initial_value) is 'number':
|
|
||||||
self.bytes = Uint8Array(number_of_bytes)
|
|
||||||
self.set_value(initial_value)
|
|
||||||
else:
|
|
||||||
self.bytes = Uint8Array(initial_value)
|
|
||||||
|
|
||||||
def set_value(self, value):
|
|
||||||
c = self.bytes
|
|
||||||
for v'var index = self.bytes.length; index >= 0; index--':
|
|
||||||
c[index] = value % 256
|
|
||||||
value >>= 8
|
|
||||||
|
|
||||||
def set_bytes(self, bytes):
|
|
||||||
self.bytes.set(bytes)
|
|
||||||
|
|
||||||
def increment(self):
|
|
||||||
c = self.bytes
|
|
||||||
for v'var i = self.bytes.length; i >= 0; i--':
|
|
||||||
if c[i] is 255:
|
|
||||||
c[i] = 0
|
|
||||||
else:
|
|
||||||
c[i] += 1
|
|
||||||
break
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return as_hex(self.bytes)
|
|
||||||
|
|
||||||
class CTR(ModeOfOperation):
|
|
||||||
|
|
||||||
# Note that this mode of operation requires the pair of (counterbytes,
|
# Note that this mode of operation requires the pair of (counterbytes,
|
||||||
# secret key) to always be unique, for every block. Therefore, if you are
|
# secret key) to always be unique, for every block. Therefore, if you are
|
||||||
@ -348,23 +328,23 @@ class CTR(ModeOfOperation):
|
|||||||
def __init__(self, key, counter):
|
def __init__(self, key, counter):
|
||||||
ModeOfOperation.__init__(self, key)
|
ModeOfOperation.__init__(self, key)
|
||||||
self.wmem = Uint8Array(16)
|
self.wmem = Uint8Array(16)
|
||||||
self.counter = Counter(counter)
|
self.counter_block = Uint8Array(16)
|
||||||
self.counter_index = 16
|
self.counter_index = 16
|
||||||
|
|
||||||
def _crypt(self, bytes):
|
def _crypt(self, bytes):
|
||||||
for v'var i = 0; i < bytes.length; i++':
|
for v'var i = 0; i < bytes.length; i++':
|
||||||
if self.counter_index is 16:
|
if self.counter_index is 16:
|
||||||
self.counter_index = 0
|
self.counter_index = 0
|
||||||
self.aes.encrypt(self.counter.bytes, self.wmem, 0)
|
self.aes.encrypt(self.counter_block, self.wmem, 0)
|
||||||
self.counter.increment()
|
increment_counter(self.counter_block)
|
||||||
bytes[i] ^= self.wmem[self.counter_index]
|
bytes[i] ^= self.wmem[self.counter_index]
|
||||||
self.counter_index += 1
|
self.counter_index += 1
|
||||||
self.counter_index = 16
|
self.counter_index = 16
|
||||||
self.counter.increment()
|
increment_counter(self.counter_block)
|
||||||
|
|
||||||
def encrypt(self, plaintext, tag):
|
def encrypt(self, plaintext, tag):
|
||||||
outbytes = string_to_bytes(plaintext)
|
outbytes = string_to_bytes(plaintext)
|
||||||
counterbytes = Uint8Array(self.counter.bytes)
|
counterbytes = Uint8Array(self.counter_block)
|
||||||
if tag:
|
if tag:
|
||||||
tag_bytes = self.tag_as_bytes(tag)
|
tag_bytes = self.tag_as_bytes(tag)
|
||||||
t = Uint8Array(outbytes.length + tag_bytes.length)
|
t = Uint8Array(outbytes.length + tag_bytes.length)
|
||||||
@ -376,16 +356,16 @@ class CTR(ModeOfOperation):
|
|||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.before_index = self.counter_index
|
self.before_index = self.counter_index
|
||||||
self.before_bytes = Uint8Array(self.counter.bytes)
|
self.before_counter = Uint8Array(self.counter_block)
|
||||||
|
|
||||||
def __exit__(self):
|
def __exit__(self):
|
||||||
self.counter_index = self.before_index
|
self.counter_index = self.before_index
|
||||||
self.counter.set_bytes(self.before_bytes)
|
self.counter_block = self.before_counter
|
||||||
|
|
||||||
def decrypt(self, output_from_encrypt, tag):
|
def decrypt(self, output_from_encrypt, tag):
|
||||||
b = Uint8Array(output_from_encrypt.cipherbytes)
|
b = Uint8Array(output_from_encrypt.cipherbytes)
|
||||||
with self:
|
with self:
|
||||||
self.counter.set_bytes(output_from_encrypt.counterbytes)
|
self.counter_block = output_from_encrypt.counterbytes
|
||||||
self.counter_index = 16
|
self.counter_index = 16
|
||||||
self._crypt(b)
|
self._crypt(b)
|
||||||
offset = 0
|
offset = 0
|
||||||
@ -396,7 +376,7 @@ class CTR(ModeOfOperation):
|
|||||||
raise ValueError('Corrupted message')
|
raise ValueError('Corrupted message')
|
||||||
offset = tag_bytes.length
|
offset = tag_bytes.length
|
||||||
return bytes_to_string(b, offset)
|
return bytes_to_string(b, offset)
|
||||||
|
# }}}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
text = 'testing a basic roundtrip ø̄ū'
|
text = 'testing a basic roundtrip ø̄ū'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user