This commit is contained in:
Michael Pilosov 2026-05-16 14:32:26 -06:00
commit 96d16fc654
13 changed files with 979 additions and 0 deletions

13
.dockerignore Normal file
View File

@ -0,0 +1,13 @@
.git
.gitignore
.venv
__pycache__
*.pyc
hf_cache
*.jpg
*.jpeg
*.png
output*
.python-version
Makefile
docker-compose.yml

8
.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
.venv/
__pycache__/
*.pyc
hf_cache/
output.png
output*.png
mask*.png
*.jpg

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

38
Dockerfile Normal file
View File

@ -0,0 +1,38 @@
# BiRefNet background removal service — CUDA 12.4 runtime image.
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
ENV DEBIAN_FRONTEND=noninteractive \
PYTHONUNBUFFERED=1 \
UV_PYTHON_INSTALL_DIR=/opt/python \
UV_PROJECT_ENVIRONMENT=/app/.venv \
UV_COMPILE_BYTECODE=1 \
UV_LINK_MODE=copy \
HF_HOME=/app/hf_cache \
PORT=8000
# uv: fast, reproducible Python + dependency management.
COPY --from=ghcr.io/astral-sh/uv:0.9 /uv /uvx /bin/
RUN apt-get update \
&& apt-get install -y --no-install-recommends ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Install Python + dependencies first so this layer is cached across code changes.
# The BuildKit cache mount keeps the uv download cache warm across rebuilds.
COPY pyproject.toml ./
RUN --mount=type=cache,target=/root/.cache/uv \
uv python install 3.12 \
&& uv sync --no-install-project --no-dev
# Application code.
COPY src ./src
COPY README.md ./
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --no-dev
ENV PATH="/app/.venv/bin:${PATH}"
EXPOSE 8000
CMD ["birefnet-service"]

58
Makefile Normal file
View File

@ -0,0 +1,58 @@
# BiRefNet background removal service — CLI shortcuts.
# Override defaults inline, e.g.: make test BG=white INPUT=photo.jpg
COMPOSE ?= docker compose
PYTHON ?= python3
PORT ?= 8000
INPUT ?= test.jpg
OUTPUT ?= output.png
BG ?= alpha
BLUR ?= 0
.DEFAULT_GOAL := help
.PHONY: help build run up stop down logs log ps test test-mask dev sync shell clean fmt
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## ' $(MAKEFILE_LIST) \
| awk 'BEGIN{FS=":.*?## "}{printf " \033[36m%-12s\033[0m %s\n", $$1, $$2}'
build: ## Build the Docker image
$(COMPOSE) build
run up: ## Start the service (GPU) in the background
$(COMPOSE) up -d
stop down: ## Stop and remove the service container
$(COMPOSE) down
logs log: ## Follow service logs
$(COMPOSE) logs -f
ps: ## Show service status
$(COMPOSE) ps
test: ## Send INPUT to the running service, save OUTPUT
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
--input $(INPUT) --output $(OUTPUT) --background $(BG) --mask-blur $(BLUR)
test-mask: ## Like 'test' but also save the raw mask (mask.png)
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
--input $(INPUT) --output $(OUTPUT) --background $(BG) \
--mask-blur $(BLUR) --mask-output mask.png
sync: ## Install dependencies locally with uv
uv sync
dev: sync ## Run the service locally (no Docker; needs local CUDA)
uv run birefnet-service
shell: ## Open a shell inside a fresh container
$(COMPOSE) run --rm --entrypoint bash birefnet
fmt: ## Format code with ruff
uv run ruff format src scripts
clean: ## Stop the service and remove build artifacts
-$(COMPOSE) down
rm -f $(OUTPUT) mask.png

94
README.md Normal file
View File

@ -0,0 +1,94 @@
# BiRefNet Background Removal Service
GPU-accelerated background removal exposed as an HTTP API. Uses
[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with
[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the
NVIDIA container runtime.
## Requirements
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
- ~2 GB free disk for the model weights (downloaded on first run)
## Quick start
```bash
make build # build the Docker image
make run # start the service on :8000 (GPU)
make logs # watch startup — first run downloads BiRefNet weights
make test # send test.jpg, save output.png
```
`make test` waits for the service `/health` endpoint before sending the
request, so the first call may block while the model downloads and loads.
### Web UI
A minimal test page is served at the service root — open
**http://localhost:8000/** in a browser, drop in an image, and preview the
transparent-background result (handy when working over SSH). It calls the
same `/predict` endpoint.
### Useful variations
```bash
make test BG=white # composite onto a white background
make test INPUT=photo.jpg OUTPUT=cut.png
make test-mask # also save the raw alpha mask (mask.png)
make help # list all targets
```
## API
`POST /predict`
```jsonc
{
"image": "<base64 image bytes>", // required
"background": "alpha", // alpha|white|black|gray|green|blue|red
"mask_blur": 0, // Gaussian blur radius on mask edges
"return_mask": false // include the raw mask in the response
}
```
Response:
```jsonc
{
"image": "<base64 PNG>",
"format": "png",
"width": 3637,
"height": 3637,
"mask": "<base64 PNG>" // only when return_mask=true
}
```
`GET /health` returns 200 when the service is ready.
## Configuration (environment variables)
| Variable | Default | Purpose |
|----------------------|----------------------|----------------------------------|
| `PORT` | `8000` | HTTP port |
| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights |
| `BIREFNET_RESOLUTION`| `1024` | Inference resolution |
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
## Local development (no Docker)
Requires a local CUDA-capable PyTorch environment.
```bash
make dev # uv sync + run the server locally
```
## Layout
```
src/birefnet_service/model.py BiRefNet wrapper (load + inference)
src/birefnet_service/server.py LitServe API + web UI route
src/birefnet_service/static/ web UI (index.html)
scripts/client.py stdlib-only test client
Dockerfile / docker-compose.yml CUDA image + nvidia runtime
Makefile build / run / test shortcuts
```

28
compose.yml Normal file
View File

@ -0,0 +1,28 @@
services:
birefnet:
build: .
image: birefnet-service:latest
container_name: birefnet-service
ports:
- "${PORT:-8000}:8000"
environment:
- NVIDIA_VISIBLE_DEVICES=all
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
# Default variant/resolution; both are also selectable per request.
- BIREFNET_MODEL=${BIREFNET_MODEL:-general}
- BIREFNET_RESOLUTION=${BIREFNET_RESOLUTION:-1024}
# Use the nvidia-container-runtime for GPU acceleration.
runtime: nvidia
volumes:
# Persist downloaded BiRefNet weights across container restarts.
- hf-cache:/app/hf_cache
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
interval: 15s
timeout: 5s
retries: 30
start_period: 180s
restart: unless-stopped
volumes:
hf-cache:

43
pyproject.toml Normal file
View File

@ -0,0 +1,43 @@
[project]
name = "birefnet-service"
version = "0.1.0"
description = "BiRefNet background removal as a GPU-accelerated API"
readme = "README.md"
requires-python = ">=3.12,<3.13"
dependencies = [
"torch==2.5.1",
"torchvision==0.20.1",
"transformers>=4.44,<5",
"timm>=1.0.0",
"einops>=0.8.0",
"kornia>=0.7.0",
"pillow>=10.0.0",
"numpy>=1.26",
"litserve>=0.2.4",
]
[project.scripts]
birefnet-service = "birefnet_service.server:run"
[dependency-groups]
dev = ["ruff>=0.6.0"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/birefnet_service"]
# BiRefNet (torch) needs CUDA wheels; pull torch/torchvision from the PyTorch index.
[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true
[tool.uv.sources]
torch = { index = "pytorch-cu124" }
torchvision = { index = "pytorch-cu124" }
[tool.ruff]
line-length = 100

92
scripts/client.py Normal file
View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""Minimal stdlib-only client for the BiRefNet service.
Encodes an image, posts it to /predict, and saves the returned PNG.
No third-party dependencies so it can run with any system Python.
"""
from __future__ import annotations
import argparse
import base64
import json
import sys
import time
import urllib.error
import urllib.request
def wait_for_health(base_url: str, timeout: float) -> None:
deadline = time.time() + timeout
health = f"{base_url}/health"
while time.time() < deadline:
try:
with urllib.request.urlopen(health, timeout=5) as resp:
if resp.status == 200:
return
except (urllib.error.URLError, ConnectionError, OSError):
pass
time.sleep(2)
sys.exit(f"server at {health} not ready after {timeout:.0f}s")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--url", default="http://localhost:8000", help="service base URL")
ap.add_argument("--input", default="test.jpg", help="input image path")
ap.add_argument("--output", default="output.png", help="output PNG path")
ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red")
ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite")
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
args = ap.parse_args()
base_url = args.url.rstrip("/")
wait_for_health(base_url, args.wait)
with open(args.input, "rb") as f:
payload = {
"image": base64.b64encode(f.read()).decode("ascii"),
"background": args.background,
"mask_blur": args.mask_blur,
"return_mask": args.mask_output is not None,
}
if args.model is not None:
payload["model"] = args.model
if args.resolution is not None:
payload["resolution"] = args.resolution
if args.crop:
payload["crop"] = True
payload["crop_margin"] = args.crop_margin
req = urllib.request.Request(
f"{base_url}/predict",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
method="POST",
)
started = time.time()
with urllib.request.urlopen(req, timeout=300) as resp:
result = json.loads(resp.read())
elapsed = time.time() - started
with open(args.output, "wb") as f:
f.write(base64.b64decode(result["image"]))
print(
f"saved {args.output} {result['width']}x{result['height']} "
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
)
if args.mask_output and "mask" in result:
with open(args.mask_output, "wb") as f:
f.write(base64.b64decode(result["mask"]))
print(f"saved {args.mask_output}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,3 @@
"""BiRefNet background removal service."""
__version__ = "0.1.0"

View File

@ -0,0 +1,185 @@
"""BiRefNet model wrapper for background removal.
Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports
multiple model variants (lazily loaded + cached) and a tunable inference
resolution, both selectable per request.
"""
from __future__ import annotations
import os
import threading
import torch
from PIL import Image, ImageFilter
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
# BiRefNet variants: general is fast for clean single subjects; HR is for large
# / detailed scenes and needs a higher resolution (>=1536) to perform well.
MODEL_ALIASES = {
"general": "ZhengPeng7/BiRefNet",
"HR": "ZhengPeng7/BiRefNet_HR",
"portrait": "ZhengPeng7/BiRefNet-portrait",
"matting": "ZhengPeng7/BiRefNet-matting",
"lite": "ZhengPeng7/BiRefNet_lite",
}
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
# ImageNet normalization, matching BiRefNet training.
_MEAN = [0.485, 0.456, 0.406]
_STD = [0.229, 0.224, 0.225]
BG_COLORS: dict[str, tuple[int, int, int]] = {
"black": (0, 0, 0),
"white": (255, 255, 255),
"gray": (128, 128, 128),
"green": (0, 255, 0),
"blue": (0, 0, 255),
"red": (255, 0, 0),
}
def resolve_repo(model: str | None) -> str:
"""Map a friendly variant name to a repo id (pass-through for raw ids)."""
model = model or DEFAULT_MODEL
return MODEL_ALIASES.get(model, model)
def _normalize_resolution(resolution: int | None) -> int:
"""BiRefNet's backbone needs the input side divisible by 32."""
res = int(resolution or DEFAULT_RESOLUTION)
return max(256, (res // 32) * 32)
class BiRefNetService:
"""Runs BiRefNet background removal; caches loaded model variants."""
def __init__(
self,
device: str | None = None,
default_model: str = DEFAULT_MODEL,
default_resolution: int = DEFAULT_RESOLUTION,
):
want_cuda = device != "cpu" and torch.cuda.is_available()
if device and device not in ("auto", "cpu"):
self.device = device
else:
self.device = "cuda" if want_cuda else "cpu"
self.use_half = self.device.startswith("cuda")
self.default_model = default_model
self.default_resolution = default_resolution
torch.set_float32_matmul_precision("high")
self._models: dict[str, torch.nn.Module] = {}
self._lock = threading.Lock()
# Preload the default variant so the worker is ready on first request.
self._get_model(resolve_repo(default_model))
def _get_model(self, repo: str) -> torch.nn.Module:
with self._lock:
model = self._models.get(repo)
if model is None:
model = AutoModelForImageSegmentation.from_pretrained(
repo, trust_remote_code=True
)
model.eval().to(self.device)
if self.use_half:
model.half()
self._models[repo] = model
return model
@torch.inference_mode()
def infer_mask(self, image: Image.Image, repo: str, resolution: int) -> Image.Image:
"""Return a single-channel ('L') alpha mask at the image's original size."""
model = self._get_model(repo)
w, h = image.size
transform = transforms.Compose(
[
transforms.Resize(
(resolution, resolution),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(_MEAN, _STD),
]
)
x = transform(image).unsqueeze(0).to(self.device)
if self.use_half:
x = x.half()
pred = model(x)[-1].sigmoid().float().cpu()[0] # [1, R, R]
mask = transforms.ToPILImage()(pred)
return mask.resize((w, h), Image.BICUBIC)
def remove_background(
self,
image: Image.Image,
model: str | None = None,
resolution: int | None = None,
background: str = "alpha",
mask_blur: int = 0,
crop: bool = False,
crop_margin: float = 0.0,
return_mask: bool = False,
) -> dict:
"""Run background removal.
model: variant alias ('general', 'HR', ...) or a raw HF repo id.
resolution: inference resolution (rounded down to a multiple of 32).
background: "alpha" for transparency, or a key from BG_COLORS.
mask_blur: Gaussian blur radius applied to the mask edges.
crop: crop the output to the foreground's bounding box.
crop_margin: extra margin around the crop, in inches (uses image DPI).
"""
# DPI for inch->pixel margin conversion; default 96 if not embedded.
dpi = image.info.get("dpi")
dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0
image = image.convert("RGB")
repo = resolve_repo(model)
resolution = _normalize_resolution(resolution)
mask = self.infer_mask(image, repo, resolution)
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
cutout = image.convert("RGBA")
cutout.putalpha(mask)
background = (background or "alpha").lower()
if background == "alpha":
result = cutout
else:
if background not in BG_COLORS:
raise ValueError(
f"Unknown background '{background}'. "
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
)
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
result = Image.alpha_composite(bg, cutout).convert("RGB")
out: dict = {"image": result, "model": repo, "resolution": resolution}
if crop:
# Bounding box of the foreground (mask above a low alpha threshold).
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
if bbox is not None:
margin_px = round(max(0.0, crop_margin) * dpi_x)
left = max(0, bbox[0] - margin_px)
top = max(0, bbox[1] - margin_px)
right = min(image.width, bbox[2] + margin_px)
bottom = min(image.height, bbox[3] + margin_px)
box = (left, top, right, bottom)
result = result.crop(box)
mask = mask.crop(box)
out["image"] = result
out["crop_box"] = box
out["dpi"] = round(dpi_x, 1)
if return_mask:
out["mask"] = mask
return out

View File

@ -0,0 +1,113 @@
"""LitServe API exposing BiRefNet background removal.
Endpoint: POST /predict
Request JSON:
{
"image": "<base64-encoded image bytes>", (required)
"model": "general" | "HR" | "portrait" | ..., (default "general")
"resolution": 1024, (default 1024)
"background": "alpha" | "white" | "black" | ..., (default "alpha")
"mask_blur": 0, (default 0)
"return_mask": false (default false)
}
Response JSON:
{
"image": "<base64 PNG>",
"format": "png",
"width": int,
"height": int,
"model": "<repo id used>",
"resolution": int,
"mask": "<base64 PNG>" (only when return_mask=true)
}
A minimal web UI is served at GET / (same origin as /predict).
"""
from __future__ import annotations
import base64
import io
import os
from pathlib import Path
import litserve as ls
from fastapi.responses import HTMLResponse
from PIL import Image, ImageOps
from .model import BiRefNetService
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
def _b64_to_image(data: str) -> Image.Image:
image = Image.open(io.BytesIO(base64.b64decode(data)))
# Honor EXIF orientation so portrait photos aren't processed/returned sideways.
return ImageOps.exif_transpose(image)
def _image_to_b64(image: Image.Image) -> str:
buf = io.BytesIO()
image.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
class BiRefNetAPI(ls.LitAPI):
def setup(self, device: str) -> None:
self.service = BiRefNetService(device=device)
def decode_request(self, request: dict) -> dict:
if "image" not in request:
raise ValueError("Request must include a base64 'image' field.")
return {
"image": _b64_to_image(request["image"]),
"model": request.get("model"),
"resolution": request.get("resolution"),
"background": request.get("background", "alpha"),
"mask_blur": int(request.get("mask_blur", 0)),
"crop": bool(request.get("crop", False)),
"crop_margin": float(request.get("crop_margin", 0.0)),
"return_mask": bool(request.get("return_mask", False)),
}
def predict(self, inputs: dict) -> dict:
return self.service.remove_background(**inputs)
def encode_response(self, output: dict) -> dict:
image: Image.Image = output["image"]
response = {
"image": _image_to_b64(image),
"format": "png",
"width": image.width,
"height": image.height,
"model": output["model"],
"resolution": output["resolution"],
}
if output.get("mask") is not None:
response["mask"] = _image_to_b64(output["mask"])
return response
def run() -> None:
server = ls.LitServer(
BiRefNetAPI(),
accelerator="auto",
devices=1,
timeout=int(os.getenv("REQUEST_TIMEOUT", "120")),
)
# LitServe registers its own "/" route ("litserve running"); drop it so
# our UI can own the root path. Served same-origin as /predict (no CORS).
server.app.router.routes = [
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
]
@server.app.get("/", response_class=HTMLResponse)
def index() -> str:
return _UI_HTML
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)
if __name__ == "__main__":
run()

View File

@ -0,0 +1,303 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>BiRefNet — Background Removal</title>
<style>
:root { color-scheme: dark; }
* { box-sizing: border-box; }
body {
margin: 0; font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
background: #15171c; color: #e8e8ea; padding: 24px;
}
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 20px; }
.wrap { max-width: 1100px; margin: 0 auto; }
#drop {
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
}
#drop.over { border-color: #5b8cff; background: #1c2230; }
#drop p { margin: 6px 0; color: #8a8f99; }
.controls { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
select, input[type=number] {
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
}
input[type=number] { width: 78px; }
input[type=number]:disabled { opacity: .45; }
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
button {
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
}
button:disabled { background: #3a3f4b; cursor: not-allowed; }
button.ghost { background: #2a2f3a; }
.go-row { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
.status { color: #8a8f99; font-size: .85rem; }
.status.err { color: #ff6b6b; }
.hint { color: #6b7280; font-size: .78rem; margin-top: -8px; }
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
text-transform: uppercase; letter-spacing: .05em; }
.imgbox {
min-height: 260px; display: flex; align-items: center; justify-content: center;
border-radius: 8px; overflow: hidden;
}
.checker {
background-image:
linear-gradient(45deg, #2a2f3a 25%, transparent 25%),
linear-gradient(-45deg, #2a2f3a 25%, transparent 25%),
linear-gradient(45deg, transparent 75%, #2a2f3a 75%),
linear-gradient(-45deg, transparent 75%, #2a2f3a 75%);
background-size: 22px 22px;
background-position: 0 0, 0 11px, 11px -11px, -11px 0;
background-color: #20242d;
}
.imgbox img { max-width: 100%; max-height: 70vh; display: block; }
.imgbox img[src] { cursor: zoom-in; }
@media (max-width: 720px) { .panels { grid-template-columns: 1fr; } }
/* lightbox */
.lightbox { position: fixed; inset: 0; z-index: 100; background: rgba(12,13,17,.97);
display: flex; align-items: center; justify-content: center; }
.lightbox[hidden] { display: none; }
.lb-stage { width: 100vw; height: 100vh; overflow: hidden;
display: flex; align-items: center; justify-content: center; }
.lb-stage img { max-width: 100vw; max-height: 100vh; transform-origin: 0 0;
cursor: grab; user-select: none; -webkit-user-drag: none; will-change: transform; }
.lb-stage.grabbing img { cursor: grabbing; }
.lb-bar { position: fixed; top: 0; left: 0; right: 0; padding: 14px 20px;
z-index: 2; display: flex; justify-content: space-between; align-items: center;
color: #8a8f99; font-size: .8rem; pointer-events: none; }
.lb-close { pointer-events: auto; background: #2a2f3a; color: #e8e8ea;
border: 1px solid #3a3f4b; border-radius: 8px; width: 34px; height: 34px;
font-size: 1rem; line-height: 1; padding: 0; cursor: pointer;
display: flex; align-items: center; justify-content: center; }
</style>
</head>
<body>
<div class="wrap">
<h1>BiRefNet — Background Removal</h1>
<div class="sub">Drop an image to get a transparent-background PNG.</div>
<div id="drop">
<p><strong>Drop an image here</strong> or click to choose</p>
<p id="fname">No file selected</p>
<input id="file" type="file" accept="image/*" hidden />
</div>
<div class="controls">
<label class="field">Model
<select id="model">
<option value="general">general — clean single subjects (fast)</option>
<option value="HR" selected>HR — large / detailed scenes</option>
<option value="portrait">portrait — people</option>
<option value="matting">matting — soft edges / hair</option>
<option value="lite">lite — fastest</option>
</select>
</label>
<label class="field">Resolution
<select id="resolution">
<option value="1024">1024</option>
<option value="1536">1536</option>
<option value="2048">2048</option>
<option value="2560" selected>2560</option>
</select>
</label>
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
<label class="field">Margin (in)
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
</label>
</div>
<div class="hint">Tip: large or busy scenes segment best with <strong>HR</strong> at <strong>2048</strong>.
The <em>general</em> model expects a clear single subject at 1024.</div>
<div class="go-row">
<button id="go" disabled>Remove background</button>
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
<span id="status" class="status"></span>
</div>
<div class="panels">
<div class="panel">
<h2>Original</h2>
<div class="imgbox"><img id="src" alt="" /></div>
</div>
<div class="panel">
<h2>Result</h2>
<div class="imgbox checker"><img id="out" alt="" /></div>
</div>
</div>
</div>
<div id="lightbox" class="lightbox" hidden>
<div class="lb-bar">
<span>scroll to zoom · drag to pan · double-click resets · Esc closes</span>
<button class="lb-close" id="lbClose" title="Close"></button>
</div>
<div class="lb-stage" id="lbStage"><img id="lbImg" alt="" /></div>
</div>
<script>
const drop = document.getElementById('drop');
const fileInput = document.getElementById('file');
const fname = document.getElementById('fname');
const go = document.getElementById('go');
const dl = document.getElementById('dl');
const dlbtn = document.getElementById('dlbtn');
const statusEl = document.getElementById('status');
const srcImg = document.getElementById('src');
const outImg = document.getElementById('out');
const modelSel = document.getElementById('model');
const resSel = document.getElementById('resolution');
const cropChk = document.getElementById('crop');
const cropMargin = document.getElementById('cropMargin');
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
let selectedFile = null;
function setStatus(msg, isErr) {
statusEl.textContent = msg;
statusEl.className = 'status' + (isErr ? ' err' : '');
}
function pickFile(file) {
if (!file || !file.type.startsWith('image/')) {
setStatus('Please choose an image file.', true);
return;
}
selectedFile = file;
fname.textContent = file.name + ' (' + Math.round(file.size / 1024) + ' KB)';
srcImg.src = URL.createObjectURL(file);
outImg.removeAttribute('src');
dlbtn.disabled = true;
go.disabled = false;
setStatus('');
}
drop.addEventListener('click', () => fileInput.click());
fileInput.addEventListener('change', e => pickFile(e.target.files[0]));
['dragenter', 'dragover'].forEach(ev =>
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.add('over'); }));
['dragleave', 'drop'].forEach(ev =>
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.remove('over'); }));
drop.addEventListener('drop', e => pickFile(e.dataTransfer.files[0]));
function fileToBase64(file) {
return new Promise((resolve, reject) => {
const r = new FileReader();
r.onload = () => resolve(r.result.split(',')[1]); // strip data URL prefix
r.onerror = reject;
r.readAsDataURL(file);
});
}
// --- lightbox: click to inspect, scroll to zoom, drag to pan ---
const lightbox = document.getElementById('lightbox');
const lbStage = document.getElementById('lbStage');
const lbImg = document.getElementById('lbImg');
const lbClose = document.getElementById('lbClose');
let lbScale = 1, lbTx = 0, lbTy = 0, lbDrag = null;
function lbApply() {
lbImg.style.transform = `translate(${lbTx}px, ${lbTy}px) scale(${lbScale})`;
}
function lbReset() { lbScale = 1; lbTx = 0; lbTy = 0; lbApply(); }
function openLightbox(src, isResult) {
if (!src) return;
lbImg.src = src;
lbImg.classList.toggle('checker', !!isResult);
lbReset();
lightbox.hidden = false;
}
function closeLightbox() { lightbox.hidden = true; lbImg.removeAttribute('src'); }
srcImg.addEventListener('click', () => openLightbox(srcImg.getAttribute('src'), false));
outImg.addEventListener('click', () => openLightbox(outImg.getAttribute('src'), true));
lbClose.addEventListener('click', closeLightbox);
lightbox.addEventListener('mousedown', e => {
if (e.target === lightbox || e.target === lbStage) closeLightbox();
});
document.addEventListener('keydown', e => {
if (e.key === 'Escape' && !lightbox.hidden) closeLightbox();
});
lbStage.addEventListener('wheel', e => {
e.preventDefault();
const rect = lbImg.getBoundingClientRect();
const cx = e.clientX - rect.left, cy = e.clientY - rect.top;
const factor = e.deltaY < 0 ? 1.2 : 1 / 1.2;
const newScale = Math.min(8, Math.max(1, lbScale * factor));
const ratio = newScale / lbScale;
lbTx -= cx * (ratio - 1);
lbTy -= cy * (ratio - 1);
lbScale = newScale;
if (lbScale === 1) { lbTx = 0; lbTy = 0; }
lbApply();
}, { passive: false });
lbImg.addEventListener('mousedown', e => {
e.preventDefault();
lbDrag = { x: e.clientX, y: e.clientY, tx: lbTx, ty: lbTy };
lbStage.classList.add('grabbing');
});
window.addEventListener('mousemove', e => {
if (!lbDrag) return;
lbTx = lbDrag.tx + (e.clientX - lbDrag.x);
lbTy = lbDrag.ty + (e.clientY - lbDrag.y);
lbApply();
});
window.addEventListener('mouseup', () => {
lbDrag = null;
lbStage.classList.remove('grabbing');
});
lbImg.addEventListener('dblclick', e => { e.preventDefault(); lbReset(); });
go.addEventListener('click', async () => {
if (!selectedFile) return;
go.disabled = true;
dlbtn.disabled = true;
setStatus('Processing… (first use of a model downloads its weights)');
const t0 = performance.now();
try {
const b64 = await fileToBase64(selectedFile);
const resp = await fetch('/predict', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
image: b64,
background: 'alpha',
model: modelSel.value,
resolution: parseInt(resSel.value, 10),
crop: cropChk.checked,
crop_margin: parseFloat(cropMargin.value) || 0,
}),
});
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
const data = await resp.json();
const dataUrl = 'data:image/png;base64,' + data.image;
outImg.src = dataUrl;
dl.href = dataUrl;
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
dlbtn.disabled = false;
const secs = ((performance.now() - t0) / 1000).toFixed(1);
setStatus('Done — ' + data.width + '×' + data.height + ' · ' + data.model +
' @ ' + data.resolution + ' · ' + secs + 's');
} catch (err) {
setStatus(err.message || String(err), true);
} finally {
go.disabled = false;
}
});
</script>
</body>
</html>