You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
3.3 KiB
123 lines
3.3 KiB
from eden.block import Block
|
|
from eden.datatypes import Image
|
|
from eden.hosting import host_block
|
|
|
|
## eden <3 pytorch
|
|
from torchvision import models, transforms
|
|
import torch
|
|
|
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
model = model.eval() ## no dont move it to the gpu just yet :)
|
|
|
|
my_transforms = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model
|
|
]
|
|
)
|
|
|
|
eden_block = Block()
|
|
|
|
my_args = {
|
|
"width": 224, ## width
|
|
"height": 224, ## height
|
|
"input_image": Image(), ## images require eden.datatypes.Image()
|
|
}
|
|
|
|
import requests
|
|
labels = requests.get(
|
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
|
|
).text.split("\n")
|
|
|
|
|
|
@eden_block.run(args=my_args, progress=False)
|
|
def do_something(config):
|
|
global model, labels
|
|
|
|
pil_image = config["input_image"]
|
|
pil_image = pil_image.resize((config["width"], config["height"]))
|
|
|
|
device = config.gpu
|
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0)
|
|
|
|
model = model.to(device)
|
|
|
|
with torch.no_grad():
|
|
pred = model(input_tensor)[0].cpu()
|
|
index = torch.argmax(pred).item()
|
|
value = pred[index].item()
|
|
# the index is the classification label for the pretrained resnet18 model.
|
|
# the human-readable labels associated with this index are pulled and returned as "label"
|
|
# we need to get them from imagenet labels, which we need to get online.
|
|
|
|
label = labels[index]
|
|
# serialize the image
|
|
pil_image = Image(pil_image)
|
|
return {"value": value, "index": index, "label": label, 'image': pil_image}
|
|
|
|
|
|
def run_host_block():
|
|
host_block(
|
|
block=eden_block,
|
|
port=5656,
|
|
host="0.0.0.0",
|
|
redis_host="redis",
|
|
# logfile="log.log",
|
|
logfile=None,
|
|
log_level="debug",
|
|
max_num_workers=1,
|
|
requires_gpu=True,
|
|
)
|
|
|
|
import asyncio
|
|
import socketio
|
|
import socket
|
|
|
|
def get_ip_address():
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
try:
|
|
# This IP address doesn't need to be reachable, as we're only using it to find the local IP address
|
|
s.connect(("10.255.255.255", 1))
|
|
ip = s.getsockname()[0]
|
|
except Exception:
|
|
ip = "127.0.0.1"
|
|
finally:
|
|
s.close()
|
|
return ip
|
|
|
|
|
|
# Update these with the correct values for your host and server
|
|
HOST_SERVER_IP = "0.0.0.0"
|
|
HOST_SERVER_PORT = 4999
|
|
SERVER_NAME = "server_1"
|
|
SERVER_IP = get_ip_address()
|
|
SERVER_PORT = 8000
|
|
|
|
sio = socketio.AsyncClient()
|
|
|
|
async def announce_server():
|
|
await sio.connect(f'http://{HOST_SERVER_IP}:{HOST_SERVER_PORT}')
|
|
await sio.emit('register', {'name': SERVER_NAME, 'ip': SERVER_IP, 'port': SERVER_PORT})
|
|
|
|
@sio.on("heartbeat")
|
|
async def on_heartbeat():
|
|
print("Received heartbeat from host")
|
|
|
|
@sio.event
|
|
async def disconnect():
|
|
print("Disconnected from host")
|
|
|
|
async def main():
|
|
# Run host_block in a separate thread
|
|
loop = asyncio.get_event_loop()
|
|
host_block_thread = loop.run_in_executor(None, run_host_block)
|
|
|
|
# Announce the server to the host
|
|
await announce_server()
|
|
|
|
# Wait for host_block to finish
|
|
await host_block_thread
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|