Use full storage number for robustness in matching

This commit is contained in:
Kovid Goyal 2016-01-18 09:56:51 +05:30
parent 3af8ad01e6
commit d8d743ce8b

View File

@ -6,11 +6,14 @@ from __future__ import (unicode_literals, division, absolute_import,
print_function)
import os, string, _winreg as winreg, re, time, sys
from collections import namedtuple, defaultdict
from operator import itemgetter
from ctypes import (
Structure, POINTER, c_ubyte, windll, byref, c_void_p, WINFUNCTYPE,
WinError, get_last_error, sizeof, c_wchar, create_string_buffer, cast,
wstring_at, addressof, create_unicode_buffer, string_at, c_uint64 as QWORD)
from ctypes.wintypes import DWORD, WORD, ULONG, LPCWSTR, HWND, BOOL, LPWSTR, UINT, BYTE, HANDLE
from pprint import pprint, pformat
from calibre import prints, as_unicode
@ -59,6 +62,8 @@ IOCTL_STORAGE_MEDIA_REMOVAL = 0x2D4804
IOCTL_STORAGE_EJECT_MEDIA = 0x2D4808
IOCTL_STORAGE_GET_DEVICE_NUMBER = 0x2D1080
StorageDeviceNumber = namedtuple('StorageDeviceNumber', 'type number partition_number')
class STORAGE_DEVICE_NUMBER(Structure):
_fields_ = [
('DeviceType', DWORD),
@ -66,6 +71,9 @@ class STORAGE_DEVICE_NUMBER(Structure):
('PartitionNumber', ULONG)
]
def as_tuple(self):
return StorageDeviceNumber(self.DeviceType, self.DeviceNumber, self.PartitionNumber)
class SP_DEVINFO_DATA(Structure):
_fields_ = [
('cbSize', DWORD),
@ -373,7 +381,7 @@ def get_storage_number(devpath):
DeviceIoControl(handle, IOCTL_STORAGE_GET_DEVICE_NUMBER, None, 0, byref(sdn), sizeof(STORAGE_DEVICE_NUMBER), None, None)
finally:
CloseHandle(handle)
return sdn.DeviceNumber
return sdn.as_tuple()
def get_all_removable_drives(allow_fixed=False):
mask = GetLogicalDrives()
@ -523,13 +531,13 @@ def get_removable_drives(debug=False): # {{{
return ans
# }}}
def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False): # {{{
def get_drive_letters_for_device(vendor_id, product_id, bcd=None, storage_number_map=None, debug=False): # {{{
'''
Get the drive letters for a connected device with the specieid USB ids. bcd
can be either None, in which case it is not tested, or it must be a list or
set like object containing bcds.
'''
rbuf = wbuf = None
rbuf = None
ans = []
# First search for a device matching the specified USB ids
@ -551,23 +559,30 @@ def get_drive_letters_for_device(vendor_id, product_id, bcd=None, debug=False):
return ans
# Get the device ids for all descendants of the found device
sn_map = get_storage_number_map(debug=debug)
sn_map = get_storage_number_map(debug=debug) if storage_number_map is None else storage_number_map
if debug:
prints('Storage number map:', sn_map)
for devinst in iterdescendants(devinfo.DevInst):
devid, wbuf = get_device_id(devinst, buf=wbuf)
try:
drive_letter = find_drive(devinst, sn_map, debug=debug)
except Exception as err:
prints('Storage number map:')
prints(pformat(sn_map))
descendants = frozenset(iterdescendants(devinfo.DevInst))
for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces():
if devinfo.DevInst in descendants:
if debug:
prints('Failed to get drive letter for: %s with error: %s' % (devid, as_unicode(err)))
import traceback
traceback.print_exc()
else:
if drive_letter:
ans.append(drive_letter)
try:
devid = get_device_id(devinfo.DevInst)[0]
except Exception as err:
devid = 'Unknown'
try:
storage_number = get_storage_number(devpath)
except Exception as err:
if debug:
prints('Failed to get storage number for: %s with error: %s' % (devid, as_unicode(err)))
continue
if debug:
prints('Drive letter for: %s is: %s' % (devid, drive_letter))
prints('Storage number for %s: %s' % (devid, storage_number))
if storage_number:
partitions = sn_map.get(storage_number[:2])
drive_letters = [x[1] for x in partitions or ()]
ans.extend(drive_letters)
return ans
@ -582,28 +597,20 @@ def get_storage_number_map(drive_types=(DRIVE_REMOVABLE, DRIVE_FIXED), debug=Fal
mask = GetLogicalDrives()
type_map = {letter:GetDriveType(letter + ':' + os.sep) for i, letter in enumerate(string.ascii_uppercase) if mask & (1 << i)}
drives = (letter for letter, dt in type_map.iteritems() if dt in drive_types)
ans = {}
ans = defaultdict(list)
for letter in drives:
try:
sn = get_storage_number('\\\\.\\' + letter + ':')
if debug and sn in ans:
prints('Duplicate storage number for drives: %s and %s' % (letter, ans[sn]))
ans[sn] = letter
ans[sn[:2]].append((sn[2], letter))
except WindowsError as err:
if debug:
prints('Failed to get storage number for drive: %s with error: %s' % (letter, as_unicode(err)))
continue
return ans
for val in ans.itervalues():
val.sort(key=itemgetter(0))
return dict(ans)
def find_drive(devinst, storage_number_map, debug=False):
for devinfo, devpath in DeviceSet(GUID_DEVINTERFACE_DISK).interfaces():
if devinfo.DevInst == devinst:
storage_number = get_storage_number(devpath)
drive_letter = storage_number_map.get(storage_number)
if drive_letter:
return drive_letter
# }}}
def get_usb_devices(): # {{{
@ -634,7 +641,7 @@ def is_usb_device_connected(vendor_id, product_id): # {{{
def eject_drive(drive_letter): # {{{
drive_letter = type('')(drive_letter)[0]
volume_access_path = '\\\\.\\' + drive_letter + ':'
devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path))
devinst = devinst_from_device_number(drive_letter, get_storage_number(volume_access_path).number)
if devinst is None:
raise ValueError('Could not find device instance number from drive letter: %s' % drive_letter)
parent = DEVINST(0)
@ -665,12 +672,11 @@ def devinst_from_device_number(drive_letter, device_number):
else:
raise ValueError('Unknown drive_type: %d' % drive_type)
for devinfo, devpath in DeviceSet(guid=guid).interfaces(ignore_errors=True):
if get_storage_number(devpath) == device_number:
if get_storage_number(devpath).number == device_number:
return devinfo.DevInst
# }}}
def develop(vendor_id=0x1949, product_id=0x4, bcd=None, do_eject=False): # {{{
from pprint import pprint
pprint(get_usb_devices())
print()
print('Is device connected:', is_usb_device_connected(vendor_id, product_id))