Skip to content
Snippets Groups Projects
Commit 4c84aedc authored by Gavin M. Roy's avatar Gavin M. Roy
Browse files

Reduce the method complexity in a few areas

parent 8b734cbf
No related branches found
No related tags found
No related merge requests found
Loading
Loading
@@ -10,7 +10,7 @@ except ImportError:
import threading
import time
 
from pamqp import specification
from pamqp import specification as spec
 
from rabbitpy import base
from rabbitpy import heartbeat
Loading
Loading
@@ -45,6 +45,7 @@ else:
SSL_CERT_MAP, SSL_VERSION_MAP = dict(), dict()
 
 
# pylint: disable=too-many-instance-attributes
class Connection(base.StatefulObject):
"""The Connection object is responsible for negotiating a connection and
managing its state. When creating a new instance of the Connection object,
Loading
Loading
@@ -246,6 +247,32 @@ class Connection(base.StatefulObject):
"""
return self._args['username'], self._args['password']
 
@property
def _channel0_closed(self):
"""Returns a boolean indicating if the base connection channel (0)
is closed.
:rtype: bool
"""
return self._channel0.open and not \
self._events.is_set(events.CHANNEL0_CLOSED)
def _close_all_channels(self, force=False):
"""Close all open channels
:param force: Force the connection to shutdown without AMQP negotiation
:type force: bool
"""
for chan_id in [chan_id for chan_id in self._channels
if not self._channels[chan_id].closed]:
if force:
# pylint: disable=protected-access
self._channels[chan_id]._force_close()
else:
self._channels[chan_id].close()
def _close_channels(self):
"""Close all the channels that are currently open."""
for channel_id in self._channels:
Loading
Loading
@@ -355,44 +382,36 @@ class Connection(base.StatefulObject):
raise exceptions.TooManyChannelsError
return self._max_channel_id + 1
 
@staticmethod
def _get_ssl_validation(values):
"""Return the value mapped from the string value in the query string
for the AMQP URL specifying which level of server certificate
validation is required, if any.
@property
def _max_channel_id(self):
"""Return the maximum channel ID that is currently being used.
 
:param dict values: The dict of query values from the AMQP URI
:rtype: int
 
"""
validation = (values.get('verify', [None])[0] or
values.get('ssl_validation', [None])[0])
if validation is None:
return None
if validation not in SSL_CERT_MAP:
raise ValueError('Unsupported server cert validation option: %s',
validation)
return SSL_CERT_MAP[validation]
@staticmethod
def _get_ssl_version(values):
"""Return the value mapped from the string value in the query string
for the AMQP URL for SSL version.
return max(list(self._channels.keys()))
 
:param dict values: The dict of query values from the AMQP URI
:rtype: int
def _maybe_close_connection(self):
"""Perform the steps required to shutdown channel0 and close the
socket.
 
"""
version = values.get('ssl_version', [None])[0]
if version is None:
return None
if version not in SSL_VERSION_MAP:
raise ValueError('Unuspported SSL version: %s' % version)
return SSL_VERSION_MAP[version]
if not self._channel0_closed:
self._channel0.close()
 
@property
def _max_channel_id(self):
return max(list(self._channels.keys()))
# Ensure the connection is closed
self._trigger_write()
# Let the IOLoop know to close
self._events.set(events.SOCKET_CLOSE)
# Break out of select waiting
self._trigger_write()
if (self._events.is_set(events.SOCKET_OPENED) and
not self._events.is_set(events.SOCKET_CLOSED)):
LOGGER.debug('Waiting on socket to close')
self._events.wait(events.SOCKET_CLOSED, 0.1)
 
@staticmethod
def _normalize_expectations(channel_id, expectations):
Loading
Loading
@@ -467,64 +486,135 @@ class Connection(base.StatefulObject):
"""
parsed = utils.urlparse(url)
 
# Ensure the protocol scheme is what is expected
if parsed.scheme not in list(self.PORTS.keys()):
raise ValueError('Unsupported protocol: %s' % parsed.scheme)
self._validate_uri_scheme(parsed.scheme)
 
# Toggle the SSL flag based upon the URL scheme
use_ssl = True if parsed.scheme == 'amqps' else False
# Toggle the SSL flag based upon the URL scheme and if SSL is enabled
use_ssl = True if parsed.scheme == 'amqps' and ssl else False
 
# Ensure that SSL is available if SSL is requested
if use_ssl and not ssl:
if parsed.scheme == 'amqps' and not ssl:
LOGGER.warning('SSL requested but not available, disabling')
use_ssl = False
 
# Use the default ports if one is not specified
port = parsed.port or (self.PORTS[AMQPS] if parsed.scheme == AMQPS
else self.PORTS[AMQP])
# Figure out the port as specified by the scheme
scheme_port = self.PORTS[AMQPS] if parsed.scheme == AMQPS \
else self.PORTS[AMQP]
 
# Set the vhost to be after the base slash if it was specified
vhost = parsed.path[1:] if parsed.path else self.DEFAULT_VHOST
# If the path was just the base path, set the vhost to the default
if not vhost:
vhost = self.DEFAULT_VHOST
vhost = self.DEFAULT_VHOST
if parsed.path:
vhost = parsed.path[1:] or self.DEFAULT_VHOST
 
# Parse the query string
query_values = utils.parse_qs(parsed.query)
channel_max = int(query_values.get('channel_max', [None])[0] or
self.DEFAULT_CHANNEL_MAX)
frame_max = int(query_values.get('frame_max', [None])[0] or
specification.FRAME_MAX_SIZE)
heartbeat_interval = int(query_values.get('heartbeat', [None])[0] or
self.DEFAULT_HEARTBEAT_INTERVAL)
# DEFAULT_TIMEOUT does not have to be 0, so explicitly setting 0
# (False-ish) should not lead to using it, thus no "or" here but
# the precise check against None.
timeout = query_values.get('timeout', [None])[0]
timeout = self.DEFAULT_TIMEOUT if timeout is None else float(timeout)
qargs = utils.parse_qs(parsed.query)
 
# Return the configuration dictionary to use when connecting
return {'host': parsed.hostname,
'port': port,
'virtual_host': utils.unquote(vhost),
'username': parsed.username or self.GUEST,
'password': parsed.password or self.GUEST,
'timeout': timeout,
'heartbeat': heartbeat_interval,
'frame_max': frame_max,
'channel_max': channel_max,
'locale': query_values.get('locale', [None])[0],
'ssl': use_ssl,
'cacertfile': (query_values.get('cacertfile', [None])[0] or
query_values.get('ssl_cacert', [None])[0]),
'certfile': (query_values.get('certfile', [None])[0] or
query_values.get('ssl_cert', [None])[0]),
'keyfile': (query_values.get('keyfile', [None])[0] or
query_values.get('ssl_key', [None])[0]),
'verify': self._get_ssl_validation(query_values),
'ssl_version': self._get_ssl_version(query_values)}
return {
'host': parsed.hostname,
'port': parsed.port or scheme_port,
'virtual_host': utils.unquote(vhost),
'username': parsed.username or self.GUEST,
'password': parsed.password or self.GUEST,
'timeout': self._qargs_int('timeout', qargs, self.DEFAULT_TIMEOUT),
'heartbeat': self._qargs_int('heartbeat', qargs,
self.DEFAULT_HEARTBEAT_INTERVAL),
'frame_max': self._qargs_int('frame_max', qargs,
spec.FRAME_MAX_SIZE),
'channel_max': self._qargs_int('channel_max', qargs,
self.DEFAULT_CHANNEL_MAX),
'locale': self._qargs_value('locale', qargs),
'ssl': use_ssl,
'cacertfile': self._qargs_mk_value(['cacertfile', 'ssl_cacert'],
qargs),
'certfile': self._qargs_mk_value(['certfile', 'ssl_cert'], qargs),
'keyfile': self._qargs_mk_value(['keyfile', 'ssl_key'], qargs),
'verify': self._qargs_ssl_validation(qargs),
'ssl_version': self._qargs_ssl_version(qargs)}
@staticmethod
def _qargs_int(key, values, default):
"""Return the query arg value as an integer for the specified key or
return the specified default value.
:param str key: The key to return the value for
:param dict values: The query value dict returned by urlparse
:param int default: The default return value
:rtype: int
"""
return int(values.get(key, [default])[0])
@staticmethod
def _qargs_float(key, values, default):
"""Return the query arg value as a float for the specified key or
return the specified default value.
:param str key: The key to return the value for
:param dict values: The query value dict returned by urlparse
:param float default: The default return value
:rtype: float
"""
return float(values.get(key, [default])[0])
def _qargs_ssl_validation(self, values):
"""Return the value mapped from the string value in the query string
for the AMQP URL specifying which level of server certificate
validation is required, if any.
:param dict values: The dict of query values from the AMQP URI
:rtype: int
"""
validation = self._qargs_mk_value(['verify', 'ssl_validation'], values)
if not validation:
return
elif validation not in SSL_CERT_MAP:
raise ValueError(
'Unsupported server cert validation option: %s',
validation)
return SSL_CERT_MAP[validation]
def _qargs_ssl_version(self, values):
"""Return the value mapped from the string value in the query string
for the AMQP URL for SSL version.
:param dict values: The dict of query values from the AMQP URI
:rtype: int
"""
version = self._qargs_value('ssl_version', values)
if not version:
return
elif version not in SSL_VERSION_MAP:
raise ValueError('Unuspported SSL version: %s' % version)
return SSL_VERSION_MAP[version]
@staticmethod
def _qargs_value(key, values, default=None):
"""Return the value from the query arguments for the specified key
or the default value.
:param str key: The key to get the value for
:param dict values: The query value dict returned by urlparse
:return: mixed
"""
return values.get(key, [default])[0]
def _qargs_mk_value(self, keys, values):
"""Try and find the query string value where the value can be specified
with different keys.
:param lists keys: The keys to check
:param dict values: The query value dict returned by urlparse
:return: mixed
"""
for key in keys:
value = self._qargs_value(key, values)
if value is not None:
return value
return None
 
def _shutdown_connection(self, force=False):
"""Tell Channel0 and IO to stop if they are not stopped.
Loading
Loading
@@ -541,36 +631,12 @@ class Connection(base.StatefulObject):
LOGGER.debug('Cant shutdown connection, IO is no longer alive')
return
 
# Close any open channels
for chan_id in [chan_id for chan_id in self._channels
if not self._channels[chan_id].closed]:
if force:
# pylint: disable=protected-access
self._channels[chan_id]._force_close()
else:
self._channels[chan_id].close()
self._close_all_channels(force)
self._maybe_close_connection()
 
# If the connection is still established, close it
if (self._channel0.open and
not self._events.is_set(events.CHANNEL0_CLOSED)):
self._channel0.close()
# Ensure the connection is closed
self._trigger_write()
# Let the IOLoop know to close
self._events.set(events.SOCKET_CLOSE)
# Break out of select waiting
self._trigger_write()
# Close the socket
if (self._events.is_set(events.SOCKET_OPENED) and
not self._events.is_set(events.SOCKET_CLOSED)):
LOGGER.debug('Waiting on socket to close')
self._events.wait(events.SOCKET_CLOSED, 0.1)
while self._io.is_alive():
time.sleep(0.25)
# Wait for the IO thread to stop
while self._io.is_alive():
time.sleep(0.25)
 
def _trigger_write(self):
"""Notifies the IO loop we need to write a frame by writing a byte
Loading
Loading
@@ -578,3 +644,13 @@ class Connection(base.StatefulObject):
 
"""
utils.trigger_write(self._io.write_trigger)
def _validate_uri_scheme(self, scheme):
"""Insure that the specified URI scheme is supported by rabbitpy
:param str scheme: The value to validate
:raises: ValueError
"""
if scheme not in list(self.PORTS.keys()):
raise ValueError('Unsupported URI scheme: %s' % scheme)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment