swarm.py 3.26 KiB
import asyncio

import httpx
from fastapi import HTTPException, Security
from typing import Dict

from sqlalchemy.ext.asyncio import AsyncSession
from starlette import status

from database import db_connection as db
from database.models.cams import Group
from database.repositories.cam_groups import GroupRepository
from database.repositories.nodes import NodeRepository
from .node import NodeConnection


class NodeSwarm:
    def __init__(self):
        self._connections: Dict[int, NodeConnection] = {}

    async def _load_from_db(self):
        node_repo = NodeRepository(await db.get_db())
        nodes = await node_repo.list_all()
        for node in nodes:
            print(node.name)
        self._connections: Dict[int, NodeConnection] = \
            {int(node.id): NodeConnection(node) for node in nodes}

    async def _try_connect_all(self):
        for _, node in self._connections:
            asyncio.create_task(node.connect())

    async def connect_all(self):
        await self.a_cancel()
        await self._load_from_db()
        await self._try_connect_all()

    async def renew_nodes_list(self):
        node_repo = NodeRepository(await db.get_db())
        nodes = await node_repo.list_all()
        for node in nodes:
            node_id = int(node.id)
            if node_id not in self._connections:
                self._connections[node_id] = NodeConnection(node)
                asyncio.create_task(node.connect())

    async def _set_group(self, group: Group, session: AsyncSession):
        node_repo = NodeRepository(session)
        pairs = await node_repo.get_node_cam_pairs(group)
        for node_id, cam_id in pairs:
            node = self._connections.get(node_id, None)
            if node is None:
                print(f"Node {node_id} not configured. Try to renew nodes list.")
                continue

            if not node.connected:
                print(f"No connection to node {node.id} or not ready yet.")
                continue

            try:
                await node.choose_camera(cam_id)
            except (httpx.ConnectTimeout, httpx.ConnectError) as e:
                print(f"Failed to connect to node {node.id}: {e}")

    async def set_group_by_id(self, group_id: int) -> Group:
        session = await db.get_db()
        group_repo = GroupRepository(session)
        group = await group_repo.get(group_id)
        if group is None:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND
            )
        await self._set_group(group, session)
        return group

    async def set_group_by_slug(self, group_slug: str) -> Group:
        session = await db.get_db()
        group_repo = GroupRepository(session)
        group = await group_repo.get_by_slug(group_slug)
        if group is None:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND
            )
        await self._set_group(group, session)
        return group

    def get_connection(self, node_id: int) -> NodeConnection | None:
        return self._connections.get(node_id, None)

    async def a_cancel(self):
        for _, node in self._connections:
            await node.a_cancel()

    def cancel(self):
        for _, node in self._connections:
            node.cancel()

    def __del__(self):
        self.cancel()