diff --git a/tests/test_decode.py b/tests/test_decode.py index 4750ca051..5f294931a 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -8,7 +8,7 @@ import uuid from uamqp import _decode as decode -from uamqp.types import AMQPTypes, TYPE, VALUE +from uamqp.amqp_types import AMQPTypes, TYPE, VALUE import pytest diff --git a/tests/test_encode.py b/tests/test_encode.py index 3ae5ba4d7..2fbc77ec8 100644 --- a/tests/test_encode.py +++ b/tests/test_encode.py @@ -8,7 +8,7 @@ import uuid import uamqp._encode as encode -from uamqp.types import AMQPTypes, TYPE, VALUE +from uamqp.amqp_types import AMQPTypes, TYPE, VALUE from uamqp.message import Message, Header, Properties import pytest diff --git a/uamqp/__init__.py b/uamqp/__init__.py index d4160e1a9..08e99ceee 100644 --- a/uamqp/__init__.py +++ b/uamqp/__init__.py @@ -1,13 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- __version__ = "2.0.0a1" -from ._connection import Connection -from ._transport import SSLTransport +from uamqp._connection import Connection +from uamqp._transport import SSLTransport -from .client import AMQPClient, ReceiveClient, SendClient +from uamqp.client import AMQPClient, ReceiveClient, SendClient diff --git a/uamqp/_connection.py b/uamqp/_connection.py index a26d220f3..c6089c100 100644 --- a/uamqp/_connection.py +++ b/uamqp/_connection.py @@ -1,35 +1,33 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint:disable=protected-access -import uuid import logging -import time -from urllib.parse import urlparse import socket +import time +import uuid from ssl import SSLError +from typing import Optional, List, Dict, Any, Tuple, NamedTuple, Union +from urllib.parse import urlparse -from ._transport import Transport -from .sasl import SASLTransport -from .session import Session -from .performatives import OpenFrame, CloseFrame -from .constants import ( +from uamqp._transport import Transport +from uamqp.constants import ( PORT, SECURE_PORT, MAX_CHANNELS, MAX_FRAME_SIZE_BYTES, HEADER_FRAME, ConnectionState, - EMPTY_FRAME -) - -from .error import ( - ErrorCondition, - AMQPConnectionError, - AMQPError + EMPTY_FRAME, ) +from uamqp.error import ErrorCondition, AMQPConnectionError, AMQPError +from uamqp.performatives import OpenFrame, CloseFrame +from uamqp.sasl import SASLTransport +from uamqp.session import Session _LOGGER = logging.getLogger(__name__) _CLOSING_STATES = ( @@ -37,12 +35,12 @@ ConnectionState.CLOSE_PIPE, ConnectionState.DISCARDING, ConnectionState.CLOSE_SENT, - ConnectionState.END + ConnectionState.END, ) def get_local_timeout(now, idle_timeout, last_frame_received_time): - # type: (float, float, float) -> bool + # type: (float, Optional[float], Optional[float]) -> bool """Check whether the local timeout has been reached since a new incoming frame was received. :param float now: The current time to check against. @@ -85,47 +83,53 @@ def __init__(self, endpoint, **kwargs): self._hostname = parsed_url.hostname if parsed_url.port: self._port = parsed_url.port - elif parsed_url.scheme == 'amqps': + elif parsed_url.scheme == "amqps": self._port = SECURE_PORT else: self._port = PORT self.state = None # type: Optional[ConnectionState] - transport = kwargs.get('transport') + transport = kwargs.get("transport") if transport: self._transport = transport - elif 'sasl_credential' in kwargs: + elif "sasl_credential" in kwargs: self._transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs + host=parsed_url.netloc, credential=kwargs["sasl_credential"], **kwargs ) else: self._transport = Transport(parsed_url.netloc, **kwargs) - self._container_id = kwargs.pop('container_id', None) or str(uuid.uuid4()) # type: str - self._max_frame_size = kwargs.pop('max_frame_size', MAX_FRAME_SIZE_BYTES) # type: int - self._remote_max_frame_size = None # type: Optional[int] - self._channel_max = kwargs.pop('channel_max', MAX_CHANNELS) # type: int - self._idle_timeout = kwargs.pop('idle_timeout', None) # type: Optional[int] - self._outgoing_locales = kwargs.pop('outgoing_locales', None) # type: Optional[List[str]] - self._incoming_locales = kwargs.pop('incoming_locales', None) # type: Optional[List[str]] - self._offered_capabilities = None # type: Optional[str] - self._desired_capabilities = kwargs.pop('desired_capabilities', None) # type: Optional[str] - self._properties = kwargs.pop('properties', None) # type: Optional[Dict[str, str]] - - self._allow_pipelined_open = kwargs.pop('allow_pipelined_open', True) # type: bool - self._remote_idle_timeout = None # type: Optional[int] - self._remote_idle_timeout_send_frame = None # type: Optional[int] - self._idle_timeout_empty_frame_send_ratio = kwargs.get('idle_timeout_empty_frame_send_ratio', 0.5) # type: float - self._last_frame_received_time = None # type: Optional[float] - self._last_frame_sent_time = None # type: Optional[float] - self._idle_wait_time = kwargs.get('idle_wait_time', 0.1) # type: float - self._network_trace = kwargs.get('network_trace', False) + self._container_id: str = kwargs.pop("container_id", None) or str(uuid.uuid4()) + self._max_frame_size: int = kwargs.pop("max_frame_size", MAX_FRAME_SIZE_BYTES) + self._remote_max_frame_size: Optional[int] = None + self._channel_max: int = kwargs.pop("channel_max", MAX_CHANNELS) + self._idle_timeout: Optional[int] = kwargs.pop("idle_timeout", None) + self._outgoing_locales: Optional[List[str]] = kwargs.pop( + "outgoing_locales", None + ) + self._incoming_locales: Optional[List[str]] = kwargs.pop( + "incoming_locales", None + ) + self._offered_capabilities: Optional[str] = None + self._desired_capabilities: Optional[str] = kwargs.pop( + "desired_capabilities", None + ) + self._properties: Optional[Dict[str, str]] = kwargs.pop("properties", None) + + self._allow_pipelined_open: bool = kwargs.pop("allow_pipelined_open", True) + self._remote_idle_timeout: Optional[int] = None + self._remote_idle_timeout_send_frame: Optional[int] = None + self._idle_timeout_empty_frame_send_ratio: float = kwargs.get( + "idle_timeout_empty_frame_send_ratio", 0.5 + ) + self._last_frame_received_time: Optional[float] = None + self._last_frame_sent_time: Optional[float] = None + self._idle_wait_time: float = kwargs.get("idle_wait_time", 0.1) + self._network_trace: bool = kwargs.get("network_trace", False) self._network_trace_params = { - 'connection': self._container_id, - 'session': None, - 'link': None + "connection": self._container_id, + "session": None, + "link": None, } self._error = None self._outgoing_endpoints = {} # type: Dict[int, Session] @@ -145,9 +149,14 @@ def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state) + _LOGGER.info( + "Connection '%s' state changed: %r -> %r", + self._container_id, + previous_state, + new_state, + ) for session in self._outgoing_endpoints.values(): - session._on_connection_state_change() # pylint:disable=protected-access + session._on_connection_state_change() def _connect(self): # type: () -> None @@ -170,17 +179,18 @@ def _connect(self): self._process_incoming_frame(*self._read_frame(wait=True)) if self.state != ConnectionState.HDR_EXCH: self._disconnect() - raise ValueError("Did not receive reciprocal protocol header. Disconnecting.") + raise ValueError( + "Did not receive reciprocal protocol header. Disconnecting." + ) else: self._set_state(ConnectionState.HDR_SENT) except (OSError, IOError, SSLError, socket.error) as exc: raise AMQPConnectionError( ErrorCondition.SocketError, - description="Failed to initiate the connection due to exception: " + str(exc), - error=exc + description="Failed to initiate the connection due to exception: " + + str(exc), + error=exc, ) - except Exception: - raise def _disconnect(self): # type: () -> None @@ -196,7 +206,7 @@ def _can_read(self): return self.state not in (ConnectionState.CLOSE_RCVD, ConnectionState.END) def _read_frame(self, wait=True, **kwargs): - # type: (bool, Any) -> Tuple[int, Optional[Tuple[int, NamedTuple]]] + # type: (Union[bool, float], Any) -> Tuple[Optional[int], Optional[Tuple[int, NamedTuple]]] """Read an incoming frame from the transport. :param Union[bool, float] wait: Whether to block on the socket while waiting for an incoming frame. @@ -208,15 +218,15 @@ def _read_frame(self, wait=True, **kwargs): descriptor and field values. """ if self._can_read(): - if wait == False: + if wait is False: return self._transport.receive_frame(**kwargs) - elif wait == True: + if wait is True: with self._transport.block(): return self._transport.receive_frame(**kwargs) - else: - with self._transport.block_with_timeout(timeout=wait): - return self._transport.receive_frame(**kwargs) + with self._transport.block_with_timeout(timeout=wait): + return self._transport.receive_frame(**kwargs) _LOGGER.warning("Cannot read frame in current state: %r", self.state) + return None, None def _can_write(self): # type: () -> bool @@ -249,10 +259,8 @@ def _send_frame(self, channel, frame, timeout=None, **kwargs): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send frame out due to exception: " + str(exc), - error=exc + error=exc, ) - except Exception: - raise else: _LOGGER.warning("Cannot write frame in current state: %r", self.state) @@ -264,9 +272,17 @@ def _get_next_outgoing_channel(self): :returns: The next available outgoing channel number. :rtype: int """ - if (len(self._incoming_endpoints) + len(self._outgoing_endpoints)) >= self._channel_max: - raise ValueError("Maximum number of channels ({}) has been reached.".format(self._channel_max)) - next_channel = next(i for i in range(1, self._channel_max) if i not in self._outgoing_endpoints) + if ( + len(self._incoming_endpoints) + len(self._outgoing_endpoints) + ) >= self._channel_max: + raise ValueError( + "Maximum number of channels ({}) has been reached.".format( + self._channel_max + ) + ) + next_channel = next( + i for i in range(1, self._channel_max) if i not in self._outgoing_endpoints + ) return next_channel def _outgoing_empty(self): @@ -286,21 +302,21 @@ def _outgoing_empty(self): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send empty frame due to exception: " + str(exc), - error=exc + error=exc, ) - except Exception: - raise def _outgoing_header(self): # type: () -> None """Send the AMQP protocol header to initiate the connection.""" self._last_frame_sent_time = time.time() if self._network_trace: - _LOGGER.info("-> header(%r)", HEADER_FRAME, extra=self._network_trace_params) + _LOGGER.info( + "-> header(%r)", HEADER_FRAME, extra=self._network_trace_params + ) self._transport.write(HEADER_FRAME) def _incoming_header(self, _, frame): - # type: (int, bytes) -> None + # type: (int, Tuple[Any, ...]) -> None """Process an incoming AMQP protocol header and update the connection state.""" if self._network_trace: _LOGGER.info("<- header(%r)", frame, extra=self._network_trace_params) @@ -319,11 +335,17 @@ def _outgoing_open(self): hostname=self._hostname, max_frame_size=self._max_frame_size, channel_max=self._channel_max, - idle_timeout=self._idle_timeout * 1000 if self._idle_timeout else None, # Convert to milliseconds + idle_timeout=self._idle_timeout * 1000 + if self._idle_timeout + else None, # Convert to milliseconds outgoing_locales=self._outgoing_locales, incoming_locales=self._incoming_locales, - offered_capabilities=self._offered_capabilities if self.state == ConnectionState.OPEN_RCVD else None, - desired_capabilities=self._desired_capabilities if self.state == ConnectionState.HDR_EXCH else None, + offered_capabilities=self._offered_capabilities + if self.state == ConnectionState.OPEN_RCVD + else None, + desired_capabilities=self._desired_capabilities + if self.state == ConnectionState.HDR_EXCH + else None, properties=self._properties, ) if self._network_trace: @@ -360,7 +382,8 @@ def _incoming_open(self, channel, frame): self.close( error=AMQPError( condition=ErrorCondition.NotAllowed, - description="OPEN frame received on a channel that is not 0." + description="OPEN frame received on a channel that is not 0.", + info=None, ) ) self._set_state(ConnectionState.END) @@ -368,8 +391,10 @@ def _incoming_open(self, channel, frame): _LOGGER.error("OPEN frame received in the OPENED state.") self.close() if frame[4]: - self._remote_idle_timeout = frame[4]/1000 # Convert to seconds - self._remote_idle_timeout_send_frame = self._idle_timeout_empty_frame_send_ratio * self._remote_idle_timeout + self._remote_idle_timeout = frame[4] / 1000 # Convert to seconds + self._remote_idle_timeout_send_frame = ( + self._idle_timeout_empty_frame_send_ratio * self._remote_idle_timeout # type: ignore + ) if frame[2] < 512: # Ensure minimum max frame size. pass # TODO: error @@ -381,7 +406,7 @@ def _incoming_open(self, channel, frame): self._outgoing_open() self._set_state(ConnectionState.OPENED) else: - pass # TODO what now...? + pass # TODO what now...? def _outgoing_close(self, error=None): # type: (Optional[AMQPError]) -> None @@ -407,7 +432,7 @@ def _incoming_close(self, channel, frame): ConnectionState.HDR_EXCH, ConnectionState.OPEN_RCVD, ConnectionState.CLOSE_SENT, - ConnectionState.DISCARDING + ConnectionState.DISCARDING, ] if self.state in disconnect_states: self._disconnect() @@ -417,7 +442,11 @@ def _incoming_close(self, channel, frame): close_error = None if channel > self._channel_max: _LOGGER.error("Invalid channel") - close_error = AMQPError(condition=ErrorCondition.InvalidField, description="Invalid channel", info=None) + close_error = AMQPError( + condition=ErrorCondition.InvalidField, + description="Invalid channel", + info=None, + ) self._set_state(ConnectionState.CLOSE_RCVD) self._outgoing_close(error=close_error) @@ -426,9 +455,7 @@ def _incoming_close(self, channel, frame): if frame[0]: self._error = AMQPConnectionError( - condition=frame[0][0], - description=frame[0][1], - info=frame[0][2] + condition=frame[0][0], description=frame[0][1], info=frame[0][2] ) _LOGGER.error("Connection error: {}".format(frame[0])) @@ -455,11 +482,13 @@ def _incoming_begin(self, channel, frame): try: existing_session = self._outgoing_endpoints[frame[0]] self._incoming_endpoints[channel] = existing_session - self._incoming_endpoints[channel]._incoming_begin(frame) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_begin( + frame + ) except KeyError: new_session = Session.from_incoming_frame(self, channel, frame) self._incoming_endpoints[channel] = new_session - new_session._incoming_begin(frame) # pylint:disable=protected-access + new_session._incoming_begin(frame) def _incoming_end(self, channel, frame): # type: (int, Tuple[Any, ...]) -> None @@ -475,14 +504,17 @@ def _incoming_end(self, channel, frame): :rtype: None """ try: - self._incoming_endpoints[channel]._incoming_end(frame) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_end( + frame + ) except KeyError: pass # TODO: channel error - #self._incoming_endpoints.pop(channel) # TODO - #self._outgoing_endpoints.pop(channel) # TODO + # self._incoming_endpoints.pop(channel) # TODO + # self._outgoing_endpoints.pop(channel) # TODO def _process_incoming_frame(self, channel, frame): - # type: (int, Optional[Union[bytes, Tuple[int, Tuple[Any, ...]]]]) -> bool + # type: (Optional[int], Optional[Tuple[int, Tuple[Any, ...]]]) -> bool + # pylint: disable=too-many-return-statements """Process an incoming frame, either directly or by passing to the necessary Session. :param int channel: The channel the frame arrived on. @@ -495,48 +527,57 @@ def _process_incoming_frame(self, channel, frame): should be interrupted. """ try: - performative, fields = frame # type: int, Tuple[Any, ...] + performative, fields = frame # type: ignore except TypeError: return True # Empty Frame or socket timeout try: self._last_frame_received_time = time.time() if performative == 20: - self._incoming_endpoints[channel]._incoming_transfer(fields) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_transfer( # type:ignore + fields + ) return False if performative == 21: - self._incoming_endpoints[channel]._incoming_disposition(fields) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_disposition( # type:ignore + fields + ) return False if performative == 19: - self._incoming_endpoints[channel]._incoming_flow(fields) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_flow( # type:ignore + fields + ) return False if performative == 18: - self._incoming_endpoints[channel]._incoming_attach(fields) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_attach( # type:ignore + fields + ) return False if performative == 22: - self._incoming_endpoints[channel]._incoming_detach(fields) # pylint:disable=protected-access + self._incoming_endpoints[channel]._incoming_detach( # type:ignore + fields + ) return True if performative == 17: - self._incoming_begin(channel, fields) + self._incoming_begin(channel, fields) # type:ignore return True if performative == 23: - self._incoming_end(channel, fields) + self._incoming_end(channel, fields) # type:ignore return True if performative == 16: - self._incoming_open(channel, fields) + self._incoming_open(channel, fields) # type:ignore return True if performative == 24: - self._incoming_close(channel, fields) + self._incoming_close(channel, fields) # type:ignore return True if performative == 0: - self._incoming_header(channel, fields) + self._incoming_header(channel, fields) # type:ignore return True if performative == 1: return False # TODO: incoming EMPTY - else: - _LOGGER.error("Unrecognized incoming frame: {}".format(frame)) - return True + _LOGGER.error("Unrecognized incoming frame: {}".format(frame)) + return True except KeyError: - return True #TODO: channel error + return True # TODO: channel error def _process_outgoing_frame(self, channel, frame): # type: (int, NamedTuple) -> None @@ -546,19 +587,29 @@ def _process_outgoing_frame(self, channel, frame): """ if self._network_trace: _LOGGER.info("-> %r", frame, extra=self._network_trace_params) - if not self._allow_pipelined_open and self.state in [ConnectionState.OPEN_PIPE, ConnectionState.OPEN_SENT]: + if not self._allow_pipelined_open and self.state in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ]: raise ValueError("Connection not configured to allow pipeline send.") - if self.state not in [ConnectionState.OPEN_PIPE, ConnectionState.OPEN_SENT, ConnectionState.OPENED]: + if self.state not in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ConnectionState.OPENED, + ]: raise ValueError("Connection not open.") now = time.time() - if get_local_timeout(now, self._idle_timeout, self._last_frame_received_time) or self._get_remote_timeout(now): + if get_local_timeout( + now, self._idle_timeout, self._last_frame_received_time + ) or self._get_remote_timeout(now): self.close( # TODO: check error condition error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, - description="No frame received for the idle timeout." + description="No frame received for the idle timeout.", + info=None, ), - wait=False + wait=False, ) return self._send_frame(channel, frame) @@ -576,8 +627,8 @@ def _get_remote_timeout(self, now): :returns: Whether the local connection should be shutdown due to timeout. """ if self._remote_idle_timeout and self._last_frame_sent_time: - time_since_last_sent = now - self._last_frame_sent_time - if time_since_last_sent > self._remote_idle_timeout_send_frame: + time_since_last_sent = now - self._last_frame_sent_time # type: ignore + if time_since_last_sent > self._remote_idle_timeout_send_frame: # type: ignore self._outgoing_empty() return False @@ -626,21 +677,24 @@ def listen(self, wait=False, batch=1, **kwargs): try: if self.state not in _CLOSING_STATES: now = time.time() - if get_local_timeout(now, self._idle_timeout, self._last_frame_received_time) or self._get_remote_timeout(now): + if get_local_timeout( + now, self._idle_timeout, self._last_frame_received_time + ) or self._get_remote_timeout(now): # TODO: check error condition self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, - description="No frame received for the idle timeout." + description="No frame received for the idle timeout.", + info=None, ), - wait=False + wait=False, ) return if self.state == ConnectionState.END: # TODO: check error condition self._error = AMQPConnectionError( condition=ErrorCondition.ConnectionCloseForced, - description="Connection was already closed." + description="Connection was already closed.", ) return for _ in range(batch): @@ -651,10 +705,8 @@ def listen(self, wait=False, batch=1, **kwargs): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send frame out due to exception: " + str(exc), - error=exc + error=exc, ) - except Exception: - raise def create_session(self, **kwargs): # type: (Any) -> Session @@ -678,14 +730,15 @@ def create_session(self, **kwargs): will be logged at the logging.INFO level. Default value is that configured for the connection. """ assigned_channel = self._get_next_outgoing_channel() - kwargs['allow_pipelined_open'] = self._allow_pipelined_open - kwargs['idle_wait_time'] = self._idle_wait_time + kwargs["allow_pipelined_open"] = self._allow_pipelined_open + kwargs["idle_wait_time"] = self._idle_wait_time session = Session( self, assigned_channel, - network_trace=kwargs.pop('network_trace', self._network_trace), + network_trace=kwargs.pop("network_trace", self._network_trace), network_trace_params=dict(self._network_trace_params), - **kwargs) + **kwargs + ) self._outgoing_endpoints[assigned_channel] = session return session @@ -708,7 +761,9 @@ def open(self, wait=False): if wait: self._wait_for_response(wait, ConnectionState.OPENED) elif not self._allow_pipelined_open: - raise ValueError("Connection has been configured to not allow piplined-open. Please set 'wait' parameter.") + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) def close(self, error=None, wait=False): # type: (Optional[AMQPError], bool) -> None @@ -720,15 +775,19 @@ def close(self, error=None, wait=False): :param bool wait: Whether to wait for a service Close response. Default is `False`. :rtype: None """ - if self.state in [ConnectionState.END, ConnectionState.CLOSE_SENT, ConnectionState.DISCARDING]: + if self.state in [ + ConnectionState.END, + ConnectionState.CLOSE_SENT, + ConnectionState.DISCARDING, + ]: return try: self._outgoing_close(error=error) if error: self._error = AMQPConnectionError( condition=error.condition, - description=error.descrption, - info=error.info + description=error.description, + info=error.info, ) if self.state == ConnectionState.OPEN_PIPE: self._set_state(ConnectionState.OC_PIPE) @@ -739,7 +798,7 @@ def close(self, error=None, wait=False): else: self._set_state(ConnectionState.CLOSE_SENT) self._wait_for_response(wait, ConnectionState.END) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except # If error happened during closing, ignore the error and set state to END _LOGGER.info("An error occurred when closing the connection: %r", exc) self._set_state(ConnectionState.END) diff --git a/uamqp/_decode.py b/uamqp/_decode.py index 53915069b..a61796c57 100644 --- a/uamqp/_decode.py +++ b/uamqp/_decode.py @@ -1,39 +1,46 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # pylint: disable=redefined-builtin, import-error +import logging import struct import uuid -import logging -from typing import List, Union, Tuple, Dict, Callable # pylint: disable=unused-import - - -from .message import Message, Header, Properties +from typing import ( + List, + Tuple, + Dict, + Callable, + Any, + Union, + Optional, +) # pylint: disable=unused-import + +from uamqp.message import Message, Header, Properties _LOGGER = logging.getLogger(__name__) -_HEADER_PREFIX = memoryview(b'AMQP') +_HEADER_PREFIX = memoryview(b"AMQP") _COMPOSITES = { - 35: 'received', - 36: 'accepted', - 37: 'rejected', - 38: 'released', - 39: 'modified', + 35: "received", + 36: "accepted", + 37: "rejected", + 38: "released", + 39: "modified", } -c_unsigned_char = struct.Struct('>B') -c_signed_char = struct.Struct('>b') -c_unsigned_short = struct.Struct('>H') -c_signed_short = struct.Struct('>h') -c_unsigned_int = struct.Struct('>I') -c_signed_int = struct.Struct('>i') -c_unsigned_long = struct.Struct('>L') -c_unsigned_long_long = struct.Struct('>Q') -c_signed_long_long = struct.Struct('>q') -c_float = struct.Struct('>f') -c_double = struct.Struct('>d') +c_unsigned_char = struct.Struct(">B") +c_signed_char = struct.Struct(">b") +c_unsigned_short = struct.Struct(">H") +c_signed_short = struct.Struct(">h") +c_unsigned_int = struct.Struct(">I") +c_signed_int = struct.Struct(">i") +c_unsigned_long = struct.Struct(">L") +c_unsigned_long_long = struct.Struct(">Q") +c_signed_long_long = struct.Struct(">q") +c_float = struct.Struct(">f") +c_double = struct.Struct(">d") def _decode_null(buffer): @@ -63,7 +70,7 @@ def _decode_empty(buffer): def _decode_boolean(buffer): # type: (memoryview) -> Tuple[memoryview, bool] - return buffer[1:], buffer[:1] == b'\x01' + return buffer[1:], buffer[:1] == b"\x01" def _decode_ubyte(buffer): @@ -164,7 +171,7 @@ def _decode_list_small(buffer): buffer = buffer[2:] values = [None] * count for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore return buffer, values @@ -174,30 +181,30 @@ def _decode_list_large(buffer): buffer = buffer[8:] values = [None] * count for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore return buffer, values def _decode_map_small(buffer): # type: (memoryview) -> Tuple[memoryview, Dict[Any, Any]] - count = int(buffer[1]/2) + count = int(buffer[1] / 2) buffer = buffer[2:] values = {} - for _ in range(count): - buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) - buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + for _ in range(count): + buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore values[key] = value return buffer, values def _decode_map_large(buffer): # type: (memoryview) -> Tuple[memoryview, Dict[Any, Any]] - count = int(c_unsigned_long.unpack(buffer[4:8])[0]/2) + count = int(c_unsigned_long.unpack(buffer[4:8])[0] / 2) buffer = buffer[8:] values = {} - for _ in range(count): - buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) - buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + for _ in range(count): + buffer, key = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore values[key] = value return buffer, values @@ -210,7 +217,7 @@ def _decode_array_small(buffer): buffer = buffer[3:] values = [None] * count for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) # type: ignore return buffer, values return buffer[2:], [] @@ -223,7 +230,7 @@ def _decode_array_large(buffer): buffer = buffer[9:] values = [None] * count for i in range(count): - buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) + buffer, values[i] = _DECODE_BY_CONSTRUCTOR[subconstructor](buffer) # type: ignore return buffer, values return buffer[8:], [] @@ -233,10 +240,10 @@ def _decode_described(buffer): # TODO: to move the cursor of the buffer to the described value based on size of the # descriptor without decoding descriptor value composite_type = buffer[0] - buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[1:]) - buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, descriptor = _DECODE_BY_CONSTRUCTOR[composite_type](buffer[1:]) # type: ignore + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore try: - composite_type = _COMPOSITES[descriptor] + composite_type = _COMPOSITES[descriptor] # type: ignore return buffer, {composite_type: value} except KeyError: return buffer, value @@ -244,12 +251,12 @@ def _decode_described(buffer): def decode_payload(buffer): # type: (memoryview) -> Message - message = {} + message: Dict[str, Any] = {} while buffer: # Ignore the first two bytes, they will always be the constructors for # described type then ulong. descriptor = buffer[2] - buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[3]](buffer[4:]) + buffer, value = _DECODE_BY_CONSTRUCTOR[buffer[3]](buffer[4:]) # type: ignore if descriptor == 112: message["header"] = Header(*value) elif descriptor == 113: @@ -285,7 +292,7 @@ def decode_frame(data): # described type then ulong. frame_type = data[2] compound_list_type = data[3] - if compound_list_type == 0xd0: + if compound_list_type == 0xD0: # list32 0xd0: data[4:8] is size, data[8:12] is count count = c_signed_int.unpack(data[8:12])[0] buffer = data[12:] @@ -293,16 +300,16 @@ def decode_frame(data): # list8 0xc0: data[4] is size, data[5] is count count = data[5] buffer = data[6:] - fields = [None] * count + fields: List[Union[None, memoryview]] = [None] * count for i in range(count): - buffer, fields[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) + buffer, fields[i] = _DECODE_BY_CONSTRUCTOR[buffer[0]](buffer[1:]) # type: ignore if frame_type == 20: fields.append(buffer) return frame_type, fields def decode_empty_frame(header): - # type: (memory) -> bytes + # type: (memoryview) -> Tuple[int,bytes] if header[0:4] == _HEADER_PREFIX: return 0, header.tobytes() if header[5] == 0: @@ -310,7 +317,9 @@ def decode_empty_frame(header): raise ValueError("Received unrecognized empty frame") -_DECODE_BY_CONSTRUCTOR = [None] * 256 # type: List[Callable[memoryview]] +_DECODE_BY_CONSTRUCTOR = [ + None +] * 256 # type: List[Optional[Callable[[memoryview], Tuple[memoryview, Any]]]] _DECODE_BY_CONSTRUCTOR[0] = _decode_described _DECODE_BY_CONSTRUCTOR[64] = _decode_null _DECODE_BY_CONSTRUCTOR[65] = _decode_true diff --git a/uamqp/_encode.py b/uamqp/_encode.py index 1eae46895..a7c576f1c 100644 --- a/uamqp/_encode.py +++ b/uamqp/_encode.py @@ -1,32 +1,35 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access import calendar import struct import uuid from datetime import datetime -from typing import Iterable, Union, Tuple, Dict # pylint: disable=unused-import - -import six - -from .types import TYPE, VALUE, AMQPTypes, FieldDefinition, ObjDefinition, ConstructorBytes -from .message import Header, Properties, Message -from . import performatives -from . import outcomes -from . import endpoints -from . import error - +from typing import Union, Tuple, Dict, Optional, Any, Sequence + +from uamqp.amqp_types import ( + TYPE, + VALUE, + AMQPTypes, + FieldDefinition, + ObjDefinition, + ConstructorBytes, +) +from uamqp.message import Message +from uamqp.performatives import Performative, TransferFrame _FRAME_OFFSET = b"\x02" -_FRAME_TYPE = b'\x00' +_FRAME_TYPE = b"\x00" def _construct(byte, construct): # type: (bytes, bool) -> bytes - return byte if construct else b'' + return byte if construct else b"" def encode_null(output, *args, **kwargs): # pylint: disable=unused-argument @@ -37,7 +40,9 @@ def encode_null(output, *args, **kwargs): # pylint: disable=unused-argument output.extend(ConstructorBytes.null) -def encode_boolean(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_boolean( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, bool, bool, Any) -> None """ @@ -48,13 +53,15 @@ def encode_boolean(output, value, with_constructor=True, **kwargs): # pylint: d value = bool(value) if with_constructor: output.extend(_construct(ConstructorBytes.bool, with_constructor)) - output.extend(b'\x01' if value else b'\x00') + output.extend(b"\x01" if value else b"\x00") return output.extend(ConstructorBytes.bool_true if value else ConstructorBytes.bool_false) -def encode_ubyte(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_ubyte( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, Union[int, bytes], bool, Any) -> None """ @@ -62,15 +69,17 @@ def encode_ubyte(output, value, with_constructor=True, **kwargs): # pylint: dis try: value = int(value) except ValueError: - value = ord(value) + value = ord(value) # type: ignore try: output.extend(_construct(ConstructorBytes.ubyte, with_constructor)) - output.extend(struct.pack('>B', abs(value))) + output.extend(struct.pack(">B", abs(value))) except struct.error: raise ValueError("Unsigned byte value must be 0-255") -def encode_ushort(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_ushort( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, int, bool, Any) -> None """ @@ -78,7 +87,7 @@ def encode_ushort(output, value, with_constructor=True, **kwargs): # pylint: di value = int(value) try: output.extend(_construct(ConstructorBytes.ushort, with_constructor)) - output.extend(struct.pack('>H', abs(value))) + output.extend(struct.pack(">H", abs(value))) except struct.error: raise ValueError("Unsigned byte value must be 0-65535") @@ -98,10 +107,10 @@ def encode_uint(output, value, with_constructor=True, use_smallest=True): try: if use_smallest and value <= 255: output.extend(_construct(ConstructorBytes.uint_small, with_constructor)) - output.extend(struct.pack('>B', abs(value))) + output.extend(struct.pack(">B", abs(value))) return output.extend(_construct(ConstructorBytes.uint_large, with_constructor)) - output.extend(struct.pack('>I', abs(value))) + output.extend(struct.pack(">I", abs(value))) except struct.error: raise ValueError("Value supplied for unsigned int invalid: {}".format(value)) @@ -114,25 +123,24 @@ def encode_ulong(output, value, with_constructor=True, use_smallest=True): label="unsigned long value in the range 0 to 255 inclusive"/> """ - try: - value = long(value) - except NameError: - value = int(value) + value = int(value) if value == 0: output.extend(ConstructorBytes.ulong_0) return try: if use_smallest and value <= 255: output.extend(_construct(ConstructorBytes.ulong_small, with_constructor)) - output.extend(struct.pack('>B', abs(value))) + output.extend(struct.pack(">B", abs(value))) return output.extend(_construct(ConstructorBytes.ulong_large, with_constructor)) - output.extend(struct.pack('>Q', abs(value))) + output.extend(struct.pack(">Q", abs(value))) except struct.error: raise ValueError("Value supplied for unsigned long invalid: {}".format(value)) -def encode_byte(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_byte( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, int, bool, Any) -> None """ @@ -140,12 +148,14 @@ def encode_byte(output, value, with_constructor=True, **kwargs): # pylint: disa value = int(value) try: output.extend(_construct(ConstructorBytes.byte, with_constructor)) - output.extend(struct.pack('>b', value)) + output.extend(struct.pack(">b", value)) except struct.error: raise ValueError("Byte value must be -128-127") -def encode_short(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_short( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, int, bool, Any) -> None """ @@ -153,7 +163,7 @@ def encode_short(output, value, with_constructor=True, **kwargs): # pylint: dis value = int(value) try: output.extend(_construct(ConstructorBytes.short, with_constructor)) - output.extend(struct.pack('>h', value)) + output.extend(struct.pack(">h", value)) except struct.error: raise ValueError("Short value must be -32768-32767") @@ -168,10 +178,10 @@ def encode_int(output, value, with_constructor=True, use_smallest=True): try: if use_smallest and (-128 <= value <= 127): output.extend(_construct(ConstructorBytes.int_small, with_constructor)) - output.extend(struct.pack('>b', value)) + output.extend(struct.pack(">b", value)) return output.extend(_construct(ConstructorBytes.int_large, with_constructor)) - output.extend(struct.pack('>i', value)) + output.extend(struct.pack(">i", value)) except struct.error: raise ValueError("Value supplied for int invalid: {}".format(value)) @@ -182,63 +192,69 @@ def encode_long(output, value, with_constructor=True, use_smallest=True): """ - try: - value = long(value) - except NameError: - value = int(value) + value = int(value) try: if use_smallest and (-128 <= value <= 127): output.extend(_construct(ConstructorBytes.long_small, with_constructor)) - output.extend(struct.pack('>b', value)) + output.extend(struct.pack(">b", value)) return output.extend(_construct(ConstructorBytes.long_large, with_constructor)) - output.extend(struct.pack('>q', value)) + output.extend(struct.pack(">q", value)) except struct.error: raise ValueError("Value supplied for long invalid: {}".format(value)) -def encode_float(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument + +def encode_float( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, float, bool, Any) -> None """ """ value = float(value) output.extend(_construct(ConstructorBytes.float, with_constructor)) - output.extend(struct.pack('>f', value)) + output.extend(struct.pack(">f", value)) -def encode_double(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_double( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, float, bool, Any) -> None """ """ value = float(value) output.extend(_construct(ConstructorBytes.double, with_constructor)) - output.extend(struct.pack('>d', value)) + output.extend(struct.pack(">d", value)) -def encode_timestamp(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_timestamp( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, Union[int, datetime], bool, Any) -> None """ """ if isinstance(value, datetime): - value = (calendar.timegm(value.utctimetuple()) * 1000) + (value.microsecond/1000) - value = int(value) + value = (calendar.timegm(value.utctimetuple()) * 1000) + (value.microsecond / 1000) # type: ignore + value = int(value) # type: ignore output.extend(_construct(ConstructorBytes.timestamp, with_constructor)) - output.extend(struct.pack('>q', value)) + output.extend(struct.pack(">q", value)) -def encode_uuid(output, value, with_constructor=True, **kwargs): # pylint: disable=unused-argument +def encode_uuid( + output, value, with_constructor=True, **kwargs +): # pylint: disable=unused-argument # type: (bytearray, Union[uuid.UUID, str, bytes], bool, Any) -> None """ """ - if isinstance(value, six.text_type): + if isinstance(value, str): value = uuid.UUID(value).bytes elif isinstance(value, uuid.UUID): value = value.bytes - elif isinstance(value, six.binary_type): + elif isinstance(value, bytes): value = uuid.UUID(bytes=value).bytes else: raise TypeError("Invalid UUID type: {}".format(type(value))) @@ -255,12 +271,12 @@ def encode_binary(output, value, with_constructor=True, use_smallest=True): length = len(value) if use_smallest and length <= 255: output.extend(_construct(ConstructorBytes.binary_small, with_constructor)) - output.extend(struct.pack('>B', length)) + output.extend(struct.pack(">B", length)) output.extend(value) return try: output.extend(_construct(ConstructorBytes.binary_large, with_constructor)) - output.extend(struct.pack('>L', length)) + output.extend(struct.pack(">L", length)) output.extend(value) except struct.error: raise ValueError("Binary data to long to encode") @@ -274,17 +290,17 @@ def encode_string(output, value, with_constructor=True, use_smallest=True): """ - if isinstance(value, six.text_type): - value = value.encode('utf-8') + if isinstance(value, str): + value = value.encode("utf-8") length = len(value) if use_smallest and length <= 255: output.extend(_construct(ConstructorBytes.string_small, with_constructor)) - output.extend(struct.pack('>B', length)) + output.extend(struct.pack(">B", length)) output.extend(value) return try: output.extend(_construct(ConstructorBytes.string_large, with_constructor)) - output.extend(struct.pack('>L', length)) + output.extend(struct.pack(">L", length)) output.extend(value) except struct.error: raise ValueError("String value too long to encode.") @@ -298,24 +314,24 @@ def encode_symbol(output, value, with_constructor=True, use_smallest=True): """ - if isinstance(value, six.text_type): - value = value.encode('utf-8') + if isinstance(value, str): + value = value.encode("utf-8") length = len(value) if use_smallest and length <= 255: output.extend(_construct(ConstructorBytes.symbol_small, with_constructor)) - output.extend(struct.pack('>B', length)) + output.extend(struct.pack(">B", length)) output.extend(value) return try: output.extend(_construct(ConstructorBytes.symbol_large, with_constructor)) - output.extend(struct.pack('>L', length)) + output.extend(struct.pack(">L", length)) output.extend(value) except struct.error: raise ValueError("Symbol value too long to encode.") def encode_list(output, value, with_constructor=True, use_smallest=True): - # type: (bytearray, Iterable[Any], bool, bool) -> None + # type: (bytearray, Sequence[Any], bool, bool) -> None """ @@ -335,20 +351,20 @@ def encode_list(output, value, with_constructor=True, use_smallest=True): encoded_size += len(encoded_values) if use_smallest and count <= 255 and encoded_size < 255: output.extend(_construct(ConstructorBytes.list_small, with_constructor)) - output.extend(struct.pack('>B', encoded_size + 1)) - output.extend(struct.pack('>B', count)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) else: try: output.extend(_construct(ConstructorBytes.list_large, with_constructor)) - output.extend(struct.pack('>L', encoded_size + 4)) - output.extend(struct.pack('>L', count)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) except struct.error: raise ValueError("List is too large or too long to be encoded.") output.extend(encoded_values) def encode_map(output, value, with_constructor=True, use_smallest=True): - # type: (bytearray, Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]], bool, bool) -> None + # type: (bytearray, Union[Dict[Any, Any], Sequence[Tuple[Any, Any]]], bool, bool) -> None """ @@ -359,7 +375,7 @@ def encode_map(output, value, with_constructor=True, use_smallest=True): encoded_size = 0 encoded_values = bytearray() try: - items = value.items() + items = value.items() # type: ignore except AttributeError: items = value for key, data in items: @@ -368,27 +384,26 @@ def encode_map(output, value, with_constructor=True, use_smallest=True): encoded_size = len(encoded_values) if use_smallest and count <= 255 and encoded_size < 255: output.extend(_construct(ConstructorBytes.map_small, with_constructor)) - output.extend(struct.pack('>B', encoded_size + 1)) - output.extend(struct.pack('>B', count)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) else: try: output.extend(_construct(ConstructorBytes.map_large, with_constructor)) - output.extend(struct.pack('>L', encoded_size + 4)) - output.extend(struct.pack('>L', count)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) except struct.error: raise ValueError("Map is too large or too long to be encoded.") output.extend(encoded_values) - return def _check_element_type(item, element_type): if not element_type: try: - return item['TYPE'] + return item["TYPE"] except (KeyError, TypeError): return type(item) try: - if item['TYPE'] != element_type: + if item["TYPE"] != element_type: raise TypeError("All elements in an array must be the same type.") except (KeyError, TypeError): if not isinstance(item, element_type): @@ -397,7 +412,7 @@ def _check_element_type(item, element_type): def encode_array(output, value, with_constructor=True, use_smallest=True): - # type: (bytearray, Iterable[Any], bool, bool) -> None + # type: (bytearray, Sequence[Any], bool, bool) -> None """ @@ -411,7 +426,9 @@ def encode_array(output, value, with_constructor=True, use_smallest=True): element_type = None for item in value: element_type = _check_element_type(item, element_type) - encode_value(encoded_values, item, with_constructor=first_item, use_smallest=False) + encode_value( + encoded_values, item, with_constructor=first_item, use_smallest=False + ) first_item = False if item is None: encoded_size -= 1 @@ -419,20 +436,20 @@ def encode_array(output, value, with_constructor=True, use_smallest=True): encoded_size += len(encoded_values) if use_smallest and count <= 255 and encoded_size < 255: output.extend(_construct(ConstructorBytes.array_small, with_constructor)) - output.extend(struct.pack('>B', encoded_size + 1)) - output.extend(struct.pack('>B', count)) + output.extend(struct.pack(">B", encoded_size + 1)) + output.extend(struct.pack(">B", count)) else: try: output.extend(_construct(ConstructorBytes.array_large, with_constructor)) - output.extend(struct.pack('>L', encoded_size + 4)) - output.extend(struct.pack('>L', count)) + output.extend(struct.pack(">L", encoded_size + 4)) + output.extend(struct.pack(">L", count)) except struct.error: raise ValueError("Array is too large or too long to be encoded.") output.extend(encoded_values) def encode_described(output, value, _=None, **kwargs): - # type: (bytearray, Tuple(Any, Any), bool, Any) -> None + # type: (bytearray, Sequence[Any], bool, Any) -> None output.extend(ConstructorBytes.descriptor) encode_value(output, value[0], **kwargs) encode_value(output, value[1], **kwargs) @@ -450,11 +467,11 @@ def encode_fields(value): """ if not value: return {TYPE: AMQPTypes.null, VALUE: None} - fields = {TYPE: AMQPTypes.map, VALUE:[]} + fields = {TYPE: AMQPTypes.map, VALUE: []} for key, data in value.items(): - if isinstance(key, six.text_type): - key = key.encode('utf-8') - fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: key}, data)) + if isinstance(key, str): + key = key.encode("utf-8") # type: ignore + fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: key}, data)) # type: ignore return fields @@ -471,14 +488,16 @@ def encode_annotations(value): """ if not value: return {TYPE: AMQPTypes.null, VALUE: None} - fields = {TYPE: AMQPTypes.map, VALUE:[]} + fields = {TYPE: AMQPTypes.map, VALUE: []} for key, data in value.items(): if isinstance(key, int): - fields[VALUE].append(({TYPE: AMQPTypes.ulong, VALUE: key}, {TYPE: None, VALUE: data})) + fields[VALUE].append( + ({TYPE: AMQPTypes.ulong, VALUE: key}, {TYPE: None, VALUE: data}) + ) else: - if isinstance(key, six.text_type): - key = key.encode('utf-8') - fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: key}, {TYPE: None, VALUE: data})) + if isinstance(key, str): + key = key.encode("utf-8") # type: ignore + fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: key}, {TYPE: None, VALUE: data})) # type: ignore return fields @@ -496,9 +515,9 @@ def encode_application_properties(value): """ if not value: return {TYPE: AMQPTypes.null, VALUE: None} - fields = {TYPE: AMQPTypes.map, VALUE:[]} + fields = {TYPE: AMQPTypes.map, VALUE: []} for key, data in value.items(): - fields[VALUE].append(({TYPE: AMQPTypes.string, VALUE: key}, data)) + fields[VALUE].append(({TYPE: AMQPTypes.string, VALUE: key}, data)) # type: ignore return fields @@ -512,11 +531,11 @@ def encode_message_id(value): """ if isinstance(value, int): return {TYPE: AMQPTypes.ulong, VALUE: value} - elif isinstance(value, uuid.UUID): + if isinstance(value, uuid.UUID): return {TYPE: AMQPTypes.uuid, VALUE: value} - elif isinstance(value, six.binary_type): + if isinstance(value, bytes): return {TYPE: AMQPTypes.binary, VALUE: value} - elif isinstance(value, six.text_type): + if isinstance(value, str): return {TYPE: AMQPTypes.string, VALUE: value} raise TypeError("Unsupported Message ID type.") @@ -526,16 +545,16 @@ def encode_node_properties(value): """Properties of a node. - + A symbol-keyed map containing properties of a node used when requesting creation or reporting the creation of a dynamic node. The following common properties are defined:: - + - `lifetime-policy`: The lifetime of a dynamically generated node. Definitionally, the lifetime will never be less than the lifetime of the link which caused its creation, however it is possible to extend the lifetime of dynamically created node using a lifetime policy. The value of this entry MUST be of a type which provides the lifetime-policy archetype. The following standard lifetime-policies are defined below: delete-on-close, delete-on-no-links, delete-on-no-messages or delete-on-no-links-or-messages. - + - `supported-dist-modes`: The distribution modes that the node supports. The value of this entry MUST be one or more symbols which are valid distribution-modes. That is, the value MUST be of the same type as would be valid in a field defined with the following attributes: @@ -544,7 +563,7 @@ def encode_node_properties(value): if not value: return {TYPE: AMQPTypes.null, VALUE: None} # TODO - fields = {TYPE: AMQPTypes.map, VALUE:[]} + fields = {TYPE: AMQPTypes.map, VALUE: []} # fields[{TYPE: AMQPTypes.symbol, VALUE: b'lifetime-policy'}] = { # TYPE: AMQPTypes.described, # VALUE: ( @@ -573,27 +592,27 @@ def encode_filter_set(value): """ if not value: return {TYPE: AMQPTypes.null, VALUE: None} - fields = {TYPE: AMQPTypes.map, VALUE:[]} + fields = {TYPE: AMQPTypes.map, VALUE: []} for name, data in value.items(): if data is None: described_filter = {TYPE: AMQPTypes.null, VALUE: None} else: - if isinstance(name, six.text_type): - name = name.encode('utf-8') + if isinstance(name, str): + name = name.encode("utf-8") # type: ignore descriptor, filter_value = data described_filter = { TYPE: AMQPTypes.described, - VALUE: ( + VALUE: ( # type: ignore {TYPE: AMQPTypes.symbol, VALUE: descriptor}, - filter_value - ) + filter_value, + ), } - fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: name}, described_filter)) + fields[VALUE].append(({TYPE: AMQPTypes.symbol, VALUE: name}, described_filter)) # type: ignore return fields def encode_unknown(output, value, **kwargs): - # type: (bytearray, Optional[Any]) -> None + # type: (bytearray, Optional[Any], Any) -> None """ Dynamic encoding according to the type of `value`. """ @@ -601,15 +620,15 @@ def encode_unknown(output, value, **kwargs): encode_null(output, **kwargs) elif isinstance(value, bool): encode_boolean(output, value, **kwargs) - elif isinstance(value, six.string_types): + elif isinstance(value, str): encode_string(output, value, **kwargs) elif isinstance(value, uuid.UUID): encode_uuid(output, value, **kwargs) - elif isinstance(value, (bytearray, six.binary_type)): + elif isinstance(value, (bytearray, bytes)): encode_binary(output, value, **kwargs) elif isinstance(value, float): encode_double(output, value, **kwargs) - elif isinstance(value, six.integer_types): + elif isinstance(value, int): encode_int(output, value, **kwargs) elif isinstance(value, datetime): encode_timestamp(output, value, **kwargs) @@ -661,39 +680,43 @@ def encode_unknown(output, value, **kwargs): def encode_value(output, value, **kwargs): # type: (bytearray, Any, Any) -> None try: - _ENCODE_MAP[value[TYPE]](output, value[VALUE], **kwargs) + _ENCODE_MAP[value[TYPE]](output, value[VALUE], **kwargs) # type: ignore except (KeyError, TypeError): encode_unknown(output, value, **kwargs) def describe_performative(performative): - # type: (Performative) -> Tuple(bytes, bytes) + # type: (Performative) -> Dict body = [] for index, value in enumerate(performative): - field = performative._definition[index] + field = performative._definition[index] # type: ignore if value is None: body.append({TYPE: AMQPTypes.null, VALUE: None}) elif field is None: continue elif isinstance(field.type, FieldDefinition): if field.multiple: - body.append({TYPE: AMQPTypes.array, VALUE: [_FIELD_DEFINITIONS[field.type](v) for v in value]}) + body.append( + {TYPE: AMQPTypes.array, VALUE: [_FIELD_DEFINITIONS[field.type](v) for v in value]} # type: ignore + ) else: body.append(_FIELD_DEFINITIONS[field.type](value)) elif isinstance(field.type, ObjDefinition): body.append(describe_performative(value)) else: if field.multiple: - body.append({TYPE: AMQPTypes.array, VALUE: [{TYPE: field.type, VALUE: v} for v in value]}) + body.append( + {TYPE: AMQPTypes.array, VALUE: [{TYPE: field.type, VALUE: v} for v in value]} # type: ignore + ) else: body.append({TYPE: field.type, VALUE: value}) return { TYPE: AMQPTypes.described, VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: performative._code}, - {TYPE: AMQPTypes.list, VALUE: body} - ) + {TYPE: AMQPTypes.ulong, VALUE: performative._code}, # type: ignore + {TYPE: AMQPTypes.list, VALUE: body}, + ), } @@ -707,13 +730,16 @@ def encode_payload(output, payload): encode_value(output, describe_performative(payload[0])) if payload[2]: # message annotations - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000072}, - encode_annotations(payload[2]), - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000072}, + encode_annotations(payload[2]), + ), + }, + ) if payload[3]: # properties # TODO: Header and Properties encoding can be optimized to @@ -722,51 +748,66 @@ def encode_payload(output, payload): encode_value(output, describe_performative(payload[3])) if payload[4]: # application properties - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000074}, - {TYPE: AMQPTypes.map, VALUE: payload[4]} - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000074}, + {TYPE: AMQPTypes.map, VALUE: payload[4]}, + ), + }, + ) if payload[5]: # data for item_value in payload[5]: - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000075}, - {TYPE: AMQPTypes.binary, VALUE: item_value} - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000075}, + {TYPE: AMQPTypes.binary, VALUE: item_value}, + ), + }, + ) if payload[6]: # sequence for item_value in payload[6]: - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000076}, - {TYPE: None, VALUE: item_value} - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000076}, + {TYPE: None, VALUE: item_value}, + ), + }, + ) if payload[7]: # value - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000077}, - {TYPE: None, VALUE: payload[7]} - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000077}, + {TYPE: None, VALUE: payload[7]}, + ), + }, + ) if payload[8]: # footer - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000078}, - encode_annotations(payload[8]), - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000078}, + encode_annotations(payload[8]), + ), + }, + ) # TODO: # currently the delivery annotations must be finally encoded instead of being encoded at the 2nd position @@ -774,31 +815,34 @@ def encode_payload(output, payload): # -- received message doesn't have it populated # check with service team? if payload[1]: # delivery annotations - encode_value(output, { - TYPE: AMQPTypes.described, - VALUE: ( - {TYPE: AMQPTypes.ulong, VALUE: 0x00000071}, - encode_annotations(payload[1]), - ) - }) + encode_value( + output, + { + TYPE: AMQPTypes.described, + VALUE: ( + {TYPE: AMQPTypes.ulong, VALUE: 0x00000071}, + encode_annotations(payload[1]), + ), + }, + ) return output def encode_frame(frame, frame_type=_FRAME_TYPE): - # type: (Performative) -> Tuple(bytes, bytes) + # type: (Performative, bytes) -> Tuple[bytes, Optional[bytes]] # TODO: allow passing type specific bytes manually, e.g. Empty Frame needs padding if frame is None: size = 8 - header = size.to_bytes(4, 'big') + _FRAME_OFFSET + frame_type + header = size.to_bytes(4, "big") + _FRAME_OFFSET + frame_type return header, None frame_description = describe_performative(frame) frame_data = bytearray() encode_value(frame_data, frame_description) - if isinstance(frame, performatives.TransferFrame): + if isinstance(frame, TransferFrame): frame_data += frame.payload size = len(frame_data) + 8 - header = size.to_bytes(4, 'big') + _FRAME_OFFSET + frame_type + header = size.to_bytes(4, "big") + _FRAME_OFFSET + frame_type return header, frame_data diff --git a/uamqp/_platform.py b/uamqp/_platform.py index e52153aa2..fe6ac2aec 100644 --- a/uamqp/_platform.py +++ b/uamqp/_platform.py @@ -7,6 +7,7 @@ import re import struct import sys +from typing import Tuple # Jython does not have this attribute try: @@ -15,12 +16,12 @@ from socket import IPPROTO_TCP as SOL_TCP # noqa -RE_NUM = re.compile(r'(\d+).+') +RE_NUM = re.compile(r"(\d+).+") def _linux_version_to_tuple(s): - # type: (str) -> Tuple[int, int, int] - return tuple(map(_versionatom, s.split('.')[:3])) + # type: (str) -> Tuple[int, ...] + return tuple(map(_versionatom, s.split(".")[:3])) def _versionatom(s): @@ -33,48 +34,55 @@ def _versionatom(s): # available socket options for TCP level KNOWN_TCP_OPTS = { - 'TCP_CORK', 'TCP_DEFER_ACCEPT', 'TCP_KEEPCNT', - 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', 'TCP_LINGER2', - 'TCP_MAXSEG', 'TCP_NODELAY', 'TCP_QUICKACK', - 'TCP_SYNCNT', 'TCP_USER_TIMEOUT', 'TCP_WINDOW_CLAMP', + "TCP_CORK", + "TCP_DEFER_ACCEPT", + "TCP_KEEPCNT", + "TCP_KEEPIDLE", + "TCP_KEEPINTVL", + "TCP_LINGER2", + "TCP_MAXSEG", + "TCP_NODELAY", + "TCP_QUICKACK", + "TCP_SYNCNT", + "TCP_USER_TIMEOUT", + "TCP_WINDOW_CLAMP", } LINUX_VERSION = None -if sys.platform.startswith('linux'): +if sys.platform.startswith("linux"): LINUX_VERSION = _linux_version_to_tuple(platform.release()) if LINUX_VERSION < (2, 6, 37): - KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') + KNOWN_TCP_OPTS.remove("TCP_USER_TIMEOUT") # Windows Subsystem for Linux is an edge-case: the Python socket library # returns most TCP_* enums, but they aren't actually supported if platform.release().endswith("Microsoft"): - KNOWN_TCP_OPTS = {'TCP_NODELAY', 'TCP_KEEPIDLE', 'TCP_KEEPINTVL', - 'TCP_KEEPCNT'} + KNOWN_TCP_OPTS = {"TCP_NODELAY", "TCP_KEEPIDLE", "TCP_KEEPINTVL", "TCP_KEEPCNT"} -elif sys.platform.startswith('darwin'): - KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') +elif sys.platform.startswith("darwin"): + KNOWN_TCP_OPTS.remove("TCP_USER_TIMEOUT") -elif 'bsd' in sys.platform: - KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') +elif "bsd" in sys.platform: + KNOWN_TCP_OPTS.remove("TCP_USER_TIMEOUT") # According to MSDN Windows platforms support getsockopt(TCP_MAXSSEG) but not # setsockopt(TCP_MAXSEG) on IPPROTO_TCP sockets. -elif sys.platform.startswith('win'): - KNOWN_TCP_OPTS = {'TCP_NODELAY'} +elif sys.platform.startswith("win"): + KNOWN_TCP_OPTS = {"TCP_NODELAY"} -elif sys.platform.startswith('cygwin'): - KNOWN_TCP_OPTS = {'TCP_NODELAY'} +elif sys.platform.startswith("cygwin"): + KNOWN_TCP_OPTS = {"TCP_NODELAY"} # illumos does not allow to set the TCP_MAXSEG socket option, # even if the Oracle documentation says otherwise. -elif sys.platform.startswith('sunos'): - KNOWN_TCP_OPTS.remove('TCP_MAXSEG') +elif sys.platform.startswith("sunos"): + KNOWN_TCP_OPTS.remove("TCP_MAXSEG") # aix does not allow to set the TCP_MAXSEG # or the TCP_USER_TIMEOUT socket options. -elif sys.platform.startswith('aix'): - KNOWN_TCP_OPTS.remove('TCP_MAXSEG') - KNOWN_TCP_OPTS.remove('TCP_USER_TIMEOUT') +elif sys.platform.startswith("aix"): + KNOWN_TCP_OPTS.remove("TCP_MAXSEG") + KNOWN_TCP_OPTS.remove("TCP_USER_TIMEOUT") if sys.version_info < (2, 7, 7): # pragma: no cover import functools @@ -83,6 +91,7 @@ def _to_bytes_arg(fun): @functools.wraps(fun) def _inner(s, *args, **kwargs): return fun(s.encode(), *args, **kwargs) + return _inner pack = _to_bytes_arg(struct.pack) @@ -96,11 +105,11 @@ def _inner(s, *args, **kwargs): unpack_from = struct.unpack_from __all__ = [ - 'LINUX_VERSION', - 'SOL_TCP', - 'KNOWN_TCP_OPTS', - 'pack', - 'pack_into', - 'unpack', - 'unpack_from', + "LINUX_VERSION", + "SOL_TCP", + "KNOWN_TCP_OPTS", + "pack", + "pack_into", + "unpack", + "unpack_from", ] diff --git a/uamqp/_transport.py b/uamqp/_transport.py index 85371fdd0..c33c80eef 100644 --- a/uamqp/_transport.py +++ b/uamqp/_transport.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # This is a fork of the transport.py which was originally written by Barry Pederson and # maintained by the Celery project: https://github.com/celery/py-amqp. # @@ -30,47 +30,47 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF # THE POSSIBILITY OF SUCH DAMAGE. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- from __future__ import absolute_import, unicode_literals import errno +import logging import re import socket import ssl import struct -from ssl import SSLError from contextlib import contextmanager from io import BytesIO -import logging +from ssl import SSLError from threading import Lock import certifi -from ._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack -from ._encode import encode_frame -from ._decode import decode_frame, decode_empty_frame -from .constants import TLS_HEADER_FRAME - +from uamqp._decode import decode_frame, decode_empty_frame +from uamqp._encode import encode_frame +from uamqp._platform import KNOWN_TCP_OPTS, SOL_TCP +from uamqp.constants import TLS_HEADER_FRAME try: import fcntl except ImportError: # pragma: no cover - fcntl = None # noqa + fcntl = None # type: ignore # noqa +# TODO: drop set_cloexec completely? try: - from os import set_cloexec # Python 3.4? + from os import set_cloexec # type: ignore # Python 3.4? except ImportError: # pragma: no cover # TODO: Drop this once we drop Python 2.7 support def set_cloexec(fd, cloexec): # noqa """Set flag to close fd after exec.""" if fcntl is None: - return + return None try: FD_CLOEXEC = fcntl.FD_CLOEXEC except AttributeError: raise NotImplementedError( - 'close-on-exec flag not supported on this platform', + "close-on-exec flag not supported on this platform", ) flags = fcntl.fcntl(fd, fcntl.F_GETFD) if cloexec: @@ -79,24 +79,49 @@ def set_cloexec(fd, cloexec): # noqa flags &= ~FD_CLOEXEC return fcntl.fcntl(fd, fcntl.F_SETFD, flags) + +try: + from socket import TCP_USER_TIMEOUT as imported_enum # type: ignore +except ImportError: + # should be in Python 3.6+ on Linux. + imported_enum = 18 + + +def _get_tcp_socket_defaults(sock): + tcp_opts = {} + for opt in KNOWN_TCP_OPTS: + enum = None + if opt == "TCP_USER_TIMEOUT": + enum = imported_enum + elif hasattr(socket, opt): + enum = getattr(socket, opt) + + if enum: + if opt in DEFAULT_SOCKET_SETTINGS: + tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] + elif hasattr(socket, opt): + tcp_opts[enum] = sock.getsockopt(SOL_TCP, getattr(socket, opt)) + return tcp_opts + + _LOGGER = logging.getLogger(__name__) _UNAVAIL = {errno.EAGAIN, errno.EINTR, errno.ENOENT, errno.EWOULDBLOCK} AMQP_PORT = 5672 AMQPS_PORT = 5671 -AMQP_FRAME = memoryview(b'AMQP') +AMQP_FRAME = memoryview(b"AMQP") EMPTY_BUFFER = bytes() SIGNED_INT_MAX = 0x7FFFFFFF # Match things like: [fe80::1]:5432, from RFC 2732 -IPV6_LITERAL = re.compile(r'\[([\.0-9a-f:]+)\](?::(\d+))?') +IPV6_LITERAL = re.compile(r"\[([\.0-9a-f:]+)\](?::(\d+))?") DEFAULT_SOCKET_SETTINGS = { - 'TCP_NODELAY': 1, - 'TCP_USER_TIMEOUT': 1000, - 'TCP_KEEPIDLE': 60, - 'TCP_KEEPINTVL': 10, - 'TCP_KEEPCNT': 9, + "TCP_NODELAY": 1, + "TCP_USER_TIMEOUT": 1000, + "TCP_KEEPIDLE": 60, + "TCP_KEEPINTVL": 10, + "TCP_KEEPCNT": 9, } @@ -127,8 +152,8 @@ def to_host_port(host, port=AMQP_PORT): if m.group(2): port = int(m.group(2)) else: - if ':' in host: - host, port = host.rsplit(':', 1) + if ":" in host: + host, port = host.rsplit(":", 1) port = int(port) return host, port @@ -140,9 +165,17 @@ class UnexpectedFrame(Exception): class _AbstractTransport(object): """Common superclass for TCP and SSL transports.""" - def __init__(self, host, port=AMQP_PORT, connect_timeout=None, - read_timeout=None, write_timeout=None, - socket_settings=None, raise_on_initial_eintr=True, **kwargs): + def __init__( # pylint: disable=unused-argument + self, + host, + port=AMQP_PORT, + connect_timeout=None, + read_timeout=None, + write_timeout=None, + socket_settings=None, + raise_on_initial_eintr=True, + **kwargs + ): self.connected = False self.sock = None self.raise_on_initial_eintr = raise_on_initial_eintr @@ -161,7 +194,9 @@ def connect(self): return self._connect(self.host, self.port, self.connect_timeout) self._init_socket( - self.socket_settings, self.read_timeout, self.write_timeout, + self.socket_settings, + self.read_timeout, + self.write_timeout, ) # we've sent the banner; signal connect # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner @@ -186,10 +221,10 @@ def block_with_timeout(self, timeout): try: yield self.sock except SSLError as exc: - if 'timed out' in str(exc): + if "timed out" in str(exc): # http://bugs.python.org/issue10272 raise socket.timeout() - elif 'The operation did not complete' in str(exc): + if "The operation did not complete" in str(exc): # Non-blocking SSL sockets can throw SSLError raise socket.timeout() raise @@ -211,10 +246,10 @@ def block(self): try: yield self.sock except SSLError as exc: - if 'timed out' in str(exc): + if "timed out" in str(exc): # http://bugs.python.org/issue10272 raise socket.timeout() - elif 'The operation did not complete' in str(exc): + if "The operation did not complete" in str(exc): # Non-blocking SSL sockets can throw SSLError raise socket.timeout() raise @@ -236,10 +271,10 @@ def non_blocking(self): try: yield self.sock except SSLError as exc: - if 'timed out' in str(exc): + if "timed out" in str(exc): # http://bugs.python.org/issue10272 raise socket.timeout() - elif 'The operation did not complete' in str(exc): + if "The operation did not complete" in str(exc): # Non-blocking SSL sockets can throw SSLError raise socket.timeout() raise @@ -269,7 +304,8 @@ def _connect(self, host, port, timeout): # first, resolve the address for a single address family try: entries = socket.getaddrinfo( - host, port, family, socket.SOCK_STREAM, SOL_TCP) + host, port, family, socket.SOCK_STREAM, SOL_TCP + ) entries_num = len(entries) except socket.gaierror: # we may have depleted all our options @@ -277,10 +313,11 @@ def _connect(self, host, port, timeout): # if getaddrinfo succeeded before for another address # family, reraise the previous socket.error since it's more # relevant to users - raise (e - if e is not None - else socket.error( - "failed to resolve broker hostname")) + raise ( + e + if e is not None + else socket.error("failed to resolve broker hostname") + ) continue # pragma: no cover # now that we have address(es) for the hostname, connect to broker @@ -306,7 +343,9 @@ def _connect(self, host, port, timeout): # hurray, we established connection return - def _init_socket(self, socket_settings, read_timeout, write_timeout): + def _init_socket( + self, socket_settings, read_timeout, write_timeout + ): # pylint: disable=unused-argument self.sock.settimeout(None) # set socket back to blocking mode self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self._set_socket_options(socket_settings) @@ -327,49 +366,26 @@ def _init_socket(self, socket_settings, read_timeout, write_timeout): # 1 second is enough for perf analysis self.sock.settimeout(1) # set socket back to non-blocking mode - def _get_tcp_socket_defaults(self, sock): - tcp_opts = {} - for opt in KNOWN_TCP_OPTS: - enum = None - if opt == 'TCP_USER_TIMEOUT': - try: - from socket import TCP_USER_TIMEOUT as enum - except ImportError: - # should be in Python 3.6+ on Linux. - enum = 18 - elif hasattr(socket, opt): - enum = getattr(socket, opt) - - if enum: - if opt in DEFAULT_SOCKET_SETTINGS: - tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] - elif hasattr(socket, opt): - tcp_opts[enum] = sock.getsockopt( - SOL_TCP, getattr(socket, opt)) - return tcp_opts - def _set_socket_options(self, socket_settings): - tcp_opts = self._get_tcp_socket_defaults(self.sock) + tcp_opts = _get_tcp_socket_defaults(self.sock) if socket_settings: tcp_opts.update(socket_settings) for opt, val in tcp_opts.items(): self.sock.setsockopt(SOL_TCP, opt, val) - def _read(self, n, initial=False): + def _read(self, n, initial=False, **kwargs): """Read exactly n bytes from the peer.""" - raise NotImplementedError('Must be overriden in subclass') + raise NotImplementedError("Must be overriden in subclass") def _setup_transport(self): """Do any additional initialization of the class.""" - pass def _shutdown_transport(self): """Do any preliminary work in shutting down the connection.""" - pass def _write(self, s): """Completely write a string to the peer.""" - raise NotImplementedError('Must be overriden in subclass') + raise NotImplementedError("Must be overriden in subclass") def close(self): if self.sock is not None: @@ -379,7 +395,7 @@ def close(self): # calling this method. try: self.sock.shutdown(socket.SHUT_RDWR) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except # TODO: shutdown could raise OSError, Transport endpoint is not connected if the endpoint is already # disconnected. can we safely ignore the errors since the close operation is initiated by us. _LOGGER.info("An error occurred when shutting down the socket: %r", exc) @@ -387,20 +403,21 @@ def close(self): self.sock = None self.connected = False - def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? + def read(self, verify_frame_type=0, **kwargs): # pylint: disable=unused-argument + # TODO: verify frame type? read = self._read read_frame_buffer = BytesIO() try: frame_header = memoryview(bytearray(8)) read_frame_buffer.write(read(8, buffer=frame_header, initial=True)) - channel = struct.unpack('>H', frame_header[6:])[0] + channel = struct.unpack(">H", frame_header[6:])[0] size = frame_header[0:4] if size == AMQP_FRAME: # Empty frame or AMQP header negotiation TODO return frame_header, channel, None - size = struct.unpack('>I', size)[0] + size = struct.unpack(">I", size)[0] offset = frame_header[4] - frame_type = frame_header[5] + # frame_type = frame_header[5] # >I is an unsigned int, but the argument to sock.recv is signed, # so we know the size can be at most 2 * SIGNED_INT_MAX @@ -408,7 +425,9 @@ def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? payload = memoryview(bytearray(payload_size)) if size > SIGNED_INT_MAX: read_frame_buffer.write(read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write(read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) + read_frame_buffer.write( + read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:]) + ) else: read_frame_buffer.write(read(payload_size, buffer=payload)) except socket.timeout: @@ -419,7 +438,7 @@ def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? except (OSError, IOError, SSLError, socket.error) as exc: # Don't disconnect for ssl read time outs # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): + if isinstance(exc, SSLError) and "timed out" in str(exc): raise socket.timeout() if get_errno(exc) not in _UNAVAIL: self.connected = False @@ -437,9 +456,9 @@ def write(self, s): self.connected = False raise - def receive_frame(self, *args, **kwargs): + def receive_frame(self, *args, **kwargs): # pylint: disable=unused-argument try: - header, channel, payload = self.read(**kwargs) + header, channel, payload = self.read(**kwargs) if not payload: decoded = decode_empty_frame(header) else: @@ -454,7 +473,7 @@ def send_frame(self, channel, frame, **kwargs): if performative is None: data = header else: - encoded_channel = struct.pack('>H', channel) + encoded_channel = struct.pack(">H", channel) data = header + encoded_channel + performative self.write(data) @@ -466,39 +485,48 @@ def negotiate(self, encode, decode): class SSLTransport(_AbstractTransport): """Transport that works over SSL.""" - def __init__(self, host, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + def __init__( + self, host, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs + ): # pylint: disable=redefined-outer-name self.sslopts = ssl if isinstance(ssl, dict) else {} self._read_buffer = BytesIO() super(SSLTransport, self).__init__( - host, - port=port, - connect_timeout=connect_timeout, - **kwargs + host, port=port, connect_timeout=connect_timeout, **kwargs ) def _setup_transport(self): """Wrap the socket in an SSL object.""" self.sock = self._wrap_socket(self.sock, **self.sslopts) - a = self.sock.do_handshake() - self._quick_recv = self.sock.recv + self.sock.do_handshake() def _wrap_socket(self, sock, context=None, **sslopts): if context: return self._wrap_context(sock, sslopts, **context) return self._wrap_socket_sni(sock, **sslopts) - def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): + def _wrap_context( + self, sock, sslopts, check_hostname=None, **ctx_options + ): # pylint: disable=no-self-use ctx = ssl.create_default_context(**ctx_options) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cafile=certifi.where()) ctx.check_hostname = check_hostname return ctx.wrap_socket(sock, **sslopts) - def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=ssl.CERT_REQUIRED, - ca_certs=None, do_handshake_on_connect=False, - suppress_ragged_eofs=True, server_hostname=None, - ciphers=None, ssl_version=None): + def _wrap_socket_sni( # pylint: disable=no-self-use + self, + sock, + keyfile=None, + certfile=None, + server_side=False, + cert_reqs=ssl.CERT_REQUIRED, + ca_certs=None, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=None, + ciphers=None, + ssl_version=None, + ): """Socket wrap with SNI headers. Default `ssl.wrap_socket` method augmented with support for @@ -511,30 +539,32 @@ def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, # ssl.PROTOCOL_TLS defined the equivalent is ssl.PROTOCOL_SSLv23 # we default to PROTOCOL_TLS and fallback to PROTOCOL_SSLv23 # TODO: Drop this once we drop Python 2.7 support - if hasattr(ssl, 'PROTOCOL_TLS'): + if hasattr(ssl, "PROTOCOL_TLS"): ssl_version = ssl.PROTOCOL_TLS else: ssl_version = ssl.PROTOCOL_SSLv23 opts = { - 'sock': sock, - 'keyfile': keyfile, - 'certfile': certfile, - 'server_side': server_side, - 'cert_reqs': cert_reqs, - 'ca_certs': ca_certs, - 'do_handshake_on_connect': do_handshake_on_connect, - 'suppress_ragged_eofs': suppress_ragged_eofs, - 'ciphers': ciphers, + "sock": sock, + "keyfile": keyfile, + "certfile": certfile, + "server_side": server_side, + "cert_reqs": cert_reqs, + "ca_certs": ca_certs, + "do_handshake_on_connect": do_handshake_on_connect, + "suppress_ragged_eofs": suppress_ragged_eofs, + "ciphers": ciphers, #'ssl_version': ssl_version } sock = ssl.wrap_socket(**opts) # Set SNI headers if supported - if (server_hostname is not None) and ( - hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and ( - hasattr(ssl, 'SSLContext')): - context = ssl.SSLContext(opts['ssl_version']) + if ( + (server_hostname is not None) + and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) + and (hasattr(ssl, "SSLContext")) + ): + context = ssl.SSLContext(opts["ssl_version"]) context.verify_mode = cert_reqs if cert_reqs != ssl.CERT_NONE: context.check_hostname = True @@ -551,8 +581,13 @@ def _shutdown_transport(self): except OSError: pass - def _read(self, toread, initial=False, buffer=None, - _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): + def _read( # pylint: disable=arguments-differ + self, + toread, + initial=False, + buffer=None, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR), + ): # According to SSL_read(3), it can at most return 16kb of data. # Thus, we use an internal read buffer like TCPTransport._read # to get the exact number of bytes wanted. @@ -568,7 +603,7 @@ def _read(self, toread, initial=False, buffer=None, except socket.error as exc: # ssl.sock.read may cause a SSLerror without errno # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): + if isinstance(exc, SSLError) and "timed out" in str(exc): raise socket.timeout() # ssl.sock.read may cause ENOENT if the # operation couldn't be performed (Issue celery#1414). @@ -578,7 +613,7 @@ def _read(self, toread, initial=False, buffer=None, continue raise if not nbytes: - raise IOError('Server unexpectedly closed connection') + raise IOError("Server unexpectedly closed connection") length += nbytes toread -= nbytes @@ -600,16 +635,20 @@ def _write(self, s): # None. n = 0 if not n: - raise IOError('Socket closed') + raise IOError("Socket closed") s = s[n:] - def negotiate(self): + def negotiate(self): # pylint: disable=arguments-differ with self.block(): self.write(TLS_HEADER_FRAME) - channel, returned_header = self.receive_frame(verify_frame_type=None) + # receive_frame returns tuple, [0] for channel, [1] for returned header + returned_header = self.receive_frame(verify_frame_type=None)[1] if returned_header[1] == TLS_HEADER_FRAME: - raise ValueError("Mismatching TLS header protocol. Excpected: {}, received: {}".format( - TLS_HEADER_FRAME, returned_header[1])) + raise ValueError( + "Mismatching TLS header protocol. Expected: {!r}, received: {!r}".format( + TLS_HEADER_FRAME, returned_header[1] + ) + ) class TCPTransport(_AbstractTransport): @@ -620,16 +659,16 @@ def _setup_transport(self): # do our own buffered reads. self._write = self.sock.sendall self._read_buffer = EMPTY_BUFFER - self._quick_recv = self.sock.recv - def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): + def _read( + self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR) + ): # pylint: disable=arguments-differ """Read exactly n bytes from the socket.""" - recv = self._quick_recv rbuf = self._read_buffer try: while len(rbuf) < n: try: - s = self.sock.read(n - len(rbuf)) + s = self.sock.recv(n - len(rbuf)) except socket.error as exc: if exc.errno in _errnos: if initial and self.raise_on_initial_eintr: @@ -637,7 +676,7 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): continue raise if not s: - raise IOError('Server unexpectedly closed connection') + raise IOError("Server unexpectedly closed connection") rbuf += s except: # noqa self._read_buffer = rbuf @@ -646,8 +685,13 @@ def _read(self, n, initial=False, _errnos=(errno.EAGAIN, errno.EINTR)): result, self._read_buffer = rbuf[:n], rbuf[n:] return result + def _write(self, s): + raise NotImplementedError("Not implemented") + -def Transport(host, connect_timeout=None, ssl=False, **kwargs): +def Transport( + host, connect_timeout=None, ssl=False, **kwargs +): # pylint: disable=redefined-outer-name """Create transport. Given a few parameters from the Connection constructor, diff --git a/uamqp/aio/__init__.py b/uamqp/aio/__init__.py index c513f35b9..6a8dbb2d4 100644 --- a/uamqp/aio/__init__.py +++ b/uamqp/aio/__init__.py @@ -1,15 +1,19 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from ._connection_async import Connection, ConnectionState -from ._link_async import Link, LinkDeliverySettleReason, LinkState -from ._receiver_async import ReceiverLink -from ._sasl_async import SASLPlainCredential, SASLTransport -from ._sender_async import SenderLink -from ._session_async import Session, SessionState -from ._transport_async import AsyncTransport -from ._client_async import AMQPClientAsync, ReceiveClientAsync, SendClientAsync -from ._authentication_async import SASTokenAuthAsync + +from uamqp.aio._authentication_async import SASTokenAuthAsync +from uamqp.aio._client_async import AMQPClientAsync, ReceiveClientAsync, SendClientAsync +from uamqp.aio._link_async import Link, LinkDeliverySettleReason, LinkState +from uamqp.aio._receiver_async import ReceiverLink +from uamqp.aio._sasl_async import SASLPlainCredential, SASLTransport +from uamqp.aio._sender_async import SenderLink +from uamqp.aio._session_async import Session, SessionState +from uamqp.aio._transport_async import AsyncTransport +from uamqp.aio._management_link_async import ManagementLink +from uamqp.aio._cbs_async import CBSAuthenticator +from uamqp.aio._management_operation_async import ManagementOperation +from uamqp.aio._connection_async import Connection, ConnectionState diff --git a/uamqp/aio/_authentication_async.py b/uamqp/aio/_authentication_async.py index 938fbe0a8..5157c55c7 100644 --- a/uamqp/aio/_authentication_async.py +++ b/uamqp/aio/_authentication_async.py @@ -1,45 +1,30 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- from functools import partial -from ..authentication import ( - _generate_sas_access_token, - SASTokenAuth, - JWTTokenAuth -) -from ..constants import AUTH_DEFAULT_EXPIRATION_SECONDS +from uamqp.authentication import _generate_sas_access_token, SASTokenAuth, JWTTokenAuth +from uamqp.constants import AUTH_DEFAULT_EXPIRATION_SECONDS -try: - from urlparse import urlparse - from urllib import quote_plus # type: ignore -except ImportError: - from urllib.parse import urlparse, quote_plus - -async def _generate_sas_token_async(auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS): +async def _generate_sas_token_async( + auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS +): return _generate_sas_access_token(auth_uri, sas_name, sas_key, expiry_in=expiry_in) class JWTTokenAuthAsync(JWTTokenAuth): - """""" # TODO: # 1. naming decision, suffix with Auth vs Credential + pass class SASTokenAuthAsync(SASTokenAuth): # TODO: # 1. naming decision, suffix with Auth vs Credential - def __init__( - self, - uri, - audience, - username, - password, - **kwargs - ): + def __init__(self, uri, audience, username, password, **kwargs): """ CBS authentication using SAS tokens. @@ -67,10 +52,8 @@ def __init__( """ super(SASTokenAuthAsync, self).__init__( - uri, - audience, - username, - password, - **kwargs + uri, audience, username, password, **kwargs + ) + self.get_token = partial( + _generate_sas_token_async, uri, username, password, self.expires_in ) - self.get_token = partial(_generate_sas_token_async, uri, username, password, self.expires_in) diff --git a/uamqp/aio/_cbs_async.py b/uamqp/aio/_cbs_async.py index c7f4e8c94..dae96cb32 100644 --- a/uamqp/aio/_cbs_async.py +++ b/uamqp/aio/_cbs_async.py @@ -1,23 +1,17 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- + +# pylint: disable=protected-access import logging -import asyncio from datetime import datetime -from ._management_link_async import ManagementLink -from ..utils import utc_now, utc_from_timestamp -from ..message import Message, Properties -from ..error import ( - AuthenticationException, - TokenAuthFailure, - TokenExpired, - ErrorCondition -) -from ..constants import ( +from uamqp.aio._management_link_async import ManagementLink +from uamqp.cbs import check_put_timeout_status, check_expiration_and_refresh_status +from uamqp.constants import ( CbsState, CbsAuthState, CBS_PUT_TOKEN, @@ -27,35 +21,34 @@ CBS_OPERATION, ManagementExecuteOperationResult, ManagementOpenResult, - DEFAULT_AUTH_TIMEOUT + DEFAULT_AUTH_TIMEOUT, ) -from ..cbs import ( - check_put_timeout_status, - check_expiration_and_refresh_status +from uamqp.error import ( + AuthenticationException, + TokenAuthFailure, + TokenExpired, + ErrorCondition, ) +from uamqp.message import Message, Properties +from uamqp.utils import utc_now, utc_from_timestamp _LOGGER = logging.getLogger(__name__) class CBSAuthenticator(object): - def __init__( - self, - session, - auth, - **kwargs - ): + def __init__(self, session, auth, **kwargs): self._session = session self._connection = self._session._connection self._mgmt_link = self._session.create_request_response_link_pair( - endpoint='$cbs', + endpoint="$cbs", on_amqp_management_open_complete=self._on_amqp_management_open_complete, on_amqp_management_error=self._on_amqp_management_error, - status_code_field=b'status-code', - status_description_field=b'status-description' + status_code_field=b"status-code", + status_description_field=b"status-description", ) # type: ManagementLink self._auth = auth - self._encoding = 'UTF-8' - self._auth_timeout = kwargs.pop('auth_timeout', DEFAULT_AUTH_TIMEOUT) + self._encoding = "UTF-8" + self._auth_timeout = kwargs.pop("auth_timeout", DEFAULT_AUTH_TIMEOUT) self._token_put_time = None self._expires_on = None self._token = None @@ -69,22 +62,22 @@ def __init__( async def _put_token(self, token, token_type, audience, expires_on=None): # type: (str, str, str, datetime) -> None - message = Message( + message = Message( # type: ignore value=token, - properties=Properties(message_id=self._mgmt_link.next_message_id), + properties=Properties(message_id=self._mgmt_link.next_message_id), # type: ignore application_properties={ CBS_NAME: audience, CBS_OPERATION: CBS_PUT_TOKEN, CBS_TYPE: token_type, - CBS_EXPIRATION: expires_on - } + CBS_EXPIRATION: expires_on, + }, ) await self._mgmt_link.execute_operation( message, self._on_execute_operation_complete, timeout=self._auth_timeout, operation=CBS_PUT_TOKEN, - type=token_type + type=token_type, ) self._mgmt_link.next_message_id += 1 @@ -95,12 +88,19 @@ async def _on_amqp_management_open_complete(self, management_open_result): self.state = CbsState.ERROR _LOGGER.info( "Unexpected AMQP management open complete in OPEN, CBS error occurred on connection %r.", - self._connection._container_id + self._connection._container_id, ) elif self.state == CbsState.OPENING: - self.state = CbsState.OPEN if management_open_result == ManagementOpenResult.OK else CbsState.CLOSED - _LOGGER.info("CBS for connection %r completed opening with status: %r", - self._connection._container_id, management_open_result) + self.state = ( + CbsState.OPEN + if management_open_result == ManagementOpenResult.OK + else CbsState.CLOSED + ) + _LOGGER.info( + "CBS for connection %r completed opening with status: %r", + self._connection._container_id, + management_open_result, + ) async def _on_amqp_management_error(self): # TODO: review the logging information, adjust level/information @@ -110,22 +110,31 @@ async def _on_amqp_management_error(self): elif self.state == CbsState.OPENING: self.state = CbsState.ERROR await self._mgmt_link.close() - _LOGGER.info("CBS for connection %r failed to open with status: %r", - self._connection._container_id, ManagementOpenResult.ERROR) + _LOGGER.info( + "CBS for connection %r failed to open with status: %r", + self._connection._container_id, + ManagementOpenResult.ERROR, + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR - _LOGGER.info("CBS error occurred on connection %r.", self._connection._container_id) + _LOGGER.info( + "CBS error occurred on connection %r.", self._connection._container_id + ) async def _on_execute_operation_complete( - self, + self, + execute_operation_result, + status_code, + status_description, + message, + error_condition=None, + ): # pylint: disable=unused-argument + _LOGGER.info( + "CBS Put token result (%r), status code: %s, status_description: %s.", execute_operation_result, status_code, status_description, - message, - error_condition=None - ): - _LOGGER.info("CBS Put token result (%r), status code: %s, status_description: %s.", - execute_operation_result, status_code, status_description) + ) self._token_status_code = status_code self._token_status_description = status_description @@ -136,18 +145,25 @@ async def _on_execute_operation_complete( # put-token-message sending failure, rejected self._token_status_code = 0 self._token_status_description = "Auth message has been rejected." - elif execute_operation_result == ManagementExecuteOperationResult.FAILED_BAD_STATUS: + elif ( + execute_operation_result + == ManagementExecuteOperationResult.FAILED_BAD_STATUS + ): self.auth_state = CbsAuthState.ERROR async def _update_status(self): if self.state == CbsAuthState.OK or self.state == CbsAuthState.REFRESH_REQUIRED: - is_expired, is_refresh_required = check_expiration_and_refresh_status(self._expires_on, self._refresh_window) + is_expired, is_refresh_required = check_expiration_and_refresh_status( + self._expires_on, self._refresh_window + ) if is_expired: self.state = CbsAuthState.EXPIRED elif is_refresh_required: self.state = CbsAuthState.REFRESH_REQUIRED elif self.state == CbsAuthState.IN_PROGRESS: - put_timeout = check_put_timeout_status(self._auth_timeout, self._token_put_time) + put_timeout = check_put_timeout_status( + self._auth_timeout, self._token_put_time + ) if put_timeout: self.state = CbsAuthState.TIMEOUT @@ -161,7 +177,7 @@ async def _cbs_link_ready(self): # Think how upper layer handle this exception + condition code raise AuthenticationException( condition=ErrorCondition.ClientError, - description="CBS authentication link is in broken status, please recreate the cbs link." + description="CBS authentication link is in broken status, please recreate the cbs link.", ) async def open(self): @@ -183,39 +199,51 @@ async def update_token(self): except AttributeError: self._token = access_token.token self._token_put_time = int(utc_now().timestamp()) - await self._put_token(self._token, self._auth.token_type, self._auth.audience, utc_from_timestamp(self._expires_on)) + await self._put_token( + self._token, + self._auth.token_type, + self._auth.audience, + utc_from_timestamp(self._expires_on), + ) async def handle_token(self): - if not (await self._cbs_link_ready()): + if not await self._cbs_link_ready(): return False await self._update_status() if self.auth_state == CbsAuthState.IDLE: await self.update_token() return False - elif self.auth_state == CbsAuthState.IN_PROGRESS: + if self.auth_state == CbsAuthState.IN_PROGRESS: return False - elif self.auth_state == CbsAuthState.OK: + if self.auth_state == CbsAuthState.OK: return True - elif self.auth_state == CbsAuthState.REFRESH_REQUIRED: - _LOGGER.info("Token on connection %r will expire soon - attempting to refresh.", - self._connection._container_id) + if self.auth_state == CbsAuthState.REFRESH_REQUIRED: + _LOGGER.info( + "Token on connection %r will expire soon - attempting to refresh.", + self._connection._container_id, + ) await self.update_token() return False - elif self.auth_state == CbsAuthState.FAILURE: + if self.auth_state == CbsAuthState.FAILURE: raise AuthenticationException( condition=ErrorCondition.InternalError, - description="Failed to open CBS authentication link." + description="Failed to open CBS authentication link.", ) - elif self.auth_state == CbsAuthState.ERROR: + if self.auth_state == CbsAuthState.ERROR: raise TokenAuthFailure( self._token_status_code, self._token_status_description, - encoding=self._encoding # TODO: drop off all the encodings + encoding=self._encoding, # TODO: drop off all the encodings ) - elif self.auth_state == CbsAuthState.TIMEOUT: + if self.auth_state == CbsAuthState.TIMEOUT: raise TimeoutError("Authentication attempt timed-out.") - elif self.auth_state == CbsAuthState.EXPIRED: + if self.auth_state == CbsAuthState.EXPIRED: raise TokenExpired( condition=ErrorCondition.InternalError, - description="CBS Authentication Expired." + description="CBS Authentication Expired.", ) + # default error case + raise AuthenticationException( + condition=ErrorCondition.InternalError, + description="Unrecognized authentication state", + ) diff --git a/uamqp/aio/_client_async.py b/uamqp/aio/_client_async.py index afb07cf56..1a4cf7931 100644 --- a/uamqp/aio/_client_async.py +++ b/uamqp/aio/_client_async.py @@ -1,36 +1,30 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # TODO: check this -# pylint: disable=super-init-not-called,too-many-lines +# pylint: disable=super-init-not-called,too-many-lines,protected-access import asyncio -import collections.abc import logging -import uuid -import time import queue -import certifi +import time from functools import partial -from ._connection_async import Connection -from ._management_operation_async import ManagementOperation -from ._receiver_async import ReceiverLink -from ._sender_async import SenderLink -from ._session_async import Session -from ._sasl_async import SASLTransport -from ._cbs_async import CBSAuthenticator -from ..client import AMQPClient as AMQPClientSync -from ..client import ReceiveClient as ReceiveClientSync -from ..client import SendClient as SendClientSync -from ..message import _MessageDelivery -from ..endpoints import Source, Target -from ..constants import ( - SenderSettleMode, - ReceiverSettleMode, +import certifi + +from uamqp.aio._cbs_async import CBSAuthenticator +from uamqp.aio._connection_async import Connection +from uamqp.aio._management_operation_async import ManagementOperation +from uamqp.client import ( + AMQPClient as AMQPClientSync, + ReceiveClient as ReceiveClientSync, + SendClient as SendClientSync, +) +from uamqp.constants import ( + LinkState, MessageDeliveryState, SEND_DISPOSITION_ACCEPT, SEND_DISPOSITION_REJECT, @@ -38,13 +32,8 @@ MESSAGE_DELIVERY_DONE_STATES, AUTH_TYPE_CBS, ) -from ..error import ( - ErrorResponse, - ErrorCondition, - AMQPException, - MessageException -) -from ..constants import LinkState +from uamqp.error import ErrorCondition, AMQPException, MessageException +from uamqp.message import _MessageDelivery _logger = logging.getLogger(__name__) @@ -136,11 +125,11 @@ async def _client_ready_async(self): # pylint: disable=no-self-use async def _client_run_async(self, **kwargs): """Perform a single Connection iteration.""" - await self._connection.listen(wait=self._socket_timeout) + await self._connection.listen(wait=self._socket_timeout, **kwargs) async def _close_link_async(self, **kwargs): if self._link and not self._link._is_closed: - await self._link.detach(close=True) + await self._link.detach(kwargs.pop('close',True), **kwargs) self._link = None async def _do_retryable_operation_async(self, operation, *args, **kwargs): @@ -160,21 +149,24 @@ async def _do_retryable_operation_async(self, operation, *args, **kwargs): retry_active = self._retry_policy.increment(retry_settings, exc) if not retry_active: break - await asyncio.sleep(self._retry_policy.get_backoff_time(retry_settings, exc)) + await asyncio.sleep( + self._retry_policy.get_backoff_time(retry_settings, exc) + ) if exc.condition == ErrorCondition.LinkDetachForced: await self._close_link_async() # if link level error, close and open a new link # TODO: check if there's any other code that we want to close link? - if exc.condition in (ErrorCondition.ConnectionCloseForced, ErrorCondition.SocketError): + if exc.condition in ( + ErrorCondition.ConnectionCloseForced, + ErrorCondition.SocketError, + ): # if connection detach or socket error, close and open a new connection await self.close_async() # TODO: check if there's any other code we want to close connection - except Exception: - raise finally: end_time = time.time() if absolute_timeout > 0: - absolute_timeout -= (end_time - start_time) - raise retry_settings['history'][-1] + absolute_timeout -= end_time - start_time + raise retry_settings["history"][-1] async def _keep_alive_worker_async(self): interval = 10 if self._keep_alive is True else self._keep_alive @@ -182,16 +174,20 @@ async def _keep_alive_worker_async(self): try: while self._connection and not self._shutdown: current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if elapsed_time >= interval: - _logger.info("Keeping %r connection alive. %r", - self.__class__.__name__, - self._connection._container_id) + _logger.info( + "Keeping %r connection alive. %r", + self.__class__.__name__, + self._connection._container_id, + ) await self._connection._get_remote_timeout(current_time) start_time = current_time await asyncio.sleep(1) except Exception as e: # pylint: disable=broad-except - _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) + _logger.info( + "Connection keep-alive for %r failed: %r.", self.__class__.__name__, e + ) async def open_async(self): """Asynchronously open the client. The client can create a new Connection @@ -212,30 +208,30 @@ async def open_async(self): self._connection = Connection( "amqps://" + self._hostname, sasl_credential=self._auth.sasl, - ssl={'ca_certs': certifi.where()}, + ssl={"ca_certs": certifi.where()}, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, ) await self._connection.open() if not self._session: self._session = self._connection.create_session( incoming_window=self._incoming_window, - outgoing_window=self._outgoing_window + outgoing_window=self._outgoing_window, ) await self._session.begin() if self._auth.auth_type == AUTH_TYPE_CBS: self._cbs_authenticator = CBSAuthenticator( - session=self._session, - auth=self._auth, - auth_timeout=self._auth_timeout + session=self._session, auth=self._auth, auth_timeout=self._auth_timeout ) await self._cbs_authenticator.open() if self._keep_alive: - self._keep_alive_thread = asyncio.ensure_future(self._keep_alive_worker_async()) + self._keep_alive_thread = asyncio.ensure_future( + self._keep_alive_worker_async() + ) self._shutdown = False async def close_async(self): @@ -266,7 +262,7 @@ async def auth_complete_async(self): :rtype: bool """ - if self._cbs_authenticator and not (await self._cbs_authenticator.handle_token()): + if self._cbs_authenticator and not await self._cbs_authenticator.handle_token(): await self._connection.listen(wait=self._socket_timeout) return False return True @@ -326,7 +322,7 @@ async def mgmt_request_async(self, message, **kwargs): operation = kwargs.pop("operation", None) operation_type = kwargs.pop("operation_type", None) node = kwargs.pop("node", "$management") - timeout = kwargs.pop('timeout', 0) + timeout = kwargs.pop("timeout", 0) try: mgmt_link = self._mgmt_links[node] except KeyError: @@ -338,18 +334,21 @@ async def mgmt_request_async(self, message, **kwargs): while not await mgmt_link.ready(): await self._connection.listen(wait=False) - operation_type = operation_type or b'empty' - status, description, response = await mgmt_link.execute( - message, - operation=operation, - operation_type=operation_type, - timeout=timeout - ) + operation_type = operation_type or b"empty" + response = ( + await mgmt_link.execute( + message, + operation=operation, + operation_type=operation_type, + timeout=timeout, + ) + )[ + 2 + ] # [0] for status, [1] for description, [2] for response return response class SendClientAsync(SendClientSync, AMQPClientAsync): - async def _client_ready_async(self): """Determine whether the client is ready to start receiving messages. To be ready, the connection must be open and authentication complete, @@ -368,7 +367,8 @@ async def _client_ready_async(self): send_settle_mode=self._send_settle_mode, rcv_settle_mode=self._receive_settle_mode, max_message_size=self._max_message_size, - properties=self._link_properties) + properties=self._link_properties, + ) await self._link.attach() return False if (await self._link.get_state()) != LinkState.ATTACHED: # ATTACHED @@ -395,9 +395,7 @@ async def _transfer_message_async(self, message_delivery, timeout=0): message_delivery.state = MessageDeliveryState.WaitingForSendAck on_send_complete = partial(self._on_send_complete_async, message_delivery) delivery = await self._link.send_transfer( - message_delivery.message, - on_send_complete=on_send_complete, - timeout=timeout + message_delivery.message, on_send_complete=on_send_complete, timeout=timeout ) if not delivery.sent: raise RuntimeError("Message is not sent.") @@ -416,12 +414,11 @@ async def _on_send_complete_async(self, message_delivery, reason, state): message_delivery, condition=error_info[0][0], description=error_info[0][1], - info=error_info[0][2] + info=error_info[0][2], ) except TypeError: self._process_send_error( - message_delivery, - condition=ErrorCondition.UnknownError + message_delivery, condition=ErrorCondition.UnknownError ) elif reason == LinkDeliverySettleReason.SETTLED: message_delivery.state = MessageDeliveryState.Ok @@ -431,8 +428,7 @@ async def _on_send_complete_async(self, message_delivery, reason, state): else: # NotDelivered and other unknown errors self._process_send_error( - message_delivery, - condition=ErrorCondition.UnknownError + message_delivery, condition=ErrorCondition.UnknownError ) async def _send_message_impl_async(self, message, **kwargs): @@ -440,9 +436,7 @@ async def _send_message_impl_async(self, message, **kwargs): expire_time = (time.time() + timeout) if timeout else None await self.open_async() message_delivery = _MessageDelivery( - message, - MessageDeliveryState.WaitingToBeSent, - expire_time + message, MessageDeliveryState.WaitingToBeSent, expire_time ) while not await self.client_ready_async(): @@ -454,25 +448,31 @@ async def _send_message_impl_async(self, message, **kwargs): while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: await self.do_work_async() if message_delivery.expiry and time.time() > message_delivery.expiry: - await self._on_send_complete_async(message_delivery, LinkDeliverySettleReason.TIMEOUT, None) + await self._on_send_complete_async( + message_delivery, LinkDeliverySettleReason.TIMEOUT, None + ) if message_delivery.state in ( MessageDeliveryState.Error, MessageDeliveryState.Cancelled, - MessageDeliveryState.Timeout + MessageDeliveryState.Timeout, ): try: raise message_delivery.error except TypeError: # This is a default handler - raise MessageException(condition=ErrorCondition.UnknownError, description="Send failed.") + raise MessageException( + condition=ErrorCondition.UnknownError, description="Send failed." + ) async def send_message_async(self, message, **kwargs): """ :param ~uamqp.message.Message message: :param int timeout: timeout in seconds """ - await self._do_retryable_operation_async(self._send_message_impl_async, message=message, **kwargs) + await self._do_retryable_operation_async( + self._send_message_impl_async, message=message, **kwargs + ) class ReceiveClientAsync(ReceiveClientSync, AMQPClientAsync): @@ -583,9 +583,9 @@ async def _client_ready_async(self): send_settle_mode=self._send_settle_mode, rcv_settle_mode=self._receive_settle_mode, max_message_size=self._max_message_size, - on_message_received=self._message_received, + on_message_received=self._message_received_async, properties=self._link_properties, - desired_capabilities=self._desired_capabilities + desired_capabilities=self._desired_capabilities, ) await self._link.attach() return False @@ -608,7 +608,7 @@ async def _client_run_async(self, **kwargs): return False return True - async def _message_received(self, message): + async def _message_received_async(self, message): """Callback run on receipt of every message. If there is a user-defined callback, this will be called. Additionally if the client is retrieving messages for a batch @@ -626,7 +626,9 @@ async def _message_received(self, message): # # Message was received with callback processing and wasn't settled. # _logger.info("Message was not settled.") - async def _receive_message_batch_impl_async(self, max_batch_size=None, on_message_received=None, timeout=0): + async def _receive_message_batch_impl_async( + self, max_batch_size=None, on_message_received=None, timeout=0 + ): self._message_received_callback = on_message_received max_batch_size = max_batch_size or self._link_credit timeout_time = time.time() + timeout if timeout else 0 @@ -653,7 +655,7 @@ async def _receive_message_batch_impl_async(self, max_batch_size=None, on_messag try: await asyncio.wait_for( self.do_work_async(batch=to_receive_size), - timeout=timeout_time - now_time if timeout else None + timeout=timeout_time - now_time if timeout else None, ) except asyncio.TimeoutError: pass @@ -708,6 +710,5 @@ async def receive_message_batch_async(self, **kwargs): :type timeout: float """ return await self._do_retryable_operation( - self._receive_message_batch_impl_async, - **kwargs + self._receive_message_batch_impl_async, **kwargs ) diff --git a/uamqp/aio/_connection_async.py b/uamqp/aio/_connection_async.py index efca9e235..7168266fe 100644 --- a/uamqp/aio/_connection_async.py +++ b/uamqp/aio/_connection_async.py @@ -4,37 +4,32 @@ # license information. # -------------------------------------------------------------------------- -import threading -import struct -import uuid +# pylint: disable=protected-access + +import asyncio import logging -import time -from urllib.parse import urlparse import socket +import time +import uuid from ssl import SSLError -from enum import Enum -import asyncio +from typing import Union +from urllib.parse import urlparse -from ._transport_async import AsyncTransport -from ._sasl_async import SASLTransport -from ._session_async import Session -from ..performatives import OpenFrame, CloseFrame -from .._connection import get_local_timeout -from ..constants import ( +from uamqp._connection import get_local_timeout +from uamqp.aio._sasl_async import SASLTransport +from uamqp.aio._session_async import Session +from uamqp.aio._transport_async import AsyncTransport +from uamqp.constants import ( PORT, SECURE_PORT, MAX_FRAME_SIZE_BYTES, MAX_CHANNELS, HEADER_FRAME, ConnectionState, - EMPTY_FRAME -) - -from ..error import ( - ErrorCondition, - AMQPConnectionError, - AMQPError + EMPTY_FRAME, ) +from uamqp.error import ErrorCondition, AMQPConnectionError, AMQPError +from uamqp.performatives import OpenFrame, CloseFrame _LOGGER = logging.getLogger(__name__) _CLOSING_STATES = ( @@ -42,7 +37,7 @@ ConnectionState.CLOSE_PIPE, ConnectionState.DISCARDING, ConnectionState.CLOSE_SENT, - ConnectionState.END + ConnectionState.END, ) @@ -65,46 +60,46 @@ def __init__(self, endpoint, **kwargs): self.hostname = parsed_url.hostname if parsed_url.port: self.port = parsed_url.port - elif parsed_url.scheme == 'amqps': + elif parsed_url.scheme == "amqps": self.port = SECURE_PORT else: self.port = PORT self.state = None - transport = kwargs.get('transport') + transport = kwargs.get("transport") if transport: self.transport = transport - elif 'sasl_credential' in kwargs: + elif "sasl_credential" in kwargs: self.transport = SASLTransport( - host=parsed_url.netloc, - credential=kwargs['sasl_credential'], - **kwargs + host=parsed_url.netloc, credential=kwargs["sasl_credential"], **kwargs ) else: self.transport = AsyncTransport(parsed_url.netloc, **kwargs) - self._container_id = kwargs.get('container_id') or str(uuid.uuid4()) - self.max_frame_size = kwargs.get('max_frame_size', MAX_FRAME_SIZE_BYTES) + self._container_id = kwargs.get("container_id") or str(uuid.uuid4()) + self.max_frame_size = kwargs.get("max_frame_size", MAX_FRAME_SIZE_BYTES) self._remote_max_frame_size = None - self.channel_max = kwargs.get('channel_max', MAX_CHANNELS) - self.idle_timeout = kwargs.get('idle_timeout') - self.outgoing_locales = kwargs.get('outgoing_locales') - self.incoming_locales = kwargs.get('incoming_locales') + self.channel_max = kwargs.get("channel_max", MAX_CHANNELS) + self.idle_timeout = kwargs.get("idle_timeout") + self.outgoing_locales = kwargs.get("outgoing_locales") + self.incoming_locales = kwargs.get("incoming_locales") self.offered_capabilities = None - self.desired_capabilities = kwargs.get('desired_capabilities') - self.properties = kwargs.pop('properties', None) + self.desired_capabilities = kwargs.get("desired_capabilities") + self.properties = kwargs.pop("properties", None) - self.allow_pipelined_open = kwargs.get('allow_pipelined_open', True) + self.allow_pipelined_open = kwargs.get("allow_pipelined_open", True) self.remote_idle_timeout = None self.remote_idle_timeout_send_frame = None - self.idle_timeout_empty_frame_send_ratio = kwargs.get('idle_timeout_empty_frame_send_ratio', 0.5) + self.idle_timeout_empty_frame_send_ratio = kwargs.get( + "idle_timeout_empty_frame_send_ratio", 0.5 + ) self.last_frame_received_time = None self.last_frame_sent_time = None - self.idle_wait_time = kwargs.get('idle_wait_time', 0.1) - self.network_trace = kwargs.get('network_trace', False) + self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) + self.network_trace = kwargs.get("network_trace", False) self.network_trace_params = { - 'connection': self._container_id, - 'session': None, - 'link': None + "connection": self._container_id, + "session": None, + "link": None, } self._error = None self.outgoing_endpoints = {} @@ -124,7 +119,12 @@ async def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Connection '%s' state changed: %r -> %r", self._container_id, previous_state, new_state) + _LOGGER.info( + "Connection '%s' state changed: %r -> %r", + self._container_id, + previous_state, + new_state, + ) for session in self.outgoing_endpoints.values(): await session._on_connection_state_change() @@ -141,17 +141,20 @@ async def _connect(self): await self._process_incoming_frame(*(await self._read_frame(wait=True))) if self.state != ConnectionState.HDR_EXCH: await self._disconnect() - raise ValueError("Did not receive reciprocal protocol header. Disconnecting.") + raise ValueError( + "Did not receive reciprocal protocol header. Disconnecting." + ) else: await self._set_state(ConnectionState.HDR_SENT) except (OSError, IOError, SSLError, socket.error) as exc: raise AMQPConnectionError( ErrorCondition.SocketError, - description="Failed to initiate the connection due to exception: " + str(exc), - error=exc + description="Failed to initiate the connection due to exception: " + + str(exc), + error=exc, ) - async def _disconnect(self, *args): + async def _disconnect(self, *args): # pylint: disable=unused-argument if self.state == ConnectionState.END: return await self._set_state(ConnectionState.END) @@ -172,7 +175,9 @@ def _can_write(self): """Whether the connection is in a state where it is legal to write outgoing frames.""" return self.state not in _CLOSING_STATES - async def _send_frame(self, channel, frame, timeout=None, **kwargs): + async def _send_frame( + self, channel, frame, timeout=None, **kwargs + ): # pylint: disable=unused-argument try: raise self._error except TypeError: @@ -186,7 +191,7 @@ async def _send_frame(self, channel, frame, timeout=None, **kwargs): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send frame out due to exception: " + str(exc), - error=exc + error=exc, ) else: _LOGGER.warning("Cannot write frame in current state: %r", self.state) @@ -199,9 +204,17 @@ def _get_next_outgoing_channel(self): :returns: The next available outgoing channel number. :rtype: int """ - if (len(self.incoming_endpoints) + len(self.outgoing_endpoints)) >= self.channel_max: - raise ValueError("Maximum number of channels ({}) has been reached.".format(self.channel_max)) - next_channel = next(i for i in range(1, self.channel_max) if i not in self.outgoing_endpoints) + if ( + len(self.incoming_endpoints) + len(self.outgoing_endpoints) + ) >= self.channel_max: + raise ValueError( + "Maximum number of channels ({}) has been reached.".format( + self.channel_max + ) + ) + next_channel = next( + i for i in range(1, self.channel_max) if i not in self.outgoing_endpoints + ) return next_channel async def _outgoing_empty(self): @@ -215,7 +228,7 @@ async def _outgoing_empty(self): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send empty frame due to exception: " + str(exc), - error=exc + error=exc, ) async def _outgoing_header(self): @@ -224,7 +237,7 @@ async def _outgoing_header(self): _LOGGER.info("-> header(%r)", HEADER_FRAME, extra=self.network_trace_params) await self.transport.write(HEADER_FRAME) - async def _incoming_header(self, channel, frame): + async def _incoming_header(self, channel, frame): # pylint: disable=unused-argument if self.network_trace: _LOGGER.info("<- header(%r)", frame, extra=self.network_trace_params) if self.state == ConnectionState.START: @@ -240,11 +253,17 @@ async def _outgoing_open(self): hostname=self.hostname, max_frame_size=self.max_frame_size, channel_max=self.channel_max, - idle_timeout=self.idle_timeout * 1000 if self.idle_timeout else None, # Convert to milliseconds + idle_timeout=self.idle_timeout * 1000 + if self.idle_timeout + else None, # Convert to milliseconds outgoing_locales=self.outgoing_locales, incoming_locales=self.incoming_locales, - offered_capabilities=self.offered_capabilities if self.state == ConnectionState.OPEN_RCVD else None, - desired_capabilities=self.desired_capabilities if self.state == ConnectionState.HDR_EXCH else None, + offered_capabilities=self.offered_capabilities + if self.state == ConnectionState.OPEN_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == ConnectionState.HDR_EXCH + else None, properties=self.properties, ) if self.network_trace: @@ -263,7 +282,9 @@ async def _incoming_open(self, channel, frame): await self.close() if frame[4]: self.remote_idle_timeout = frame[4] / 1000 # Convert to seconds - self.remote_idle_timeout_send_frame = self.idle_timeout_empty_frame_send_ratio * self.remote_idle_timeout + self.remote_idle_timeout_send_frame = ( + self.idle_timeout_empty_frame_send_ratio * self.remote_idle_timeout + ) if frame[2] < 512: pass # TODO: error @@ -291,7 +312,7 @@ async def _incoming_close(self, channel, frame): ConnectionState.HDR_EXCH, ConnectionState.OPEN_RCVD, ConnectionState.CLOSE_SENT, - ConnectionState.DISCARDING + ConnectionState.DISCARDING, ] if self.state in disconnect_states: await self._disconnect() @@ -307,9 +328,7 @@ async def _incoming_close(self, channel, frame): if frame[0]: self._error = AMQPConnectionError( - condition=frame[0][0], - description=frame[0][1], - info=frame[0][2] + condition=frame[0][0], description=frame[0][1], info=frame[0][2] ) _LOGGER.error("Connection error: {}".format(frame[0])) @@ -332,6 +351,7 @@ async def _incoming_end(self, channel, frame): # self.outgoing_endpoints.pop(channel) # TODO async def _process_incoming_frame(self, channel, frame): + # pylint: disable=too-many-return-statements try: performative, fields = frame except TypeError: @@ -370,29 +390,36 @@ async def _process_incoming_frame(self, channel, frame): return True if performative == 1: return False # TODO: incoming EMPTY - else: - _LOGGER.error("Unrecognized incoming frame: {}".format(frame)) - return True + _LOGGER.error("Unrecognized incoming frame: {}".format(frame)) + return True except KeyError: return True # TODO: channel error async def _process_outgoing_frame(self, channel, frame): if self.network_trace: _LOGGER.info("-> %r", frame, extra=self.network_trace_params) - if not self.allow_pipelined_open and self.state in [ConnectionState.OPEN_PIPE, ConnectionState.OPEN_SENT]: + if not self.allow_pipelined_open and self.state in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ]: raise ValueError("Connection not configured to allow pipeline send.") - if self.state not in [ConnectionState.OPEN_PIPE, ConnectionState.OPEN_SENT, ConnectionState.OPENED]: + if self.state not in [ + ConnectionState.OPEN_PIPE, + ConnectionState.OPEN_SENT, + ConnectionState.OPENED, + ]: raise ValueError("Connection not open.") now = time.time() if get_local_timeout(now, self.idle_timeout, self.last_frame_received_time) or ( - await self._get_remote_timeout(now)): + await self._get_remote_timeout(now) + ): await self.close( # TODO: check error condition error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, - description="No frame received for the idle timeout." + description="No frame received for the idle timeout.", ), - wait=False + wait=False, ) return await self._send_frame(channel, frame) @@ -406,7 +433,7 @@ async def _get_remote_timeout(self, now): async def _wait_for_response(self, wait, end_state): # type: (Union[bool, float], ConnectionState) -> None - if wait == True: + if wait is True: await self.listen(wait=False) while self.state != end_state: await asyncio.sleep(self.idle_wait_time) @@ -424,7 +451,9 @@ async def _listen_one_frame(self, **kwargs): new_frame = await self._read_frame(**kwargs) return await self._process_incoming_frame(*new_frame) - async def listen(self, wait=False, batch=1, **kwargs): + async def listen( + self, wait=False, batch=1, **kwargs + ): # pylint: disable=unused-argument try: raise self._error except TypeError: @@ -432,22 +461,23 @@ async def listen(self, wait=False, batch=1, **kwargs): try: if self.state not in _CLOSING_STATES: now = time.time() - if get_local_timeout(now, self.idle_timeout, self.last_frame_received_time) or ( - await self._get_remote_timeout(now)): + if get_local_timeout( + now, self.idle_timeout, self.last_frame_received_time + ) or (await self._get_remote_timeout(now)): # TODO: check error condition await self.close( error=AMQPError( condition=ErrorCondition.ConnectionCloseForced, - description="No frame received for the idle timeout." + description="No frame received for the idle timeout.", ), - wait=False + wait=False, ) return if self.state == ConnectionState.END: # TODO: check error condition self._error = AMQPConnectionError( condition=ErrorCondition.ConnectionCloseForced, - description="Connection was already closed." + description="Connection was already closed.", ) return for _ in range(batch): @@ -458,19 +488,20 @@ async def listen(self, wait=False, batch=1, **kwargs): self._error = AMQPConnectionError( ErrorCondition.SocketError, description="Can not send frame out due to exception: " + str(exc), - error=exc + error=exc, ) def create_session(self, **kwargs): assigned_channel = self._get_next_outgoing_channel() - kwargs['allow_pipelined_open'] = self.allow_pipelined_open - kwargs['idle_wait_time'] = self.idle_wait_time + kwargs["allow_pipelined_open"] = self.allow_pipelined_open + kwargs["idle_wait_time"] = self.idle_wait_time session = Session( self, assigned_channel, - network_trace=kwargs.pop('network_trace', self.network_trace), + network_trace=kwargs.pop("network_trace", self.network_trace), network_trace_params=dict(self.network_trace_params), - **kwargs) + **kwargs + ) self.outgoing_endpoints[assigned_channel] = session return session @@ -484,7 +515,9 @@ async def open(self, wait=False): if wait: await self._wait_for_response(wait, ConnectionState.OPENED) elif not self.allow_pipelined_open: - raise ValueError("Connection has been configured to not allow piplined-open. Please set 'wait' parameter.") + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) async def close(self, error=None, wait=False): if self.state in [ConnectionState.END, ConnectionState.CLOSE_SENT]: @@ -495,7 +528,7 @@ async def close(self, error=None, wait=False): self._error = AMQPConnectionError( condition=error.condition, description=error.description, - info=error.info + info=error.info, ) if self.state == ConnectionState.OPEN_PIPE: await self._set_state(ConnectionState.OC_PIPE) @@ -506,7 +539,7 @@ async def close(self, error=None, wait=False): else: await self._set_state(ConnectionState.CLOSE_SENT) await self._wait_for_response(wait, ConnectionState.END) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except # If error happened during closing, ignore the error and set state to END _LOGGER.info("An error occurred when closing the connection: %r", exc) await self._set_state(ConnectionState.END) diff --git a/uamqp/aio/_link_async.py b/uamqp/aio/_link_async.py index f89e02d23..e746eec9b 100644 --- a/uamqp/aio/_link_async.py +++ b/uamqp/aio/_link_async.py @@ -1,41 +1,34 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access + import asyncio -import threading -import struct -import uuid import logging -import time -from urllib.parse import urlparse -from enum import Enum -from io import BytesIO +import uuid -from ..endpoints import Source, Target -from ..constants import ( +from uamqp.constants import ( DEFAULT_LINK_CREDIT, SessionState, - SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, - ReceiverSettleMode + ReceiverSettleMode, ) -from ..performatives import ( - AttachFrame, - DetachFrame, - TransferFrame, - DispositionFrame, - FlowFrame, -) -from ..error import ( +from uamqp.endpoints import Source, Target +from uamqp.error import ( AMQPConnectionError, AMQPLinkRedirect, AMQPLinkError, - ErrorCondition + ErrorCondition, +) +from uamqp.performatives import ( + AttachFrame, + DetachFrame, ) _LOGGER = logging.getLogger(__name__) @@ -43,7 +36,7 @@ class Link(object): """ - + AMQP link """ def __init__(self, session, handle, name, role, **kwargs): @@ -52,54 +45,64 @@ def __init__(self, session, handle, name, role, **kwargs): self.handle = handle self.remote_handle = None self.role = role - source_address = kwargs['source_address'] + source_address = kwargs["source_address"] target_address = kwargs["target_address"] - self.source = source_address if isinstance(source_address, Source) else Source( - address=kwargs['source_address'], - durable=kwargs.get('source_durable'), - expiry_policy=kwargs.get('source_expiry_policy'), - timeout=kwargs.get('source_timeout'), - dynamic=kwargs.get('source_dynamic'), - dynamic_node_properties=kwargs.get('source_dynamic_node_properties'), - distribution_mode=kwargs.get('source_distribution_mode'), - filters=kwargs.get('source_filters'), - default_outcome=kwargs.get('source_default_outcome'), - outcomes=kwargs.get('source_outcomes'), - capabilities=kwargs.get('source_capabilities')) - self.target = target_address if isinstance(target_address,Target) else Target( - address=kwargs['target_address'], - durable=kwargs.get('target_durable'), - expiry_policy=kwargs.get('target_expiry_policy'), - timeout=kwargs.get('target_timeout'), - dynamic=kwargs.get('target_dynamic'), - dynamic_node_properties=kwargs.get('target_dynamic_node_properties'), - capabilities=kwargs.get('target_capabilities')) - self.link_credit = kwargs.pop('link_credit', None) or DEFAULT_LINK_CREDIT + self.source = ( + source_address + if isinstance(source_address, Source) + else Source( + address=kwargs["source_address"], + durable=kwargs.get("source_durable"), + expiry_policy=kwargs.get("source_expiry_policy"), + timeout=kwargs.get("source_timeout"), + dynamic=kwargs.get("source_dynamic"), + dynamic_node_properties=kwargs.get("source_dynamic_node_properties"), + distribution_mode=kwargs.get("source_distribution_mode"), + filters=kwargs.get("source_filters"), + default_outcome=kwargs.get("source_default_outcome"), + outcomes=kwargs.get("source_outcomes"), + capabilities=kwargs.get("source_capabilities"), + ) + ) + self.target = ( + target_address + if isinstance(target_address, Target) + else Target( + address=kwargs["target_address"], + durable=kwargs.get("target_durable"), + expiry_policy=kwargs.get("target_expiry_policy"), + timeout=kwargs.get("target_timeout"), + dynamic=kwargs.get("target_dynamic"), + dynamic_node_properties=kwargs.get("target_dynamic_node_properties"), + capabilities=kwargs.get("target_capabilities"), + ) + ) + self.link_credit = kwargs.pop("link_credit", None) or DEFAULT_LINK_CREDIT self.current_link_credit = self.link_credit - self.send_settle_mode = kwargs.pop('send_settle_mode', SenderSettleMode.Mixed) - self.rcv_settle_mode = kwargs.pop('rcv_settle_mode', ReceiverSettleMode.First) - self.unsettled = kwargs.pop('unsettled', None) - self.incomplete_unsettled = kwargs.pop('incomplete_unsettled', None) - self.initial_delivery_count = kwargs.pop('initial_delivery_count', 0) + self.send_settle_mode = kwargs.pop("send_settle_mode", SenderSettleMode.Mixed) + self.rcv_settle_mode = kwargs.pop("rcv_settle_mode", ReceiverSettleMode.First) + self.unsettled = kwargs.pop("unsettled", None) + self.incomplete_unsettled = kwargs.pop("incomplete_unsettled", None) + self.initial_delivery_count = kwargs.pop("initial_delivery_count", 0) self.delivery_count = self.initial_delivery_count self.received_delivery_id = None - self.max_message_size = kwargs.pop('max_message_size', None) + self.max_message_size = kwargs.pop("max_message_size", None) self.remote_max_message_size = None - self.available = kwargs.pop('available', None) - self.properties = kwargs.pop('properties', None) + self.available = kwargs.pop("available", None) + self.properties = kwargs.pop("properties", None) self.offered_capabilities = None - self.desired_capabilities = kwargs.pop('desired_capabilities', None) + self.desired_capabilities = kwargs.pop("desired_capabilities", None) - self.network_trace = kwargs['network_trace'] - self.network_trace_params = kwargs['network_trace_params'] - self.network_trace_params['link'] = self.name + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["link"] = self.name self._session = session self._is_closed = False self._send_links = {} self._receive_links = {} self._pending_deliveries = {} self._received_payload = bytearray() - self._on_link_state_change = kwargs.get('on_link_state_change') + self._on_link_state_change = kwargs.get("on_link_state_change") self._error = None async def __aenter__(self): @@ -112,7 +115,9 @@ async def __aexit__(self, *args): @classmethod def from_incoming_frame(cls, session, handle, frame): # check link_create_from_endpoint in C lib - raise NotImplementedError('Pending') # TODO: Assuming we establish all links for now... + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... async def get_state(self): try: @@ -128,7 +133,7 @@ async def _check_if_closed(self): except TypeError: raise AMQPConnectionError( condition=ErrorCondition.InternalError, - description="Link already closed." + description="Link already closed.", ) async def _set_state(self, new_state): @@ -138,21 +143,34 @@ async def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Link state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + _LOGGER.info( + "Link state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) try: await self._on_link_state_change(previous_state, new_state) except TypeError: pass except Exception as e: # pylint: disable=broad-except - _LOGGER.error("Link state change callback failed: '%r'", e, extra=self.network_trace_params) + _LOGGER.error( + "Link state change callback failed: '%r'", + e, + extra=self.network_trace_params, + ) async def _remove_pending_deliveries(self): # TODO: move to sender futures = [] for delivery in self._pending_deliveries.values(): - futures.append(asyncio.ensure_future(delivery.on_settled(LinkDeliverySettleReason.NOT_DELIVERED, None))) + futures.append( + asyncio.ensure_future( + delivery.on_settled(LinkDeliverySettleReason.NOT_DELIVERED, None) + ) + ) await asyncio.gather(*futures) self._pending_deliveries = {} - + async def _on_session_state_change(self): if self._session.state == SessionState.MAPPED: if not self._is_closed and self.state == LinkState.DETACHED: @@ -174,11 +192,17 @@ async def _outgoing_attach(self): target=self.target, unsettled=self.unsettled, incomplete_unsettled=self.incomplete_unsettled, - initial_delivery_count=self.initial_delivery_count if self.role == Role.Sender else None, + initial_delivery_count=self.initial_delivery_count + if self.role == Role.Sender + else None, max_message_size=self.max_message_size, - offered_capabilities=self.offered_capabilities if self.state == LinkState.ATTACH_RCVD else None, - desired_capabilities=self.desired_capabilities if self.state == LinkState.DETACHED else None, - properties=self.properties + offered_capabilities=self.offered_capabilities + if self.state == LinkState.ATTACH_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == LinkState.DETACHED + else None, + properties=self.properties, ) if self.network_trace: _LOGGER.info("-> %r", attach_frame, extra=self.network_trace_params) @@ -189,7 +213,7 @@ async def _incoming_attach(self, frame): _LOGGER.info("<- %r", AttachFrame(*frame), extra=self.network_trace_params) if self._is_closed: raise ValueError("Invalid link") - elif not frame[5] or not frame[6]: # TODO: not sure if we should check here + if not frame[5] or not frame[6]: # TODO: not sure if we should check here _LOGGER.info("Cannot get source or target. Detaching link") await self._remove_pending_deliveries() await self._set_state(LinkState.DETACHED) # TODO: Send detach now? @@ -205,16 +229,16 @@ async def _incoming_attach(self, frame): await self._set_state(LinkState.ATTACH_RCVD) elif self.state == LinkState.ATTACH_SENT: await self._set_state(LinkState.ATTACHED) - + async def _outgoing_flow(self): flow_frame = { - 'handle': self.handle, - 'delivery_count': self.delivery_count, - 'link_credit': self.current_link_credit, - 'available': None, - 'drain': None, - 'echo': None, - 'properties': None + "handle": self.handle, + "delivery_count": self.delivery_count, + "link_credit": self.current_link_credit, + "available": None, + "drain": None, + "echo": None, + "properties": None, } await self._session._outgoing_flow(flow_frame) @@ -237,7 +261,11 @@ async def _incoming_detach(self, frame): _LOGGER.info("<- %r", DetachFrame(*frame), extra=self.network_trace_params) if self.state == LinkState.ATTACHED: await self._outgoing_detach(close=frame[1]) - elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + elif ( + frame[1] + and not self._is_closed + and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD] + ): # Received a closing detach after we sent a non-closing detach. # In this case, we MUST signal that we closed by reattaching and then sending a closing detach. await self._outgoing_attach() @@ -246,8 +274,14 @@ async def _incoming_detach(self, frame): # TODO: on_detach_hook if frame[2]: # error # frame[2][0] is condition, frame[2][1] is description, frame[2][2] is info - error_cls = AMQPLinkRedirect if frame[2][0] == ErrorCondition.LinkRedirect else AMQPLinkError - self._error = error_cls(condition=frame[2][0], description=frame[2][1], info=frame[2][2]) + error_cls = ( + AMQPLinkRedirect + if frame[2][0] == ErrorCondition.LinkRedirect + else AMQPLinkError + ) + self._error = error_cls( + condition=frame[2][0], description=frame[2][1], info=frame[2][2] + ) await self._set_state(LinkState.ERROR) else: await self._set_state(LinkState.DETACHED) @@ -271,6 +305,6 @@ async def detach(self, close=False, error=None): elif self.state == LinkState.ATTACHED: await self._outgoing_detach(close=close, error=error) await self._set_state(LinkState.DETACH_SENT) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except _LOGGER.info("An error occurred when detaching the link: %r", exc) await self._set_state(LinkState.DETACHED) diff --git a/uamqp/aio/_management_link_async.py b/uamqp/aio/_management_link_async.py index 5daddf9a6..68cace2eb 100644 --- a/uamqp/aio/_management_link_async.py +++ b/uamqp/aio/_management_link_async.py @@ -8,9 +8,9 @@ import time from functools import partial -from ._sender_async import SenderLink -from ._receiver_async import ReceiverLink -from ..constants import ( +from uamqp.aio._receiver_async import ReceiverLink +from uamqp.aio._sender_async import SenderLink +from uamqp.constants import ( ManagementLinkState, LinkState, SenderSettleMode, @@ -18,11 +18,11 @@ ManagementExecuteOperationResult, ManagementOpenResult, MessageDeliveryState, - SEND_DISPOSITION_REJECT + SEND_DISPOSITION_REJECT, ) -from ..message import Properties, _MessageDelivery -from ..management_link import PendingManagementOperation -from ..error import AMQPException, ErrorCondition +from uamqp.error import AMQPException, ErrorCondition +from uamqp.management_link import PendingManagementOperation +from uamqp.message import Properties, _MessageDelivery _LOGGER = logging.getLogger(__name__) @@ -39,24 +39,28 @@ def __init__(self, session, endpoint, **kwargs): self.state = ManagementLinkState.IDLE self._pending_operations = [] self._session = session - self._request_link = session.create_sender_link( # type: SenderLink + self._request_link: SenderLink = session.create_sender_link( endpoint, on_link_state_change=self._on_sender_state_change, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, ) - self._response_link = session.create_receiver_link( # type: ReceiverLink + self._response_link: ReceiverLink = session.create_receiver_link( endpoint, on_link_state_change=self._on_receiver_state_change, on_message_received=self._on_message_received, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, + ) + self._on_amqp_management_error = kwargs.get("on_amqp_management_error") + self._on_amqp_management_open_complete = kwargs.get( + "on_amqp_management_open_complete" ) - self._on_amqp_management_error = kwargs.get('on_amqp_management_error') - self._on_amqp_management_open_complete = kwargs.get('on_amqp_management_open_complete') - self._status_code_field = kwargs.pop('status_code_field', b'statusCode') - self._status_description_field = kwargs.pop('status_description_field', b'statusDescription') + self._status_code_field = kwargs.pop("status_code_field", b"statusCode") + self._status_description_field = kwargs.pop( + "status_description_field", b"statusDescription" + ) self._sender_connected = False self._receiver_connected = False @@ -69,7 +73,9 @@ async def __aexit__(self, *args): await self.close() async def _on_sender_state_change(self, previous_state, new_state): - _LOGGER.info("Management link sender state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link sender state changed: %r -> %r", previous_state, new_state + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -77,8 +83,15 @@ async def _on_sender_state_change(self, previous_state, new_state): self._sender_connected = True if self._receiver_connected: self.state = ManagementLinkState.OPEN - await self._on_amqp_management_open_complete(ManagementOpenResult.OK) - elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + await self._on_amqp_management_open_complete( + ManagementOpenResult.OK + ) + elif new_state in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + LinkState.ERROR, + ]: self.state = ManagementLinkState.IDLE await self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) elif self.state == ManagementLinkState.OPEN: @@ -86,7 +99,11 @@ async def _on_sender_state_change(self, previous_state, new_state): self.state = ManagementLinkState.ERROR await self._on_amqp_management_error() elif self.state == ManagementLinkState.CLOSING: - if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + if new_state not in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + ]: self.state = ManagementLinkState.ERROR await self._on_amqp_management_error() elif self.state == ManagementLinkState.ERROR: @@ -94,7 +111,11 @@ async def _on_sender_state_change(self, previous_state, new_state): return async def _on_receiver_state_change(self, previous_state, new_state): - _LOGGER.info("Management link receiver state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -102,8 +123,15 @@ async def _on_receiver_state_change(self, previous_state, new_state): self._receiver_connected = True if self._sender_connected: self.state = ManagementLinkState.OPEN - await self._on_amqp_management_open_complete(ManagementOpenResult.OK) - elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + await self._on_amqp_management_open_complete( + ManagementOpenResult.OK + ) + elif new_state in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + LinkState.ERROR, + ]: self.state = ManagementLinkState.IDLE await self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) elif self.state == ManagementLinkState.OPEN: @@ -111,7 +139,11 @@ async def _on_receiver_state_change(self, previous_state, new_state): self.state = ManagementLinkState.ERROR await self._on_amqp_management_error() elif self.state == ManagementLinkState.CLOSING: - if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + if new_state not in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + ]: self.state = ManagementLinkState.ERROR await self._on_amqp_management_error() elif self.state == ManagementLinkState.ERROR: @@ -132,18 +164,24 @@ async def _on_message_received(self, message): to_remove_operation = operation break if to_remove_operation: - mgmt_result = ManagementExecuteOperationResult.OK \ - if 200 <= status_code <= 299 else ManagementExecuteOperationResult.FAILED_BAD_STATUS + mgmt_result = ( + ManagementExecuteOperationResult.OK + if 200 <= status_code <= 299 + else ManagementExecuteOperationResult.FAILED_BAD_STATUS + ) await to_remove_operation.on_execute_operation_complete( mgmt_result, status_code, status_description, message, - response_detail.get(b'error-condition') + response_detail.get(b"error-condition"), ) self._pending_operations.remove(to_remove_operation) - async def _on_send_complete(self, message_delivery, reason, state): # todo: reason is never used, should check spec + async def _on_send_complete( + self, message_delivery, reason, state + ): # pylint: disable=unused-argument + # todo: reason is never used, should check spec if SEND_DISPOSITION_REJECT in state: # sample reject state: {'rejected': [[b'amqp:not-allowed', b"Invalid command 'RE1AD'.", None]]} to_remove_operation = None @@ -155,16 +193,21 @@ async def _on_send_complete(self, message_delivery, reason, state): # todo: rea # TODO: better error handling # AMQPException is too general? to be more specific: MessageReject(Error) or AMQPManagementError? # or should there an error mapping which maps the condition to the error type - await to_remove_operation.on_execute_operation_complete( # The callback is defined in management_operation.py + # The callback is defined in management_operation.py + await to_remove_operation.on_execute_operation_complete( ManagementExecuteOperationResult.ERROR, None, None, message_delivery.message, error=AMQPException( - condition=state[SEND_DISPOSITION_REJECT][0][0], # 0 is error condition - description=state[SEND_DISPOSITION_REJECT][0][1], # 1 is error description + condition=state[SEND_DISPOSITION_REJECT][0][ + 0 + ], # 0 is error condition + description=state[SEND_DISPOSITION_REJECT][0][ + 1 + ], # 1 is error description info=state[SEND_DISPOSITION_REJECT][0][2], # 2 is error info - ) + ), ) async def open(self): @@ -174,38 +217,33 @@ async def open(self): await self._response_link.attach() await self._request_link.attach() - async def execute_operation( - self, - message, - on_execute_operation_complete, - **kwargs - ): + async def execute_operation(self, message, on_execute_operation_complete, **kwargs): timeout = kwargs.get("timeout") message.application_properties["operation"] = kwargs.get("operation") message.application_properties["type"] = kwargs.get("type") message.application_properties["locales"] = kwargs.get("locales") try: # TODO: namedtuple is immutable, which may push us to re-think about the namedtuple approach for Message - new_properties = message.properties._replace(message_id=self.next_message_id) + new_properties = message.properties._replace( + message_id=self.next_message_id + ) except AttributeError: new_properties = Properties(message_id=self.next_message_id) message = message._replace(properties=new_properties) expire_time = (time.time() + timeout) if timeout else None message_delivery = _MessageDelivery( - message, - MessageDeliveryState.WaitingToBeSent, - expire_time + message, MessageDeliveryState.WaitingToBeSent, expire_time ) on_send_complete = partial(self._on_send_complete, message_delivery) await self._request_link.send_transfer( - message, - on_send_complete=on_send_complete, - timeout=timeout + message, on_send_complete=on_send_complete, timeout=timeout ) self.next_message_id += 1 - self._pending_operations.append(PendingManagementOperation(message, on_execute_operation_complete)) + self._pending_operations.append( + PendingManagementOperation(message, on_execute_operation_complete) + ) async def close(self): if self.state != ManagementLinkState.IDLE: @@ -218,7 +256,10 @@ async def close(self): None, None, pending_operation.message, - AMQPException(condition=ErrorCondition.ClientError, description="Management link already closed.") + AMQPException( + condition=ErrorCondition.ClientError, + description="Management link already closed.", + ), ) self._pending_operations = [] self.state = ManagementLinkState.IDLE diff --git a/uamqp/aio/_management_operation_async.py b/uamqp/aio/_management_operation_async.py index 7c916a3be..46361a712 100644 --- a/uamqp/aio/_management_operation_async.py +++ b/uamqp/aio/_management_operation_async.py @@ -1,32 +1,22 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging -import uuid import time +import uuid from functools import partial -from ._management_link_async import ManagementLink -from ..message import Message -from ..error import ( - AMQPException, - AMQPConnectionError, - AMQPLinkError, - ErrorCondition -) - -from ..constants import ( - ManagementOpenResult, - ManagementExecuteOperationResult -) +from uamqp.aio._management_link_async import ManagementLink +from uamqp.constants import ManagementOpenResult, ManagementExecuteOperationResult +from uamqp.error import AMQPLinkError, ErrorCondition _LOGGER = logging.getLogger(__name__) class ManagementOperation(object): - def __init__(self, session, endpoint='$management', **kwargs): + def __init__(self, session, endpoint="$management", **kwargs): self._mgmt_link_open_status = None self._session = session @@ -61,7 +51,7 @@ async def _on_execute_operation_complete( status_code, status_description, raw_message, - error=None + error=None, ): _LOGGER.debug( "mgmt operation completed, operation id: %r; operation_result: %r; status_code: %r; " @@ -71,18 +61,25 @@ async def _on_execute_operation_complete( status_code, status_description, raw_message, - error + error, ) - if operation_result in\ - (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): + if operation_result in ( + ManagementExecuteOperationResult.ERROR, + ManagementExecuteOperationResult.LINK_CLOSED, + ): self._mgmt_error = error _LOGGER.error( "Failed to complete mgmt operation due to error: %r. The management request message is: %r", - error, raw_message + error, + raw_message, ) else: - self._responses[operation_id] = (status_code, status_description, raw_message) + self._responses[operation_id] = ( + status_code, + status_description, + raw_message, + ) async def execute(self, message, operation=None, operation_type=None, timeout=0): start_time = time.time() @@ -95,14 +92,16 @@ async def execute(self, message, operation=None, operation_type=None, timeout=0) partial(self._on_execute_operation_complete, operation_id), timeout=timeout, operation=operation, - type=operation_type + type=operation_type, ) while not self._responses[operation_id] and not self._mgmt_error: if timeout > 0: now = time.time() if (now - start_time) >= timeout: - raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + raise TimeoutError( + "Failed to receive mgmt response in {}ms".format(timeout) + ) await self._connection.listen() if self._mgmt_error: @@ -130,8 +129,10 @@ async def ready(self): # TODO: update below with correct status code + info raise AMQPLinkError( condition=ErrorCondition.ClientError, - description="Failed to open mgmt link, management link status: {}".format(self._mgmt_link_open_status), - info=None + description="Failed to open mgmt link, management link status: {}".format( + self._mgmt_link_open_status + ), + info=None, ) async def close(self): diff --git a/uamqp/aio/_receiver_async.py b/uamqp/aio/_receiver_async.py index 9bbe7aca9..ef6404907 100644 --- a/uamqp/aio/_receiver_async.py +++ b/uamqp/aio/_receiver_async.py @@ -1,46 +1,37 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -import uuid import logging -from io import BytesIO +import uuid -from .._decode import decode_payload -from ._link_async import Link -from ..constants import DEFAULT_LINK_CREDIT, Role -from ..endpoints import Target -from ..constants import ( - DEFAULT_LINK_CREDIT, - SessionState, - SessionTransferState, - LinkDeliverySettleReason, - LinkState -) -from ..performatives import ( - AttachFrame, - DetachFrame, +from uamqp._decode import decode_payload +from uamqp.aio._link_async import Link +from uamqp.constants import LinkState +from uamqp.constants import Role +from uamqp.performatives import ( TransferFrame, DispositionFrame, - FlowFrame, ) - _LOGGER = logging.getLogger(__name__) class ReceiverLink(Link): - def __init__(self, session, handle, source_address, **kwargs): - name = kwargs.pop('name', None) or str(uuid.uuid4()) + name = kwargs.pop("name", None) or str(uuid.uuid4()) role = Role.Receiver - if 'target_address' not in kwargs: - kwargs['target_address'] = "receiver-link-{}".format(name) - super(ReceiverLink, self).__init__(session, handle, name, role, source_address=source_address, **kwargs) - self.on_message_received = kwargs.get('on_message_received') - self.on_transfer_received = kwargs.get('on_transfer_received') + if "target_address" not in kwargs: + kwargs["target_address"] = "receiver-link-{}".format(name) + super(ReceiverLink, self).__init__( + session, handle, name, role, source_address=source_address, **kwargs + ) + self.on_message_received = kwargs.get("on_message_received") + self.on_transfer_received = kwargs.get("on_transfer_received") if not self.on_message_received and not self.on_transfer_received: raise ValueError("Must specify either a message or transfer handler.") @@ -48,9 +39,9 @@ async def _process_incoming_message(self, frame, message): try: if self.on_message_received: return await self.on_message_received(message) - elif self.on_transfer_received: + if self.on_transfer_received: return await self.on_transfer_received(frame, message) - except Exception as e: + except Exception as e: # pylint: disable=broad-except _LOGGER.error("Handler function failed with error: %r", e) return None @@ -66,7 +57,9 @@ async def _incoming_attach(self, frame): async def _incoming_transfer(self, frame): if self.network_trace: - _LOGGER.info("<- %r", TransferFrame(*frame), extra=self.network_trace_params) + _LOGGER.info( + "<- %r", TransferFrame(*frame), extra=self.network_trace_params + ) self.current_link_credit -= 1 self.delivery_count += 1 self.received_delivery_id = frame[1] @@ -94,13 +87,24 @@ async def _outgoing_disposition(self, delivery_id, delivery_state): last=delivery_id, settled=True, state=delivery_state, - batchable=None + batchable=None, ) if self.network_trace: - _LOGGER.info("-> %r", DispositionFrame(*disposition_frame), extra=self.network_trace_params) + _LOGGER.info( + "-> %r", + DispositionFrame(*disposition_frame), + extra=self.network_trace_params, + ) await self._session._outgoing_disposition(disposition_frame) async def send_disposition(self, delivery_id, delivery_state=None): if self._is_closed: raise ValueError("Link already closed.") await self._outgoing_disposition(delivery_id, delivery_state) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # check link_create_from_endpoint in C lib + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... diff --git a/uamqp/aio/_sasl_async.py b/uamqp/aio/_sasl_async.py index dda1931b9..369ed7efc 100644 --- a/uamqp/aio/_sasl_async.py +++ b/uamqp/aio/_sasl_async.py @@ -1,25 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- - -import struct -from enum import Enum +# -------------------------------------------------------------------------- from ._transport_async import AsyncTransport -from ..types import AMQPTypes, TYPE, VALUE -from ..constants import FIELD, SASLCode, SASL_HEADER_FRAME from .._transport import AMQPS_PORT -from ..performatives import ( - SASLOutcome, - SASLResponse, - SASLChallenge, - SASLInit -) - +from ..constants import SASLCode, SASL_HEADER_FRAME +from ..performatives import SASLInit -_SASL_FRAME_TYPE = b'\x01' +_SASL_FRAME_TYPE = b"\x01" # TODO: do we need it here? it's a duplicate of the sync version @@ -28,7 +18,7 @@ class SASLPlainCredential(object): See https://tools.ietf.org/html/rfc4616 for details """ - mechanism = b'PLAIN' + mechanism = b"PLAIN" def __init__(self, authcid, passwd, authzid=None): self.authcid = authcid @@ -37,13 +27,13 @@ def __init__(self, authcid, passwd, authzid=None): def start(self): if self.authzid: - login_response = self.authzid.encode('utf-8') + login_response = self.authzid.encode("utf-8") else: - login_response = b'' - login_response += b'\0' - login_response += self.authcid.encode('utf-8') - login_response += b'\0' - login_response += self.passwd.encode('utf-8') + login_response = b"" + login_response += b"\0" + login_response += self.authcid.encode("utf-8") + login_response += b"\0" + login_response += self.passwd.encode("utf-8") return login_response @@ -53,10 +43,10 @@ class SASLAnonymousCredential(object): See https://tools.ietf.org/html/rfc4505 for details """ - mechanism = b'ANONYMOUS' + mechanism = b"ANONYMOUS" - def start(self): - return b'' + def start(self): # pylint: disable=no-self-use + return b"" # TODO: do we need it here? it's a duplicate of the sync version @@ -67,33 +57,50 @@ class SASLExternalCredential(object): authentication data. """ - mechanism = b'EXTERNAL' + mechanism = b"EXTERNAL" - def start(self): - return b'' + def start(self): # pylint: disable=no-self-use + return b"" class SASLTransport(AsyncTransport): - - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + def __init__( + self, + host, + credential, + port=AMQPS_PORT, + connect_timeout=None, + ssl=None, + **kwargs + ): self.credential = credential ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + super(SASLTransport, self).__init__( + host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs + ) async def negotiate(self): await self.write(SASL_HEADER_FRAME) _, returned_header = await self.receive_frame() if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Excpected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) + raise ValueError( + "Mismatching AMQP header protocol. Expected: {!r}, received: {!r}".format( + SASL_HEADER_FRAME, returned_header[1] + ) + ) _, supported_mechanisms = await self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechanisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + if ( + self.credential.mechanism not in supported_mechanisms[1][0] + ): # sasl_server_mechanisms + raise ValueError( + "Unsupported SASL credential type: {}".format(self.credential.mechanism) + ) sasl_init = SASLInit( mechanism=self.credential.mechanism, initial_response=self.credential.start(), - hostname=self.host) + hostname=self.host, + ) await self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) _, next_frame = await self.receive_frame(verify_frame_type=1) @@ -102,5 +109,6 @@ async def negotiate(self): raise NotImplementedError("Unsupported SASL challenge") if fields[0] == SASLCode.Ok: return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + raise ValueError( + "SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields) + ) diff --git a/uamqp/aio/_sender_async.py b/uamqp/aio/_sender_async.py index b113c51df..d791265d0 100644 --- a/uamqp/aio/_sender_async.py +++ b/uamqp/aio/_sender_async.py @@ -1,64 +1,61 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -import uuid import logging import time +import uuid -from ._link_async import Link -from .._encode import encode_payload -from ..endpoints import Source -from ..constants import ( - SessionState, +from uamqp._encode import encode_payload +from uamqp.aio._link_async import Link +from uamqp.constants import ( SessionTransferState, LinkDeliverySettleReason, LinkState, Role, - SenderSettleMode + SenderSettleMode, ) -from ..performatives import ( - AttachFrame, - DetachFrame, +from uamqp.performatives import ( TransferFrame, DispositionFrame, - FlowFrame, ) _LOGGER = logging.getLogger(__name__) class PendingDelivery(object): - def __init__(self, **kwargs): - self.message = kwargs.get('message') + self.message = kwargs.get("message") self.sent = False self.frame = None - self.on_delivery_settled = kwargs.get('on_delivery_settled') - self.link = kwargs.get('link') + self.on_delivery_settled = kwargs.get("on_delivery_settled") + self.link = kwargs.get("link") self.start = time.time() self.transfer_state = None - self.timeout = kwargs.get('timeout') - self.settled = kwargs.get('settled', False) - + self.timeout = kwargs.get("timeout") + self.settled = kwargs.get("settled", False) + async def on_settled(self, reason, state): if self.on_delivery_settled and not self.settled: try: await self.on_delivery_settled(reason, state) - except Exception as e: + except Exception as e: # pylint: disable=broad-except _LOGGER.warning("Message 'on_send_complete' callback failed: %r", e) class SenderLink(Link): - def __init__(self, session, handle, target_address, **kwargs): - name = kwargs.pop('name', None) or str(uuid.uuid4()) + name = kwargs.pop("name", None) or str(uuid.uuid4()) role = Role.Sender - if 'source_address' not in kwargs: - kwargs['source_address'] = "sender-link-{}".format(name) - super(SenderLink, self).__init__(session, handle, name, role, target_address=target_address, **kwargs) + if "source_address" not in kwargs: + kwargs["source_address"] = "sender-link-{}".format(name) + super(SenderLink, self).__init__( + session, handle, name, role, target_address=target_address, **kwargs + ) self._unsent_messages = [] async def _incoming_attach(self, frame): @@ -72,11 +69,15 @@ async def _incoming_flow(self, frame): rcv_delivery_count = frame[5] if frame[4] is not None: if rcv_link_credit is None or rcv_delivery_count is None: - _LOGGER.info("Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.") + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link." + ) await self._remove_pending_deliveries() await self._set_state(LinkState.DETACHED) # TODO: Send detach now? else: - self.current_link_credit = rcv_delivery_count + rcv_link_credit - self.delivery_count + self.current_link_credit = ( + rcv_delivery_count + rcv_link_credit - self.delivery_count + ) if self.current_link_credit > 0: await self._send_unsent_messages() @@ -85,20 +86,24 @@ async def _outgoing_transfer(self, delivery): encode_payload(output, delivery.message) delivery_count = self.delivery_count + 1 delivery.frame = { - 'handle': self.handle, - 'delivery_tag': bytes(delivery_count), - 'message_format': delivery.message._code, - 'settled': delivery.settled, - 'more': False, - 'rcv_settle_mode': None, - 'state': None, - 'resume': None, - 'aborted': None, - 'batchable': None, - 'payload': output + "handle": self.handle, + "delivery_tag": bytes(delivery_count), + "message_format": delivery.message._code, + "settled": delivery.settled, + "more": False, + "rcv_settle_mode": None, + "state": None, + "resume": None, + "aborted": None, + "batchable": None, + "payload": output, } if self.network_trace: - _LOGGER.info("-> %r", TransferFrame(delivery_id='', **delivery.frame), extra=self.network_trace_params) + _LOGGER.info( + "-> %r", + TransferFrame(delivery_id="", **delivery.frame), + extra=self.network_trace_params, + ) await self._session._outgoing_transfer(delivery) if delivery.transfer_state == SessionTransferState.OKAY: self.delivery_count = delivery_count @@ -107,7 +112,7 @@ async def _outgoing_transfer(self, delivery): if delivery.settled: await delivery.on_settled(LinkDeliverySettleReason.SETTLED, None) else: - self._pending_deliveries[delivery.frame['delivery_id']] = delivery + self._pending_deliveries[delivery.frame["delivery_id"]] = delivery elif delivery.transfer_state == SessionTransferState.ERROR: raise ValueError("Message failed to send") if self.current_link_credit <= 0: @@ -116,24 +121,29 @@ async def _outgoing_transfer(self, delivery): async def _incoming_disposition(self, frame): if self.network_trace: - _LOGGER.info("<- %r", DispositionFrame(*frame), extra=self.network_trace_params) + _LOGGER.info( + "<- %r", DispositionFrame(*frame), extra=self.network_trace_params + ) if not frame[3]: return range_end = (frame[2] or frame[1]) + 1 - settled_ids = [i for i in range(frame[1], range_end)] - for settled_id in settled_ids: + for settled_id in range(frame[1], range_end): delivery = self._pending_deliveries.pop(settled_id, None) if delivery: - await delivery.on_settled(LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4]) + await delivery.on_settled( + LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4] + ) async def _update_pending_delivery_status(self): now = time.time() expired = [] for delivery in self._pending_deliveries.values(): if delivery.timeout and (now - delivery.start) >= delivery.timeout: - expired.append(delivery.frame['delivery_id']) + expired.append(delivery.frame["delivery_id"]) await delivery.on_settled(LinkDeliverySettleReason.TIMEOUT, None) - self._pending_deliveries = {i: d for i, d in self._pending_deliveries.items() if i not in expired} + self._pending_deliveries = { + i: d for i, d in self._pending_deliveries.items() if i not in expired + } async def _send_unsent_messages(self): unsent = [] @@ -151,10 +161,10 @@ async def send_transfer(self, message, **kwargs): raise ValueError("Link is not attached.") settled = self.send_settle_mode == SenderSettleMode.Settled if self.send_settle_mode == SenderSettleMode.Mixed: - settled = kwargs.pop('settled', True) + settled = kwargs.pop("settled", True) delivery = PendingDelivery( - on_delivery_settled=kwargs.get('on_send_complete'), - timeout=kwargs.get('timeout'), + on_delivery_settled=kwargs.get("on_send_complete"), + timeout=kwargs.get("timeout"), link=self, message=message, settled=settled, @@ -169,10 +179,21 @@ async def send_transfer(self, message, **kwargs): async def cancel_transfer(self, delivery): try: - delivery = self._pending_deliveries.pop(delivery.frame['delivery_id']) + delivery = self._pending_deliveries.pop(delivery.frame["delivery_id"]) await delivery.on_settled(LinkDeliverySettleReason.CANCELLED, None) return except KeyError: pass # todo remove from unset messages - raise ValueError("No pending delivery with ID '{}' found.".format(delivery.frame['delivery_id'])) + raise ValueError( + "No pending delivery with ID '{}' found.".format( + delivery.frame["delivery_id"] + ) + ) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # check link_create_from_endpoint in C lib + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... diff --git a/uamqp/aio/_session_async.py b/uamqp/aio/_session_async.py index a40f60290..a0537c0a9 100644 --- a/uamqp/aio/_session_async.py +++ b/uamqp/aio/_session_async.py @@ -1,37 +1,24 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -import uuid +# pylint: disable=protected-access + +import asyncio import logging import time -import asyncio +import uuid from typing import Optional, Union -from ..constants import ( - INCOMING_WINDOW, - OUTGOING_WIDNOW, - ConnectionState, - SessionState, - SessionTransferState, - Role -) -from ..endpoints import Source, Target -from ._management_link_async import ManagementLink -from ._sender_async import SenderLink -from ._receiver_async import ReceiverLink -from ..performatives import ( - BeginFrame, - EndFrame, - FlowFrame, - AttachFrame, - DetachFrame, - TransferFrame, - DispositionFrame -) -from .._encode import encode_frame +from uamqp._encode import encode_frame +from uamqp.aio._management_link_async import ManagementLink +from uamqp.aio._receiver_async import ReceiverLink +from uamqp.aio._sender_async import SenderLink +from uamqp.constants import ConnectionState, SessionState, SessionTransferState, Role +from uamqp.error import AMQPError +from uamqp.performatives import BeginFrame, EndFrame, FlowFrame, TransferFrame _LOGGER = logging.getLogger(__name__) @@ -49,27 +36,27 @@ class Session(object): """ def __init__(self, connection, channel, **kwargs): - self.name = kwargs.pop('name', None) or str(uuid.uuid4()) + self.name = kwargs.pop("name", None) or str(uuid.uuid4()) self.state = SessionState.UNMAPPED - self.handle_max = kwargs.get('handle_max', 4294967295) - self.properties = kwargs.pop('properties', None) + self.handle_max = kwargs.get("handle_max", 4294967295) + self.properties = kwargs.pop("properties", None) self.channel = channel self.remote_channel = None - self.next_outgoing_id = kwargs.pop('next_outgoing_id', 0) + self.next_outgoing_id = kwargs.pop("next_outgoing_id", 0) self.next_incoming_id = None - self.incoming_window = kwargs.pop('incoming_window', 1) - self.outgoing_window = kwargs.pop('outgoing_window', 1) + self.incoming_window = kwargs.pop("incoming_window", 1) + self.outgoing_window = kwargs.pop("outgoing_window", 1) self.target_incoming_window = self.incoming_window self.remote_incoming_window = 0 self.remote_outgoing_window = 0 self.offered_capabilities = None - self.desired_capabilities = kwargs.pop('desired_capabilities', None) + self.desired_capabilities = kwargs.pop("desired_capabilities", None) - self.allow_pipelined_open = kwargs.pop('allow_pipelined_open', True) - self.idle_wait_time = kwargs.get('idle_wait_time', 0.1) - self.network_trace = kwargs['network_trace'] - self.network_trace_params = kwargs['network_trace_params'] - self.network_trace_params['session'] = self.name + self.allow_pipelined_open = kwargs.pop("allow_pipelined_open", True) + self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["session"] = self.name self.links = {} self._connection = connection @@ -84,7 +71,9 @@ async def __aexit__(self, *args): await self.end() @classmethod - def from_incoming_frame(cls, connection, channel, frame): + def from_incoming_frame( + cls, connection, channel, frame + ): # pylint: disable=unused-argument # check session_create_from_endpoint in C lib new_session = cls(connection, channel) return new_session @@ -96,7 +85,12 @@ async def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Session state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + _LOGGER.info( + "Session state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) futures = [] for link in self.links.values(): @@ -117,19 +111,31 @@ def _get_next_output_handle(self): :rtype: int """ if len(self._output_handles) >= self.handle_max: - raise ValueError("Maximum number of handles ({}) has been reached.".format(self.handle_max)) - next_handle = next(i for i in range(1, self.handle_max) if i not in self._output_handles) + raise ValueError( + "Maximum number of handles ({}) has been reached.".format( + self.handle_max + ) + ) + next_handle = next( + i for i in range(1, self.handle_max) if i not in self._output_handles + ) return next_handle - + async def _outgoing_begin(self): begin_frame = BeginFrame( - remote_channel=self.remote_channel if self.state == SessionState.BEGIN_RCVD else None, + remote_channel=self.remote_channel + if self.state == SessionState.BEGIN_RCVD + else None, next_outgoing_id=self.next_outgoing_id, outgoing_window=self.outgoing_window, incoming_window=self.incoming_window, handle_max=self.handle_max, - offered_capabilities=self.offered_capabilities if self.state == SessionState.BEGIN_RCVD else None, - desired_capabilities=self.desired_capabilities if self.state == SessionState.UNMAPPED else None, + offered_capabilities=self.offered_capabilities + if self.state == SessionState.BEGIN_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == SessionState.UNMAPPED + else None, properties=self.properties, ) if self.network_trace: @@ -160,7 +166,11 @@ async def _outgoing_end(self, error=None): async def _incoming_end(self, frame): if self.network_trace: _LOGGER.info("<- %r", EndFrame(*frame), extra=self.network_trace_params) - if self.state not in [SessionState.END_RCVD, SessionState.END_SENT, SessionState.DISCARDING]: + if self.state not in [ + SessionState.END_RCVD, + SessionState.END_SENT, + SessionState.DISCARDING, + ]: await self._set_state(SessionState.END_RCVD) # TODO: Clean up all links await self._outgoing_end() @@ -171,12 +181,16 @@ async def _outgoing_attach(self, frame): async def _incoming_attach(self, frame): try: - self._input_handles[frame[1]] = self.links[frame[0].decode('utf-8')] + self._input_handles[frame[1]] = self.links[frame[0].decode("utf-8")] await self._input_handles[frame[1]]._incoming_attach(frame) except KeyError: - outgoing_handle = self._get_next_output_handle() # TODO: catch max-handles error + outgoing_handle = ( + self._get_next_output_handle() + ) # TODO: catch max-handles error if frame[2] == Role.Sender: - new_link = ReceiverLink.from_incoming_frame(self, outgoing_handle, frame) + new_link = ReceiverLink.from_incoming_frame( + self, outgoing_handle, frame + ) else: new_link = SenderLink.from_incoming_frame(self, outgoing_handle, frame) await new_link._incoming_attach(frame) @@ -185,15 +199,17 @@ async def _incoming_attach(self, frame): self._input_handles[frame[1]] = new_link except ValueError: pass # TODO: Reject link - + async def _outgoing_flow(self, frame=None): link_flow = frame or {} - link_flow.update({ - 'next_incoming_id': self.next_incoming_id, - 'incoming_window': self.incoming_window, - 'next_outgoing_id': self.next_outgoing_id, - 'outgoing_window': self.outgoing_window - }) + link_flow.update( + { + "next_incoming_id": self.next_incoming_id, + "incoming_window": self.incoming_window, + "next_outgoing_id": self.next_outgoing_id, + "outgoing_window": self.outgoing_window, + } + ) flow_frame = FlowFrame(**link_flow) if self.network_trace: _LOGGER.info("-> %r", flow_frame, extra=self.network_trace_params) @@ -203,8 +219,12 @@ async def _incoming_flow(self, frame): if self.network_trace: _LOGGER.info("<- %r", FlowFrame(*frame), extra=self.network_trace_params) self.next_incoming_id = frame[2] - remote_incoming_id = frame[0] or self.next_outgoing_id # TODO "initial-outgoing-id" - self.remote_incoming_window = remote_incoming_id + frame[1] - self.next_outgoing_id + remote_incoming_id = ( + frame[0] or self.next_outgoing_id + ) # TODO "initial-outgoing-id" + self.remote_incoming_window = ( + remote_incoming_id + frame[1] - self.next_outgoing_id + ) self.remote_outgoing_window = frame[3] if frame[4] is not None: await self._input_handles[frame[4]]._incoming_flow(frame) @@ -222,58 +242,64 @@ async def _outgoing_transfer(self, delivery): delivery.transfer_state = SessionTransferState.BUSY else: - payload = delivery.frame['payload'] + payload = delivery.frame["payload"] payload_size = len(payload) - delivery.frame['delivery_id'] = self.next_outgoing_id + delivery.frame["delivery_id"] = self.next_outgoing_id # calculate the transfer frame encoding size excluding the payload - delivery.frame['payload'] = b"" + delivery.frame["payload"] = b"" # TODO: encoding a frame would be expensive, we might want to improve depending on the perf test results encoded_frame = encode_frame(TransferFrame(**delivery.frame))[1] transfer_overhead_size = len(encoded_frame) # available size for payload per frame is calculated as following: # remote max frame size - transfer overhead (calculated) - header (8 bytes) - available_frame_size = self._connection._remote_max_frame_size - transfer_overhead_size - 8 + available_frame_size = ( + self._connection._remote_max_frame_size - transfer_overhead_size - 8 + ) start_idx = 0 remaining_payload_cnt = payload_size # encode n-1 frames if payload_size > available_frame_size while remaining_payload_cnt > available_frame_size: tmp_delivery_frame = { - 'handle': delivery.frame['handle'], - 'delivery_tag': delivery.frame['delivery_tag'], - 'message_format': delivery.frame['message_format'], - 'settled': delivery.frame['settled'], - 'more': True, - 'rcv_settle_mode': delivery.frame['rcv_settle_mode'], - 'state': delivery.frame['state'], - 'resume': delivery.frame['resume'], - 'aborted': delivery.frame['aborted'], - 'batchable': delivery.frame['batchable'], - 'payload': payload[start_idx:start_idx+available_frame_size], - 'delivery_id': self.next_outgoing_id + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": True, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "payload": payload[start_idx : start_idx + available_frame_size], + "delivery_id": self.next_outgoing_id, } - await self._connection._process_outgoing_frame(self.channel, TransferFrame(**tmp_delivery_frame)) + await self._connection._process_outgoing_frame( + self.channel, TransferFrame(**tmp_delivery_frame) + ) start_idx += available_frame_size remaining_payload_cnt -= available_frame_size # encode the last frame tmp_delivery_frame = { - 'handle': delivery.frame['handle'], - 'delivery_tag': delivery.frame['delivery_tag'], - 'message_format': delivery.frame['message_format'], - 'settled': delivery.frame['settled'], - 'more': False, - 'rcv_settle_mode': delivery.frame['rcv_settle_mode'], - 'state': delivery.frame['state'], - 'resume': delivery.frame['resume'], - 'aborted': delivery.frame['aborted'], - 'batchable': delivery.frame['batchable'], - 'payload': payload[start_idx:], - 'delivery_id': self.next_outgoing_id + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": False, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "payload": payload[start_idx:], + "delivery_id": self.next_outgoing_id, } - await self._connection._process_outgoing_frame(self.channel, TransferFrame(**tmp_delivery_frame)) + await self._connection._process_outgoing_frame( + self.channel, TransferFrame(**tmp_delivery_frame) + ) self.next_outgoing_id += 1 self.remote_incoming_window -= 1 self.outgoing_window -= 1 @@ -286,7 +312,7 @@ async def _incoming_transfer(self, frame): try: await self._input_handles[frame[0]]._incoming_transfer(frame) except KeyError: - pass #TODO: "unattached handle" + pass # TODO: "unattached handle" if self.incoming_window == 0: self.incoming_window = self.target_incoming_window await self._outgoing_flow() @@ -316,7 +342,7 @@ async def _incoming_detach(self, frame): async def _wait_for_response(self, wait, end_state): # type: (Union[bool, float], SessionState) -> None - if wait == True: + if wait is True: await self._connection.listen(wait=False) while self.state != end_state: await asyncio.sleep(self.idle_wait_time) @@ -336,10 +362,12 @@ async def begin(self, wait=False): if wait: await self._wait_for_response(wait, SessionState.BEGIN_SENT) elif not self.allow_pipelined_open: - raise ValueError("Connection has been configured to not allow piplined-open. Please set 'wait' parameter.") + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) async def end(self, error=None, wait=False): - # type: (Optional[AMQPError]) -> None + # type: (Optional[AMQPError], Union[bool, int]) -> None try: if self.state not in [SessionState.UNMAPPED, SessionState.DISCARDING]: await self._outgoing_end(error=error) @@ -347,7 +375,7 @@ async def end(self, error=None, wait=False): new_state = SessionState.DISCARDING if error else SessionState.END_SENT await self._set_state(new_state) await self._wait_for_response(wait, SessionState.UNMAPPED) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except _LOGGER.info("An error occurred when ending the session: %r", exc) await self._set_state(SessionState.UNMAPPED) @@ -357,9 +385,10 @@ def create_receiver_link(self, source_address, **kwargs): self, handle=assigned_handle, source_address=source_address, - network_trace=kwargs.pop('network_trace', self.network_trace), + network_trace=kwargs.pop("network_trace", self.network_trace), network_trace_params=dict(self.network_trace_params), - **kwargs) + **kwargs + ) self.links[link.name] = link self._output_handles[assigned_handle] = link return link @@ -370,9 +399,10 @@ def create_sender_link(self, target_address, **kwargs): self, handle=assigned_handle, target_address=target_address, - network_trace=kwargs.pop('network_trace', self.network_trace), + network_trace=kwargs.pop("network_trace", self.network_trace), network_trace_params=dict(self.network_trace_params), - **kwargs) + **kwargs + ) self._output_handles[assigned_handle] = link self.links[link.name] = link return link @@ -381,5 +411,6 @@ def create_request_response_link_pair(self, endpoint, **kwargs): return ManagementLink( self, endpoint, - network_trace=kwargs.pop('network_trace', self.network_trace), - **kwargs) + network_trace=kwargs.pop("network_trace", self.network_trace), + **kwargs + ) diff --git a/uamqp/aio/_transport_async.py b/uamqp/aio/_transport_async.py index acbdd8af8..312cbb15b 100644 --- a/uamqp/aio/_transport_async.py +++ b/uamqp/aio/_transport_async.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # This is a fork of the transport.py which was originally written by Barry Pederson and # maintained by the Celery project: https://github.com/celery/py-amqp. # @@ -30,54 +30,50 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF # THE POSSIBILITY OF SUCH DAMAGE. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import asyncio import errno -import re +import logging import socket import ssl import struct -from ssl import SSLError -from contextlib import contextmanager from io import BytesIO -import logging -from threading import Lock +from ssl import SSLError import certifi -from .._platform import KNOWN_TCP_OPTS, SOL_TCP, pack, unpack -from .._encode import encode_frame -from .._decode import decode_frame, decode_empty_frame -from ..constants import TLS_HEADER_FRAME -from .._transport import ( +from uamqp._decode import decode_frame, decode_empty_frame +from uamqp._encode import encode_frame +from uamqp._platform import SOL_TCP +from uamqp._transport import ( AMQP_FRAME, get_errno, to_host_port, - DEFAULT_SOCKET_SETTINGS, - IPV6_LITERAL, SIGNED_INT_MAX, _UNAVAIL, set_cloexec, - AMQP_PORT + AMQP_PORT, + _get_tcp_socket_defaults, ) - +from uamqp.constants import TLS_HEADER_FRAME _LOGGER = logging.getLogger(__name__) def get_running_loop(): try: - import asyncio # pylint: disable=import-error return asyncio.get_running_loop() except AttributeError: # 3.6 loop = None try: loop = asyncio._get_running_loop() # pylint: disable=protected-access except AttributeError: - _LOGGER.warning('This version of Python is deprecated, please upgrade to >= v3.6') + _LOGGER.warning( + "This version of Python is deprecated, please upgrade to >= v3.6" + ) if loop is None: - _LOGGER.warning('No running event loop') + _LOGGER.warning("No running event loop") loop = asyncio.get_event_loop() return loop @@ -85,9 +81,18 @@ def get_running_loop(): class AsyncTransport(object): """Common superclass for TCP and SSL transports.""" - def __init__(self, host, port=AMQP_PORT, connect_timeout=None, - read_timeout=None, write_timeout=None, ssl=False, - socket_settings=None, raise_on_initial_eintr=True, **kwargs): + def __init__( + self, # pylint: disable=unused-argument + host, + port=AMQP_PORT, + connect_timeout=None, + read_timeout=None, + write_timeout=None, + ssl=False, # pylint: disable=redefined-outer-name + socket_settings=None, + raise_on_initial_eintr=True, + **kwargs + ): self.connected = False self.sock = None self.reader = None @@ -108,19 +113,23 @@ def _build_ssl_opts(self, sslopts): if sslopts in [True, False, None, {}]: return sslopts try: - if 'context' in sslopts: - return self._build_ssl_context(sslopts, **sslopts.pop('context')) - ssl_version = sslopts.get('ssl_version') + if "context" in sslopts: + return self._build_ssl_context(sslopts, **sslopts.pop("context")) + ssl_version = sslopts.get("ssl_version") if ssl_version is None: ssl_version = ssl.PROTOCOL_TLS # Set SNI headers if supported - server_hostname = sslopts.get('server_hostname') - if (server_hostname is not None) and (hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and (hasattr(ssl, 'SSLContext')): + server_hostname = sslopts.get("server_hostname") + if ( + (server_hostname is not None) + and (hasattr(ssl, "HAS_SNI") and ssl.HAS_SNI) + and (hasattr(ssl, "SSLContext")) + ): context = ssl.SSLContext(ssl_version) - cert_reqs = sslopts.get('cert_reqs', ssl.CERT_REQUIRED) - certfile = sslopts.get('certfile') - keyfile = sslopts.get('keyfile') + cert_reqs = sslopts.get("cert_reqs", ssl.CERT_REQUIRED) + certfile = sslopts.get("certfile") + keyfile = sslopts.get("keyfile") context.verify_mode = cert_reqs if cert_reqs != ssl.CERT_NONE: context.check_hostname = True @@ -129,9 +138,13 @@ def _build_ssl_opts(self, sslopts): return context return True except TypeError: - raise TypeError('SSL configuration must be a dictionary, or the value True.') + raise TypeError( + "SSL configuration must be a dictionary, or the value True." + ) - def _build_ssl_context(self, sslopts, check_hostname=None, **ctx_options): + def _build_ssl_context( + self, sslopts, check_hostname=None, **ctx_options + ): # pylint: disable=unused-argument,no-self-use ctx = ssl.create_default_context(**ctx_options) ctx.verify_mode = ssl.CERT_REQUIRED ctx.load_verify_locations(cafile=certifi.where()) @@ -145,12 +158,14 @@ async def connect(self): return await self._connect(self.host, self.port, self.connect_timeout) self._init_socket( - self.socket_settings, self.read_timeout, self.write_timeout, + self.socket_settings, + self.read_timeout, + self.write_timeout, ) self.reader, self.writer = await asyncio.open_connection( sock=self.sock, ssl=self.sslopts, - server_hostname=self.host if self.sslopts else None + server_hostname=self.host if self.sslopts else None, ) # we've sent the banner; signal connect # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner @@ -180,7 +195,8 @@ async def _connect(self, host, port, timeout): # first, resolve the address for a single address family try: entries = await self.loop.getaddrinfo( - host, port, family=family, type=socket.SOCK_STREAM, proto=SOL_TCP) + host, port, family=family, type=socket.SOCK_STREAM, proto=SOL_TCP + ) entries_num = len(entries) except socket.gaierror: # we may have depleted all our options @@ -188,10 +204,11 @@ async def _connect(self, host, port, timeout): # if getaddrinfo succeeded before for another address # family, reraise the previous socket.error since it's more # relevant to users - raise (e - if e is not None - else socket.error( - "failed to resolve broker hostname")) + raise ( + e + if e is not None + else socket.error("failed to resolve broker hostname") + ) continue # pragma: no cover # now that we have address(es) for the hostname, connect to broker @@ -217,7 +234,9 @@ async def _connect(self, host, port, timeout): # hurray, we established connection return - def _init_socket(self, socket_settings, read_timeout, write_timeout): + def _init_socket( + self, socket_settings, read_timeout, write_timeout + ): # pylint: disable=unused-argument self.sock.settimeout(None) # set socket back to blocking mode self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self._set_socket_options(socket_settings) @@ -235,36 +254,20 @@ def _init_socket(self, socket_settings, read_timeout, write_timeout): self.sock.settimeout(1) # set socket back to non-blocking mode - def _get_tcp_socket_defaults(self, sock): - tcp_opts = {} - for opt in KNOWN_TCP_OPTS: - enum = None - if opt == 'TCP_USER_TIMEOUT': - try: - from socket import TCP_USER_TIMEOUT as enum - except ImportError: - # should be in Python 3.6+ on Linux. - enum = 18 - elif hasattr(socket, opt): - enum = getattr(socket, opt) - - if enum: - if opt in DEFAULT_SOCKET_SETTINGS: - tcp_opts[enum] = DEFAULT_SOCKET_SETTINGS[opt] - elif hasattr(socket, opt): - tcp_opts[enum] = sock.getsockopt( - SOL_TCP, getattr(socket, opt)) - return tcp_opts - def _set_socket_options(self, socket_settings): - tcp_opts = self._get_tcp_socket_defaults(self.sock) + tcp_opts = _get_tcp_socket_defaults(self.sock) if socket_settings: tcp_opts.update(socket_settings) for opt, val in tcp_opts.items(): self.sock.setsockopt(SOL_TCP, opt, val) - async def _read(self, toread, initial=False, buffer=None, - _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): + async def _read( + self, + toread, + initial=False, + buffer=None, + _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR), + ): # According to SSL_read(3), it can at most return 16kb of data. # Thus, we use an internal read buffer like TCPTransport._read # to get the exact number of bytes wanted. @@ -276,16 +279,18 @@ async def _read(self, toread, initial=False, buffer=None, try: while toread: try: - view[nbytes:nbytes + toread] = await self.reader.readexactly(toread) + view[nbytes : nbytes + toread] = await self.reader.readexactly( + toread + ) nbytes = toread except asyncio.IncompleteReadError as exc: pbytes = len(exc.partial) - view[nbytes:nbytes + pbytes] = exc.partial + view[nbytes : nbytes + pbytes] = exc.partial nbytes = pbytes except socket.error as exc: # ssl.sock.read may cause a SSLerror without errno # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): + if isinstance(exc, SSLError) and "timed out" in str(exc): raise socket.timeout() # ssl.sock.read may cause ENOENT if the # operation couldn't be performed (Issue celery#1414). @@ -295,7 +300,7 @@ async def _read(self, toread, initial=False, buffer=None, continue raise if not nbytes: - raise IOError('Server unexpectedly closed connection') + raise IOError("Server unexpectedly closed connection") length += nbytes toread -= nbytes @@ -318,30 +323,43 @@ def close(self): self.sock = None self.connected = False - async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? + async def read( + self, verify_frame_type=0, **kwargs + ): # pylint: disable=unused-argument + # TODO: verify frame type? async with self.socket_lock: read_frame_buffer = BytesIO() try: frame_header = memoryview(bytearray(8)) - read_frame_buffer.write(await self._read(8, buffer=frame_header, initial=True)) + read_frame_buffer.write( + await self._read(8, buffer=frame_header, initial=True) + ) - channel = struct.unpack('>H', frame_header[6:])[0] + channel = struct.unpack(">H", frame_header[6:])[0] size = frame_header[0:4] if size == AMQP_FRAME: # Empty frame or AMQP header negotiation return frame_header, channel, None - size = struct.unpack('>I', size)[0] + size = struct.unpack(">I", size)[0] offset = frame_header[4] - frame_type = frame_header[5] + # frame_type = frame_header[5] # >I is an unsigned int, but the argument to sock.recv is signed, # so we know the size can be at most 2 * SIGNED_INT_MAX payload_size = size - len(frame_header) payload = memoryview(bytearray(payload_size)) if size > SIGNED_INT_MAX: - read_frame_buffer.write(await self._read(SIGNED_INT_MAX, buffer=payload)) - read_frame_buffer.write(await self._read(size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:])) + read_frame_buffer.write( + await self._read(SIGNED_INT_MAX, buffer=payload) + ) + read_frame_buffer.write( + await self._read( + size - SIGNED_INT_MAX, buffer=payload[SIGNED_INT_MAX:] + ) + ) else: - read_frame_buffer.write(await self._read(payload_size, buffer=payload)) + read_frame_buffer.write( + await self._read(payload_size, buffer=payload) + ) except (socket.timeout, asyncio.IncompleteReadError): read_frame_buffer.write(self._read_buffer.getvalue()) self._read_buffer = read_frame_buffer @@ -350,7 +368,7 @@ async def read(self, verify_frame_type=0, **kwargs): # TODO: verify frame type? except (OSError, IOError, SSLError, socket.error) as exc: # Don't disconnect for ssl read time outs # http://bugs.python.org/issue10272 - if isinstance(exc, SSLError) and 'timed out' in str(exc): + if isinstance(exc, SSLError) and "timed out" in str(exc): raise socket.timeout() if get_errno(exc) not in _UNAVAIL: self.connected = False @@ -368,20 +386,22 @@ async def write(self, s): self.connected = False raise - async def receive_frame(self, *args, **kwargs): + async def receive_frame(self, *args, **kwargs): # pylint: disable=unused-argument try: - header, channel, payload = await self.read(**kwargs) + header, channel, payload = await self.read(**kwargs) if not payload: decoded = decode_empty_frame(header) else: decoded = decode_frame(payload) # TODO: Catch decode error and return amqp:decode-error - #_LOGGER.info("ICH%d <- %r", channel, decoded) + # _LOGGER.info("ICH%d <- %r", channel, decoded) return channel, decoded except (socket.timeout, asyncio.IncompleteReadError, asyncio.TimeoutError): return None, None - async def receive_frame_with_lock(self, *args, **kwargs): + async def receive_frame_with_lock( + self, *args, **kwargs + ): # pylint: disable=unused-argument try: async with self.socket_lock: header, channel, payload = await self.read(**kwargs) @@ -398,17 +418,21 @@ async def send_frame(self, channel, frame, **kwargs): if performative is None: data = header else: - encoded_channel = struct.pack('>H', channel) + encoded_channel = struct.pack(">H", channel) data = header + encoded_channel + performative await self.write(data) - #_LOGGER.info("OCH%d -> %r", channel, frame) + # _LOGGER.info("OCH%d -> %r", channel, frame) async def negotiate(self): if not self.sslopts: return await self.write(TLS_HEADER_FRAME) - channel, returned_header = await self.receive_frame(verify_frame_type=None) + # receive_frame returns tuple, [0] for channel, [1] for returned header + returned_header = (await self.receive_frame(verify_frame_type=None))[1] if returned_header[1] == TLS_HEADER_FRAME: - raise ValueError("Mismatching TLS header protocol. Excpected: {}, received: {}".format( - TLS_HEADER_FRAME, returned_header[1])) + raise ValueError( + "Mismatching TLS header protocol. Expected: {!r}, received: {!r}".format( + TLS_HEADER_FRAME, returned_header[1] + ) + ) diff --git a/uamqp/amqp_types.py b/uamqp/amqp_types.py new file mode 100644 index 000000000..5a51531b1 --- /dev/null +++ b/uamqp/amqp_types.py @@ -0,0 +1,90 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from enum import Enum + + +TYPE = "TYPE" +VALUE = "VALUE" + + +class AMQPTypes(object): # pylint: disable=no-init + null = "NULL" + boolean = "BOOL" + ubyte = "UBYTE" + byte = "BYTE" + ushort = "USHORT" + short = "SHORT" + uint = "UINT" + int = "INT" + ulong = "ULONG" + long = "LONG" + float = "FLOAT" + double = "DOUBLE" + timestamp = "TIMESTAMP" + uuid = "UUID" + binary = "BINARY" + string = "STRING" + symbol = "SYMBOL" + list = "LIST" + map = "MAP" + array = "ARRAY" + described = "DESCRIBED" + + +class FieldDefinition(Enum): + fields = "fields" + annotations = "annotations" + message_id = "message-id" + app_properties = "application-properties" + node_properties = "node-properties" + filter_set = "filter-set" + + +class ObjDefinition(Enum): + source = "source" + target = "target" + delivery_state = "delivery-state" + error = "error" + + +class ConstructorBytes(object): # pylint: disable=no-init + null = b"\x40" + bool = b"\x56" + bool_true = b"\x41" + bool_false = b"\x42" + ubyte = b"\x50" + byte = b"\x51" + ushort = b"\x60" + short = b"\x61" + uint_0 = b"\x43" + uint_small = b"\x52" + int_small = b"\x54" + uint_large = b"\x70" + int_large = b"\x71" + ulong_0 = b"\x44" + ulong_small = b"\x53" + long_small = b"\x55" + ulong_large = b"\x80" + long_large = b"\x81" + float = b"\x72" + double = b"\x82" + timestamp = b"\x83" + uuid = b"\x98" + binary_small = b"\xA0" + binary_large = b"\xB0" + string_small = b"\xA1" + string_large = b"\xB1" + symbol_small = b"\xA3" + symbol_large = b"\xB3" + list_0 = b"\x45" + list_small = b"\xC0" + list_large = b"\xD0" + map_small = b"\xC1" + map_large = b"\xD1" + array_small = b"\xE0" + array_large = b"\xF0" + descriptor = b"\x00" diff --git a/uamqp/authentication.py b/uamqp/authentication.py index 6fb937867..9861b8ceb 100644 --- a/uamqp/authentication.py +++ b/uamqp/authentication.py @@ -1,41 +1,32 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- import time -import urllib from collections import namedtuple from functools import partial -from .sasl import SASLAnonymousCredential, SASLPlainCredential -from .utils import generate_sas_token - -from .constants import ( +from uamqp.constants import ( AUTH_DEFAULT_EXPIRATION_SECONDS, TOKEN_TYPE_JWT, TOKEN_TYPE_SASTOKEN, AUTH_TYPE_CBS, - AUTH_TYPE_SASL_PLAIN + AUTH_TYPE_SASL_PLAIN, ) - -try: - from urlparse import urlparse - from urllib import quote_plus # type: ignore -except ImportError: - from urllib.parse import urlparse, quote_plus +from uamqp.sasl import SASLAnonymousCredential, SASLPlainCredential +from uamqp.utils import generate_sas_token AccessToken = namedtuple("AccessToken", ["token", "expires_on"]) -def _generate_sas_access_token(auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS): +def _generate_sas_access_token( + auth_uri, sas_name, sas_key, expiry_in=AUTH_DEFAULT_EXPIRATION_SECONDS +): expires_on = int(time.time() + expiry_in) token = generate_sas_token(auth_uri, sas_name, sas_key, expires_on) - return AccessToken( - token, - expires_on - ) + return AccessToken(token, expires_on) class SASLPlainAuth(object): @@ -52,14 +43,7 @@ class _CBSAuth(object): # 1. naming decision, suffix with Auth vs Credential auth_type = AUTH_TYPE_CBS - def __init__( - self, - uri, - audience, - token_type, - get_token, - **kwargs - ): + def __init__(self, uri, audience, token_type, get_token, **kwargs): """ CBS authentication using JWT tokens. @@ -101,13 +85,7 @@ def _set_expiry(expires_in, expires_on): class JWTTokenAuth(_CBSAuth): # TODO: # 1. naming decision, suffix with Auth vs Credential - def __init__( - self, - uri, - audience, - get_token, - **kwargs - ): + def __init__(self, uri, audience, get_token, **kwargs): """ CBS authentication using JWT tokens. @@ -125,21 +103,16 @@ def __init__( :type token_type: str """ - super(JWTTokenAuth, self).__init__(uri, audience, kwargs.pop("kwargs", TOKEN_TYPE_JWT), get_token) + super(JWTTokenAuth, self).__init__( + uri, audience, kwargs.pop("kwargs", TOKEN_TYPE_JWT), get_token + ) self.get_token = get_token class SASTokenAuth(_CBSAuth): # TODO: # 1. naming decision, suffix with Auth vs Credential - def __init__( - self, - uri, - audience, - username, - password, - **kwargs - ): + def __init__(self, uri, audience, username, password, **kwargs): """ CBS authentication using SAS tokens. @@ -171,12 +144,14 @@ def __init__( expires_in = kwargs.pop("expires_in", AUTH_DEFAULT_EXPIRATION_SECONDS) expires_on = kwargs.pop("expires_on", None) expires_in, expires_on = self._set_expiry(expires_in, expires_on) - self.get_token = partial(_generate_sas_access_token, uri, username, password, expires_in) + self.get_token = partial( + _generate_sas_access_token, uri, username, password, expires_in + ) super(SASTokenAuth, self).__init__( uri, audience, kwargs.pop("token_type", TOKEN_TYPE_SASTOKEN), self.get_token, expires_in=expires_in, - expires_on=expires_on + expires_on=expires_on, ) diff --git a/uamqp/cbs.py b/uamqp/cbs.py index b8ac11192..59e2356f6 100644 --- a/uamqp/cbs.py +++ b/uamqp/cbs.py @@ -1,22 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- + +# pylint: disable=protected-access import logging from datetime import datetime -from .utils import utc_now, utc_from_timestamp -from .management_link import ManagementLink -from .message import Message, Properties -from .error import ( - AuthenticationException, - ErrorCondition, - TokenAuthFailure, - TokenExpired -) -from .constants import ( +from uamqp.constants import ( CbsState, CbsAuthState, CBS_PUT_TOKEN, @@ -26,8 +19,17 @@ CBS_OPERATION, ManagementExecuteOperationResult, ManagementOpenResult, - DEFAULT_AUTH_TIMEOUT + DEFAULT_AUTH_TIMEOUT, +) +from uamqp.error import ( + AuthenticationException, + ErrorCondition, + TokenAuthFailure, + TokenExpired, ) +from uamqp.management_link import ManagementLink +from uamqp.message import Message, Properties +from uamqp.utils import utc_now, utc_from_timestamp _LOGGER = logging.getLogger(__name__) @@ -40,35 +42,27 @@ def check_expiration_and_refresh_status(expires_on, refresh_window): def check_put_timeout_status(auth_timeout, token_put_time): - if auth_timeout > 0: - return (int(utc_now().timestamp()) - token_put_time) >= auth_timeout - else: - return False + return 0 < auth_timeout <= (int(utc_now().timestamp()) - token_put_time) class CBSAuthenticator(object): - def __init__( - self, - session, - auth, - **kwargs - ): + def __init__(self, session, auth, **kwargs): self._session = session self._connection = self._session._connection self._mgmt_link = self._session.create_request_response_link_pair( - endpoint='$cbs', + endpoint="$cbs", on_amqp_management_open_complete=self._on_amqp_management_open_complete, on_amqp_management_error=self._on_amqp_management_error, - status_code_field=b'status-code', - status_description_field=b'status-description' + status_code_field=b"status-code", + status_description_field=b"status-description", ) # type: ManagementLink if not auth.get_token or not callable(auth.get_token): raise ValueError("get_token must be a callable object.") self._auth = auth - self._encoding = 'UTF-8' - self._auth_timeout = kwargs.pop('auth_timeout', DEFAULT_AUTH_TIMEOUT) + self._encoding = "UTF-8" + self._auth_timeout = kwargs.pop("auth_timeout", DEFAULT_AUTH_TIMEOUT) self._token_put_time = None self._expires_on = None self._token = None @@ -82,38 +76,48 @@ def __init__( def _put_token(self, token, token_type, audience, expires_on=None): # type: (str, str, str, datetime) -> None - message = Message( + message = Message( # type: ignore value=token, - properties=Properties(message_id=self._mgmt_link.next_message_id), + properties=Properties(message_id=self._mgmt_link.next_message_id), # type: ignore application_properties={ CBS_NAME: audience, CBS_OPERATION: CBS_PUT_TOKEN, CBS_TYPE: token_type, - CBS_EXPIRATION: expires_on - } + CBS_EXPIRATION: expires_on, + }, ) self._mgmt_link.execute_operation( message, self._on_execute_operation_complete, timeout=self._auth_timeout, operation=CBS_PUT_TOKEN, - type=token_type + type=token_type, ) self._mgmt_link.next_message_id += 1 def _on_amqp_management_open_complete(self, management_open_result): if self.state in (CbsState.CLOSED, CbsState.ERROR): - _LOGGER.debug("CSB with status: %r encounters unexpected AMQP management open complete.", self.state) + _LOGGER.debug( + "CSB with status: %r encounters unexpected AMQP management open complete.", + self.state, + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR _LOGGER.info( "Unexpected AMQP management open complete in OPEN, CBS error occurred on connection %r.", - self._connection._container_id + self._connection._container_id, ) elif self.state == CbsState.OPENING: - self.state = CbsState.OPEN if management_open_result == ManagementOpenResult.OK else CbsState.CLOSED - _LOGGER.info("CBS for connection %r completed opening with status: %r", - self._connection._container_id, management_open_result) + self.state = ( + CbsState.OPEN + if management_open_result == ManagementOpenResult.OK + else CbsState.CLOSED + ) + _LOGGER.info( + "CBS for connection %r completed opening with status: %r", + self._connection._container_id, + management_open_result, + ) def _on_amqp_management_error(self): if self.state == CbsState.CLOSED: @@ -121,22 +125,31 @@ def _on_amqp_management_error(self): elif self.state == CbsState.OPENING: self.state = CbsState.ERROR self._mgmt_link.close() - _LOGGER.info("CBS for connection %r failed to open with status: %r", - self._connection._container_id, ManagementOpenResult.ERROR) + _LOGGER.info( + "CBS for connection %r failed to open with status: %r", + self._connection._container_id, + ManagementOpenResult.ERROR, + ) elif self.state == CbsState.OPEN: self.state = CbsState.ERROR - _LOGGER.info("CBS error occurred on connection %r.", self._connection._container_id) + _LOGGER.info( + "CBS error occurred on connection %r.", self._connection._container_id + ) def _on_execute_operation_complete( - self, + self, + execute_operation_result, + status_code, + status_description, + message, + error_condition=None, + ): # pylint: disable=unused-argument + _LOGGER.info( + "CBS Put token result (%r), status code: %s, status_description: %s.", execute_operation_result, status_code, status_description, - message, - error_condition=None - ): - _LOGGER.info("CBS Put token result (%r), status code: %s, status_description: %s.", - execute_operation_result, status_code, status_description) + ) self._token_status_code = status_code self._token_status_description = status_description @@ -147,33 +160,41 @@ def _on_execute_operation_complete( # put-token-message sending failure, rejected self._token_status_code = 0 self._token_status_description = "Auth message has been rejected." - elif execute_operation_result == ManagementExecuteOperationResult.FAILED_BAD_STATUS: + elif ( + execute_operation_result + == ManagementExecuteOperationResult.FAILED_BAD_STATUS + ): + # TODO: log error and message self.auth_state = CbsAuthState.ERROR def _update_status(self): if self.state == CbsAuthState.OK or self.state == CbsAuthState.REFRESH_REQUIRED: - is_expired, is_refresh_required = check_expiration_and_refresh_status(self._expires_on, self._refresh_window) + is_expired, is_refresh_required = check_expiration_and_refresh_status( + self._expires_on, self._refresh_window + ) if is_expired: self.state = CbsAuthState.EXPIRED elif is_refresh_required: self.state = CbsAuthState.REFRESH_REQUIRED elif self.state == CbsAuthState.IN_PROGRESS: - put_timeout = check_put_timeout_status(self._auth_timeout, self._token_put_time) + put_timeout = check_put_timeout_status( + self._auth_timeout, self._token_put_time + ) if put_timeout: self.state = CbsAuthState.TIMEOUT def _cbs_link_ready(self): if self.state == CbsState.OPEN: return True - if self.state != CbsState.OPEN: + if self.state == CbsState.OPENING: return False - if self.state in (CbsState.CLOSED, CbsState.ERROR): - # TODO: raise proper error type also should this be a ClientError? - # Think how upper layer handle this exception + condition code - raise AuthenticationException( - condition=ErrorCondition.ClientError, - description="CBS authentication link is in broken status, please recreate the cbs link." - ) + # cbs state in CbsState.CLOSED or CbsState.ERROR + # TODO: raise proper error type also should this be a ClientError? + # Think how upper layer handle this exception + condition code + raise AuthenticationException( + condition=ErrorCondition.ClientError, + description="CBS authentication link is in broken status, please recreate the cbs link.", + ) def open(self): self.state = CbsState.OPENING @@ -194,7 +215,12 @@ def update_token(self): except AttributeError: self._token = access_token.token self._token_put_time = int(utc_now().timestamp()) - self._put_token(self._token, self._auth.token_type, self._auth.audience, utc_from_timestamp(self._expires_on)) + self._put_token( + self._token, + self._auth.token_type, + self._auth.audience, + utc_from_timestamp(self._expires_on), + ) def handle_token(self): if not self._cbs_link_ready(): @@ -203,30 +229,37 @@ def handle_token(self): if self.auth_state == CbsAuthState.IDLE: self.update_token() return False - elif self.auth_state == CbsAuthState.IN_PROGRESS: + if self.auth_state == CbsAuthState.IN_PROGRESS: return False - elif self.auth_state == CbsAuthState.OK: + if self.auth_state == CbsAuthState.OK: return True - elif self.auth_state == CbsAuthState.REFRESH_REQUIRED: - _LOGGER.info("Token on connection %r will expire soon - attempting to refresh.", - self._connection._container_id) + if self.auth_state == CbsAuthState.REFRESH_REQUIRED: + _LOGGER.info( + "Token on connection %r will expire soon - attempting to refresh.", + self._connection._container_id, + ) self.update_token() return False - elif self.auth_state == CbsAuthState.FAILURE: + if self.auth_state == CbsAuthState.FAILURE: raise AuthenticationException( condition=ErrorCondition.InternalError, - description="Failed to open CBS authentication link." + description="Failed to open CBS authentication link.", ) - elif self.auth_state == CbsAuthState.ERROR: + if self.auth_state == CbsAuthState.ERROR: raise TokenAuthFailure( self._token_status_code, self._token_status_description, - encoding=self._encoding # TODO: drop off all the encodings + encoding=self._encoding, # TODO: drop off all the encodings ) - elif self.auth_state == CbsAuthState.TIMEOUT: + if self.auth_state == CbsAuthState.TIMEOUT: raise TimeoutError("Authentication attempt timed-out.") - elif self.auth_state == CbsAuthState.EXPIRED: + if self.auth_state == CbsAuthState.EXPIRED: raise TokenExpired( condition=ErrorCondition.InternalError, - description="CBS Authentication Expired." + description="CBS Authentication Expired.", ) + # default error case + raise AuthenticationException( + condition=ErrorCondition.InternalError, + description="Unrecognized authentication state", + ) diff --git a/uamqp/client.py b/uamqp/client.py index d9a947a54..33c3d30e6 100644 --- a/uamqp/client.py +++ b/uamqp/client.py @@ -1,38 +1,23 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -# pylint: disable=too-many-lines +# pylint: disable=protected-access,too-many-lines -from collections import namedtuple import logging +import queue import threading import time import uuid -import certifi -import queue from functools import partial -from ._connection import Connection -from .message import _MessageDelivery -from .session import Session -from .sender import SenderLink -from .receiver import ReceiverLink -from .sasl import SASLTransport -from .endpoints import Source, Target -from .error import ( - AMQPConnectionError, - AMQPException, - ErrorResponse, - ErrorCondition, - MessageException, - MessageSendFailed, - RetryPolicy -) +import certifi -from .constants import ( +from uamqp._connection import Connection +from uamqp.cbs import CBSAuthenticator +from uamqp.constants import ( MessageDeliveryState, SenderSettleMode, ReceiverSettleMode, @@ -46,11 +31,15 @@ DEFAULT_AUTH_TIMEOUT, MESSAGE_DELIVERY_DONE_STATES, ) - -from .management_operation import ManagementOperation -from .cbs import CBSAuthenticator -from .authentication import _CBSAuth - +from uamqp.error import ( + AMQPException, + ErrorCondition, + MessageException, + MessageSendFailed, + RetryPolicy, +) +from uamqp.management_operation import ManagementOperation +from uamqp.message import _MessageDelivery _logger = logging.getLogger(__name__) @@ -141,21 +130,27 @@ def __init__(self, hostname, auth=None, **kwargs): self._keep_alive_thread = None # Connection settings - self._max_frame_size = kwargs.pop('max_frame_size', None) or MAX_FRAME_SIZE_BYTES - self._channel_max = kwargs.pop('channel_max', None) or 65535 - self._idle_timeout = kwargs.pop('idle_timeout', None) - self._properties = kwargs.pop('properties', None) + self._max_frame_size = ( + kwargs.pop("max_frame_size", None) or MAX_FRAME_SIZE_BYTES + ) + self._channel_max = kwargs.pop("channel_max", None) or 65535 + self._idle_timeout = kwargs.pop("idle_timeout", None) + self._properties = kwargs.pop("properties", None) self._network_trace = kwargs.pop("network_trace", False) # Session settings - self._outgoing_window = kwargs.pop('outgoing_window', None) or OUTGOING_WIDNOW - self._incoming_window = kwargs.pop('incoming_window', None) or INCOMING_WINDOW - self._handle_max = kwargs.pop('handle_max', None) + self._outgoing_window = kwargs.pop("outgoing_window", None) or OUTGOING_WIDNOW + self._incoming_window = kwargs.pop("incoming_window", None) or INCOMING_WINDOW + self._handle_max = kwargs.pop("handle_max", None) # Link settings - self._send_settle_mode = kwargs.pop('send_settle_mode', SenderSettleMode.Unsettled) - self._receive_settle_mode = kwargs.pop('receive_settle_mode', ReceiverSettleMode.Second) - self._desired_capabilities = kwargs.pop('desired_capabilities', None) + self._send_settle_mode = kwargs.pop( + "send_settle_mode", SenderSettleMode.Unsettled + ) + self._receive_settle_mode = kwargs.pop( + "receive_settle_mode", ReceiverSettleMode.Second + ) + self._desired_capabilities = kwargs.pop("desired_capabilities", None) def __enter__(self): """Run Client in a context manager.""" @@ -177,11 +172,11 @@ def _client_ready(self): # pylint: disable=no-self-use def _client_run(self, **kwargs): """Perform a single Connection iteration.""" - self._connection.listen(wait=self._socket_timeout) + self._connection.listen(wait=self._socket_timeout, **kwargs) def _close_link(self, **kwargs): if self._link and not self._link._is_closed: - self._link.detach(close=True) + self._link.detach(kwargs.pop('close', True), **kwargs) self._link = None def _do_retryable_operation(self, operation, *args, **kwargs): @@ -205,17 +200,18 @@ def _do_retryable_operation(self, operation, *args, **kwargs): if exc.condition == ErrorCondition.LinkDetachForced: self._close_link() # if link level error, close and open a new link # TODO: check if there's any other code that we want to close link? - if exc.condition in (ErrorCondition.ConnectionCloseForced, ErrorCondition.SocketError): + if exc.condition in ( + ErrorCondition.ConnectionCloseForced, + ErrorCondition.SocketError, + ): # if connection detach or socket error, close and open a new connection self.close() # TODO: check if there's any other code we want to close connection - except Exception: - raise finally: end_time = time.time() if absolute_timeout > 0: - absolute_timeout -= (end_time - start_time) - raise retry_settings['history'][-1] + absolute_timeout -= end_time - start_time + raise retry_settings["history"][-1] def _keep_alive_worker(self): interval = 10 if self._keep_alive is True else self._keep_alive @@ -223,17 +219,20 @@ def _keep_alive_worker(self): try: while self._connection and not self._shutdown: current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if elapsed_time >= interval: - _logger.info("Keeping %r connection alive. %r", - self.__class__.__name__, - self._connection._container_id) + _logger.info( + "Keeping %r connection alive. %r", + self.__class__.__name__, + self._connection._container_id, + ) self._connection._get_remote_timeout(current_time) start_time = current_time time.sleep(1) - except Exception as e: - _logger.info("Connection keep-alive for %r failed: %r.", self.__class__.__name__, e) - + except Exception as e: # pylint: disable=broad-except + _logger.info( + "Connection keep-alive for %r failed: %r.", self.__class__.__name__, e + ) def open(self): """Open the client. The client can create a new Connection @@ -254,26 +253,24 @@ def open(self): self._connection = Connection( "amqps://" + self._hostname, sasl_credential=self._auth.sasl, - ssl={'ca_certs':certifi.where()}, + ssl={"ca_certs": certifi.where()}, container_id=self._name, max_frame_size=self._max_frame_size, channel_max=self._channel_max, idle_timeout=self._idle_timeout, properties=self._properties, - network_trace=self._network_trace + network_trace=self._network_trace, ) self._connection.open() if not self._session: self._session = self._connection.create_session( incoming_window=self._incoming_window, - outgoing_window=self._outgoing_window + outgoing_window=self._outgoing_window, ) self._session.begin() if self._auth.auth_type == AUTH_TYPE_CBS: self._cbs_authenticator = CBSAuthenticator( - session=self._session, - auth=self._auth, - auth_timeout=self._auth_timeout + session=self._session, auth=self._auth, auth_timeout=self._auth_timeout ) self._cbs_authenticator.open() if self._keep_alive: @@ -376,7 +373,7 @@ def mgmt_request(self, message, **kwargs): operation = kwargs.pop("operation", None) operation_type = kwargs.pop("operation_type", None) node = kwargs.pop("node", "$management") - timeout = kwargs.pop('timeout', 0) + timeout = kwargs.pop("timeout", 0) try: mgmt_link = self._mgmt_links[node] except KeyError: @@ -388,13 +385,12 @@ def mgmt_request(self, message, **kwargs): while not mgmt_link.ready(): self._connection.listen(wait=False) - operation_type = operation_type or b'empty' - status, description, response = mgmt_link.execute( - message, - operation=operation, - operation_type=operation_type, - timeout=timeout - ) + operation_type = operation_type or b"empty" + response = mgmt_link.execute( + message, operation=operation, operation_type=operation_type, timeout=timeout + )[ + 2 + ] # [0] for status, [1] for description, [2] for response return response @@ -402,9 +398,11 @@ class SendClient(AMQPClient): def __init__(self, hostname, target, auth=None, **kwargs): self.target = target # Sender and Link settings - self._max_message_size = kwargs.pop('max_message_size', None) or MAX_FRAME_SIZE_BYTES - self._link_properties = kwargs.pop('link_properties', None) - self._link_credit = kwargs.pop('link_credit', None) + self._max_message_size = ( + kwargs.pop("max_message_size", None) or MAX_FRAME_SIZE_BYTES + ) + self._link_properties = kwargs.pop("link_properties", None) + self._link_credit = kwargs.pop("link_credit", None) super(SendClient, self).__init__(hostname, auth=auth, **kwargs) def _client_ready(self): @@ -425,7 +423,8 @@ def _client_ready(self): send_settle_mode=self._send_settle_mode, rcv_settle_mode=self._receive_settle_mode, max_message_size=self._max_message_size, - properties=self._link_properties) + properties=self._link_properties, + ) self._link.attach() return False if self._link.get_state().value != 3: # ATTACHED @@ -452,9 +451,7 @@ def _transfer_message(self, message_delivery, timeout=0): message_delivery.state = MessageDeliveryState.WaitingForSendAck on_send_complete = partial(self._on_send_complete, message_delivery) delivery = self._link.send_transfer( - message_delivery.message, - on_send_complete=on_send_complete, - timeout=timeout + message_delivery.message, on_send_complete=on_send_complete, timeout=timeout ) if not delivery.sent: raise RuntimeError("Message is not sent.") @@ -466,7 +463,9 @@ def _process_send_error(message_delivery, condition, description=None, info=None except ValueError: error = MessageException(condition, description=description, info=info) else: - error = MessageSendFailed(amqp_condition, description=description, info=info) + error = MessageSendFailed( + amqp_condition, description=description, info=info + ) message_delivery.state = MessageDeliveryState.Error message_delivery.error = error @@ -484,12 +483,11 @@ def _on_send_complete(self, message_delivery, reason, state): message_delivery, condition=error_info[0][0], description=error_info[0][1], - info=error_info[0][2] + info=error_info[0][2], ) except TypeError: self._process_send_error( - message_delivery, - condition=ErrorCondition.UnknownError + message_delivery, condition=ErrorCondition.UnknownError ) elif reason == LinkDeliverySettleReason.SETTLED: message_delivery.state = MessageDeliveryState.Ok @@ -499,8 +497,7 @@ def _on_send_complete(self, message_delivery, reason, state): else: # NotDelivered and other unknown errors self._process_send_error( - message_delivery, - condition=ErrorCondition.UnknownError + message_delivery, condition=ErrorCondition.UnknownError ) def _send_message_impl(self, message, **kwargs): @@ -508,9 +505,7 @@ def _send_message_impl(self, message, **kwargs): expire_time = (time.time() + timeout) if timeout else None self.open() message_delivery = _MessageDelivery( - message, - MessageDeliveryState.WaitingToBeSent, - expire_time + message, MessageDeliveryState.WaitingToBeSent, expire_time ) while not self.client_ready(): @@ -522,14 +517,22 @@ def _send_message_impl(self, message, **kwargs): while running and message_delivery.state not in MESSAGE_DELIVERY_DONE_STATES: running = self.do_work() if message_delivery.expiry and time.time() > message_delivery.expiry: - self._on_send_complete(message_delivery, LinkDeliverySettleReason.TIMEOUT, None) - - if message_delivery.state in (MessageDeliveryState.Error, MessageDeliveryState.Cancelled, MessageDeliveryState.Timeout): + self._on_send_complete( + message_delivery, LinkDeliverySettleReason.TIMEOUT, None + ) + + if message_delivery.state in ( + MessageDeliveryState.Error, + MessageDeliveryState.Cancelled, + MessageDeliveryState.Timeout, + ): try: raise message_delivery.error except TypeError: # This is a default handler - raise MessageException(condition=ErrorCondition.UnknownError, description="Send failed.") + raise MessageException( + condition=ErrorCondition.UnknownError, description="Send failed." + ) def send_message(self, message, **kwargs): """ @@ -627,14 +630,20 @@ class ReceiveClient(AMQPClient): def __init__(self, hostname, source, auth=None, **kwargs): self.source = source - self._streaming_receive = kwargs.pop("streaming_receive", False) # TODO: whether public? + self._streaming_receive = kwargs.pop( + "streaming_receive", False + ) # TODO: whether public? self._received_messages = queue.Queue() - self._message_received_callback = kwargs.pop("message_received_callback", None) # TODO: whether public? + self._message_received_callback = kwargs.pop( + "message_received_callback", None + ) # TODO: whether public? # Sender and Link settings - self._max_message_size = kwargs.pop('max_message_size', None) or MAX_FRAME_SIZE_BYTES - self._link_properties = kwargs.pop('link_properties', None) - self._link_credit = kwargs.pop('link_credit', 300) + self._max_message_size = ( + kwargs.pop("max_message_size", None) or MAX_FRAME_SIZE_BYTES + ) + self._link_properties = kwargs.pop("link_properties", None) + self._link_credit = kwargs.pop("link_credit", 300) super(ReceiveClient, self).__init__(hostname, auth=auth, **kwargs) def _client_ready(self): @@ -657,7 +666,7 @@ def _client_ready(self): max_message_size=self._max_message_size, on_message_received=self._message_received, properties=self._link_properties, - desired_capabilities=self._desired_capabilities + desired_capabilities=self._desired_capabilities, ) self._link.attach() return False @@ -694,11 +703,13 @@ def _message_received(self, message): if not self._streaming_receive: self._received_messages.put(message) # TODO: do we need settled property for a message? - #elif not message.settled: + # elif not message.settled: # # Message was received with callback processing and wasn't settled. # _logger.info("Message was not settled.") - def _receive_message_batch_impl(self, max_batch_size=None, on_message_received=None, timeout=0): + def _receive_message_batch_impl( + self, max_batch_size=None, on_message_received=None, timeout=0 + ): self._message_received_callback = on_message_received max_batch_size = max_batch_size or self._link_credit timeout = time.time() + timeout if timeout else 0 @@ -770,7 +781,4 @@ def receive_message_batch(self, **kwargs): default is 0. :type timeout: float """ - return self._do_retryable_operation( - self._receive_message_batch_impl, - **kwargs - ) + return self._do_retryable_operation(self._receive_message_batch_impl, **kwargs) diff --git a/uamqp/constants.py b/uamqp/constants.py index 0e60bbca7..115f26e35 100644 --- a/uamqp/constants.py +++ b/uamqp/constants.py @@ -1,13 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +import struct from collections import namedtuple from enum import Enum -import struct -_AS_BYTES = struct.Struct('>B') +_AS_BYTES = struct.Struct(">B") #: The IANA assigned port number for AMQP.The standard AMQP port number that has been assigned by IANA #: for TCP, UDP, and SCTP.There are currently no UDP or SCTP mappings defined for AMQP. @@ -24,20 +24,32 @@ MAJOR = 1 #: Major protocol version. MINOR = 0 #: Minor protocol version. REV = 0 #: Protocol revision. -HEADER_FRAME = b"AMQP\x00" + _AS_BYTES.pack(MAJOR) + _AS_BYTES.pack(MINOR) + _AS_BYTES.pack(REV) +HEADER_FRAME = ( + b"AMQP\x00" + _AS_BYTES.pack(MAJOR) + _AS_BYTES.pack(MINOR) + _AS_BYTES.pack(REV) +) TLS_MAJOR = 1 #: Major protocol version. TLS_MINOR = 0 #: Minor protocol version. TLS_REV = 0 #: Protocol revision. -TLS_HEADER_FRAME = b"AMQP\x02" + _AS_BYTES.pack(TLS_MAJOR) + _AS_BYTES.pack(TLS_MINOR) + _AS_BYTES.pack(TLS_REV) +TLS_HEADER_FRAME = ( + b"AMQP\x02" + + _AS_BYTES.pack(TLS_MAJOR) + + _AS_BYTES.pack(TLS_MINOR) + + _AS_BYTES.pack(TLS_REV) +) SASL_MAJOR = 1 #: Major protocol version. SASL_MINOR = 0 #: Minor protocol version. SASL_REV = 0 #: Protocol revision. -SASL_HEADER_FRAME = b"AMQP\x03" + _AS_BYTES.pack(SASL_MAJOR) + _AS_BYTES.pack(SASL_MINOR) + _AS_BYTES.pack(SASL_REV) +SASL_HEADER_FRAME = ( + b"AMQP\x03" + + _AS_BYTES.pack(SASL_MAJOR) + + _AS_BYTES.pack(SASL_MINOR) + + _AS_BYTES.pack(SASL_REV) +) -EMPTY_FRAME = b'\x00\x00\x00\x08\x02\x00\x00\x00' +EMPTY_FRAME = b"\x00\x00\x00\x08\x02\x00\x00\x00" #: The lower bound for the agreed maximum frame size (in bytes). During the initial Connection negotiation, the #: two peers must agree upon a maximum frame size. This constant defines the minimum value to which the maximum @@ -51,7 +63,7 @@ DEFAULT_LINK_CREDIT = 10000 -FIELD = namedtuple('field', 'name, type, mandatory, default, multiple') +FIELD = namedtuple("FIELD", "name, type, mandatory, default, multiple") DEFAULT_AUTH_TIMEOUT = 60 @@ -223,6 +235,7 @@ class Role(object): """ + Sender = False Receiver = True @@ -241,6 +254,7 @@ class SenderSettleMode(object): """ + Unsettled = 0 Settled = 1 Mixed = 2 @@ -259,6 +273,7 @@ class ReceiverSettleMode(object): """ + First = 0 Second = 1 @@ -274,6 +289,7 @@ class SASLCode(object): """ + #: Connection authentication succeeded. Ok = 0 #: Connection authentication failed due to an unspecified problem with the supplied credentials. @@ -300,5 +316,5 @@ class MessageDeliveryState(object): MessageDeliveryState.Ok, MessageDeliveryState.Error, MessageDeliveryState.Timeout, - MessageDeliveryState.Cancelled + MessageDeliveryState.Cancelled, ) diff --git a/uamqp/endpoints.py b/uamqp/endpoints.py index c68cc05c3..c729356ff 100644 --- a/uamqp/endpoints.py +++ b/uamqp/endpoints.py @@ -1,8 +1,10 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access # The messaging layer defines two concrete types (source and target) to be used as the source and target of a # link. These types are supplied in the source and target fields of the attach frame when establishing or @@ -16,9 +18,9 @@ from collections import namedtuple -from .types import AMQPTypes, FieldDefinition, ObjDefinition -from .constants import FIELD -from .performatives import _CAN_ADD_DOCSTRING +from uamqp.constants import FIELD +from uamqp.performatives import _CAN_ADD_DOCSTRING +from uamqp.amqp_types import AMQPTypes, FieldDefinition, ObjDefinition class TerminusDurability(object): @@ -32,6 +34,7 @@ class TerminusDurability(object): Determines which state of the terminus is held durably. """ + #: No Terminus state is retained durably NoDurability = 0 #: Only the existence and configuration of the Terminus is retained durably. @@ -56,6 +59,7 @@ class ExpiryPolicy(object): count down is aborted. If the conditions for the terminus-expiry-policy are subsequently re-met, the expiry timer restarts from its originally configured timeout value. """ + #: The expiry timer starts when Terminus is detached. LinkDetach = b"link-detach" #: The expiry timer starts when the most recently associated session is ended. @@ -76,31 +80,32 @@ class DistributionMode(object): Policies for distributing messages when multiple links are connected to the same node. """ + #: Once successfully transferred over the link, the message will no longer be available #: to other links from the same node. - Move = b'move' + Move = b"move" #: Once successfully transferred over the link, the message is still available for other #: links from the same node. - Copy = b'copy' + Copy = b"copy" class LifeTimePolicy(object): #: Lifetime of dynamic node scoped to lifetime of link which caused creation. #: A node dynamically created with this lifetime policy will be deleted at the point that the link #: which caused its creation ceases to exist. - DeleteOnClose = 0x0000002b + DeleteOnClose = 0x0000002B #: Lifetime of dynamic node scoped to existence of links to the node. #: A node dynamically created with this lifetime policy will be deleted at the point that there remain #: no links for which the node is either the source or target. - DeleteOnNoLinks = 0x0000002c + DeleteOnNoLinks = 0x0000002C #: Lifetime of dynamic node scoped to existence of messages on the node. #: A node dynamically created with this lifetime policy will be deleted at the point that the link which #: caused its creation no longer exists and there remain no messages at the node. - DeleteOnNoMessages = 0x0000002d + DeleteOnNoMessages = 0x0000002D #: Lifetime of node scoped to existence of messages on or links to the node. #: A node dynamically created with this lifetime policy will be deleted at the point that the there are no #: links which have this node as their source or target, and there remain no messages at the node. - DeleteOnNoLinksOrMessages = 0x0000002e + DeleteOnNoLinksOrMessages = 0x0000002E class SupportedOutcomes(object): @@ -128,34 +133,38 @@ class ApacheFilters(object): Source = namedtuple( - 'source', + "Source", [ - 'address', - 'durable', - 'expiry_policy', - 'timeout', - 'dynamic', - 'dynamic_node_properties', - 'distribution_mode', - 'filters', - 'default_outcome', - 'outcomes', - 'capabilities' - ]) -Source.__new__.__defaults__ = (None,) * len(Source._fields) -Source._code = 0x00000028 -Source._definition = ( + "address", + "durable", + "expiry_policy", + "timeout", + "dynamic", + "dynamic_node_properties", + "distribution_mode", + "filters", + "default_outcome", + "outcomes", + "capabilities", + ], +) +Source.__new__.__defaults__ = (None,) * len(Source._fields) # type: ignore +Source._code = 0x00000028 # type: ignore +Source._definition = ( # type: ignore FIELD("address", AMQPTypes.string, False, None, False), FIELD("durable", AMQPTypes.uint, False, "none", False), FIELD("expiry_policy", AMQPTypes.symbol, False, ExpiryPolicy.SessionEnd, False), FIELD("timeout", AMQPTypes.uint, False, 0, False), FIELD("dynamic", AMQPTypes.boolean, False, False, False), - FIELD("dynamic_node_properties", FieldDefinition.node_properties, False, None, False), + FIELD( + "dynamic_node_properties", FieldDefinition.node_properties, False, None, False + ), FIELD("distribution_mode", AMQPTypes.symbol, False, None, False), FIELD("filters", FieldDefinition.filter_set, False, None, False), FIELD("default_outcome", ObjDefinition.delivery_state, False, None, False), FIELD("outcomes", AMQPTypes.symbol, False, None, True), - FIELD("capabilities", AMQPTypes.symbol, False, None, True)) + FIELD("capabilities", AMQPTypes.symbol, False, None, True), +) if _CAN_ADD_DOCSTRING: Source.__doc__ = """ For containers which do not implement address resolution (and do not admit spontaneous link @@ -217,26 +226,30 @@ class ApacheFilters(object): Target = namedtuple( - 'target', + "Target", [ - 'address', - 'durable', - 'expiry_policy', - 'timeout', - 'dynamic', - 'dynamic_node_properties', - 'capabilities' - ]) -Target._code = 0x00000029 -Target.__new__.__defaults__ = (None,) * len(Target._fields) -Target._definition = ( + "address", + "durable", + "expiry_policy", + "timeout", + "dynamic", + "dynamic_node_properties", + "capabilities", + ], +) +Target._code = 0x00000029 # type: ignore +Target.__new__.__defaults__ = (None,) * len(Target._fields) # type: ignore +Target._definition = ( # type: ignore FIELD("address", AMQPTypes.string, False, None, False), FIELD("durable", AMQPTypes.uint, False, "none", False), FIELD("expiry_policy", AMQPTypes.symbol, False, ExpiryPolicy.SessionEnd, False), FIELD("timeout", AMQPTypes.uint, False, 0, False), FIELD("dynamic", AMQPTypes.boolean, False, False, False), - FIELD("dynamic_node_properties", FieldDefinition.node_properties, False, None, False), - FIELD("capabilities", AMQPTypes.symbol, False, None, True)) + FIELD( + "dynamic_node_properties", FieldDefinition.node_properties, False, None, False + ), + FIELD("capabilities", AMQPTypes.symbol, False, None, True), +) if _CAN_ADD_DOCSTRING: Target.__doc__ = """ For containers which do not implement address resolution (and do not admit spontaneous link attachment diff --git a/uamqp/error.py b/uamqp/error.py index fc2b8cbfe..e4f88502a 100644 --- a/uamqp/error.py +++ b/uamqp/error.py @@ -1,14 +1,16 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -from enum import Enum from collections import namedtuple +from enum import Enum -from .constants import SECURE_PORT, FIELD -from .types import AMQPTypes, FieldDefinition +from uamqp.constants import SECURE_PORT, FIELD +from uamqp.amqp_types import AMQPTypes, FieldDefinition class ErrorCondition(bytes, Enum): @@ -88,8 +90,8 @@ class ErrorCondition(bytes, Enum): class RetryMode(str, Enum): - EXPONENTIAL = 'exponential' - FIXED = 'fixed' + EXPONENTIAL = "exponential" + FIXED = "fixed" class RetryPolicy: @@ -114,13 +116,10 @@ class RetryPolicy: ErrorCondition.SessionUnattachedHandle, ErrorCondition.SessionHandleInUse, ErrorCondition.SessionErrantLink, - ErrorCondition.SessionWindowViolation + ErrorCondition.SessionWindowViolation, ] - def __init__( - self, - **kwargs - ): + def __init__(self, **kwargs): """ keyword int retry_total: keyword float retry_backoff_factor: @@ -129,30 +128,30 @@ def __init__( keyword list no_retry: keyword dict custom_retry_policy: """ - self.total_retries = kwargs.pop('retry_total', 3) + self.total_retries = kwargs.pop("retry_total", 3) # TODO: A. consider letting retry_backoff_factor be either a float or a callback obj which returns a float # to give more extensibility on customization of retry backoff time, the callback could take the exception # as input. - self.backoff_factor = kwargs.pop('retry_backoff_factor', 0.8) - self.backoff_max = kwargs.pop('retry_backoff_max', 120) - self.retry_mode = kwargs.pop('retry_mode', RetryMode.EXPONENTIAL) - self.no_retry.extend(kwargs.get('no_retry', [])) + self.backoff_factor = kwargs.pop("retry_backoff_factor", 0.8) + self.backoff_max = kwargs.pop("retry_backoff_max", 120) + self.retry_mode = kwargs.pop("retry_mode", RetryMode.EXPONENTIAL) + self.no_retry.extend(kwargs.get("no_retry", [])) self.custom_condition_backoff = kwargs.pop("custom_condition_backoff", None) # TODO: B. As an alternative of option A, we could have a new kwarg serve the goal def configure_retries(self, **kwargs): return { - 'total': kwargs.pop("retry_total", self.total_retries), - 'backoff': kwargs.pop("retry_backoff_factor", self.backoff_factor), - 'max_backoff': kwargs.pop("retry_backoff_max", self.backoff_max), - 'retry_mode': kwargs.pop("retry_mode", self.retry_mode), - 'history': [] + "total": kwargs.pop("retry_total", self.total_retries), + "backoff": kwargs.pop("retry_backoff_factor", self.backoff_factor), + "max_backoff": kwargs.pop("retry_backoff_max", self.backoff_max), + "retry_mode": kwargs.pop("retry_mode", self.retry_mode), + "history": [], } - def increment(self, settings, error): - settings['total'] -= 1 - settings['history'].append(error) - if settings['total'] < 0: + def increment(self, settings, error): # pylint: disable=no-self-use + settings["total"] -= 1 + settings["history"].append(error) + if settings["total"] < 0: return False return True @@ -170,24 +169,24 @@ def get_backoff_time(self, settings, error): except (KeyError, TypeError): pass - consecutive_errors_len = len(settings['history']) + consecutive_errors_len = len(settings["history"]) if consecutive_errors_len <= 1: return 0 if self.retry_mode == RetryMode.FIXED: - backoff_value = settings['backoff'] + backoff_value = settings["backoff"] else: - backoff_value = settings['backoff'] * (2 ** (consecutive_errors_len - 1)) - return min(settings['max_backoff'], backoff_value) + backoff_value = settings["backoff"] * (2 ** (consecutive_errors_len - 1)) + return min(settings["max_backoff"], backoff_value) -AMQPError = namedtuple('error', ['condition', 'description', 'info']) -AMQPError.__new__.__defaults__ = (None,) * len(AMQPError._fields) -AMQPError._code = 0x0000001d -AMQPError._definition = ( - FIELD('condition', AMQPTypes.symbol, True, None, False), - FIELD('description', AMQPTypes.string, False, None, False), - FIELD('info', FieldDefinition.fields, False, None, False), +AMQPError = namedtuple("AMQPError", ["condition", "description", "info"]) +AMQPError.__new__.__defaults__ = (None,) * len(AMQPError._fields) # type: ignore +AMQPError._code = 0x0000001D # type: ignore +AMQPError._definition = ( # type: ignore + FIELD("condition", AMQPTypes.symbol, True, None, False), + FIELD("description", AMQPTypes.string, False, None, False), + FIELD("info", FieldDefinition.fields, False, None, False), ) @@ -198,6 +197,7 @@ class AMQPException(Exception): :keyword str description: A description of the error. :keyword dict info: A dictionary of additional data associated with the error. """ + def __init__(self, condition, **kwargs): self.condition = condition or ErrorCondition.UnknownError self.description = kwargs.get("description", None) @@ -205,7 +205,9 @@ def __init__(self, condition, **kwargs): self.message = kwargs.get("message", None) self.inner_error = kwargs.get("error", None) message = self.message or "Error condition: {}".format( - str(condition) if isinstance(condition, ErrorCondition) else condition.decode() + str(condition) + if isinstance(condition, ErrorCondition) + else condition.decode() ) if self.description: try: @@ -216,15 +218,11 @@ def __init__(self, condition, **kwargs): class AMQPDecodeError(AMQPException): - """An error occurred while decoding an incoming frame. - - """ + """An error occurred while decoding an incoming frame.""" class AMQPConnectionError(AMQPException): - """Details of a Connection-level error. - - """ + """Details of a Connection-level error.""" class AMQPConnectionRedirect(AMQPConnectionError): @@ -237,11 +235,14 @@ class AMQPConnectionRedirect(AMQPConnectionError): :keyword str description: A description of the error. :keyword dict info: A dictionary of additional data associated with the error. """ + def __init__(self, condition, description=None, info=None): - self.hostname = info.get(b'hostname', b'').decode('utf-8') - self.network_host = info.get(b'network-host', b'').decode('utf-8') - self.port = int(info.get(b'port', SECURE_PORT)) - super(AMQPConnectionRedirect, self).__init__(condition, description=description, info=info) + self.hostname = info.get(b"hostname", b"").decode("utf-8") + self.network_host = info.get(b"network-host", b"").decode("utf-8") + self.port = int(info.get(b"port", SECURE_PORT)) + super(AMQPConnectionRedirect, self).__init__( + condition, description=description, info=info + ) class AMQPSessionError(AMQPException): @@ -255,7 +256,7 @@ class AMQPSessionError(AMQPException): class AMQPLinkError(AMQPException): """ - + AMQP link error """ @@ -271,57 +272,68 @@ class AMQPLinkRedirect(AMQPLinkError): """ def __init__(self, condition, description=None, info=None): - self.hostname = info.get(b'hostname', b'').decode('utf-8') - self.network_host = info.get(b'network-host', b'').decode('utf-8') - self.port = int(info.get(b'port', SECURE_PORT)) - self.address = info.get(b'address', b'').decode('utf-8') - super(AMQPLinkError, self).__init__(condition, description=description, info=info) + self.hostname = info.get(b"hostname", b"").decode("utf-8") + self.network_host = info.get(b"network-host", b"").decode("utf-8") + self.port = int(info.get(b"port", SECURE_PORT)) + self.address = info.get(b"address", b"").decode("utf-8") + super(AMQPLinkRedirect, self).__init__( + condition, description=description, info=info + ) class AuthenticationException(AMQPException): """ - + Authentication exception """ class TokenExpired(AuthenticationException): """ - + Token expired exception """ class TokenAuthFailure(AuthenticationException): """ - + Token authentication failure """ + def __init__(self, status_code, status_description, **kwargs): - encoding = kwargs.get("encoding", 'utf-8') + encoding = kwargs.get("encoding", "utf-8") self.status_code = status_code self.status_description = status_description - message = "CBS Token authentication failed.\nStatus code: {}".format(self.status_code) + message = "CBS Token authentication failed.\nStatus code: {}".format( + self.status_code + ) if self.status_description: try: - message += "\nDescription: {}".format(self.status_description.decode(encoding)) + message += "\nDescription: {}".format( + self.status_description.decode(encoding) + ) except (TypeError, AttributeError): message += "\nDescription: {}".format(self.status_description) - super(TokenAuthFailure, self).__init__(condition=ErrorCondition.ClientError, message=message) + super(TokenAuthFailure, self).__init__( + condition=ErrorCondition.ClientError, message=message + ) class MessageException(AMQPException): """ - + Message exception """ class MessageSendFailed(MessageException): """ - + Message send failed """ class ErrorResponse(object): """ + Error response """ + def __init__(self, **kwargs): self.condition = kwargs.get("condition") self.description = kwargs.get("description") diff --git a/uamqp/link.py b/uamqp/link.py index e65b5614a..c22e39e06 100644 --- a/uamqp/link.py +++ b/uamqp/link.py @@ -1,42 +1,33 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -import threading -import struct -import uuid import logging -import time -from enum import Enum -from io import BytesIO -from urllib.parse import urlparse +import uuid -from .endpoints import Source, Target -from .constants import ( +from uamqp.constants import ( DEFAULT_LINK_CREDIT, SessionState, - SessionTransferState, LinkDeliverySettleReason, LinkState, Role, SenderSettleMode, - ReceiverSettleMode + ReceiverSettleMode, ) -from .performatives import ( - AttachFrame, - DetachFrame, - TransferFrame, - DispositionFrame, - FlowFrame, -) - -from .error import ( +from uamqp.endpoints import Source, Target +from uamqp.error import ( ErrorCondition, AMQPLinkError, AMQPLinkRedirect, - AMQPConnectionError + AMQPConnectionError, +) +from uamqp.performatives import ( + AttachFrame, + DetachFrame, ) _LOGGER = logging.getLogger(__name__) @@ -44,7 +35,7 @@ class Link(object): """ - + AMQP link """ def __init__(self, session, handle, name, role, **kwargs): @@ -53,56 +44,64 @@ def __init__(self, session, handle, name, role, **kwargs): self.handle = handle self.remote_handle = None self.role = role - source_address = kwargs['source_address'] + source_address = kwargs["source_address"] target_address = kwargs["target_address"] - self.source = source_address if isinstance(source_address, Source) else Source( - address=kwargs['source_address'], - durable=kwargs.get('source_durable'), - expiry_policy=kwargs.get('source_expiry_policy'), - timeout=kwargs.get('source_timeout'), - dynamic=kwargs.get('source_dynamic'), - dynamic_node_properties=kwargs.get('source_dynamic_node_properties'), - distribution_mode=kwargs.get('source_distribution_mode'), - filters=kwargs.get('source_filters'), - default_outcome=kwargs.get('source_default_outcome'), - outcomes=kwargs.get('source_outcomes'), - capabilities=kwargs.get('source_capabilities') + self.source = ( + source_address + if isinstance(source_address, Source) + else Source( + address=kwargs["source_address"], + durable=kwargs.get("source_durable"), + expiry_policy=kwargs.get("source_expiry_policy"), + timeout=kwargs.get("source_timeout"), + dynamic=kwargs.get("source_dynamic"), + dynamic_node_properties=kwargs.get("source_dynamic_node_properties"), + distribution_mode=kwargs.get("source_distribution_mode"), + filters=kwargs.get("source_filters"), + default_outcome=kwargs.get("source_default_outcome"), + outcomes=kwargs.get("source_outcomes"), + capabilities=kwargs.get("source_capabilities"), + ) ) - self.target = target_address if isinstance(target_address,Target) else Target( - address=kwargs['target_address'], - durable=kwargs.get('target_durable'), - expiry_policy=kwargs.get('target_expiry_policy'), - timeout=kwargs.get('target_timeout'), - dynamic=kwargs.get('target_dynamic'), - dynamic_node_properties=kwargs.get('target_dynamic_node_properties'), - capabilities=kwargs.get('target_capabilities') + self.target = ( + target_address + if isinstance(target_address, Target) + else Target( + address=kwargs["target_address"], + durable=kwargs.get("target_durable"), + expiry_policy=kwargs.get("target_expiry_policy"), + timeout=kwargs.get("target_timeout"), + dynamic=kwargs.get("target_dynamic"), + dynamic_node_properties=kwargs.get("target_dynamic_node_properties"), + capabilities=kwargs.get("target_capabilities"), + ) ) - self.link_credit = kwargs.pop('link_credit', None) or DEFAULT_LINK_CREDIT + self.link_credit = kwargs.pop("link_credit", None) or DEFAULT_LINK_CREDIT self.current_link_credit = self.link_credit - self.send_settle_mode = kwargs.pop('send_settle_mode', SenderSettleMode.Mixed) - self.rcv_settle_mode = kwargs.pop('rcv_settle_mode', ReceiverSettleMode.First) - self.unsettled = kwargs.pop('unsettled', None) - self.incomplete_unsettled = kwargs.pop('incomplete_unsettled', None) - self.initial_delivery_count = kwargs.pop('initial_delivery_count', 0) + self.send_settle_mode = kwargs.pop("send_settle_mode", SenderSettleMode.Mixed) + self.rcv_settle_mode = kwargs.pop("rcv_settle_mode", ReceiverSettleMode.First) + self.unsettled = kwargs.pop("unsettled", None) + self.incomplete_unsettled = kwargs.pop("incomplete_unsettled", None) + self.initial_delivery_count = kwargs.pop("initial_delivery_count", 0) self.delivery_count = self.initial_delivery_count self.received_delivery_id = None - self.max_message_size = kwargs.pop('max_message_size', None) + self.max_message_size = kwargs.pop("max_message_size", None) self.remote_max_message_size = None - self.available = kwargs.pop('available', None) - self.properties = kwargs.pop('properties', None) + self.available = kwargs.pop("available", None) + self.properties = kwargs.pop("properties", None) self.offered_capabilities = None - self.desired_capabilities = kwargs.pop('desired_capabilities', None) + self.desired_capabilities = kwargs.pop("desired_capabilities", None) - self.network_trace = kwargs['network_trace'] - self.network_trace_params = kwargs['network_trace_params'] - self.network_trace_params['link'] = self.name + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["link"] = self.name self._session = session self._is_closed = False self._send_links = {} self._receive_links = {} self._pending_deliveries = {} self._received_payload = bytearray() - self._on_link_state_change = kwargs.get('on_link_state_change') + self._on_link_state_change = kwargs.get("on_link_state_change") self._error = None def __enter__(self): @@ -115,7 +114,9 @@ def __exit__(self, *args): @classmethod def from_incoming_frame(cls, session, handle, frame): # check link_create_from_endpoint in C lib - raise NotImplementedError('Pending') # TODO: Assuming we establish all links for now... + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... def get_state(self): try: @@ -131,7 +132,7 @@ def _check_if_closed(self): except TypeError: raise AMQPConnectionError( condition=ErrorCondition.InternalError, - description="Link already closed." + description="Link already closed.", ) def _set_state(self, new_state): @@ -141,19 +142,28 @@ def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Link state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + _LOGGER.info( + "Link state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) try: self._on_link_state_change(previous_state, new_state) except TypeError: pass except Exception as e: # pylint: disable=broad-except - _LOGGER.error("Link state change callback failed: '%r'", e, extra=self.network_trace_params) + _LOGGER.error( + "Link state change callback failed: '%r'", + e, + extra=self.network_trace_params, + ) def _remove_pending_deliveries(self): # TODO: move to sender for delivery in self._pending_deliveries.values(): delivery.on_settled(LinkDeliverySettleReason.NOT_DELIVERED, None) self._pending_deliveries = {} - + def _on_session_state_change(self): if self._session.state == SessionState.MAPPED: if not self._is_closed and self.state == LinkState.DETACHED: @@ -175,11 +185,17 @@ def _outgoing_attach(self): target=self.target, unsettled=self.unsettled, incomplete_unsettled=self.incomplete_unsettled, - initial_delivery_count=self.initial_delivery_count if self.role == Role.Sender else None, + initial_delivery_count=self.initial_delivery_count + if self.role == Role.Sender + else None, max_message_size=self.max_message_size, - offered_capabilities=self.offered_capabilities if self.state == LinkState.ATTACH_RCVD else None, - desired_capabilities=self.desired_capabilities if self.state == LinkState.DETACHED else None, - properties=self.properties + offered_capabilities=self.offered_capabilities + if self.state == LinkState.ATTACH_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == LinkState.DETACHED + else None, + properties=self.properties, ) if self.network_trace: _LOGGER.info("-> %r", attach_frame, extra=self.network_trace_params) @@ -190,7 +206,9 @@ def _incoming_attach(self, frame): _LOGGER.info("<- %r", AttachFrame(*frame), extra=self.network_trace_params) if self._is_closed: raise ValueError("Invalid link") - elif not frame[5] or not frame[6]: # TODO: not sure if we should source + target check here + if ( + not frame[5] or not frame[6] + ): # TODO: not sure if we should source + target check here _LOGGER.info("Cannot get source or target. Detaching link") self._remove_pending_deliveries() self._set_state(LinkState.DETACHED) # TODO: Send detach now? @@ -209,19 +227,19 @@ def _incoming_attach(self, frame): def _outgoing_flow(self): flow_frame = { - 'handle': self.handle, - 'delivery_count': self.delivery_count, - 'link_credit': self.current_link_credit, - 'available': None, - 'drain': None, - 'echo': None, - 'properties': None + "handle": self.handle, + "delivery_count": self.delivery_count, + "link_credit": self.current_link_credit, + "available": None, + "drain": None, + "echo": None, + "properties": None, } self._session._outgoing_flow(flow_frame) def _incoming_flow(self, frame): pass - + def _incoming_disposition(self, frame): pass @@ -238,7 +256,11 @@ def _incoming_detach(self, frame): _LOGGER.info("<- %r", DetachFrame(*frame), extra=self.network_trace_params) if self.state == LinkState.ATTACHED: self._outgoing_detach(close=frame[1]) # closed - elif frame[1] and not self._is_closed and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD]: + elif ( + frame[1] + and not self._is_closed + and self.state in [LinkState.ATTACH_SENT, LinkState.ATTACH_RCVD] + ): # Received a closing detach after we sent a non-closing detach. # In this case, we MUST signal that we closed by reattaching and then sending a closing detach. self._outgoing_attach() @@ -247,8 +269,14 @@ def _incoming_detach(self, frame): # TODO: on_detach_hook if frame[2]: # error # frame[2][0] is condition, frame[2][1] is description, frame[2][2] is info - error_cls = AMQPLinkRedirect if frame[2][0] == ErrorCondition.LinkRedirect else AMQPLinkError - self._error = error_cls(condition=frame[2][0], description=frame[2][1], info=frame[2][2]) + error_cls = ( + AMQPLinkRedirect + if frame[2][0] == ErrorCondition.LinkRedirect + else AMQPLinkError + ) + self._error = error_cls( + condition=frame[2][0], description=frame[2][1], info=frame[2][2] + ) self._set_state(LinkState.ERROR) else: self._set_state(LinkState.DETACHED) @@ -272,6 +300,6 @@ def detach(self, close=False, error=None): elif self.state == LinkState.ATTACHED: self._outgoing_detach(close=close, error=error) self._set_state(LinkState.DETACH_SENT) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except _LOGGER.info("An error occurred when detaching the link: %r", exc) self._set_state(LinkState.DETACHED) diff --git a/uamqp/management_link.py b/uamqp/management_link.py index 9784a96f8..7ddd6bb4b 100644 --- a/uamqp/management_link.py +++ b/uamqp/management_link.py @@ -1,62 +1,68 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -import time import logging -from functools import partial +import time from collections import namedtuple +from functools import partial -from .sender import SenderLink -from .receiver import ReceiverLink -from .constants import ( +from uamqp.constants import ( ManagementLinkState, LinkState, SenderSettleMode, ReceiverSettleMode, ManagementExecuteOperationResult, ManagementOpenResult, - SEND_DISPOSITION_ACCEPT, SEND_DISPOSITION_REJECT, - MessageDeliveryState + MessageDeliveryState, ) -from .error import ErrorResponse, AMQPException, ErrorCondition -from .message import Message, Properties, _MessageDelivery +from uamqp.error import AMQPException, ErrorCondition +from uamqp.message import Properties, _MessageDelivery +from uamqp.receiver import ReceiverLink +from uamqp.sender import SenderLink _LOGGER = logging.getLogger(__name__) -PendingManagementOperation = namedtuple('PendingManagementOperation', ['message', 'on_execute_operation_complete']) +PendingManagementOperation = namedtuple( + "PendingManagementOperation", ["message", "on_execute_operation_complete"] +) class ManagementLink(object): """ - + AMQP management link """ + def __init__(self, session, endpoint, **kwargs): self.next_message_id = 0 self.state = ManagementLinkState.IDLE self._pending_operations = [] self._session = session - self._request_link = session.create_sender_link( # type: SenderLink + self._request_link: SenderLink = session.create_sender_link( endpoint, on_link_state_change=self._on_sender_state_change, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, ) - self._response_link = session.create_receiver_link( # type: ReceiverLink + self._response_link: ReceiverLink = session.create_receiver_link( endpoint, on_link_state_change=self._on_receiver_state_change, on_message_received=self._on_message_received, send_settle_mode=SenderSettleMode.Unsettled, - rcv_settle_mode=ReceiverSettleMode.First + rcv_settle_mode=ReceiverSettleMode.First, + ) + self._on_amqp_management_error = kwargs.get("on_amqp_management_error") + self._on_amqp_management_open_complete = kwargs.get( + "on_amqp_management_open_complete" ) - self._on_amqp_management_error = kwargs.get('on_amqp_management_error') - self._on_amqp_management_open_complete = kwargs.get('on_amqp_management_open_complete') - self._status_code_field = kwargs.get('status_code_field', b'statusCode') - self._status_description_field = kwargs.get('status_description_field', b'statusDescription') + self._status_code_field = kwargs.get("status_code_field", b"statusCode") + self._status_description_field = kwargs.get( + "status_description_field", b"statusDescription" + ) self._sender_connected = False self._receiver_connected = False @@ -69,7 +75,9 @@ def __exit__(self, *args): self.close() def _on_sender_state_change(self, previous_state, new_state): - _LOGGER.info("Management link sender state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link sender state changed: %r -> %r", previous_state, new_state + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -78,7 +86,12 @@ def _on_sender_state_change(self, previous_state, new_state): if self._receiver_connected: self.state = ManagementLinkState.OPEN self._on_amqp_management_open_complete(ManagementOpenResult.OK) - elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + elif new_state in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + LinkState.ERROR, + ]: self.state = ManagementLinkState.IDLE self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) elif self.state == ManagementLinkState.OPEN: @@ -86,7 +99,11 @@ def _on_sender_state_change(self, previous_state, new_state): self.state = ManagementLinkState.ERROR self._on_amqp_management_error() elif self.state == ManagementLinkState.CLOSING: - if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + if new_state not in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + ]: self.state = ManagementLinkState.ERROR self._on_amqp_management_error() elif self.state == ManagementLinkState.ERROR: @@ -94,7 +111,11 @@ def _on_sender_state_change(self, previous_state, new_state): return def _on_receiver_state_change(self, previous_state, new_state): - _LOGGER.info("Management link receiver state changed: %r -> %r", previous_state, new_state) + _LOGGER.info( + "Management link receiver state changed: %r -> %r", + previous_state, + new_state, + ) if new_state == previous_state: return if self.state == ManagementLinkState.OPENING: @@ -103,7 +124,12 @@ def _on_receiver_state_change(self, previous_state, new_state): if self._sender_connected: self.state = ManagementLinkState.OPEN self._on_amqp_management_open_complete(ManagementOpenResult.OK) - elif new_state in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD, LinkState.ERROR]: + elif new_state in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + LinkState.ERROR, + ]: self.state = ManagementLinkState.IDLE self._on_amqp_management_open_complete(ManagementOpenResult.ERROR) elif self.state == ManagementLinkState.OPEN: @@ -111,7 +137,11 @@ def _on_receiver_state_change(self, previous_state, new_state): self.state = ManagementLinkState.ERROR self._on_amqp_management_error() elif self.state == ManagementLinkState.CLOSING: - if new_state not in [LinkState.DETACHED, LinkState.DETACH_SENT, LinkState.DETACH_RCVD]: + if new_state not in [ + LinkState.DETACHED, + LinkState.DETACH_SENT, + LinkState.DETACH_RCVD, + ]: self.state = ManagementLinkState.ERROR self._on_amqp_management_error() elif self.state == ManagementLinkState.ERROR: @@ -132,18 +162,24 @@ def _on_message_received(self, message): to_remove_operation = operation break if to_remove_operation: - mgmt_result = ManagementExecuteOperationResult.OK \ - if 200 <= status_code <= 299 else ManagementExecuteOperationResult.FAILED_BAD_STATUS + mgmt_result = ( + ManagementExecuteOperationResult.OK + if 200 <= status_code <= 299 + else ManagementExecuteOperationResult.FAILED_BAD_STATUS + ) to_remove_operation.on_execute_operation_complete( mgmt_result, status_code, status_description, message, - response_detail.get(b'error-condition') + response_detail.get(b"error-condition"), ) self._pending_operations.remove(to_remove_operation) - def _on_send_complete(self, message_delivery, reason, state): # todo: reason is never used, should check spec + def _on_send_complete( + self, message_delivery, reason, state + ): # pylint: disable=unused-argument + # todo: reason is never used, should check spec if SEND_DISPOSITION_REJECT in state: # sample reject state: {'rejected': [[b'amqp:not-allowed', b"Invalid command 'RE1AD'.", None]]} to_remove_operation = None @@ -161,10 +197,14 @@ def _on_send_complete(self, message_delivery, reason, state): # todo: reason is None, message_delivery.message, error=AMQPException( - condition=state[SEND_DISPOSITION_REJECT][0][0], # 0 is error condition - description=state[SEND_DISPOSITION_REJECT][0][1], # 1 is error description + condition=state[SEND_DISPOSITION_REJECT][0][ + 0 + ], # 0 is error condition + description=state[SEND_DISPOSITION_REJECT][0][ + 1 + ], # 1 is error description info=state[SEND_DISPOSITION_REJECT][0][2], # 2 is error info - ) + ), ) def open(self): @@ -174,12 +214,7 @@ def open(self): self._response_link.attach() self._request_link.attach() - def execute_operation( - self, - message, - on_execute_operation_complete, - **kwargs - ): + def execute_operation(self, message, on_execute_operation_complete, **kwargs): """Execute a request and wait on a response. :param message: The message to send in the management request. @@ -208,26 +243,26 @@ def execute_operation( message.application_properties["locales"] = kwargs.get("locales") try: # TODO: namedtuple is immutable, which may push us to re-think about the namedtuple approach for Message - new_properties = message.properties._replace(message_id=self.next_message_id) + new_properties = message.properties._replace( + message_id=self.next_message_id + ) except AttributeError: new_properties = Properties(message_id=self.next_message_id) message = message._replace(properties=new_properties) expire_time = (time.time() + timeout) if timeout else None message_delivery = _MessageDelivery( - message, - MessageDeliveryState.WaitingToBeSent, - expire_time + message, MessageDeliveryState.WaitingToBeSent, expire_time ) on_send_complete = partial(self._on_send_complete, message_delivery) self._request_link.send_transfer( - message, - on_send_complete=on_send_complete, - timeout=timeout + message, on_send_complete=on_send_complete, timeout=timeout ) self.next_message_id += 1 - self._pending_operations.append(PendingManagementOperation(message, on_execute_operation_complete)) + self._pending_operations.append( + PendingManagementOperation(message, on_execute_operation_complete) + ) def close(self): if self.state != ManagementLinkState.IDLE: @@ -240,7 +275,10 @@ def close(self): None, None, pending_operation.message, - AMQPException(condition=ErrorCondition.ClientError, description="Management link already closed.") + AMQPException( + condition=ErrorCondition.ClientError, + description="Management link already closed.", + ), ) self._pending_operations = [] self.state = ManagementLinkState.IDLE diff --git a/uamqp/management_operation.py b/uamqp/management_operation.py index 811074f4b..e55edf59f 100644 --- a/uamqp/management_operation.py +++ b/uamqp/management_operation.py @@ -1,32 +1,22 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging -import uuid import time +import uuid from functools import partial -from .management_link import ManagementLink -from .message import Message -from .error import ( - AMQPException, - AMQPConnectionError, - AMQPLinkError, - ErrorCondition -) - -from .constants import ( - ManagementOpenResult, - ManagementExecuteOperationResult -) +from uamqp.constants import ManagementOpenResult, ManagementExecuteOperationResult +from uamqp.error import AMQPLinkError, ErrorCondition +from uamqp.management_link import ManagementLink _LOGGER = logging.getLogger(__name__) class ManagementOperation(object): - def __init__(self, session, endpoint='$management', **kwargs): + def __init__(self, session, endpoint="$management", **kwargs): self._mgmt_link_open_status = None self._session = session @@ -61,7 +51,7 @@ def _on_execute_operation_complete( status_code, status_description, raw_message, - error=None + error=None, ): _LOGGER.debug( "mgmt operation completed, operation id: %r; operation_result: %r; status_code: %r; " @@ -71,18 +61,25 @@ def _on_execute_operation_complete( status_code, status_description, raw_message, - error + error, ) - if operation_result in\ - (ManagementExecuteOperationResult.ERROR, ManagementExecuteOperationResult.LINK_CLOSED): + if operation_result in ( + ManagementExecuteOperationResult.ERROR, + ManagementExecuteOperationResult.LINK_CLOSED, + ): self._mgmt_error = error _LOGGER.error( "Failed to complete mgmt operation due to error: %r. The management request message is: %r", - error, raw_message + error, + raw_message, ) else: - self._responses[operation_id] = (status_code, status_description, raw_message) + self._responses[operation_id] = ( + status_code, + status_description, + raw_message, + ) def execute(self, message, operation=None, operation_type=None, timeout=0): start_time = time.time() @@ -95,14 +92,16 @@ def execute(self, message, operation=None, operation_type=None, timeout=0): partial(self._on_execute_operation_complete, operation_id), timeout=timeout, operation=operation, - type=operation_type + type=operation_type, ) while not self._responses[operation_id] and not self._mgmt_error: if timeout > 0: now = time.time() if (now - start_time) >= timeout: - raise TimeoutError("Failed to receive mgmt response in {}ms".format(timeout)) + raise TimeoutError( + "Failed to receive mgmt response in {}ms".format(timeout) + ) self._connection.listen() if self._mgmt_error: @@ -130,8 +129,10 @@ def ready(self): # TODO: update below with correct status code + info raise AMQPLinkError( condition=ErrorCondition.ClientError, - description="Failed to open mgmt link, management link status: {}".format(self._mgmt_link_open_status), - info=None + description="Failed to open mgmt link, management link status: {}".format( + self._mgmt_link_open_status + ), + info=None, ) def close(self): diff --git a/uamqp/message.py b/uamqp/message.py index a2ef0087f..ca5b2a920 100644 --- a/uamqp/message.py +++ b/uamqp/message.py @@ -1,33 +1,29 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from collections import namedtuple +# pylint: disable=protected-access -from .types import AMQPTypes, FieldDefinition -from .constants import FIELD, MessageDeliveryState -from .performatives import _CAN_ADD_DOCSTRING +from collections import namedtuple +from uamqp.amqp_types import AMQPTypes, FieldDefinition +from uamqp.constants import FIELD, MessageDeliveryState +from uamqp.performatives import _CAN_ADD_DOCSTRING Header = namedtuple( - 'header', - [ - 'durable', - 'priority', - 'ttl', - 'first_acquirer', - 'delivery_count' - ]) -Header._code = 0x00000070 -Header.__new__.__defaults__ = (None,) * len(Header._fields) -Header._definition = ( + "Header", ["durable", "priority", "ttl", "first_acquirer", "delivery_count"] +) +Header._code = 0x00000070 # type: ignore +Header.__new__.__defaults__ = (None,) * len(Header._fields) # type: ignore +Header._definition = ( # type: ignore FIELD("durable", AMQPTypes.boolean, False, None, False), FIELD("priority", AMQPTypes.ubyte, False, None, False), FIELD("ttl", AMQPTypes.uint, False, None, False), FIELD("first_acquirer", AMQPTypes.boolean, False, None, False), - FIELD("delivery_count", AMQPTypes.uint, False, None, False)) + FIELD("delivery_count", AMQPTypes.uint, False, None, False), +) if _CAN_ADD_DOCSTRING: Header.__doc__ = """ Transport headers for a Message. @@ -75,25 +71,26 @@ Properties = namedtuple( - 'properties', + "Properties", [ - 'message_id', - 'user_id', - 'to', - 'subject', - 'reply_to', - 'correlation_id', - 'content_type', - 'content_encoding', - 'absolute_expiry_time', - 'creation_time', - 'group_id', - 'group_sequence', - 'reply_to_group_id' - ]) -Properties._code = 0x00000073 -Properties.__new__.__defaults__ = (None,) * len(Properties._fields) -Properties._definition = ( + "message_id", + "user_id", + "to", + "subject", + "reply_to", + "correlation_id", + "content_type", + "content_encoding", + "absolute_expiry_time", + "creation_time", + "group_id", + "group_sequence", + "reply_to_group_id", + ], +) +Properties._code = 0x00000073 # type: ignore +Properties.__new__.__defaults__ = (None,) * len(Properties._fields) # type: ignore +Properties._definition = ( # type: ignore FIELD("message_id", FieldDefinition.message_id, False, None, False), FIELD("user_id", AMQPTypes.binary, False, None, False), FIELD("to", AMQPTypes.string, False, None, False), @@ -106,7 +103,8 @@ FIELD("creation_time", AMQPTypes.timestamp, False, None, False), FIELD("group_id", AMQPTypes.string, False, None, False), FIELD("group_sequence", AMQPTypes.uint, False, None, False), - FIELD("reply_to_group_id", AMQPTypes.string, False, None, False)) + FIELD("reply_to_group_id", AMQPTypes.string, False, None, False), +) if _CAN_ADD_DOCSTRING: Properties.__doc__ = """ Immutable properties of the Message. @@ -165,30 +163,38 @@ # TODO: should be a class, namedtuple or dataclass, immutability vs performance, need to collect performance data Message = namedtuple( - 'message', + "Message", [ - 'header', - 'delivery_annotations', - 'message_annotations', - 'properties', - 'application_properties', - 'data', - 'sequence', - 'value', - 'footer', - ]) -Message.__new__.__defaults__ = (None,) * len(Message._fields) -Message._code = 0 -Message._definition = ( + "header", + "delivery_annotations", + "message_annotations", + "properties", + "application_properties", + "data", + "sequence", + "value", + "footer", + ], +) +Message.__new__.__defaults__ = (None,) * len(Message._fields) # type: ignore +Message._code = 0 # type: ignore +Message._definition = ( # type: ignore (0x00000070, FIELD("header", Header, False, None, False)), - (0x00000071, FIELD("delivery_annotations", FieldDefinition.annotations, False, None, False)), - (0x00000072, FIELD("message_annotations", FieldDefinition.annotations, False, None, False)), + ( + 0x00000071, + FIELD("delivery_annotations", FieldDefinition.annotations, False, None, False), + ), + ( + 0x00000072, + FIELD("message_annotations", FieldDefinition.annotations, False, None, False), + ), (0x00000073, FIELD("properties", Properties, False, None, False)), (0x00000074, FIELD("application_properties", AMQPTypes.map, False, None, False)), (0x00000075, FIELD("data", AMQPTypes.binary, False, None, True)), (0x00000076, FIELD("sequence", AMQPTypes.list, False, None, False)), (0x00000077, FIELD("value", None, False, None, False)), - (0x00000078, FIELD("footer", FieldDefinition.annotations, False, None, False))) + (0x00000078, FIELD("footer", FieldDefinition.annotations, False, None, False)), +) if _CAN_ADD_DOCSTRING: Message.__doc__ = """ An annotated message consists of the bare message plus sections for annotation at the head and tail @@ -258,7 +264,9 @@ class BatchMessage(Message): class _MessageDelivery: - def __init__(self, message, state=MessageDeliveryState.WaitingToBeSent, expiry=None): + def __init__( + self, message, state=MessageDeliveryState.WaitingToBeSent, expiry=None + ): self.message = message self.state = state self.expiry = expiry diff --git a/uamqp/outcomes.py b/uamqp/outcomes.py index 970a1d92b..fb00438ea 100644 --- a/uamqp/outcomes.py +++ b/uamqp/outcomes.py @@ -1,8 +1,10 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access # The Messaging layer defines a concrete set of delivery states which can be used (via the disposition frame) # to indicate the state of the message at the receiver. @@ -27,16 +29,16 @@ from collections import namedtuple -from .types import AMQPTypes, FieldDefinition, ObjDefinition -from .constants import FIELD -from .performatives import _CAN_ADD_DOCSTRING - +from uamqp.amqp_types import AMQPTypes, FieldDefinition, ObjDefinition +from uamqp.constants import FIELD +from uamqp.performatives import _CAN_ADD_DOCSTRING # type: ignore -Received = namedtuple('received', ['section_number', 'section_offset']) -Received._code = 0x00000023 -Received._definition = ( +Received = namedtuple("Received", ["section_number", "section_offset"]) +Received._code = 0x00000023 # type: ignore +Received._definition = ( # type: ignore FIELD("section_number", AMQPTypes.uint, True, None, False), - FIELD("section_offset", AMQPTypes.ulong, True, None, False)) + FIELD("section_offset", AMQPTypes.ulong, True, None, False), +) if _CAN_ADD_DOCSTRING: Received.__doc__ = """ At the target the received state indicates the furthest point in the payload of the message @@ -64,9 +66,9 @@ """ -Accepted = namedtuple('accepted', []) -Accepted._code = 0x00000024 -Accepted._definition = () +Accepted = namedtuple("Accepted", []) +Accepted._code = 0x00000024 # type: ignore +Accepted._definition = () # type: ignore if _CAN_ADD_DOCSTRING: Accepted.__doc__ = """ The accepted outcome. @@ -82,9 +84,9 @@ """ -Rejected = namedtuple('rejected', ['error']) -Rejected._code = 0x00000025 -Rejected._definition = (FIELD("error", ObjDefinition.error, False, None, False),) +Rejected = namedtuple("Rejected", ["error"]) +Rejected._code = 0x00000025 # type: ignore +Rejected._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore if _CAN_ADD_DOCSTRING: Rejected.__doc__ = """ The rejected outcome. @@ -101,9 +103,9 @@ """ -Released = namedtuple('released', []) -Released._code = 0x00000026 -Released._definition = () +Released = namedtuple("Released", []) +Released._code = 0x00000026 # type: ignore +Released._definition = () # type: ignore if _CAN_ADD_DOCSTRING: Released.__doc__ = """ The released outcome. @@ -122,12 +124,15 @@ """ -Modified = namedtuple('modified', ['delivery_failed', 'undeliverable_here', 'message_annotations']) -Modified._code = 0x00000027 -Modified._definition = ( - FIELD('delivery_failed', AMQPTypes.boolean, False, None, False), - FIELD('undeliverable_here', AMQPTypes.boolean, False, None, False), - FIELD('message_annotations', FieldDefinition.fields, False, None, False)) +Modified = namedtuple( + "Modified", ["delivery_failed", "undeliverable_here", "message_annotations"] +) +Modified._code = 0x00000027 # type: ignore +Modified._definition = ( # type: ignore + FIELD("delivery_failed", AMQPTypes.boolean, False, None, False), + FIELD("undeliverable_here", AMQPTypes.boolean, False, None, False), + FIELD("message_annotations", FieldDefinition.fields, False, None, False), +) if _CAN_ADD_DOCSTRING: Modified.__doc__ = """ The modified outcome. diff --git a/uamqp/performatives.py b/uamqp/performatives.py index 8b27295fa..eb33f276b 100644 --- a/uamqp/performatives.py +++ b/uamqp/performatives.py @@ -1,34 +1,36 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from collections import namedtuple import sys +from collections import namedtuple +from typing import Union -from .types import AMQPTypes, FieldDefinition, ObjDefinition -from .constants import FIELD +from uamqp.constants import FIELD +from uamqp.amqp_types import AMQPTypes, FieldDefinition, ObjDefinition _CAN_ADD_DOCSTRING = sys.version_info.major >= 3 OpenFrame = namedtuple( - 'open', + "OpenFrame", [ - 'container_id', - 'hostname', - 'max_frame_size', - 'channel_max', - 'idle_timeout', - 'outgoing_locales', - 'incoming_locales', - 'offered_capabilities', - 'desired_capabilities', - 'properties' - ]) -OpenFrame._code = 0x00000010 # pylint:disable=protected-access -OpenFrame._definition = ( # pylint:disable=protected-access + "container_id", + "hostname", + "max_frame_size", + "channel_max", + "idle_timeout", + "outgoing_locales", + "incoming_locales", + "offered_capabilities", + "desired_capabilities", + "properties", + ], +) +OpenFrame._code = 0x00000010 # type: ignore # pylint:disable=protected-access +OpenFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("container_id", AMQPTypes.string, True, None, False), FIELD("hostname", AMQPTypes.string, False, None, False), FIELD("max_frame_size", AMQPTypes.uint, False, 4294967295, False), @@ -38,7 +40,8 @@ FIELD("incoming_locales", AMQPTypes.symbol, False, None, True), FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), - FIELD("properties", FieldDefinition.fields, False, None, False)) + FIELD("properties", FieldDefinition.fields, False, None, False), +) if _CAN_ADD_DOCSTRING: OpenFrame.__doc__ = """ OPEN performative. Negotiate Connection parameters. @@ -103,19 +106,20 @@ BeginFrame = namedtuple( - 'begin', + "BeginFrame", [ - 'remote_channel', - 'next_outgoing_id', - 'incoming_window', - 'outgoing_window', - 'handle_max', - 'offered_capabilities', - 'desired_capabilities', - 'properties' - ]) -BeginFrame._code = 0x00000011 # pylint:disable=protected-access -BeginFrame._definition = ( # pylint:disable=protected-access + "remote_channel", + "next_outgoing_id", + "incoming_window", + "outgoing_window", + "handle_max", + "offered_capabilities", + "desired_capabilities", + "properties", + ], +) +BeginFrame._code = 0x00000011 # type: ignore # pylint:disable=protected-access +BeginFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("remote_channel", AMQPTypes.ushort, False, None, False), FIELD("next_outgoing_id", AMQPTypes.uint, True, None, False), FIELD("incoming_window", AMQPTypes.uint, True, None, False), @@ -123,7 +127,8 @@ FIELD("handle_max", AMQPTypes.uint, False, 4294967295, False), FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), - FIELD("properties", FieldDefinition.fields, False, None, False)) + FIELD("properties", FieldDefinition.fields, False, None, False), +) if _CAN_ADD_DOCSTRING: BeginFrame.__doc__ = """ BEGIN performative. Begin a Session on a channel. @@ -163,25 +168,26 @@ AttachFrame = namedtuple( - 'attach', + "AttachFrame", [ - 'name', - 'handle', - 'role', - 'send_settle_mode', - 'rcv_settle_mode', - 'source', - 'target', - 'unsettled', - 'incomplete_unsettled', - 'initial_delivery_count', - 'max_message_size', - 'offered_capabilities', - 'desired_capabilities', - 'properties' - ]) -AttachFrame._code = 0x00000012 # pylint:disable=protected-access -AttachFrame._definition = ( # pylint:disable=protected-access + "name", + "handle", + "role", + "send_settle_mode", + "rcv_settle_mode", + "source", + "target", + "unsettled", + "incomplete_unsettled", + "initial_delivery_count", + "max_message_size", + "offered_capabilities", + "desired_capabilities", + "properties", + ], +) +AttachFrame._code = 0x00000012 # type: ignore # pylint:disable=protected-access +AttachFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("name", AMQPTypes.string, True, None, False), FIELD("handle", AMQPTypes.uint, True, None, False), FIELD("role", AMQPTypes.boolean, True, None, False), @@ -195,7 +201,8 @@ FIELD("max_message_size", AMQPTypes.ulong, False, None, False), FIELD("offered_capabilities", AMQPTypes.symbol, False, None, True), FIELD("desired_capabilities", AMQPTypes.symbol, False, None, True), - FIELD("properties", FieldDefinition.fields, False, None, False)) + FIELD("properties", FieldDefinition.fields, False, None, False), +) if _CAN_ADD_DOCSTRING: AttachFrame.__doc__ = """ ATTACH performative. Attach a Link to a Session. @@ -262,23 +269,24 @@ FlowFrame = namedtuple( - 'flow', + "FlowFrame", [ - 'next_incoming_id', - 'incoming_window', - 'next_outgoing_id', - 'outgoing_window', - 'handle', - 'delivery_count', - 'link_credit', - 'available', - 'drain', - 'echo', - 'properties' - ]) -FlowFrame.__new__.__defaults__ = (None, None, None, None, None, None, None) -FlowFrame._code = 0x00000013 # pylint:disable=protected-access -FlowFrame._definition = ( # pylint:disable=protected-access + "next_incoming_id", + "incoming_window", + "next_outgoing_id", + "outgoing_window", + "handle", + "delivery_count", + "link_credit", + "available", + "drain", + "echo", + "properties", + ], +) +FlowFrame._code = 0x00000013 # type: ignore # pylint:disable=protected-access +FlowFrame.__new__.__defaults__ = (None,) * len(FlowFrame._fields) # type: ignore +FlowFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("next_incoming_id", AMQPTypes.uint, False, None, False), FIELD("incoming_window", AMQPTypes.uint, True, None, False), FIELD("next_outgoing_id", AMQPTypes.uint, True, None, False), @@ -289,7 +297,8 @@ FIELD("available", AMQPTypes.uint, False, None, False), FIELD("drain", AMQPTypes.boolean, False, False, False), FIELD("echo", AMQPTypes.boolean, False, False, False), - FIELD("properties", FieldDefinition.fields, False, None, False)) + FIELD("properties", FieldDefinition.fields, False, None, False), +) if _CAN_ADD_DOCSTRING: FlowFrame.__doc__ = """ FLOW performative. Update link state. @@ -334,23 +343,24 @@ TransferFrame = namedtuple( - 'transfer', + "TransferFrame", [ - 'handle', - 'delivery_id', - 'delivery_tag', - 'message_format', - 'settled', - 'more', - 'rcv_settle_mode', - 'state', - 'resume', - 'aborted', - 'batchable', - 'payload' - ]) -TransferFrame._code = 0x00000014 # pylint:disable=protected-access -TransferFrame._definition = ( # pylint:disable=protected-access + "handle", + "delivery_id", + "delivery_tag", + "message_format", + "settled", + "more", + "rcv_settle_mode", + "state", + "resume", + "aborted", + "batchable", + "payload", + ], +) +TransferFrame._code = 0x00000014 # type: ignore # pylint:disable=protected-access +TransferFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("handle", AMQPTypes.uint, True, None, False), FIELD("delivery_id", AMQPTypes.uint, False, None, False), FIELD("delivery_tag", AMQPTypes.binary, False, None, False), @@ -361,8 +371,9 @@ FIELD("state", ObjDefinition.delivery_state, False, None, False), FIELD("resume", AMQPTypes.boolean, False, False, False), FIELD("aborted", AMQPTypes.boolean, False, False, False), - FIELD("batchable", AMQPTypes.boolean, False, False, False), - None) + FIELD("batchable", AMQPTypes.boolean, False, False, False), + None, +) if _CAN_ADD_DOCSTRING: TransferFrame.__doc__ = """ TRANSFER performative. Transfer a Message. @@ -435,23 +446,17 @@ DispositionFrame = namedtuple( - 'disposition', - [ - 'role', - 'first', - 'last', - 'settled', - 'state', - 'batchable' - ]) -DispositionFrame._code = 0x00000015 # pylint:disable=protected-access -DispositionFrame._definition = ( # pylint:disable=protected-access + "DispositionFrame", ["role", "first", "last", "settled", "state", "batchable"] +) +DispositionFrame._code = 0x00000015 # type: ignore # pylint:disable=protected-access +DispositionFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("role", AMQPTypes.boolean, True, None, False), FIELD("first", AMQPTypes.uint, True, None, False), FIELD("last", AMQPTypes.uint, False, None, False), FIELD("settled", AMQPTypes.boolean, False, False, False), FIELD("state", ObjDefinition.delivery_state, False, None, False), - FIELD("batchable", AMQPTypes.boolean, False, False, False)) + FIELD("batchable", AMQPTypes.boolean, False, False, False), +) if _CAN_ADD_DOCSTRING: DispositionFrame.__doc__ = """ DISPOSITION performative. Inform remote peer of delivery state changes. @@ -484,12 +489,13 @@ implementation uses when communicating delivery states, and thereby save bandwidth. """ -DetachFrame = namedtuple('detach', ['handle', 'closed', 'error']) -DetachFrame._code = 0x00000016 # pylint:disable=protected-access -DetachFrame._definition = ( # pylint:disable=protected-access +DetachFrame = namedtuple("DetachFrame", ["handle", "closed", "error"]) +DetachFrame._code = 0x00000016 # type: ignore # pylint:disable=protected-access +DetachFrame._definition = ( # type: ignore # pylint:disable=protected-access FIELD("handle", AMQPTypes.uint, True, None, False), FIELD("closed", AMQPTypes.boolean, False, False, False), - FIELD("error", ObjDefinition.error, False, None, False)) + FIELD("error", ObjDefinition.error, False, None, False), +) if _CAN_ADD_DOCSTRING: DetachFrame.__doc__ = """ DETACH performative. Detach the Link Endpoint from the Session. @@ -505,9 +511,9 @@ """ -EndFrame = namedtuple('end', ['error']) -EndFrame._code = 0x00000017 # pylint:disable=protected-access -EndFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # pylint:disable=protected-access +EndFrame = namedtuple("EndFrame", ["error"]) +EndFrame._code = 0x00000017 # type: ignore # pylint:disable=protected-access +EndFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore # pylint:disable=protected-access if _CAN_ADD_DOCSTRING: EndFrame.__doc__ = """ END performative. End the Session. @@ -520,9 +526,9 @@ """ -CloseFrame = namedtuple('close', ['error']) -CloseFrame._code = 0x00000018 # pylint:disable=protected-access -CloseFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # pylint:disable=protected-access +CloseFrame = namedtuple("CloseFrame", ["error"]) +CloseFrame._code = 0x00000018 # type: ignore # pylint:disable=protected-access +CloseFrame._definition = (FIELD("error", ObjDefinition.error, False, None, False),) # type: ignore # pylint:disable=protected-access if _CAN_ADD_DOCSTRING: CloseFrame.__doc__ = """ CLOSE performative. Signal a Connection close. @@ -537,9 +543,9 @@ """ -SASLMechanism = namedtuple('sasl_mechanism', ['sasl_server_mechanisms']) -SASLMechanism._code = 0x00000040 # pylint:disable=protected-access -SASLMechanism._definition = (FIELD('sasl_server_mechanisms', AMQPTypes.symbol, True, None, True),) # pylint:disable=protected-access +SASLMechanism = namedtuple("SASLMechanism", ["sasl_server_mechanisms"]) +SASLMechanism._code = 0x00000040 # type: ignore # pylint:disable=protected-access +SASLMechanism._definition = (FIELD("sasl_server_mechanisms", AMQPTypes.symbol, True, None, True),) # type: ignore # pylint:disable=protected-access if _CAN_ADD_DOCSTRING: SASLMechanism.__doc__ = """ Advertise available sasl mechanisms. @@ -554,12 +560,13 @@ """ -SASLInit = namedtuple('sasl_init', ['mechanism', 'initial_response', 'hostname']) -SASLInit._code = 0x00000041 # pylint:disable=protected-access -SASLInit._definition = ( # pylint:disable=protected-access - FIELD('mechanism', AMQPTypes.symbol, True, None, False), - FIELD('initial_response', AMQPTypes.binary, False, None, False), - FIELD('hostname', AMQPTypes.string, False, None, False)) +SASLInit = namedtuple("SASLInit", ["mechanism", "initial_response", "hostname"]) +SASLInit._code = 0x00000041 # type: ignore # pylint:disable=protected-access +SASLInit._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("mechanism", AMQPTypes.symbol, True, None, False), + FIELD("initial_response", AMQPTypes.binary, False, None, False), + FIELD("hostname", AMQPTypes.string, False, None, False), +) if _CAN_ADD_DOCSTRING: SASLInit.__doc__ = """ Initiate sasl exchange. @@ -585,9 +592,9 @@ """ -SASLChallenge = namedtuple('sasl_challenge', ['challenge']) -SASLChallenge._code = 0x00000042 # pylint:disable=protected-access -SASLChallenge._definition = (FIELD('challenge', AMQPTypes.binary, True, None, False),) # pylint:disable=protected-access +SASLChallenge = namedtuple("SASLChallenge", ["challenge"]) +SASLChallenge._code = 0x00000042 # type: ignore # pylint:disable=protected-access +SASLChallenge._definition = (FIELD("challenge", AMQPTypes.binary, True, None, False),) # type: ignore # pylint:disable=protected-access if _CAN_ADD_DOCSTRING: SASLChallenge.__doc__ = """ Security mechanism challenge. @@ -599,9 +606,9 @@ """ -SASLResponse = namedtuple('sasl_response', ['response']) -SASLResponse._code = 0x00000043 # pylint:disable=protected-access -SASLResponse._definition = (FIELD('response', AMQPTypes.binary, True, None, False),) # pylint:disable=protected-access +SASLResponse = namedtuple("SASLResponse", ["response"]) +SASLResponse._code = 0x00000043 # type: ignore # pylint:disable=protected-access +SASLResponse._definition = (FIELD("response", AMQPTypes.binary, True, None, False),) # type: ignore # pylint:disable=protected-access if _CAN_ADD_DOCSTRING: SASLResponse.__doc__ = """ Security mechanism response. @@ -612,11 +619,12 @@ """ -SASLOutcome = namedtuple('sasl_outcome', ['code', 'additional_data']) -SASLOutcome._code = 0x00000044 # pylint:disable=protected-access -SASLOutcome._definition = ( # pylint:disable=protected-access - FIELD('code', AMQPTypes.ubyte, True, None, False), - FIELD('additional_data', AMQPTypes.binary, False, None, False)) +SASLOutcome = namedtuple("SASLOutcome", ["code", "additional_data"]) +SASLOutcome._code = 0x00000044 # type: ignore # pylint:disable=protected-access +SASLOutcome._definition = ( # type: ignore # pylint:disable=protected-access + FIELD("code", AMQPTypes.ubyte, True, None, False), + FIELD("additional_data", AMQPTypes.binary, False, None, False), +) if _CAN_ADD_DOCSTRING: SASLOutcome.__doc__ = """ Indicates the outcome of the sasl dialog. @@ -631,3 +639,20 @@ The additional-data field carries additional data on successful authentication outcomeas specified by the SASL specification (RFC-4422). If the authentication is unsuccessful, this field is not set. """ + +Performative = Union[ + OpenFrame, + BeginFrame, + AttachFrame, + FlowFrame, + TransferFrame, + DispositionFrame, + DetachFrame, + EndFrame, + CloseFrame, + SASLMechanism, + SASLInit, + SASLChallenge, + SASLResponse, + SASLOutcome, +] diff --git a/uamqp/receiver.py b/uamqp/receiver.py index a4d93b01e..f961c81a4 100644 --- a/uamqp/receiver.py +++ b/uamqp/receiver.py @@ -1,47 +1,37 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -import uuid import logging -from io import BytesIO +import uuid -from ._decode import decode_payload -from .constants import DEFAULT_LINK_CREDIT, Role -from .endpoints import Target -from .link import Link -from .message import Message, Properties, Header -from .constants import ( - DEFAULT_LINK_CREDIT, - SessionState, - SessionTransferState, - LinkDeliverySettleReason, - LinkState -) -from .performatives import ( - AttachFrame, - DetachFrame, +from uamqp._decode import decode_payload +from uamqp.constants import LinkState +from uamqp.constants import Role +from uamqp.link import Link +from uamqp.performatives import ( TransferFrame, DispositionFrame, - FlowFrame, ) - _LOGGER = logging.getLogger(__name__) class ReceiverLink(Link): - def __init__(self, session, handle, source_address, **kwargs): - name = kwargs.pop('name', None) or str(uuid.uuid4()) + name = kwargs.pop("name", None) or str(uuid.uuid4()) role = Role.Receiver - if 'target_address' not in kwargs: - kwargs['target_address'] = "receiver-link-{}".format(name) - super(ReceiverLink, self).__init__(session, handle, name, role, source_address=source_address, **kwargs) - self.on_message_received = kwargs.get('on_message_received') - self.on_transfer_received = kwargs.get('on_transfer_received') + if "target_address" not in kwargs: + kwargs["target_address"] = "receiver-link-{}".format(name) + super(ReceiverLink, self).__init__( + session, handle, name, role, source_address=source_address, **kwargs + ) + self.on_message_received = kwargs.get("on_message_received") + self.on_transfer_received = kwargs.get("on_transfer_received") if not self.on_message_received and not self.on_transfer_received: raise ValueError("Must specify either a message or transfer handler.") @@ -49,9 +39,9 @@ def _process_incoming_message(self, frame, message): try: if self.on_message_received: return self.on_message_received(message) - elif self.on_transfer_received: + if self.on_transfer_received: return self.on_transfer_received(frame, message) - except Exception as e: + except Exception as e: # pylint: disable=broad-except _LOGGER.error("Handler function failed with error: %r", e) return None @@ -67,7 +57,9 @@ def _incoming_attach(self, frame): def _incoming_transfer(self, frame): if self.network_trace: - _LOGGER.info("<- %r", TransferFrame(*frame), extra=self.network_trace_params) + _LOGGER.info( + "<- %r", TransferFrame(*frame), extra=self.network_trace_params + ) self.current_link_credit -= 1 self.delivery_count += 1 self.received_delivery_id = frame[1] # delivery_id @@ -95,13 +87,24 @@ def _outgoing_disposition(self, delivery_id, delivery_state): last=delivery_id, settled=True, state=delivery_state, - batchable=None + batchable=None, ) if self.network_trace: - _LOGGER.info("-> %r", DispositionFrame(*disposition_frame), extra=self.network_trace_params) + _LOGGER.info( + "-> %r", + DispositionFrame(*disposition_frame), + extra=self.network_trace_params, + ) self._session._outgoing_disposition(disposition_frame) def send_disposition(self, delivery_id, delivery_state=None): if self._is_closed: raise ValueError("Link already closed.") self._outgoing_disposition(delivery_id, delivery_state) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # check link_create_from_endpoint in C lib + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... diff --git a/uamqp/sasl.py b/uamqp/sasl.py index 99dd25d43..b355f8438 100644 --- a/uamqp/sasl.py +++ b/uamqp/sasl.py @@ -1,24 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -import struct -from enum import Enum +from uamqp._transport import SSLTransport, AMQPS_PORT +from uamqp.constants import SASLCode, SASL_HEADER_FRAME +from uamqp.performatives import SASLInit -from ._transport import SSLTransport, AMQPS_PORT -from .types import AMQPTypes, TYPE, VALUE -from .constants import FIELD, SASLCode, SASL_HEADER_FRAME -from .performatives import ( - SASLOutcome, - SASLResponse, - SASLChallenge, - SASLInit -) - -_SASL_FRAME_TYPE = b'\x01' +_SASL_FRAME_TYPE = b"\x01" class SASLPlainCredential(object): @@ -26,7 +17,7 @@ class SASLPlainCredential(object): See https://tools.ietf.org/html/rfc4616 for details """ - mechanism = b'PLAIN' + mechanism = b"PLAIN" def __init__(self, authcid, passwd, authzid=None): self.authcid = authcid @@ -35,13 +26,13 @@ def __init__(self, authcid, passwd, authzid=None): def start(self): if self.authzid: - login_response = self.authzid.encode('utf-8') + login_response = self.authzid.encode("utf-8") else: - login_response = b'' - login_response += b'\0' - login_response += self.authcid.encode('utf-8') - login_response += b'\0' - login_response += self.passwd.encode('utf-8') + login_response = b"" + login_response += b"\0" + login_response += self.authcid.encode("utf-8") + login_response += b"\0" + login_response += self.passwd.encode("utf-8") return login_response @@ -50,10 +41,10 @@ class SASLAnonymousCredential(object): See https://tools.ietf.org/html/rfc4505 for details """ - mechanism = b'ANONYMOUS' + mechanism = b"ANONYMOUS" - def start(self): - return b'' + def start(self): # pylint: disable=no-self-use + return b"" class SASLExternalCredential(object): @@ -63,34 +54,53 @@ class SASLExternalCredential(object): authentication data. """ - mechanism = b'EXTERNAL' + mechanism = b"EXTERNAL" - def start(self): - return b'' + def start(self): # pylint: disable=no-self-use + return b"" class SASLTransport(SSLTransport): - - def __init__(self, host, credential, port=AMQPS_PORT, connect_timeout=None, ssl=None, **kwargs): + def __init__( + self, + host, + credential, + port=AMQPS_PORT, + connect_timeout=None, + ssl=None, + **kwargs + ): self.credential = credential ssl = ssl or True - super(SASLTransport, self).__init__(host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs) + super(SASLTransport, self).__init__( + host, port=port, connect_timeout=connect_timeout, ssl=ssl, **kwargs + ) def negotiate(self): with self.block(): self.write(SASL_HEADER_FRAME) _, returned_header = self.receive_frame() if returned_header[1] != SASL_HEADER_FRAME: - raise ValueError("Mismatching AMQP header protocol. Expected: {}, received: {}".format( - SASL_HEADER_FRAME, returned_header[1])) + raise ValueError( + "Mismatching AMQP header protocol. Expected: {!r}, received: {!r}".format( + SASL_HEADER_FRAME, returned_header[1] + ) + ) _, supported_mechansisms = self.receive_frame(verify_frame_type=1) - if self.credential.mechanism not in supported_mechansisms[1][0]: # sasl_server_mechanisms - raise ValueError("Unsupported SASL credential type: {}".format(self.credential.mechanism)) + if ( + self.credential.mechanism not in supported_mechansisms[1][0] + ): # sasl_server_mechanisms + raise ValueError( + "Unsupported SASL credential type: {}".format( + self.credential.mechanism + ) + ) sasl_init = SASLInit( mechanism=self.credential.mechanism, initial_response=self.credential.start(), - hostname=self.host) + hostname=self.host, + ) self.send_frame(0, sasl_init, frame_type=_SASL_FRAME_TYPE) _, next_frame = self.receive_frame(verify_frame_type=1) @@ -99,5 +109,6 @@ def negotiate(self): raise NotImplementedError("Unsupported SASL challenge") if fields[0] == SASLCode.Ok: # code return - else: - raise ValueError("SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields)) + raise ValueError( + "SASL negotiation failed.\nOutcome: {}\nDetails: {}".format(*fields) + ) diff --git a/uamqp/sender.py b/uamqp/sender.py index 7b53f793c..b381c7ba9 100644 --- a/uamqp/sender.py +++ b/uamqp/sender.py @@ -1,54 +1,50 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- -import struct -import uuid +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access + import logging +import struct import time +import uuid -from ._encode import encode_payload -from .endpoints import Source -from .link import Link -from .constants import ( - SessionState, +from uamqp._encode import encode_payload +from uamqp.constants import ( SessionTransferState, LinkDeliverySettleReason, LinkState, Role, - SenderSettleMode + SenderSettleMode, ) -from .performatives import ( - AttachFrame, - DetachFrame, +from uamqp.error import AMQPLinkError, ErrorCondition +from uamqp.link import Link +from uamqp.performatives import ( TransferFrame, - DispositionFrame, - FlowFrame, ) -from .error import AMQPLinkError, ErrorCondition _LOGGER = logging.getLogger(__name__) class PendingDelivery(object): - def __init__(self, **kwargs): - self.message = kwargs.get('message') + self.message = kwargs.get("message") self.sent = False self.frame = None - self.on_delivery_settled = kwargs.get('on_delivery_settled') - self.link = kwargs.get('link') + self.on_delivery_settled = kwargs.get("on_delivery_settled") + self.link = kwargs.get("link") self.start = time.time() self.transfer_state = None - self.timeout = kwargs.get('timeout') - self.settled = kwargs.get('settled', False) - + self.timeout = kwargs.get("timeout") + self.settled = kwargs.get("settled", False) + def on_settled(self, reason, state): if self.on_delivery_settled and not self.settled: try: self.on_delivery_settled(reason, state) - except Exception as e: + except Exception as e: # pylint: disable=broad-except # TODO: this swallows every error in on_delivery_settled, which mean we # 1. only handle errors we care about in the callback # 2. ignore errors we don't care @@ -58,13 +54,14 @@ def on_settled(self, reason, state): class SenderLink(Link): - def __init__(self, session, handle, target_address, **kwargs): - name = kwargs.pop('name', None) or str(uuid.uuid4()) + name = kwargs.pop("name", None) or str(uuid.uuid4()) role = Role.Sender - if 'source_address' not in kwargs: - kwargs['source_address'] = "sender-link-{}".format(name) - super(SenderLink, self).__init__(session, handle, name, role, target_address=target_address, **kwargs) + if "source_address" not in kwargs: + kwargs["source_address"] = "sender-link-{}".format(name) + super(SenderLink, self).__init__( + session, handle, name, role, target_address=target_address, **kwargs + ) self._unsent_messages = [] def _incoming_attach(self, frame): @@ -78,11 +75,15 @@ def _incoming_flow(self, frame): rcv_delivery_count = frame[5] # delivery_count if frame[4] is not None: # handle if rcv_link_credit is None or rcv_delivery_count is None: - _LOGGER.info("Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link.") + _LOGGER.info( + "Unable to get link-credit or delivery-count from incoming ATTACH. Detaching link." + ) self._remove_pending_deliveries() self._set_state(LinkState.DETACHED) # TODO: Send detach now? else: - self.current_link_credit = rcv_delivery_count + rcv_link_credit - self.delivery_count + self.current_link_credit = ( + rcv_delivery_count + rcv_link_credit - self.delivery_count + ) if self.current_link_credit > 0: self._send_unsent_messages() @@ -91,21 +92,25 @@ def _outgoing_transfer(self, delivery): encode_payload(output, delivery.message) delivery_count = self.delivery_count + 1 delivery.frame = { - 'handle': self.handle, - 'delivery_tag': struct.pack('>I', abs(delivery_count)), - 'message_format': delivery.message._code, - 'settled': delivery.settled, - 'more': False, - 'rcv_settle_mode': None, - 'state': None, - 'resume': None, - 'aborted': None, - 'batchable': None, - 'payload': output + "handle": self.handle, + "delivery_tag": struct.pack(">I", abs(delivery_count)), + "message_format": delivery.message._code, + "settled": delivery.settled, + "more": False, + "rcv_settle_mode": None, + "state": None, + "resume": None, + "aborted": None, + "batchable": None, + "payload": output, } if self.network_trace: # TODO: whether we should move frame tracing into centralized place e.g. connection.py - _LOGGER.info("-> %r", TransferFrame(delivery_id='', **delivery.frame), extra=self.network_trace_params) + _LOGGER.info( + "-> %r", + TransferFrame(delivery_id="", **delivery.frame), + extra=self.network_trace_params, + ) self._session._outgoing_transfer(delivery) if delivery.transfer_state == SessionTransferState.OKAY: self.delivery_count = delivery_count @@ -114,7 +119,7 @@ def _outgoing_transfer(self, delivery): if delivery.settled: delivery.on_settled(LinkDeliverySettleReason.SETTLED, None) else: - self._pending_deliveries[delivery.frame['delivery_id']] = delivery + self._pending_deliveries[delivery.frame["delivery_id"]] = delivery elif delivery.transfer_state == SessionTransferState.ERROR: raise ValueError("Message failed to send") if self.current_link_credit <= 0: @@ -125,20 +130,23 @@ def _incoming_disposition(self, frame): if not frame[3]: # settled return range_end = (frame[2] or frame[1]) + 1 # first or last - settled_ids = [i for i in range(frame[1], range_end)] - for settled_id in settled_ids: + for settled_id in range(frame[1], range_end): delivery = self._pending_deliveries.pop(settled_id, None) if delivery: - delivery.on_settled(LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4]) # state + delivery.on_settled( + LinkDeliverySettleReason.DISPOSITION_RECEIVED, frame[4] + ) # state def _update_pending_delivery_status(self): # TODO now = time.time() expired = [] for delivery in self._pending_deliveries.values(): if delivery.timeout and (now - delivery.start) >= delivery.timeout: - expired.append(delivery.frame['delivery_id']) + expired.append(delivery.frame["delivery_id"]) delivery.on_settled(LinkDeliverySettleReason.TIMEOUT, None) - self._pending_deliveries = {i: d for i, d in self._pending_deliveries.items() if i not in expired} + self._pending_deliveries = { + i: d for i, d in self._pending_deliveries.items() if i not in expired + } def _send_unsent_messages(self): unsent = [] @@ -154,14 +162,14 @@ def send_transfer(self, message, **kwargs): if self.state != LinkState.ATTACHED: raise AMQPLinkError( # TODO: should we introduce MessageHandler to indicate the handler is in wrong state condition=ErrorCondition.ClientError, # TODO: should this be a ClientError? - description="Link is not attached." + description="Link is not attached.", ) settled = self.send_settle_mode == SenderSettleMode.Settled if self.send_settle_mode == SenderSettleMode.Mixed: - settled = kwargs.pop('settled', True) + settled = kwargs.pop("settled", True) delivery = PendingDelivery( - on_delivery_settled=kwargs.get('on_send_complete'), - timeout=kwargs.get('timeout'), + on_delivery_settled=kwargs.get("on_send_complete"), + timeout=kwargs.get("timeout"), link=self, message=message, settled=settled, @@ -176,10 +184,21 @@ def send_transfer(self, message, **kwargs): def cancel_transfer(self, delivery): try: - delivery = self._pending_deliveries.pop(delivery.frame['delivery_id']) + delivery = self._pending_deliveries.pop(delivery.frame["delivery_id"]) delivery.on_settled(LinkDeliverySettleReason.CANCELLED, None) return except KeyError: pass # todo remove from unset messages - raise ValueError("No pending delivery with ID '{}' found.".format(delivery.frame['delivery_id'])) + raise ValueError( + "No pending delivery with ID '{}' found.".format( + delivery.frame["delivery_id"] + ) + ) + + @classmethod + def from_incoming_frame(cls, session, handle, frame): + # check link_create_from_endpoint in C lib + raise NotImplementedError( + "Pending" + ) # TODO: Assuming we establish all links for now... diff --git a/uamqp/session.py b/uamqp/session.py index 905a35da5..6caf27a66 100644 --- a/uamqp/session.py +++ b/uamqp/session.py @@ -1,36 +1,29 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- + +# pylint: disable=protected-access -import uuid import logging -from enum import Enum import time +import uuid +from typing import Union, Optional -from .constants import ( - INCOMING_WINDOW, - OUTGOING_WIDNOW, - ConnectionState, - SessionState, - SessionTransferState, - Role -) -from .endpoints import Source, Target -from .sender import SenderLink -from .receiver import ReceiverLink -from .management_link import ManagementLink -from .performatives import ( +from uamqp._encode import encode_frame +from uamqp.constants import ConnectionState, SessionState, SessionTransferState, Role +from uamqp.management_link import ManagementLink +from uamqp.performatives import ( BeginFrame, EndFrame, FlowFrame, - AttachFrame, - DetachFrame, TransferFrame, - DispositionFrame + DispositionFrame, ) -from ._encode import encode_frame +from uamqp.receiver import ReceiverLink +from uamqp.sender import SenderLink +from uamqp.error import AMQPError _LOGGER = logging.getLogger(__name__) @@ -48,27 +41,27 @@ class Session(object): """ def __init__(self, connection, channel, **kwargs): - self.name = kwargs.pop('name', None) or str(uuid.uuid4()) + self.name = kwargs.pop("name", None) or str(uuid.uuid4()) self.state = SessionState.UNMAPPED - self.handle_max = kwargs.get('handle_max', 4294967295) - self.properties = kwargs.pop('properties', None) + self.handle_max = kwargs.get("handle_max", 4294967295) + self.properties = kwargs.pop("properties", None) self.channel = channel self.remote_channel = None - self.next_outgoing_id = kwargs.pop('next_outgoing_id', 0) + self.next_outgoing_id = kwargs.pop("next_outgoing_id", 0) self.next_incoming_id = None - self.incoming_window = kwargs.pop('incoming_window', 1) - self.outgoing_window = kwargs.pop('outgoing_window', 1) + self.incoming_window = kwargs.pop("incoming_window", 1) + self.outgoing_window = kwargs.pop("outgoing_window", 1) self.target_incoming_window = self.incoming_window self.remote_incoming_window = 0 self.remote_outgoing_window = 0 self.offered_capabilities = None - self.desired_capabilities = kwargs.pop('desired_capabilities', None) + self.desired_capabilities = kwargs.pop("desired_capabilities", None) - self.allow_pipelined_open = kwargs.pop('allow_pipelined_open', True) - self.idle_wait_time = kwargs.get('idle_wait_time', 0.1) - self.network_trace = kwargs['network_trace'] - self.network_trace_params = kwargs['network_trace_params'] - self.network_trace_params['session'] = self.name + self.allow_pipelined_open = kwargs.pop("allow_pipelined_open", True) + self.idle_wait_time = kwargs.get("idle_wait_time", 0.1) + self.network_trace = kwargs["network_trace"] + self.network_trace_params = kwargs["network_trace_params"] + self.network_trace_params["session"] = self.name self.links = {} self._connection = connection @@ -84,6 +77,7 @@ def __exit__(self, *args): @classmethod def from_incoming_frame(cls, connection, channel, frame): + # pylint: disable=unused-argument # check session_create_from_endpoint in C lib new_session = cls(connection, channel) return new_session @@ -95,7 +89,12 @@ def _set_state(self, new_state): return previous_state = self.state self.state = new_state - _LOGGER.info("Session state changed: %r -> %r", previous_state, new_state, extra=self.network_trace_params) + _LOGGER.info( + "Session state changed: %r -> %r", + previous_state, + new_state, + extra=self.network_trace_params, + ) for link in self.links.values(): link._on_session_state_change() @@ -113,19 +112,31 @@ def _get_next_output_handle(self): :rtype: int """ if len(self._output_handles) >= self.handle_max: - raise ValueError("Maximum number of handles ({}) has been reached.".format(self.handle_max)) - next_handle = next(i for i in range(1, self.handle_max) if i not in self._output_handles) + raise ValueError( + "Maximum number of handles ({}) has been reached.".format( + self.handle_max + ) + ) + next_handle = next( + i for i in range(1, self.handle_max) if i not in self._output_handles + ) return next_handle - + def _outgoing_begin(self): begin_frame = BeginFrame( - remote_channel=self.remote_channel if self.state == SessionState.BEGIN_RCVD else None, + remote_channel=self.remote_channel + if self.state == SessionState.BEGIN_RCVD + else None, next_outgoing_id=self.next_outgoing_id, outgoing_window=self.outgoing_window, incoming_window=self.incoming_window, handle_max=self.handle_max, - offered_capabilities=self.offered_capabilities if self.state == SessionState.BEGIN_RCVD else None, - desired_capabilities=self.desired_capabilities if self.state == SessionState.UNMAPPED else None, + offered_capabilities=self.offered_capabilities + if self.state == SessionState.BEGIN_RCVD + else None, + desired_capabilities=self.desired_capabilities + if self.state == SessionState.UNMAPPED + else None, properties=self.properties, ) if self.network_trace: @@ -156,7 +167,11 @@ def _outgoing_end(self, error=None): def _incoming_end(self, frame): if self.network_trace: _LOGGER.info("<- %r", EndFrame(*frame), extra=self.network_trace_params) - if self.state not in [SessionState.END_RCVD, SessionState.END_SENT, SessionState.DISCARDING]: + if self.state not in [ + SessionState.END_RCVD, + SessionState.END_SENT, + SessionState.DISCARDING, + ]: self._set_state(SessionState.END_RCVD) # TODO: Clean up all links # TODO: handling error @@ -168,12 +183,18 @@ def _outgoing_attach(self, frame): def _incoming_attach(self, frame): try: - self._input_handles[frame[1]] = self.links[frame[0].decode('utf-8')] # name and handle + self._input_handles[frame[1]] = self.links[ + frame[0].decode("utf-8") + ] # name and handle self._input_handles[frame[1]]._incoming_attach(frame) except KeyError: - outgoing_handle = self._get_next_output_handle() # TODO: catch max-handles error + outgoing_handle = ( + self._get_next_output_handle() + ) # TODO: catch max-handles error if frame[2] == Role.Sender: # role - new_link = ReceiverLink.from_incoming_frame(self, outgoing_handle, frame) + new_link = ReceiverLink.from_incoming_frame( + self, outgoing_handle, frame + ) else: new_link = SenderLink.from_incoming_frame(self, outgoing_handle, frame) new_link._incoming_attach(frame) @@ -182,15 +203,17 @@ def _incoming_attach(self, frame): self._input_handles[frame[1]] = new_link except ValueError: pass # TODO: Reject link - + def _outgoing_flow(self, frame=None): link_flow = frame or {} - link_flow.update({ - 'next_incoming_id': self.next_incoming_id, - 'incoming_window': self.incoming_window, - 'next_outgoing_id': self.next_outgoing_id, - 'outgoing_window': self.outgoing_window - }) + link_flow.update( + { + "next_incoming_id": self.next_incoming_id, + "incoming_window": self.incoming_window, + "next_outgoing_id": self.next_outgoing_id, + "outgoing_window": self.outgoing_window, + } + ) flow_frame = FlowFrame(**link_flow) if self.network_trace: _LOGGER.info("-> %r", flow_frame, extra=self.network_trace_params) @@ -200,8 +223,12 @@ def _incoming_flow(self, frame): if self.network_trace: _LOGGER.info("<- %r", FlowFrame(*frame), extra=self.network_trace_params) self.next_incoming_id = frame[2] # next_outgoing_id - remote_incoming_id = frame[0] or self.next_outgoing_id # next_incoming_id TODO "initial-outgoing-id" - self.remote_incoming_window = remote_incoming_id + frame[1] - self.next_outgoing_id # incoming_window + remote_incoming_id = ( + frame[0] or self.next_outgoing_id + ) # next_incoming_id TODO "initial-outgoing-id" + self.remote_incoming_window = ( + remote_incoming_id + frame[1] - self.next_outgoing_id + ) # incoming_window self.remote_outgoing_window = frame[3] # outgoing_window if frame[4] is not None: # handle self._input_handles[frame[4]]._incoming_flow(frame) @@ -216,58 +243,64 @@ def _outgoing_transfer(self, delivery): if self.remote_incoming_window <= 0: delivery.transfer_state = SessionTransferState.BUSY else: - payload = delivery.frame['payload'] + payload = delivery.frame["payload"] payload_size = len(payload) - delivery.frame['delivery_id'] = self.next_outgoing_id + delivery.frame["delivery_id"] = self.next_outgoing_id # calculate the transfer frame encoding size excluding the payload - delivery.frame['payload'] = b"" + delivery.frame["payload"] = b"" # TODO: encoding a frame would be expensive, we might want to improve depending on the perf test results encoded_frame = encode_frame(TransferFrame(**delivery.frame))[1] transfer_overhead_size = len(encoded_frame) # available size for payload per frame is calculated as following: # remote max frame size - transfer overhead (calculated) - header (8 bytes) - available_frame_size = self._connection._remote_max_frame_size - transfer_overhead_size - 8 + available_frame_size = ( + self._connection._remote_max_frame_size - transfer_overhead_size - 8 + ) start_idx = 0 remaining_payload_cnt = payload_size # encode n-1 frames if payload_size > available_frame_size while remaining_payload_cnt > available_frame_size: tmp_delivery_frame = { - 'handle': delivery.frame['handle'], - 'delivery_tag': delivery.frame['delivery_tag'], - 'message_format': delivery.frame['message_format'], - 'settled': delivery.frame['settled'], - 'more': True, - 'rcv_settle_mode': delivery.frame['rcv_settle_mode'], - 'state': delivery.frame['state'], - 'resume': delivery.frame['resume'], - 'aborted': delivery.frame['aborted'], - 'batchable': delivery.frame['batchable'], - 'payload': payload[start_idx:start_idx+available_frame_size], - 'delivery_id': self.next_outgoing_id + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": True, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "payload": payload[start_idx : start_idx + available_frame_size], + "delivery_id": self.next_outgoing_id, } - self._connection._process_outgoing_frame(self.channel, TransferFrame(**tmp_delivery_frame)) + self._connection._process_outgoing_frame( + self.channel, TransferFrame(**tmp_delivery_frame) + ) start_idx += available_frame_size remaining_payload_cnt -= available_frame_size # encode the last frame tmp_delivery_frame = { - 'handle': delivery.frame['handle'], - 'delivery_tag': delivery.frame['delivery_tag'], - 'message_format': delivery.frame['message_format'], - 'settled': delivery.frame['settled'], - 'more': False, - 'rcv_settle_mode': delivery.frame['rcv_settle_mode'], - 'state': delivery.frame['state'], - 'resume': delivery.frame['resume'], - 'aborted': delivery.frame['aborted'], - 'batchable': delivery.frame['batchable'], - 'payload': payload[start_idx:], - 'delivery_id': self.next_outgoing_id + "handle": delivery.frame["handle"], + "delivery_tag": delivery.frame["delivery_tag"], + "message_format": delivery.frame["message_format"], + "settled": delivery.frame["settled"], + "more": False, + "rcv_settle_mode": delivery.frame["rcv_settle_mode"], + "state": delivery.frame["state"], + "resume": delivery.frame["resume"], + "aborted": delivery.frame["aborted"], + "batchable": delivery.frame["batchable"], + "payload": payload[start_idx:], + "delivery_id": self.next_outgoing_id, } - self._connection._process_outgoing_frame(self.channel, TransferFrame(**tmp_delivery_frame)) + self._connection._process_outgoing_frame( + self.channel, TransferFrame(**tmp_delivery_frame) + ) self.next_outgoing_id += 1 self.remote_incoming_window -= 1 self.outgoing_window -= 1 @@ -280,7 +313,7 @@ def _incoming_transfer(self, frame): try: self._input_handles[frame[0]]._incoming_transfer(frame) # handle except KeyError: - pass #TODO: "unattached handle" + pass # TODO: "unattached handle" if self.incoming_window == 0: self.incoming_window = self.target_incoming_window self._outgoing_flow() @@ -290,7 +323,9 @@ def _outgoing_disposition(self, frame): def _incoming_disposition(self, frame): if self.network_trace: - _LOGGER.info("<- %r", DispositionFrame(*frame), extra=self.network_trace_params) + _LOGGER.info( + "<- %r", DispositionFrame(*frame), extra=self.network_trace_params + ) for link in self._input_handles.values(): link._incoming_disposition(frame) @@ -310,7 +345,7 @@ def _incoming_detach(self, frame): def _wait_for_response(self, wait, end_state): # type: (Union[bool, float], SessionState) -> None - if wait == True: + if wait is True: self._connection.listen(wait=False) while self.state != end_state: time.sleep(self.idle_wait_time) @@ -330,10 +365,12 @@ def begin(self, wait=False): if wait: self._wait_for_response(wait, SessionState.BEGIN_SENT) elif not self.allow_pipelined_open: - raise ValueError("Connection has been configured to not allow piplined-open. Please set 'wait' parameter.") + raise ValueError( + "Connection has been configured to not allow piplined-open. Please set 'wait' parameter." + ) def end(self, error=None, wait=False): - # type: (Optional[AMQPError]) -> None + # type: (Optional[AMQPError], Union[bool, float]) -> None try: if self.state not in [SessionState.UNMAPPED, SessionState.DISCARDING]: self._outgoing_end(error=error) @@ -341,7 +378,7 @@ def end(self, error=None, wait=False): new_state = SessionState.DISCARDING if error else SessionState.END_SENT self._set_state(new_state) self._wait_for_response(wait, SessionState.UNMAPPED) - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except _LOGGER.info("An error occurred when ending the session: %r", exc) self._set_state(SessionState.UNMAPPED) @@ -351,9 +388,10 @@ def create_receiver_link(self, source_address, **kwargs): self, handle=assigned_handle, source_address=source_address, - network_trace=kwargs.pop('network_trace', self.network_trace), + network_trace=kwargs.pop("network_trace", self.network_trace), network_trace_params=dict(self.network_trace_params), - **kwargs) + **kwargs + ) self.links[link.name] = link self._output_handles[assigned_handle] = link return link @@ -364,16 +402,18 @@ def create_sender_link(self, target_address, **kwargs): self, handle=assigned_handle, target_address=target_address, - network_trace=kwargs.pop('network_trace', self.network_trace), + network_trace=kwargs.pop("network_trace", self.network_trace), network_trace_params=dict(self.network_trace_params), - **kwargs) + **kwargs + ) self._output_handles[assigned_handle] = link self.links[link.name] = link return link - + def create_request_response_link_pair(self, endpoint, **kwargs): return ManagementLink( self, endpoint, - network_trace=kwargs.pop('network_trace', self.network_trace), - **kwargs) + network_trace=kwargs.pop("network_trace", self.network_trace), + **kwargs + ) diff --git a/uamqp/types.py b/uamqp/types.py deleted file mode 100644 index db478af59..000000000 --- a/uamqp/types.py +++ /dev/null @@ -1,90 +0,0 @@ -#------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -#-------------------------------------------------------------------------- - -from enum import Enum - - -TYPE = 'TYPE' -VALUE = 'VALUE' - - -class AMQPTypes(object): # pylint: disable=no-init - null = 'NULL' - boolean = 'BOOL' - ubyte = 'UBYTE' - byte = 'BYTE' - ushort = 'USHORT' - short = 'SHORT' - uint = 'UINT' - int = 'INT' - ulong = 'ULONG' - long = 'LONG' - float = 'FLOAT' - double = 'DOUBLE' - timestamp = 'TIMESTAMP' - uuid = 'UUID' - binary = 'BINARY' - string = 'STRING' - symbol = 'SYMBOL' - list = 'LIST' - map = 'MAP' - array = 'ARRAY' - described = 'DESCRIBED' - - -class FieldDefinition(Enum): - fields = "fields" - annotations = "annotations" - message_id = "message-id" - app_properties = "application-properties" - node_properties = "node-properties" - filter_set = "filter-set" - - -class ObjDefinition(Enum): - source = "source" - target = "target" - delivery_state = "delivery-state" - error = "error" - - -class ConstructorBytes(object): # pylint: disable=no-init - null = b'\x40' - bool = b'\x56' - bool_true = b'\x41' - bool_false = b'\x42' - ubyte = b'\x50' - byte = b'\x51' - ushort = b'\x60' - short = b'\x61' - uint_0 = b'\x43' - uint_small = b'\x52' - int_small = b'\x54' - uint_large = b'\x70' - int_large = b'\x71' - ulong_0 = b'\x44' - ulong_small = b'\x53' - long_small = b'\x55' - ulong_large = b'\x80' - long_large = b'\x81' - float = b'\x72' - double = b'\x82' - timestamp = b'\x83' - uuid = b'\x98' - binary_small = b'\xA0' - binary_large = b'\xB0' - string_small = b'\xA1' - string_large = b'\xB1' - symbol_small = b'\xA3' - symbol_large = b'\xB3' - list_0 = b'\x45' - list_small = b'\xC0' - list_large = b'\xD0' - map_small = b'\xC1' - map_large = b'\xD1' - array_small = b'\xE0' - array_large = b'\xF0' - descriptor = b'\x00' diff --git a/uamqp/utils.py b/uamqp/utils.py index 33255956c..67c138572 100644 --- a/uamqp/utils.py +++ b/uamqp/utils.py @@ -1,19 +1,18 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -import six import datetime +import time from base64 import b64encode from hashlib import sha256 from hmac import HMAC from urllib.parse import urlencode, quote_plus -import time -from .types import TYPE, VALUE, AMQPTypes -from ._encode import encode_payload +from uamqp._encode import encode_payload +from uamqp.amqp_types import TYPE, VALUE, AMQPTypes class UTC(datetime.tzinfo): @@ -48,8 +47,8 @@ def utc_now(): return datetime.datetime.now(tz=TZ_UTC) -def encode(value, encoding='UTF-8'): - return value.encode(encoding) if isinstance(value, six.text_type) else value +def encode(value, encoding="UTF-8"): + return value.encode(encoding) if isinstance(value, str) else value def generate_sas_token(audience, policy, key, expiry=None): @@ -70,16 +69,12 @@ def generate_sas_token(audience, policy, key, expiry=None): encoded_key = key.encode("utf-8") ttl = int(expiry) - sign_key = '%s\n%d' % (encoded_uri, ttl) - signature = b64encode(HMAC(encoded_key, sign_key.encode('utf-8'), sha256).digest()) - result = { - 'sr': audience, - 'sig': signature, - 'se': str(ttl) - } + sign_key = "%s\n%d" % (encoded_uri, ttl) + signature = b64encode(HMAC(encoded_key, sign_key.encode("utf-8"), sha256).digest()) + result = {"sr": audience, "sig": signature, "se": str(ttl)} if policy: - result['skn'] = encoded_policy - return 'SharedAccessSignature ' + urlencode(result) + result["skn"] = encoded_policy + return "SharedAccessSignature " + urlencode(result) def add_batch(batch, message): @@ -89,7 +84,7 @@ def add_batch(batch, message): batch.data.append(output) -def encode_str(data, encoding='utf-8'): +def encode_str(data, encoding="utf-8"): try: return data.encode(encoding) except AttributeError: @@ -101,16 +96,14 @@ def normalized_data_body(data, **kwargs): encoding = kwargs.get("encoding", "utf-8") if isinstance(data, list): return [encode_str(item, encoding) for item in data] - else: - return [encode_str(data, encoding)] + return [encode_str(data, encoding)] def normalized_sequence_body(sequence): # A helper method to normalize input into AMQP Sequence Body format if isinstance(sequence, list) and all([isinstance(b, list) for b in sequence]): return sequence - elif isinstance(sequence, list): - return [sequence] + return [sequence] def get_message_encoded_size(message):