commit 96d16fc65439245ba2ba35c2fc02f07c1287e8e9 Author: Michael Pilosov Date: Sat May 16 14:32:26 2026 -0600 mvp diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..fc1db4a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,13 @@ +.git +.gitignore +.venv +__pycache__ +*.pyc +hf_cache +*.jpg +*.jpeg +*.png +output* +.python-version +Makefile +docker-compose.yml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5292b5c --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.venv/ +__pycache__/ +*.pyc +hf_cache/ +output.png +output*.png +mask*.png +*.jpg diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..56b10d1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,38 @@ +# BiRefNet background removal service — CUDA 12.4 runtime image. +FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + UV_PYTHON_INSTALL_DIR=/opt/python \ + UV_PROJECT_ENVIRONMENT=/app/.venv \ + UV_COMPILE_BYTECODE=1 \ + UV_LINK_MODE=copy \ + HF_HOME=/app/hf_cache \ + PORT=8000 + +# uv: fast, reproducible Python + dependency management. +COPY --from=ghcr.io/astral-sh/uv:0.9 /uv /uvx /bin/ + +RUN apt-get update \ + && apt-get install -y --no-install-recommends ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install Python + dependencies first so this layer is cached across code changes. +# The BuildKit cache mount keeps the uv download cache warm across rebuilds. +COPY pyproject.toml ./ +RUN --mount=type=cache,target=/root/.cache/uv \ + uv python install 3.12 \ + && uv sync --no-install-project --no-dev + +# Application code. +COPY src ./src +COPY README.md ./ +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --no-dev + +ENV PATH="/app/.venv/bin:${PATH}" + +EXPOSE 8000 +CMD ["birefnet-service"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ba636d1 --- /dev/null +++ b/Makefile @@ -0,0 +1,58 @@ +# BiRefNet background removal service — CLI shortcuts. +# Override defaults inline, e.g.: make test BG=white INPUT=photo.jpg + +COMPOSE ?= docker compose +PYTHON ?= python3 +PORT ?= 8000 +INPUT ?= test.jpg +OUTPUT ?= output.png +BG ?= alpha +BLUR ?= 0 + +.DEFAULT_GOAL := help + +.PHONY: help build run up stop down logs log ps test test-mask dev sync shell clean fmt + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## ' $(MAKEFILE_LIST) \ + | awk 'BEGIN{FS=":.*?## "}{printf " \033[36m%-12s\033[0m %s\n", $$1, $$2}' + +build: ## Build the Docker image + $(COMPOSE) build + +run up: ## Start the service (GPU) in the background + $(COMPOSE) up -d + +stop down: ## Stop and remove the service container + $(COMPOSE) down + +logs log: ## Follow service logs + $(COMPOSE) logs -f + +ps: ## Show service status + $(COMPOSE) ps + +test: ## Send INPUT to the running service, save OUTPUT + $(PYTHON) scripts/client.py --url http://localhost:$(PORT) \ + --input $(INPUT) --output $(OUTPUT) --background $(BG) --mask-blur $(BLUR) + +test-mask: ## Like 'test' but also save the raw mask (mask.png) + $(PYTHON) scripts/client.py --url http://localhost:$(PORT) \ + --input $(INPUT) --output $(OUTPUT) --background $(BG) \ + --mask-blur $(BLUR) --mask-output mask.png + +sync: ## Install dependencies locally with uv + uv sync + +dev: sync ## Run the service locally (no Docker; needs local CUDA) + uv run birefnet-service + +shell: ## Open a shell inside a fresh container + $(COMPOSE) run --rm --entrypoint bash birefnet + +fmt: ## Format code with ruff + uv run ruff format src scripts + +clean: ## Stop the service and remove build artifacts + -$(COMPOSE) down + rm -f $(OUTPUT) mask.png diff --git a/README.md b/README.md new file mode 100644 index 0000000..093a6c4 --- /dev/null +++ b/README.md @@ -0,0 +1,94 @@ +# BiRefNet Background Removal Service + +GPU-accelerated background removal exposed as an HTTP API. Uses +[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with +[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the +NVIDIA container runtime. + +## Requirements + +- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime +- ~2 GB free disk for the model weights (downloaded on first run) + +## Quick start + +```bash +make build # build the Docker image +make run # start the service on :8000 (GPU) +make logs # watch startup — first run downloads BiRefNet weights +make test # send test.jpg, save output.png +``` + +`make test` waits for the service `/health` endpoint before sending the +request, so the first call may block while the model downloads and loads. + +### Web UI + +A minimal test page is served at the service root — open +**http://localhost:8000/** in a browser, drop in an image, and preview the +transparent-background result (handy when working over SSH). It calls the +same `/predict` endpoint. + +### Useful variations + +```bash +make test BG=white # composite onto a white background +make test INPUT=photo.jpg OUTPUT=cut.png +make test-mask # also save the raw alpha mask (mask.png) +make help # list all targets +``` + +## API + +`POST /predict` + +```jsonc +{ + "image": "", // required + "background": "alpha", // alpha|white|black|gray|green|blue|red + "mask_blur": 0, // Gaussian blur radius on mask edges + "return_mask": false // include the raw mask in the response +} +``` + +Response: + +```jsonc +{ + "image": "", + "format": "png", + "width": 3637, + "height": 3637, + "mask": "" // only when return_mask=true +} +``` + +`GET /health` returns 200 when the service is ready. + +## Configuration (environment variables) + +| Variable | Default | Purpose | +|----------------------|----------------------|----------------------------------| +| `PORT` | `8000` | HTTP port | +| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights | +| `BIREFNET_RESOLUTION`| `1024` | Inference resolution | +| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) | + +## Local development (no Docker) + +Requires a local CUDA-capable PyTorch environment. + +```bash +make dev # uv sync + run the server locally +``` + +## Layout + +``` +src/birefnet_service/model.py BiRefNet wrapper (load + inference) +src/birefnet_service/server.py LitServe API + web UI route +src/birefnet_service/static/ web UI (index.html) +scripts/client.py stdlib-only test client +Dockerfile / docker-compose.yml CUDA image + nvidia runtime +Makefile build / run / test shortcuts +``` diff --git a/compose.yml b/compose.yml new file mode 100644 index 0000000..402dc0f --- /dev/null +++ b/compose.yml @@ -0,0 +1,28 @@ +services: + birefnet: + build: . + image: birefnet-service:latest + container_name: birefnet-service + ports: + - "${PORT:-8000}:8000" + environment: + - NVIDIA_VISIBLE_DEVICES=all + - NVIDIA_DRIVER_CAPABILITIES=compute,utility + # Default variant/resolution; both are also selectable per request. + - BIREFNET_MODEL=${BIREFNET_MODEL:-general} + - BIREFNET_RESOLUTION=${BIREFNET_RESOLUTION:-1024} + # Use the nvidia-container-runtime for GPU acceleration. + runtime: nvidia + volumes: + # Persist downloaded BiRefNet weights across container restarts. + - hf-cache:/app/hf_cache + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] + interval: 15s + timeout: 5s + retries: 30 + start_period: 180s + restart: unless-stopped + +volumes: + hf-cache: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3adf538 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[project] +name = "birefnet-service" +version = "0.1.0" +description = "BiRefNet background removal as a GPU-accelerated API" +readme = "README.md" +requires-python = ">=3.12,<3.13" +dependencies = [ + "torch==2.5.1", + "torchvision==0.20.1", + "transformers>=4.44,<5", + "timm>=1.0.0", + "einops>=0.8.0", + "kornia>=0.7.0", + "pillow>=10.0.0", + "numpy>=1.26", + "litserve>=0.2.4", +] + +[project.scripts] +birefnet-service = "birefnet_service.server:run" + +[dependency-groups] +dev = ["ruff>=0.6.0"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/birefnet_service"] + +# BiRefNet (torch) needs CUDA wheels; pull torch/torchvision from the PyTorch index. +[[tool.uv.index]] +name = "pytorch-cu124" +url = "https://download.pytorch.org/whl/cu124" +explicit = true + +[tool.uv.sources] +torch = { index = "pytorch-cu124" } +torchvision = { index = "pytorch-cu124" } + +[tool.ruff] +line-length = 100 diff --git a/scripts/client.py b/scripts/client.py new file mode 100644 index 0000000..3a6772b --- /dev/null +++ b/scripts/client.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +"""Minimal stdlib-only client for the BiRefNet service. + +Encodes an image, posts it to /predict, and saves the returned PNG. +No third-party dependencies so it can run with any system Python. +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import sys +import time +import urllib.error +import urllib.request + + +def wait_for_health(base_url: str, timeout: float) -> None: + deadline = time.time() + timeout + health = f"{base_url}/health" + while time.time() < deadline: + try: + with urllib.request.urlopen(health, timeout=5) as resp: + if resp.status == 200: + return + except (urllib.error.URLError, ConnectionError, OSError): + pass + time.sleep(2) + sys.exit(f"server at {health} not ready after {timeout:.0f}s") + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--url", default="http://localhost:8000", help="service base URL") + ap.add_argument("--input", default="test.jpg", help="input image path") + ap.add_argument("--output", default="output.png", help="output PNG path") + ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red") + ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite") + ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048") + ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box") + ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches") + ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges") + ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path") + ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health") + args = ap.parse_args() + + base_url = args.url.rstrip("/") + wait_for_health(base_url, args.wait) + + with open(args.input, "rb") as f: + payload = { + "image": base64.b64encode(f.read()).decode("ascii"), + "background": args.background, + "mask_blur": args.mask_blur, + "return_mask": args.mask_output is not None, + } + if args.model is not None: + payload["model"] = args.model + if args.resolution is not None: + payload["resolution"] = args.resolution + if args.crop: + payload["crop"] = True + payload["crop_margin"] = args.crop_margin + + req = urllib.request.Request( + f"{base_url}/predict", + data=json.dumps(payload).encode(), + headers={"Content-Type": "application/json"}, + method="POST", + ) + + started = time.time() + with urllib.request.urlopen(req, timeout=300) as resp: + result = json.loads(resp.read()) + elapsed = time.time() - started + + with open(args.output, "wb") as f: + f.write(base64.b64decode(result["image"])) + print( + f"saved {args.output} {result['width']}x{result['height']} " + f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)" + ) + + if args.mask_output and "mask" in result: + with open(args.mask_output, "wb") as f: + f.write(base64.b64decode(result["mask"])) + print(f"saved {args.mask_output}") + + +if __name__ == "__main__": + main() diff --git a/src/birefnet_service/__init__.py b/src/birefnet_service/__init__.py new file mode 100644 index 0000000..a26fa96 --- /dev/null +++ b/src/birefnet_service/__init__.py @@ -0,0 +1,3 @@ +"""BiRefNet background removal service.""" + +__version__ = "0.1.0" diff --git a/src/birefnet_service/model.py b/src/birefnet_service/model.py new file mode 100644 index 0000000..b14ebdb --- /dev/null +++ b/src/birefnet_service/model.py @@ -0,0 +1,185 @@ +"""BiRefNet model wrapper for background removal. + +Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports +multiple model variants (lazily loaded + cached) and a tunable inference +resolution, both selectable per request. +""" + +from __future__ import annotations + +import os +import threading + +import torch +from PIL import Image, ImageFilter +from torchvision import transforms +from transformers import AutoModelForImageSegmentation + +# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed. +# BiRefNet variants: general is fast for clean single subjects; HR is for large +# / detailed scenes and needs a higher resolution (>=1536) to perform well. +MODEL_ALIASES = { + "general": "ZhengPeng7/BiRefNet", + "HR": "ZhengPeng7/BiRefNet_HR", + "portrait": "ZhengPeng7/BiRefNet-portrait", + "matting": "ZhengPeng7/BiRefNet-matting", + "lite": "ZhengPeng7/BiRefNet_lite", +} + +DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general") +DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024")) + +# ImageNet normalization, matching BiRefNet training. +_MEAN = [0.485, 0.456, 0.406] +_STD = [0.229, 0.224, 0.225] + +BG_COLORS: dict[str, tuple[int, int, int]] = { + "black": (0, 0, 0), + "white": (255, 255, 255), + "gray": (128, 128, 128), + "green": (0, 255, 0), + "blue": (0, 0, 255), + "red": (255, 0, 0), +} + + +def resolve_repo(model: str | None) -> str: + """Map a friendly variant name to a repo id (pass-through for raw ids).""" + model = model or DEFAULT_MODEL + return MODEL_ALIASES.get(model, model) + + +def _normalize_resolution(resolution: int | None) -> int: + """BiRefNet's backbone needs the input side divisible by 32.""" + res = int(resolution or DEFAULT_RESOLUTION) + return max(256, (res // 32) * 32) + + +class BiRefNetService: + """Runs BiRefNet background removal; caches loaded model variants.""" + + def __init__( + self, + device: str | None = None, + default_model: str = DEFAULT_MODEL, + default_resolution: int = DEFAULT_RESOLUTION, + ): + want_cuda = device != "cpu" and torch.cuda.is_available() + if device and device not in ("auto", "cpu"): + self.device = device + else: + self.device = "cuda" if want_cuda else "cpu" + self.use_half = self.device.startswith("cuda") + self.default_model = default_model + self.default_resolution = default_resolution + + torch.set_float32_matmul_precision("high") + self._models: dict[str, torch.nn.Module] = {} + self._lock = threading.Lock() + + # Preload the default variant so the worker is ready on first request. + self._get_model(resolve_repo(default_model)) + + def _get_model(self, repo: str) -> torch.nn.Module: + with self._lock: + model = self._models.get(repo) + if model is None: + model = AutoModelForImageSegmentation.from_pretrained( + repo, trust_remote_code=True + ) + model.eval().to(self.device) + if self.use_half: + model.half() + self._models[repo] = model + return model + + @torch.inference_mode() + def infer_mask(self, image: Image.Image, repo: str, resolution: int) -> Image.Image: + """Return a single-channel ('L') alpha mask at the image's original size.""" + model = self._get_model(repo) + w, h = image.size + transform = transforms.Compose( + [ + transforms.Resize( + (resolution, resolution), + interpolation=transforms.InterpolationMode.BICUBIC, + ), + transforms.ToTensor(), + transforms.Normalize(_MEAN, _STD), + ] + ) + x = transform(image).unsqueeze(0).to(self.device) + if self.use_half: + x = x.half() + pred = model(x)[-1].sigmoid().float().cpu()[0] # [1, R, R] + mask = transforms.ToPILImage()(pred) + return mask.resize((w, h), Image.BICUBIC) + + def remove_background( + self, + image: Image.Image, + model: str | None = None, + resolution: int | None = None, + background: str = "alpha", + mask_blur: int = 0, + crop: bool = False, + crop_margin: float = 0.0, + return_mask: bool = False, + ) -> dict: + """Run background removal. + + model: variant alias ('general', 'HR', ...) or a raw HF repo id. + resolution: inference resolution (rounded down to a multiple of 32). + background: "alpha" for transparency, or a key from BG_COLORS. + mask_blur: Gaussian blur radius applied to the mask edges. + crop: crop the output to the foreground's bounding box. + crop_margin: extra margin around the crop, in inches (uses image DPI). + """ + # DPI for inch->pixel margin conversion; default 96 if not embedded. + dpi = image.info.get("dpi") + dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0 + + image = image.convert("RGB") + repo = resolve_repo(model) + resolution = _normalize_resolution(resolution) + + mask = self.infer_mask(image, repo, resolution) + if mask_blur > 0: + mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur)) + + cutout = image.convert("RGBA") + cutout.putalpha(mask) + + background = (background or "alpha").lower() + if background == "alpha": + result = cutout + else: + if background not in BG_COLORS: + raise ValueError( + f"Unknown background '{background}'. " + f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}" + ) + bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255)) + result = Image.alpha_composite(bg, cutout).convert("RGB") + + out: dict = {"image": result, "model": repo, "resolution": resolution} + + if crop: + # Bounding box of the foreground (mask above a low alpha threshold). + bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox() + if bbox is not None: + margin_px = round(max(0.0, crop_margin) * dpi_x) + left = max(0, bbox[0] - margin_px) + top = max(0, bbox[1] - margin_px) + right = min(image.width, bbox[2] + margin_px) + bottom = min(image.height, bbox[3] + margin_px) + box = (left, top, right, bottom) + result = result.crop(box) + mask = mask.crop(box) + out["image"] = result + out["crop_box"] = box + out["dpi"] = round(dpi_x, 1) + + if return_mask: + out["mask"] = mask + return out diff --git a/src/birefnet_service/server.py b/src/birefnet_service/server.py new file mode 100644 index 0000000..a437ef2 --- /dev/null +++ b/src/birefnet_service/server.py @@ -0,0 +1,113 @@ +"""LitServe API exposing BiRefNet background removal. + +Endpoint: POST /predict +Request JSON: + { + "image": "", (required) + "model": "general" | "HR" | "portrait" | ..., (default "general") + "resolution": 1024, (default 1024) + "background": "alpha" | "white" | "black" | ..., (default "alpha") + "mask_blur": 0, (default 0) + "return_mask": false (default false) + } +Response JSON: + { + "image": "", + "format": "png", + "width": int, + "height": int, + "model": "", + "resolution": int, + "mask": "" (only when return_mask=true) + } + +A minimal web UI is served at GET / (same origin as /predict). +""" + +from __future__ import annotations + +import base64 +import io +import os +from pathlib import Path + +import litserve as ls +from fastapi.responses import HTMLResponse +from PIL import Image, ImageOps + +from .model import BiRefNetService + +_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8") + + +def _b64_to_image(data: str) -> Image.Image: + image = Image.open(io.BytesIO(base64.b64decode(data))) + # Honor EXIF orientation so portrait photos aren't processed/returned sideways. + return ImageOps.exif_transpose(image) + + +def _image_to_b64(image: Image.Image) -> str: + buf = io.BytesIO() + image.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +class BiRefNetAPI(ls.LitAPI): + def setup(self, device: str) -> None: + self.service = BiRefNetService(device=device) + + def decode_request(self, request: dict) -> dict: + if "image" not in request: + raise ValueError("Request must include a base64 'image' field.") + return { + "image": _b64_to_image(request["image"]), + "model": request.get("model"), + "resolution": request.get("resolution"), + "background": request.get("background", "alpha"), + "mask_blur": int(request.get("mask_blur", 0)), + "crop": bool(request.get("crop", False)), + "crop_margin": float(request.get("crop_margin", 0.0)), + "return_mask": bool(request.get("return_mask", False)), + } + + def predict(self, inputs: dict) -> dict: + return self.service.remove_background(**inputs) + + def encode_response(self, output: dict) -> dict: + image: Image.Image = output["image"] + response = { + "image": _image_to_b64(image), + "format": "png", + "width": image.width, + "height": image.height, + "model": output["model"], + "resolution": output["resolution"], + } + if output.get("mask") is not None: + response["mask"] = _image_to_b64(output["mask"]) + return response + + +def run() -> None: + server = ls.LitServer( + BiRefNetAPI(), + accelerator="auto", + devices=1, + timeout=int(os.getenv("REQUEST_TIMEOUT", "120")), + ) + + # LitServe registers its own "/" route ("litserve running"); drop it so + # our UI can own the root path. Served same-origin as /predict (no CORS). + server.app.router.routes = [ + r for r in server.app.router.routes if getattr(r, "path", None) != "/" + ] + + @server.app.get("/", response_class=HTMLResponse) + def index() -> str: + return _UI_HTML + + server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False) + + +if __name__ == "__main__": + run() diff --git a/src/birefnet_service/static/index.html b/src/birefnet_service/static/index.html new file mode 100644 index 0000000..4fe2c51 --- /dev/null +++ b/src/birefnet_service/static/index.html @@ -0,0 +1,303 @@ + + + + + +BiRefNet — Background Removal + + + +
+

BiRefNet — Background Removal

+
Drop an image to get a transparent-background PNG.
+ +
+

Drop an image here or click to choose

+

No file selected

+ +
+ +
+ + + + +
+
Tip: large or busy scenes segment best with HR at 2048. + The general model expects a clear single subject at 1024.
+ +
+ + + +
+ +
+
+

Original

+
+
+
+

Result

+
+
+
+
+ + + + + +