"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048.""" from __future__ import annotations import argparse import time import torch from PIL import Image from birefnet_service.model import BiRefNetService def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int): svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution) print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}") for i in range(warmup): t0 = time.perf_counter() svc.remove_background(image, model=model, resolution=resolution) if svc.device == "mps": torch.mps.synchronize() print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s") times = [] for i in range(iters): t0 = time.perf_counter() svc.remove_background(image, model=model, resolution=resolution) if svc.device == "mps": torch.mps.synchronize() dt = time.perf_counter() - t0 times.append(dt) print(f" run {i + 1}: {dt:.2f}s") avg = sum(times) / len(times) print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}") return avg def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--input", default="test.jpg") ap.add_argument("--model", default="HR") ap.add_argument("--resolution", type=int, default=2048) ap.add_argument("--devices", default="cpu,mps") ap.add_argument("--warmup", type=int, default=1) ap.add_argument("--iters", type=int, default=3) args = ap.parse_args() image = Image.open(args.input) print(f"image: {args.input} {image.size} mode={image.mode}") print(f"mps available: {torch.backends.mps.is_available()}") results = {} for d in args.devices.split(","): d = d.strip() results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters) if "cpu" in results and "mps" in results: print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x") if __name__ == "__main__": main()