rmbg option + some other post-processing
This commit is contained in:
parent
96d16fc654
commit
4efb4b8a2f
97
README.md
97
README.md
@ -1,78 +1,98 @@
|
|||||||
# BiRefNet Background Removal Service
|
# BiRefNet Background Removal Service
|
||||||
|
|
||||||
GPU-accelerated background removal exposed as an HTTP API. Uses
|
GPU-accelerated background removal as an HTTP API. Two pipelines:
|
||||||
[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with
|
|
||||||
[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the
|
- **Auto** — [BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) /
|
||||||
NVIDIA container runtime.
|
[RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) salient-object matting.
|
||||||
|
- **Prompt** — [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-tiny)
|
||||||
|
+ [SAM](https://huggingface.co/facebook/sam-vit-base): segment whatever a text
|
||||||
|
prompt describes.
|
||||||
|
|
||||||
|
Served with [LitServe](https://github.com/Lightning-AI/LitServe), packaged for
|
||||||
|
the NVIDIA container runtime.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
|
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
|
||||||
- ~2 GB free disk for the model weights (downloaded on first run)
|
- ~5 GB free disk for model weights (downloaded on first use, cached in a volume)
|
||||||
|
|
||||||
## Quick start
|
## Quick start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make build # build the Docker image
|
make build # build the Docker image
|
||||||
make run # start the service on :8000 (GPU)
|
make run # start the service on :8000 (GPU)
|
||||||
make logs # watch startup — first run downloads BiRefNet weights
|
make logs # watch startup — first run downloads model weights
|
||||||
make test # send test.jpg, save output.png
|
make test # send test.jpg, save output.png
|
||||||
```
|
```
|
||||||
|
|
||||||
`make test` waits for the service `/health` endpoint before sending the
|
`make test` waits for `/health` before sending, so the first call may block
|
||||||
request, so the first call may block while the model downloads and loads.
|
while a model downloads and loads.
|
||||||
|
|
||||||
### Web UI
|
### Web UI
|
||||||
|
|
||||||
A minimal test page is served at the service root — open
|
Open **http://localhost:8000/** — a two-tab test page (handy over SSH):
|
||||||
**http://localhost:8000/** in a browser, drop in an image, and preview the
|
|
||||||
transparent-background result (handy when working over SSH). It calls the
|
|
||||||
same `/predict` endpoint.
|
|
||||||
|
|
||||||
### Useful variations
|
- **Auto remove** — pick a model variant + resolution.
|
||||||
|
- **Prompt segment** — type what to keep (e.g. `the dog`), tune the
|
||||||
|
GroundingDINO box / text thresholds.
|
||||||
|
|
||||||
```bash
|
Both tabs support a transparency checkerboard preview, click-to-zoom lightbox,
|
||||||
make test BG=white # composite onto a white background
|
optional crop-to-subject, and download.
|
||||||
make test INPUT=photo.jpg OUTPUT=cut.png
|
|
||||||
make test-mask # also save the raw alpha mask (mask.png)
|
|
||||||
make help # list all targets
|
|
||||||
```
|
|
||||||
|
|
||||||
## API
|
## API
|
||||||
|
|
||||||
`POST /predict`
|
### `POST /predict` — automatic background removal
|
||||||
|
|
||||||
```jsonc
|
```jsonc
|
||||||
{
|
{
|
||||||
"image": "<base64 image bytes>", // required
|
"image": "<base64 image bytes>", // required
|
||||||
|
"model": "HR", // general|HR|portrait|matting|lite|rmbg2
|
||||||
|
"resolution": 2048, // inference resolution (×32)
|
||||||
"background": "alpha", // alpha|white|black|gray|green|blue|red
|
"background": "alpha", // alpha|white|black|gray|green|blue|red
|
||||||
"mask_blur": 0, // Gaussian blur radius on mask edges
|
"mask_blur": 0, // Gaussian blur radius on mask edges
|
||||||
|
"crop": false, // crop to the foreground bounding box
|
||||||
|
"crop_margin": 0.0, // crop margin in inches (uses image DPI)
|
||||||
"return_mask": false // include the raw mask in the response
|
"return_mask": false // include the raw mask in the response
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Response:
|
### `POST /segment` — prompt-conditioned segmentation
|
||||||
|
|
||||||
```jsonc
|
```jsonc
|
||||||
{
|
{
|
||||||
"image": "<base64 PNG>",
|
"image": "<base64 image bytes>", // required
|
||||||
"format": "png",
|
"prompt": "the dog", // required — object(s) to keep
|
||||||
"width": 3637,
|
"box_threshold": 0.3, // GroundingDINO detection threshold
|
||||||
"height": 3637,
|
"text_threshold": 0.25,
|
||||||
"mask": "<base64 PNG>" // only when return_mask=true
|
"background": "alpha",
|
||||||
|
"mask_blur": 0,
|
||||||
|
"crop": false,
|
||||||
|
"crop_margin": 0.0
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Response (both): `image` (base64 PNG), `format`, `width`, `height`, plus
|
||||||
|
`model`/`resolution` (`/predict`) or `detections`/`prompt` (`/segment`).
|
||||||
|
|
||||||
`GET /health` returns 200 when the service is ready.
|
`GET /health` returns 200 when the service is ready.
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 scripts/client.py --input photo.jpg --output cut.png --model HR --resolution 2048 --crop
|
||||||
|
python3 scripts/client.py --input photo.jpg --output dog.png --prompt "the dog" --crop
|
||||||
|
```
|
||||||
|
|
||||||
## Configuration (environment variables)
|
## Configuration (environment variables)
|
||||||
|
|
||||||
| Variable | Default | Purpose |
|
| Variable | Default | Purpose |
|
||||||
|----------------------|----------------------|----------------------------------|
|
|----------------------|--------------------------------|-------------------------------|
|
||||||
| `PORT` | `8000` | HTTP port |
|
| `PORT` | `8000` | HTTP port |
|
||||||
| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights |
|
| `BIREFNET_MODEL` | `general` | Default Auto variant |
|
||||||
| `BIREFNET_RESOLUTION`| `1024` | Inference resolution |
|
| `BIREFNET_RESOLUTION`| `1024` | Default Auto resolution |
|
||||||
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
|
| `DINO_MODEL` | `IDEA-Research/grounding-dino-tiny` | GroundingDINO checkpoint |
|
||||||
|
| `SAM_MODEL` | `facebook/sam-vit-large` | SAM checkpoint |
|
||||||
|
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
|
||||||
|
|
||||||
## Local development (no Docker)
|
## Local development (no Docker)
|
||||||
|
|
||||||
@ -85,10 +105,11 @@ make dev # uv sync + run the server locally
|
|||||||
## Layout
|
## Layout
|
||||||
|
|
||||||
```
|
```
|
||||||
src/birefnet_service/model.py BiRefNet wrapper (load + inference)
|
src/birefnet_service/model.py BiRefNet / RMBG-2.0 wrapper + compositing
|
||||||
src/birefnet_service/server.py LitServe API + web UI route
|
src/birefnet_service/prompt_segment.py GroundingDINO + SAM pipeline
|
||||||
src/birefnet_service/static/ web UI (index.html)
|
src/birefnet_service/server.py LitServe /predict + /segment + web UI
|
||||||
scripts/client.py stdlib-only test client
|
src/birefnet_service/static/ web UI (index.html)
|
||||||
Dockerfile / docker-compose.yml CUDA image + nvidia runtime
|
scripts/client.py stdlib-only test client
|
||||||
Makefile build / run / test shortcuts
|
Dockerfile / compose.yml CUDA image + nvidia runtime
|
||||||
|
Makefile build / run / test shortcuts
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "birefnet-service"
|
name = "rmbg-as-a-service"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "BiRefNet background removal as a GPU-accelerated API"
|
description = "Background removal as a GPU-accelerated API"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12,<3.13"
|
requires-python = ">=3.12,<3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|||||||
@ -40,6 +40,9 @@ def main() -> None:
|
|||||||
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
|
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
|
||||||
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
|
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
|
||||||
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
|
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
|
||||||
|
ap.add_argument("--prompt", default=None, help="if set, use prompt segmentation (/segment)")
|
||||||
|
ap.add_argument("--box-threshold", type=float, default=0.3, help="GroundingDINO box threshold")
|
||||||
|
ap.add_argument("--text-threshold", type=float, default=0.25, help="GroundingDINO text threshold")
|
||||||
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
|
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
|
||||||
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
|
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
|
||||||
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
|
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
|
||||||
@ -55,16 +58,24 @@ def main() -> None:
|
|||||||
"mask_blur": args.mask_blur,
|
"mask_blur": args.mask_blur,
|
||||||
"return_mask": args.mask_output is not None,
|
"return_mask": args.mask_output is not None,
|
||||||
}
|
}
|
||||||
if args.model is not None:
|
|
||||||
payload["model"] = args.model
|
|
||||||
if args.resolution is not None:
|
|
||||||
payload["resolution"] = args.resolution
|
|
||||||
if args.crop:
|
if args.crop:
|
||||||
payload["crop"] = True
|
payload["crop"] = True
|
||||||
payload["crop_margin"] = args.crop_margin
|
payload["crop_margin"] = args.crop_margin
|
||||||
|
|
||||||
|
if args.prompt is not None:
|
||||||
|
endpoint = "/segment"
|
||||||
|
payload["prompt"] = args.prompt
|
||||||
|
payload["box_threshold"] = args.box_threshold
|
||||||
|
payload["text_threshold"] = args.text_threshold
|
||||||
|
else:
|
||||||
|
endpoint = "/predict"
|
||||||
|
if args.model is not None:
|
||||||
|
payload["model"] = args.model
|
||||||
|
if args.resolution is not None:
|
||||||
|
payload["resolution"] = args.resolution
|
||||||
|
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
f"{base_url}/predict",
|
f"{base_url}{endpoint}",
|
||||||
data=json.dumps(payload).encode(),
|
data=json.dumps(payload).encode(),
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
method="POST",
|
method="POST",
|
||||||
@ -77,9 +88,13 @@ def main() -> None:
|
|||||||
|
|
||||||
with open(args.output, "wb") as f:
|
with open(args.output, "wb") as f:
|
||||||
f.write(base64.b64decode(result["image"]))
|
f.write(base64.b64decode(result["image"]))
|
||||||
|
if "detections" in result:
|
||||||
|
detail = f"{result['detections']} object(s) matched '{result.get('prompt')}'"
|
||||||
|
else:
|
||||||
|
detail = f"{result.get('model')} @ {result.get('resolution')}"
|
||||||
print(
|
print(
|
||||||
f"saved {args.output} {result['width']}x{result['height']} "
|
f"saved {args.output} {result['width']}x{result['height']} "
|
||||||
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
|
f"{detail} ({elapsed:.1f}s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.mask_output and "mask" in result:
|
if args.mask_output and "mask" in result:
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
"""BiRefNet model wrapper for background removal.
|
"""BiRefNet model wrapper for background removal.
|
||||||
|
|
||||||
Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports
|
Loads BiRefNet / RMBG-2.0 weights via ``transformers`` (trust_remote_code).
|
||||||
multiple model variants (lazily loaded + cached) and a tunable inference
|
Supports multiple model variants (lazily loaded + cached) and a tunable
|
||||||
resolution, both selectable per request.
|
inference resolution, both selectable per request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -16,20 +16,21 @@ from torchvision import transforms
|
|||||||
from transformers import AutoModelForImageSegmentation
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
|
||||||
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
|
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
|
||||||
# BiRefNet variants: general is fast for clean single subjects; HR is for large
|
# general: fast for clean single subjects. HR: large / detailed scenes (>=1536).
|
||||||
# / detailed scenes and needs a higher resolution (>=1536) to perform well.
|
# rmbg2: BRIA RMBG-2.0 (a BiRefNet-architecture model), loaded the same way.
|
||||||
MODEL_ALIASES = {
|
MODEL_ALIASES = {
|
||||||
"general": "ZhengPeng7/BiRefNet",
|
"general": "ZhengPeng7/BiRefNet",
|
||||||
"HR": "ZhengPeng7/BiRefNet_HR",
|
"HR": "ZhengPeng7/BiRefNet_HR",
|
||||||
"portrait": "ZhengPeng7/BiRefNet-portrait",
|
"portrait": "ZhengPeng7/BiRefNet-portrait",
|
||||||
"matting": "ZhengPeng7/BiRefNet-matting",
|
"matting": "ZhengPeng7/BiRefNet-matting",
|
||||||
"lite": "ZhengPeng7/BiRefNet_lite",
|
"lite": "ZhengPeng7/BiRefNet_lite",
|
||||||
|
"rmbg2": "1038lab/RMBG-2.0",
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
|
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
|
||||||
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
|
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
|
||||||
|
|
||||||
# ImageNet normalization, matching BiRefNet training.
|
# ImageNet normalization, matching BiRefNet / RMBG-2.0 training.
|
||||||
_MEAN = [0.485, 0.456, 0.406]
|
_MEAN = [0.485, 0.456, 0.406]
|
||||||
_STD = [0.229, 0.224, 0.225]
|
_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
@ -55,8 +56,73 @@ def _normalize_resolution(resolution: int | None) -> int:
|
|||||||
return max(256, (res // 32) * 32)
|
return max(256, (res // 32) * 32)
|
||||||
|
|
||||||
|
|
||||||
|
def dpi_of(image: Image.Image) -> float:
|
||||||
|
"""Horizontal DPI from image metadata; 96 if not embedded."""
|
||||||
|
dpi = image.info.get("dpi")
|
||||||
|
return float(dpi[0]) if dpi and dpi[0] else 96.0
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mask(
|
||||||
|
image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
background: str = "alpha",
|
||||||
|
mask_blur: int = 0,
|
||||||
|
mask_offset: int = 0,
|
||||||
|
crop: bool = False,
|
||||||
|
crop_margin: float = 0.0,
|
||||||
|
dpi: float = 96.0,
|
||||||
|
) -> dict:
|
||||||
|
"""Composite an RGB image with an 'L' mask into a result dict.
|
||||||
|
|
||||||
|
Shared by every removal pipeline (BiRefNet and prompt-based). Handles
|
||||||
|
edge offset (grow/shrink), edge blur, solid-background compositing, and
|
||||||
|
optional bounding-box crop.
|
||||||
|
Returns {"image": ..., "mask": ..., "crop_box"?: ...}.
|
||||||
|
"""
|
||||||
|
# Grow (positive) or shrink (negative) the cutout — kills edge halos/fringe.
|
||||||
|
if mask_offset:
|
||||||
|
edge_filter = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
|
||||||
|
for _ in range(abs(int(mask_offset))):
|
||||||
|
mask = mask.filter(edge_filter(3))
|
||||||
|
if mask_blur > 0:
|
||||||
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
||||||
|
|
||||||
|
cutout = image.convert("RGBA")
|
||||||
|
cutout.putalpha(mask)
|
||||||
|
|
||||||
|
background = (background or "alpha").lower()
|
||||||
|
if background == "alpha":
|
||||||
|
result = cutout
|
||||||
|
else:
|
||||||
|
if background not in BG_COLORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown background '{background}'. "
|
||||||
|
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
|
||||||
|
)
|
||||||
|
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
|
||||||
|
result = Image.alpha_composite(bg, cutout).convert("RGB")
|
||||||
|
|
||||||
|
out: dict = {"image": result, "mask": mask}
|
||||||
|
|
||||||
|
if crop:
|
||||||
|
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
|
||||||
|
if bbox is not None:
|
||||||
|
margin_px = round(max(0.0, crop_margin) * dpi)
|
||||||
|
box = (
|
||||||
|
max(0, bbox[0] - margin_px),
|
||||||
|
max(0, bbox[1] - margin_px),
|
||||||
|
min(image.width, bbox[2] + margin_px),
|
||||||
|
min(image.height, bbox[3] + margin_px),
|
||||||
|
)
|
||||||
|
out["image"] = result.crop(box)
|
||||||
|
out["mask"] = mask.crop(box)
|
||||||
|
out["crop_box"] = box
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class BiRefNetService:
|
class BiRefNetService:
|
||||||
"""Runs BiRefNet background removal; caches loaded model variants."""
|
"""Runs BiRefNet / RMBG-2.0 background removal; caches loaded variants."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -122,64 +188,34 @@ class BiRefNetService:
|
|||||||
resolution: int | None = None,
|
resolution: int | None = None,
|
||||||
background: str = "alpha",
|
background: str = "alpha",
|
||||||
mask_blur: int = 0,
|
mask_blur: int = 0,
|
||||||
|
mask_offset: int = 0,
|
||||||
crop: bool = False,
|
crop: bool = False,
|
||||||
crop_margin: float = 0.0,
|
crop_margin: float = 0.0,
|
||||||
return_mask: bool = False,
|
return_mask: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Run background removal.
|
"""Run background removal.
|
||||||
|
|
||||||
model: variant alias ('general', 'HR', ...) or a raw HF repo id.
|
model: variant alias ('general', 'HR', 'rmbg2', ...) or a repo id.
|
||||||
resolution: inference resolution (rounded down to a multiple of 32).
|
resolution: inference resolution (rounded down to a multiple of 32).
|
||||||
background: "alpha" for transparency, or a key from BG_COLORS.
|
background: "alpha" for transparency, or a key from BG_COLORS.
|
||||||
mask_blur: Gaussian blur radius applied to the mask edges.
|
mask_blur: Gaussian blur radius applied to the mask edges.
|
||||||
|
mask_offset: grow (+) or shrink (-) the cutout edge, in pixels.
|
||||||
crop: crop the output to the foreground's bounding box.
|
crop: crop the output to the foreground's bounding box.
|
||||||
crop_margin: extra margin around the crop, in inches (uses image DPI).
|
crop_margin: extra margin around the crop, in inches (uses image DPI).
|
||||||
"""
|
"""
|
||||||
# DPI for inch->pixel margin conversion; default 96 if not embedded.
|
dpi = dpi_of(image)
|
||||||
dpi = image.info.get("dpi")
|
|
||||||
dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0
|
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
repo = resolve_repo(model)
|
repo = resolve_repo(model)
|
||||||
resolution = _normalize_resolution(resolution)
|
resolution = _normalize_resolution(resolution)
|
||||||
|
|
||||||
mask = self.infer_mask(image, repo, resolution)
|
mask = self.infer_mask(image, repo, resolution)
|
||||||
if mask_blur > 0:
|
out = apply_mask(
|
||||||
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi
|
||||||
|
)
|
||||||
cutout = image.convert("RGBA")
|
out["model"] = repo
|
||||||
cutout.putalpha(mask)
|
out["resolution"] = resolution
|
||||||
|
if "crop_box" in out:
|
||||||
background = (background or "alpha").lower()
|
out["dpi"] = round(dpi, 1)
|
||||||
if background == "alpha":
|
if not return_mask:
|
||||||
result = cutout
|
out.pop("mask", None)
|
||||||
else:
|
|
||||||
if background not in BG_COLORS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown background '{background}'. "
|
|
||||||
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
|
|
||||||
)
|
|
||||||
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
|
|
||||||
result = Image.alpha_composite(bg, cutout).convert("RGB")
|
|
||||||
|
|
||||||
out: dict = {"image": result, "model": repo, "resolution": resolution}
|
|
||||||
|
|
||||||
if crop:
|
|
||||||
# Bounding box of the foreground (mask above a low alpha threshold).
|
|
||||||
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
|
|
||||||
if bbox is not None:
|
|
||||||
margin_px = round(max(0.0, crop_margin) * dpi_x)
|
|
||||||
left = max(0, bbox[0] - margin_px)
|
|
||||||
top = max(0, bbox[1] - margin_px)
|
|
||||||
right = min(image.width, bbox[2] + margin_px)
|
|
||||||
bottom = min(image.height, bbox[3] + margin_px)
|
|
||||||
box = (left, top, right, bottom)
|
|
||||||
result = result.crop(box)
|
|
||||||
mask = mask.crop(box)
|
|
||||||
out["image"] = result
|
|
||||||
out["crop_box"] = box
|
|
||||||
out["dpi"] = round(dpi_x, 1)
|
|
||||||
|
|
||||||
if return_mask:
|
|
||||||
out["mask"] = mask
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
135
src/birefnet_service/prompt_segment.py
Normal file
135
src/birefnet_service/prompt_segment.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
"""Prompt-conditioned segmentation: GroundingDINO + SAM.
|
||||||
|
|
||||||
|
Given a text prompt, GroundingDINO detects matching boxes and SAM turns those
|
||||||
|
boxes into masks. The union of the masks becomes the foreground alpha — which
|
||||||
|
is then composited/cropped by the shared ``apply_mask`` helper.
|
||||||
|
|
||||||
|
Both models come from ``transformers`` (no custom CUDA extensions).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
GroundingDinoForObjectDetection,
|
||||||
|
SamModel,
|
||||||
|
SamProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .model import apply_mask, dpi_of
|
||||||
|
|
||||||
|
DINO_MODEL = os.getenv("DINO_MODEL", "IDEA-Research/grounding-dino-tiny")
|
||||||
|
SAM_MODEL = os.getenv("SAM_MODEL", "facebook/sam-vit-large")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptSegmenter:
|
||||||
|
"""Text-prompted segmentation. Models load lazily on first use."""
|
||||||
|
|
||||||
|
def __init__(self, device: str | None = None):
|
||||||
|
want_cuda = device != "cpu" and torch.cuda.is_available()
|
||||||
|
self.device = "cuda" if want_cuda else "cpu"
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._ready = False
|
||||||
|
|
||||||
|
def _ensure_loaded(self) -> None:
|
||||||
|
if self._ready:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
if self._ready:
|
||||||
|
return
|
||||||
|
self.dino_processor = AutoProcessor.from_pretrained(DINO_MODEL)
|
||||||
|
self.dino = GroundingDinoForObjectDetection.from_pretrained(DINO_MODEL)
|
||||||
|
self.dino.eval().to(self.device)
|
||||||
|
self.sam_processor = SamProcessor.from_pretrained(SAM_MODEL)
|
||||||
|
self.sam = SamModel.from_pretrained(SAM_MODEL)
|
||||||
|
self.sam.eval().to(self.device)
|
||||||
|
self._ready = True
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _detect(
|
||||||
|
self, image: Image.Image, prompt: str, box_threshold: float, text_threshold: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# GroundingDINO expects lowercase phrases separated/terminated by '.'.
|
||||||
|
text = prompt.strip().lower().replace(",", ".")
|
||||||
|
if not text.endswith("."):
|
||||||
|
text += "."
|
||||||
|
inputs = self.dino_processor(images=image, text=text, return_tensors="pt").to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
outputs = self.dino(**inputs)
|
||||||
|
results = self.dino_processor.post_process_grounded_object_detection(
|
||||||
|
outputs,
|
||||||
|
inputs["input_ids"],
|
||||||
|
threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold,
|
||||||
|
target_sizes=[(image.height, image.width)],
|
||||||
|
)[0]
|
||||||
|
return results["boxes"] # [N, 4] xyxy, pixel coords
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _mask_from_boxes(self, image: Image.Image, boxes: torch.Tensor) -> Image.Image:
|
||||||
|
inputs = self.sam_processor(
|
||||||
|
image, input_boxes=[boxes.tolist()], return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
outputs = self.sam(**inputs)
|
||||||
|
masks = self.sam_processor.image_processor.post_process_masks(
|
||||||
|
outputs.pred_masks.cpu(),
|
||||||
|
inputs["original_sizes"].cpu(),
|
||||||
|
inputs["reshaped_input_sizes"].cpu(),
|
||||||
|
)[0] # [N, 3, H, W] bool
|
||||||
|
iou = outputs.iou_scores.cpu()[0] # [N, 3]
|
||||||
|
# Best of SAM's 3 candidates per box, then union all boxes' masks.
|
||||||
|
best = iou.argmax(dim=-1)
|
||||||
|
chosen = torch.stack([masks[i, best[i]] for i in range(masks.shape[0])])
|
||||||
|
union = chosen.any(dim=0) # [H, W] bool
|
||||||
|
return Image.fromarray((union.numpy() * 255).astype("uint8"), mode="L")
|
||||||
|
|
||||||
|
def segment(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
prompt: str,
|
||||||
|
background: str = "alpha",
|
||||||
|
mask_blur: int = 0,
|
||||||
|
mask_offset: int = 0,
|
||||||
|
crop: bool = False,
|
||||||
|
crop_margin: float = 0.0,
|
||||||
|
box_threshold: float = 0.3,
|
||||||
|
text_threshold: float = 0.25,
|
||||||
|
return_mask: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Segment whatever ``prompt`` describes and remove the rest.
|
||||||
|
|
||||||
|
prompt: object(s) to keep, e.g. "the dog" or "cow. person.".
|
||||||
|
box_threshold / text_threshold: GroundingDINO detection thresholds.
|
||||||
|
Other args match ``apply_mask`` / the BiRefNet pipeline.
|
||||||
|
"""
|
||||||
|
if not prompt or not prompt.strip():
|
||||||
|
raise ValueError("A non-empty 'prompt' is required.")
|
||||||
|
|
||||||
|
self._ensure_loaded()
|
||||||
|
dpi = dpi_of(image)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
boxes = self._detect(image, prompt, box_threshold, text_threshold)
|
||||||
|
detections = int(len(boxes))
|
||||||
|
if detections == 0:
|
||||||
|
mask = Image.new("L", image.size, 0)
|
||||||
|
else:
|
||||||
|
mask = self._mask_from_boxes(image, boxes)
|
||||||
|
|
||||||
|
out = apply_mask(
|
||||||
|
image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi
|
||||||
|
)
|
||||||
|
out["prompt"] = prompt
|
||||||
|
out["detections"] = detections
|
||||||
|
if "crop_box" in out:
|
||||||
|
out["dpi"] = round(dpi, 1)
|
||||||
|
if not return_mask:
|
||||||
|
out.pop("mask", None)
|
||||||
|
return out
|
||||||
@ -1,27 +1,34 @@
|
|||||||
"""LitServe API exposing BiRefNet background removal.
|
"""LitServe API exposing BiRefNet background removal + prompt segmentation.
|
||||||
|
|
||||||
Endpoint: POST /predict
|
Endpoints:
|
||||||
Request JSON:
|
POST /predict BiRefNet / RMBG-2.0 automatic background removal (LitServe)
|
||||||
|
POST /segment GroundingDINO + SAM prompt-conditioned segmentation
|
||||||
|
GET / minimal web UI (two tabs: Auto / Prompt)
|
||||||
|
GET /health readiness
|
||||||
|
|
||||||
|
/predict request JSON:
|
||||||
{
|
{
|
||||||
"image": "<base64-encoded image bytes>", (required)
|
"image": "<base64>", (required)
|
||||||
"model": "general" | "HR" | "portrait" | ..., (default "general")
|
"model": "general" | "HR" | "rmbg2" | ..., (default "general")
|
||||||
"resolution": 1024, (default 1024)
|
"resolution": 1024,
|
||||||
"background": "alpha" | "white" | "black" | ..., (default "alpha")
|
"background": "alpha" | "white" | ...,
|
||||||
"mask_blur": 0, (default 0)
|
"mask_blur": 0,
|
||||||
"return_mask": false (default false)
|
"crop": false,
|
||||||
}
|
"crop_margin": 0.0, (inches)
|
||||||
Response JSON:
|
"return_mask": false
|
||||||
{
|
|
||||||
"image": "<base64 PNG>",
|
|
||||||
"format": "png",
|
|
||||||
"width": int,
|
|
||||||
"height": int,
|
|
||||||
"model": "<repo id used>",
|
|
||||||
"resolution": int,
|
|
||||||
"mask": "<base64 PNG>" (only when return_mask=true)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
A minimal web UI is served at GET / (same origin as /predict).
|
/segment request JSON:
|
||||||
|
{
|
||||||
|
"image": "<base64>", (required)
|
||||||
|
"prompt": "the dog", (required)
|
||||||
|
"box_threshold": 0.3,
|
||||||
|
"text_threshold": 0.25,
|
||||||
|
"background": "alpha" | ...,
|
||||||
|
"mask_blur": 0,
|
||||||
|
"crop": false,
|
||||||
|
"crop_margin": 0.0
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -29,16 +36,32 @@ from __future__ import annotations
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import litserve as ls
|
import litserve as ls
|
||||||
|
from fastapi import HTTPException
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from .model import BiRefNetService
|
from .model import BiRefNetService
|
||||||
|
from .prompt_segment import PromptSegmenter
|
||||||
|
|
||||||
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
|
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Lazily-created prompt segmenter (DINO + SAM), shared by the /segment route.
|
||||||
|
_segmenter: PromptSegmenter | None = None
|
||||||
|
_segmenter_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_segmenter() -> PromptSegmenter:
|
||||||
|
global _segmenter
|
||||||
|
if _segmenter is None:
|
||||||
|
with _segmenter_lock:
|
||||||
|
if _segmenter is None:
|
||||||
|
_segmenter = PromptSegmenter()
|
||||||
|
return _segmenter
|
||||||
|
|
||||||
|
|
||||||
def _b64_to_image(data: str) -> Image.Image:
|
def _b64_to_image(data: str) -> Image.Image:
|
||||||
image = Image.open(io.BytesIO(base64.b64decode(data)))
|
image = Image.open(io.BytesIO(base64.b64decode(data)))
|
||||||
@ -65,6 +88,7 @@ class BiRefNetAPI(ls.LitAPI):
|
|||||||
"resolution": request.get("resolution"),
|
"resolution": request.get("resolution"),
|
||||||
"background": request.get("background", "alpha"),
|
"background": request.get("background", "alpha"),
|
||||||
"mask_blur": int(request.get("mask_blur", 0)),
|
"mask_blur": int(request.get("mask_blur", 0)),
|
||||||
|
"mask_offset": int(request.get("mask_offset", 0)),
|
||||||
"crop": bool(request.get("crop", False)),
|
"crop": bool(request.get("crop", False)),
|
||||||
"crop_margin": float(request.get("crop_margin", 0.0)),
|
"crop_margin": float(request.get("crop_margin", 0.0)),
|
||||||
"return_mask": bool(request.get("return_mask", False)),
|
"return_mask": bool(request.get("return_mask", False)),
|
||||||
@ -83,6 +107,8 @@ class BiRefNetAPI(ls.LitAPI):
|
|||||||
"model": output["model"],
|
"model": output["model"],
|
||||||
"resolution": output["resolution"],
|
"resolution": output["resolution"],
|
||||||
}
|
}
|
||||||
|
if "crop_box" in output:
|
||||||
|
response["cropped"] = True
|
||||||
if output.get("mask") is not None:
|
if output.get("mask") is not None:
|
||||||
response["mask"] = _image_to_b64(output["mask"])
|
response["mask"] = _image_to_b64(output["mask"])
|
||||||
return response
|
return response
|
||||||
@ -97,7 +123,7 @@ def run() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# LitServe registers its own "/" route ("litserve running"); drop it so
|
# LitServe registers its own "/" route ("litserve running"); drop it so
|
||||||
# our UI can own the root path. Served same-origin as /predict (no CORS).
|
# our UI can own the root path. Served same-origin as the APIs (no CORS).
|
||||||
server.app.router.routes = [
|
server.app.router.routes = [
|
||||||
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
|
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
|
||||||
]
|
]
|
||||||
@ -106,6 +132,36 @@ def run() -> None:
|
|||||||
def index() -> str:
|
def index() -> str:
|
||||||
return _UI_HTML
|
return _UI_HTML
|
||||||
|
|
||||||
|
@server.app.post("/segment")
|
||||||
|
def segment(payload: dict) -> dict:
|
||||||
|
"""Prompt-conditioned segmentation (GroundingDINO + SAM)."""
|
||||||
|
if "image" not in payload:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing base64 'image'.")
|
||||||
|
image = _b64_to_image(payload["image"])
|
||||||
|
try:
|
||||||
|
result = _get_segmenter().segment(
|
||||||
|
image,
|
||||||
|
prompt=payload.get("prompt", ""),
|
||||||
|
background=payload.get("background", "alpha"),
|
||||||
|
mask_blur=int(payload.get("mask_blur", 0)),
|
||||||
|
mask_offset=int(payload.get("mask_offset", 0)),
|
||||||
|
crop=bool(payload.get("crop", False)),
|
||||||
|
crop_margin=float(payload.get("crop_margin", 0.0)),
|
||||||
|
box_threshold=float(payload.get("box_threshold", 0.3)),
|
||||||
|
text_threshold=float(payload.get("text_threshold", 0.25)),
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
image_out: Image.Image = result["image"]
|
||||||
|
return {
|
||||||
|
"image": _image_to_b64(image_out),
|
||||||
|
"format": "png",
|
||||||
|
"width": image_out.width,
|
||||||
|
"height": image_out.height,
|
||||||
|
"detections": result["detections"],
|
||||||
|
"prompt": result["prompt"],
|
||||||
|
}
|
||||||
|
|
||||||
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)
|
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
<title>BiRefNet — Background Removal</title>
|
<title>Background Removal & Segmentation</title>
|
||||||
<style>
|
<style>
|
||||||
:root { color-scheme: dark; }
|
:root { color-scheme: dark; }
|
||||||
* { box-sizing: border-box; }
|
* { box-sizing: border-box; }
|
||||||
@ -12,36 +12,62 @@
|
|||||||
background: #15171c; color: #e8e8ea; padding: 24px;
|
background: #15171c; color: #e8e8ea; padding: 24px;
|
||||||
}
|
}
|
||||||
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
|
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
|
||||||
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 20px; }
|
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 16px; }
|
||||||
.wrap { max-width: 1100px; margin: 0 auto; }
|
.wrap { max-width: 1100px; margin: 0 auto; }
|
||||||
|
|
||||||
|
.tabs { display: flex; gap: 4px; margin-bottom: 16px; border-bottom: 1px solid #2a2f3a; }
|
||||||
|
.tab { background: none; border: 0; color: #8a8f99; font-size: .9rem; font-weight: 600;
|
||||||
|
padding: 10px 16px; cursor: pointer; border-bottom: 2px solid transparent; }
|
||||||
|
.tab.active { color: #e8e8ea; border-bottom-color: #5b8cff; }
|
||||||
|
|
||||||
#drop {
|
#drop {
|
||||||
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
|
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
|
||||||
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
|
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
|
||||||
}
|
}
|
||||||
#drop.over { border-color: #5b8cff; background: #1c2230; }
|
#drop.over { border-color: #5b8cff; background: #1c2230; }
|
||||||
#drop p { margin: 6px 0; color: #8a8f99; }
|
#drop p { margin: 6px 0; color: #8a8f99; }
|
||||||
.controls { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
|
||||||
|
.controls { display: flex; gap: 12px; align-items: flex-end; margin: 14px 0; flex-wrap: wrap; }
|
||||||
|
.controls[hidden] { display: none; }
|
||||||
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
|
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
|
||||||
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
|
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
|
||||||
select, input[type=number] {
|
select, input[type=number], input[type=text] {
|
||||||
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
|
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
|
||||||
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
|
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
|
||||||
}
|
}
|
||||||
input[type=number] { width: 78px; }
|
input[type=number] { width: 78px; }
|
||||||
input[type=number]:disabled { opacity: .45; }
|
input[type=number]:disabled { opacity: .45; }
|
||||||
|
input[type=text]#prompt { width: 320px; }
|
||||||
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
|
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
|
||||||
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
|
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
|
||||||
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
|
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
|
||||||
button {
|
|
||||||
|
/* help tooltips */
|
||||||
|
.help { display: inline-flex; align-items: center; justify-content: center;
|
||||||
|
width: 14px; height: 14px; margin-left: 5px; border-radius: 50%;
|
||||||
|
border: 1px solid #4a4f5b; color: #8a8f99; font-size: 9px; font-weight: 700;
|
||||||
|
font-style: normal; cursor: help; position: relative; vertical-align: middle; }
|
||||||
|
.help:hover { color: #e8e8ea; border-color: #5b8cff; }
|
||||||
|
.help:hover::after {
|
||||||
|
content: attr(data-tip); position: absolute; bottom: 150%; left: 50%;
|
||||||
|
transform: translateX(-50%); width: 220px; background: #0c0d11;
|
||||||
|
color: #d8d9dc; border: 1px solid #3a3f4b; border-radius: 6px;
|
||||||
|
padding: 7px 9px; font-size: .72rem; font-weight: 400; line-height: 1.4;
|
||||||
|
text-transform: none; letter-spacing: normal; white-space: normal;
|
||||||
|
z-index: 50; pointer-events: none; }
|
||||||
|
|
||||||
|
button.go {
|
||||||
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
|
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
|
||||||
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
|
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
|
||||||
}
|
}
|
||||||
button:disabled { background: #3a3f4b; cursor: not-allowed; }
|
button.go:disabled { background: #3a3f4b; cursor: not-allowed; }
|
||||||
button.ghost { background: #2a2f3a; }
|
button.ghost { background: #2a2f3a; color: #fff; border: 0; border-radius: 8px;
|
||||||
.go-row { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600; }
|
||||||
|
.go-row { display: flex; gap: 12px; align-items: center; margin: 14px 0; flex-wrap: wrap; }
|
||||||
.status { color: #8a8f99; font-size: .85rem; }
|
.status { color: #8a8f99; font-size: .85rem; }
|
||||||
.status.err { color: #ff6b6b; }
|
.status.err { color: #ff6b6b; }
|
||||||
.hint { color: #6b7280; font-size: .78rem; margin-top: -8px; }
|
.hint { color: #6b7280; font-size: .78rem; margin: -4px 0 4px; }
|
||||||
|
|
||||||
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
|
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
|
||||||
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
|
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
|
||||||
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
|
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
|
||||||
@ -84,8 +110,13 @@
|
|||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div class="wrap">
|
<div class="wrap">
|
||||||
<h1>BiRefNet — Background Removal</h1>
|
<h1>Background Removal & Segmentation</h1>
|
||||||
<div class="sub">Drop an image to get a transparent-background PNG.</div>
|
<div class="sub">Automatic removal, or prompt-conditioned segmentation.</div>
|
||||||
|
|
||||||
|
<div class="tabs">
|
||||||
|
<button class="tab active" data-tab="auto">Auto remove</button>
|
||||||
|
<button class="tab" data-tab="prompt">Prompt segment</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div id="drop">
|
<div id="drop">
|
||||||
<p><strong>Drop an image here</strong> or click to choose</p>
|
<p><strong>Drop an image here</strong> or click to choose</p>
|
||||||
@ -93,7 +124,8 @@
|
|||||||
<input id="file" type="file" accept="image/*" hidden />
|
<input id="file" type="file" accept="image/*" hidden />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="controls">
|
<!-- Auto (BiRefNet / RMBG-2.0) controls -->
|
||||||
|
<div class="controls" id="ctl-auto">
|
||||||
<label class="field">Model
|
<label class="field">Model
|
||||||
<select id="model">
|
<select id="model">
|
||||||
<option value="general">general — clean single subjects (fast)</option>
|
<option value="general">general — clean single subjects (fast)</option>
|
||||||
@ -101,6 +133,7 @@
|
|||||||
<option value="portrait">portrait — people</option>
|
<option value="portrait">portrait — people</option>
|
||||||
<option value="matting">matting — soft edges / hair</option>
|
<option value="matting">matting — soft edges / hair</option>
|
||||||
<option value="lite">lite — fastest</option>
|
<option value="lite">lite — fastest</option>
|
||||||
|
<option value="rmbg2">rmbg2 — BRIA RMBG-2.0</option>
|
||||||
</select>
|
</select>
|
||||||
</label>
|
</label>
|
||||||
<label class="field">Resolution
|
<label class="field">Resolution
|
||||||
@ -111,16 +144,54 @@
|
|||||||
<option value="2560" selected>2560</option>
|
<option value="2560" selected>2560</option>
|
||||||
</select>
|
</select>
|
||||||
</label>
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Prompt (GroundingDINO + SAM) controls -->
|
||||||
|
<div class="controls" id="ctl-prompt" hidden>
|
||||||
|
<label class="field">Prompt — what to keep
|
||||||
|
<input type="text" id="prompt" placeholder="e.g. the dog · cow. person." />
|
||||||
|
</label>
|
||||||
|
<label class="field">
|
||||||
|
<span>Box threshold<span class="help" data-tip="Minimum confidence for GroundingDINO to keep a detected box. Lower finds more (and looser) objects; higher keeps only strong matches.">?</span></span>
|
||||||
|
<input type="number" id="boxThr" value="0.3" min="0" max="1" step="0.05" />
|
||||||
|
</label>
|
||||||
|
<label class="field">
|
||||||
|
<span>Text threshold<span class="help" data-tip="How strongly a detection must match your prompt words. Lower = looser word matching; higher = stricter.">?</span></span>
|
||||||
|
<input type="number" id="textThr" value="0.25" min="0" max="1" step="0.05" />
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- Shared output controls -->
|
||||||
|
<div class="controls">
|
||||||
|
<label class="field">Background
|
||||||
|
<select id="background">
|
||||||
|
<option value="alpha" selected>transparent</option>
|
||||||
|
<option value="white">white</option>
|
||||||
|
<option value="black">black</option>
|
||||||
|
<option value="gray">gray</option>
|
||||||
|
<option value="green">green</option>
|
||||||
|
<option value="blue">blue</option>
|
||||||
|
<option value="red">red</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<label class="field">
|
||||||
|
<span>Edge offset (px)<span class="help" data-tip="Grow (+) or shrink (−) the cutout edge by N pixels. A small negative value trims a leftover background-colored fringe around hair or fur.">?</span></span>
|
||||||
|
<input type="number" id="maskOffset" value="0" min="-20" max="20" step="1" />
|
||||||
|
</label>
|
||||||
|
<label class="field">
|
||||||
|
<span>Feather (px)<span class="help" data-tip="Gaussian blur applied to the mask edge, in pixels. Softens the cutout for smoother compositing onto a new background.">?</span></span>
|
||||||
|
<input type="number" id="maskBlur" value="0" min="0" max="64" step="1" />
|
||||||
|
</label>
|
||||||
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
|
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
|
||||||
<label class="field">Margin (in)
|
<label class="field">Margin (in)
|
||||||
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
|
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<div class="hint">Tip: large or busy scenes segment best with <strong>HR</strong> at <strong>2048</strong>.
|
|
||||||
The <em>general</em> model expects a clear single subject at 1024.</div>
|
<div class="hint" id="hint"></div>
|
||||||
|
|
||||||
<div class="go-row">
|
<div class="go-row">
|
||||||
<button id="go" disabled>Remove background</button>
|
<button class="go" id="go" disabled>Remove background</button>
|
||||||
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
|
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
|
||||||
<span id="status" class="status"></span>
|
<span id="status" class="status"></span>
|
||||||
</div>
|
</div>
|
||||||
@ -155,14 +226,42 @@ const dlbtn = document.getElementById('dlbtn');
|
|||||||
const statusEl = document.getElementById('status');
|
const statusEl = document.getElementById('status');
|
||||||
const srcImg = document.getElementById('src');
|
const srcImg = document.getElementById('src');
|
||||||
const outImg = document.getElementById('out');
|
const outImg = document.getElementById('out');
|
||||||
|
const hint = document.getElementById('hint');
|
||||||
|
|
||||||
const modelSel = document.getElementById('model');
|
const modelSel = document.getElementById('model');
|
||||||
const resSel = document.getElementById('resolution');
|
const resSel = document.getElementById('resolution');
|
||||||
|
const promptInput = document.getElementById('prompt');
|
||||||
|
const boxThr = document.getElementById('boxThr');
|
||||||
|
const textThr = document.getElementById('textThr');
|
||||||
|
const bgSel = document.getElementById('background');
|
||||||
|
const maskOffset = document.getElementById('maskOffset');
|
||||||
|
const maskBlur = document.getElementById('maskBlur');
|
||||||
const cropChk = document.getElementById('crop');
|
const cropChk = document.getElementById('crop');
|
||||||
const cropMargin = document.getElementById('cropMargin');
|
const cropMargin = document.getElementById('cropMargin');
|
||||||
|
|
||||||
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
|
const ctlAuto = document.getElementById('ctl-auto');
|
||||||
|
const ctlPrompt = document.getElementById('ctl-prompt');
|
||||||
|
|
||||||
let selectedFile = null;
|
let selectedFile = null;
|
||||||
|
let tab = 'auto';
|
||||||
|
|
||||||
|
const HINTS = {
|
||||||
|
auto: 'Large or busy scenes segment best with HR at 2048+. The general model expects a clear single subject at 1024.',
|
||||||
|
prompt: 'Type what to keep, e.g. "the dog" (or several: "cow. person."). Lower the box threshold to detect more / fainter objects.',
|
||||||
|
};
|
||||||
|
|
||||||
|
function setTab(name) {
|
||||||
|
tab = name;
|
||||||
|
document.querySelectorAll('.tab').forEach(t =>
|
||||||
|
t.classList.toggle('active', t.dataset.tab === name));
|
||||||
|
ctlAuto.hidden = name !== 'auto';
|
||||||
|
ctlPrompt.hidden = name !== 'prompt';
|
||||||
|
go.textContent = name === 'auto' ? 'Remove background' : 'Segment';
|
||||||
|
hint.textContent = HINTS[name];
|
||||||
|
}
|
||||||
|
document.querySelectorAll('.tab').forEach(t =>
|
||||||
|
t.addEventListener('click', () => setTab(t.dataset.tab)));
|
||||||
|
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
|
||||||
|
|
||||||
function setStatus(msg, isErr) {
|
function setStatus(msg, isErr) {
|
||||||
statusEl.textContent = msg;
|
statusEl.textContent = msg;
|
||||||
@ -264,23 +363,38 @@ lbImg.addEventListener('dblclick', e => { e.preventDefault(); lbReset(); });
|
|||||||
|
|
||||||
go.addEventListener('click', async () => {
|
go.addEventListener('click', async () => {
|
||||||
if (!selectedFile) return;
|
if (!selectedFile) return;
|
||||||
|
if (tab === 'prompt' && !promptInput.value.trim()) {
|
||||||
|
setStatus('Enter a prompt describing what to keep.', true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
go.disabled = true;
|
go.disabled = true;
|
||||||
dlbtn.disabled = true;
|
dlbtn.disabled = true;
|
||||||
setStatus('Processing… (first use of a model downloads its weights)');
|
setStatus('Processing… (first use of a model downloads its weights)');
|
||||||
const t0 = performance.now();
|
const t0 = performance.now();
|
||||||
try {
|
try {
|
||||||
const b64 = await fileToBase64(selectedFile);
|
const b64 = await fileToBase64(selectedFile);
|
||||||
const resp = await fetch('/predict', {
|
const shared = {
|
||||||
|
image: b64,
|
||||||
|
background: bgSel.value,
|
||||||
|
mask_offset: parseInt(maskOffset.value, 10) || 0,
|
||||||
|
mask_blur: parseInt(maskBlur.value, 10) || 0,
|
||||||
|
crop: cropChk.checked,
|
||||||
|
crop_margin: parseFloat(cropMargin.value) || 0,
|
||||||
|
};
|
||||||
|
let endpoint, body;
|
||||||
|
if (tab === 'auto') {
|
||||||
|
endpoint = '/predict';
|
||||||
|
body = { ...shared, model: modelSel.value, resolution: parseInt(resSel.value, 10) };
|
||||||
|
} else {
|
||||||
|
endpoint = '/segment';
|
||||||
|
body = { ...shared, prompt: promptInput.value.trim(),
|
||||||
|
box_threshold: parseFloat(boxThr.value) || 0.3,
|
||||||
|
text_threshold: parseFloat(textThr.value) || 0.25 };
|
||||||
|
}
|
||||||
|
const resp = await fetch(endpoint, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(body),
|
||||||
image: b64,
|
|
||||||
background: 'alpha',
|
|
||||||
model: modelSel.value,
|
|
||||||
resolution: parseInt(resSel.value, 10),
|
|
||||||
crop: cropChk.checked,
|
|
||||||
crop_margin: parseFloat(cropMargin.value) || 0,
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
|
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
|
||||||
const data = await resp.json();
|
const data = await resp.json();
|
||||||
@ -290,14 +404,22 @@ go.addEventListener('click', async () => {
|
|||||||
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
|
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
|
||||||
dlbtn.disabled = false;
|
dlbtn.disabled = false;
|
||||||
const secs = ((performance.now() - t0) / 1000).toFixed(1);
|
const secs = ((performance.now() - t0) / 1000).toFixed(1);
|
||||||
setStatus('Done — ' + data.width + '×' + data.height + ' · ' + data.model +
|
if (tab === 'auto') {
|
||||||
' @ ' + data.resolution + ' · ' + secs + 's');
|
setStatus(`Done — ${data.width}×${data.height} · ${data.model} @ ${data.resolution} · ${secs}s`);
|
||||||
|
} else {
|
||||||
|
const n = data.detections;
|
||||||
|
setStatus(`Done — ${n} object${n === 1 ? '' : 's'} matched "${data.prompt}" · ` +
|
||||||
|
`${data.width}×${data.height} · ${secs}s` +
|
||||||
|
(n === 0 ? ' (try a lower box threshold)' : ''));
|
||||||
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
setStatus(err.message || String(err), true);
|
setStatus(err.message || String(err), true);
|
||||||
} finally {
|
} finally {
|
||||||
go.disabled = false;
|
go.disabled = false;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
setTab('auto');
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user