rmbg option + some other post-processing

This commit is contained in:
Michael Pilosov 2026-05-16 17:04:56 -06:00
parent 96d16fc654
commit 4efb4b8a2f
7 changed files with 529 additions and 144 deletions

View File

@ -1,78 +1,98 @@
# BiRefNet Background Removal Service
GPU-accelerated background removal exposed as an HTTP API. Uses
[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with
[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the
NVIDIA container runtime.
GPU-accelerated background removal as an HTTP API. Two pipelines:
- **Auto** — [BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) /
[RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) salient-object matting.
- **Prompt** — [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-tiny)
+ [SAM](https://huggingface.co/facebook/sam-vit-base): segment whatever a text
prompt describes.
Served with [LitServe](https://github.com/Lightning-AI/LitServe), packaged for
the NVIDIA container runtime.
## Requirements
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
- ~2 GB free disk for the model weights (downloaded on first run)
- ~5 GB free disk for model weights (downloaded on first use, cached in a volume)
## Quick start
```bash
make build # build the Docker image
make run # start the service on :8000 (GPU)
make logs # watch startup — first run downloads BiRefNet weights
make logs # watch startup — first run downloads model weights
make test # send test.jpg, save output.png
```
`make test` waits for the service `/health` endpoint before sending the
request, so the first call may block while the model downloads and loads.
`make test` waits for `/health` before sending, so the first call may block
while a model downloads and loads.
### Web UI
A minimal test page is served at the service root — open
**http://localhost:8000/** in a browser, drop in an image, and preview the
transparent-background result (handy when working over SSH). It calls the
same `/predict` endpoint.
Open **http://localhost:8000/** — a two-tab test page (handy over SSH):
### Useful variations
- **Auto remove** — pick a model variant + resolution.
- **Prompt segment** — type what to keep (e.g. `the dog`), tune the
GroundingDINO box / text thresholds.
```bash
make test BG=white # composite onto a white background
make test INPUT=photo.jpg OUTPUT=cut.png
make test-mask # also save the raw alpha mask (mask.png)
make help # list all targets
```
Both tabs support a transparency checkerboard preview, click-to-zoom lightbox,
optional crop-to-subject, and download.
## API
`POST /predict`
### `POST /predict` — automatic background removal
```jsonc
{
"image": "<base64 image bytes>", // required
"model": "HR", // general|HR|portrait|matting|lite|rmbg2
"resolution": 2048, // inference resolution (×32)
"background": "alpha", // alpha|white|black|gray|green|blue|red
"mask_blur": 0, // Gaussian blur radius on mask edges
"crop": false, // crop to the foreground bounding box
"crop_margin": 0.0, // crop margin in inches (uses image DPI)
"return_mask": false // include the raw mask in the response
}
```
Response:
### `POST /segment` — prompt-conditioned segmentation
```jsonc
{
"image": "<base64 PNG>",
"format": "png",
"width": 3637,
"height": 3637,
"mask": "<base64 PNG>" // only when return_mask=true
"image": "<base64 image bytes>", // required
"prompt": "the dog", // required — object(s) to keep
"box_threshold": 0.3, // GroundingDINO detection threshold
"text_threshold": 0.25,
"background": "alpha",
"mask_blur": 0,
"crop": false,
"crop_margin": 0.0
}
```
Response (both): `image` (base64 PNG), `format`, `width`, `height`, plus
`model`/`resolution` (`/predict`) or `detections`/`prompt` (`/segment`).
`GET /health` returns 200 when the service is ready.
## CLI
```bash
python3 scripts/client.py --input photo.jpg --output cut.png --model HR --resolution 2048 --crop
python3 scripts/client.py --input photo.jpg --output dog.png --prompt "the dog" --crop
```
## Configuration (environment variables)
| Variable | Default | Purpose |
|----------------------|----------------------|----------------------------------|
| `PORT` | `8000` | HTTP port |
| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights |
| `BIREFNET_RESOLUTION`| `1024` | Inference resolution |
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
| Variable | Default | Purpose |
|----------------------|--------------------------------|-------------------------------|
| `PORT` | `8000` | HTTP port |
| `BIREFNET_MODEL` | `general` | Default Auto variant |
| `BIREFNET_RESOLUTION`| `1024` | Default Auto resolution |
| `DINO_MODEL` | `IDEA-Research/grounding-dino-tiny` | GroundingDINO checkpoint |
| `SAM_MODEL` | `facebook/sam-vit-large` | SAM checkpoint |
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
## Local development (no Docker)
@ -85,10 +105,11 @@ make dev # uv sync + run the server locally
## Layout
```
src/birefnet_service/model.py BiRefNet wrapper (load + inference)
src/birefnet_service/server.py LitServe API + web UI route
src/birefnet_service/static/ web UI (index.html)
scripts/client.py stdlib-only test client
Dockerfile / docker-compose.yml CUDA image + nvidia runtime
Makefile build / run / test shortcuts
src/birefnet_service/model.py BiRefNet / RMBG-2.0 wrapper + compositing
src/birefnet_service/prompt_segment.py GroundingDINO + SAM pipeline
src/birefnet_service/server.py LitServe /predict + /segment + web UI
src/birefnet_service/static/ web UI (index.html)
scripts/client.py stdlib-only test client
Dockerfile / compose.yml CUDA image + nvidia runtime
Makefile build / run / test shortcuts
```

View File

@ -1,7 +1,7 @@
[project]
name = "birefnet-service"
name = "rmbg-as-a-service"
version = "0.1.0"
description = "BiRefNet background removal as a GPU-accelerated API"
description = "Background removal as a GPU-accelerated API"
readme = "README.md"
requires-python = ">=3.12,<3.13"
dependencies = [

View File

@ -40,6 +40,9 @@ def main() -> None:
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
ap.add_argument("--prompt", default=None, help="if set, use prompt segmentation (/segment)")
ap.add_argument("--box-threshold", type=float, default=0.3, help="GroundingDINO box threshold")
ap.add_argument("--text-threshold", type=float, default=0.25, help="GroundingDINO text threshold")
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
@ -55,16 +58,24 @@ def main() -> None:
"mask_blur": args.mask_blur,
"return_mask": args.mask_output is not None,
}
if args.model is not None:
payload["model"] = args.model
if args.resolution is not None:
payload["resolution"] = args.resolution
if args.crop:
payload["crop"] = True
payload["crop_margin"] = args.crop_margin
if args.prompt is not None:
endpoint = "/segment"
payload["prompt"] = args.prompt
payload["box_threshold"] = args.box_threshold
payload["text_threshold"] = args.text_threshold
else:
endpoint = "/predict"
if args.model is not None:
payload["model"] = args.model
if args.resolution is not None:
payload["resolution"] = args.resolution
req = urllib.request.Request(
f"{base_url}/predict",
f"{base_url}{endpoint}",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
method="POST",
@ -77,9 +88,13 @@ def main() -> None:
with open(args.output, "wb") as f:
f.write(base64.b64decode(result["image"]))
if "detections" in result:
detail = f"{result['detections']} object(s) matched '{result.get('prompt')}'"
else:
detail = f"{result.get('model')} @ {result.get('resolution')}"
print(
f"saved {args.output} {result['width']}x{result['height']} "
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
f"{detail} ({elapsed:.1f}s)"
)
if args.mask_output and "mask" in result:

View File

@ -1,8 +1,8 @@
"""BiRefNet model wrapper for background removal.
Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports
multiple model variants (lazily loaded + cached) and a tunable inference
resolution, both selectable per request.
Loads BiRefNet / RMBG-2.0 weights via ``transformers`` (trust_remote_code).
Supports multiple model variants (lazily loaded + cached) and a tunable
inference resolution, both selectable per request.
"""
from __future__ import annotations
@ -16,20 +16,21 @@ from torchvision import transforms
from transformers import AutoModelForImageSegmentation
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
# BiRefNet variants: general is fast for clean single subjects; HR is for large
# / detailed scenes and needs a higher resolution (>=1536) to perform well.
# general: fast for clean single subjects. HR: large / detailed scenes (>=1536).
# rmbg2: BRIA RMBG-2.0 (a BiRefNet-architecture model), loaded the same way.
MODEL_ALIASES = {
"general": "ZhengPeng7/BiRefNet",
"HR": "ZhengPeng7/BiRefNet_HR",
"portrait": "ZhengPeng7/BiRefNet-portrait",
"matting": "ZhengPeng7/BiRefNet-matting",
"lite": "ZhengPeng7/BiRefNet_lite",
"rmbg2": "1038lab/RMBG-2.0",
}
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
# ImageNet normalization, matching BiRefNet training.
# ImageNet normalization, matching BiRefNet / RMBG-2.0 training.
_MEAN = [0.485, 0.456, 0.406]
_STD = [0.229, 0.224, 0.225]
@ -55,8 +56,73 @@ def _normalize_resolution(resolution: int | None) -> int:
return max(256, (res // 32) * 32)
def dpi_of(image: Image.Image) -> float:
"""Horizontal DPI from image metadata; 96 if not embedded."""
dpi = image.info.get("dpi")
return float(dpi[0]) if dpi and dpi[0] else 96.0
def apply_mask(
image: Image.Image,
mask: Image.Image,
background: str = "alpha",
mask_blur: int = 0,
mask_offset: int = 0,
crop: bool = False,
crop_margin: float = 0.0,
dpi: float = 96.0,
) -> dict:
"""Composite an RGB image with an 'L' mask into a result dict.
Shared by every removal pipeline (BiRefNet and prompt-based). Handles
edge offset (grow/shrink), edge blur, solid-background compositing, and
optional bounding-box crop.
Returns {"image": ..., "mask": ..., "crop_box"?: ...}.
"""
# Grow (positive) or shrink (negative) the cutout — kills edge halos/fringe.
if mask_offset:
edge_filter = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
for _ in range(abs(int(mask_offset))):
mask = mask.filter(edge_filter(3))
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
cutout = image.convert("RGBA")
cutout.putalpha(mask)
background = (background or "alpha").lower()
if background == "alpha":
result = cutout
else:
if background not in BG_COLORS:
raise ValueError(
f"Unknown background '{background}'. "
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
)
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
result = Image.alpha_composite(bg, cutout).convert("RGB")
out: dict = {"image": result, "mask": mask}
if crop:
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
if bbox is not None:
margin_px = round(max(0.0, crop_margin) * dpi)
box = (
max(0, bbox[0] - margin_px),
max(0, bbox[1] - margin_px),
min(image.width, bbox[2] + margin_px),
min(image.height, bbox[3] + margin_px),
)
out["image"] = result.crop(box)
out["mask"] = mask.crop(box)
out["crop_box"] = box
return out
class BiRefNetService:
"""Runs BiRefNet background removal; caches loaded model variants."""
"""Runs BiRefNet / RMBG-2.0 background removal; caches loaded variants."""
def __init__(
self,
@ -122,64 +188,34 @@ class BiRefNetService:
resolution: int | None = None,
background: str = "alpha",
mask_blur: int = 0,
mask_offset: int = 0,
crop: bool = False,
crop_margin: float = 0.0,
return_mask: bool = False,
) -> dict:
"""Run background removal.
model: variant alias ('general', 'HR', ...) or a raw HF repo id.
model: variant alias ('general', 'HR', 'rmbg2', ...) or a repo id.
resolution: inference resolution (rounded down to a multiple of 32).
background: "alpha" for transparency, or a key from BG_COLORS.
mask_blur: Gaussian blur radius applied to the mask edges.
mask_offset: grow (+) or shrink (-) the cutout edge, in pixels.
crop: crop the output to the foreground's bounding box.
crop_margin: extra margin around the crop, in inches (uses image DPI).
"""
# DPI for inch->pixel margin conversion; default 96 if not embedded.
dpi = image.info.get("dpi")
dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0
dpi = dpi_of(image)
image = image.convert("RGB")
repo = resolve_repo(model)
resolution = _normalize_resolution(resolution)
mask = self.infer_mask(image, repo, resolution)
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
cutout = image.convert("RGBA")
cutout.putalpha(mask)
background = (background or "alpha").lower()
if background == "alpha":
result = cutout
else:
if background not in BG_COLORS:
raise ValueError(
f"Unknown background '{background}'. "
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
)
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
result = Image.alpha_composite(bg, cutout).convert("RGB")
out: dict = {"image": result, "model": repo, "resolution": resolution}
if crop:
# Bounding box of the foreground (mask above a low alpha threshold).
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
if bbox is not None:
margin_px = round(max(0.0, crop_margin) * dpi_x)
left = max(0, bbox[0] - margin_px)
top = max(0, bbox[1] - margin_px)
right = min(image.width, bbox[2] + margin_px)
bottom = min(image.height, bbox[3] + margin_px)
box = (left, top, right, bottom)
result = result.crop(box)
mask = mask.crop(box)
out["image"] = result
out["crop_box"] = box
out["dpi"] = round(dpi_x, 1)
if return_mask:
out["mask"] = mask
out = apply_mask(
image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi
)
out["model"] = repo
out["resolution"] = resolution
if "crop_box" in out:
out["dpi"] = round(dpi, 1)
if not return_mask:
out.pop("mask", None)
return out

View File

@ -0,0 +1,135 @@
"""Prompt-conditioned segmentation: GroundingDINO + SAM.
Given a text prompt, GroundingDINO detects matching boxes and SAM turns those
boxes into masks. The union of the masks becomes the foreground alpha which
is then composited/cropped by the shared ``apply_mask`` helper.
Both models come from ``transformers`` (no custom CUDA extensions).
"""
from __future__ import annotations
import os
import threading
import torch
from PIL import Image
from transformers import (
AutoProcessor,
GroundingDinoForObjectDetection,
SamModel,
SamProcessor,
)
from .model import apply_mask, dpi_of
DINO_MODEL = os.getenv("DINO_MODEL", "IDEA-Research/grounding-dino-tiny")
SAM_MODEL = os.getenv("SAM_MODEL", "facebook/sam-vit-large")
class PromptSegmenter:
"""Text-prompted segmentation. Models load lazily on first use."""
def __init__(self, device: str | None = None):
want_cuda = device != "cpu" and torch.cuda.is_available()
self.device = "cuda" if want_cuda else "cpu"
self._lock = threading.Lock()
self._ready = False
def _ensure_loaded(self) -> None:
if self._ready:
return
with self._lock:
if self._ready:
return
self.dino_processor = AutoProcessor.from_pretrained(DINO_MODEL)
self.dino = GroundingDinoForObjectDetection.from_pretrained(DINO_MODEL)
self.dino.eval().to(self.device)
self.sam_processor = SamProcessor.from_pretrained(SAM_MODEL)
self.sam = SamModel.from_pretrained(SAM_MODEL)
self.sam.eval().to(self.device)
self._ready = True
@torch.inference_mode()
def _detect(
self, image: Image.Image, prompt: str, box_threshold: float, text_threshold: float
) -> torch.Tensor:
# GroundingDINO expects lowercase phrases separated/terminated by '.'.
text = prompt.strip().lower().replace(",", ".")
if not text.endswith("."):
text += "."
inputs = self.dino_processor(images=image, text=text, return_tensors="pt").to(
self.device
)
outputs = self.dino(**inputs)
results = self.dino_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[(image.height, image.width)],
)[0]
return results["boxes"] # [N, 4] xyxy, pixel coords
@torch.inference_mode()
def _mask_from_boxes(self, image: Image.Image, boxes: torch.Tensor) -> Image.Image:
inputs = self.sam_processor(
image, input_boxes=[boxes.tolist()], return_tensors="pt"
).to(self.device)
outputs = self.sam(**inputs)
masks = self.sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0] # [N, 3, H, W] bool
iou = outputs.iou_scores.cpu()[0] # [N, 3]
# Best of SAM's 3 candidates per box, then union all boxes' masks.
best = iou.argmax(dim=-1)
chosen = torch.stack([masks[i, best[i]] for i in range(masks.shape[0])])
union = chosen.any(dim=0) # [H, W] bool
return Image.fromarray((union.numpy() * 255).astype("uint8"), mode="L")
def segment(
self,
image: Image.Image,
prompt: str,
background: str = "alpha",
mask_blur: int = 0,
mask_offset: int = 0,
crop: bool = False,
crop_margin: float = 0.0,
box_threshold: float = 0.3,
text_threshold: float = 0.25,
return_mask: bool = False,
) -> dict:
"""Segment whatever ``prompt`` describes and remove the rest.
prompt: object(s) to keep, e.g. "the dog" or "cow. person.".
box_threshold / text_threshold: GroundingDINO detection thresholds.
Other args match ``apply_mask`` / the BiRefNet pipeline.
"""
if not prompt or not prompt.strip():
raise ValueError("A non-empty 'prompt' is required.")
self._ensure_loaded()
dpi = dpi_of(image)
image = image.convert("RGB")
with self._lock:
boxes = self._detect(image, prompt, box_threshold, text_threshold)
detections = int(len(boxes))
if detections == 0:
mask = Image.new("L", image.size, 0)
else:
mask = self._mask_from_boxes(image, boxes)
out = apply_mask(
image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi
)
out["prompt"] = prompt
out["detections"] = detections
if "crop_box" in out:
out["dpi"] = round(dpi, 1)
if not return_mask:
out.pop("mask", None)
return out

View File

@ -1,27 +1,34 @@
"""LitServe API exposing BiRefNet background removal.
"""LitServe API exposing BiRefNet background removal + prompt segmentation.
Endpoint: POST /predict
Request JSON:
Endpoints:
POST /predict BiRefNet / RMBG-2.0 automatic background removal (LitServe)
POST /segment GroundingDINO + SAM prompt-conditioned segmentation
GET / minimal web UI (two tabs: Auto / Prompt)
GET /health readiness
/predict request JSON:
{
"image": "<base64-encoded image bytes>", (required)
"model": "general" | "HR" | "portrait" | ..., (default "general")
"resolution": 1024, (default 1024)
"background": "alpha" | "white" | "black" | ..., (default "alpha")
"mask_blur": 0, (default 0)
"return_mask": false (default false)
}
Response JSON:
{
"image": "<base64 PNG>",
"format": "png",
"width": int,
"height": int,
"model": "<repo id used>",
"resolution": int,
"mask": "<base64 PNG>" (only when return_mask=true)
"image": "<base64>", (required)
"model": "general" | "HR" | "rmbg2" | ..., (default "general")
"resolution": 1024,
"background": "alpha" | "white" | ...,
"mask_blur": 0,
"crop": false,
"crop_margin": 0.0, (inches)
"return_mask": false
}
A minimal web UI is served at GET / (same origin as /predict).
/segment request JSON:
{
"image": "<base64>", (required)
"prompt": "the dog", (required)
"box_threshold": 0.3,
"text_threshold": 0.25,
"background": "alpha" | ...,
"mask_blur": 0,
"crop": false,
"crop_margin": 0.0
}
"""
from __future__ import annotations
@ -29,16 +36,32 @@ from __future__ import annotations
import base64
import io
import os
import threading
from pathlib import Path
import litserve as ls
from fastapi import HTTPException
from fastapi.responses import HTMLResponse
from PIL import Image, ImageOps
from .model import BiRefNetService
from .prompt_segment import PromptSegmenter
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
# Lazily-created prompt segmenter (DINO + SAM), shared by the /segment route.
_segmenter: PromptSegmenter | None = None
_segmenter_lock = threading.Lock()
def _get_segmenter() -> PromptSegmenter:
global _segmenter
if _segmenter is None:
with _segmenter_lock:
if _segmenter is None:
_segmenter = PromptSegmenter()
return _segmenter
def _b64_to_image(data: str) -> Image.Image:
image = Image.open(io.BytesIO(base64.b64decode(data)))
@ -65,6 +88,7 @@ class BiRefNetAPI(ls.LitAPI):
"resolution": request.get("resolution"),
"background": request.get("background", "alpha"),
"mask_blur": int(request.get("mask_blur", 0)),
"mask_offset": int(request.get("mask_offset", 0)),
"crop": bool(request.get("crop", False)),
"crop_margin": float(request.get("crop_margin", 0.0)),
"return_mask": bool(request.get("return_mask", False)),
@ -83,6 +107,8 @@ class BiRefNetAPI(ls.LitAPI):
"model": output["model"],
"resolution": output["resolution"],
}
if "crop_box" in output:
response["cropped"] = True
if output.get("mask") is not None:
response["mask"] = _image_to_b64(output["mask"])
return response
@ -97,7 +123,7 @@ def run() -> None:
)
# LitServe registers its own "/" route ("litserve running"); drop it so
# our UI can own the root path. Served same-origin as /predict (no CORS).
# our UI can own the root path. Served same-origin as the APIs (no CORS).
server.app.router.routes = [
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
]
@ -106,6 +132,36 @@ def run() -> None:
def index() -> str:
return _UI_HTML
@server.app.post("/segment")
def segment(payload: dict) -> dict:
"""Prompt-conditioned segmentation (GroundingDINO + SAM)."""
if "image" not in payload:
raise HTTPException(status_code=400, detail="Missing base64 'image'.")
image = _b64_to_image(payload["image"])
try:
result = _get_segmenter().segment(
image,
prompt=payload.get("prompt", ""),
background=payload.get("background", "alpha"),
mask_blur=int(payload.get("mask_blur", 0)),
mask_offset=int(payload.get("mask_offset", 0)),
crop=bool(payload.get("crop", False)),
crop_margin=float(payload.get("crop_margin", 0.0)),
box_threshold=float(payload.get("box_threshold", 0.3)),
text_threshold=float(payload.get("text_threshold", 0.25)),
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
image_out: Image.Image = result["image"]
return {
"image": _image_to_b64(image_out),
"format": "png",
"width": image_out.width,
"height": image_out.height,
"detections": result["detections"],
"prompt": result["prompt"],
}
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)

View File

@ -3,7 +3,7 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>BiRefNet — Background Removal</title>
<title>Background Removal &amp; Segmentation</title>
<style>
:root { color-scheme: dark; }
* { box-sizing: border-box; }
@ -12,36 +12,62 @@
background: #15171c; color: #e8e8ea; padding: 24px;
}
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 20px; }
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 16px; }
.wrap { max-width: 1100px; margin: 0 auto; }
.tabs { display: flex; gap: 4px; margin-bottom: 16px; border-bottom: 1px solid #2a2f3a; }
.tab { background: none; border: 0; color: #8a8f99; font-size: .9rem; font-weight: 600;
padding: 10px 16px; cursor: pointer; border-bottom: 2px solid transparent; }
.tab.active { color: #e8e8ea; border-bottom-color: #5b8cff; }
#drop {
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
}
#drop.over { border-color: #5b8cff; background: #1c2230; }
#drop p { margin: 6px 0; color: #8a8f99; }
.controls { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
.controls { display: flex; gap: 12px; align-items: flex-end; margin: 14px 0; flex-wrap: wrap; }
.controls[hidden] { display: none; }
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
select, input[type=number] {
select, input[type=number], input[type=text] {
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
}
input[type=number] { width: 78px; }
input[type=number]:disabled { opacity: .45; }
input[type=text]#prompt { width: 320px; }
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
button {
/* help tooltips */
.help { display: inline-flex; align-items: center; justify-content: center;
width: 14px; height: 14px; margin-left: 5px; border-radius: 50%;
border: 1px solid #4a4f5b; color: #8a8f99; font-size: 9px; font-weight: 700;
font-style: normal; cursor: help; position: relative; vertical-align: middle; }
.help:hover { color: #e8e8ea; border-color: #5b8cff; }
.help:hover::after {
content: attr(data-tip); position: absolute; bottom: 150%; left: 50%;
transform: translateX(-50%); width: 220px; background: #0c0d11;
color: #d8d9dc; border: 1px solid #3a3f4b; border-radius: 6px;
padding: 7px 9px; font-size: .72rem; font-weight: 400; line-height: 1.4;
text-transform: none; letter-spacing: normal; white-space: normal;
z-index: 50; pointer-events: none; }
button.go {
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
}
button:disabled { background: #3a3f4b; cursor: not-allowed; }
button.ghost { background: #2a2f3a; }
.go-row { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
button.go:disabled { background: #3a3f4b; cursor: not-allowed; }
button.ghost { background: #2a2f3a; color: #fff; border: 0; border-radius: 8px;
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600; }
.go-row { display: flex; gap: 12px; align-items: center; margin: 14px 0; flex-wrap: wrap; }
.status { color: #8a8f99; font-size: .85rem; }
.status.err { color: #ff6b6b; }
.hint { color: #6b7280; font-size: .78rem; margin-top: -8px; }
.hint { color: #6b7280; font-size: .78rem; margin: -4px 0 4px; }
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
@ -84,8 +110,13 @@
</head>
<body>
<div class="wrap">
<h1>BiRefNet — Background Removal</h1>
<div class="sub">Drop an image to get a transparent-background PNG.</div>
<h1>Background Removal &amp; Segmentation</h1>
<div class="sub">Automatic removal, or prompt-conditioned segmentation.</div>
<div class="tabs">
<button class="tab active" data-tab="auto">Auto remove</button>
<button class="tab" data-tab="prompt">Prompt segment</button>
</div>
<div id="drop">
<p><strong>Drop an image here</strong> or click to choose</p>
@ -93,7 +124,8 @@
<input id="file" type="file" accept="image/*" hidden />
</div>
<div class="controls">
<!-- Auto (BiRefNet / RMBG-2.0) controls -->
<div class="controls" id="ctl-auto">
<label class="field">Model
<select id="model">
<option value="general">general — clean single subjects (fast)</option>
@ -101,6 +133,7 @@
<option value="portrait">portrait — people</option>
<option value="matting">matting — soft edges / hair</option>
<option value="lite">lite — fastest</option>
<option value="rmbg2">rmbg2 — BRIA RMBG-2.0</option>
</select>
</label>
<label class="field">Resolution
@ -111,16 +144,54 @@
<option value="2560" selected>2560</option>
</select>
</label>
</div>
<!-- Prompt (GroundingDINO + SAM) controls -->
<div class="controls" id="ctl-prompt" hidden>
<label class="field">Prompt — what to keep
<input type="text" id="prompt" placeholder="e.g. the dog · cow. person." />
</label>
<label class="field">
<span>Box threshold<span class="help" data-tip="Minimum confidence for GroundingDINO to keep a detected box. Lower finds more (and looser) objects; higher keeps only strong matches.">?</span></span>
<input type="number" id="boxThr" value="0.3" min="0" max="1" step="0.05" />
</label>
<label class="field">
<span>Text threshold<span class="help" data-tip="How strongly a detection must match your prompt words. Lower = looser word matching; higher = stricter.">?</span></span>
<input type="number" id="textThr" value="0.25" min="0" max="1" step="0.05" />
</label>
</div>
<!-- Shared output controls -->
<div class="controls">
<label class="field">Background
<select id="background">
<option value="alpha" selected>transparent</option>
<option value="white">white</option>
<option value="black">black</option>
<option value="gray">gray</option>
<option value="green">green</option>
<option value="blue">blue</option>
<option value="red">red</option>
</select>
</label>
<label class="field">
<span>Edge offset (px)<span class="help" data-tip="Grow (+) or shrink () the cutout edge by N pixels. A small negative value trims a leftover background-colored fringe around hair or fur.">?</span></span>
<input type="number" id="maskOffset" value="0" min="-20" max="20" step="1" />
</label>
<label class="field">
<span>Feather (px)<span class="help" data-tip="Gaussian blur applied to the mask edge, in pixels. Softens the cutout for smoother compositing onto a new background.">?</span></span>
<input type="number" id="maskBlur" value="0" min="0" max="64" step="1" />
</label>
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
<label class="field">Margin (in)
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
</label>
</div>
<div class="hint">Tip: large or busy scenes segment best with <strong>HR</strong> at <strong>2048</strong>.
The <em>general</em> model expects a clear single subject at 1024.</div>
<div class="hint" id="hint"></div>
<div class="go-row">
<button id="go" disabled>Remove background</button>
<button class="go" id="go" disabled>Remove background</button>
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
<span id="status" class="status"></span>
</div>
@ -155,14 +226,42 @@ const dlbtn = document.getElementById('dlbtn');
const statusEl = document.getElementById('status');
const srcImg = document.getElementById('src');
const outImg = document.getElementById('out');
const hint = document.getElementById('hint');
const modelSel = document.getElementById('model');
const resSel = document.getElementById('resolution');
const promptInput = document.getElementById('prompt');
const boxThr = document.getElementById('boxThr');
const textThr = document.getElementById('textThr');
const bgSel = document.getElementById('background');
const maskOffset = document.getElementById('maskOffset');
const maskBlur = document.getElementById('maskBlur');
const cropChk = document.getElementById('crop');
const cropMargin = document.getElementById('cropMargin');
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
const ctlAuto = document.getElementById('ctl-auto');
const ctlPrompt = document.getElementById('ctl-prompt');
let selectedFile = null;
let tab = 'auto';
const HINTS = {
auto: 'Large or busy scenes segment best with HR at 2048+. The general model expects a clear single subject at 1024.',
prompt: 'Type what to keep, e.g. "the dog" (or several: "cow. person."). Lower the box threshold to detect more / fainter objects.',
};
function setTab(name) {
tab = name;
document.querySelectorAll('.tab').forEach(t =>
t.classList.toggle('active', t.dataset.tab === name));
ctlAuto.hidden = name !== 'auto';
ctlPrompt.hidden = name !== 'prompt';
go.textContent = name === 'auto' ? 'Remove background' : 'Segment';
hint.textContent = HINTS[name];
}
document.querySelectorAll('.tab').forEach(t =>
t.addEventListener('click', () => setTab(t.dataset.tab)));
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
function setStatus(msg, isErr) {
statusEl.textContent = msg;
@ -264,23 +363,38 @@ lbImg.addEventListener('dblclick', e => { e.preventDefault(); lbReset(); });
go.addEventListener('click', async () => {
if (!selectedFile) return;
if (tab === 'prompt' && !promptInput.value.trim()) {
setStatus('Enter a prompt describing what to keep.', true);
return;
}
go.disabled = true;
dlbtn.disabled = true;
setStatus('Processing… (first use of a model downloads its weights)');
const t0 = performance.now();
try {
const b64 = await fileToBase64(selectedFile);
const resp = await fetch('/predict', {
const shared = {
image: b64,
background: bgSel.value,
mask_offset: parseInt(maskOffset.value, 10) || 0,
mask_blur: parseInt(maskBlur.value, 10) || 0,
crop: cropChk.checked,
crop_margin: parseFloat(cropMargin.value) || 0,
};
let endpoint, body;
if (tab === 'auto') {
endpoint = '/predict';
body = { ...shared, model: modelSel.value, resolution: parseInt(resSel.value, 10) };
} else {
endpoint = '/segment';
body = { ...shared, prompt: promptInput.value.trim(),
box_threshold: parseFloat(boxThr.value) || 0.3,
text_threshold: parseFloat(textThr.value) || 0.25 };
}
const resp = await fetch(endpoint, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
image: b64,
background: 'alpha',
model: modelSel.value,
resolution: parseInt(resSel.value, 10),
crop: cropChk.checked,
crop_margin: parseFloat(cropMargin.value) || 0,
}),
body: JSON.stringify(body),
});
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
const data = await resp.json();
@ -290,14 +404,22 @@ go.addEventListener('click', async () => {
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
dlbtn.disabled = false;
const secs = ((performance.now() - t0) / 1000).toFixed(1);
setStatus('Done — ' + data.width + '×' + data.height + ' · ' + data.model +
' @ ' + data.resolution + ' · ' + secs + 's');
if (tab === 'auto') {
setStatus(`Done — ${data.width}×${data.height} · ${data.model} @ ${data.resolution} · ${secs}s`);
} else {
const n = data.detections;
setStatus(`Done — ${n} object${n === 1 ? '' : 's'} matched "${data.prompt}" · ` +
`${data.width}×${data.height} · ${secs}s` +
(n === 0 ? ' (try a lower box threshold)' : ''));
}
} catch (err) {
setStatus(err.message || String(err), true);
} finally {
go.disabled = false;
}
});
setTab('auto');
</script>
</body>
</html>