diff --git a/README.md b/README.md index 69545d2..6946204 100644 --- a/README.md +++ b/README.md @@ -46,4 +46,35 @@ from announce_server import register_service def your_function(): pass +``` + +## Registry + +The `announce_server` CLI provides a simple way to start a registry server. The registry server keeps track of available services and periodically sends heartbeat messages to ensure that registered services are still active. + +### Command + +```bash +announce_server start_registry [--address ADDRESS] [--port PORT] [--heartbeat_interval INTERVAL] [--heartbeat_timeout TIMEOUT] +``` + +### Arguments + +- `--address ADDRESS`: The IP address of the server. Default: `0.0.0.0`. +- `--port PORT`: The port number of the server. Default: `4999`. +- `--heartbeat_interval INTERVAL`: The interval between heartbeat messages in seconds. Default: `5`. +- `--heartbeat_timeout TIMEOUT`: The timeout for waiting for a response in seconds. Default: `3`. + +### Example + +To start the registry server with the default configuration, run: + +```bash +announce_server start_registry +``` + +To start the registry server with a custom IP address, port number, heartbeat interval, and timeout, run: + +```bash +announce_server start_registry --address localhost --port 4998 --heartbeat_interval 10 --heartbeat_timeout 5 ``` \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 0e6bd7b..5f9f862 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,17 @@ pub = setuptools_scm twine +api = + aiohttp + +[options.entry_points] +console_scripts = + announce_server = announce_server.__main__:main + + +[tool:pytest] +addopts = --cov --cov-report term-missing + [coverage:run] source = announce_server branch = True @@ -53,5 +64,3 @@ show_missing = True exclude_lines = if __name__ == .__main__.: -[tool:pytest] -addopts = --cov --cov-report term-missing \ No newline at end of file diff --git a/src/announce_server/__main__.py b/src/announce_server/__main__.py new file mode 100644 index 0000000..92302da --- /dev/null +++ b/src/announce_server/__main__.py @@ -0,0 +1,48 @@ +import argparse + +from announce_server import register_service +from announce_server.server import start_server + + +def main(): + parser = argparse.ArgumentParser(description="Announce server CLI") + subparsers = parser.add_subparsers(dest="command", help="Available subcommands") + + # Start registry subcommand + start_registry_parser = subparsers.add_parser( + "start_registry", help="Start the registry server" + ) + start_registry_parser.add_argument( + "--ip", default="0.0.0.0", help="IP address of the host server" + ) + start_registry_parser.add_argument( + "--port", default=4999, type=int, help="Port of the host server" + ) + start_registry_parser.add_argument( + "--heartbeat-interval", + default=5, + type=float, + help="Heartbeat interval in seconds", + ) + start_registry_parser.add_argument( + "--heartbeat-timeout", + default=3, + type=float, + help="Heartbeat timeout in seconds", + ) + + args = parser.parse_args() + + if args.command == "start_registry": + start_server( + address=args.ip, + port=args.port, + heartbeat_interval=args.heartbeat_interval, + heartbeat_timeout=args.heartbeat_timeout, + ) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/src/announce_server/server.py b/src/announce_server/server.py new file mode 100644 index 0000000..8d6ba35 --- /dev/null +++ b/src/announce_server/server.py @@ -0,0 +1,164 @@ +import asyncio +import signal + +import socketio +from aiohttp import web + +sio = socketio.AsyncServer(async_mode="aiohttp") +app = web.Application() +sio.attach(app) + +servers = {} + + +async def available(request): + """ + Return a JSON response containing the available servers. + + Returns + ------- + aiohttp.web.Response + JSON response containing the available servers. + """ + return web.json_response(servers) + + +app.router.add_get("/available", available) + + +@sio.event +async def connect(sid, environ): + """Handle a new connection to the socket.""" + print("Connected:", sid) + + +@sio.event +async def register(sid, data): + """ + Register a new server. + + Parameters + ---------- + sid : str + Socket ID of the connected server. + data : dict + Server information (name, IP, and port). + """ + server_info = data + name = server_info["name"] + + servers[name] = {"ip": server_info["ip"], "port": server_info["port"], "sid": sid} + print(servers) + + +@sio.event +async def disconnect(sid): + """ + Handle a server disconnect. + + Parameters + ---------- + sid : str + Socket ID of the disconnected server. + """ + print("Disconnected from server:", sid) + for name, server in servers.items(): + if server["sid"] == sid: + del servers[name] + break + + +async def heartbeat(sio, interval, timeout): + """ + Periodically send heartbeat messages to connected servers. + + Parameters + ---------- + sio : socketio.AsyncServer + The socket.io server instance. + interval : int + The interval between heartbeat messages in seconds. + timeout : int + The timeout for waiting for a response in seconds. + """ + while True: + await asyncio.sleep(interval) + server_values_copy = list(servers.values()) + for server in server_values_copy: + sid = server["sid"] + try: + print(f"Sending heartbeat to {sid}...") + heartbeat_future = sio.emit("heartbeat", to=sid) + await asyncio.wait_for(heartbeat_future, timeout=timeout) + except (asyncio.TimeoutError, socketio.exceptions.TimeoutError): + print(f"Server {sid} failed to respond to heartbeat after {timeout}s.") + await sio.disconnect(sid) + + +def create_exit_handler(loop, heartbeat_task): + """ + Create an exit handler for gracefully shutting down the server. + + Parameters + ---------- + loop : asyncio.AbstractEventLoop + The event loop. + heartbeat_task : asyncio.Task + The heartbeat task. + + Returns + ------- + Callable + An asynchronous exit handler function. + """ + + async def exit_handler(sig, frame): + print("Shutting down host...") + heartbeat_task.cancel() + await loop.shutdown_asyncgens() + loop.stop() + + return exit_handler + + +def start_server(address, port, heartbeat_interval, heartbeat_timeout): + """ + Run the main server loop. + + Parameters + ---------- + address : str + The IP address of the server. + port : int + The port number of the server. + heartbeat_interval : int + The interval between heartbeat messages in seconds. + heartbeat_timeout : int + The timeout for waiting for a response in seconds. + """ + loop = asyncio.get_event_loop() + heartbeat_task = loop.create_task( + heartbeat(sio, heartbeat_interval, heartbeat_timeout) + ) + aiohttp_app = loop.create_task(web._run_app(app, host=address, port=port)) + + exit_handler = create_exit_handler(loop, heartbeat_task) + signal.signal(signal.SIGINT, exit_handler) + signal.signal(signal.SIGTERM, exit_handler) + + try: + loop.run_until_complete(asyncio.gather(heartbeat_task, aiohttp_app)) + except asyncio.CancelledError: + pass + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.stop() + + +if __name__ == "__main__": + start_server( + address="0.0.0.0", + port=4999, + heartbeat_interval=5, + heartbeat_timeout=3, + )