diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ec4935b03..d583002ae 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,7 +12,7 @@ on: - cron: '0 0 * * *' env: - FAF_DB_VERSION: v138 + FAF_DB_VERSION: v143 FLYWAY_VERSION: 7.5.4 jobs: diff --git a/compose.yaml b/compose.yaml index 4473a4bb1..238d3f41c 100644 --- a/compose.yaml +++ b/compose.yaml @@ -16,7 +16,7 @@ services: - "3306:3306" faf-db-migrations: - image: faforever/faf-db-migrations:v138 + image: faforever/faf-db-migrations:v143 command: migrate environment: FLYWAY_URL: jdbc:mysql://faf-db/faf?useSSL=false diff --git a/server/__init__.py b/server/__init__.py index bf6a5c40a..cb8f762f5 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -117,6 +117,7 @@ import server.metrics as metrics from .asyncio_extensions import map_suppress, synchronizedmethod +from .avatar_change_queue_service import AvatarChangeQueueService from .broadcast_service import BroadcastService from .client_message_queue_service import ClientMessageQueueService from .config import TRACE, config @@ -144,6 +145,7 @@ __copyright__ = "Copyright (c) 2011-2015 " + __author__ __all__ = ( + "AvatarChangeQueueService", "BroadcastService", "ClientMessageQueueService", "ConfigurationService", diff --git a/server/avatar_change_queue_service.py b/server/avatar_change_queue_service.py new file mode 100644 index 000000000..349cc184e --- /dev/null +++ b/server/avatar_change_queue_service.py @@ -0,0 +1,119 @@ +"""RabbitMQ consumer that refreshes player avatars from DB on update events.""" + +import json +import logging +import socket +from typing import Any, ClassVar, Optional + +from aio_pika.abc import AbstractIncomingMessage, AbstractQueue + +from .config import config +from .core import Service +from .decorators import with_logger +from .message_queue_service import MessageQueueService +from .player_service import PlayerService + +PLAYER_AVATAR_UPDATE_ROUTING_KEY = "success.player_avatar.update" + + +@with_logger +class AvatarChangeQueueService(Service): + + """ + Consume `success.player_avatar.update` messages and refresh players. + + Wire contract + ------------- + Publishers post to the `MQ_EXCHANGE_NAME` topic exchange with routing + key `success.player_avatar.update`. The body is a UTF-8 JSON object: + + - `player_id` (int, required): the player whose selected avatar changed. + - `avatar_id` (int or null, optional): the newly selected avatar id, or + null if the player cleared their avatar. The lobby itself ignores this + field — it always re-reads the DB so it gets the matching url/tooltip + and applies the ownership check. The field is shipped for the benefit + of other subscribers that may want to act on the change without an + extra DB roundtrip. + + On receipt the lobby re-reads the affected player's avatar from the DB + and marks them dirty so the existing `BroadcastService` emits a + `player_info` to every connected client on its next tick. + """ + + _logger: ClassVar[logging.Logger] + + def __init__( + self, + message_queue_service: MessageQueueService, + player_service: PlayerService, + ): + """Wire dependencies; consumer is started in `initialize`.""" + self.message_queue_service = message_queue_service + self.player_service = player_service + self._queue: Optional[AbstractQueue] = None + self._consumer_tag: Optional[str] = None + + async def initialize(self) -> None: + # Per-instance queue: every lobby pod must process every event so + # whichever pod is hosting the player can refresh its in-memory + # state. Naming follows `...`, + # matching `ClientMessageQueueService`. + queue_name = ( + f"{config.MQ_EXCHANGE_NAME}.lobby.player_avatar.update" + f".{socket.gethostname()}" + ) + result = await self.message_queue_service.declare_queue_and_consume( + exchange_name=config.MQ_EXCHANGE_NAME, + routing_key=PLAYER_AVATAR_UPDATE_ROUTING_KEY, + callback=self._on_message, + queue_name=queue_name, + ) + if result is not None: + self._queue, self._consumer_tag = result + + async def shutdown(self) -> None: + if self._queue is not None and self._consumer_tag is not None: + await self._queue.cancel(self._consumer_tag) + self._queue = None + self._consumer_tag = None + + async def _on_message(self, message: AbstractIncomingMessage) -> None: + async with message.process(requeue=False): + try: + payload = json.loads(message.body) + except (ValueError, UnicodeDecodeError): + self._logger.warning( + "Dropping avatar-update message with non-JSON body" + ) + return + + if not isinstance(payload, dict): + self._logger.warning( + "Dropping avatar-update message: payload is not a JSON object" + ) + return + + raw_player_id: Any = payload.get("player_id") + # Reject bool explicitly: int(True) == 1 would otherwise sneak + # through and refresh player 1 on every truthy payload. + if isinstance(raw_player_id, bool): + self._logger.warning( + "Dropping avatar-update message: invalid player_id %r", + raw_player_id, + ) + return + try: + player_id = int(raw_player_id) + except (TypeError, ValueError): + self._logger.warning( + "Dropping avatar-update message: invalid player_id %r", + raw_player_id, + ) + return + + refreshed = await self.player_service.refresh_player_avatar(player_id) + if not refreshed: + self._logger.debug( + "avatar-update for player %s ignored: not connected here", + player_id, + ) diff --git a/server/db/models.py b/server/db/models.py index 5934b5e45..e4e671bb5 100644 --- a/server/db/models.py +++ b/server/db/models.py @@ -184,7 +184,8 @@ Column("create_time", TIMESTAMP, nullable=False), Column("update_time", TIMESTAMP, nullable=False), Column("user_agent", String), - Column("last_login", TIMESTAMP) + Column("last_login", TIMESTAMP), + Column("avatar_id", Integer, ForeignKey("avatars_list.id")), ) leaderboard = Table( diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index 5b5ec2e10..c2203b7c9 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -957,6 +957,7 @@ async def command_avatar(self, message): ) self.player.avatar = None + new_avatar_id = row.id if avatar_url is not None else None if avatar_url is not None: await conn.execute( avatars.update().where( @@ -972,6 +973,23 @@ async def command_avatar(self, message): "url": avatar_url, "tooltip": row.tooltip } + # Mirror the selection to login.avatar_id so reads via the new + # authoritative column stay consistent with the legacy flag. + await conn.execute( + t_login.update().where( + t_login.c.id == self.player.id + ).values( + avatar_id=new_avatar_id + ) + ) + avatar_tooltip = ( + self.player.avatar["tooltip"] if self.player.avatar else None + ) + self._logger.info( + "Player %s changed avatar via client connection: " + "avatar_id=%s tooltip=%s", + self.player.id, new_avatar_id, avatar_tooltip + ) self.player_service.mark_dirty(self.player) else: raise KeyError("invalid action") diff --git a/server/player_service.py b/server/player_service.py index f0e544c1f..a0f28844a 100644 --- a/server/player_service.py +++ b/server/player_service.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, ClassVar, Optional, ValuesView import aiocron -from sqlalchemy import and_, select +from sqlalchemy import and_, or_, select import server.metrics as metrics from server.config import config @@ -81,6 +81,25 @@ def pop_dirty_players(self) -> set[Player]: return dirty_players + @staticmethod + def _avatar_grant_join_onclause(): + # ON clause for joining `avatars` against `login` to pick a + # player's currently worn avatar with ownership enforced. + # Prefer the new authoritative `login.avatar_id` column; fall + # back to the legacy `avatars.selected = 1` row only when + # `avatar_id` is null. Either way the row must be a real grant + # in `avatars`, so revoked grants resolve to no avatar. + return and_( + avatars.c.idUser == login.c.id, + or_( + avatars.c.idAvatar == login.c.avatar_id, + and_( + login.c.avatar_id.is_(None), + avatars.c.selected == 1 + ) + ) + ) + async def fetch_player_data(self, player: Player) -> None: async with self._db.acquire() as conn: result = await conn.execute( @@ -100,10 +119,7 @@ async def fetch_player_data(self, player: Player) -> None: .outerjoin(clan) .outerjoin( avatars, - onclause=and_( - avatars.c.idUser == login.c.id, - avatars.c.selected == 1 - ) + onclause=self._avatar_grant_join_onclause() ) .outerjoin(avatars_list) ).where(login.c.id == player.id) # yapf: disable @@ -129,6 +145,60 @@ async def fetch_player_data(self, player: Player) -> None: await self._fetch_player_ratings(player, conn) + async def _fetch_player_avatar( + self, player: Player, conn + ) -> Optional[int]: + """Refresh `player.avatar` from DB; return the avatar id, if any.""" + sql = select( + avatars_list.c.id, + avatars_list.c.url, + avatars_list.c.tooltip, + ).select_from( + login + .outerjoin( + avatars, + onclause=self._avatar_grant_join_onclause() + ) + .outerjoin(avatars_list) + ).where(login.c.id == player.id) + + result = await conn.execute(sql) + row = result.fetchone() + if row is None: + player.avatar = None + return None + + row_mapping = row._mapping + avatar_id = row_mapping.get(avatars_list.c.id) + url = row_mapping.get(avatars_list.c.url) + tooltip = row_mapping.get(avatars_list.c.tooltip) + if url and tooltip: + player.avatar = {"url": url, "tooltip": tooltip} + return avatar_id + player.avatar = None + return None + + async def refresh_player_avatar(self, player_id: int) -> bool: + """ + Re-read avatar for one player and mark them dirty. + + `BroadcastService` emits a `player_info` on the next tick. Returns + True if the player is connected to this instance, False otherwise. + """ + player = self._players.get(player_id) + if player is None: + return False + async with self._db.acquire() as conn: + avatar_id = await self._fetch_player_avatar(player, conn) + avatar_tooltip = player.avatar["tooltip"] if player.avatar else None + self._logger.info( + "Player %s avatar refreshed from RabbitMQ event: " + "avatar_id=%s tooltip=%s", + player_id, avatar_id, avatar_tooltip + ) + self.mark_dirty(player) + return True + async def _fetch_player_ratings(self, player: Player, conn): sql = select( leaderboard_rating.c.mean, diff --git a/tests/unit_tests/test_avatar_change_queue_service.py b/tests/unit_tests/test_avatar_change_queue_service.py new file mode 100644 index 000000000..9439e7b21 --- /dev/null +++ b/tests/unit_tests/test_avatar_change_queue_service.py @@ -0,0 +1,198 @@ +import json +from unittest import mock + +import pytest + +from server.avatar_change_queue_service import ( + PLAYER_AVATAR_UPDATE_ROUTING_KEY, + AvatarChangeQueueService +) +from server.config import config + + +def make_incoming_message(body: bytes): + """Build a stand-in for aio_pika's IncomingMessage.""" + message = mock.Mock() + message.body = body + + process_cm = mock.MagicMock() + process_cm.__aenter__ = mock.AsyncMock(return_value=None) + process_cm.__aexit__ = mock.AsyncMock(return_value=False) + message.process = mock.Mock(return_value=process_cm) + return message + + +@pytest.fixture +def fake_player_service(): + service = mock.Mock() + service.refresh_player_avatar = mock.AsyncMock(return_value=True) + return service + + +@pytest.fixture +async def avatar_queue_service(fake_player_service): + queue = mock.Mock() + queue.cancel = mock.AsyncMock() + mq_service = mock.Mock() + mq_service.declare_queue_and_consume = mock.AsyncMock( + return_value=(queue, "consumer-tag-avatar") + ) + service = AvatarChangeQueueService( + message_queue_service=mq_service, + player_service=fake_player_service, + ) + await service.initialize() + yield service + await service.shutdown() + + +async def test_shutdown_cancels_consumer(fake_player_service): + queue = mock.Mock() + queue.cancel = mock.AsyncMock() + mq_service = mock.Mock() + mq_service.declare_queue_and_consume = mock.AsyncMock( + return_value=(queue, "consumer-tag-xyz") + ) + service = AvatarChangeQueueService( + message_queue_service=mq_service, + player_service=fake_player_service, + ) + await service.initialize() + await service.shutdown() + + queue.cancel.assert_awaited_once_with("consumer-tag-xyz") + assert service._queue is None + assert service._consumer_tag is None + + +async def test_shutdown_noop_when_broker_unavailable(fake_player_service): + mq_service = mock.Mock() + mq_service.declare_queue_and_consume = mock.AsyncMock(return_value=None) + service = AvatarChangeQueueService( + message_queue_service=mq_service, + player_service=fake_player_service, + ) + await service.initialize() + await service.shutdown() + + +async def test_initialize_declares_consumer(avatar_queue_service): + mq = avatar_queue_service.message_queue_service + mq.declare_queue_and_consume.assert_awaited_once() + kwargs = mq.declare_queue_and_consume.await_args.kwargs + assert kwargs["exchange_name"] == config.MQ_EXCHANGE_NAME + assert kwargs["routing_key"] == PLAYER_AVATAR_UPDATE_ROUTING_KEY + assert kwargs["callback"] == avatar_queue_service._on_message + + +async def test_refresh_called_for_valid_payload( + avatar_queue_service, fake_player_service +): + msg = make_incoming_message( + json.dumps({"player_id": 42, "avatar_id": 5}).encode() + ) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_awaited_once_with(42) + + +async def test_refresh_called_when_avatar_cleared( + avatar_queue_service, fake_player_service +): + msg = make_incoming_message( + json.dumps({"player_id": 42, "avatar_id": None}).encode() + ) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_awaited_once_with(42) + + +async def test_extra_fields_are_tolerated( + avatar_queue_service, fake_player_service +): + msg = make_incoming_message( + json.dumps({"player_id": 7, "avatar_id": 1, "future_field": "ok"}).encode() + ) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_awaited_once_with(7) + + +async def test_player_not_connected_is_logged( + avatar_queue_service, fake_player_service, caplog +): + fake_player_service.refresh_player_avatar = mock.AsyncMock(return_value=False) + + msg = make_incoming_message( + json.dumps({"player_id": 999, "avatar_id": 1}).encode() + ) + + import logging + with caplog.at_level(logging.DEBUG): + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_awaited_once_with(999) + assert any("not connected here" in m for m in caplog.messages) + + +async def test_malformed_json_body_is_dropped( + avatar_queue_service, fake_player_service, caplog +): + msg = make_incoming_message(b"definitely not json") + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_not_awaited() + assert any("non-JSON body" in m for m in caplog.messages) + + +async def test_non_object_json_body_is_dropped( + avatar_queue_service, fake_player_service +): + msg = make_incoming_message(b"[1, 2, 3]") + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_not_awaited() + + +async def test_missing_player_id_is_dropped( + avatar_queue_service, fake_player_service, caplog +): + msg = make_incoming_message(json.dumps({"avatar_id": 5}).encode()) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_not_awaited() + assert any("invalid player_id" in m for m in caplog.messages) + + +async def test_non_int_player_id_is_dropped( + avatar_queue_service, fake_player_service, caplog +): + msg = make_incoming_message( + json.dumps({"player_id": "not-an-int", "avatar_id": 5}).encode() + ) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_not_awaited() + assert any("invalid player_id" in m for m in caplog.messages) + + +async def test_bool_player_id_is_dropped( + avatar_queue_service, fake_player_service, caplog +): + # int(True) == 1, so without an explicit bool check this would refresh + # player 1. Guard against that surprise. + msg = make_incoming_message( + json.dumps({"player_id": True, "avatar_id": 5}).encode() + ) + + await avatar_queue_service._on_message(msg) + + fake_player_service.refresh_player_avatar.assert_not_awaited() + assert any("invalid player_id" in m for m in caplog.messages) diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index ea764e84d..d9ee09535 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -700,10 +700,39 @@ async def test_command_avatar_select(database, lobbyconnection: LobbyConnection) }) async with database.acquire() as conn: - result = await conn.execute("SELECT selected from avatars where idUser=2") + result = await conn.execute( + "SELECT idAvatar, selected FROM avatars WHERE idUser=2 AND selected=1" + ) row = result.fetchone() + assert row is not None + selected_avatar_id = row.idAvatar assert row.selected == 1 + result = await conn.execute("SELECT avatar_id FROM login WHERE id=2") + row = result.fetchone() + assert row.avatar_id == selected_avatar_id + + +async def test_command_avatar_select_clear( + database, lobbyconnection: LobbyConnection +): + lobbyconnection.player.id = 2 # Dostya test user + + await lobbyconnection.on_message_received({ + "command": "avatar", + "action": "select", + "avatar": None, + }) + + async with database.acquire() as conn: + result = await conn.execute( + "SELECT COUNT(*) AS n FROM avatars WHERE idUser=2 AND selected=1" + ) + assert result.fetchone().n == 0 + + result = await conn.execute("SELECT avatar_id FROM login WHERE id=2") + assert result.fetchone().avatar_id is None + async def get_friends(player_id, database): async with database.acquire() as conn: diff --git a/tests/unit_tests/test_player_service.py b/tests/unit_tests/test_player_service.py index 99924ef5a..b46c63950 100644 --- a/tests/unit_tests/test_player_service.py +++ b/tests/unit_tests/test_player_service.py @@ -64,6 +64,30 @@ async def test_fetch_player_data_non_existent(player_factory, player_service): await player_service.fetch_player_data(player) +async def test_refresh_player_avatar_connected( + player_factory, player_service +): + player = player_factory(player_id=50) + player.avatar = None # simulate stale (e.g. just connected) + player_service[50] = player + + refreshed = await player_service.refresh_player_avatar(50) + + assert refreshed is True + assert player.avatar == { + "url": "https://content.faforever.com/faf/avatars/UEF.png", + "tooltip": "UEF", + } + assert player in player_service._dirty_players + + +async def test_refresh_player_avatar_not_connected(player_service): + refreshed = await player_service.refresh_player_avatar(999) + + assert refreshed is False + assert not player_service._dirty_players + + async def test_magic_methods(player_factory, player_service): player = player_factory(player_id=0) player_service[0] = player