Use full storage number for robustness in matching

This commit is contained in:
Kovid Goyal 2016-01-18 09:56:51 +05:30
parent 3af8ad01e6
commit d8d743ce8b

View File

@ -6,11 +6,14 @@ from __future__ import (unicode_literals, division, absolute_import,
print_function) print_function)
import os, string, _winreg as winreg, re, time, sys import os, string, _winreg as winreg, re, time, sys
from collections import namedtuple, defaultdict
from operator import itemgetter
from ctypes import ( from ctypes import (
Structure, POINTER, c_ubyte, windll, byref, c_void_p, WINFUNCTYPE, Structure, POINTER, c_ubyte, windll, byref, c_void_p, WINFUNCTYPE,
WinError, get_last_error, sizeof, c_wchar, create_string_buffer, cast, WinError, get_last_error, sizeof, c_wchar, create_string_buffer, cast,
wstring_at, addressof, create_unicode_buffer, string_at, c_uint64 as QWORD) wstring_at, addressof, create_unicode_buffer, string_at, c_uint64 as QWORD)
from ctypes.wintypes import DWORD, WORD, ULONG, LPCWSTR, HWND, BOOL, LPWSTR, UINT, BYTE, HANDLE from ctypes.wintypes import DWORD, WORD, ULONG, LPCWSTR, HWND, BOOL, LPWSTR, UINT, BYTE, HANDLE
from pprint import pprint, pformat
from calibre import prints, as_unicode from calibre import prints, as_unicode
@ -59,6 +62,8 @@ IOCTL_STORAGE_MEDIA_REMOVAL = 0x2D4804
IOCTL_STORAGE_EJECT_MEDIA = 0x2D4808 IOCTL_STORAGE_EJECT_MEDIA = 0x2D4808
IOCTL_STORAGE_GET_DEVICE_NUMBER = 0x2D1080 IOCTL_STORAGE_GET_DEVICE_NUMBER = 0x2D1080
StorageDeviceNumber = namedtuple('StorageDeviceNumber', 'type number partition_number')
class STORAGE_DEVICE_NUMBER(Structure): class STORAGE_DEVICE_NUMBER(Structure):
_fields_ = [ _fields_ = [
('DeviceType', DWORD), ('DeviceType', DWORD),
@ -66,6 +71,9 @@ class STORAGE_DEVICE_NUMBER(Structure):
('PartitionNumber', ULONG) ('PartitionNumber', ULONG)
] ]
def as_tuple(self):
return StorageDeviceNumber(self.DeviceType, self.DeviceNumber, self.PartitionNumber)
class SP_DEVINFO_DATA(Structure): class SP_DEVINFO_DATA(Structure):
_fields_ = [ _fields_ = [
('cbSize', DWORD), ('cbSize', DWORD),
@ -373,7 +381,7 @@ def get_storage_number(devpath):
DeviceIoControl(handle, IOCTL_STORAGE_GET_DEVICE_NUMBER, None, 0, byref(sdn), sizeof(STORAGE_DEVICE_NUMBER), None, None) DeviceIoControl(handle, IOCTL_STORAGE_GET_DEVICE_NUMBER, None, 0, byref(sdn), sizeof(STORAGE_DEVICE_NUMBER), None, None)
finally: finally:
CloseHandle(handle) CloseHandle(handle)
return sdn.DeviceNumber return sdn.as_tuple()
def get_all_removable_drives(allow_fixed=False): def get_all_removable_drives(allow_fixed=False):
mask = GetLogicalDrives() mask = GetLogicalDrives()
@ -523,13 +531,13 @@ def get_removable_drives(debug=False): # {{{
return ans return ans
# }}} # }}}
def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False): # {{{ def get_drive_letters_for_device(vendor_id, product_id, bcd=None, storage_number_map=None, debug=False): # {{{
''' '''
Get the drive letters for a connected device with the specieid USB ids. bcd Get the drive letters for a connected device with the specieid USB ids. bcd
can be either None, in which case it is not tested, or it must be a list or can be either None, in which case it is not tested, or it must be a list or
set like object containing bcds. set like object containing bcds.
''' '''
rbuf = wbuf = None rbuf = None
ans = [] ans = []
# First search for a device matching the specified USB ids # First search for a device matching the specified USB ids
@ -551,23 +559,30 @@ def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False):
return ans return ans
# Get the device ids for all descendants of the found device # Get the device ids for all descendants of the found device
sn_map = get_storage_number_map(debug=debug) sn_map = get_storage_number_map(debug=debug) if storage_number_map is None else storage_number_map
if debug: if debug:
prints('Storage number map:', sn_map) prints('Storage number map:')
for devinst in iterdescendants(devinfo.DevInst): prints(pformat(sn_map))
devid, wbuf = get_device_id(devinst, buf=wbuf) descendants = frozenset(iterdescendants(devinfo.DevInst))
try: for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces():
drive_letter = find_drive(devinst, sn_map, debug=debug) if devinfo.DevInst in descendants:
except Exception as err:
if debug: if debug:
prints('Failed to get drive letter for: %s with error: %s' % (devid, as_unicode(err))) try:
import traceback devid = get_device_id(devinfo.DevInst)[0]
traceback.print_exc() except Exception as err:
else: devid = 'Unknown'
if drive_letter: try:
ans.append(drive_letter) storage_number = get_storage_number(devpath)
except Exception as err:
if debug:
prints('Failed to get storage number for: %s with error: %s' % (devid, as_unicode(err)))
continue
if debug: if debug:
prints('Drive letter for: %s is: %s' % (devid, drive_letter)) prints('Storage number for %s: %s' % (devid, storage_number))
if storage_number:
partitions = sn_map.get(storage_number[:2])
drive_letters = [x[1] for x in partitions or ()]
ans.extend(drive_letters)
return ans return ans
@ -582,28 +597,20 @@ def get_storage_number_map(drive_types=(DRIVE_REMOVABLE, DRIVE_FIXED), debug=Fal
mask = GetLogicalDrives() mask = GetLogicalDrives()
type_map = {letter:GetDriveType(letter + ':' + os.sep) for i, letter in enumerate(string.ascii_uppercase) if mask & (1 << i)} type_map = {letter:GetDriveType(letter + ':' + os.sep) for i, letter in enumerate(string.ascii_uppercase) if mask & (1 << i)}
drives = (letter for letter, dt in type_map.iteritems() if dt in drive_types) drives = (letter for letter, dt in type_map.iteritems() if dt in drive_types)
ans = {} ans = defaultdict(list)
for letter in drives: for letter in drives:
try: try:
sn = get_storage_number('\\\\.\\' + letter + ':') sn = get_storage_number('\\\\.\\' + letter + ':')
if debug and sn in ans: ans[sn[:2]].append((sn[2], letter))
prints('Duplicate storage number for drives: %s and %s' % (letter, ans[sn]))
ans[sn] = letter
except WindowsError as err: except WindowsError as err:
if debug: if debug:
prints('Failed to get storage number for drive: %s with error: %s' % (letter, as_unicode(err))) prints('Failed to get storage number for drive: %s with error: %s' % (letter, as_unicode(err)))
continue continue
return ans for val in ans.itervalues():
val.sort(key=itemgetter(0))
return dict(ans)
def find_drive(devinst, storage_number_map, debug=False):
for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces():
if devinfo.DevInst == devinst:
storage_number = get_storage_number(devpath)
drive_letter = storage_number_map.get(storage_number)
if drive_letter:
return drive_letter
# }}} # }}}
def get_usb_devices(): # {{{ def get_usb_devices(): # {{{
@ -634,7 +641,7 @@ def is_usb_device_connected(vendor_id, product_id): # {{{
def eject_drive(drive_letter): # {{{ def eject_drive(drive_letter): # {{{
drive_letter = type('')(drive_letter)[0] drive_letter = type('')(drive_letter)[0]
volume_access_path = '\\\\.\\' + drive_letter + ':' volume_access_path = '\\\\.\\' + drive_letter + ':'
devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path)) devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path).number)
if devinst is None: if devinst is None:
raise ValueError('Could not find device instance number from drive letter: %s' % drive_letter) raise ValueError('Could not find device instance number from drive letter: %s' % drive_letter)
parent = DEVINST(0) parent = DEVINST(0)
@ -665,12 +672,11 @@ def devinst_from_device_number(drive_letter, device_number):
else: else:
raise ValueError('Unknown drive_type: %d' % drive_type) raise ValueError('Unknown drive_type: %d' % drive_type)
for devinfo, devpath in DeviceSet(guid=guid).interfaces(ignore_errors=True): for devinfo, devpath in DeviceSet(guid=guid).interfaces(ignore_errors=True):
if get_storage_number(devpath) == device_number: if get_storage_number(devpath).number == device_number:
return devinfo.DevInst return devinfo.DevInst
# }}} # }}}
def develop(vendor_id=0x1949, product_id=0x4, bcd=None, do_eject=False): # {{{ def develop(vendor_id=0x1949, product_id=0x4, bcd=None, do_eject=False): # {{{
from pprint import pprint
pprint(get_usb_devices()) pprint(get_usb_devices())
print() print()
print('Is device connected:', is_usb_device_connected(vendor_id, product_id)) print('Is device connected:', is_usb_device_connected(vendor_id, product_id))