You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

123 lines
3.3 KiB

from eden.block import Block
from eden.datatypes import Image
from eden.hosting import host_block
## eden <3 pytorch
from torchvision import models, transforms
import torch
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
model = model.eval() ## no dont move it to the gpu just yet :)
my_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model
]
)
eden_block = Block()
my_args = {
"width": 224, ## width
"height": 224, ## height
"input_image": Image(), ## images require eden.datatypes.Image()
}
import requests
labels = requests.get(
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
).text.split("\n")
@eden_block.run(args=my_args, progress=False)
def do_something(config):
global model, labels
pil_image = config["input_image"]
pil_image = pil_image.resize((config["width"], config["height"]))
device = config.gpu
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0)
model = model.to(device)
with torch.no_grad():
pred = model(input_tensor)[0].cpu()
index = torch.argmax(pred).item()
value = pred[index].item()
# the index is the classification label for the pretrained resnet18 model.
# the human-readable labels associated with this index are pulled and returned as "label"
# we need to get them from imagenet labels, which we need to get online.
label = labels[index]
# serialize the image
pil_image = Image(pil_image)
return {"value": value, "index": index, "label": label, 'image': pil_image}
def run_host_block():
host_block(
block=eden_block,
port=5656,
host="0.0.0.0",
redis_host="redis",
# logfile="log.log",
logfile=None,
log_level="debug",
max_num_workers=1,
requires_gpu=True,
)
import asyncio
import socketio
import socket
def get_ip_address():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# This IP address doesn't need to be reachable, as we're only using it to find the local IP address
s.connect(("10.255.255.255", 1))
ip = s.getsockname()[0]
except Exception:
ip = "127.0.0.1"
finally:
s.close()
return ip
# Update these with the correct values for your host and server
HOST_SERVER_IP = "0.0.0.0"
HOST_SERVER_PORT = 4999
SERVER_NAME = "server_1"
SERVER_IP = get_ip_address()
SERVER_PORT = 8000
sio = socketio.AsyncClient()
async def announce_server():
await sio.connect(f'http://{HOST_SERVER_IP}:{HOST_SERVER_PORT}')
await sio.emit('register', {'name': SERVER_NAME, 'ip': SERVER_IP, 'port': SERVER_PORT})
@sio.on("heartbeat")
async def on_heartbeat():
print("Received heartbeat from host")
@sio.event
async def disconnect():
print("Disconnected from host")
async def main():
# Run host_block in a separate thread
loop = asyncio.get_event_loop()
host_block_thread = loop.run_in_executor(None, run_host_block)
# Announce the server to the host
await announce_server()
# Wait for host_block to finish
await host_block_thread
if __name__ == "__main__":
asyncio.run(main())