Michael Pilosov
2 years ago
25 changed files with 1496 additions and 4 deletions
@ -0,0 +1,57 @@ |
|||||
|
import time |
||||
|
from eden.client import Client |
||||
|
from eden.datatypes import Image |
||||
|
|
||||
|
import subprocess |
||||
|
import socket |
||||
|
|
||||
|
# Get IP address of eden-server service |
||||
|
hostname = 'eden-server' |
||||
|
port = 5656 |
||||
|
network_name = 'eden-network' |
||||
|
import docker |
||||
|
client = docker.from_env() |
||||
|
project_name = 'not_so_minimal' |
||||
|
container_name = f'{project_name}_{hostname}_1' |
||||
|
container = client.containers.get(container_name) |
||||
|
ip_address = container.attrs['NetworkSettings']['Networks'][network_name]['IPAddress'] |
||||
|
print(ip_address) |
||||
|
url = f"http://{ip_address}:{port}" |
||||
|
|
||||
|
## set up a client |
||||
|
c = Client(url=url, username="abraham") |
||||
|
|
||||
|
# get server's identity |
||||
|
generator_id = c.get_generator_identity() |
||||
|
print(generator_id) |
||||
|
|
||||
|
## define input args to be sent |
||||
|
config = { |
||||
|
"width": 2000, ## width |
||||
|
"height": 1000, ## height |
||||
|
"input_image": Image( |
||||
|
"/home/mm/Downloads/FF06F0EC-1B54-458A-BF12-FF7FC2A43C10.jpeg" |
||||
|
), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
# start the task |
||||
|
run_response = c.run(config) |
||||
|
|
||||
|
print("Intitial response") |
||||
|
# check status of the task, returns the output too if the task is complete |
||||
|
results = c.fetch(token=run_response["token"]) |
||||
|
print(results) |
||||
|
|
||||
|
# one eternity later |
||||
|
# time.sleep(5) |
||||
|
|
||||
|
print("Trying") |
||||
|
while results["status"].get("status") != "complete": |
||||
|
results = c.fetch(token=run_response["token"]) |
||||
|
print(results) |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
## check status again, hopefully the task is complete by now |
||||
|
# results = c.fetch(token=run_response["token"]) |
||||
|
# print(results) |
||||
|
# results['output']['image'].show() |
@ -0,0 +1,41 @@ |
|||||
|
import time |
||||
|
from eden.client import Client |
||||
|
from eden.datatypes import Image |
||||
|
|
||||
|
## set up a client |
||||
|
c = Client(url="http://0.0.0.0:5656", username="abraham") |
||||
|
|
||||
|
# get server's identity |
||||
|
generator_id = c.get_generator_identity() |
||||
|
print(generator_id) |
||||
|
|
||||
|
## define input args to be sent |
||||
|
config = { |
||||
|
"width": 2000, ## width |
||||
|
"height": 1000, ## height |
||||
|
"input_image": Image( |
||||
|
"/home/mm/Downloads/FF06F0EC-1B54-458A-BF12-FF7FC2A43C10.jpeg" |
||||
|
), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
# start the task |
||||
|
run_response = c.run(config) |
||||
|
|
||||
|
print("Intitial response") |
||||
|
# check status of the task, returns the output too if the task is complete |
||||
|
results = c.fetch(token=run_response["token"]) |
||||
|
print(results) |
||||
|
|
||||
|
# one eternity later |
||||
|
# time.sleep(5) |
||||
|
|
||||
|
print("Trying") |
||||
|
while results["status"].get("status") != "complete": |
||||
|
results = c.fetch(token=run_response["token"]) |
||||
|
print(results) |
||||
|
time.sleep(0.1) |
||||
|
|
||||
|
## check status again, hopefully the task is complete by now |
||||
|
# results = c.fetch(token=run_response["token"]) |
||||
|
# print(results) |
||||
|
# results['output']['image'].show() |
@ -0,0 +1,47 @@ |
|||||
|
# docker-compose for redis service defined in ./redis |
||||
|
version: '3.7' |
||||
|
|
||||
|
services: |
||||
|
redis: |
||||
|
build: ./redis |
||||
|
image: redis |
||||
|
ports: |
||||
|
- "6379:6379" |
||||
|
volumes: |
||||
|
- ./data:/data |
||||
|
networks: |
||||
|
- default |
||||
|
|
||||
|
# eden server, started with python server.py, based on Dockerfile in cwd. |
||||
|
eden-server: |
||||
|
build: ./eden-server |
||||
|
image: eden-server |
||||
|
# ports: |
||||
|
# - "5656:5656" |
||||
|
volumes: |
||||
|
- /home/mm/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth:/root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth |
||||
|
networks: |
||||
|
- default |
||||
|
depends_on: |
||||
|
- redis |
||||
|
# pass nvidia gpu |
||||
|
runtime: nvidia |
||||
|
environment: |
||||
|
- CUDA_VISIBLE_DEVICES=0 |
||||
|
- NVIDIA_VISIBLE_DEVICES=0 |
||||
|
|
||||
|
# load-balancer: |
||||
|
# image: nginx |
||||
|
# ports: |
||||
|
# - "5656:80" |
||||
|
# volumes: |
||||
|
# - ./nginx.conf:/etc/nginx/nginx.conf:ro |
||||
|
# networks: |
||||
|
# - default |
||||
|
# depends_on: |
||||
|
# - eden-server |
||||
|
|
||||
|
networks: |
||||
|
default: |
||||
|
name: eden-network |
||||
|
external: true |
@ -0,0 +1,42 @@ |
|||||
|
FROM pytorch/pytorch:1.13.1-cuda11.6-cudnn8-runtime |
||||
|
|
||||
|
RUN apt-get update && apt-get install -y \ |
||||
|
libgl1-mesa-glx \ |
||||
|
libglib2.0-0 \ |
||||
|
&& rm -rf /var/lib/apt/lists/* |
||||
|
|
||||
|
|
||||
|
# until we hack around gitpython, we need git |
||||
|
# RUN apt-get update && apt-get install -y \ |
||||
|
# git \ |
||||
|
# && rm -rf /var/lib/apt/lists/* |
||||
|
|
||||
|
WORKDIR /app |
||||
|
# create a safe user |
||||
|
RUN useradd -ms /bin/bash eden |
||||
|
# make them own /app |
||||
|
RUN chown eden:eden /app |
||||
|
|
||||
|
USER eden |
||||
|
# add /home/eden/.local/bin to PATH |
||||
|
ENV PATH="/home/eden/.local/bin:${PATH}" |
||||
|
RUN pip install eden-python |
||||
|
RUN pip install python-socketio[asyncio_server] aiohttp |
||||
|
COPY server.py . |
||||
|
# attempted bugfix |
||||
|
COPY image_utils.py /home/eden/.local/lib/python3.10/site-packages/eden/image_utils.py |
||||
|
# attempt git-python hackaround |
||||
|
COPY hosting.py /home/eden/.local/lib/python3.10/site-packages/eden/hosting.py |
||||
|
|
||||
|
EXPOSE 5656 |
||||
|
# ENV GIT_PYTHON_REFRESH=quiet |
||||
|
# hack around gitpython |
||||
|
# RUN git init . |
||||
|
# RUN git config --global user.email "none@site.com" |
||||
|
# RUN git config --global user.name "eden-service-user" |
||||
|
# # add fake remote upstream |
||||
|
# RUN git remote add origin https://git.clfx.cc/mm/eden-app.git |
||||
|
# RUN git add server.py |
||||
|
# RUN git commit -am "initial commit" |
||||
|
ENV GIT_PYTHON_REFRESH=quiet |
||||
|
CMD ["python", "server.py"] |
@ -0,0 +1,55 @@ |
|||||
|
import asyncio |
||||
|
from functools import wraps |
||||
|
import socketio |
||||
|
import socket |
||||
|
|
||||
|
def get_ip_address(): |
||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
||||
|
try: |
||||
|
# This IP address doesn't need to be reachable, as we're only using it to find the local IP address |
||||
|
s.connect(("10.255.255.255", 1)) |
||||
|
ip = s.getsockname()[0] |
||||
|
except Exception: |
||||
|
ip = "127.0.0.1" |
||||
|
finally: |
||||
|
s.close() |
||||
|
return ip |
||||
|
|
||||
|
# Update these with the correct values for your host and server |
||||
|
HOST_SERVER_IP = "192.168.1.113" |
||||
|
HOST_SERVER_PORT = 4999 |
||||
|
SERVER_NAME = "server_1" |
||||
|
SERVER_IP = get_ip_address() |
||||
|
SERVER_PORT = 8000 |
||||
|
|
||||
|
sio = socketio.AsyncClient() |
||||
|
|
||||
|
async def announce_server(): |
||||
|
await sio.connect(f'http://{HOST_SERVER_IP}:{HOST_SERVER_PORT}') |
||||
|
await sio.emit('register', {'name': SERVER_NAME, 'ip': SERVER_IP, 'port': SERVER_PORT}) |
||||
|
|
||||
|
@sio.on("heartbeat") |
||||
|
async def on_heartbeat(): |
||||
|
print("Received heartbeat from host") |
||||
|
|
||||
|
@sio.event |
||||
|
async def disconnect(): |
||||
|
print("Disconnected from host") |
||||
|
|
||||
|
def announce_server_decorator(host_block_function): |
||||
|
@wraps(host_block_function) |
||||
|
def wrapper(*args, **kwargs): |
||||
|
loop = asyncio.get_event_loop() |
||||
|
|
||||
|
# Start the server announcement task |
||||
|
announce_task = loop.create_task(announce_server()) |
||||
|
|
||||
|
# Run the original host_block function |
||||
|
result = host_block_function(*args, **kwargs) |
||||
|
|
||||
|
# Cancel the announcement task after the host_block function is done |
||||
|
announce_task.cancel() |
||||
|
|
||||
|
return result |
||||
|
|
||||
|
return wrapper |
@ -0,0 +1,515 @@ |
|||||
|
import os |
||||
|
import git |
||||
|
import warnings |
||||
|
import uvicorn |
||||
|
import logging |
||||
|
from fastapi import FastAPI |
||||
|
from prometheus_client import Gauge |
||||
|
from starlette_exporter import PrometheusMiddleware, handle_metrics |
||||
|
from fastapi.middleware.cors import CORSMiddleware |
||||
|
|
||||
|
from .datatypes import Image |
||||
|
from .queue import QueueData |
||||
|
from .log_utils import Colors |
||||
|
from .models import Credentials, WaitFor |
||||
|
from .result_storage import ResultStorage |
||||
|
from .config_wrapper import ConfigWrapper |
||||
|
from .data_handlers import Encoder, Decoder |
||||
|
from .threaded_server import ThreadedServer |
||||
|
from .progress_tracker import fetch_progress_from_token |
||||
|
from .log_utils import log_levels, celery_log_levels, PREFIX |
||||
|
from .prometheus_utils import PrometheusMetrics |
||||
|
|
||||
|
from .utils import stop_everything_gracefully, generate_random_string |
||||
|
|
||||
|
from uvicorn.config import LOGGING_CONFIG |
||||
|
|
||||
|
""" |
||||
|
Celery+redis is needed to be able to queue tasks |
||||
|
""" |
||||
|
from celery import Celery |
||||
|
from .celery_utils import run_celery_app |
||||
|
|
||||
|
""" |
||||
|
tool to allocate gpus on queued tasks |
||||
|
""" |
||||
|
from .gpu_allocator import GPUAllocator |
||||
|
|
||||
|
|
||||
|
def host_block( |
||||
|
block, |
||||
|
port=8080, |
||||
|
host="0.0.0.0", |
||||
|
max_num_workers=4, |
||||
|
redis_port=6379, |
||||
|
redis_host="localhost", |
||||
|
requires_gpu=True, |
||||
|
log_level="warning", |
||||
|
logfile="logs.log", |
||||
|
exclude_gpu_ids: list = [], |
||||
|
remove_result_on_fetch = False |
||||
|
): |
||||
|
""" |
||||
|
Use this to host your eden.Block on a server. Supports multiple GPUs and queues tasks automatically with celery. |
||||
|
|
||||
|
Args: |
||||
|
block (eden.block.Block): The eden block you'd want to host. |
||||
|
port (int, optional): Localhost port where the block would be hosted. Defaults to 8080. |
||||
|
host (str): specifies where the endpoint would be hosted. Defaults to '0.0.0.0'. |
||||
|
max_num_workers (int, optional): Maximum number of tasks to run in parallel. Defaults to 4. |
||||
|
redis_port (int, optional): Port number for celery's redis server. Defaults to 6379. |
||||
|
redis_host (str, optional): Place to host redis for `eden.queue.QueueData`. Defaults to localhost. |
||||
|
requires_gpu (bool, optional): Set this to False if your tasks dont necessarily need GPUs. |
||||
|
log_level (str, optional): Can be 'debug', 'info', or 'warning'. Defaults to 'warning' |
||||
|
logfile(str, optional): Name of the file where the logs would be stored. If set to None, it will show all logs on stdout. Defaults to 'logs.log' |
||||
|
exclude_gpu_ids (list, optional): List of gpu ids to not use for hosting. Example: [2,3] |
||||
|
""" |
||||
|
|
||||
|
""" |
||||
|
Response templates: |
||||
|
|
||||
|
/run: |
||||
|
{ |
||||
|
'token': some_long_token, |
||||
|
} |
||||
|
|
||||
|
/fetch: |
||||
|
if task is queued: |
||||
|
{ |
||||
|
'status': { |
||||
|
'status': queued, |
||||
|
'queue_position': int |
||||
|
}, |
||||
|
config: current_config |
||||
|
} |
||||
|
|
||||
|
elif task is running: |
||||
|
{ |
||||
|
'status': { |
||||
|
'status': 'running', |
||||
|
'progress': float between 0 and 1, |
||||
|
|
||||
|
}, |
||||
|
config: current_config, |
||||
|
'output': {} ## optionally the user should be able to write outputs here |
||||
|
} |
||||
|
elif task failed: |
||||
|
{ |
||||
|
'status': { |
||||
|
'status': 'failed', |
||||
|
} |
||||
|
'config': current_config, |
||||
|
'output': {} ## will still include the outputs if any so that it gets returned even though the task failed |
||||
|
} |
||||
|
elif task succeeded: |
||||
|
{ |
||||
|
'status': { |
||||
|
'status': 'complete' |
||||
|
}, |
||||
|
'output': user_output, |
||||
|
'config': config |
||||
|
} |
||||
|
""" |
||||
|
|
||||
|
""" |
||||
|
Initiating celery app |
||||
|
""" |
||||
|
celery_app = Celery(__name__, broker=f"redis://{redis_host}:{str(redis_port)}") |
||||
|
celery_app.conf.broker_url = os.environ.get( |
||||
|
"CELERY_BROKER_URL", f"redis://{redis_host}:{str(redis_port)}" |
||||
|
) |
||||
|
celery_app.conf.result_backend = os.environ.get( |
||||
|
"CELERY_RESULT_BACKEND", f"redis://{redis_host}:{str(redis_port)}" |
||||
|
) |
||||
|
celery_app.conf.task_track_started = os.environ.get( |
||||
|
"CELERY_TRACK_STARTED", default=True |
||||
|
) |
||||
|
|
||||
|
celery_app.conf.worker_send_task_events = True |
||||
|
celery_app.conf.task_send_sent_event = True |
||||
|
|
||||
|
""" |
||||
|
each block gets its wown queue |
||||
|
""" |
||||
|
celery_app.conf.task_default_queue = block.name |
||||
|
|
||||
|
""" |
||||
|
set prefetch mult to 1 so that tasks dont get pre-fetched by workers |
||||
|
""" |
||||
|
celery_app.conf.worker_prefetch_multiplier = 1 |
||||
|
|
||||
|
""" |
||||
|
task messages will be acknowledged after the task has been executed |
||||
|
""" |
||||
|
celery_app.conf.task_acks_late = True |
||||
|
|
||||
|
""" |
||||
|
Initiating GPUAllocator only if requires_gpu is True |
||||
|
""" |
||||
|
if requires_gpu == True: |
||||
|
gpu_allocator = GPUAllocator(exclude_gpu_ids=exclude_gpu_ids) |
||||
|
else: |
||||
|
print(PREFIX + " Initiating server with no GPUs since requires_gpu = False") |
||||
|
|
||||
|
if requires_gpu == True: |
||||
|
if gpu_allocator.num_gpus < max_num_workers: |
||||
|
""" |
||||
|
if a task requires a gpu, and the number of workers is > the number of available gpus, |
||||
|
then max_num_workers is automatically set to the number of gpus available |
||||
|
this is because eden assumes that each task requires one gpu (all of it) |
||||
|
""" |
||||
|
warnings.warn( |
||||
|
"max_num_workers is greater than the number of GPUs found, overriding max_num_workers to be: " |
||||
|
+ str(gpu_allocator.num_gpus) |
||||
|
) |
||||
|
max_num_workers = gpu_allocator.num_gpus |
||||
|
|
||||
|
""" |
||||
|
Initiating queue data to keep track of the queue |
||||
|
""" |
||||
|
queue_data = QueueData( |
||||
|
redis_port=redis_port, redis_host=redis_host, queue_name=block.name |
||||
|
) |
||||
|
|
||||
|
""" |
||||
|
Initiate encoder and decoder |
||||
|
""" |
||||
|
|
||||
|
data_encoder = Encoder() |
||||
|
data_decoder = Decoder() |
||||
|
|
||||
|
""" |
||||
|
Initiate fastAPI app |
||||
|
""" |
||||
|
app = FastAPI() |
||||
|
origins = ["*"] |
||||
|
app.add_middleware( |
||||
|
CORSMiddleware, |
||||
|
allow_origins=origins, |
||||
|
allow_credentials=True, |
||||
|
allow_methods=["*"], |
||||
|
allow_headers=["*"], |
||||
|
) |
||||
|
app.add_middleware(PrometheusMiddleware) |
||||
|
app.add_route("/metrics", handle_metrics) |
||||
|
|
||||
|
""" |
||||
|
Initiate result storage on redis |
||||
|
""" |
||||
|
|
||||
|
result_storage = ResultStorage( |
||||
|
redis_host=redis_host, |
||||
|
redis_port=redis_port, |
||||
|
) |
||||
|
|
||||
|
## set up result storage and data encoder for block |
||||
|
block.result_storage = result_storage |
||||
|
block.data_encoder = data_encoder |
||||
|
|
||||
|
""" |
||||
|
initiate a wrapper which handles 4 metrics for prometheus: |
||||
|
* number of queued jobs |
||||
|
* number of running jobs |
||||
|
* number of failed jobs |
||||
|
* number of succeeded jobs |
||||
|
""" |
||||
|
prometheus_metrics = PrometheusMetrics() |
||||
|
|
||||
|
""" |
||||
|
define celery task |
||||
|
""" |
||||
|
|
||||
|
@celery_app.task(name="run") |
||||
|
def run(args, token: str): |
||||
|
|
||||
|
## job moves from queue to running |
||||
|
prometheus_metrics.queued.dec(1) |
||||
|
prometheus_metrics.running.inc(1) |
||||
|
|
||||
|
args = data_decoder.decode(args) |
||||
|
""" |
||||
|
allocating a GPU ID to the tast based on usage |
||||
|
for now let's settle for max 1 GPU per task :( |
||||
|
""" |
||||
|
|
||||
|
if requires_gpu == True: |
||||
|
# returns None if there are no gpus available |
||||
|
gpu_name = gpu_allocator.get_gpu() |
||||
|
else: |
||||
|
gpu_name = None ## default value either if there are no gpus available or requires_gpu = False |
||||
|
|
||||
|
""" |
||||
|
If there are no GPUs available, then it returns a sad message. |
||||
|
But if there ARE GPUs available, then it starts run() |
||||
|
""" |
||||
|
if ( |
||||
|
gpu_name == None and requires_gpu == True |
||||
|
): ## making sure there are no gpus available |
||||
|
|
||||
|
status = { |
||||
|
"status": "No GPUs are available at the moment, please try again later", |
||||
|
} |
||||
|
|
||||
|
else: |
||||
|
|
||||
|
""" |
||||
|
refer: |
||||
|
https://github.com/abraham-ai/eden/issues/14 |
||||
|
""" |
||||
|
args = ConfigWrapper( |
||||
|
data=args, |
||||
|
token=token, |
||||
|
result_storage=result_storage, |
||||
|
gpu=None, ## will be provided later on in the run |
||||
|
progress=None, ## will be provided later on in the run |
||||
|
) |
||||
|
|
||||
|
if requires_gpu == True: |
||||
|
args.gpu = gpu_name |
||||
|
|
||||
|
if block.progress == True: |
||||
|
""" |
||||
|
if progress was set to True on @eden.Block.run() decorator, then add a progress tracker into the config |
||||
|
""" |
||||
|
args.progress = block.get_progress_bar( |
||||
|
token=token, result_storage=result_storage |
||||
|
) |
||||
|
|
||||
|
try: |
||||
|
output = block.__run__(args) |
||||
|
|
||||
|
# job moves from running to succeeded |
||||
|
prometheus_metrics.running.dec(1) |
||||
|
prometheus_metrics.succeeded.inc(1) |
||||
|
|
||||
|
# prevent further jobs from hitting a busy gpu after a caught exception |
||||
|
except Exception as e: |
||||
|
|
||||
|
# job moves from running to failed |
||||
|
prometheus_metrics.running.dec(1) |
||||
|
prometheus_metrics.failed.inc(1) |
||||
|
if requires_gpu == True: |
||||
|
gpu_allocator.set_as_free(name=gpu_name) |
||||
|
raise Exception(str(e)) |
||||
|
|
||||
|
if requires_gpu == True: |
||||
|
gpu_allocator.set_as_free(name=gpu_name) |
||||
|
|
||||
|
success = block.write_results(output=output, token=token) |
||||
|
|
||||
|
return success ## return None because results go to result_storage instead |
||||
|
|
||||
|
@app.post("/run") |
||||
|
def start_run(config: block.data_model): |
||||
|
|
||||
|
## job moves into queue |
||||
|
prometheus_metrics.queued.inc(1) |
||||
|
|
||||
|
""" |
||||
|
refer: |
||||
|
https://github.com/celery/celery/issues/1813#issuecomment-33142648 |
||||
|
""" |
||||
|
token = generate_random_string(len=10) |
||||
|
|
||||
|
kwargs = dict(args=dict(config), token=token) |
||||
|
|
||||
|
res = run.apply_async(kwargs=kwargs, task_id=token, queue_name=block.name) |
||||
|
|
||||
|
initial_dict = {"config": dict(config), "output": {}, "progress": "__none__"} |
||||
|
|
||||
|
success = result_storage.add(token=token, encoded_results=initial_dict) |
||||
|
|
||||
|
response = {"token": token} |
||||
|
|
||||
|
return response |
||||
|
|
||||
|
@app.post("/update") |
||||
|
def update(credentials: Credentials, config: block.data_model): |
||||
|
|
||||
|
token = credentials.token |
||||
|
config = dict(config) |
||||
|
|
||||
|
status = queue_data.get_status(token=token) |
||||
|
|
||||
|
if status["status"] != "invalid token": |
||||
|
|
||||
|
if ( |
||||
|
status["status"] == "queued" |
||||
|
or status["status"] == "running" |
||||
|
or status["status"] == "starting" |
||||
|
): |
||||
|
|
||||
|
output_from_storage = result_storage.get(token=token) |
||||
|
output_from_storage["config"] = config |
||||
|
|
||||
|
success = result_storage.add( |
||||
|
encoded_results=output_from_storage, token=token |
||||
|
) |
||||
|
|
||||
|
response = { |
||||
|
"status": { |
||||
|
"status": "successfully updated config", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return response |
||||
|
|
||||
|
elif status["status"] == "failed": |
||||
|
|
||||
|
return { |
||||
|
"status": { |
||||
|
"status": "could not update config because job failed", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
elif status["status"] == "complete": |
||||
|
|
||||
|
return { |
||||
|
"status": { |
||||
|
"status": "could not update config because job is already complete", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
else: |
||||
|
response = {"status": {"status": "invalid token"}} |
||||
|
return response |
||||
|
|
||||
|
@app.post("/fetch") |
||||
|
def fetch(credentials: Credentials): |
||||
|
""" |
||||
|
Returns either the status of the task or the result depending on whether it's queued, running, complete or failed. |
||||
|
|
||||
|
Args: |
||||
|
credentials (Credentials): should contain a token that points to a task |
||||
|
""" |
||||
|
|
||||
|
token = credentials.token |
||||
|
|
||||
|
status = queue_data.get_status(token=token) |
||||
|
|
||||
|
if status["status"] != "invalid token": |
||||
|
|
||||
|
if status["status"] == "running": |
||||
|
|
||||
|
results = result_storage.get(token=token) |
||||
|
|
||||
|
response = { |
||||
|
"status": status, |
||||
|
"config": results["config"], |
||||
|
"output": results["output"], |
||||
|
} |
||||
|
|
||||
|
if block.progress == True: |
||||
|
progress_value = fetch_progress_from_token( |
||||
|
result_storage=result_storage, token=token |
||||
|
) |
||||
|
response["status"]["progress"] = progress_value |
||||
|
|
||||
|
elif status["status"] == "complete": |
||||
|
|
||||
|
results = result_storage.get(token=token) |
||||
|
|
||||
|
## if results are deleted, it still returns the same schema |
||||
|
if results == None and remove_result_on_fetch == True: |
||||
|
response = { |
||||
|
"status": { |
||||
|
"status": "removed" |
||||
|
}, |
||||
|
} |
||||
|
else: |
||||
|
response = { |
||||
|
"status": status, |
||||
|
"config": results["config"], |
||||
|
"output": results["output"], |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
if remove_result_on_fetch == True: |
||||
|
result_storage.delete(token=token) |
||||
|
|
||||
|
elif ( |
||||
|
status["status"] == "queued" |
||||
|
or status["status"] == "starting" |
||||
|
or status["status"] == "failed" |
||||
|
or status["status"] == "revoked" |
||||
|
): |
||||
|
|
||||
|
results = result_storage.get(token=token) |
||||
|
|
||||
|
response = {"status": status, "config": results["config"]} |
||||
|
|
||||
|
else: |
||||
|
|
||||
|
response = {"status": status} ## invalid token |
||||
|
|
||||
|
return response |
||||
|
|
||||
|
@app.post("/stop") |
||||
|
async def stop(wait_for: WaitFor): |
||||
|
""" |
||||
|
Stops the eden block, and exits the script |
||||
|
|
||||
|
Args: |
||||
|
config (dict, optional): Amount of time in seconds before the server shuts down. Defaults to {'time': 0}. |
||||
|
""" |
||||
|
logging.info(f"Stopping gracefully in {wait_for.seconds} seconds") |
||||
|
stop_everything_gracefully(t=wait_for.seconds) |
||||
|
|
||||
|
@app.post("/get_identity") |
||||
|
def get_identity(): |
||||
|
""" |
||||
|
Returns name and active commit hash of the generator |
||||
|
""" |
||||
|
try: |
||||
|
repo = git.Repo(search_parent_directories=True) |
||||
|
name = repo.remotes.origin.url.split('.git')[0].split('/')[-1] |
||||
|
sha = repo.head.object.hexsha |
||||
|
except git.exc.InvalidGitRepositoryError: |
||||
|
name = "repo-less-eden" |
||||
|
sha = "none" |
||||
|
|
||||
|
response = { |
||||
|
"name": name, |
||||
|
"commit": sha |
||||
|
} |
||||
|
|
||||
|
return response |
||||
|
|
||||
|
|
||||
|
## overriding the boring old [INFO] thingy |
||||
|
LOGGING_CONFIG["formatters"]["default"]["fmt"] = ( |
||||
|
"[" + Colors.CYAN + "EDEN" + Colors.END + "] %(asctime)s %(message)s" |
||||
|
) |
||||
|
LOGGING_CONFIG["formatters"]["access"]["fmt"] = ( |
||||
|
"[" |
||||
|
+ Colors.CYAN |
||||
|
+ "EDEN" |
||||
|
+ Colors.END |
||||
|
+ "] %(levelprefix)s %(client_addr)s - '%(request_line)s' %(status_code)s" |
||||
|
) |
||||
|
|
||||
|
config = uvicorn.config.Config(app=app, host=host, port=port, log_level=log_level) |
||||
|
server = ThreadedServer(config=config) |
||||
|
|
||||
|
# context starts fastAPI stuff and run_celery_app starts celery |
||||
|
with server.run_in_thread(): |
||||
|
message = ( |
||||
|
PREFIX |
||||
|
+ " Initializing celery worker on: " |
||||
|
+ f"redis://localhost:{str(redis_port)}" |
||||
|
) |
||||
|
print(message) |
||||
|
## starts celery app |
||||
|
run_celery_app( |
||||
|
celery_app, |
||||
|
max_num_workers=max_num_workers, |
||||
|
loglevel=celery_log_levels[log_level], |
||||
|
logfile=logfile, |
||||
|
queue_name=block.name, |
||||
|
) |
||||
|
|
||||
|
message = PREFIX + " Stopped" |
||||
|
|
||||
|
print(message) |
||||
|
|
@ -0,0 +1,75 @@ |
|||||
|
import PIL |
||||
|
import cv2 |
||||
|
import base64 |
||||
|
import numpy as np |
||||
|
from PIL.Image import Image as ImageFile |
||||
|
from PIL.JpegImagePlugin import JpegImageFile |
||||
|
from PIL.PngImagePlugin import PngImageFile |
||||
|
from PIL import Image |
||||
|
from io import BytesIO |
||||
|
|
||||
|
|
||||
|
def _encode_numpy_array_image(image): |
||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||
|
|
||||
|
if image.shape[-1] == 3: |
||||
|
_, buffer = cv2.imencode(".jpg", image) |
||||
|
|
||||
|
elif image.shape[-1] == 4: |
||||
|
_, buffer = cv2.imencode(".png", image) |
||||
|
|
||||
|
image_as_text = base64.b64encode(buffer) |
||||
|
|
||||
|
return image_as_text |
||||
|
|
||||
|
|
||||
|
def _encode_pil_image(image): |
||||
|
opencv_image = np.array(image) |
||||
|
image_as_text = _encode_numpy_array_image(image=opencv_image) |
||||
|
|
||||
|
return image_as_text |
||||
|
|
||||
|
|
||||
|
def _encode_image_file(image): |
||||
|
pil_image = Image.open(image) |
||||
|
|
||||
|
return _encode_pil_image(pil_image) |
||||
|
|
||||
|
|
||||
|
def encode(image): |
||||
|
|
||||
|
if ( |
||||
|
type(image) == np.ndarray |
||||
|
or type(image) == str |
||||
|
or isinstance( |
||||
|
image, |
||||
|
( |
||||
|
JpegImageFile, |
||||
|
PngImageFile, |
||||
|
ImageFile, |
||||
|
), |
||||
|
) |
||||
|
): |
||||
|
|
||||
|
if type(image) == np.ndarray: |
||||
|
image_as_text = _encode_numpy_array_image(image) |
||||
|
|
||||
|
elif type(image) == str: |
||||
|
image_as_text = _encode_image_file(image) |
||||
|
|
||||
|
else: |
||||
|
image_as_text = _encode_pil_image(image) |
||||
|
|
||||
|
return image_as_text.decode("ascii") |
||||
|
|
||||
|
else: |
||||
|
raise Exception( |
||||
|
"expected numpy.array, PIL.Image or str, not: ", str(type(image)) |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def decode(jpg_as_text): |
||||
|
if jpg_as_text is None: |
||||
|
return None |
||||
|
pil_image = Image.open(BytesIO(base64.b64decode(jpg_as_text))) |
||||
|
return pil_image |
@ -0,0 +1,69 @@ |
|||||
|
from eden.block import Block |
||||
|
from eden.datatypes import Image |
||||
|
from eden.hosting import host_block |
||||
|
|
||||
|
## eden <3 pytorch |
||||
|
from torchvision import models, transforms |
||||
|
import torch |
||||
|
|
||||
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
||||
|
model = model.eval() ## no dont move it to the gpu just yet :) |
||||
|
|
||||
|
my_transforms = transforms.Compose( |
||||
|
[ |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model |
||||
|
] |
||||
|
) |
||||
|
|
||||
|
eden_block = Block() |
||||
|
|
||||
|
my_args = { |
||||
|
"width": 224, ## width |
||||
|
"height": 224, ## height |
||||
|
"input_image": Image(), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
import requests |
||||
|
labels = requests.get( |
||||
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
||||
|
).text.split("\n") |
||||
|
|
||||
|
|
||||
|
@eden_block.run(args=my_args, progress=False) |
||||
|
def do_something(config): |
||||
|
global model, labels |
||||
|
|
||||
|
pil_image = config["input_image"] |
||||
|
pil_image = pil_image.resize((config["width"], config["height"])) |
||||
|
|
||||
|
device = config.gpu |
||||
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) |
||||
|
|
||||
|
model = model.to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
pred = model(input_tensor)[0].cpu() |
||||
|
index = torch.argmax(pred).item() |
||||
|
value = pred[index].item() |
||||
|
# the index is the classification label for the pretrained resnet18 model. |
||||
|
# the human-readable labels associated with this index are pulled and returned as "label" |
||||
|
# we need to get them from imagenet labels, which we need to get online. |
||||
|
|
||||
|
label = labels[index] |
||||
|
# serialize the image |
||||
|
pil_image = Image(pil_image) |
||||
|
return {"value": value, "index": index, "label": label, 'image': pil_image} |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
host_block( |
||||
|
block=eden_block, |
||||
|
port=5656, |
||||
|
host="0.0.0.0", |
||||
|
redis_host="redis", |
||||
|
# logfile="log.log", |
||||
|
logfile=None, |
||||
|
log_level="debug", |
||||
|
max_num_workers=1, |
||||
|
requires_gpu=True, |
||||
|
) |
@ -0,0 +1,123 @@ |
|||||
|
from eden.block import Block |
||||
|
from eden.datatypes import Image |
||||
|
from eden.hosting import host_block |
||||
|
|
||||
|
## eden <3 pytorch |
||||
|
from torchvision import models, transforms |
||||
|
import torch |
||||
|
|
||||
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
||||
|
model = model.eval() ## no dont move it to the gpu just yet :) |
||||
|
|
||||
|
my_transforms = transforms.Compose( |
||||
|
[ |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model |
||||
|
] |
||||
|
) |
||||
|
|
||||
|
eden_block = Block() |
||||
|
|
||||
|
my_args = { |
||||
|
"width": 224, ## width |
||||
|
"height": 224, ## height |
||||
|
"input_image": Image(), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
import requests |
||||
|
labels = requests.get( |
||||
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
||||
|
).text.split("\n") |
||||
|
|
||||
|
|
||||
|
@eden_block.run(args=my_args, progress=False) |
||||
|
def do_something(config): |
||||
|
global model, labels |
||||
|
|
||||
|
pil_image = config["input_image"] |
||||
|
pil_image = pil_image.resize((config["width"], config["height"])) |
||||
|
|
||||
|
device = config.gpu |
||||
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) |
||||
|
|
||||
|
model = model.to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
pred = model(input_tensor)[0].cpu() |
||||
|
index = torch.argmax(pred).item() |
||||
|
value = pred[index].item() |
||||
|
# the index is the classification label for the pretrained resnet18 model. |
||||
|
# the human-readable labels associated with this index are pulled and returned as "label" |
||||
|
# we need to get them from imagenet labels, which we need to get online. |
||||
|
|
||||
|
label = labels[index] |
||||
|
# serialize the image |
||||
|
pil_image = Image(pil_image) |
||||
|
return {"value": value, "index": index, "label": label, 'image': pil_image} |
||||
|
|
||||
|
|
||||
|
def run_host_block(): |
||||
|
host_block( |
||||
|
block=eden_block, |
||||
|
port=5656, |
||||
|
host="0.0.0.0", |
||||
|
redis_host="redis", |
||||
|
# logfile="log.log", |
||||
|
logfile=None, |
||||
|
log_level="debug", |
||||
|
max_num_workers=1, |
||||
|
requires_gpu=True, |
||||
|
) |
||||
|
|
||||
|
import asyncio |
||||
|
import socketio |
||||
|
import socket |
||||
|
|
||||
|
def get_ip_address(): |
||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
||||
|
try: |
||||
|
# This IP address doesn't need to be reachable, as we're only using it to find the local IP address |
||||
|
s.connect(("10.255.255.255", 1)) |
||||
|
ip = s.getsockname()[0] |
||||
|
except Exception: |
||||
|
ip = "127.0.0.1" |
||||
|
finally: |
||||
|
s.close() |
||||
|
return ip |
||||
|
|
||||
|
|
||||
|
# Update these with the correct values for your host and server |
||||
|
HOST_SERVER_IP = "0.0.0.0" |
||||
|
HOST_SERVER_PORT = 4999 |
||||
|
SERVER_NAME = "server_1" |
||||
|
SERVER_IP = get_ip_address() |
||||
|
SERVER_PORT = 8000 |
||||
|
|
||||
|
sio = socketio.AsyncClient() |
||||
|
|
||||
|
async def announce_server(): |
||||
|
await sio.connect(f'http://{HOST_SERVER_IP}:{HOST_SERVER_PORT}') |
||||
|
await sio.emit('register', {'name': SERVER_NAME, 'ip': SERVER_IP, 'port': SERVER_PORT}) |
||||
|
|
||||
|
@sio.on("heartbeat") |
||||
|
async def on_heartbeat(): |
||||
|
print("Received heartbeat from host") |
||||
|
|
||||
|
@sio.event |
||||
|
async def disconnect(): |
||||
|
print("Disconnected from host") |
||||
|
|
||||
|
async def main(): |
||||
|
# Run host_block in a separate thread |
||||
|
loop = asyncio.get_event_loop() |
||||
|
host_block_thread = loop.run_in_executor(None, run_host_block) |
||||
|
|
||||
|
# Announce the server to the host |
||||
|
await announce_server() |
||||
|
|
||||
|
# Wait for host_block to finish |
||||
|
await host_block_thread |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
asyncio.run(main()) |
@ -0,0 +1,24 @@ |
|||||
|
worker_processes 1; |
||||
|
|
||||
|
events { |
||||
|
worker_connections 1024; |
||||
|
} |
||||
|
|
||||
|
http { |
||||
|
upstream eden-servers { |
||||
|
server eden-server:5656; |
||||
|
} |
||||
|
|
||||
|
server { |
||||
|
listen 80; |
||||
|
server_name _; |
||||
|
|
||||
|
location / { |
||||
|
proxy_pass http://eden-servers; |
||||
|
proxy_set_header Host $host; |
||||
|
proxy_set_header X-Real-IP $remote_addr; |
||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
|
@ -0,0 +1,12 @@ |
|||||
|
# Use an official Redis image as a parent image |
||||
|
FROM redis:latest |
||||
|
|
||||
|
# Set the working directory to /data |
||||
|
WORKDIR /data |
||||
|
|
||||
|
# Expose Redis port |
||||
|
EXPOSE 6379 |
||||
|
|
||||
|
# Run Redis server as daemon |
||||
|
#CMD ["redis-server", "--daemonize", "yes"] |
||||
|
CMD ["redis-server", "--daemonize", "no"] |
@ -0,0 +1,66 @@ |
|||||
|
from eden.block import Block |
||||
|
from eden.datatypes import Image |
||||
|
from eden.hosting import host_block |
||||
|
|
||||
|
## eden <3 pytorch |
||||
|
from torchvision import models, transforms |
||||
|
import torch |
||||
|
|
||||
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
||||
|
model = model.eval() ## no dont move it to the gpu just yet :) |
||||
|
|
||||
|
my_transforms = transforms.Compose( |
||||
|
[ |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model |
||||
|
] |
||||
|
) |
||||
|
|
||||
|
eden_block = Block() |
||||
|
|
||||
|
my_args = { |
||||
|
"width": 224, ## width |
||||
|
"height": 224, ## height |
||||
|
"input_image": Image(), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
import requests |
||||
|
labels = requests.get( |
||||
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
||||
|
).text.split("\n") |
||||
|
|
||||
|
|
||||
|
@eden_block.run(args=my_args, progress=False) |
||||
|
def do_something(config): |
||||
|
global model, labels |
||||
|
|
||||
|
pil_image = config["input_image"] |
||||
|
pil_image = pil_image.resize((config["width"], config["height"])) |
||||
|
|
||||
|
device = config.gpu |
||||
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) |
||||
|
|
||||
|
model = model.to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
pred = model(input_tensor)[0].cpu() |
||||
|
index = torch.argmax(pred).item() |
||||
|
value = pred[index].item() |
||||
|
# the index is the classification label for the pretrained resnet18 model. |
||||
|
# the human-readable labels associated with this index are pulled and returned as "label" |
||||
|
# we need to get them from imagenet labels, which we need to get online. |
||||
|
|
||||
|
label = labels[index] |
||||
|
# serialize the image |
||||
|
pil_image = Image(pil_image) |
||||
|
return {"value": value, "index": index, "label": label, 'image': pil_image} |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
host_block( |
||||
|
block=eden_block, |
||||
|
port=5655, |
||||
|
logfile="log2.log", |
||||
|
log_level="debug", |
||||
|
max_num_workers=1, |
||||
|
requires_gpu=True, |
||||
|
) |
@ -0,0 +1,66 @@ |
|||||
|
from eden.block import Block |
||||
|
from eden.datatypes import Image |
||||
|
from eden.hosting import host_block |
||||
|
|
||||
|
## eden <3 pytorch |
||||
|
from torchvision import models, transforms |
||||
|
import torch |
||||
|
|
||||
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) |
||||
|
model = model.eval() ## no dont move it to the gpu just yet :) |
||||
|
|
||||
|
my_transforms = transforms.Compose( |
||||
|
[ |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model |
||||
|
] |
||||
|
) |
||||
|
|
||||
|
eden_block = Block() |
||||
|
|
||||
|
my_args = { |
||||
|
"width": 224, ## width |
||||
|
"height": 224, ## height |
||||
|
"input_image": Image(), ## images require eden.datatypes.Image() |
||||
|
} |
||||
|
|
||||
|
import requests |
||||
|
labels = requests.get( |
||||
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
||||
|
).text.split("\n") |
||||
|
|
||||
|
|
||||
|
@eden_block.run(args=my_args, progress=False) |
||||
|
def do_something(config): |
||||
|
global model, labels |
||||
|
|
||||
|
pil_image = config["input_image"] |
||||
|
pil_image = pil_image.resize((config["width"], config["height"])) |
||||
|
|
||||
|
device = config.gpu |
||||
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) |
||||
|
|
||||
|
model = model.to(device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
pred = model(input_tensor)[0].cpu() |
||||
|
index = torch.argmax(pred).item() |
||||
|
value = pred[index].item() |
||||
|
# the index is the classification label for the pretrained resnet18 model. |
||||
|
# the human-readable labels associated with this index are pulled and returned as "label" |
||||
|
# we need to get them from imagenet labels, which we need to get online. |
||||
|
|
||||
|
label = labels[index] |
||||
|
# serialize the image |
||||
|
pil_image = Image(pil_image) |
||||
|
return {"value": value, "index": index, "label": label, 'image': pil_image} |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
host_block( |
||||
|
block=eden_block, |
||||
|
port=5656, |
||||
|
logfile="logs.log", |
||||
|
log_level="debug", |
||||
|
max_num_workers=1, |
||||
|
requires_gpu=True, |
||||
|
) |
@ -0,0 +1,75 @@ |
|||||
|
import PIL |
||||
|
import cv2 |
||||
|
import base64 |
||||
|
import numpy as np |
||||
|
from PIL.Image import Image as ImageFile |
||||
|
from PIL.JpegImagePlugin import JpegImageFile |
||||
|
from PIL.PngImagePlugin import PngImageFile |
||||
|
from PIL import Image |
||||
|
from io import BytesIO |
||||
|
|
||||
|
|
||||
|
def _encode_numpy_array_image(image): |
||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
||||
|
|
||||
|
if image.shape[-1] == 3: |
||||
|
_, buffer = cv2.imencode(".jpg", image) |
||||
|
|
||||
|
elif image.shape[-1] == 4: |
||||
|
_, buffer = cv2.imencode(".png", image) |
||||
|
|
||||
|
image_as_text = base64.b64encode(buffer) |
||||
|
|
||||
|
return image_as_text |
||||
|
|
||||
|
|
||||
|
def _encode_pil_image(image): |
||||
|
opencv_image = np.array(image) |
||||
|
image_as_text = _encode_numpy_array_image(image=opencv_image) |
||||
|
|
||||
|
return image_as_text |
||||
|
|
||||
|
|
||||
|
def _encode_image_file(image): |
||||
|
pil_image = Image.open(image) |
||||
|
|
||||
|
return _encode_pil_image(pil_image) |
||||
|
|
||||
|
|
||||
|
def encode(image): |
||||
|
|
||||
|
if ( |
||||
|
type(image) == np.ndarray |
||||
|
or type(image) == str |
||||
|
or isinstance( |
||||
|
image, |
||||
|
( |
||||
|
JpegImageFile, |
||||
|
PngImageFile, |
||||
|
ImageFile, |
||||
|
), |
||||
|
) |
||||
|
): |
||||
|
|
||||
|
if type(image) == np.ndarray: |
||||
|
image_as_text = _encode_numpy_array_image(image) |
||||
|
|
||||
|
elif type(image) == str: |
||||
|
image_as_text = _encode_image_file(image) |
||||
|
|
||||
|
else: |
||||
|
image_as_text = _encode_pil_image(image) |
||||
|
|
||||
|
return image_as_text.decode("ascii") |
||||
|
|
||||
|
else: |
||||
|
raise Exception( |
||||
|
"expected numpy.array, PIL.Image or str, not: ", str(type(image)) |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def decode(jpg_as_text): |
||||
|
if jpg_as_text is None: |
||||
|
return None |
||||
|
pil_image = Image.open(BytesIO(base64.b64decode(jpg_as_text))) |
||||
|
return pil_image |
@ -0,0 +1,77 @@ |
|||||
|
import asyncio |
||||
|
import signal |
||||
|
import socketio |
||||
|
from aiohttp import web |
||||
|
|
||||
|
SERVER_0_IP = "192.168.1.113" |
||||
|
FLASK_SERVER_PORT = 4999 |
||||
|
HEARTBEAT_INTERVAL = 1 |
||||
|
HEARTBEAT_TIMEOUT = 3 |
||||
|
|
||||
|
sio = socketio.AsyncServer(async_mode='aiohttp') |
||||
|
app = web.Application() |
||||
|
sio.attach(app) |
||||
|
|
||||
|
servers = {} |
||||
|
|
||||
|
async def available(request): |
||||
|
return web.json_response(servers) |
||||
|
|
||||
|
app.router.add_get("/available", available) |
||||
|
|
||||
|
@sio.event |
||||
|
async def connect(sid, environ): |
||||
|
print("I'm connected!", sid) |
||||
|
|
||||
|
@sio.event |
||||
|
async def register(sid, data): |
||||
|
server_info = data |
||||
|
name = server_info["name"] |
||||
|
|
||||
|
servers[name] = {"ip": server_info["ip"], "port": server_info["port"], "sid": sid} |
||||
|
print(servers) |
||||
|
|
||||
|
@sio.event |
||||
|
async def disconnect(sid): |
||||
|
print("I'm disconnected!", sid) |
||||
|
for name, server in servers.items(): |
||||
|
if server["sid"] == sid: |
||||
|
del servers[name] |
||||
|
break |
||||
|
|
||||
|
async def heartbeat(): |
||||
|
while True: |
||||
|
await asyncio.sleep(HEARTBEAT_INTERVAL) |
||||
|
server_values_copy = list(servers.values()) |
||||
|
for server in server_values_copy: |
||||
|
sid = server["sid"] |
||||
|
try: |
||||
|
print(f"Sending heartbeat to {sid}...") |
||||
|
heartbeat_future = sio.emit("heartbeat", to=sid) |
||||
|
await asyncio.wait_for(heartbeat_future, timeout=HEARTBEAT_TIMEOUT) |
||||
|
except (asyncio.TimeoutError, socketio.exceptions.TimeoutError): |
||||
|
print(f"Server {sid} failed to respond to heartbeat.") |
||||
|
await sio.disconnect(sid) |
||||
|
|
||||
|
def exit_handler(sig, frame): |
||||
|
print("Shutting down host...") |
||||
|
loop = asyncio.get_event_loop() |
||||
|
heartbeat_task.cancel() |
||||
|
loop.run_until_complete(loop.shutdown_asyncgens()) |
||||
|
loop.stop() |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
signal.signal(signal.SIGINT, exit_handler) |
||||
|
signal.signal(signal.SIGTERM, exit_handler) |
||||
|
|
||||
|
loop = asyncio.get_event_loop() |
||||
|
heartbeat_task = loop.create_task(heartbeat()) |
||||
|
aiohttp_app = loop.create_task(web._run_app(app, host=SERVER_0_IP, port=FLASK_SERVER_PORT)) |
||||
|
|
||||
|
try: |
||||
|
loop.run_until_complete(asyncio.gather(heartbeat_task, aiohttp_app)) |
||||
|
except asyncio.CancelledError: |
||||
|
pass |
||||
|
finally: |
||||
|
loop.run_until_complete(loop.shutdown_asyncgens()) |
||||
|
loop.stop() |
@ -0,0 +1,2 @@ |
|||||
|
python-socketio[asyncio_client]==6.1.1 |
||||
|
aiohttp==3.8.1 |
@ -0,0 +1,39 @@ |
|||||
|
import signal |
||||
|
import socketio |
||||
|
|
||||
|
SERVER_0_IP = "localhost" |
||||
|
SERVER_0_PORT = 4999 |
||||
|
SERVER_1_PORT = 5001 |
||||
|
SERVER_1_NAME = "server_1" |
||||
|
|
||||
|
sio = socketio.Client() |
||||
|
|
||||
|
@sio.event |
||||
|
def connect(): |
||||
|
print("I'm connected!") |
||||
|
sio.emit("register", {"name": SERVER_1_NAME, "ip": SERVER_0_IP, "port": SERVER_1_PORT}) |
||||
|
|
||||
|
@sio.event |
||||
|
def connect_error(data): |
||||
|
print("The connection failed!") |
||||
|
|
||||
|
@sio.event |
||||
|
def disconnect(): |
||||
|
print("I'm disconnected!") |
||||
|
|
||||
|
@sio.event |
||||
|
def heartbeat(): |
||||
|
print("Received heartbeat") |
||||
|
|
||||
|
def main(): |
||||
|
sio.connect(f"http://{SERVER_0_IP}:{SERVER_0_PORT}") |
||||
|
sio.wait() |
||||
|
|
||||
|
def exit_handler(sig, frame): |
||||
|
sio.disconnect() |
||||
|
exit(0) |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
signal.signal(signal.SIGINT, exit_handler) |
||||
|
signal.signal(signal.SIGTERM, exit_handler) |
||||
|
main() |
@ -0,0 +1,39 @@ |
|||||
|
import signal |
||||
|
import socketio |
||||
|
|
||||
|
SERVER_0_IP = "localhost" |
||||
|
SERVER_0_PORT = 4999 |
||||
|
SERVER_1_PORT = 5002 |
||||
|
SERVER_1_NAME = "server_2" |
||||
|
|
||||
|
sio = socketio.Client() |
||||
|
|
||||
|
@sio.event |
||||
|
def connect(): |
||||
|
print("I'm connected!") |
||||
|
sio.emit("register", {"name": SERVER_1_NAME, "ip": SERVER_0_IP, "port": SERVER_1_PORT}) |
||||
|
|
||||
|
@sio.event |
||||
|
def connect_error(data): |
||||
|
print("The connection failed!") |
||||
|
|
||||
|
@sio.event |
||||
|
def disconnect(): |
||||
|
print("I'm disconnected!") |
||||
|
|
||||
|
@sio.event |
||||
|
def heartbeat(): |
||||
|
print("Received heartbeat") |
||||
|
|
||||
|
def main(): |
||||
|
sio.connect(f"http://{SERVER_0_IP}:{SERVER_0_PORT}") |
||||
|
sio.wait() |
||||
|
|
||||
|
def exit_handler(sig, frame): |
||||
|
sio.disconnect() |
||||
|
exit(0) |
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
signal.signal(signal.SIGINT, exit_handler) |
||||
|
signal.signal(signal.SIGTERM, exit_handler) |
||||
|
main() |
Loading…
Reference in new issue