from eden.block import Block from eden.datatypes import Image from eden.hosting import host_block ## eden <3 pytorch from torchvision import models, transforms import torch model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) model = model.eval() ## no dont move it to the gpu just yet :) my_transforms = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model ] ) eden_block = Block() my_args = { "width": 224, ## width "height": 224, ## height "input_image": Image(), ## images require eden.datatypes.Image() } import requests labels = requests.get( "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" ).text.split("\n") @eden_block.run(args=my_args, progress=False) def do_something(config): global model, labels pil_image = config["input_image"] pil_image = pil_image.resize((config["width"], config["height"])) device = config.gpu input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) model = model.to(device) with torch.no_grad(): pred = model(input_tensor)[0].cpu() index = torch.argmax(pred).item() value = pred[index].item() # the index is the classification label for the pretrained resnet18 model. # the human-readable labels associated with this index are pulled and returned as "label" # we need to get them from imagenet labels, which we need to get online. label = labels[index] # serialize the image pil_image = Image(pil_image) return {"value": value, "index": index, "label": label, 'image': pil_image} def run_host_block(): host_block( block=eden_block, port=5656, host="0.0.0.0", redis_host="redis", # logfile="log.log", logfile=None, log_level="debug", max_num_workers=1, requires_gpu=True, ) import asyncio import socketio import socket def get_ip_address(): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: # This IP address doesn't need to be reachable, as we're only using it to find the local IP address s.connect(("10.255.255.255", 1)) ip = s.getsockname()[0] except Exception: ip = "127.0.0.1" finally: s.close() return ip # Update these with the correct values for your host and server HOST_SERVER_IP = "0.0.0.0" HOST_SERVER_PORT = 4999 SERVER_NAME = "server_1" SERVER_IP = get_ip_address() SERVER_PORT = 8000 sio = socketio.AsyncClient() async def announce_server(): await sio.connect(f'http://{HOST_SERVER_IP}:{HOST_SERVER_PORT}') await sio.emit('register', {'name': SERVER_NAME, 'ip': SERVER_IP, 'port': SERVER_PORT}) @sio.on("heartbeat") async def on_heartbeat(): print("Received heartbeat from host") @sio.event async def disconnect(): print("Disconnected from host") async def main(): # Run host_block in a separate thread loop = asyncio.get_event_loop() host_block_thread = loop.run_in_executor(None, run_host_block) # Announce the server to the host await announce_server() # Wait for host_block to finish await host_block_thread if __name__ == "__main__": asyncio.run(main())