Source code for noob.runner.zmq.command

import asyncio
import threading
from time import time
from typing import cast

import zmq

from noob import init_logger
from noob.config import config
from noob.network.loop import EventloopMixin
from noob.network.message import (
    AnnounceMsg,
    AnnounceValue,
    DeinitMsg,
    EpochEndedMsg,
    IdentifyMsg,
    IdentifyValue,
    Message,
    MessageType,
    PingMsg,
    ProcessMsg,
    StartMsg,
    StatusMsg,
    StopMsg,
)
from noob.types import Epoch, NodeID


[docs] class CommandNode(EventloopMixin): """ Pub node that controls the state of the other nodes/announces addresses - one PUB socket to distribute commands - one ROUTER socket to receive return messages from runner nodes - one SUB socket to subscribe to all events The wrapping runner should register callbacks with `add_callback` to handle incoming messages. """ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = None): """ Args: runner_id (str): The unique ID for the runner/tube session. All nodes within a runner use this to limit communication within a tube protocol: port: """ self.runner_id = runner_id self.port = port self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") self._nodes: dict[str, IdentifyValue] = {} self._ready_condition: threading.Condition = None # type: ignore[assignment] self._init = threading.Event() super().__init__() @property def pub_address(self) -> str: """Address the publisher bound to""" if self.protocol == "ipc": path = config.tmp_dir / f"{self.runner_id}/command/outbox" path.parent.mkdir(parents=True, exist_ok=True) return f"{self.protocol}://{str(path)}" else: raise NotImplementedError() @property def router_address(self) -> str: """Address the return router is bound to""" if self.protocol == "ipc": path = config.tmp_dir / f"{self.runner_id}/command/inbox" path.parent.mkdir(parents=True, exist_ok=True) return f"{self.protocol}://{str(path)}" else: raise NotImplementedError()
[docs] def run(self) -> None: """ Target for :class:`threading.Thread` """ asyncio.run(self._run())
async def _run(self) -> None: self.init() await self._poll_receivers()
[docs] def init(self) -> None: self.logger.debug("Starting command runner") self._init.clear() self._init_loop() self._ready_condition = threading.Condition() self._init_sockets() self._init.set() self.logger.debug("Command runner started")
[docs] def deinit(self) -> None: """Close the eventloop, stop processing messages, reset state""" self.logger.debug("Deinitializing") async def _deinit() -> None: msg = DeinitMsg(node_id="command") await self.sockets["outbox"].send_multipart([b"deinit", msg.to_bytes()]) self._quitting.set() self.loop.create_task(_deinit()) self.logger.debug("Queued loop for deinitialization")
[docs] def stop(self) -> None: self.logger.debug("Stopping command runner") msg = StopMsg(node_id="command") self.loop.call_soon_threadsafe( self.sockets["outbox"].send_multipart, [b"stop", msg.to_bytes()] ) self.logger.debug("Command runner stopped")
def _init_sockets(self) -> None: self._init_outbox() self._init_router() self._init_inbox() def _init_outbox(self) -> None: """Create the main control publisher""" pub = self.context.socket(zmq.PUB) pub.bind(self.pub_address) pub.setsockopt_string(zmq.IDENTITY, "command.outbox") self.register_socket("outbox", pub) def _init_router(self) -> None: """Create the inbox router""" router = self.context.socket(zmq.ROUTER) router.bind(self.router_address) router.setsockopt_string(zmq.IDENTITY, "command.router") self.register_socket("router", router, receiver=True) self.add_callback("router", self.on_router) self.logger.debug("Router bound to %s", self.router_address) def _init_inbox(self) -> None: """Subscriber that receives all events from running nodes""" sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, "command.inbox") sub.setsockopt_string(zmq.SUBSCRIBE, "") self.register_socket("inbox", sub, receiver=True)
[docs] async def announce(self) -> None: msg = AnnounceMsg( node_id="command", value=AnnounceValue(inbox=self.router_address, nodes=self._nodes) ) await self.sockets["outbox"].send_multipart([b"announce", msg.to_bytes()])
[docs] async def ping(self) -> None: """Send a ping message asking everyone to identify themselves""" msg = PingMsg(node_id="command") await self.sockets["outbox"].send_multipart([b"ping", msg.to_bytes()])
[docs] def start(self, n: int | None = None) -> None: """ Start running in free-run mode """ self.loop.call_soon_threadsafe( self.sockets["outbox"].send_multipart, [b"start", StartMsg(node_id="command", value=n).to_bytes()], ) self.logger.debug("Sent start message")
[docs] def process(self, epoch: Epoch, input: dict | None = None) -> None: """Emit a ProcessMsg to process a single round through the graph""" # no empty dicts input = input if input else None self.loop.call_soon_threadsafe( self.sockets["outbox"].send_multipart, [ b"process", ProcessMsg(node_id="command", value={"input": input, "epoch": epoch}).to_bytes(), ], ) self.logger.debug("Sent process message")
[docs] async def epoch_ended(self, epoch: Epoch) -> None: await self.sockets["outbox"].send_multipart( [b"epoch_ended", EpochEndedMsg(node_id="command", value=epoch).to_bytes()] )
[docs] def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> None: """ Wait until all the node_ids have announced themselves """ def _ready_nodes() -> set[str]: return {node_id for node_id, state in self._nodes.items() if state["status"] == "ready"} def _is_ready() -> bool: ready_nodes = _ready_nodes() waiting_for = set(node_ids) self.logger.debug( "Checking if ready, ready nodes are: %s, waiting for %s", ready_nodes, waiting_for, ) return waiting_for.issubset(ready_nodes) with self._ready_condition: # ping periodically for identifications in case we have slow subscribers start_time = time() ready = False while time() < start_time + timeout and not ready: ready = self._ready_condition.wait_for(_is_ready, timeout=1) if not ready: self.loop.call_soon_threadsafe(self.loop.create_task, self.ping()) # if still not ready, timeout if not ready: raise TimeoutError( f"Nodes were not ready after the timeout. " f"Waiting for: {set(node_ids)}, " f"ready: {_ready_nodes()}" )
[docs] async def on_router(self, message: Message) -> None: self.logger.debug("Received ROUTER message %s", message) if message.type_ == MessageType.identify: message = cast(IdentifyMsg, message) await self.on_identify(message) elif message.type_ == MessageType.status: message = cast(StatusMsg, message) await self.on_status(message)
[docs] async def on_identify(self, msg: IdentifyMsg) -> None: self._nodes[msg.node_id] = msg.value self.sockets["inbox"].connect(msg.value["outbox"]) try: await self.announce() self.logger.debug("Announced") except Exception as e: self.logger.exception("Exception announced: %s", e) with self._ready_condition: self._ready_condition.notify_all()
[docs] async def on_status(self, msg: StatusMsg) -> None: if msg.node_id not in self._nodes: self.logger.warning( "Node %s sent us a status before sending its full identify message, ignoring", msg.node_id, ) return self._nodes[msg.node_id]["status"] = msg.value with self._ready_condition: self._ready_condition.notify_all()