From 7c366c051708260f1d613bf7cac73519c7e4ff5d Mon Sep 17 00:00:00 2001
From: Mikhail Sennikov <mifls@yandex.ru>
Date: Sun, 28 Jan 2024 18:13:04 +0300
Subject: [PATCH] Fix warmup

---
 Backend/node/swarm.py             | 8 ++++----
 ProxyNode/onvif_proxy/hot_cams.py | 2 +-
 database/db_connection.py         | 7 ++-----
 3 files changed, 7 insertions(+), 10 deletions(-)

diff --git a/Backend/node/swarm.py b/Backend/node/swarm.py
index d0d7633..11e284d 100644
--- a/Backend/node/swarm.py
+++ b/Backend/node/swarm.py
@@ -19,7 +19,7 @@ class NodeSwarm:
         self._connections: Dict[int, NodeConnection] = {}
 
     async def _load_from_db(self):
-        node_repo = NodeRepository(await db.get_db())
+        node_repo = NodeRepository(await anext(db.get_db()))
         nodes = await node_repo.list_all()
         self._connections: Dict[int, NodeConnection] = \
             {int(getattr(node, 'id')): NodeConnection(node) for node in nodes}
@@ -34,7 +34,7 @@ class NodeSwarm:
         await self._try_connect_all()
 
     async def renew_nodes_list(self):
-        node_repo = NodeRepository(await db.get_db())
+        node_repo = NodeRepository(await anext(db.get_db()))
         nodes = await node_repo.list_all()
         for node in nodes:
             node_id = int(node.id)
@@ -61,7 +61,7 @@ class NodeSwarm:
                 print(f"Failed to connect to node {node.id}: {e}")
 
     async def set_group_by_id(self, group_id: int) -> Group:
-        session = await db.get_db()
+        session = await anext(db.get_db())
         group_repo = GroupRepository(session)
         group = await group_repo.get(group_id)
         if group is None:
@@ -72,7 +72,7 @@ class NodeSwarm:
         return group
 
     async def set_group_by_slug(self, group_slug: str) -> Group:
-        session = await db.get_db()
+        session = await anext(db.get_db())
         group_repo = GroupRepository(session)
         group = await group_repo.get_by_slug(group_slug)
         if group is None:
diff --git a/ProxyNode/onvif_proxy/hot_cams.py b/ProxyNode/onvif_proxy/hot_cams.py
index 4ccebbd..ccba830 100644
--- a/ProxyNode/onvif_proxy/hot_cams.py
+++ b/ProxyNode/onvif_proxy/hot_cams.py
@@ -36,7 +36,7 @@ class HotCams:
             return _cam_id
 
         async def try_connect(self, cam_id: int):
-            cam_repo = CamRepository(await db.get_db())
+            cam_repo = CamRepository(await anext(db.get_db()))
             cam_db_obj = await cam_repo.get(cam_id)
             if cam_db_obj is None:
                 raise HTTPException(
diff --git a/database/db_connection.py b/database/db_connection.py
index 1facd93..b271020 100644
--- a/database/db_connection.py
+++ b/database/db_connection.py
@@ -24,7 +24,7 @@ class AsyncDBConnection:
         async with self._engine.begin() as connection:
             await connection.run_sync(Base.metadata.create_all)
 
-    async def _get_db(self) -> AsyncGenerator:
+    async def get_db(self) -> AsyncGenerator:
         db = None
         try:
             db = self._session_maker()
@@ -33,14 +33,11 @@ class AsyncDBConnection:
             if isinstance(db, AsyncSession):
                 await db.close()
 
-    async def get_db(self) -> AsyncSession | None:
-        return await anext(self._get_db())
-
     def get_repository(
             self,
             repository: Type[Repository],
     ) -> Any:
-        def repository_getter(db: AsyncSession = Depends(self._get_db)) -> Repository:
+        def repository_getter(db: AsyncSession = Depends(self.get_db)) -> Repository:
             return repository(db)
 
         return repository_getter
-- 
GitLab