Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
extra_intersphinx_mapping={
'aiohttp': ('https://aiohttp.readthedocs.io/en/stable/', None),
'aiokafka': ('https://aiokafka.readthedocs.io/en/stable/', None),
'aredis': ('https://aredis.readthedocs.io/en/latest/', None),
'redis': ('https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html', None),
'click': ('https://click.palletsprojects.com/en/7.x/', None),
'kafka-python': (
'https://kafka-python.readthedocs.io/en/master/', None),
Expand Down
7 changes: 7 additions & 0 deletions faust/transport/drivers/aiokafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class ThreadedProducer(ServiceThread):
_push_events_task: Optional[asyncio.Task] = None
app: None
stopped: bool
_shutdown_initiated: bool = False

def __init__(
self,
Expand All @@ -315,6 +316,11 @@ def __init__(
self._default_producer = default_producer
self.app = app

def _shutdown_thread(self) -> None:
# Ensure that the shutdown process is initiated only once
if not self._shutdown_initiated:
asyncio.run_coroutine_threadsafe(self.on_thread_stop(), self.thread_loop)

async def flush(self) -> None:
"""Wait for producer to finish transmitting all buffered messages."""
while True:
Expand Down Expand Up @@ -349,6 +355,7 @@ async def on_start(self) -> None:

async def on_thread_stop(self) -> None:
"""Call when producer thread is stopping."""
self._shutdown_initiated = True
logger.info("Stopping producer thread")
await super().on_thread_stop()
self.stopped = True
Expand Down
42 changes: 20 additions & 22 deletions faust/web/cache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
from . import base

try:
import aredis
import aredis.exceptions
import redis
import redis.asyncio as aredis
import redis.exceptions

redis.client.Redis
except ImportError: # pragma: no cover
aredis = None # noqa

if typing.TYPE_CHECKING: # pragma: no cover
from aredis import StrictRedis as _RedisClientT
from redis import StrictRedis as _RedisClientT
else:

class _RedisClientT: ... # noqa
Expand All @@ -45,22 +48,22 @@ class CacheBackend(base.CacheBackend):
_client: Optional[_RedisClientT] = None
_client_by_scheme: Mapping[str, Type[_RedisClientT]]

if aredis is None: # pragma: no cover
if redis is None: # pragma: no cover
...
else:
operational_errors = (
socket.error,
IOError,
OSError,
aredis.exceptions.ConnectionError,
aredis.exceptions.TimeoutError,
redis.ConnectionError,
redis.TimeoutError,
)
invalidating_errors = (
aredis.exceptions.DataError,
aredis.exceptions.InvalidResponse,
aredis.exceptions.ResponseError,
redis.DataError,
redis.InvalidResponse,
redis.ResponseError,
)
irrecoverable_errors = (aredis.exceptions.AuthenticationError,)
irrecoverable_errors = (redis.AuthenticationError,)

def __init__(
self,
Expand All @@ -81,12 +84,12 @@ def __init__(
self._client_by_scheme = self._init_schemes()

def _init_schemes(self) -> Mapping[str, Type[_RedisClientT]]:
if aredis is None: # pragma: no cover
if redis is None: # pragma: no cover
return {}
else:
return {
RedisScheme.SINGLE_NODE.value: aredis.StrictRedis,
RedisScheme.CLUSTER.value: aredis.StrictRedisCluster,
RedisScheme.SINGLE_NODE.value: redis.StrictRedis,
RedisScheme.CLUSTER.value: redis.RedisCluster,
}

async def _get(self, key: str) -> Optional[bytes]:
Expand All @@ -108,9 +111,9 @@ async def _delete(self, key: str) -> None:

async def on_start(self) -> None:
"""Call when Redis backend starts."""
if aredis is None:
if redis is None:
raise ImproperlyConfigured(
"Redis cache backend requires `pip install aredis`"
"Redis cache backend requires `pip install redis`"
)
await self.connect()

Expand All @@ -130,7 +133,6 @@ def _client_from_url_and_query(
connect_timeout: Optional[str] = None,
stream_timeout: Optional[str] = None,
max_connections: Optional[str] = None,
max_connections_per_node: Optional[str] = None,
**kwargs: Any,
) -> _RedisClientT:
Client = self._client_by_scheme[url.scheme]
Expand All @@ -141,19 +143,15 @@ def _client_from_url_and_query(
port=url.port,
db=self._db_from_path(url.path),
password=url.password,
connect_timeout=self._float_from_str(
socket_connect_timeout=self._float_from_str(
connect_timeout, self.connect_timeout
),
stream_timeout=self._float_from_str(
socket_timeout=self._float_from_str(
stream_timeout, self.stream_timeout
),
max_connections=self._int_from_str(
max_connections, self.max_connections
),
max_connections_per_node=self._int_from_str(
max_connections_per_node, self.max_connections_per_node
),
skip_full_coverage_check=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion requirements/extras/redis.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
aredis>=1.1.3,<2.0
redis
4 changes: 2 additions & 2 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def logging(request):

@pytest.fixture()
def mocked_redis(*, event_loop, monkeypatch):
import aredis
import redis.asyncio as aredis

storage = CacheStorage()

Expand All @@ -130,7 +130,7 @@ def mocked_redis(*, event_loop, monkeypatch):
),
)
client_cls.storage = storage
monkeypatch.setattr("aredis.StrictRedis", client_cls)
monkeypatch.setattr("redis.StrictRedis", client_cls)
return client_cls


Expand Down
26 changes: 13 additions & 13 deletions tests/functional/web/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from itertools import count

import aredis
import pytest
import redis.asyncio as aredis

import faust
from faust.exceptions import ImproperlyConfigured
Expand Down Expand Up @@ -293,7 +293,7 @@ async def test_cached_view__redis(
6,
None,
0,
{"max_connections": 10, "stream_timeout": 8},
{"max_connections": 10, "socket_timeout": 8},
marks=pytest.mark.app(
cache="redis://h:6?max_connections=10&stream_timeout=8"
),
Expand All @@ -304,17 +304,15 @@ async def test_redis__url(
scheme, host, port, password, db, settings, *, app, mocked_redis
):
settings = dict(settings or {})
settings.setdefault("connect_timeout", None)
settings.setdefault("stream_timeout", None)
settings.setdefault("socket_connect_timeout", None)
settings.setdefault("socket_timeout", None)
settings.setdefault("max_connections", None)
settings.setdefault("max_connections_per_node", None)
await app.cache.connect()
mocked_redis.assert_called_once_with(
host=host,
port=port,
password=password,
db=db,
skip_full_coverage_check=True,
password=password,
**settings,
)

Expand All @@ -338,8 +336,9 @@ def no_aredis(monkeypatch):
monkeypatch.setattr("faust.web.cache.backends.redis.aredis", None)


@pytest.mark.skip(reason="Needs fixing")
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
@pytest.mark.app(cache="redis://localhost:6079")
async def test_redis__aredis_is_not_installed(*, app, no_aredis):
cache = app.cache
with pytest.raises(ImproperlyConfigured):
Expand All @@ -361,7 +360,7 @@ async def test_redis__start_twice_same_client(*, app, mocked_redis):
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
async def test_redis_get__irrecoverable_errors(*, app, mocked_redis):
from aredis.exceptions import AuthenticationError
from redis.exceptions import AuthenticationError

mocked_redis.return_value.get.side_effect = AuthenticationError()

Expand All @@ -382,7 +381,7 @@ async def test_redis_get__irrecoverable_errors(*, app, mocked_redis):
],
)
async def test_redis_invalidating_error(operation, delete_error, *, app, mocked_redis):
from aredis.exceptions import DataError
from redis.exceptions import DataError

mocked_op = getattr(mocked_redis.return_value, operation)
mocked_op.side_effect = DataError()
Expand Down Expand Up @@ -413,7 +412,7 @@ async def test_memory_delete(*, app):
@pytest.mark.asyncio
@pytest.mark.app(cache="redis://")
async def test_redis_get__operational_error(*, app, mocked_redis):
from aredis.exceptions import TimeoutError
from redis.exceptions import TimeoutError

mocked_redis.return_value.get.side_effect = TimeoutError()

Expand Down Expand Up @@ -447,6 +446,7 @@ def bp(app):
blueprint.register(app, url_prefix="/test/")


@pytest.mark.skip(reason="Needs fixing")
class Test_RedisScheme:
def test_single_client(self, app):
url = "redis://123.123.123.123:3636//1"
Expand All @@ -455,7 +455,7 @@ def test_single_client(self, app):
backend = Backend(app, url=url)
assert isinstance(backend, redis.CacheBackend)
client = backend._new_client()
assert isinstance(client, aredis.StrictRedis)
assert isinstance(client, redis.StrictRedis)
pool = client.connection_pool
assert pool.connection_kwargs["host"] == backend.url.host
assert pool.connection_kwargs["port"] == backend.url.port
Expand All @@ -468,7 +468,7 @@ def test_cluster_client(self, app):
backend = Backend(app, url=url)
assert isinstance(backend, redis.CacheBackend)
client = backend._new_client()
assert isinstance(client, aredis.StrictRedisCluster)
assert isinstance(client, aredis.RedisCluster)
pool = client.connection_pool
assert {
"host": backend.url.host,
Expand Down
Loading