Skip to content
Snippets Groups Projects
Commit 02b5227f authored by nickolas360's avatar nickolas360
Browse files

Better SSL defaults; allow custom contexts

Fixes #4.
parent d37b19b4
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -460,7 +460,10 @@ class IRCBot(object):
 
:param str hostname: The hostname of the IRC server.
:param int port: The port of the IRC server.
:param bool use_ssl: Whether or not to use SSL/TLS.
:param bool use_ssl: Whether or not to use SSL/TLS. This can also be
an `~ssl.SSLContext` object, in which case it will be used instead
of a default `~ssl.SSLContext`, and the ``ca_certs`` and
``verify_ssl`` parameters will be ignored.
:param str ca_certs: Optional path to a list of trusted CA
certificates. If omitted, the system's default CA certificates will
be loaded instead.
Loading
Loading
@@ -478,8 +481,9 @@ class IRCBot(object):
self.socket = socket.create_connection((hostname, port))
 
if use_ssl:
context = use_ssl if isinstance(use_ssl, ssl.SSLContext) else None
self.socket = wrap_socket(
self.socket, hostname, ca_certs, verify_ssl)
self.socket, hostname, ca_certs, verify_ssl, context)
 
self.alive = True
if self.delay:
Loading
Loading
@@ -962,22 +966,40 @@ def get_required_args(func):
# Wraps a plain socket into an SSL one. Attempts to load default CA
# certificates if none are provided. Verifies the server's certificate and
# hostname if specified.
def wrap_socket(sock, hostname=None, ca_certs=None, verify_ssl=True):
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
# Use load_default_certs() if available (Python >= 3.4); otherwise, use
# set_default_verify_paths() (doesn't work on Windows).
load_default_certs = getattr(
context, "load_default_certs", context.set_default_verify_paths)
if verify_ssl:
context.verify_mode = ssl.CERT_REQUIRED
if ca_certs:
context.load_verify_locations(cafile=ca_certs)
else:
load_default_certs()
def wrap_socket(
sock, hostname=None, ca_certs=None, verify_ssl=True, context=None):
created = False
initialized = True
if context is None:
created = True
if hasattr(ssl, "create_default_context"):
context = ssl.create_default_context(cafile=ca_certs)
else:
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
initialized = False
if created:
mode = ssl.CERT_REQUIRED if verify_ssl else ssl.CERT_NONE
context.verify_mode = mode
if hasattr(context, "check_hostname"):
context.check_hostname = bool(verify_ssl)
if not initialized:
# When no certs are provided, use load_default_certs() if available
# (Python >= 3.4); otherwise, use set_default_verify_paths() (doesn't
# work on Windows).
if ca_certs:
context.load_verify_locations(cafile=ca_certs)
elif hasattr(context, "load_default_certs"):
context.load_default_certs()
else:
context.set_default_verify_paths()
if hasattr(context, "check_hostname"):
return context.wrap_socket(sock, server_hostname=hostname)
sock = context.wrap_socket(sock)
if verify_ssl:
if verify_ssl and not hasattr(context, "check_hostname"):
ssl.match_hostname(sock.getpeercert(), hostname)
return sock
 
Loading
Loading
Loading
Loading
@@ -68,10 +68,14 @@ class BaseBotTest(BaseTest):
# Patched classes will return the same instance every
# time so the instances can be inspected.
self.mock_socket = MockSocket()
self.patch("ssl.SSLContext", new=MockSSLContext.get_mock_class())
self.patch("ssl.SSLContext.__new__",
new=MockSSLContext.get_mock_class())
self.patch("ssl.match_hostname", spec=ssl.match_hostname)
self.patch("socket.create_connection", spec=socket.create_connection,
side_effect=get_mock_create_connection(self.mock_socket))
self.patch("ssl.create_default_context", create=True,
spec=ssl.create_default_context,
return_value=ssl.SSLContext())
 
# Loads data into the bot's socket to be returned later by recv().
def from_server(self, *lines):
Loading
Loading
@@ -424,20 +428,56 @@ class TestConnect(BaseBotTest):
self.bot.socket.connect.assert_called_once_with(("example.com", 6667))
 
def test_connect_ssl(self):
# ssl.create_default_context has already been patched with a mock,
# so it will be restored after the test completes.
del ssl.create_default_context
context = ssl.SSLContext()
context.load_default_certs = mock.Mock()
self.bot.connect("example.com", 6697, use_ssl=True)
self.assertIs(context.check_hostname, True)
self.assertCalledOnce(
context.wrap_socket, self.mock_socket,
server_hostname="example.com")
self.assertCalledOnce(context.load_default_certs)
self.assertIs(ssl.match_hostname.called, False)
def test_connect_ssl_ca_certs(self):
del ssl.create_default_context
context = ssl.SSLContext()
peercert = self.bot.socket.getpeercert()
self.bot.connect("example.com", 6697, use_ssl=True, ca_certs="/test")
load_default_certs = getattr(
context, "load_default_certs", context.set_default_verify_paths)
 
self.assertCalledOnce(context.load_verify_locations, cafile="/test")
self.assertIs(load_default_certs.called, False)
def test_connect_ssl_match_hostname(self):
del ssl.create_default_context
context = ssl.SSLContext()
del context.check_hostname
if hasattr(context, "load_default_certs"):
del context.load_default_certs
self.bot.connect("example.com", 6697, use_ssl=True)
peercert = self.bot.socket.getpeercert()
self.assertIs(hasattr(context, "check_hostname"), False)
self.assertCalledOnce(context.wrap_socket, self.mock_socket)
self.assertCalledOnce(load_default_certs)
self.assertCalledOnce(context.set_default_verify_paths)
self.assertCalledOnce(ssl.match_hostname, peercert, "example.com")
 
def test_connect_ssl_ca_certs(self):
self.bot.connect("example.com", 6697, use_ssl=True, ca_certs="/test")
def test_connect_ssl_create_default_context(self):
context = ssl.SSLContext()
self.assertCalledOnce(context.load_verify_locations, cafile="/test")
del context.check_hostname
self.bot.connect("example.com", 6697, use_ssl=True)
peercert = self.bot.socket.getpeercert()
load_default_certs = getattr(
context, "load_default_certs", context.set_default_verify_paths)
self.assertIs(hasattr(context, "check_hostname"), False)
self.assertCalledOnce(context.wrap_socket, self.mock_socket)
self.assertIs(load_default_certs.called, False)
self.assertCalledOnce(ssl.match_hostname, peercert, "example.com")
 
def test_register(self):
self.bot.connect("example.com", 6667)
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment