diff --git a/src/calibre/utils/copy_files.py b/src/calibre/utils/copy_files.py index 856ff79704..a0aa55fcf7 100644 --- a/src/calibre/utils/copy_files.py +++ b/src/calibre/utils/copy_files.py @@ -21,7 +21,8 @@ WindowsFileId = Tuple[int, int, int] class UnixFileCopier: - def __init__(self): + def __init__(self, delete_all=False): + self.delete_all = delete_all self.copy_map: Dict[str, str] = {} def register(self, path: str, dest: str) -> None: @@ -30,8 +31,9 @@ class UnixFileCopier: def __enter__(self) -> None: pass - def __exit__(self, *a) -> None: - pass + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + if self.delete_all and exc_val is None: + self.delete_all_source_files() def rename_all(self) -> None: for src_path, dest_path in self.copy_map.items(): @@ -57,7 +59,8 @@ class WindowsFileCopier: Locks all files before starting the copy, ensuring other processes cannot interfere ''' - def __init__(self): + def __init__(self, delete_all=False): + self.delete_all = delete_all self.path_to_fileid_map : Dict[str, WindowsFileId] = {} self.fileid_to_paths_map: Dict[WindowsFileId, Set[str]] = defaultdict(set) self.path_to_handle_map: Dict[str, 'winutil.Handle'] = {} @@ -98,9 +101,13 @@ class WindowsFileCopier: for src in self.copy_map: self.path_to_handle_map[src] = self._open_file(src) - def __exit__(self, *a) -> None: - for h in self.path_to_handle_map.values(): - h.close() + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + try: + if self.delete_all and exc_val is None: + self.delete_all_source_files() + finally: + for h in self.path_to_handle_map.values(): + h.close() def copy_all(self) -> None: for src_path, dest_path in self.copy_map.items(): @@ -128,8 +135,8 @@ class WindowsFileCopier: winutil.delete_file(make_long_path_useable(src_path)) -def get_copier() -> Union[UnixFileCopier, WindowsFileCopier]: - return WindowsFileCopier() if iswindows else UnixFileCopier() +def get_copier(delete_all=False) -> Union[UnixFileCopier, WindowsFileCopier]: + return WindowsFileCopier(delete_all) if iswindows else UnixFileCopier(delete_all) def rename_files(src_to_dest_map: Dict[str, str]) -> None: @@ -142,14 +149,12 @@ def rename_files(src_to_dest_map: Dict[str, str]) -> None: def copy_files(src_to_dest_map: Dict[str, str], delete_source: bool = False) -> None: - copier = get_copier() + copier = get_copier(delete_source) for s, d in src_to_dest_map.items(): if not samefile(s, d): copier.register(s, d) with copier: copier.copy_all() - if delete_source: - copier.delete_all_source_files() def copy_tree( @@ -183,7 +188,7 @@ def copy_tree( return os.path.join(dest_dir, rel) - copier = get_copier() + copier = get_copier(delete_source) for (dirpath, dirnames, filenames) in os.walk(src, onerror=raise_error): for d in dirnames: path = os.path.join(dirpath, d) @@ -205,8 +210,6 @@ def copy_tree( with copier: copier.copy_all() - if delete_source: - copier.delete_all_source_files() if delete_source: try: