Commit 7c366c05 authored by Mikhail Sennikov's avatar Mikhail Sennikov
Browse files

Fix warmup

parent 5b315a8d
No related merge requests found
Showing with 7 additions and 10 deletions
+7 -10
......@@ -19,7 +19,7 @@ class NodeSwarm:
self._connections: Dict[int, NodeConnection] = {}
async def _load_from_db(self):
node_repo = NodeRepository(await db.get_db())
node_repo = NodeRepository(await anext(db.get_db()))
nodes = await node_repo.list_all()
self._connections: Dict[int, NodeConnection] = \
{int(getattr(node, 'id')): NodeConnection(node) for node in nodes}
......@@ -34,7 +34,7 @@ class NodeSwarm:
await self._try_connect_all()
async def renew_nodes_list(self):
node_repo = NodeRepository(await db.get_db())
node_repo = NodeRepository(await anext(db.get_db()))
nodes = await node_repo.list_all()
for node in nodes:
node_id = int(node.id)
......@@ -61,7 +61,7 @@ class NodeSwarm:
print(f"Failed to connect to node {node.id}: {e}")
async def set_group_by_id(self, group_id: int) -> Group:
session = await db.get_db()
session = await anext(db.get_db())
group_repo = GroupRepository(session)
group = await group_repo.get(group_id)
if group is None:
......@@ -72,7 +72,7 @@ class NodeSwarm:
return group
async def set_group_by_slug(self, group_slug: str) -> Group:
session = await db.get_db()
session = await anext(db.get_db())
group_repo = GroupRepository(session)
group = await group_repo.get_by_slug(group_slug)
if group is None:
......
......@@ -36,7 +36,7 @@ class HotCams:
return _cam_id
async def try_connect(self, cam_id: int):
cam_repo = CamRepository(await db.get_db())
cam_repo = CamRepository(await anext(db.get_db()))
cam_db_obj = await cam_repo.get(cam_id)
if cam_db_obj is None:
raise HTTPException(
......
......@@ -24,7 +24,7 @@ class AsyncDBConnection:
async with self._engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)
async def _get_db(self) -> AsyncGenerator:
async def get_db(self) -> AsyncGenerator:
db = None
try:
db = self._session_maker()
......@@ -33,14 +33,11 @@ class AsyncDBConnection:
if isinstance(db, AsyncSession):
await db.close()
async def get_db(self) -> AsyncSession | None:
return await anext(self._get_db())
def get_repository(
self,
repository: Type[Repository],
) -> Any:
def repository_getter(db: AsyncSession = Depends(self._get_db)) -> Repository:
def repository_getter(db: AsyncSession = Depends(self.get_db)) -> Repository:
return repository(db)
return repository_getter
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment