62 lines
2.1 KiB
Python
62 lines
2.1 KiB
Python
"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import time
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from birefnet_service.model import BiRefNetService
|
|
|
|
|
|
def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int):
|
|
svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution)
|
|
print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}")
|
|
for i in range(warmup):
|
|
t0 = time.perf_counter()
|
|
svc.remove_background(image, model=model, resolution=resolution)
|
|
if svc.device == "mps":
|
|
torch.mps.synchronize()
|
|
print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s")
|
|
times = []
|
|
for i in range(iters):
|
|
t0 = time.perf_counter()
|
|
svc.remove_background(image, model=model, resolution=resolution)
|
|
if svc.device == "mps":
|
|
torch.mps.synchronize()
|
|
dt = time.perf_counter() - t0
|
|
times.append(dt)
|
|
print(f" run {i + 1}: {dt:.2f}s")
|
|
avg = sum(times) / len(times)
|
|
print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}")
|
|
return avg
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--input", default="test.jpg")
|
|
ap.add_argument("--model", default="HR")
|
|
ap.add_argument("--resolution", type=int, default=2048)
|
|
ap.add_argument("--devices", default="cpu,mps")
|
|
ap.add_argument("--warmup", type=int, default=1)
|
|
ap.add_argument("--iters", type=int, default=3)
|
|
args = ap.parse_args()
|
|
|
|
image = Image.open(args.input)
|
|
print(f"image: {args.input} {image.size} mode={image.mode}")
|
|
print(f"mps available: {torch.backends.mps.is_available()}")
|
|
|
|
results = {}
|
|
for d in args.devices.split(","):
|
|
d = d.strip()
|
|
results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters)
|
|
|
|
if "cpu" in results and "mps" in results:
|
|
print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|