prep for metal

This commit is contained in:
Michael Pilosov 2026-05-16 17:47:35 -06:00
parent 4efb4b8a2f
commit cea7706bea
4 changed files with 1134 additions and 4 deletions

View File

@ -36,8 +36,8 @@ url = "https://download.pytorch.org/whl/cu124"
explicit = true
[tool.uv.sources]
torch = { index = "pytorch-cu124" }
torchvision = { index = "pytorch-cu124" }
torch = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
torchvision = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
[tool.ruff]
line-length = 100

61
scripts/bench.py Normal file
View File

@ -0,0 +1,61 @@
"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048."""
from __future__ import annotations
import argparse
import time
import torch
from PIL import Image
from birefnet_service.model import BiRefNetService
def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int):
svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution)
print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}")
for i in range(warmup):
t0 = time.perf_counter()
svc.remove_background(image, model=model, resolution=resolution)
if svc.device == "mps":
torch.mps.synchronize()
print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s")
times = []
for i in range(iters):
t0 = time.perf_counter()
svc.remove_background(image, model=model, resolution=resolution)
if svc.device == "mps":
torch.mps.synchronize()
dt = time.perf_counter() - t0
times.append(dt)
print(f" run {i + 1}: {dt:.2f}s")
avg = sum(times) / len(times)
print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}")
return avg
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--input", default="test.jpg")
ap.add_argument("--model", default="HR")
ap.add_argument("--resolution", type=int, default=2048)
ap.add_argument("--devices", default="cpu,mps")
ap.add_argument("--warmup", type=int, default=1)
ap.add_argument("--iters", type=int, default=3)
args = ap.parse_args()
image = Image.open(args.input)
print(f"image: {args.input} {image.size} mode={image.mode}")
print(f"mps available: {torch.backends.mps.is_available()}")
results = {}
for d in args.devices.split(","):
d = d.strip()
results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters)
if "cpu" in results and "mps" in results:
print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x")
if __name__ == "__main__":
main()

View File

@ -130,11 +130,15 @@ class BiRefNetService:
default_model: str = DEFAULT_MODEL,
default_resolution: int = DEFAULT_RESOLUTION,
):
want_cuda = device != "cpu" and torch.cuda.is_available()
if device and device not in ("auto", "cpu"):
self.device = device
elif device != "cpu" and torch.cuda.is_available():
self.device = "cuda"
elif device != "cpu" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cuda" if want_cuda else "cpu"
self.device = "cpu"
# fp16 is reliable on CUDA; on MPS it can introduce NaNs in BiRefNet — keep fp32.
self.use_half = self.device.startswith("cuda")
self.default_model = default_model
self.default_resolution = default_resolution

1065
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff