Source

Target

Commits (4)
Showing with 75 additions and 55 deletions
+75 -55
......@@ -19,10 +19,11 @@ class NodeSwarm:
self._connections: Dict[int, NodeConnection] = {}
async def _load_from_db(self):
node_repo = NodeRepository(await db.get_db())
nodes = await node_repo.list_all()
self._connections: Dict[int, NodeConnection] = \
{int(getattr(node, 'id')): NodeConnection(node) for node in nodes}
async with db.get_session() as session:
node_repo = NodeRepository(session)
nodes = await node_repo.list_all()
self._connections: Dict[int, NodeConnection] = \
{int(getattr(node, 'id')): NodeConnection(node) for node in nodes}
async def _try_connect_all(self):
for _, node in self._connections.items():
......@@ -34,13 +35,14 @@ class NodeSwarm:
await self._try_connect_all()
async def renew_nodes_list(self):
node_repo = NodeRepository(await db.get_db())
nodes = await node_repo.list_all()
for node in nodes:
node_id = int(node.id)
if node_id not in self._connections:
self._connections[node_id] = NodeConnection(node)
asyncio.create_task(node.connect())
async with db.get_session() as session:
node_repo = NodeRepository(session)
nodes = await node_repo.list_all()
for node in nodes:
node_id = int(node.id)
if node_id not in self._connections:
self._connections[node_id] = NodeConnection(node)
asyncio.create_task(node.connect())
async def _set_group(self, group: Group, session: AsyncSession):
node_repo = NodeRepository(session)
......@@ -61,26 +63,26 @@ class NodeSwarm:
print(f"Failed to connect to node {node.id}: {e}")
async def set_group_by_id(self, group_id: int) -> Group:
session = await db.get_db()
group_repo = GroupRepository(session)
group = await group_repo.get(group_id)
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
await self._set_group(group, session)
return group
async with db.get_session() as session:
group_repo = GroupRepository(session)
group = await group_repo.get(group_id)
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
await self._set_group(group, session)
return group
async def set_group_by_slug(self, group_slug: str) -> Group:
session = await db.get_db()
group_repo = GroupRepository(session)
group = await group_repo.get_by_slug(group_slug)
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
await self._set_group(group, session)
return group
async with db.get_session() as session:
group_repo = GroupRepository(session)
group = await group_repo.get_by_slug(group_slug)
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
await self._set_group(group, session)
return group
def get_connection(self, node_id: int) -> NodeConnection | None:
return self._connections.get(node_id, None)
......
......@@ -3,6 +3,7 @@ import asyncio
from asyncio import Task
from fastapi import FastAPI
from node_settings.node_settings import node_router
from onvif_proxy.onvif_proxy import onvif_router
from rtsp.rtsp_placeholder import start_rtsp_server
from node_switch.proxy_switch import switch_router
......@@ -21,6 +22,12 @@ app.include_router(
tags=["Proxy switch"],
)
app.include_router(
node_router,
prefix="/node",
tags=["Proxy node settings"],
)
rtsp_task = None
......
......@@ -13,7 +13,8 @@ from onvif_proxy.auth_utils import make_digest
# noinspection HttpUrlsUsage
class Camera:
def __init__(self, name, host, port, username, password: str | bytes):
def __init__(self, cam_id, name, host, port, username, password: str | bytes):
self.id = cam_id
self.name = name
self.host = host
self.port = port
......
......@@ -16,12 +16,12 @@ class HotCams:
class CamConnection:
def __init__(self):
self._connect_task = None
self.camera = None
self.camera: Camera | None = None
async def _create_camera(self, _cam_id, kwargs):
async def _create_camera(self, _cam_id, **kwargs):
while True:
try:
cam = Camera(**kwargs)
cam = Camera(_cam_id, **kwargs)
await cam.update()
break
except (ONVIFError, httpx.ConnectTimeout, httpx.ConnectError) as e:
......@@ -36,16 +36,17 @@ class HotCams:
return _cam_id
async def try_connect(self, cam_id: int):
cam_repo = CamRepository(await db.get_db())
cam_db_obj = await cam_repo.get(cam_id)
if cam_db_obj is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
params = cam_db_obj.to_dict()
del params['id']
self._connect_task = asyncio.create_task(self._create_camera(cam_id, **params))
async with db.get_session() as session:
cam_repo = CamRepository(session)
cam_db_obj = await cam_repo.get(cam_id)
if cam_db_obj is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND
)
params = cam_db_obj.to_dict()
del params['id']
self._connect_task = asyncio.create_task(self._create_camera(cam_id, **params))
async def a_cancel(self):
if isinstance(self._connect_task, Task):
......
......@@ -10,6 +10,19 @@ from database.repositories.base import Repository
class AsyncDBConnection:
class WithSession:
def __init__(self, session_maker):
self._maker = session_maker
self._session = None
async def __aenter__(self):
self._session = self._maker()
return self._session
async def __aexit__(self, exc_type, exc, tb):
if isinstance(self._session, AsyncSession):
await self._session.close()
def __init__(self, url: str):
self._engine = create_async_engine(url, echo=False)
# noinspection PyTypeChecker
......@@ -33,8 +46,8 @@ class AsyncDBConnection:
if isinstance(db, AsyncSession):
await db.close()
async def get_db(self) -> AsyncSession | None:
return await anext(self._get_db())
def get_session(self):
return self.WithSession(self._session_maker)
def get_repository(
self,
......
from typing import Any, Dict
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class Base(declarative_base()):
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for column in self.__table__.columns:
result[column.name] = getattr(self, column.name)
return result
class SerializerMixin:
def to_dict(self):
return {column.name: getattr(self, column.name) for column in self.__table__.columns}
from sqlalchemy import Column, Integer, String, ForeignKey,\
Index, UniqueConstraint
from .base import Base
from .base import Base, SerializerMixin
class Cam(Base):
class Cam(Base, SerializerMixin):
__tablename__ = 'cams'
id: Column = Column(Integer, primary_key=True) # Will make SERIAL
......