from eden.block import Block from eden.datatypes import Image from eden.hosting import host_block ## eden <3 pytorch from torchvision import models, transforms import torch model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) model = model.eval() ## no dont move it to the gpu just yet :) my_transforms = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model ] ) eden_block = Block() my_args = { "width": 224, ## width "height": 224, ## height "input_image": Image(), ## images require eden.datatypes.Image() } import requests labels = requests.get( "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" ).text.split("\n") @eden_block.run(args=my_args, progress=False) def do_something(config): global model, labels pil_image = config["input_image"] pil_image = pil_image.resize((config["width"], config["height"])) device = config.gpu input_tensor = my_transforms(pil_image).to(device).unsqueeze(0) model = model.to(device) with torch.no_grad(): pred = model(input_tensor)[0].cpu() index = torch.argmax(pred).item() value = pred[index].item() # the index is the classification label for the pretrained resnet18 model. # the human-readable labels associated with this index are pulled and returned as "label" # we need to get them from imagenet labels, which we need to get online. label = labels[index] # serialize the image pil_image = Image(pil_image) return {"value": value, "index": index, "label": label, 'image': pil_image} if __name__ == "__main__": host_block( block=eden_block, port=5656, logfile="logs.log", log_level="debug", max_num_workers=1, requires_gpu=True, )