mvp
This commit is contained in:
commit
96d16fc654
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@ -0,0 +1,13 @@
|
||||
.git
|
||||
.gitignore
|
||||
.venv
|
||||
__pycache__
|
||||
*.pyc
|
||||
hf_cache
|
||||
*.jpg
|
||||
*.jpeg
|
||||
*.png
|
||||
output*
|
||||
.python-version
|
||||
Makefile
|
||||
docker-compose.yml
|
||||
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
hf_cache/
|
||||
output.png
|
||||
output*.png
|
||||
mask*.png
|
||||
*.jpg
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12
|
||||
38
Dockerfile
Normal file
38
Dockerfile
Normal file
@ -0,0 +1,38 @@
|
||||
# BiRefNet background removal service — CUDA 12.4 runtime image.
|
||||
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
PYTHONUNBUFFERED=1 \
|
||||
UV_PYTHON_INSTALL_DIR=/opt/python \
|
||||
UV_PROJECT_ENVIRONMENT=/app/.venv \
|
||||
UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy \
|
||||
HF_HOME=/app/hf_cache \
|
||||
PORT=8000
|
||||
|
||||
# uv: fast, reproducible Python + dependency management.
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.9 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python + dependencies first so this layer is cached across code changes.
|
||||
# The BuildKit cache mount keeps the uv download cache warm across rebuilds.
|
||||
COPY pyproject.toml ./
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv python install 3.12 \
|
||||
&& uv sync --no-install-project --no-dev
|
||||
|
||||
# Application code.
|
||||
COPY src ./src
|
||||
COPY README.md ./
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --no-dev
|
||||
|
||||
ENV PATH="/app/.venv/bin:${PATH}"
|
||||
|
||||
EXPOSE 8000
|
||||
CMD ["birefnet-service"]
|
||||
58
Makefile
Normal file
58
Makefile
Normal file
@ -0,0 +1,58 @@
|
||||
# BiRefNet background removal service — CLI shortcuts.
|
||||
# Override defaults inline, e.g.: make test BG=white INPUT=photo.jpg
|
||||
|
||||
COMPOSE ?= docker compose
|
||||
PYTHON ?= python3
|
||||
PORT ?= 8000
|
||||
INPUT ?= test.jpg
|
||||
OUTPUT ?= output.png
|
||||
BG ?= alpha
|
||||
BLUR ?= 0
|
||||
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
.PHONY: help build run up stop down logs log ps test test-mask dev sync shell clean fmt
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## ' $(MAKEFILE_LIST) \
|
||||
| awk 'BEGIN{FS=":.*?## "}{printf " \033[36m%-12s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
build: ## Build the Docker image
|
||||
$(COMPOSE) build
|
||||
|
||||
run up: ## Start the service (GPU) in the background
|
||||
$(COMPOSE) up -d
|
||||
|
||||
stop down: ## Stop and remove the service container
|
||||
$(COMPOSE) down
|
||||
|
||||
logs log: ## Follow service logs
|
||||
$(COMPOSE) logs -f
|
||||
|
||||
ps: ## Show service status
|
||||
$(COMPOSE) ps
|
||||
|
||||
test: ## Send INPUT to the running service, save OUTPUT
|
||||
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
|
||||
--input $(INPUT) --output $(OUTPUT) --background $(BG) --mask-blur $(BLUR)
|
||||
|
||||
test-mask: ## Like 'test' but also save the raw mask (mask.png)
|
||||
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
|
||||
--input $(INPUT) --output $(OUTPUT) --background $(BG) \
|
||||
--mask-blur $(BLUR) --mask-output mask.png
|
||||
|
||||
sync: ## Install dependencies locally with uv
|
||||
uv sync
|
||||
|
||||
dev: sync ## Run the service locally (no Docker; needs local CUDA)
|
||||
uv run birefnet-service
|
||||
|
||||
shell: ## Open a shell inside a fresh container
|
||||
$(COMPOSE) run --rm --entrypoint bash birefnet
|
||||
|
||||
fmt: ## Format code with ruff
|
||||
uv run ruff format src scripts
|
||||
|
||||
clean: ## Stop the service and remove build artifacts
|
||||
-$(COMPOSE) down
|
||||
rm -f $(OUTPUT) mask.png
|
||||
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
||||
# BiRefNet Background Removal Service
|
||||
|
||||
GPU-accelerated background removal exposed as an HTTP API. Uses
|
||||
[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with
|
||||
[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the
|
||||
NVIDIA container runtime.
|
||||
|
||||
## Requirements
|
||||
|
||||
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
|
||||
- ~2 GB free disk for the model weights (downloaded on first run)
|
||||
|
||||
## Quick start
|
||||
|
||||
```bash
|
||||
make build # build the Docker image
|
||||
make run # start the service on :8000 (GPU)
|
||||
make logs # watch startup — first run downloads BiRefNet weights
|
||||
make test # send test.jpg, save output.png
|
||||
```
|
||||
|
||||
`make test` waits for the service `/health` endpoint before sending the
|
||||
request, so the first call may block while the model downloads and loads.
|
||||
|
||||
### Web UI
|
||||
|
||||
A minimal test page is served at the service root — open
|
||||
**http://localhost:8000/** in a browser, drop in an image, and preview the
|
||||
transparent-background result (handy when working over SSH). It calls the
|
||||
same `/predict` endpoint.
|
||||
|
||||
### Useful variations
|
||||
|
||||
```bash
|
||||
make test BG=white # composite onto a white background
|
||||
make test INPUT=photo.jpg OUTPUT=cut.png
|
||||
make test-mask # also save the raw alpha mask (mask.png)
|
||||
make help # list all targets
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
`POST /predict`
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"image": "<base64 image bytes>", // required
|
||||
"background": "alpha", // alpha|white|black|gray|green|blue|red
|
||||
"mask_blur": 0, // Gaussian blur radius on mask edges
|
||||
"return_mask": false // include the raw mask in the response
|
||||
}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"image": "<base64 PNG>",
|
||||
"format": "png",
|
||||
"width": 3637,
|
||||
"height": 3637,
|
||||
"mask": "<base64 PNG>" // only when return_mask=true
|
||||
}
|
||||
```
|
||||
|
||||
`GET /health` returns 200 when the service is ready.
|
||||
|
||||
## Configuration (environment variables)
|
||||
|
||||
| Variable | Default | Purpose |
|
||||
|----------------------|----------------------|----------------------------------|
|
||||
| `PORT` | `8000` | HTTP port |
|
||||
| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights |
|
||||
| `BIREFNET_RESOLUTION`| `1024` | Inference resolution |
|
||||
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
|
||||
|
||||
## Local development (no Docker)
|
||||
|
||||
Requires a local CUDA-capable PyTorch environment.
|
||||
|
||||
```bash
|
||||
make dev # uv sync + run the server locally
|
||||
```
|
||||
|
||||
## Layout
|
||||
|
||||
```
|
||||
src/birefnet_service/model.py BiRefNet wrapper (load + inference)
|
||||
src/birefnet_service/server.py LitServe API + web UI route
|
||||
src/birefnet_service/static/ web UI (index.html)
|
||||
scripts/client.py stdlib-only test client
|
||||
Dockerfile / docker-compose.yml CUDA image + nvidia runtime
|
||||
Makefile build / run / test shortcuts
|
||||
```
|
||||
28
compose.yml
Normal file
28
compose.yml
Normal file
@ -0,0 +1,28 @@
|
||||
services:
|
||||
birefnet:
|
||||
build: .
|
||||
image: birefnet-service:latest
|
||||
container_name: birefnet-service
|
||||
ports:
|
||||
- "${PORT:-8000}:8000"
|
||||
environment:
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||
# Default variant/resolution; both are also selectable per request.
|
||||
- BIREFNET_MODEL=${BIREFNET_MODEL:-general}
|
||||
- BIREFNET_RESOLUTION=${BIREFNET_RESOLUTION:-1024}
|
||||
# Use the nvidia-container-runtime for GPU acceleration.
|
||||
runtime: nvidia
|
||||
volumes:
|
||||
# Persist downloaded BiRefNet weights across container restarts.
|
||||
- hf-cache:/app/hf_cache
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 180s
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
hf-cache:
|
||||
43
pyproject.toml
Normal file
43
pyproject.toml
Normal file
@ -0,0 +1,43 @@
|
||||
[project]
|
||||
name = "birefnet-service"
|
||||
version = "0.1.0"
|
||||
description = "BiRefNet background removal as a GPU-accelerated API"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12,<3.13"
|
||||
dependencies = [
|
||||
"torch==2.5.1",
|
||||
"torchvision==0.20.1",
|
||||
"transformers>=4.44,<5",
|
||||
"timm>=1.0.0",
|
||||
"einops>=0.8.0",
|
||||
"kornia>=0.7.0",
|
||||
"pillow>=10.0.0",
|
||||
"numpy>=1.26",
|
||||
"litserve>=0.2.4",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
birefnet-service = "birefnet_service.server:run"
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["ruff>=0.6.0"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/birefnet_service"]
|
||||
|
||||
# BiRefNet (torch) needs CUDA wheels; pull torch/torchvision from the PyTorch index.
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu124"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cu124" }
|
||||
torchvision = { index = "pytorch-cu124" }
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
92
scripts/client.py
Normal file
92
scripts/client.py
Normal file
@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Minimal stdlib-only client for the BiRefNet service.
|
||||
|
||||
Encodes an image, posts it to /predict, and saves the returned PNG.
|
||||
No third-party dependencies so it can run with any system Python.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
|
||||
def wait_for_health(base_url: str, timeout: float) -> None:
|
||||
deadline = time.time() + timeout
|
||||
health = f"{base_url}/health"
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
with urllib.request.urlopen(health, timeout=5) as resp:
|
||||
if resp.status == 200:
|
||||
return
|
||||
except (urllib.error.URLError, ConnectionError, OSError):
|
||||
pass
|
||||
time.sleep(2)
|
||||
sys.exit(f"server at {health} not ready after {timeout:.0f}s")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument("--url", default="http://localhost:8000", help="service base URL")
|
||||
ap.add_argument("--input", default="test.jpg", help="input image path")
|
||||
ap.add_argument("--output", default="output.png", help="output PNG path")
|
||||
ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red")
|
||||
ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite")
|
||||
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
|
||||
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
|
||||
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
|
||||
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
|
||||
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
|
||||
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
|
||||
args = ap.parse_args()
|
||||
|
||||
base_url = args.url.rstrip("/")
|
||||
wait_for_health(base_url, args.wait)
|
||||
|
||||
with open(args.input, "rb") as f:
|
||||
payload = {
|
||||
"image": base64.b64encode(f.read()).decode("ascii"),
|
||||
"background": args.background,
|
||||
"mask_blur": args.mask_blur,
|
||||
"return_mask": args.mask_output is not None,
|
||||
}
|
||||
if args.model is not None:
|
||||
payload["model"] = args.model
|
||||
if args.resolution is not None:
|
||||
payload["resolution"] = args.resolution
|
||||
if args.crop:
|
||||
payload["crop"] = True
|
||||
payload["crop_margin"] = args.crop_margin
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{base_url}/predict",
|
||||
data=json.dumps(payload).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
started = time.time()
|
||||
with urllib.request.urlopen(req, timeout=300) as resp:
|
||||
result = json.loads(resp.read())
|
||||
elapsed = time.time() - started
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(base64.b64decode(result["image"]))
|
||||
print(
|
||||
f"saved {args.output} {result['width']}x{result['height']} "
|
||||
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
|
||||
)
|
||||
|
||||
if args.mask_output and "mask" in result:
|
||||
with open(args.mask_output, "wb") as f:
|
||||
f.write(base64.b64decode(result["mask"]))
|
||||
print(f"saved {args.mask_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
src/birefnet_service/__init__.py
Normal file
3
src/birefnet_service/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""BiRefNet background removal service."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
185
src/birefnet_service/model.py
Normal file
185
src/birefnet_service/model.py
Normal file
@ -0,0 +1,185 @@
|
||||
"""BiRefNet model wrapper for background removal.
|
||||
|
||||
Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports
|
||||
multiple model variants (lazily loaded + cached) and a tunable inference
|
||||
resolution, both selectable per request.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageFilter
|
||||
from torchvision import transforms
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
|
||||
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
|
||||
# BiRefNet variants: general is fast for clean single subjects; HR is for large
|
||||
# / detailed scenes and needs a higher resolution (>=1536) to perform well.
|
||||
MODEL_ALIASES = {
|
||||
"general": "ZhengPeng7/BiRefNet",
|
||||
"HR": "ZhengPeng7/BiRefNet_HR",
|
||||
"portrait": "ZhengPeng7/BiRefNet-portrait",
|
||||
"matting": "ZhengPeng7/BiRefNet-matting",
|
||||
"lite": "ZhengPeng7/BiRefNet_lite",
|
||||
}
|
||||
|
||||
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
|
||||
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
|
||||
|
||||
# ImageNet normalization, matching BiRefNet training.
|
||||
_MEAN = [0.485, 0.456, 0.406]
|
||||
_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
BG_COLORS: dict[str, tuple[int, int, int]] = {
|
||||
"black": (0, 0, 0),
|
||||
"white": (255, 255, 255),
|
||||
"gray": (128, 128, 128),
|
||||
"green": (0, 255, 0),
|
||||
"blue": (0, 0, 255),
|
||||
"red": (255, 0, 0),
|
||||
}
|
||||
|
||||
|
||||
def resolve_repo(model: str | None) -> str:
|
||||
"""Map a friendly variant name to a repo id (pass-through for raw ids)."""
|
||||
model = model or DEFAULT_MODEL
|
||||
return MODEL_ALIASES.get(model, model)
|
||||
|
||||
|
||||
def _normalize_resolution(resolution: int | None) -> int:
|
||||
"""BiRefNet's backbone needs the input side divisible by 32."""
|
||||
res = int(resolution or DEFAULT_RESOLUTION)
|
||||
return max(256, (res // 32) * 32)
|
||||
|
||||
|
||||
class BiRefNetService:
|
||||
"""Runs BiRefNet background removal; caches loaded model variants."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
default_model: str = DEFAULT_MODEL,
|
||||
default_resolution: int = DEFAULT_RESOLUTION,
|
||||
):
|
||||
want_cuda = device != "cpu" and torch.cuda.is_available()
|
||||
if device and device not in ("auto", "cpu"):
|
||||
self.device = device
|
||||
else:
|
||||
self.device = "cuda" if want_cuda else "cpu"
|
||||
self.use_half = self.device.startswith("cuda")
|
||||
self.default_model = default_model
|
||||
self.default_resolution = default_resolution
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self._models: dict[str, torch.nn.Module] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Preload the default variant so the worker is ready on first request.
|
||||
self._get_model(resolve_repo(default_model))
|
||||
|
||||
def _get_model(self, repo: str) -> torch.nn.Module:
|
||||
with self._lock:
|
||||
model = self._models.get(repo)
|
||||
if model is None:
|
||||
model = AutoModelForImageSegmentation.from_pretrained(
|
||||
repo, trust_remote_code=True
|
||||
)
|
||||
model.eval().to(self.device)
|
||||
if self.use_half:
|
||||
model.half()
|
||||
self._models[repo] = model
|
||||
return model
|
||||
|
||||
@torch.inference_mode()
|
||||
def infer_mask(self, image: Image.Image, repo: str, resolution: int) -> Image.Image:
|
||||
"""Return a single-channel ('L') alpha mask at the image's original size."""
|
||||
model = self._get_model(repo)
|
||||
w, h = image.size
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(resolution, resolution),
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(_MEAN, _STD),
|
||||
]
|
||||
)
|
||||
x = transform(image).unsqueeze(0).to(self.device)
|
||||
if self.use_half:
|
||||
x = x.half()
|
||||
pred = model(x)[-1].sigmoid().float().cpu()[0] # [1, R, R]
|
||||
mask = transforms.ToPILImage()(pred)
|
||||
return mask.resize((w, h), Image.BICUBIC)
|
||||
|
||||
def remove_background(
|
||||
self,
|
||||
image: Image.Image,
|
||||
model: str | None = None,
|
||||
resolution: int | None = None,
|
||||
background: str = "alpha",
|
||||
mask_blur: int = 0,
|
||||
crop: bool = False,
|
||||
crop_margin: float = 0.0,
|
||||
return_mask: bool = False,
|
||||
) -> dict:
|
||||
"""Run background removal.
|
||||
|
||||
model: variant alias ('general', 'HR', ...) or a raw HF repo id.
|
||||
resolution: inference resolution (rounded down to a multiple of 32).
|
||||
background: "alpha" for transparency, or a key from BG_COLORS.
|
||||
mask_blur: Gaussian blur radius applied to the mask edges.
|
||||
crop: crop the output to the foreground's bounding box.
|
||||
crop_margin: extra margin around the crop, in inches (uses image DPI).
|
||||
"""
|
||||
# DPI for inch->pixel margin conversion; default 96 if not embedded.
|
||||
dpi = image.info.get("dpi")
|
||||
dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0
|
||||
|
||||
image = image.convert("RGB")
|
||||
repo = resolve_repo(model)
|
||||
resolution = _normalize_resolution(resolution)
|
||||
|
||||
mask = self.infer_mask(image, repo, resolution)
|
||||
if mask_blur > 0:
|
||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
||||
|
||||
cutout = image.convert("RGBA")
|
||||
cutout.putalpha(mask)
|
||||
|
||||
background = (background or "alpha").lower()
|
||||
if background == "alpha":
|
||||
result = cutout
|
||||
else:
|
||||
if background not in BG_COLORS:
|
||||
raise ValueError(
|
||||
f"Unknown background '{background}'. "
|
||||
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
|
||||
)
|
||||
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
|
||||
result = Image.alpha_composite(bg, cutout).convert("RGB")
|
||||
|
||||
out: dict = {"image": result, "model": repo, "resolution": resolution}
|
||||
|
||||
if crop:
|
||||
# Bounding box of the foreground (mask above a low alpha threshold).
|
||||
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
|
||||
if bbox is not None:
|
||||
margin_px = round(max(0.0, crop_margin) * dpi_x)
|
||||
left = max(0, bbox[0] - margin_px)
|
||||
top = max(0, bbox[1] - margin_px)
|
||||
right = min(image.width, bbox[2] + margin_px)
|
||||
bottom = min(image.height, bbox[3] + margin_px)
|
||||
box = (left, top, right, bottom)
|
||||
result = result.crop(box)
|
||||
mask = mask.crop(box)
|
||||
out["image"] = result
|
||||
out["crop_box"] = box
|
||||
out["dpi"] = round(dpi_x, 1)
|
||||
|
||||
if return_mask:
|
||||
out["mask"] = mask
|
||||
return out
|
||||
113
src/birefnet_service/server.py
Normal file
113
src/birefnet_service/server.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""LitServe API exposing BiRefNet background removal.
|
||||
|
||||
Endpoint: POST /predict
|
||||
Request JSON:
|
||||
{
|
||||
"image": "<base64-encoded image bytes>", (required)
|
||||
"model": "general" | "HR" | "portrait" | ..., (default "general")
|
||||
"resolution": 1024, (default 1024)
|
||||
"background": "alpha" | "white" | "black" | ..., (default "alpha")
|
||||
"mask_blur": 0, (default 0)
|
||||
"return_mask": false (default false)
|
||||
}
|
||||
Response JSON:
|
||||
{
|
||||
"image": "<base64 PNG>",
|
||||
"format": "png",
|
||||
"width": int,
|
||||
"height": int,
|
||||
"model": "<repo id used>",
|
||||
"resolution": int,
|
||||
"mask": "<base64 PNG>" (only when return_mask=true)
|
||||
}
|
||||
|
||||
A minimal web UI is served at GET / (same origin as /predict).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import litserve as ls
|
||||
from fastapi.responses import HTMLResponse
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from .model import BiRefNetService
|
||||
|
||||
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _b64_to_image(data: str) -> Image.Image:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(data)))
|
||||
# Honor EXIF orientation so portrait photos aren't processed/returned sideways.
|
||||
return ImageOps.exif_transpose(image)
|
||||
|
||||
|
||||
def _image_to_b64(image: Image.Image) -> str:
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
|
||||
|
||||
class BiRefNetAPI(ls.LitAPI):
|
||||
def setup(self, device: str) -> None:
|
||||
self.service = BiRefNetService(device=device)
|
||||
|
||||
def decode_request(self, request: dict) -> dict:
|
||||
if "image" not in request:
|
||||
raise ValueError("Request must include a base64 'image' field.")
|
||||
return {
|
||||
"image": _b64_to_image(request["image"]),
|
||||
"model": request.get("model"),
|
||||
"resolution": request.get("resolution"),
|
||||
"background": request.get("background", "alpha"),
|
||||
"mask_blur": int(request.get("mask_blur", 0)),
|
||||
"crop": bool(request.get("crop", False)),
|
||||
"crop_margin": float(request.get("crop_margin", 0.0)),
|
||||
"return_mask": bool(request.get("return_mask", False)),
|
||||
}
|
||||
|
||||
def predict(self, inputs: dict) -> dict:
|
||||
return self.service.remove_background(**inputs)
|
||||
|
||||
def encode_response(self, output: dict) -> dict:
|
||||
image: Image.Image = output["image"]
|
||||
response = {
|
||||
"image": _image_to_b64(image),
|
||||
"format": "png",
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"model": output["model"],
|
||||
"resolution": output["resolution"],
|
||||
}
|
||||
if output.get("mask") is not None:
|
||||
response["mask"] = _image_to_b64(output["mask"])
|
||||
return response
|
||||
|
||||
|
||||
def run() -> None:
|
||||
server = ls.LitServer(
|
||||
BiRefNetAPI(),
|
||||
accelerator="auto",
|
||||
devices=1,
|
||||
timeout=int(os.getenv("REQUEST_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
# LitServe registers its own "/" route ("litserve running"); drop it so
|
||||
# our UI can own the root path. Served same-origin as /predict (no CORS).
|
||||
server.app.router.routes = [
|
||||
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
|
||||
]
|
||||
|
||||
@server.app.get("/", response_class=HTMLResponse)
|
||||
def index() -> str:
|
||||
return _UI_HTML
|
||||
|
||||
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
303
src/birefnet_service/static/index.html
Normal file
303
src/birefnet_service/static/index.html
Normal file
@ -0,0 +1,303 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>BiRefNet — Background Removal</title>
|
||||
<style>
|
||||
:root { color-scheme: dark; }
|
||||
* { box-sizing: border-box; }
|
||||
body {
|
||||
margin: 0; font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
|
||||
background: #15171c; color: #e8e8ea; padding: 24px;
|
||||
}
|
||||
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
|
||||
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 20px; }
|
||||
.wrap { max-width: 1100px; margin: 0 auto; }
|
||||
#drop {
|
||||
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
|
||||
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
|
||||
}
|
||||
#drop.over { border-color: #5b8cff; background: #1c2230; }
|
||||
#drop p { margin: 6px 0; color: #8a8f99; }
|
||||
.controls { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
||||
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
|
||||
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
|
||||
select, input[type=number] {
|
||||
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
|
||||
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
|
||||
}
|
||||
input[type=number] { width: 78px; }
|
||||
input[type=number]:disabled { opacity: .45; }
|
||||
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
|
||||
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
|
||||
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
|
||||
button {
|
||||
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
|
||||
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
|
||||
}
|
||||
button:disabled { background: #3a3f4b; cursor: not-allowed; }
|
||||
button.ghost { background: #2a2f3a; }
|
||||
.go-row { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
||||
.status { color: #8a8f99; font-size: .85rem; }
|
||||
.status.err { color: #ff6b6b; }
|
||||
.hint { color: #6b7280; font-size: .78rem; margin-top: -8px; }
|
||||
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
|
||||
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
|
||||
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
|
||||
text-transform: uppercase; letter-spacing: .05em; }
|
||||
.imgbox {
|
||||
min-height: 260px; display: flex; align-items: center; justify-content: center;
|
||||
border-radius: 8px; overflow: hidden;
|
||||
}
|
||||
.checker {
|
||||
background-image:
|
||||
linear-gradient(45deg, #2a2f3a 25%, transparent 25%),
|
||||
linear-gradient(-45deg, #2a2f3a 25%, transparent 25%),
|
||||
linear-gradient(45deg, transparent 75%, #2a2f3a 75%),
|
||||
linear-gradient(-45deg, transparent 75%, #2a2f3a 75%);
|
||||
background-size: 22px 22px;
|
||||
background-position: 0 0, 0 11px, 11px -11px, -11px 0;
|
||||
background-color: #20242d;
|
||||
}
|
||||
.imgbox img { max-width: 100%; max-height: 70vh; display: block; }
|
||||
.imgbox img[src] { cursor: zoom-in; }
|
||||
@media (max-width: 720px) { .panels { grid-template-columns: 1fr; } }
|
||||
|
||||
/* lightbox */
|
||||
.lightbox { position: fixed; inset: 0; z-index: 100; background: rgba(12,13,17,.97);
|
||||
display: flex; align-items: center; justify-content: center; }
|
||||
.lightbox[hidden] { display: none; }
|
||||
.lb-stage { width: 100vw; height: 100vh; overflow: hidden;
|
||||
display: flex; align-items: center; justify-content: center; }
|
||||
.lb-stage img { max-width: 100vw; max-height: 100vh; transform-origin: 0 0;
|
||||
cursor: grab; user-select: none; -webkit-user-drag: none; will-change: transform; }
|
||||
.lb-stage.grabbing img { cursor: grabbing; }
|
||||
.lb-bar { position: fixed; top: 0; left: 0; right: 0; padding: 14px 20px;
|
||||
z-index: 2; display: flex; justify-content: space-between; align-items: center;
|
||||
color: #8a8f99; font-size: .8rem; pointer-events: none; }
|
||||
.lb-close { pointer-events: auto; background: #2a2f3a; color: #e8e8ea;
|
||||
border: 1px solid #3a3f4b; border-radius: 8px; width: 34px; height: 34px;
|
||||
font-size: 1rem; line-height: 1; padding: 0; cursor: pointer;
|
||||
display: flex; align-items: center; justify-content: center; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="wrap">
|
||||
<h1>BiRefNet — Background Removal</h1>
|
||||
<div class="sub">Drop an image to get a transparent-background PNG.</div>
|
||||
|
||||
<div id="drop">
|
||||
<p><strong>Drop an image here</strong> or click to choose</p>
|
||||
<p id="fname">No file selected</p>
|
||||
<input id="file" type="file" accept="image/*" hidden />
|
||||
</div>
|
||||
|
||||
<div class="controls">
|
||||
<label class="field">Model
|
||||
<select id="model">
|
||||
<option value="general">general — clean single subjects (fast)</option>
|
||||
<option value="HR" selected>HR — large / detailed scenes</option>
|
||||
<option value="portrait">portrait — people</option>
|
||||
<option value="matting">matting — soft edges / hair</option>
|
||||
<option value="lite">lite — fastest</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="field">Resolution
|
||||
<select id="resolution">
|
||||
<option value="1024">1024</option>
|
||||
<option value="1536">1536</option>
|
||||
<option value="2048">2048</option>
|
||||
<option value="2560" selected>2560</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
|
||||
<label class="field">Margin (in)
|
||||
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
|
||||
</label>
|
||||
</div>
|
||||
<div class="hint">Tip: large or busy scenes segment best with <strong>HR</strong> at <strong>2048</strong>.
|
||||
The <em>general</em> model expects a clear single subject at 1024.</div>
|
||||
|
||||
<div class="go-row">
|
||||
<button id="go" disabled>Remove background</button>
|
||||
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
|
||||
<span id="status" class="status"></span>
|
||||
</div>
|
||||
|
||||
<div class="panels">
|
||||
<div class="panel">
|
||||
<h2>Original</h2>
|
||||
<div class="imgbox"><img id="src" alt="" /></div>
|
||||
</div>
|
||||
<div class="panel">
|
||||
<h2>Result</h2>
|
||||
<div class="imgbox checker"><img id="out" alt="" /></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="lightbox" class="lightbox" hidden>
|
||||
<div class="lb-bar">
|
||||
<span>scroll to zoom · drag to pan · double-click resets · Esc closes</span>
|
||||
<button class="lb-close" id="lbClose" title="Close">✕</button>
|
||||
</div>
|
||||
<div class="lb-stage" id="lbStage"><img id="lbImg" alt="" /></div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const drop = document.getElementById('drop');
|
||||
const fileInput = document.getElementById('file');
|
||||
const fname = document.getElementById('fname');
|
||||
const go = document.getElementById('go');
|
||||
const dl = document.getElementById('dl');
|
||||
const dlbtn = document.getElementById('dlbtn');
|
||||
const statusEl = document.getElementById('status');
|
||||
const srcImg = document.getElementById('src');
|
||||
const outImg = document.getElementById('out');
|
||||
const modelSel = document.getElementById('model');
|
||||
const resSel = document.getElementById('resolution');
|
||||
const cropChk = document.getElementById('crop');
|
||||
const cropMargin = document.getElementById('cropMargin');
|
||||
|
||||
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
|
||||
|
||||
let selectedFile = null;
|
||||
|
||||
function setStatus(msg, isErr) {
|
||||
statusEl.textContent = msg;
|
||||
statusEl.className = 'status' + (isErr ? ' err' : '');
|
||||
}
|
||||
|
||||
function pickFile(file) {
|
||||
if (!file || !file.type.startsWith('image/')) {
|
||||
setStatus('Please choose an image file.', true);
|
||||
return;
|
||||
}
|
||||
selectedFile = file;
|
||||
fname.textContent = file.name + ' (' + Math.round(file.size / 1024) + ' KB)';
|
||||
srcImg.src = URL.createObjectURL(file);
|
||||
outImg.removeAttribute('src');
|
||||
dlbtn.disabled = true;
|
||||
go.disabled = false;
|
||||
setStatus('');
|
||||
}
|
||||
|
||||
drop.addEventListener('click', () => fileInput.click());
|
||||
fileInput.addEventListener('change', e => pickFile(e.target.files[0]));
|
||||
['dragenter', 'dragover'].forEach(ev =>
|
||||
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.add('over'); }));
|
||||
['dragleave', 'drop'].forEach(ev =>
|
||||
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.remove('over'); }));
|
||||
drop.addEventListener('drop', e => pickFile(e.dataTransfer.files[0]));
|
||||
|
||||
function fileToBase64(file) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const r = new FileReader();
|
||||
r.onload = () => resolve(r.result.split(',')[1]); // strip data URL prefix
|
||||
r.onerror = reject;
|
||||
r.readAsDataURL(file);
|
||||
});
|
||||
}
|
||||
|
||||
// --- lightbox: click to inspect, scroll to zoom, drag to pan ---
|
||||
const lightbox = document.getElementById('lightbox');
|
||||
const lbStage = document.getElementById('lbStage');
|
||||
const lbImg = document.getElementById('lbImg');
|
||||
const lbClose = document.getElementById('lbClose');
|
||||
let lbScale = 1, lbTx = 0, lbTy = 0, lbDrag = null;
|
||||
|
||||
function lbApply() {
|
||||
lbImg.style.transform = `translate(${lbTx}px, ${lbTy}px) scale(${lbScale})`;
|
||||
}
|
||||
function lbReset() { lbScale = 1; lbTx = 0; lbTy = 0; lbApply(); }
|
||||
|
||||
function openLightbox(src, isResult) {
|
||||
if (!src) return;
|
||||
lbImg.src = src;
|
||||
lbImg.classList.toggle('checker', !!isResult);
|
||||
lbReset();
|
||||
lightbox.hidden = false;
|
||||
}
|
||||
function closeLightbox() { lightbox.hidden = true; lbImg.removeAttribute('src'); }
|
||||
|
||||
srcImg.addEventListener('click', () => openLightbox(srcImg.getAttribute('src'), false));
|
||||
outImg.addEventListener('click', () => openLightbox(outImg.getAttribute('src'), true));
|
||||
lbClose.addEventListener('click', closeLightbox);
|
||||
lightbox.addEventListener('mousedown', e => {
|
||||
if (e.target === lightbox || e.target === lbStage) closeLightbox();
|
||||
});
|
||||
document.addEventListener('keydown', e => {
|
||||
if (e.key === 'Escape' && !lightbox.hidden) closeLightbox();
|
||||
});
|
||||
|
||||
lbStage.addEventListener('wheel', e => {
|
||||
e.preventDefault();
|
||||
const rect = lbImg.getBoundingClientRect();
|
||||
const cx = e.clientX - rect.left, cy = e.clientY - rect.top;
|
||||
const factor = e.deltaY < 0 ? 1.2 : 1 / 1.2;
|
||||
const newScale = Math.min(8, Math.max(1, lbScale * factor));
|
||||
const ratio = newScale / lbScale;
|
||||
lbTx -= cx * (ratio - 1);
|
||||
lbTy -= cy * (ratio - 1);
|
||||
lbScale = newScale;
|
||||
if (lbScale === 1) { lbTx = 0; lbTy = 0; }
|
||||
lbApply();
|
||||
}, { passive: false });
|
||||
|
||||
lbImg.addEventListener('mousedown', e => {
|
||||
e.preventDefault();
|
||||
lbDrag = { x: e.clientX, y: e.clientY, tx: lbTx, ty: lbTy };
|
||||
lbStage.classList.add('grabbing');
|
||||
});
|
||||
window.addEventListener('mousemove', e => {
|
||||
if (!lbDrag) return;
|
||||
lbTx = lbDrag.tx + (e.clientX - lbDrag.x);
|
||||
lbTy = lbDrag.ty + (e.clientY - lbDrag.y);
|
||||
lbApply();
|
||||
});
|
||||
window.addEventListener('mouseup', () => {
|
||||
lbDrag = null;
|
||||
lbStage.classList.remove('grabbing');
|
||||
});
|
||||
lbImg.addEventListener('dblclick', e => { e.preventDefault(); lbReset(); });
|
||||
|
||||
go.addEventListener('click', async () => {
|
||||
if (!selectedFile) return;
|
||||
go.disabled = true;
|
||||
dlbtn.disabled = true;
|
||||
setStatus('Processing… (first use of a model downloads its weights)');
|
||||
const t0 = performance.now();
|
||||
try {
|
||||
const b64 = await fileToBase64(selectedFile);
|
||||
const resp = await fetch('/predict', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
image: b64,
|
||||
background: 'alpha',
|
||||
model: modelSel.value,
|
||||
resolution: parseInt(resSel.value, 10),
|
||||
crop: cropChk.checked,
|
||||
crop_margin: parseFloat(cropMargin.value) || 0,
|
||||
}),
|
||||
});
|
||||
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
|
||||
const data = await resp.json();
|
||||
const dataUrl = 'data:image/png;base64,' + data.image;
|
||||
outImg.src = dataUrl;
|
||||
dl.href = dataUrl;
|
||||
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
|
||||
dlbtn.disabled = false;
|
||||
const secs = ((performance.now() - t0) / 1000).toFixed(1);
|
||||
setStatus('Done — ' + data.width + '×' + data.height + ' · ' + data.model +
|
||||
' @ ' + data.resolution + ' · ' + secs + 's');
|
||||
} catch (err) {
|
||||
setStatus(err.message || String(err), true);
|
||||
} finally {
|
||||
go.disabled = false;
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Loading…
Reference in New Issue
Block a user