diff --git a/Backend/node/swarm.py b/Backend/node/swarm.py index d0d76330b12d10764e146f00f4c0f5831b63b7e7..11e284d5670be0fcef1db1aa14e74b301abc5af2 100644 --- a/Backend/node/swarm.py +++ b/Backend/node/swarm.py @@ -19,7 +19,7 @@ class NodeSwarm: self._connections: Dict[int, NodeConnection] = {} async def _load_from_db(self): - node_repo = NodeRepository(await db.get_db()) + node_repo = NodeRepository(await anext(db.get_db())) nodes = await node_repo.list_all() self._connections: Dict[int, NodeConnection] = \ {int(getattr(node, 'id')): NodeConnection(node) for node in nodes} @@ -34,7 +34,7 @@ class NodeSwarm: await self._try_connect_all() async def renew_nodes_list(self): - node_repo = NodeRepository(await db.get_db()) + node_repo = NodeRepository(await anext(db.get_db())) nodes = await node_repo.list_all() for node in nodes: node_id = int(node.id) @@ -61,7 +61,7 @@ class NodeSwarm: print(f"Failed to connect to node {node.id}: {e}") async def set_group_by_id(self, group_id: int) -> Group: - session = await db.get_db() + session = await anext(db.get_db()) group_repo = GroupRepository(session) group = await group_repo.get(group_id) if group is None: @@ -72,7 +72,7 @@ class NodeSwarm: return group async def set_group_by_slug(self, group_slug: str) -> Group: - session = await db.get_db() + session = await anext(db.get_db()) group_repo = GroupRepository(session) group = await group_repo.get_by_slug(group_slug) if group is None: diff --git a/ProxyNode/onvif_proxy/hot_cams.py b/ProxyNode/onvif_proxy/hot_cams.py index 4ccebbd03f4b4c66801b1b301cb6d494eb474d7f..ccba8302a33733bbc7132a2a8827203f28bf2a56 100644 --- a/ProxyNode/onvif_proxy/hot_cams.py +++ b/ProxyNode/onvif_proxy/hot_cams.py @@ -36,7 +36,7 @@ class HotCams: return _cam_id async def try_connect(self, cam_id: int): - cam_repo = CamRepository(await db.get_db()) + cam_repo = CamRepository(await anext(db.get_db())) cam_db_obj = await cam_repo.get(cam_id) if cam_db_obj is None: raise HTTPException( diff --git a/database/db_connection.py b/database/db_connection.py index 1facd934e14afe31abe17517efdbd269a4d7df5a..b2710202935cce5ee0eb3c7ed6c6b58502248ce9 100644 --- a/database/db_connection.py +++ b/database/db_connection.py @@ -24,7 +24,7 @@ class AsyncDBConnection: async with self._engine.begin() as connection: await connection.run_sync(Base.metadata.create_all) - async def _get_db(self) -> AsyncGenerator: + async def get_db(self) -> AsyncGenerator: db = None try: db = self._session_maker() @@ -33,14 +33,11 @@ class AsyncDBConnection: if isinstance(db, AsyncSession): await db.close() - async def get_db(self) -> AsyncSession | None: - return await anext(self._get_db()) - def get_repository( self, repository: Type[Repository], ) -> Any: - def repository_getter(db: AsyncSession = Depends(self._get_db)) -> Repository: + def repository_getter(db: AsyncSession = Depends(self.get_db)) -> Repository: return repository(db) return repository_getter