Tweak incoming/outgoing workers
This commit is contained in:
parent
0b6556e54a
commit
0696268d0b
5 changed files with 30 additions and 46 deletions
|
@ -69,13 +69,11 @@ def _set_next_try(
|
||||||
|
|
||||||
async def fetch_next_incoming_activity(
|
async def fetch_next_incoming_activity(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
in_flight: set[int],
|
|
||||||
) -> models.IncomingActivity | None:
|
) -> models.IncomingActivity | None:
|
||||||
where = [
|
where = [
|
||||||
models.IncomingActivity.next_try <= now(),
|
models.IncomingActivity.next_try <= now(),
|
||||||
models.IncomingActivity.is_errored.is_(False),
|
models.IncomingActivity.is_errored.is_(False),
|
||||||
models.IncomingActivity.is_processed.is_(False),
|
models.IncomingActivity.is_processed.is_(False),
|
||||||
models.IncomingActivity.id.not_in(in_flight),
|
|
||||||
]
|
]
|
||||||
q_count = await db_session.scalar(
|
q_count = await db_session.scalar(
|
||||||
select(func.count(models.IncomingActivity.id)).where(*where)
|
select(func.count(models.IncomingActivity.id)).where(*where)
|
||||||
|
@ -144,11 +142,11 @@ class IncomingActivityWorker(Worker[models.IncomingActivity]):
|
||||||
self,
|
self,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
) -> models.IncomingActivity | None:
|
) -> models.IncomingActivity | None:
|
||||||
return await fetch_next_incoming_activity(db_session, self.in_flight_ids())
|
return await fetch_next_incoming_activity(db_session)
|
||||||
|
|
||||||
|
|
||||||
async def loop() -> None:
|
async def loop() -> None:
|
||||||
await IncomingActivityWorker(workers_count=1).run_forever()
|
await IncomingActivityWorker().run_forever()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -170,13 +170,11 @@ def _set_next_try(
|
||||||
|
|
||||||
async def fetch_next_outgoing_activity(
|
async def fetch_next_outgoing_activity(
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
in_fligh: set[int],
|
|
||||||
) -> models.OutgoingActivity | None:
|
) -> models.OutgoingActivity | None:
|
||||||
where = [
|
where = [
|
||||||
models.OutgoingActivity.next_try <= now(),
|
models.OutgoingActivity.next_try <= now(),
|
||||||
models.OutgoingActivity.is_errored.is_(False),
|
models.OutgoingActivity.is_errored.is_(False),
|
||||||
models.OutgoingActivity.is_sent.is_(False),
|
models.OutgoingActivity.is_sent.is_(False),
|
||||||
models.OutgoingActivity.id.not_in(in_fligh),
|
|
||||||
]
|
]
|
||||||
q_count = await db_session.scalar(
|
q_count = await db_session.scalar(
|
||||||
select(func.count(models.OutgoingActivity.id)).where(*where)
|
select(func.count(models.OutgoingActivity.id)).where(*where)
|
||||||
|
@ -289,14 +287,14 @@ class OutgoingActivityWorker(Worker[models.OutgoingActivity]):
|
||||||
self,
|
self,
|
||||||
db_session: AsyncSession,
|
db_session: AsyncSession,
|
||||||
) -> models.OutgoingActivity | None:
|
) -> models.OutgoingActivity | None:
|
||||||
return await fetch_next_outgoing_activity(db_session, self.in_flight_ids())
|
return await fetch_next_outgoing_activity(db_session)
|
||||||
|
|
||||||
async def startup(self, db_session: AsyncSession) -> None:
|
async def startup(self, db_session: AsyncSession) -> None:
|
||||||
await _send_actor_update_if_needed(db_session)
|
await _send_actor_update_if_needed(db_session)
|
||||||
|
|
||||||
|
|
||||||
async def loop() -> None:
|
async def loop() -> None:
|
||||||
await OutgoingActivityWorker(workers_count=3).run_forever()
|
await OutgoingActivityWorker().run_forever()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -12,30 +12,9 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Worker(Generic[T]):
|
class Worker(Generic[T]):
|
||||||
def __init__(self, workers_count: int) -> None:
|
def __init__(self) -> None:
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
self._in_flight: set[int] = set()
|
|
||||||
self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1)
|
|
||||||
self._stop_event = asyncio.Event()
|
self._stop_event = asyncio.Event()
|
||||||
self._workers_count = workers_count
|
|
||||||
|
|
||||||
async def _consumer(self, db_session: AsyncSession) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
message = await self._queue.get()
|
|
||||||
try:
|
|
||||||
await self.process_message(db_session, message)
|
|
||||||
finally:
|
|
||||||
self._in_flight.remove(message.id) # type: ignore
|
|
||||||
self._queue.task_done()
|
|
||||||
|
|
||||||
async def _producer(self, db_session: AsyncSession) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
next_message = await self.get_next_message(db_session)
|
|
||||||
if next_message:
|
|
||||||
self._in_flight.add(next_message.id) # type: ignore
|
|
||||||
await self._queue.put(next_message)
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
async def process_message(self, db_session: AsyncSession, message: T) -> None:
|
async def process_message(self, db_session: AsyncSession, message: T) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -46,8 +25,16 @@ class Worker(Generic[T]):
|
||||||
async def startup(self, db_session: AsyncSession) -> None:
|
async def startup(self, db_session: AsyncSession) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def in_flight_ids(self) -> set[int]:
|
async def _main_loop(self, db_session: AsyncSession) -> None:
|
||||||
return self._in_flight
|
while not self._stop_event.is_set():
|
||||||
|
next_message = await self.get_next_message(db_session)
|
||||||
|
if next_message:
|
||||||
|
await self.process_message(db_session, next_message)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
async def _until_stopped(self) -> None:
|
||||||
|
await self._stop_event.wait()
|
||||||
|
|
||||||
async def run_forever(self) -> None:
|
async def run_forever(self) -> None:
|
||||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||||
|
@ -59,13 +46,14 @@ class Worker(Generic[T]):
|
||||||
|
|
||||||
async with async_session() as db_session:
|
async with async_session() as db_session:
|
||||||
await self.startup(db_session)
|
await self.startup(db_session)
|
||||||
self._loop.create_task(self._producer(db_session))
|
task = self._loop.create_task(self._main_loop(db_session))
|
||||||
for _ in range(self._workers_count):
|
stop_task = self._loop.create_task(self._until_stopped())
|
||||||
self._loop.create_task(self._consumer(db_session))
|
|
||||||
|
|
||||||
await self._stop_event.wait()
|
done, pending = await asyncio.wait(
|
||||||
logger.info("Waiting for tasks to finish")
|
{task, stop_task}, return_when=asyncio.FIRST_COMPLETED
|
||||||
await self._queue.join()
|
)
|
||||||
|
logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
|
||||||
|
await asyncio.sleep(5)
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
logger.info(f"Cancelling {len(tasks)} tasks")
|
logger.info(f"Cancelling {len(tasks)} tasks")
|
||||||
[task.cancel() for task in tasks]
|
[task.cancel() for task in tasks]
|
||||||
|
|
|
@ -24,7 +24,7 @@ from tests.utils import setup_remote_actor_as_follower
|
||||||
|
|
||||||
|
|
||||||
async def _process_next_incoming_activity(db_session: AsyncSession) -> None:
|
async def _process_next_incoming_activity(db_session: AsyncSession) -> None:
|
||||||
next_activity = await fetch_next_incoming_activity(db_session, set())
|
next_activity = await fetch_next_incoming_activity(db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_incoming_activity(db_session, next_activity)
|
await process_next_incoming_activity(db_session, next_activity)
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ async def test_process_next_outgoing_activity__no_next_activity(
|
||||||
respx_mock: respx.MockRouter,
|
respx_mock: respx.MockRouter,
|
||||||
async_db_session: AsyncSession,
|
async_db_session: AsyncSession,
|
||||||
) -> None:
|
) -> None:
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity is None
|
assert next_activity is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ async def test_process_next_outgoing_activity__server_200(
|
||||||
|
|
||||||
# When processing the next outgoing activity
|
# When processing the next outgoing activity
|
||||||
# Then it is processed
|
# Then it is processed
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_outgoing_activity(async_db_session, next_activity)
|
await process_next_outgoing_activity(async_db_session, next_activity)
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ async def test_process_next_outgoing_activity__webmention(
|
||||||
|
|
||||||
# When processing the next outgoing activity
|
# When processing the next outgoing activity
|
||||||
# Then it is processed
|
# Then it is processed
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_outgoing_activity(async_db_session, next_activity)
|
await process_next_outgoing_activity(async_db_session, next_activity)
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ async def test_process_next_outgoing_activity__error_500(
|
||||||
|
|
||||||
# When processing the next outgoing activity
|
# When processing the next outgoing activity
|
||||||
# Then it is processed
|
# Then it is processed
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_outgoing_activity(async_db_session, next_activity)
|
await process_next_outgoing_activity(async_db_session, next_activity)
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ async def test_process_next_outgoing_activity__errored(
|
||||||
|
|
||||||
# When processing the next outgoing activity
|
# When processing the next outgoing activity
|
||||||
# Then it is processed
|
# Then it is processed
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_outgoing_activity(async_db_session, next_activity)
|
await process_next_outgoing_activity(async_db_session, next_activity)
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ async def test_process_next_outgoing_activity__errored(
|
||||||
assert outgoing_activity.is_errored is True
|
assert outgoing_activity.is_errored is True
|
||||||
|
|
||||||
# And it is skipped from processing
|
# And it is skipped from processing
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity is None
|
assert next_activity is None
|
||||||
|
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ async def test_process_next_outgoing_activity__connect_error(
|
||||||
|
|
||||||
# When processing the next outgoing activity
|
# When processing the next outgoing activity
|
||||||
# Then it is processed
|
# Then it is processed
|
||||||
next_activity = await fetch_next_outgoing_activity(async_db_session, set())
|
next_activity = await fetch_next_outgoing_activity(async_db_session)
|
||||||
assert next_activity
|
assert next_activity
|
||||||
await process_next_outgoing_activity(async_db_session, next_activity)
|
await process_next_outgoing_activity(async_db_session, next_activity)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue