Add some tests for the new copy tree functionality

This commit is contained in:
Kovid Goyal 2023-04-08 13:02:17 +05:30
parent 42cabadccb
commit e25f8bf61c
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
2 changed files with 72 additions and 9 deletions

View File

@ -38,7 +38,7 @@ class UnixFileCopier:
with suppress(OSError):
os.link(src_path, dest_path, follow_symlinks=False)
shutil.copystat(src_path, dest_path, follow_symlinks=False)
return
continue
shutil.copy2(src_path, dest_path, follow_symlinks=False)
def delete_all_source_files(self) -> None:
@ -102,6 +102,7 @@ class WindowsFileCopier:
with suppress(Exception):
windows_hardlink(src_path, dest_path)
shutil.copystat(src_path, dest_path, follow_symlinks=False)
continue
handle = self.path_to_handle_map[src_path]
winutil.set_file_pointer(handle, 0, winutil.FILE_BEGIN)
with open(dest_path, 'wb') as f:
@ -140,7 +141,8 @@ def copy_tree(
) -> None:
'''
Copy all files in the tree over. On Windows locks all files before starting the copy to ensure that
other processes cannot interfere once the copy starts.
other processes cannot interfere once the copy starts. Uses hardlinks, falling back to actual file copies
only if hardlinking fails.
'''
if iswindows:
if isinstance(src, bytes):
@ -152,14 +154,15 @@ def copy_tree(
os.makedirs(dest, exist_ok=True)
if samefile(src, dest):
raise ValueError(f'Cannot copy tree if the source and destination are the same: {src!r} == {dest!r}')
dest_dir = dest
def raise_error(e: OSError) -> None:
raise e
def dest_from_entry(dirpath: str, x: str) -> str:
path = os.path.join(dirpath, d)
path = os.path.join(dirpath, x)
rel = os.path.relpath(path, src)
return os.path.join(dest, rel)
return os.path.join(dest_dir, rel)
copier = get_copier()
@ -171,7 +174,7 @@ def copy_tree(
shutil.copystat(make_long_path_useable(path), make_long_path_useable(dest), follow_symlinks=False)
for f in filenames:
path = os.path.join(dirpath, f)
dest = dest_from_entry(dirpath, d)
dest = dest_from_entry(dirpath, f)
dest = transform_destination_filename(path, dest)
if not iswindows:
s = os.stat(path, follow_symlinks=False)

View File

@ -1,23 +1,83 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2023, Kovid Goyal <kovid at kovidgoyal.net>
import os
import shutil
import tempfile
import time
import unittest
from calibre.constants import iswindows
from .copy_files import copy_tree
from .filenames import nlinks_file
class TestCopyFiles(unittest.TestCase):
ae = unittest.TestCase.assertEqual
def setUp(self):
self.tdir = tempfile.mkdtemp()
self.tdir = t = tempfile.mkdtemp()
def wf(*parts):
d = os.path.join(t, *parts)
os.makedirs(os.path.dirname(d), exist_ok=True)
with open(d, 'w') as f:
f.write(' '.join(parts))
wf('base'), wf('src/one'), wf('src/sub/a')
if not iswindows:
os.symlink('sub/a', os.path.join(t, 'src/link'))
def tearDown(self):
shutil.rmtree(self.tdir)
if self.tdir:
try:
shutil.rmtree(self.tdir)
except OSError:
time.sleep(1)
shutil.rmtree(self.tdir)
self.tdir = ''
def test_copy_files(self):
pass
def s(self, *path):
return os.path.abspath(os.path.join(self.tdir, 'src', *path))
def d(self, *path):
return os.path.abspath(os.path.join(self.tdir, 'dest', *path))
def file_data_eq(self, path):
with open(self.s(path)) as src, open(self.d(path)) as dest:
self.ae(src.read(), dest.read())
def reset(self):
self.tearDown()
self.setUp()
def test_copying_of_trees(self):
src, dest = self.s(), self.d()
copy_tree(src, dest)
eq = self.file_data_eq
eq('one')
eq('sub/a')
if not iswindows:
eq('link')
self.ae(os.readlink(self.d('link')), 'sub/a')
self.ae(nlinks_file(self.s('one')), 2)
self.ae(set(os.listdir(self.tdir)), {'src', 'dest', 'base'})
self.reset()
src, dest = self.s(), self.d()
copy_tree(src, dest, delete_source=True)
self.ae(set(os.listdir(self.tdir)), {'dest', 'base'})
self.ae(nlinks_file(self.d('one')), 1)
def transform_destination_filename(src, dest):
return dest + '.extra'
self.reset()
src, dest = self.s(), self.d()
copy_tree(src, dest, transform_destination_filename=transform_destination_filename)
with open(self.d('sub/a.extra')) as d:
self.ae(d.read(), 'src/sub/a')
if not iswindows:
self.ae(os.readlink(self.d('link.extra')), 'sub/a')
def find_tests():