You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.1 KiB
76 lines
2.1 KiB
from eden.block import Block
|
|
from eden.datatypes import Image
|
|
from eden.hosting import host_block
|
|
|
|
## eden <3 pytorch
|
|
from torchvision import models, transforms
|
|
import torch
|
|
|
|
model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
|
|
model = model.eval() ## no dont move it to the gpu just yet :)
|
|
|
|
my_transforms = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # this normalizes the image to the same format as the pretrained model
|
|
]
|
|
)
|
|
|
|
eden_block = Block()
|
|
|
|
my_args = {
|
|
"width": 224, ## width
|
|
"height": 224, ## height
|
|
"input_image": Image(), ## images require eden.datatypes.Image()
|
|
}
|
|
|
|
import requests
|
|
labels = requests.get(
|
|
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
|
|
).text.split("\n")
|
|
|
|
|
|
@eden_block.run(args=my_args, progress=False)
|
|
def do_something(config):
|
|
global model, labels
|
|
|
|
pil_image = config["input_image"]
|
|
pil_image = pil_image.resize((config["width"], config["height"]))
|
|
|
|
device = config.gpu
|
|
input_tensor = my_transforms(pil_image).to(device).unsqueeze(0)
|
|
|
|
model = model.to(device)
|
|
|
|
with torch.no_grad():
|
|
pred = model(input_tensor)[0].cpu()
|
|
index = torch.argmax(pred).item()
|
|
value = pred[index].item()
|
|
# the index is the classification label for the pretrained resnet18 model.
|
|
# the human-readable labels associated with this index are pulled and returned as "label"
|
|
# we need to get them from imagenet labels, which we need to get online.
|
|
|
|
label = labels[index]
|
|
# serialize the image
|
|
pil_image = Image(pil_image)
|
|
return {"value": value, "index": index, "label": label, 'image': pil_image}
|
|
|
|
from announce import announce_server_decorator
|
|
|
|
@announce_server_decorator#(name="example_block", port=5656)
|
|
def run_host_block():
|
|
host_block(
|
|
block=eden_block,
|
|
port=5656,
|
|
host="0.0.0.0",
|
|
redis_host="redis",
|
|
# logfile="log.log",
|
|
logfile=None,
|
|
log_level="debug",
|
|
max_num_workers=1,
|
|
requires_gpu=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_host_block()
|