py3: Port SSL test

This commit is contained in:
Kovid Goyal 2019-04-15 11:04:14 +05:30
parent 96746583a2
commit 15c4080567
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 30 additions and 8 deletions

View File

@ -207,8 +207,10 @@ class LoopTest(BaseTest):
cert_file, key_file, ca_file = map(lambda x:os.path.join(tdir, x), 'cka') cert_file, key_file, ca_file = map(lambda x:os.path.join(tdir, x), 'cka')
create_server_cert(address, ca_file, cert_file, key_file, key_size=1024) create_server_cert(address, ca_file, cert_file, key_file, key_size=1024)
ctx = ssl.create_default_context(cafile=ca_file) ctx = ssl.create_default_context(cafile=ca_file)
with TestServer(lambda data:(data.path[0] + data.read()), ssl_certfile=cert_file, ssl_keyfile=key_file, listen_on=address, port=0) as server: with TestServer(
conn = http_client.HTTPSConnection(address, server.address[1], strict=True, context=ctx) lambda data:(data.path[0] + data.read().decode('utf-8')),
ssl_certfile=cert_file, ssl_keyfile=key_file, listen_on=address, port=0) as server:
conn = http_client.HTTPSConnection(address, server.address[1], context=ctx)
conn.request('GET', '/test', 'body') conn.request('GET', '/test', 'body')
r = conn.getresponse() r = conn.getresponse()
self.ae(r.status, http_client.OK) self.ae(r.status, http_client.OK)

View File

@ -129,7 +129,7 @@ static PyObject* create_rsa_cert_req(PyObject *self, PyObject *args) {
for (i = 0; i < PySequence_Length(alt_names); i++) { for (i = 0; i < PySequence_Length(alt_names); i++) {
t = PySequence_ITEM(alt_names, i); t = PySequence_ITEM(alt_names, i);
memset(buf, 0, 1024); memset(buf, 0, 1024);
snprintf(buf, 1023, "DNS:%s", PyBytes_AS_STRING(t)); snprintf(buf, 1023, "%s", PyBytes_AS_STRING(t));
Py_XDECREF(t); Py_XDECREF(t);
if(!add_ext(exts, NID_subject_alt_name, buf)) goto error; if(!add_ext(exts, NID_subject_alt_name, buf)) goto error;
} }

View File

@ -6,7 +6,10 @@ from __future__ import (unicode_literals, division, absolute_import,
__license__ = 'GPL v3' __license__ = 'GPL v3'
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
import socket
from calibre.constants import plugins from calibre.constants import plugins
from polyglot.builtins import unicode_type
certgen, err = plugins['certgen'] certgen, err = plugins['certgen']
if err: if err:
raise ImportError('Failed to load the certgen module with error: %s' % err) raise ImportError('Failed to load the certgen module with error: %s' % err)
@ -22,12 +25,13 @@ def create_cert_request(
organizational_unit=None, email_address=None, alt_names=(), basic_constraints=None organizational_unit=None, email_address=None, alt_names=(), basic_constraints=None
): ):
def enc(x): def enc(x):
if isinstance(x, type('')): if isinstance(x, unicode_type):
x = x.encode('ascii') x = x.encode('ascii')
return x or None return x or None
return certgen.create_rsa_cert_req( return certgen.create_rsa_cert_req(
key_pair, tuple(bytes(enc(x)) for x in alt_names if x), key_pair, tuple(enc(x) for x in alt_names if x), common_name,
*map(enc, (common_name, country, state, locality, organization, organizational_unit, email_address, basic_constraints)) country, state, locality, organization, organizational_unit, email_address, basic_constraints
) )
@ -52,11 +56,25 @@ def cert_info(cert):
def create_server_cert( def create_server_cert(
domain, ca_cert_file=None, server_cert_file=None, server_key_file=None, domain_or_ip, ca_cert_file=None, server_cert_file=None, server_key_file=None,
expire=365, ca_key_file=None, ca_name='Dummy Certificate Authority', key_size=2048, expire=365, ca_key_file=None, ca_name='Dummy Certificate Authority', key_size=2048,
country='IN', state='Maharashtra', locality='Mumbai', organization=None, country='IN', state='Maharashtra', locality='Mumbai', organization=None,
organizational_unit=None, email_address=None, alt_names=(), encrypt_key_with_password=None, organizational_unit=None, email_address=None, alt_names=(), encrypt_key_with_password=None,
): ):
is_ip = False
try:
socket.inet_pton(socket.AF_INET, domain_or_ip)
is_ip = True
except Exception:
try:
socket.inet_aton(socket.AF_INET6, domain_or_ip)
is_ip = True
except Exception:
pass
if not alt_names:
prefix = 'IP' if is_ip else 'DNS'
alt_names = ('{}:{}'.format(prefix, domain_or_ip),)
# Create the Certificate Authority # Create the Certificate Authority
cakey = create_key_pair(key_size) cakey = create_key_pair(key_size)
careq = create_cert_request(cakey, ca_name, basic_constraints='CA:TRUE') careq = create_cert_request(cakey, ca_name, basic_constraints='CA:TRUE')
@ -64,12 +82,14 @@ def create_server_cert(
# Create the server certificate issued by the newly created CA # Create the server certificate issued by the newly created CA
pkey = create_key_pair(key_size) pkey = create_key_pair(key_size)
req = create_cert_request(pkey, domain, country, state, locality, organization, organizational_unit, email_address, alt_names) req = create_cert_request(pkey, domain_or_ip, country, state, locality, organization, organizational_unit, email_address, alt_names)
cert = create_cert(req, cacert, cakey, expire=expire) cert = create_cert(req, cacert, cakey, expire=expire)
def export(dest, obj, func, *args): def export(dest, obj, func, *args):
if dest is not None: if dest is not None:
data = func(obj, *args) data = func(obj, *args)
if isinstance(data, unicode_type):
data = data.encode('utf-8')
if hasattr(dest, 'write'): if hasattr(dest, 'write'):
dest.write(data) dest.write(data)
else: else: