diff --git a/src/calibre/devices/winusb.py b/src/calibre/devices/winusb.py index 2ecec83d2f..141613f84e 100644 --- a/src/calibre/devices/winusb.py +++ b/src/calibre/devices/winusb.py @@ -6,11 +6,14 @@ from __future__ import (unicode_literals, division, absolute_import, print_function) import os, string, _winreg as winreg, re, time, sys +from collections import namedtuple, defaultdict +from operator import itemgetter from ctypes import ( Structure, POINTER, c_ubyte, windll, byref, c_void_p, WINFUNCTYPE, WinError, get_last_error, sizeof, c_wchar, create_string_buffer, cast, wstring_at, addressof, create_unicode_buffer, string_at, c_uint64 as QWORD) from ctypes.wintypes import DWORD, WORD, ULONG, LPCWSTR, HWND, BOOL, LPWSTR, UINT, BYTE, HANDLE +from pprint import pprint, pformat from calibre import prints, as_unicode @@ -59,6 +62,8 @@ IOCTL_STORAGE_MEDIA_REMOVAL = 0x2D4804 IOCTL_STORAGE_EJECT_MEDIA = 0x2D4808 IOCTL_STORAGE_GET_DEVICE_NUMBER = 0x2D1080 +StorageDeviceNumber = namedtuple('StorageDeviceNumber', 'type number partition_number') + class STORAGE_DEVICE_NUMBER(Structure): _fields_ = [ ('DeviceType', DWORD), @@ -66,6 +71,9 @@ class STORAGE_DEVICE_NUMBER(Structure): ('PartitionNumber', ULONG) ] + def as_tuple(self): + return StorageDeviceNumber(self.DeviceType, self.DeviceNumber, self.PartitionNumber) + class SP_DEVINFO_DATA(Structure): _fields_ = [ ('cbSize', DWORD), @@ -373,7 +381,7 @@ def get_storage_number(devpath): DeviceIoControl(handle, IOCTL_STORAGE_GET_DEVICE_NUMBER, None, 0, byref(sdn), sizeof(STORAGE_DEVICE_NUMBER), None, None) finally: CloseHandle(handle) - return sdn.DeviceNumber + return sdn.as_tuple() def get_all_removable_drives(allow_fixed=False): mask = GetLogicalDrives() @@ -523,13 +531,13 @@ def get_removable_drives(debug=False): # {{{ return ans # }}} -def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False): # {{{ +def get_drive_letters_for_device(vendor_id, product_id, bcd=None, storage_number_map=None, debug=False): # {{{ ''' Get the drive letters for a connected device with the specieid USB ids. bcd can be either None, in which case it is not tested, or it must be a list or set like object containing bcds. ''' - rbuf = wbuf = None + rbuf = None ans = [] # First search for a device matching the specified USB ids @@ -551,23 +559,30 @@ def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False): return ans # Get the device ids for all descendants of the found device - sn_map = get_storage_number_map(debug=debug) + sn_map = get_storage_number_map(debug=debug) if storage_number_map is None else storage_number_map if debug: - prints('Storage number map:', sn_map) - for devinst in iterdescendants(devinfo.DevInst): - devid, wbuf = get_device_id(devinst, buf=wbuf) - try: - drive_letter = find_drive(devinst, sn_map, debug=debug) - except Exception as err: + prints('Storage number map:') + prints(pformat(sn_map)) + descendants = frozenset(iterdescendants(devinfo.DevInst)) + for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces(): + if devinfo.DevInst in descendants: if debug: - prints('Failed to get drive letter for: %s with error: %s' % (devid, as_unicode(err))) - import traceback - traceback.print_exc() - else: - if drive_letter: - ans.append(drive_letter) + try: + devid = get_device_id(devinfo.DevInst)[0] + except Exception as err: + devid = 'Unknown' + try: + storage_number = get_storage_number(devpath) + except Exception as err: + if debug: + prints('Failed to get storage number for: %s with error: %s' % (devid, as_unicode(err))) + continue if debug: - prints('Drive letter for: %s is: %s' % (devid, drive_letter)) + prints('Storage number for %s: %s' % (devid, storage_number)) + if storage_number: + partitions = sn_map.get(storage_number[:2]) + drive_letters = [x[1] for x in partitions or ()] + ans.extend(drive_letters) return ans @@ -582,28 +597,20 @@ def get_storage_number_map(drive_types=(DRIVE_REMOVABLE, DRIVE_FIXED), debug=Fal mask = GetLogicalDrives() type_map = {letter:GetDriveType(letter + ':' + os.sep) for i, letter in enumerate(string.ascii_uppercase) if mask & (1 << i)} drives = (letter for letter, dt in type_map.iteritems() if dt in drive_types) - ans = {} + ans = defaultdict(list) for letter in drives: try: sn = get_storage_number('\\\\.\\' + letter + ':') - if debug and sn in ans: - prints('Duplicate storage number for drives: %s and %s' % (letter, ans[sn])) - ans[sn] = letter + ans[sn[:2]].append((sn[2], letter)) except WindowsError as err: if debug: prints('Failed to get storage number for drive: %s with error: %s' % (letter, as_unicode(err))) continue - return ans + for val in ans.itervalues(): + val.sort(key=itemgetter(0)) + return dict(ans) -def find_drive(devinst, storage_number_map, debug=False): - for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces(): - if devinfo.DevInst == devinst: - storage_number = get_storage_number(devpath) - drive_letter = storage_number_map.get(storage_number) - if drive_letter: - return drive_letter - # }}} def get_usb_devices(): # {{{ @@ -634,7 +641,7 @@ def is_usb_device_connected(vendor_id, product_id): # {{{ def eject_drive(drive_letter): # {{{ drive_letter = type('')(drive_letter)[0] volume_access_path = '\\\\.\\' + drive_letter + ':' - devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path)) + devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path).number) if devinst is None: raise ValueError('Could not find device instance number from drive letter: %s' % drive_letter) parent = DEVINST(0) @@ -665,12 +672,11 @@ def devinst_from_device_number(drive_letter, device_number): else: raise ValueError('Unknown drive_type: %d' % drive_type) for devinfo, devpath in DeviceSet(guid=guid).interfaces(ignore_errors=True): - if get_storage_number(devpath) == device_number: + if get_storage_number(devpath).number == device_number: return devinfo.DevInst # }}} def develop(vendor_id=0x1949, product_id=0x4, bcd=None, do_eject=False): # {{{ - from pprint import pprint pprint(get_usb_devices()) print() print('Is device connected:', is_usb_device_connected(vendor_id, product_id))