From 4efb4b8a2f4438c9dfc6de072cc71f47e30a02e6 Mon Sep 17 00:00:00 2001 From: Michael Pilosov Date: Sat, 16 May 2026 17:04:56 -0600 Subject: [PATCH] rmbg option + some other post-processing --- README.md | 97 ++++++++------ pyproject.toml | 4 +- scripts/client.py | 27 +++- src/birefnet_service/model.py | 136 ++++++++++++------- src/birefnet_service/prompt_segment.py | 135 +++++++++++++++++++ src/birefnet_service/server.py | 98 +++++++++++--- src/birefnet_service/static/index.html | 176 +++++++++++++++++++++---- 7 files changed, 529 insertions(+), 144 deletions(-) create mode 100644 src/birefnet_service/prompt_segment.py diff --git a/README.md b/README.md index 093a6c4..6d77e6f 100644 --- a/README.md +++ b/README.md @@ -1,78 +1,98 @@ # BiRefNet Background Removal Service -GPU-accelerated background removal exposed as an HTTP API. Uses -[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with -[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the -NVIDIA container runtime. +GPU-accelerated background removal as an HTTP API. Two pipelines: + +- **Auto** — [BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) / + [RMBG-2.0](https://huggingface.co/briaai/RMBG-2.0) salient-object matting. +- **Prompt** — [GroundingDINO](https://huggingface.co/IDEA-Research/grounding-dino-tiny) + + [SAM](https://huggingface.co/facebook/sam-vit-base): segment whatever a text + prompt describes. + +Served with [LitServe](https://github.com/Lightning-AI/LitServe), packaged for +the NVIDIA container runtime. ## Requirements - NVIDIA GPU + driver, Docker, and the `nvidia` container runtime -- ~2 GB free disk for the model weights (downloaded on first run) +- ~5 GB free disk for model weights (downloaded on first use, cached in a volume) ## Quick start ```bash make build # build the Docker image make run # start the service on :8000 (GPU) -make logs # watch startup — first run downloads BiRefNet weights +make logs # watch startup — first run downloads model weights make test # send test.jpg, save output.png ``` -`make test` waits for the service `/health` endpoint before sending the -request, so the first call may block while the model downloads and loads. +`make test` waits for `/health` before sending, so the first call may block +while a model downloads and loads. ### Web UI -A minimal test page is served at the service root — open -**http://localhost:8000/** in a browser, drop in an image, and preview the -transparent-background result (handy when working over SSH). It calls the -same `/predict` endpoint. +Open **http://localhost:8000/** — a two-tab test page (handy over SSH): -### Useful variations +- **Auto remove** — pick a model variant + resolution. +- **Prompt segment** — type what to keep (e.g. `the dog`), tune the + GroundingDINO box / text thresholds. -```bash -make test BG=white # composite onto a white background -make test INPUT=photo.jpg OUTPUT=cut.png -make test-mask # also save the raw alpha mask (mask.png) -make help # list all targets -``` +Both tabs support a transparency checkerboard preview, click-to-zoom lightbox, +optional crop-to-subject, and download. ## API -`POST /predict` +### `POST /predict` — automatic background removal ```jsonc { "image": "", // required + "model": "HR", // general|HR|portrait|matting|lite|rmbg2 + "resolution": 2048, // inference resolution (×32) "background": "alpha", // alpha|white|black|gray|green|blue|red "mask_blur": 0, // Gaussian blur radius on mask edges + "crop": false, // crop to the foreground bounding box + "crop_margin": 0.0, // crop margin in inches (uses image DPI) "return_mask": false // include the raw mask in the response } ``` -Response: +### `POST /segment` — prompt-conditioned segmentation ```jsonc { - "image": "", - "format": "png", - "width": 3637, - "height": 3637, - "mask": "" // only when return_mask=true + "image": "", // required + "prompt": "the dog", // required — object(s) to keep + "box_threshold": 0.3, // GroundingDINO detection threshold + "text_threshold": 0.25, + "background": "alpha", + "mask_blur": 0, + "crop": false, + "crop_margin": 0.0 } ``` +Response (both): `image` (base64 PNG), `format`, `width`, `height`, plus +`model`/`resolution` (`/predict`) or `detections`/`prompt` (`/segment`). + `GET /health` returns 200 when the service is ready. +## CLI + +```bash +python3 scripts/client.py --input photo.jpg --output cut.png --model HR --resolution 2048 --crop +python3 scripts/client.py --input photo.jpg --output dog.png --prompt "the dog" --crop +``` + ## Configuration (environment variables) -| Variable | Default | Purpose | -|----------------------|----------------------|----------------------------------| -| `PORT` | `8000` | HTTP port | -| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights | -| `BIREFNET_RESOLUTION`| `1024` | Inference resolution | -| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) | +| Variable | Default | Purpose | +|----------------------|--------------------------------|-------------------------------| +| `PORT` | `8000` | HTTP port | +| `BIREFNET_MODEL` | `general` | Default Auto variant | +| `BIREFNET_RESOLUTION`| `1024` | Default Auto resolution | +| `DINO_MODEL` | `IDEA-Research/grounding-dino-tiny` | GroundingDINO checkpoint | +| `SAM_MODEL` | `facebook/sam-vit-large` | SAM checkpoint | +| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) | ## Local development (no Docker) @@ -85,10 +105,11 @@ make dev # uv sync + run the server locally ## Layout ``` -src/birefnet_service/model.py BiRefNet wrapper (load + inference) -src/birefnet_service/server.py LitServe API + web UI route -src/birefnet_service/static/ web UI (index.html) -scripts/client.py stdlib-only test client -Dockerfile / docker-compose.yml CUDA image + nvidia runtime -Makefile build / run / test shortcuts +src/birefnet_service/model.py BiRefNet / RMBG-2.0 wrapper + compositing +src/birefnet_service/prompt_segment.py GroundingDINO + SAM pipeline +src/birefnet_service/server.py LitServe /predict + /segment + web UI +src/birefnet_service/static/ web UI (index.html) +scripts/client.py stdlib-only test client +Dockerfile / compose.yml CUDA image + nvidia runtime +Makefile build / run / test shortcuts ``` diff --git a/pyproject.toml b/pyproject.toml index 3adf538..1d3c675 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "birefnet-service" +name = "rmbg-as-a-service" version = "0.1.0" -description = "BiRefNet background removal as a GPU-accelerated API" +description = "Background removal as a GPU-accelerated API" readme = "README.md" requires-python = ">=3.12,<3.13" dependencies = [ diff --git a/scripts/client.py b/scripts/client.py index 3a6772b..b10b1c8 100644 --- a/scripts/client.py +++ b/scripts/client.py @@ -40,6 +40,9 @@ def main() -> None: ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048") ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box") ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches") + ap.add_argument("--prompt", default=None, help="if set, use prompt segmentation (/segment)") + ap.add_argument("--box-threshold", type=float, default=0.3, help="GroundingDINO box threshold") + ap.add_argument("--text-threshold", type=float, default=0.25, help="GroundingDINO text threshold") ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges") ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path") ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health") @@ -55,16 +58,24 @@ def main() -> None: "mask_blur": args.mask_blur, "return_mask": args.mask_output is not None, } - if args.model is not None: - payload["model"] = args.model - if args.resolution is not None: - payload["resolution"] = args.resolution if args.crop: payload["crop"] = True payload["crop_margin"] = args.crop_margin + if args.prompt is not None: + endpoint = "/segment" + payload["prompt"] = args.prompt + payload["box_threshold"] = args.box_threshold + payload["text_threshold"] = args.text_threshold + else: + endpoint = "/predict" + if args.model is not None: + payload["model"] = args.model + if args.resolution is not None: + payload["resolution"] = args.resolution + req = urllib.request.Request( - f"{base_url}/predict", + f"{base_url}{endpoint}", data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, method="POST", @@ -77,9 +88,13 @@ def main() -> None: with open(args.output, "wb") as f: f.write(base64.b64decode(result["image"])) + if "detections" in result: + detail = f"{result['detections']} object(s) matched '{result.get('prompt')}'" + else: + detail = f"{result.get('model')} @ {result.get('resolution')}" print( f"saved {args.output} {result['width']}x{result['height']} " - f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)" + f"{detail} ({elapsed:.1f}s)" ) if args.mask_output and "mask" in result: diff --git a/src/birefnet_service/model.py b/src/birefnet_service/model.py index b14ebdb..e899109 100644 --- a/src/birefnet_service/model.py +++ b/src/birefnet_service/model.py @@ -1,8 +1,8 @@ """BiRefNet model wrapper for background removal. -Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports -multiple model variants (lazily loaded + cached) and a tunable inference -resolution, both selectable per request. +Loads BiRefNet / RMBG-2.0 weights via ``transformers`` (trust_remote_code). +Supports multiple model variants (lazily loaded + cached) and a tunable +inference resolution, both selectable per request. """ from __future__ import annotations @@ -16,20 +16,21 @@ from torchvision import transforms from transformers import AutoModelForImageSegmentation # Friendly variant names -> HuggingFace repo. A raw repo id may also be passed. -# BiRefNet variants: general is fast for clean single subjects; HR is for large -# / detailed scenes and needs a higher resolution (>=1536) to perform well. +# general: fast for clean single subjects. HR: large / detailed scenes (>=1536). +# rmbg2: BRIA RMBG-2.0 (a BiRefNet-architecture model), loaded the same way. MODEL_ALIASES = { "general": "ZhengPeng7/BiRefNet", "HR": "ZhengPeng7/BiRefNet_HR", "portrait": "ZhengPeng7/BiRefNet-portrait", "matting": "ZhengPeng7/BiRefNet-matting", "lite": "ZhengPeng7/BiRefNet_lite", + "rmbg2": "1038lab/RMBG-2.0", } DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general") DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024")) -# ImageNet normalization, matching BiRefNet training. +# ImageNet normalization, matching BiRefNet / RMBG-2.0 training. _MEAN = [0.485, 0.456, 0.406] _STD = [0.229, 0.224, 0.225] @@ -55,8 +56,73 @@ def _normalize_resolution(resolution: int | None) -> int: return max(256, (res // 32) * 32) +def dpi_of(image: Image.Image) -> float: + """Horizontal DPI from image metadata; 96 if not embedded.""" + dpi = image.info.get("dpi") + return float(dpi[0]) if dpi and dpi[0] else 96.0 + + +def apply_mask( + image: Image.Image, + mask: Image.Image, + background: str = "alpha", + mask_blur: int = 0, + mask_offset: int = 0, + crop: bool = False, + crop_margin: float = 0.0, + dpi: float = 96.0, +) -> dict: + """Composite an RGB image with an 'L' mask into a result dict. + + Shared by every removal pipeline (BiRefNet and prompt-based). Handles + edge offset (grow/shrink), edge blur, solid-background compositing, and + optional bounding-box crop. + Returns {"image": ..., "mask": ..., "crop_box"?: ...}. + """ + # Grow (positive) or shrink (negative) the cutout — kills edge halos/fringe. + if mask_offset: + edge_filter = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter + for _ in range(abs(int(mask_offset))): + mask = mask.filter(edge_filter(3)) + if mask_blur > 0: + mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur)) + + cutout = image.convert("RGBA") + cutout.putalpha(mask) + + background = (background or "alpha").lower() + if background == "alpha": + result = cutout + else: + if background not in BG_COLORS: + raise ValueError( + f"Unknown background '{background}'. " + f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}" + ) + bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255)) + result = Image.alpha_composite(bg, cutout).convert("RGB") + + out: dict = {"image": result, "mask": mask} + + if crop: + bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox() + if bbox is not None: + margin_px = round(max(0.0, crop_margin) * dpi) + box = ( + max(0, bbox[0] - margin_px), + max(0, bbox[1] - margin_px), + min(image.width, bbox[2] + margin_px), + min(image.height, bbox[3] + margin_px), + ) + out["image"] = result.crop(box) + out["mask"] = mask.crop(box) + out["crop_box"] = box + + return out + + class BiRefNetService: - """Runs BiRefNet background removal; caches loaded model variants.""" + """Runs BiRefNet / RMBG-2.0 background removal; caches loaded variants.""" def __init__( self, @@ -122,64 +188,34 @@ class BiRefNetService: resolution: int | None = None, background: str = "alpha", mask_blur: int = 0, + mask_offset: int = 0, crop: bool = False, crop_margin: float = 0.0, return_mask: bool = False, ) -> dict: """Run background removal. - model: variant alias ('general', 'HR', ...) or a raw HF repo id. + model: variant alias ('general', 'HR', 'rmbg2', ...) or a repo id. resolution: inference resolution (rounded down to a multiple of 32). background: "alpha" for transparency, or a key from BG_COLORS. mask_blur: Gaussian blur radius applied to the mask edges. + mask_offset: grow (+) or shrink (-) the cutout edge, in pixels. crop: crop the output to the foreground's bounding box. crop_margin: extra margin around the crop, in inches (uses image DPI). """ - # DPI for inch->pixel margin conversion; default 96 if not embedded. - dpi = image.info.get("dpi") - dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0 - + dpi = dpi_of(image) image = image.convert("RGB") repo = resolve_repo(model) resolution = _normalize_resolution(resolution) mask = self.infer_mask(image, repo, resolution) - if mask_blur > 0: - mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur)) - - cutout = image.convert("RGBA") - cutout.putalpha(mask) - - background = (background or "alpha").lower() - if background == "alpha": - result = cutout - else: - if background not in BG_COLORS: - raise ValueError( - f"Unknown background '{background}'. " - f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}" - ) - bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255)) - result = Image.alpha_composite(bg, cutout).convert("RGB") - - out: dict = {"image": result, "model": repo, "resolution": resolution} - - if crop: - # Bounding box of the foreground (mask above a low alpha threshold). - bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox() - if bbox is not None: - margin_px = round(max(0.0, crop_margin) * dpi_x) - left = max(0, bbox[0] - margin_px) - top = max(0, bbox[1] - margin_px) - right = min(image.width, bbox[2] + margin_px) - bottom = min(image.height, bbox[3] + margin_px) - box = (left, top, right, bottom) - result = result.crop(box) - mask = mask.crop(box) - out["image"] = result - out["crop_box"] = box - out["dpi"] = round(dpi_x, 1) - - if return_mask: - out["mask"] = mask + out = apply_mask( + image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi + ) + out["model"] = repo + out["resolution"] = resolution + if "crop_box" in out: + out["dpi"] = round(dpi, 1) + if not return_mask: + out.pop("mask", None) return out diff --git a/src/birefnet_service/prompt_segment.py b/src/birefnet_service/prompt_segment.py new file mode 100644 index 0000000..4f56550 --- /dev/null +++ b/src/birefnet_service/prompt_segment.py @@ -0,0 +1,135 @@ +"""Prompt-conditioned segmentation: GroundingDINO + SAM. + +Given a text prompt, GroundingDINO detects matching boxes and SAM turns those +boxes into masks. The union of the masks becomes the foreground alpha — which +is then composited/cropped by the shared ``apply_mask`` helper. + +Both models come from ``transformers`` (no custom CUDA extensions). +""" + +from __future__ import annotations + +import os +import threading + +import torch +from PIL import Image +from transformers import ( + AutoProcessor, + GroundingDinoForObjectDetection, + SamModel, + SamProcessor, +) + +from .model import apply_mask, dpi_of + +DINO_MODEL = os.getenv("DINO_MODEL", "IDEA-Research/grounding-dino-tiny") +SAM_MODEL = os.getenv("SAM_MODEL", "facebook/sam-vit-large") + + +class PromptSegmenter: + """Text-prompted segmentation. Models load lazily on first use.""" + + def __init__(self, device: str | None = None): + want_cuda = device != "cpu" and torch.cuda.is_available() + self.device = "cuda" if want_cuda else "cpu" + self._lock = threading.Lock() + self._ready = False + + def _ensure_loaded(self) -> None: + if self._ready: + return + with self._lock: + if self._ready: + return + self.dino_processor = AutoProcessor.from_pretrained(DINO_MODEL) + self.dino = GroundingDinoForObjectDetection.from_pretrained(DINO_MODEL) + self.dino.eval().to(self.device) + self.sam_processor = SamProcessor.from_pretrained(SAM_MODEL) + self.sam = SamModel.from_pretrained(SAM_MODEL) + self.sam.eval().to(self.device) + self._ready = True + + @torch.inference_mode() + def _detect( + self, image: Image.Image, prompt: str, box_threshold: float, text_threshold: float + ) -> torch.Tensor: + # GroundingDINO expects lowercase phrases separated/terminated by '.'. + text = prompt.strip().lower().replace(",", ".") + if not text.endswith("."): + text += "." + inputs = self.dino_processor(images=image, text=text, return_tensors="pt").to( + self.device + ) + outputs = self.dino(**inputs) + results = self.dino_processor.post_process_grounded_object_detection( + outputs, + inputs["input_ids"], + threshold=box_threshold, + text_threshold=text_threshold, + target_sizes=[(image.height, image.width)], + )[0] + return results["boxes"] # [N, 4] xyxy, pixel coords + + @torch.inference_mode() + def _mask_from_boxes(self, image: Image.Image, boxes: torch.Tensor) -> Image.Image: + inputs = self.sam_processor( + image, input_boxes=[boxes.tolist()], return_tensors="pt" + ).to(self.device) + outputs = self.sam(**inputs) + masks = self.sam_processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + )[0] # [N, 3, H, W] bool + iou = outputs.iou_scores.cpu()[0] # [N, 3] + # Best of SAM's 3 candidates per box, then union all boxes' masks. + best = iou.argmax(dim=-1) + chosen = torch.stack([masks[i, best[i]] for i in range(masks.shape[0])]) + union = chosen.any(dim=0) # [H, W] bool + return Image.fromarray((union.numpy() * 255).astype("uint8"), mode="L") + + def segment( + self, + image: Image.Image, + prompt: str, + background: str = "alpha", + mask_blur: int = 0, + mask_offset: int = 0, + crop: bool = False, + crop_margin: float = 0.0, + box_threshold: float = 0.3, + text_threshold: float = 0.25, + return_mask: bool = False, + ) -> dict: + """Segment whatever ``prompt`` describes and remove the rest. + + prompt: object(s) to keep, e.g. "the dog" or "cow. person.". + box_threshold / text_threshold: GroundingDINO detection thresholds. + Other args match ``apply_mask`` / the BiRefNet pipeline. + """ + if not prompt or not prompt.strip(): + raise ValueError("A non-empty 'prompt' is required.") + + self._ensure_loaded() + dpi = dpi_of(image) + image = image.convert("RGB") + + with self._lock: + boxes = self._detect(image, prompt, box_threshold, text_threshold) + detections = int(len(boxes)) + if detections == 0: + mask = Image.new("L", image.size, 0) + else: + mask = self._mask_from_boxes(image, boxes) + + out = apply_mask( + image, mask, background, mask_blur, mask_offset, crop, crop_margin, dpi + ) + out["prompt"] = prompt + out["detections"] = detections + if "crop_box" in out: + out["dpi"] = round(dpi, 1) + if not return_mask: + out.pop("mask", None) + return out diff --git a/src/birefnet_service/server.py b/src/birefnet_service/server.py index a437ef2..15945d6 100644 --- a/src/birefnet_service/server.py +++ b/src/birefnet_service/server.py @@ -1,27 +1,34 @@ -"""LitServe API exposing BiRefNet background removal. +"""LitServe API exposing BiRefNet background removal + prompt segmentation. -Endpoint: POST /predict -Request JSON: +Endpoints: + POST /predict BiRefNet / RMBG-2.0 automatic background removal (LitServe) + POST /segment GroundingDINO + SAM prompt-conditioned segmentation + GET / minimal web UI (two tabs: Auto / Prompt) + GET /health readiness + +/predict request JSON: { - "image": "", (required) - "model": "general" | "HR" | "portrait" | ..., (default "general") - "resolution": 1024, (default 1024) - "background": "alpha" | "white" | "black" | ..., (default "alpha") - "mask_blur": 0, (default 0) - "return_mask": false (default false) - } -Response JSON: - { - "image": "", - "format": "png", - "width": int, - "height": int, - "model": "", - "resolution": int, - "mask": "" (only when return_mask=true) + "image": "", (required) + "model": "general" | "HR" | "rmbg2" | ..., (default "general") + "resolution": 1024, + "background": "alpha" | "white" | ..., + "mask_blur": 0, + "crop": false, + "crop_margin": 0.0, (inches) + "return_mask": false } -A minimal web UI is served at GET / (same origin as /predict). +/segment request JSON: + { + "image": "", (required) + "prompt": "the dog", (required) + "box_threshold": 0.3, + "text_threshold": 0.25, + "background": "alpha" | ..., + "mask_blur": 0, + "crop": false, + "crop_margin": 0.0 + } """ from __future__ import annotations @@ -29,16 +36,32 @@ from __future__ import annotations import base64 import io import os +import threading from pathlib import Path import litserve as ls +from fastapi import HTTPException from fastapi.responses import HTMLResponse from PIL import Image, ImageOps from .model import BiRefNetService +from .prompt_segment import PromptSegmenter _UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8") +# Lazily-created prompt segmenter (DINO + SAM), shared by the /segment route. +_segmenter: PromptSegmenter | None = None +_segmenter_lock = threading.Lock() + + +def _get_segmenter() -> PromptSegmenter: + global _segmenter + if _segmenter is None: + with _segmenter_lock: + if _segmenter is None: + _segmenter = PromptSegmenter() + return _segmenter + def _b64_to_image(data: str) -> Image.Image: image = Image.open(io.BytesIO(base64.b64decode(data))) @@ -65,6 +88,7 @@ class BiRefNetAPI(ls.LitAPI): "resolution": request.get("resolution"), "background": request.get("background", "alpha"), "mask_blur": int(request.get("mask_blur", 0)), + "mask_offset": int(request.get("mask_offset", 0)), "crop": bool(request.get("crop", False)), "crop_margin": float(request.get("crop_margin", 0.0)), "return_mask": bool(request.get("return_mask", False)), @@ -83,6 +107,8 @@ class BiRefNetAPI(ls.LitAPI): "model": output["model"], "resolution": output["resolution"], } + if "crop_box" in output: + response["cropped"] = True if output.get("mask") is not None: response["mask"] = _image_to_b64(output["mask"]) return response @@ -97,7 +123,7 @@ def run() -> None: ) # LitServe registers its own "/" route ("litserve running"); drop it so - # our UI can own the root path. Served same-origin as /predict (no CORS). + # our UI can own the root path. Served same-origin as the APIs (no CORS). server.app.router.routes = [ r for r in server.app.router.routes if getattr(r, "path", None) != "/" ] @@ -106,6 +132,36 @@ def run() -> None: def index() -> str: return _UI_HTML + @server.app.post("/segment") + def segment(payload: dict) -> dict: + """Prompt-conditioned segmentation (GroundingDINO + SAM).""" + if "image" not in payload: + raise HTTPException(status_code=400, detail="Missing base64 'image'.") + image = _b64_to_image(payload["image"]) + try: + result = _get_segmenter().segment( + image, + prompt=payload.get("prompt", ""), + background=payload.get("background", "alpha"), + mask_blur=int(payload.get("mask_blur", 0)), + mask_offset=int(payload.get("mask_offset", 0)), + crop=bool(payload.get("crop", False)), + crop_margin=float(payload.get("crop_margin", 0.0)), + box_threshold=float(payload.get("box_threshold", 0.3)), + text_threshold=float(payload.get("text_threshold", 0.25)), + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + image_out: Image.Image = result["image"] + return { + "image": _image_to_b64(image_out), + "format": "png", + "width": image_out.width, + "height": image_out.height, + "detections": result["detections"], + "prompt": result["prompt"], + } + server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False) diff --git a/src/birefnet_service/static/index.html b/src/birefnet_service/static/index.html index 4fe2c51..eec01c9 100644 --- a/src/birefnet_service/static/index.html +++ b/src/birefnet_service/static/index.html @@ -3,7 +3,7 @@ -BiRefNet — Background Removal +Background Removal & Segmentation