diff --git a/src/calibre/utils/run_tests.py b/src/calibre/utils/run_tests.py index 4e403dc722..19853589ff 100644 --- a/src/calibre/utils/run_tests.py +++ b/src/calibre/utils/run_tests.py @@ -330,6 +330,8 @@ def find_tests(which_tests=None, exclude_tests=None): a(find_tests()) from calibre.utils.config_base import find_tests a(find_tests()) + from calibre.utils.zipfile import find_tests + a(find_tests()) if ok('dbcli'): from calibre.db.cli.tests import find_tests a(find_tests()) diff --git a/src/calibre/utils/zipfile.py b/src/calibre/utils/zipfile.py index 37bca7dae4..ee23dac727 100644 --- a/src/calibre/utils/zipfile.py +++ b/src/calibre/utils/zipfile.py @@ -1771,10 +1771,137 @@ def main(args=None): def find_tests(): import unittest - class TextZipFile(unittest.TestCase): - def test_zipfile_safe_replace(self): - pass + class TestZipFileSafeReplace(unittest.TestCase): + + def setUp(self): + import shutil + self.tdir = tempfile.mkdtemp() + self._cleanup = shutil.rmtree + + def tearDown(self): + self._cleanup(self.tdir, ignore_errors=True) + + def _make_zip(self, path, entries): + # entries: list of (name, data) tuples + with ZipFile(path, 'w') as z: + for name, data in entries: + z.writestr(name, data) + + def _read_zip_entry(self, path_or_stream, name): + if isinstance(path_or_stream, str): + with ZipFile(path_or_stream, 'r') as z: + return z.read(name) + else: + path_or_stream.seek(0) + with ZipFile(path_or_stream, 'r') as z: + return z.read(name) + + def _zip_names(self, path_or_stream): + if isinstance(path_or_stream, str): + with ZipFile(path_or_stream, 'r') as z: + return set(z.namelist()) + else: + path_or_stream.seek(0) + with ZipFile(path_or_stream, 'r') as z: + return set(z.namelist()) + + # ------------------------------------------------------------------ # + # Path-based branch + # ------------------------------------------------------------------ # + + def test_safe_replace_path_replaces_content(self): + zpath = os.path.join(self.tdir, 'test.zip') + self._make_zip(zpath, [('a.txt', b'original'), ('b.txt', b'other')]) + safe_replace(zpath, 'a.txt', io.BytesIO(b'replaced')) + self.assertEqual(self._read_zip_entry(zpath, 'a.txt'), b'replaced') + # unrelated entry must be preserved + self.assertEqual(self._read_zip_entry(zpath, 'b.txt'), b'other') + + def test_safe_replace_path_preserves_all_entries(self): + zpath = os.path.join(self.tdir, 'test.zip') + entries = [('one.txt', b'1'), ('two.txt', b'2'), ('three.txt', b'3')] + self._make_zip(zpath, entries) + safe_replace(zpath, 'two.txt', io.BytesIO(b'TWO')) + self.assertEqual(self._zip_names(zpath), {'one.txt', 'two.txt', 'three.txt'}) + self.assertEqual(self._read_zip_entry(zpath, 'two.txt'), b'TWO') + + def test_safe_replace_path_extra_replacements(self): + zpath = os.path.join(self.tdir, 'test.zip') + self._make_zip(zpath, [('a.txt', b'a'), ('b.txt', b'b')]) + safe_replace(zpath, 'a.txt', io.BytesIO(b'A'), + extra_replacements={'b.txt': io.BytesIO(b'B')}) + self.assertEqual(self._read_zip_entry(zpath, 'a.txt'), b'A') + self.assertEqual(self._read_zip_entry(zpath, 'b.txt'), b'B') + + def test_safe_replace_path_add_missing(self): + zpath = os.path.join(self.tdir, 'test.zip') + self._make_zip(zpath, [('a.txt', b'a')]) + safe_replace(zpath, 'new.txt', io.BytesIO(b'new'), add_missing=True) + self.assertIn('new.txt', self._zip_names(zpath)) + self.assertEqual(self._read_zip_entry(zpath, 'new.txt'), b'new') + + @unittest.skipIf(sys.platform == 'win32', 'POSIX-only attribute preservation test') + def test_safe_replace_path_preserves_file_attributes(self): + zpath = os.path.join(self.tdir, 'test.zip') + self._make_zip(zpath, [('a.txt', b'original')]) + # set specific permissions on the zip file + desired_mode = 0o640 + os.chmod(zpath, desired_mode) + original_stat = os.stat(zpath) + safe_replace(zpath, 'a.txt', io.BytesIO(b'replaced')) + new_stat = os.stat(zpath) + # permissions (lower 12 bits) must be preserved + self.assertEqual( + stat.S_IMODE(new_stat.st_mode), + stat.S_IMODE(original_stat.st_mode), + 'File permissions not preserved after safe_replace with path' + ) + # uid/gid must be preserved + self.assertEqual(new_stat.st_uid, original_stat.st_uid, + 'File owner (uid) not preserved after safe_replace with path') + self.assertEqual(new_stat.st_gid, original_stat.st_gid, + 'File group (gid) not preserved after safe_replace with path') + + # ------------------------------------------------------------------ # + # Stream-based branch + # ------------------------------------------------------------------ # + + def test_safe_replace_stream_replaces_content(self): + buf = io.BytesIO() + self._make_zip(buf, [('a.txt', b'original'), ('b.txt', b'other')]) + buf.seek(0) + safe_replace(buf, 'a.txt', io.BytesIO(b'replaced')) + self.assertEqual(self._read_zip_entry(buf, 'a.txt'), b'replaced') + self.assertEqual(self._read_zip_entry(buf, 'b.txt'), b'other') + + def test_safe_replace_stream_preserves_all_entries(self): + buf = io.BytesIO() + entries = [('one.txt', b'1'), ('two.txt', b'2'), ('three.txt', b'3')] + self._make_zip(buf, entries) + buf.seek(0) + safe_replace(buf, 'two.txt', io.BytesIO(b'TWO')) + self.assertEqual(self._zip_names(buf), {'one.txt', 'two.txt', 'three.txt'}) + self.assertEqual(self._read_zip_entry(buf, 'two.txt'), b'TWO') + + def test_safe_replace_stream_extra_replacements(self): + buf = io.BytesIO() + self._make_zip(buf, [('a.txt', b'a'), ('b.txt', b'b')]) + buf.seek(0) + safe_replace(buf, 'a.txt', io.BytesIO(b'A'), + extra_replacements={'b.txt': io.BytesIO(b'B')}) + self.assertEqual(self._read_zip_entry(buf, 'a.txt'), b'A') + self.assertEqual(self._read_zip_entry(buf, 'b.txt'), b'B') + + def test_safe_replace_stream_add_missing(self): + buf = io.BytesIO() + self._make_zip(buf, [('a.txt', b'a')]) + buf.seek(0) + safe_replace(buf, 'new.txt', io.BytesIO(b'new'), add_missing=True) + self.assertIn('new.txt', self._zip_names(buf)) + self.assertEqual(self._read_zip_entry(buf, 'new.txt'), b'new') + + return unittest.defaultTestLoader.loadTestsFromTestCase(TestZipFileSafeReplace) if __name__ == '__main__':