rmbg/scripts/bench.py
2026-05-16 17:47:35 -06:00

62 lines
2.1 KiB
Python

"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048."""
from __future__ import annotations
import argparse
import time
import torch
from PIL import Image
from birefnet_service.model import BiRefNetService
def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int):
svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution)
print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}")
for i in range(warmup):
t0 = time.perf_counter()
svc.remove_background(image, model=model, resolution=resolution)
if svc.device == "mps":
torch.mps.synchronize()
print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s")
times = []
for i in range(iters):
t0 = time.perf_counter()
svc.remove_background(image, model=model, resolution=resolution)
if svc.device == "mps":
torch.mps.synchronize()
dt = time.perf_counter() - t0
times.append(dt)
print(f" run {i + 1}: {dt:.2f}s")
avg = sum(times) / len(times)
print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}")
return avg
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--input", default="test.jpg")
ap.add_argument("--model", default="HR")
ap.add_argument("--resolution", type=int, default=2048)
ap.add_argument("--devices", default="cpu,mps")
ap.add_argument("--warmup", type=int, default=1)
ap.add_argument("--iters", type=int, default=3)
args = ap.parse_args()
image = Image.open(args.input)
print(f"image: {args.input} {image.size} mode={image.mode}")
print(f"mps available: {torch.backends.mps.is_available()}")
results = {}
for d in args.devices.split(","):
d = d.strip()
results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters)
if "cpu" in results and "mps" in results:
print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x")
if __name__ == "__main__":
main()