prep for metal
This commit is contained in:
parent
4efb4b8a2f
commit
cea7706bea
@ -36,8 +36,8 @@ url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu124" }
|
||||
torchvision = { index = "pytorch-cu124" }
|
||||
torch = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
|
||||
torchvision = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
61
scripts/bench.py
Normal file
61
scripts/bench.py
Normal file
@ -0,0 +1,61 @@
|
||||
"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from birefnet_service.model import BiRefNetService
|
||||
|
||||
|
||||
def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int):
|
||||
svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution)
|
||||
print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}")
|
||||
for i in range(warmup):
|
||||
t0 = time.perf_counter()
|
||||
svc.remove_background(image, model=model, resolution=resolution)
|
||||
if svc.device == "mps":
|
||||
torch.mps.synchronize()
|
||||
print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s")
|
||||
times = []
|
||||
for i in range(iters):
|
||||
t0 = time.perf_counter()
|
||||
svc.remove_background(image, model=model, resolution=resolution)
|
||||
if svc.device == "mps":
|
||||
torch.mps.synchronize()
|
||||
dt = time.perf_counter() - t0
|
||||
times.append(dt)
|
||||
print(f" run {i + 1}: {dt:.2f}s")
|
||||
avg = sum(times) / len(times)
|
||||
print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}")
|
||||
return avg
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--input", default="test.jpg")
|
||||
ap.add_argument("--model", default="HR")
|
||||
ap.add_argument("--resolution", type=int, default=2048)
|
||||
ap.add_argument("--devices", default="cpu,mps")
|
||||
ap.add_argument("--warmup", type=int, default=1)
|
||||
ap.add_argument("--iters", type=int, default=3)
|
||||
args = ap.parse_args()
|
||||
|
||||
image = Image.open(args.input)
|
||||
print(f"image: {args.input} {image.size} mode={image.mode}")
|
||||
print(f"mps available: {torch.backends.mps.is_available()}")
|
||||
|
||||
results = {}
|
||||
for d in args.devices.split(","):
|
||||
d = d.strip()
|
||||
results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters)
|
||||
|
||||
if "cpu" in results and "mps" in results:
|
||||
print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -130,11 +130,15 @@ class BiRefNetService:
|
||||
default_model: str = DEFAULT_MODEL,
|
||||
default_resolution: int = DEFAULT_RESOLUTION,
|
||||
):
|
||||
want_cuda = device != "cpu" and torch.cuda.is_available()
|
||||
if device and device not in ("auto", "cpu"):
|
||||
self.device = device
|
||||
elif device != "cpu" and torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
elif device != "cpu" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cuda" if want_cuda else "cpu"
|
||||
self.device = "cpu"
|
||||
# fp16 is reliable on CUDA; on MPS it can introduce NaNs in BiRefNet — keep fp32.
|
||||
self.use_half = self.device.startswith("cuda")
|
||||
self.default_model = default_model
|
||||
self.default_resolution = default_resolution
|
||||
|
||||
Loading…
Reference in New Issue
Block a user