mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
py3: Port SSL test
This commit is contained in:
parent
96746583a2
commit
15c4080567
@ -207,8 +207,10 @@ class LoopTest(BaseTest):
|
|||||||
cert_file, key_file, ca_file = map(lambda x:os.path.join(tdir, x), 'cka')
|
cert_file, key_file, ca_file = map(lambda x:os.path.join(tdir, x), 'cka')
|
||||||
create_server_cert(address, ca_file, cert_file, key_file, key_size=1024)
|
create_server_cert(address, ca_file, cert_file, key_file, key_size=1024)
|
||||||
ctx = ssl.create_default_context(cafile=ca_file)
|
ctx = ssl.create_default_context(cafile=ca_file)
|
||||||
with TestServer(lambda data:(data.path[0] + data.read()), ssl_certfile=cert_file, ssl_keyfile=key_file, listen_on=address, port=0) as server:
|
with TestServer(
|
||||||
conn = http_client.HTTPSConnection(address, server.address[1], strict=True, context=ctx)
|
lambda data:(data.path[0] + data.read().decode('utf-8')),
|
||||||
|
ssl_certfile=cert_file, ssl_keyfile=key_file, listen_on=address, port=0) as server:
|
||||||
|
conn = http_client.HTTPSConnection(address, server.address[1], context=ctx)
|
||||||
conn.request('GET', '/test', 'body')
|
conn.request('GET', '/test', 'body')
|
||||||
r = conn.getresponse()
|
r = conn.getresponse()
|
||||||
self.ae(r.status, http_client.OK)
|
self.ae(r.status, http_client.OK)
|
||||||
|
@ -129,7 +129,7 @@ static PyObject* create_rsa_cert_req(PyObject *self, PyObject *args) {
|
|||||||
for (i = 0; i < PySequence_Length(alt_names); i++) {
|
for (i = 0; i < PySequence_Length(alt_names); i++) {
|
||||||
t = PySequence_ITEM(alt_names, i);
|
t = PySequence_ITEM(alt_names, i);
|
||||||
memset(buf, 0, 1024);
|
memset(buf, 0, 1024);
|
||||||
snprintf(buf, 1023, "DNS:%s", PyBytes_AS_STRING(t));
|
snprintf(buf, 1023, "%s", PyBytes_AS_STRING(t));
|
||||||
Py_XDECREF(t);
|
Py_XDECREF(t);
|
||||||
if(!add_ext(exts, NID_subject_alt_name, buf)) goto error;
|
if(!add_ext(exts, NID_subject_alt_name, buf)) goto error;
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,10 @@ from __future__ import (unicode_literals, division, absolute_import,
|
|||||||
__license__ = 'GPL v3'
|
__license__ = 'GPL v3'
|
||||||
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
__copyright__ = '2015, Kovid Goyal <kovid at kovidgoyal.net>'
|
||||||
|
|
||||||
|
import socket
|
||||||
from calibre.constants import plugins
|
from calibre.constants import plugins
|
||||||
|
from polyglot.builtins import unicode_type
|
||||||
|
|
||||||
certgen, err = plugins['certgen']
|
certgen, err = plugins['certgen']
|
||||||
if err:
|
if err:
|
||||||
raise ImportError('Failed to load the certgen module with error: %s' % err)
|
raise ImportError('Failed to load the certgen module with error: %s' % err)
|
||||||
@ -22,12 +25,13 @@ def create_cert_request(
|
|||||||
organizational_unit=None, email_address=None, alt_names=(), basic_constraints=None
|
organizational_unit=None, email_address=None, alt_names=(), basic_constraints=None
|
||||||
):
|
):
|
||||||
def enc(x):
|
def enc(x):
|
||||||
if isinstance(x, type('')):
|
if isinstance(x, unicode_type):
|
||||||
x = x.encode('ascii')
|
x = x.encode('ascii')
|
||||||
return x or None
|
return x or None
|
||||||
|
|
||||||
return certgen.create_rsa_cert_req(
|
return certgen.create_rsa_cert_req(
|
||||||
key_pair, tuple(bytes(enc(x)) for x in alt_names if x),
|
key_pair, tuple(enc(x) for x in alt_names if x), common_name,
|
||||||
*map(enc, (common_name, country, state, locality, organization, organizational_unit, email_address, basic_constraints))
|
country, state, locality, organization, organizational_unit, email_address, basic_constraints
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -52,11 +56,25 @@ def cert_info(cert):
|
|||||||
|
|
||||||
|
|
||||||
def create_server_cert(
|
def create_server_cert(
|
||||||
domain, ca_cert_file=None, server_cert_file=None, server_key_file=None,
|
domain_or_ip, ca_cert_file=None, server_cert_file=None, server_key_file=None,
|
||||||
expire=365, ca_key_file=None, ca_name='Dummy Certificate Authority', key_size=2048,
|
expire=365, ca_key_file=None, ca_name='Dummy Certificate Authority', key_size=2048,
|
||||||
country='IN', state='Maharashtra', locality='Mumbai', organization=None,
|
country='IN', state='Maharashtra', locality='Mumbai', organization=None,
|
||||||
organizational_unit=None, email_address=None, alt_names=(), encrypt_key_with_password=None,
|
organizational_unit=None, email_address=None, alt_names=(), encrypt_key_with_password=None,
|
||||||
):
|
):
|
||||||
|
is_ip = False
|
||||||
|
try:
|
||||||
|
socket.inet_pton(socket.AF_INET, domain_or_ip)
|
||||||
|
is_ip = True
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
socket.inet_aton(socket.AF_INET6, domain_or_ip)
|
||||||
|
is_ip = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not alt_names:
|
||||||
|
prefix = 'IP' if is_ip else 'DNS'
|
||||||
|
alt_names = ('{}:{}'.format(prefix, domain_or_ip),)
|
||||||
|
|
||||||
# Create the Certificate Authority
|
# Create the Certificate Authority
|
||||||
cakey = create_key_pair(key_size)
|
cakey = create_key_pair(key_size)
|
||||||
careq = create_cert_request(cakey, ca_name, basic_constraints='CA:TRUE')
|
careq = create_cert_request(cakey, ca_name, basic_constraints='CA:TRUE')
|
||||||
@ -64,12 +82,14 @@ def create_server_cert(
|
|||||||
|
|
||||||
# Create the server certificate issued by the newly created CA
|
# Create the server certificate issued by the newly created CA
|
||||||
pkey = create_key_pair(key_size)
|
pkey = create_key_pair(key_size)
|
||||||
req = create_cert_request(pkey, domain, country, state, locality, organization, organizational_unit, email_address, alt_names)
|
req = create_cert_request(pkey, domain_or_ip, country, state, locality, organization, organizational_unit, email_address, alt_names)
|
||||||
cert = create_cert(req, cacert, cakey, expire=expire)
|
cert = create_cert(req, cacert, cakey, expire=expire)
|
||||||
|
|
||||||
def export(dest, obj, func, *args):
|
def export(dest, obj, func, *args):
|
||||||
if dest is not None:
|
if dest is not None:
|
||||||
data = func(obj, *args)
|
data = func(obj, *args)
|
||||||
|
if isinstance(data, unicode_type):
|
||||||
|
data = data.encode('utf-8')
|
||||||
if hasattr(dest, 'write'):
|
if hasattr(dest, 'write'):
|
||||||
dest.write(data)
|
dest.write(data)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user