Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cea7706bea |
@ -36,8 +36,8 @@ url = "https://download.pytorch.org/whl/cu124"
|
|||||||
explicit = true
|
explicit = true
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
torch = { index = "pytorch-cu124" }
|
torch = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
|
||||||
torchvision = { index = "pytorch-cu124" }
|
torchvision = [{ index = "pytorch-cu124", marker = "sys_platform == 'linux'" }]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
|
|||||||
61
scripts/bench.py
Normal file
61
scripts/bench.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
"""Quick CPU vs MPS benchmark for BiRefNet HR at 2048."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from birefnet_service.model import BiRefNetService
|
||||||
|
|
||||||
|
|
||||||
|
def bench(device: str, image: Image.Image, model: str, resolution: int, warmup: int, iters: int):
|
||||||
|
svc = BiRefNetService(device=device, default_model=model, default_resolution=resolution)
|
||||||
|
print(f"[{device}] loaded {model} @ {resolution}, runtime device={svc.device}")
|
||||||
|
for i in range(warmup):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
svc.remove_background(image, model=model, resolution=resolution)
|
||||||
|
if svc.device == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
print(f" warmup {i + 1}: {time.perf_counter() - t0:.2f}s")
|
||||||
|
times = []
|
||||||
|
for i in range(iters):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
svc.remove_background(image, model=model, resolution=resolution)
|
||||||
|
if svc.device == "mps":
|
||||||
|
torch.mps.synchronize()
|
||||||
|
dt = time.perf_counter() - t0
|
||||||
|
times.append(dt)
|
||||||
|
print(f" run {i + 1}: {dt:.2f}s")
|
||||||
|
avg = sum(times) / len(times)
|
||||||
|
print(f"[{device}] avg={avg:.2f}s min={min(times):.2f}s best-of-{iters}")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--input", default="test.jpg")
|
||||||
|
ap.add_argument("--model", default="HR")
|
||||||
|
ap.add_argument("--resolution", type=int, default=2048)
|
||||||
|
ap.add_argument("--devices", default="cpu,mps")
|
||||||
|
ap.add_argument("--warmup", type=int, default=1)
|
||||||
|
ap.add_argument("--iters", type=int, default=3)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
image = Image.open(args.input)
|
||||||
|
print(f"image: {args.input} {image.size} mode={image.mode}")
|
||||||
|
print(f"mps available: {torch.backends.mps.is_available()}")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for d in args.devices.split(","):
|
||||||
|
d = d.strip()
|
||||||
|
results[d] = bench(d, image, args.model, args.resolution, args.warmup, args.iters)
|
||||||
|
|
||||||
|
if "cpu" in results and "mps" in results:
|
||||||
|
print(f"\nspeedup mps vs cpu: {results['cpu'] / results['mps']:.2f}x")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -130,11 +130,15 @@ class BiRefNetService:
|
|||||||
default_model: str = DEFAULT_MODEL,
|
default_model: str = DEFAULT_MODEL,
|
||||||
default_resolution: int = DEFAULT_RESOLUTION,
|
default_resolution: int = DEFAULT_RESOLUTION,
|
||||||
):
|
):
|
||||||
want_cuda = device != "cpu" and torch.cuda.is_available()
|
|
||||||
if device and device not in ("auto", "cpu"):
|
if device and device not in ("auto", "cpu"):
|
||||||
self.device = device
|
self.device = device
|
||||||
|
elif device != "cpu" and torch.cuda.is_available():
|
||||||
|
self.device = "cuda"
|
||||||
|
elif device != "cpu" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||||
|
self.device = "mps"
|
||||||
else:
|
else:
|
||||||
self.device = "cuda" if want_cuda else "cpu"
|
self.device = "cpu"
|
||||||
|
# fp16 is reliable on CUDA; on MPS it can introduce NaNs in BiRefNet — keep fp32.
|
||||||
self.use_half = self.device.startswith("cuda")
|
self.use_half = self.device.startswith("cuda")
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self.default_resolution = default_resolution
|
self.default_resolution = default_resolution
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user