From 7c366c051708260f1d613bf7cac73519c7e4ff5d Mon Sep 17 00:00:00 2001 From: Mikhail Sennikov <mifls@yandex.ru> Date: Sun, 28 Jan 2024 18:13:04 +0300 Subject: [PATCH] Fix warmup --- Backend/node/swarm.py | 8 ++++---- ProxyNode/onvif_proxy/hot_cams.py | 2 +- database/db_connection.py | 7 ++----- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/Backend/node/swarm.py b/Backend/node/swarm.py index d0d7633..11e284d 100644 --- a/Backend/node/swarm.py +++ b/Backend/node/swarm.py @@ -19,7 +19,7 @@ class NodeSwarm: self._connections: Dict[int, NodeConnection] = {} async def _load_from_db(self): - node_repo = NodeRepository(await db.get_db()) + node_repo = NodeRepository(await anext(db.get_db())) nodes = await node_repo.list_all() self._connections: Dict[int, NodeConnection] = \ {int(getattr(node, 'id')): NodeConnection(node) for node in nodes} @@ -34,7 +34,7 @@ class NodeSwarm: await self._try_connect_all() async def renew_nodes_list(self): - node_repo = NodeRepository(await db.get_db()) + node_repo = NodeRepository(await anext(db.get_db())) nodes = await node_repo.list_all() for node in nodes: node_id = int(node.id) @@ -61,7 +61,7 @@ class NodeSwarm: print(f"Failed to connect to node {node.id}: {e}") async def set_group_by_id(self, group_id: int) -> Group: - session = await db.get_db() + session = await anext(db.get_db()) group_repo = GroupRepository(session) group = await group_repo.get(group_id) if group is None: @@ -72,7 +72,7 @@ class NodeSwarm: return group async def set_group_by_slug(self, group_slug: str) -> Group: - session = await db.get_db() + session = await anext(db.get_db()) group_repo = GroupRepository(session) group = await group_repo.get_by_slug(group_slug) if group is None: diff --git a/ProxyNode/onvif_proxy/hot_cams.py b/ProxyNode/onvif_proxy/hot_cams.py index 4ccebbd..ccba830 100644 --- a/ProxyNode/onvif_proxy/hot_cams.py +++ b/ProxyNode/onvif_proxy/hot_cams.py @@ -36,7 +36,7 @@ class HotCams: return _cam_id async def try_connect(self, cam_id: int): - cam_repo = CamRepository(await db.get_db()) + cam_repo = CamRepository(await anext(db.get_db())) cam_db_obj = await cam_repo.get(cam_id) if cam_db_obj is None: raise HTTPException( diff --git a/database/db_connection.py b/database/db_connection.py index 1facd93..b271020 100644 --- a/database/db_connection.py +++ b/database/db_connection.py @@ -24,7 +24,7 @@ class AsyncDBConnection: async with self._engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) - async def _get_db(self) -> AsyncGenerator: + async def get_db(self) -> AsyncGenerator: db = None try: db = self._session_maker() @@ -33,14 +33,11 @@ class AsyncDBConnection: if isinstance(db, AsyncSession): await db.close() - async def get_db(self) -> AsyncSession | None: - return await anext(self._get_db()) - def get_repository( self, repository: Type[Repository], ) -> Any: - def repository_getter(db: AsyncSession = Depends(self._get_db)) -> Repository: + def repository_getter(db: AsyncSession = Depends(self.get_db)) -> Repository: return repository(db) return repository_getter -- GitLab