mvp
This commit is contained in:
commit
96d16fc654
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
.venv
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
hf_cache
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.png
|
||||||
|
output*
|
||||||
|
.python-version
|
||||||
|
Makefile
|
||||||
|
docker-compose.yml
|
||||||
8
.gitignore
vendored
Normal file
8
.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
hf_cache/
|
||||||
|
output.png
|
||||||
|
output*.png
|
||||||
|
mask*.png
|
||||||
|
*.jpg
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
38
Dockerfile
Normal file
38
Dockerfile
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# BiRefNet background removal service — CUDA 12.4 runtime image.
|
||||||
|
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
PYTHONUNBUFFERED=1 \
|
||||||
|
UV_PYTHON_INSTALL_DIR=/opt/python \
|
||||||
|
UV_PROJECT_ENVIRONMENT=/app/.venv \
|
||||||
|
UV_COMPILE_BYTECODE=1 \
|
||||||
|
UV_LINK_MODE=copy \
|
||||||
|
HF_HOME=/app/hf_cache \
|
||||||
|
PORT=8000
|
||||||
|
|
||||||
|
# uv: fast, reproducible Python + dependency management.
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:0.9 /uv /uvx /bin/
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y --no-install-recommends ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install Python + dependencies first so this layer is cached across code changes.
|
||||||
|
# The BuildKit cache mount keeps the uv download cache warm across rebuilds.
|
||||||
|
COPY pyproject.toml ./
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv python install 3.12 \
|
||||||
|
&& uv sync --no-install-project --no-dev
|
||||||
|
|
||||||
|
# Application code.
|
||||||
|
COPY src ./src
|
||||||
|
COPY README.md ./
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv sync --no-dev
|
||||||
|
|
||||||
|
ENV PATH="/app/.venv/bin:${PATH}"
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
CMD ["birefnet-service"]
|
||||||
58
Makefile
Normal file
58
Makefile
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# BiRefNet background removal service — CLI shortcuts.
|
||||||
|
# Override defaults inline, e.g.: make test BG=white INPUT=photo.jpg
|
||||||
|
|
||||||
|
COMPOSE ?= docker compose
|
||||||
|
PYTHON ?= python3
|
||||||
|
PORT ?= 8000
|
||||||
|
INPUT ?= test.jpg
|
||||||
|
OUTPUT ?= output.png
|
||||||
|
BG ?= alpha
|
||||||
|
BLUR ?= 0
|
||||||
|
|
||||||
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
|
.PHONY: help build run up stop down logs log ps test test-mask dev sync shell clean fmt
|
||||||
|
|
||||||
|
help: ## Show this help
|
||||||
|
@grep -E '^[a-zA-Z_-]+:.*?## ' $(MAKEFILE_LIST) \
|
||||||
|
| awk 'BEGIN{FS=":.*?## "}{printf " \033[36m%-12s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
|
build: ## Build the Docker image
|
||||||
|
$(COMPOSE) build
|
||||||
|
|
||||||
|
run up: ## Start the service (GPU) in the background
|
||||||
|
$(COMPOSE) up -d
|
||||||
|
|
||||||
|
stop down: ## Stop and remove the service container
|
||||||
|
$(COMPOSE) down
|
||||||
|
|
||||||
|
logs log: ## Follow service logs
|
||||||
|
$(COMPOSE) logs -f
|
||||||
|
|
||||||
|
ps: ## Show service status
|
||||||
|
$(COMPOSE) ps
|
||||||
|
|
||||||
|
test: ## Send INPUT to the running service, save OUTPUT
|
||||||
|
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
|
||||||
|
--input $(INPUT) --output $(OUTPUT) --background $(BG) --mask-blur $(BLUR)
|
||||||
|
|
||||||
|
test-mask: ## Like 'test' but also save the raw mask (mask.png)
|
||||||
|
$(PYTHON) scripts/client.py --url http://localhost:$(PORT) \
|
||||||
|
--input $(INPUT) --output $(OUTPUT) --background $(BG) \
|
||||||
|
--mask-blur $(BLUR) --mask-output mask.png
|
||||||
|
|
||||||
|
sync: ## Install dependencies locally with uv
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
dev: sync ## Run the service locally (no Docker; needs local CUDA)
|
||||||
|
uv run birefnet-service
|
||||||
|
|
||||||
|
shell: ## Open a shell inside a fresh container
|
||||||
|
$(COMPOSE) run --rm --entrypoint bash birefnet
|
||||||
|
|
||||||
|
fmt: ## Format code with ruff
|
||||||
|
uv run ruff format src scripts
|
||||||
|
|
||||||
|
clean: ## Stop the service and remove build artifacts
|
||||||
|
-$(COMPOSE) down
|
||||||
|
rm -f $(OUTPUT) mask.png
|
||||||
94
README.md
Normal file
94
README.md
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# BiRefNet Background Removal Service
|
||||||
|
|
||||||
|
GPU-accelerated background removal exposed as an HTTP API. Uses
|
||||||
|
[BiRefNet](https://huggingface.co/ZhengPeng7/BiRefNet) for matting, served with
|
||||||
|
[LitServe](https://github.com/Lightning-AI/LitServe), packaged for the
|
||||||
|
NVIDIA container runtime.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- NVIDIA GPU + driver, Docker, and the `nvidia` container runtime
|
||||||
|
- ~2 GB free disk for the model weights (downloaded on first run)
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make build # build the Docker image
|
||||||
|
make run # start the service on :8000 (GPU)
|
||||||
|
make logs # watch startup — first run downloads BiRefNet weights
|
||||||
|
make test # send test.jpg, save output.png
|
||||||
|
```
|
||||||
|
|
||||||
|
`make test` waits for the service `/health` endpoint before sending the
|
||||||
|
request, so the first call may block while the model downloads and loads.
|
||||||
|
|
||||||
|
### Web UI
|
||||||
|
|
||||||
|
A minimal test page is served at the service root — open
|
||||||
|
**http://localhost:8000/** in a browser, drop in an image, and preview the
|
||||||
|
transparent-background result (handy when working over SSH). It calls the
|
||||||
|
same `/predict` endpoint.
|
||||||
|
|
||||||
|
### Useful variations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make test BG=white # composite onto a white background
|
||||||
|
make test INPUT=photo.jpg OUTPUT=cut.png
|
||||||
|
make test-mask # also save the raw alpha mask (mask.png)
|
||||||
|
make help # list all targets
|
||||||
|
```
|
||||||
|
|
||||||
|
## API
|
||||||
|
|
||||||
|
`POST /predict`
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"image": "<base64 image bytes>", // required
|
||||||
|
"background": "alpha", // alpha|white|black|gray|green|blue|red
|
||||||
|
"mask_blur": 0, // Gaussian blur radius on mask edges
|
||||||
|
"return_mask": false // include the raw mask in the response
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
```jsonc
|
||||||
|
{
|
||||||
|
"image": "<base64 PNG>",
|
||||||
|
"format": "png",
|
||||||
|
"width": 3637,
|
||||||
|
"height": 3637,
|
||||||
|
"mask": "<base64 PNG>" // only when return_mask=true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`GET /health` returns 200 when the service is ready.
|
||||||
|
|
||||||
|
## Configuration (environment variables)
|
||||||
|
|
||||||
|
| Variable | Default | Purpose |
|
||||||
|
|----------------------|----------------------|----------------------------------|
|
||||||
|
| `PORT` | `8000` | HTTP port |
|
||||||
|
| `BIREFNET_MODEL` | `ZhengPeng7/BiRefNet`| HuggingFace repo for the weights |
|
||||||
|
| `BIREFNET_RESOLUTION`| `1024` | Inference resolution |
|
||||||
|
| `REQUEST_TIMEOUT` | `120` | Per-request timeout (seconds) |
|
||||||
|
|
||||||
|
## Local development (no Docker)
|
||||||
|
|
||||||
|
Requires a local CUDA-capable PyTorch environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make dev # uv sync + run the server locally
|
||||||
|
```
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
```
|
||||||
|
src/birefnet_service/model.py BiRefNet wrapper (load + inference)
|
||||||
|
src/birefnet_service/server.py LitServe API + web UI route
|
||||||
|
src/birefnet_service/static/ web UI (index.html)
|
||||||
|
scripts/client.py stdlib-only test client
|
||||||
|
Dockerfile / docker-compose.yml CUDA image + nvidia runtime
|
||||||
|
Makefile build / run / test shortcuts
|
||||||
|
```
|
||||||
28
compose.yml
Normal file
28
compose.yml
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
services:
|
||||||
|
birefnet:
|
||||||
|
build: .
|
||||||
|
image: birefnet-service:latest
|
||||||
|
container_name: birefnet-service
|
||||||
|
ports:
|
||||||
|
- "${PORT:-8000}:8000"
|
||||||
|
environment:
|
||||||
|
- NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
- NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
# Default variant/resolution; both are also selectable per request.
|
||||||
|
- BIREFNET_MODEL=${BIREFNET_MODEL:-general}
|
||||||
|
- BIREFNET_RESOLUTION=${BIREFNET_RESOLUTION:-1024}
|
||||||
|
# Use the nvidia-container-runtime for GPU acceleration.
|
||||||
|
runtime: nvidia
|
||||||
|
volumes:
|
||||||
|
# Persist downloaded BiRefNet weights across container restarts.
|
||||||
|
- hf-cache:/app/hf_cache
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||||
|
interval: 15s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 30
|
||||||
|
start_period: 180s
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
hf-cache:
|
||||||
43
pyproject.toml
Normal file
43
pyproject.toml
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
[project]
|
||||||
|
name = "birefnet-service"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "BiRefNet background removal as a GPU-accelerated API"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12,<3.13"
|
||||||
|
dependencies = [
|
||||||
|
"torch==2.5.1",
|
||||||
|
"torchvision==0.20.1",
|
||||||
|
"transformers>=4.44,<5",
|
||||||
|
"timm>=1.0.0",
|
||||||
|
"einops>=0.8.0",
|
||||||
|
"kornia>=0.7.0",
|
||||||
|
"pillow>=10.0.0",
|
||||||
|
"numpy>=1.26",
|
||||||
|
"litserve>=0.2.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
birefnet-service = "birefnet_service.server:run"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = ["ruff>=0.6.0"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/birefnet_service"]
|
||||||
|
|
||||||
|
# BiRefNet (torch) needs CUDA wheels; pull torch/torchvision from the PyTorch index.
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cu124"
|
||||||
|
url = "https://download.pytorch.org/whl/cu124"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = { index = "pytorch-cu124" }
|
||||||
|
torchvision = { index = "pytorch-cu124" }
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
92
scripts/client.py
Normal file
92
scripts/client.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Minimal stdlib-only client for the BiRefNet service.
|
||||||
|
|
||||||
|
Encodes an image, posts it to /predict, and saves the returned PNG.
|
||||||
|
No third-party dependencies so it can run with any system Python.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_health(base_url: str, timeout: float) -> None:
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
health = f"{base_url}/health"
|
||||||
|
while time.time() < deadline:
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(health, timeout=5) as resp:
|
||||||
|
if resp.status == 200:
|
||||||
|
return
|
||||||
|
except (urllib.error.URLError, ConnectionError, OSError):
|
||||||
|
pass
|
||||||
|
time.sleep(2)
|
||||||
|
sys.exit(f"server at {health} not ready after {timeout:.0f}s")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument("--url", default="http://localhost:8000", help="service base URL")
|
||||||
|
ap.add_argument("--input", default="test.jpg", help="input image path")
|
||||||
|
ap.add_argument("--output", default="output.png", help="output PNG path")
|
||||||
|
ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red")
|
||||||
|
ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite")
|
||||||
|
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
|
||||||
|
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
|
||||||
|
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
|
||||||
|
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
|
||||||
|
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
|
||||||
|
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
base_url = args.url.rstrip("/")
|
||||||
|
wait_for_health(base_url, args.wait)
|
||||||
|
|
||||||
|
with open(args.input, "rb") as f:
|
||||||
|
payload = {
|
||||||
|
"image": base64.b64encode(f.read()).decode("ascii"),
|
||||||
|
"background": args.background,
|
||||||
|
"mask_blur": args.mask_blur,
|
||||||
|
"return_mask": args.mask_output is not None,
|
||||||
|
}
|
||||||
|
if args.model is not None:
|
||||||
|
payload["model"] = args.model
|
||||||
|
if args.resolution is not None:
|
||||||
|
payload["resolution"] = args.resolution
|
||||||
|
if args.crop:
|
||||||
|
payload["crop"] = True
|
||||||
|
payload["crop_margin"] = args.crop_margin
|
||||||
|
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{base_url}/predict",
|
||||||
|
data=json.dumps(payload).encode(),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
|
||||||
|
started = time.time()
|
||||||
|
with urllib.request.urlopen(req, timeout=300) as resp:
|
||||||
|
result = json.loads(resp.read())
|
||||||
|
elapsed = time.time() - started
|
||||||
|
|
||||||
|
with open(args.output, "wb") as f:
|
||||||
|
f.write(base64.b64decode(result["image"]))
|
||||||
|
print(
|
||||||
|
f"saved {args.output} {result['width']}x{result['height']} "
|
||||||
|
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.mask_output and "mask" in result:
|
||||||
|
with open(args.mask_output, "wb") as f:
|
||||||
|
f.write(base64.b64decode(result["mask"]))
|
||||||
|
print(f"saved {args.mask_output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
src/birefnet_service/__init__.py
Normal file
3
src/birefnet_service/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""BiRefNet background removal service."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
185
src/birefnet_service/model.py
Normal file
185
src/birefnet_service/model.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
"""BiRefNet model wrapper for background removal.
|
||||||
|
|
||||||
|
Loads BiRefNet weights via ``transformers`` (trust_remote_code). Supports
|
||||||
|
multiple model variants (lazily loaded + cached) and a tunable inference
|
||||||
|
resolution, both selectable per request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
|
||||||
|
# Friendly variant names -> HuggingFace repo. A raw repo id may also be passed.
|
||||||
|
# BiRefNet variants: general is fast for clean single subjects; HR is for large
|
||||||
|
# / detailed scenes and needs a higher resolution (>=1536) to perform well.
|
||||||
|
MODEL_ALIASES = {
|
||||||
|
"general": "ZhengPeng7/BiRefNet",
|
||||||
|
"HR": "ZhengPeng7/BiRefNet_HR",
|
||||||
|
"portrait": "ZhengPeng7/BiRefNet-portrait",
|
||||||
|
"matting": "ZhengPeng7/BiRefNet-matting",
|
||||||
|
"lite": "ZhengPeng7/BiRefNet_lite",
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_MODEL = os.getenv("BIREFNET_MODEL", "general")
|
||||||
|
DEFAULT_RESOLUTION = int(os.getenv("BIREFNET_RESOLUTION", "1024"))
|
||||||
|
|
||||||
|
# ImageNet normalization, matching BiRefNet training.
|
||||||
|
_MEAN = [0.485, 0.456, 0.406]
|
||||||
|
_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
BG_COLORS: dict[str, tuple[int, int, int]] = {
|
||||||
|
"black": (0, 0, 0),
|
||||||
|
"white": (255, 255, 255),
|
||||||
|
"gray": (128, 128, 128),
|
||||||
|
"green": (0, 255, 0),
|
||||||
|
"blue": (0, 0, 255),
|
||||||
|
"red": (255, 0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_repo(model: str | None) -> str:
|
||||||
|
"""Map a friendly variant name to a repo id (pass-through for raw ids)."""
|
||||||
|
model = model or DEFAULT_MODEL
|
||||||
|
return MODEL_ALIASES.get(model, model)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_resolution(resolution: int | None) -> int:
|
||||||
|
"""BiRefNet's backbone needs the input side divisible by 32."""
|
||||||
|
res = int(resolution or DEFAULT_RESOLUTION)
|
||||||
|
return max(256, (res // 32) * 32)
|
||||||
|
|
||||||
|
|
||||||
|
class BiRefNetService:
|
||||||
|
"""Runs BiRefNet background removal; caches loaded model variants."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: str | None = None,
|
||||||
|
default_model: str = DEFAULT_MODEL,
|
||||||
|
default_resolution: int = DEFAULT_RESOLUTION,
|
||||||
|
):
|
||||||
|
want_cuda = device != "cpu" and torch.cuda.is_available()
|
||||||
|
if device and device not in ("auto", "cpu"):
|
||||||
|
self.device = device
|
||||||
|
else:
|
||||||
|
self.device = "cuda" if want_cuda else "cpu"
|
||||||
|
self.use_half = self.device.startswith("cuda")
|
||||||
|
self.default_model = default_model
|
||||||
|
self.default_resolution = default_resolution
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
self._models: dict[str, torch.nn.Module] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Preload the default variant so the worker is ready on first request.
|
||||||
|
self._get_model(resolve_repo(default_model))
|
||||||
|
|
||||||
|
def _get_model(self, repo: str) -> torch.nn.Module:
|
||||||
|
with self._lock:
|
||||||
|
model = self._models.get(repo)
|
||||||
|
if model is None:
|
||||||
|
model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
|
repo, trust_remote_code=True
|
||||||
|
)
|
||||||
|
model.eval().to(self.device)
|
||||||
|
if self.use_half:
|
||||||
|
model.half()
|
||||||
|
self._models[repo] = model
|
||||||
|
return model
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def infer_mask(self, image: Image.Image, repo: str, resolution: int) -> Image.Image:
|
||||||
|
"""Return a single-channel ('L') alpha mask at the image's original size."""
|
||||||
|
model = self._get_model(repo)
|
||||||
|
w, h = image.size
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(
|
||||||
|
(resolution, resolution),
|
||||||
|
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||||
|
),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(_MEAN, _STD),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
x = transform(image).unsqueeze(0).to(self.device)
|
||||||
|
if self.use_half:
|
||||||
|
x = x.half()
|
||||||
|
pred = model(x)[-1].sigmoid().float().cpu()[0] # [1, R, R]
|
||||||
|
mask = transforms.ToPILImage()(pred)
|
||||||
|
return mask.resize((w, h), Image.BICUBIC)
|
||||||
|
|
||||||
|
def remove_background(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
model: str | None = None,
|
||||||
|
resolution: int | None = None,
|
||||||
|
background: str = "alpha",
|
||||||
|
mask_blur: int = 0,
|
||||||
|
crop: bool = False,
|
||||||
|
crop_margin: float = 0.0,
|
||||||
|
return_mask: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Run background removal.
|
||||||
|
|
||||||
|
model: variant alias ('general', 'HR', ...) or a raw HF repo id.
|
||||||
|
resolution: inference resolution (rounded down to a multiple of 32).
|
||||||
|
background: "alpha" for transparency, or a key from BG_COLORS.
|
||||||
|
mask_blur: Gaussian blur radius applied to the mask edges.
|
||||||
|
crop: crop the output to the foreground's bounding box.
|
||||||
|
crop_margin: extra margin around the crop, in inches (uses image DPI).
|
||||||
|
"""
|
||||||
|
# DPI for inch->pixel margin conversion; default 96 if not embedded.
|
||||||
|
dpi = image.info.get("dpi")
|
||||||
|
dpi_x = float(dpi[0]) if dpi and dpi[0] else 96.0
|
||||||
|
|
||||||
|
image = image.convert("RGB")
|
||||||
|
repo = resolve_repo(model)
|
||||||
|
resolution = _normalize_resolution(resolution)
|
||||||
|
|
||||||
|
mask = self.infer_mask(image, repo, resolution)
|
||||||
|
if mask_blur > 0:
|
||||||
|
mask = mask.filter(ImageFilter.GaussianBlur(radius=mask_blur))
|
||||||
|
|
||||||
|
cutout = image.convert("RGBA")
|
||||||
|
cutout.putalpha(mask)
|
||||||
|
|
||||||
|
background = (background or "alpha").lower()
|
||||||
|
if background == "alpha":
|
||||||
|
result = cutout
|
||||||
|
else:
|
||||||
|
if background not in BG_COLORS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown background '{background}'. "
|
||||||
|
f"Use 'alpha' or one of: {', '.join(sorted(BG_COLORS))}"
|
||||||
|
)
|
||||||
|
bg = Image.new("RGBA", image.size, (*BG_COLORS[background], 255))
|
||||||
|
result = Image.alpha_composite(bg, cutout).convert("RGB")
|
||||||
|
|
||||||
|
out: dict = {"image": result, "model": repo, "resolution": resolution}
|
||||||
|
|
||||||
|
if crop:
|
||||||
|
# Bounding box of the foreground (mask above a low alpha threshold).
|
||||||
|
bbox = mask.point(lambda p: 255 if p >= 16 else 0).getbbox()
|
||||||
|
if bbox is not None:
|
||||||
|
margin_px = round(max(0.0, crop_margin) * dpi_x)
|
||||||
|
left = max(0, bbox[0] - margin_px)
|
||||||
|
top = max(0, bbox[1] - margin_px)
|
||||||
|
right = min(image.width, bbox[2] + margin_px)
|
||||||
|
bottom = min(image.height, bbox[3] + margin_px)
|
||||||
|
box = (left, top, right, bottom)
|
||||||
|
result = result.crop(box)
|
||||||
|
mask = mask.crop(box)
|
||||||
|
out["image"] = result
|
||||||
|
out["crop_box"] = box
|
||||||
|
out["dpi"] = round(dpi_x, 1)
|
||||||
|
|
||||||
|
if return_mask:
|
||||||
|
out["mask"] = mask
|
||||||
|
return out
|
||||||
113
src/birefnet_service/server.py
Normal file
113
src/birefnet_service/server.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
"""LitServe API exposing BiRefNet background removal.
|
||||||
|
|
||||||
|
Endpoint: POST /predict
|
||||||
|
Request JSON:
|
||||||
|
{
|
||||||
|
"image": "<base64-encoded image bytes>", (required)
|
||||||
|
"model": "general" | "HR" | "portrait" | ..., (default "general")
|
||||||
|
"resolution": 1024, (default 1024)
|
||||||
|
"background": "alpha" | "white" | "black" | ..., (default "alpha")
|
||||||
|
"mask_blur": 0, (default 0)
|
||||||
|
"return_mask": false (default false)
|
||||||
|
}
|
||||||
|
Response JSON:
|
||||||
|
{
|
||||||
|
"image": "<base64 PNG>",
|
||||||
|
"format": "png",
|
||||||
|
"width": int,
|
||||||
|
"height": int,
|
||||||
|
"model": "<repo id used>",
|
||||||
|
"resolution": int,
|
||||||
|
"mask": "<base64 PNG>" (only when return_mask=true)
|
||||||
|
}
|
||||||
|
|
||||||
|
A minimal web UI is served at GET / (same origin as /predict).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import litserve as ls
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
from .model import BiRefNetService
|
||||||
|
|
||||||
|
_UI_HTML = (Path(__file__).parent / "static" / "index.html").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _b64_to_image(data: str) -> Image.Image:
|
||||||
|
image = Image.open(io.BytesIO(base64.b64decode(data)))
|
||||||
|
# Honor EXIF orientation so portrait photos aren't processed/returned sideways.
|
||||||
|
return ImageOps.exif_transpose(image)
|
||||||
|
|
||||||
|
|
||||||
|
def _image_to_b64(image: Image.Image) -> str:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
image.save(buf, format="PNG")
|
||||||
|
return base64.b64encode(buf.getvalue()).decode("ascii")
|
||||||
|
|
||||||
|
|
||||||
|
class BiRefNetAPI(ls.LitAPI):
|
||||||
|
def setup(self, device: str) -> None:
|
||||||
|
self.service = BiRefNetService(device=device)
|
||||||
|
|
||||||
|
def decode_request(self, request: dict) -> dict:
|
||||||
|
if "image" not in request:
|
||||||
|
raise ValueError("Request must include a base64 'image' field.")
|
||||||
|
return {
|
||||||
|
"image": _b64_to_image(request["image"]),
|
||||||
|
"model": request.get("model"),
|
||||||
|
"resolution": request.get("resolution"),
|
||||||
|
"background": request.get("background", "alpha"),
|
||||||
|
"mask_blur": int(request.get("mask_blur", 0)),
|
||||||
|
"crop": bool(request.get("crop", False)),
|
||||||
|
"crop_margin": float(request.get("crop_margin", 0.0)),
|
||||||
|
"return_mask": bool(request.get("return_mask", False)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(self, inputs: dict) -> dict:
|
||||||
|
return self.service.remove_background(**inputs)
|
||||||
|
|
||||||
|
def encode_response(self, output: dict) -> dict:
|
||||||
|
image: Image.Image = output["image"]
|
||||||
|
response = {
|
||||||
|
"image": _image_to_b64(image),
|
||||||
|
"format": "png",
|
||||||
|
"width": image.width,
|
||||||
|
"height": image.height,
|
||||||
|
"model": output["model"],
|
||||||
|
"resolution": output["resolution"],
|
||||||
|
}
|
||||||
|
if output.get("mask") is not None:
|
||||||
|
response["mask"] = _image_to_b64(output["mask"])
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def run() -> None:
|
||||||
|
server = ls.LitServer(
|
||||||
|
BiRefNetAPI(),
|
||||||
|
accelerator="auto",
|
||||||
|
devices=1,
|
||||||
|
timeout=int(os.getenv("REQUEST_TIMEOUT", "120")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# LitServe registers its own "/" route ("litserve running"); drop it so
|
||||||
|
# our UI can own the root path. Served same-origin as /predict (no CORS).
|
||||||
|
server.app.router.routes = [
|
||||||
|
r for r in server.app.router.routes if getattr(r, "path", None) != "/"
|
||||||
|
]
|
||||||
|
|
||||||
|
@server.app.get("/", response_class=HTMLResponse)
|
||||||
|
def index() -> str:
|
||||||
|
return _UI_HTML
|
||||||
|
|
||||||
|
server.run(port=int(os.getenv("PORT", "8000")), generate_client_file=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
303
src/birefnet_service/static/index.html
Normal file
303
src/birefnet_service/static/index.html
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
|
<title>BiRefNet — Background Removal</title>
|
||||||
|
<style>
|
||||||
|
:root { color-scheme: dark; }
|
||||||
|
* { box-sizing: border-box; }
|
||||||
|
body {
|
||||||
|
margin: 0; font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif;
|
||||||
|
background: #15171c; color: #e8e8ea; padding: 24px;
|
||||||
|
}
|
||||||
|
h1 { font-size: 1.25rem; font-weight: 600; margin: 0 0 4px; }
|
||||||
|
.sub { color: #8a8f99; font-size: .85rem; margin-bottom: 20px; }
|
||||||
|
.wrap { max-width: 1100px; margin: 0 auto; }
|
||||||
|
#drop {
|
||||||
|
border: 2px dashed #3a3f4b; border-radius: 12px; padding: 36px;
|
||||||
|
text-align: center; cursor: pointer; transition: border-color .15s, background .15s;
|
||||||
|
}
|
||||||
|
#drop.over { border-color: #5b8cff; background: #1c2230; }
|
||||||
|
#drop p { margin: 6px 0; color: #8a8f99; }
|
||||||
|
.controls { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
||||||
|
label.field { display: flex; flex-direction: column; gap: 4px; font-size: .72rem;
|
||||||
|
color: #8a8f99; text-transform: uppercase; letter-spacing: .04em; }
|
||||||
|
select, input[type=number] {
|
||||||
|
background: #2a2f3a; color: #e8e8ea; border: 1px solid #3a3f4b;
|
||||||
|
border-radius: 8px; padding: 8px 10px; font-size: .9rem;
|
||||||
|
}
|
||||||
|
input[type=number] { width: 78px; }
|
||||||
|
input[type=number]:disabled { opacity: .45; }
|
||||||
|
.check { display: flex; align-items: center; gap: 6px; font-size: .85rem;
|
||||||
|
color: #e8e8ea; cursor: pointer; align-self: end; padding-bottom: 8px; }
|
||||||
|
.check input { width: 15px; height: 15px; accent-color: #5b8cff; cursor: pointer; }
|
||||||
|
button {
|
||||||
|
background: #5b8cff; color: #fff; border: 0; border-radius: 8px;
|
||||||
|
padding: 10px 18px; font-size: .9rem; cursor: pointer; font-weight: 600;
|
||||||
|
}
|
||||||
|
button:disabled { background: #3a3f4b; cursor: not-allowed; }
|
||||||
|
button.ghost { background: #2a2f3a; }
|
||||||
|
.go-row { display: flex; gap: 12px; align-items: center; margin: 16px 0; flex-wrap: wrap; }
|
||||||
|
.status { color: #8a8f99; font-size: .85rem; }
|
||||||
|
.status.err { color: #ff6b6b; }
|
||||||
|
.hint { color: #6b7280; font-size: .78rem; margin-top: -8px; }
|
||||||
|
.panels { display: grid; grid-template-columns: 1fr 1fr; gap: 16px; margin-top: 16px; }
|
||||||
|
.panel { background: #1c1f27; border-radius: 12px; padding: 12px; }
|
||||||
|
.panel h2 { font-size: .8rem; font-weight: 600; color: #8a8f99; margin: 0 0 8px;
|
||||||
|
text-transform: uppercase; letter-spacing: .05em; }
|
||||||
|
.imgbox {
|
||||||
|
min-height: 260px; display: flex; align-items: center; justify-content: center;
|
||||||
|
border-radius: 8px; overflow: hidden;
|
||||||
|
}
|
||||||
|
.checker {
|
||||||
|
background-image:
|
||||||
|
linear-gradient(45deg, #2a2f3a 25%, transparent 25%),
|
||||||
|
linear-gradient(-45deg, #2a2f3a 25%, transparent 25%),
|
||||||
|
linear-gradient(45deg, transparent 75%, #2a2f3a 75%),
|
||||||
|
linear-gradient(-45deg, transparent 75%, #2a2f3a 75%);
|
||||||
|
background-size: 22px 22px;
|
||||||
|
background-position: 0 0, 0 11px, 11px -11px, -11px 0;
|
||||||
|
background-color: #20242d;
|
||||||
|
}
|
||||||
|
.imgbox img { max-width: 100%; max-height: 70vh; display: block; }
|
||||||
|
.imgbox img[src] { cursor: zoom-in; }
|
||||||
|
@media (max-width: 720px) { .panels { grid-template-columns: 1fr; } }
|
||||||
|
|
||||||
|
/* lightbox */
|
||||||
|
.lightbox { position: fixed; inset: 0; z-index: 100; background: rgba(12,13,17,.97);
|
||||||
|
display: flex; align-items: center; justify-content: center; }
|
||||||
|
.lightbox[hidden] { display: none; }
|
||||||
|
.lb-stage { width: 100vw; height: 100vh; overflow: hidden;
|
||||||
|
display: flex; align-items: center; justify-content: center; }
|
||||||
|
.lb-stage img { max-width: 100vw; max-height: 100vh; transform-origin: 0 0;
|
||||||
|
cursor: grab; user-select: none; -webkit-user-drag: none; will-change: transform; }
|
||||||
|
.lb-stage.grabbing img { cursor: grabbing; }
|
||||||
|
.lb-bar { position: fixed; top: 0; left: 0; right: 0; padding: 14px 20px;
|
||||||
|
z-index: 2; display: flex; justify-content: space-between; align-items: center;
|
||||||
|
color: #8a8f99; font-size: .8rem; pointer-events: none; }
|
||||||
|
.lb-close { pointer-events: auto; background: #2a2f3a; color: #e8e8ea;
|
||||||
|
border: 1px solid #3a3f4b; border-radius: 8px; width: 34px; height: 34px;
|
||||||
|
font-size: 1rem; line-height: 1; padding: 0; cursor: pointer;
|
||||||
|
display: flex; align-items: center; justify-content: center; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="wrap">
|
||||||
|
<h1>BiRefNet — Background Removal</h1>
|
||||||
|
<div class="sub">Drop an image to get a transparent-background PNG.</div>
|
||||||
|
|
||||||
|
<div id="drop">
|
||||||
|
<p><strong>Drop an image here</strong> or click to choose</p>
|
||||||
|
<p id="fname">No file selected</p>
|
||||||
|
<input id="file" type="file" accept="image/*" hidden />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="controls">
|
||||||
|
<label class="field">Model
|
||||||
|
<select id="model">
|
||||||
|
<option value="general">general — clean single subjects (fast)</option>
|
||||||
|
<option value="HR" selected>HR — large / detailed scenes</option>
|
||||||
|
<option value="portrait">portrait — people</option>
|
||||||
|
<option value="matting">matting — soft edges / hair</option>
|
||||||
|
<option value="lite">lite — fastest</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<label class="field">Resolution
|
||||||
|
<select id="resolution">
|
||||||
|
<option value="1024">1024</option>
|
||||||
|
<option value="1536">1536</option>
|
||||||
|
<option value="2048">2048</option>
|
||||||
|
<option value="2560" selected>2560</option>
|
||||||
|
</select>
|
||||||
|
</label>
|
||||||
|
<label class="check"><input type="checkbox" id="crop" checked /> Crop to subject</label>
|
||||||
|
<label class="field">Margin (in)
|
||||||
|
<input type="number" id="cropMargin" value="0" min="0" step="0.1" />
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
<div class="hint">Tip: large or busy scenes segment best with <strong>HR</strong> at <strong>2048</strong>.
|
||||||
|
The <em>general</em> model expects a clear single subject at 1024.</div>
|
||||||
|
|
||||||
|
<div class="go-row">
|
||||||
|
<button id="go" disabled>Remove background</button>
|
||||||
|
<a id="dl" download="cutout.png"><button id="dlbtn" class="ghost" disabled>Download PNG</button></a>
|
||||||
|
<span id="status" class="status"></span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="panels">
|
||||||
|
<div class="panel">
|
||||||
|
<h2>Original</h2>
|
||||||
|
<div class="imgbox"><img id="src" alt="" /></div>
|
||||||
|
</div>
|
||||||
|
<div class="panel">
|
||||||
|
<h2>Result</h2>
|
||||||
|
<div class="imgbox checker"><img id="out" alt="" /></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="lightbox" class="lightbox" hidden>
|
||||||
|
<div class="lb-bar">
|
||||||
|
<span>scroll to zoom · drag to pan · double-click resets · Esc closes</span>
|
||||||
|
<button class="lb-close" id="lbClose" title="Close">✕</button>
|
||||||
|
</div>
|
||||||
|
<div class="lb-stage" id="lbStage"><img id="lbImg" alt="" /></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
const drop = document.getElementById('drop');
|
||||||
|
const fileInput = document.getElementById('file');
|
||||||
|
const fname = document.getElementById('fname');
|
||||||
|
const go = document.getElementById('go');
|
||||||
|
const dl = document.getElementById('dl');
|
||||||
|
const dlbtn = document.getElementById('dlbtn');
|
||||||
|
const statusEl = document.getElementById('status');
|
||||||
|
const srcImg = document.getElementById('src');
|
||||||
|
const outImg = document.getElementById('out');
|
||||||
|
const modelSel = document.getElementById('model');
|
||||||
|
const resSel = document.getElementById('resolution');
|
||||||
|
const cropChk = document.getElementById('crop');
|
||||||
|
const cropMargin = document.getElementById('cropMargin');
|
||||||
|
|
||||||
|
cropChk.addEventListener('change', () => { cropMargin.disabled = !cropChk.checked; });
|
||||||
|
|
||||||
|
let selectedFile = null;
|
||||||
|
|
||||||
|
function setStatus(msg, isErr) {
|
||||||
|
statusEl.textContent = msg;
|
||||||
|
statusEl.className = 'status' + (isErr ? ' err' : '');
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickFile(file) {
|
||||||
|
if (!file || !file.type.startsWith('image/')) {
|
||||||
|
setStatus('Please choose an image file.', true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
selectedFile = file;
|
||||||
|
fname.textContent = file.name + ' (' + Math.round(file.size / 1024) + ' KB)';
|
||||||
|
srcImg.src = URL.createObjectURL(file);
|
||||||
|
outImg.removeAttribute('src');
|
||||||
|
dlbtn.disabled = true;
|
||||||
|
go.disabled = false;
|
||||||
|
setStatus('');
|
||||||
|
}
|
||||||
|
|
||||||
|
drop.addEventListener('click', () => fileInput.click());
|
||||||
|
fileInput.addEventListener('change', e => pickFile(e.target.files[0]));
|
||||||
|
['dragenter', 'dragover'].forEach(ev =>
|
||||||
|
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.add('over'); }));
|
||||||
|
['dragleave', 'drop'].forEach(ev =>
|
||||||
|
drop.addEventListener(ev, e => { e.preventDefault(); drop.classList.remove('over'); }));
|
||||||
|
drop.addEventListener('drop', e => pickFile(e.dataTransfer.files[0]));
|
||||||
|
|
||||||
|
function fileToBase64(file) {
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const r = new FileReader();
|
||||||
|
r.onload = () => resolve(r.result.split(',')[1]); // strip data URL prefix
|
||||||
|
r.onerror = reject;
|
||||||
|
r.readAsDataURL(file);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- lightbox: click to inspect, scroll to zoom, drag to pan ---
|
||||||
|
const lightbox = document.getElementById('lightbox');
|
||||||
|
const lbStage = document.getElementById('lbStage');
|
||||||
|
const lbImg = document.getElementById('lbImg');
|
||||||
|
const lbClose = document.getElementById('lbClose');
|
||||||
|
let lbScale = 1, lbTx = 0, lbTy = 0, lbDrag = null;
|
||||||
|
|
||||||
|
function lbApply() {
|
||||||
|
lbImg.style.transform = `translate(${lbTx}px, ${lbTy}px) scale(${lbScale})`;
|
||||||
|
}
|
||||||
|
function lbReset() { lbScale = 1; lbTx = 0; lbTy = 0; lbApply(); }
|
||||||
|
|
||||||
|
function openLightbox(src, isResult) {
|
||||||
|
if (!src) return;
|
||||||
|
lbImg.src = src;
|
||||||
|
lbImg.classList.toggle('checker', !!isResult);
|
||||||
|
lbReset();
|
||||||
|
lightbox.hidden = false;
|
||||||
|
}
|
||||||
|
function closeLightbox() { lightbox.hidden = true; lbImg.removeAttribute('src'); }
|
||||||
|
|
||||||
|
srcImg.addEventListener('click', () => openLightbox(srcImg.getAttribute('src'), false));
|
||||||
|
outImg.addEventListener('click', () => openLightbox(outImg.getAttribute('src'), true));
|
||||||
|
lbClose.addEventListener('click', closeLightbox);
|
||||||
|
lightbox.addEventListener('mousedown', e => {
|
||||||
|
if (e.target === lightbox || e.target === lbStage) closeLightbox();
|
||||||
|
});
|
||||||
|
document.addEventListener('keydown', e => {
|
||||||
|
if (e.key === 'Escape' && !lightbox.hidden) closeLightbox();
|
||||||
|
});
|
||||||
|
|
||||||
|
lbStage.addEventListener('wheel', e => {
|
||||||
|
e.preventDefault();
|
||||||
|
const rect = lbImg.getBoundingClientRect();
|
||||||
|
const cx = e.clientX - rect.left, cy = e.clientY - rect.top;
|
||||||
|
const factor = e.deltaY < 0 ? 1.2 : 1 / 1.2;
|
||||||
|
const newScale = Math.min(8, Math.max(1, lbScale * factor));
|
||||||
|
const ratio = newScale / lbScale;
|
||||||
|
lbTx -= cx * (ratio - 1);
|
||||||
|
lbTy -= cy * (ratio - 1);
|
||||||
|
lbScale = newScale;
|
||||||
|
if (lbScale === 1) { lbTx = 0; lbTy = 0; }
|
||||||
|
lbApply();
|
||||||
|
}, { passive: false });
|
||||||
|
|
||||||
|
lbImg.addEventListener('mousedown', e => {
|
||||||
|
e.preventDefault();
|
||||||
|
lbDrag = { x: e.clientX, y: e.clientY, tx: lbTx, ty: lbTy };
|
||||||
|
lbStage.classList.add('grabbing');
|
||||||
|
});
|
||||||
|
window.addEventListener('mousemove', e => {
|
||||||
|
if (!lbDrag) return;
|
||||||
|
lbTx = lbDrag.tx + (e.clientX - lbDrag.x);
|
||||||
|
lbTy = lbDrag.ty + (e.clientY - lbDrag.y);
|
||||||
|
lbApply();
|
||||||
|
});
|
||||||
|
window.addEventListener('mouseup', () => {
|
||||||
|
lbDrag = null;
|
||||||
|
lbStage.classList.remove('grabbing');
|
||||||
|
});
|
||||||
|
lbImg.addEventListener('dblclick', e => { e.preventDefault(); lbReset(); });
|
||||||
|
|
||||||
|
go.addEventListener('click', async () => {
|
||||||
|
if (!selectedFile) return;
|
||||||
|
go.disabled = true;
|
||||||
|
dlbtn.disabled = true;
|
||||||
|
setStatus('Processing… (first use of a model downloads its weights)');
|
||||||
|
const t0 = performance.now();
|
||||||
|
try {
|
||||||
|
const b64 = await fileToBase64(selectedFile);
|
||||||
|
const resp = await fetch('/predict', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
image: b64,
|
||||||
|
background: 'alpha',
|
||||||
|
model: modelSel.value,
|
||||||
|
resolution: parseInt(resSel.value, 10),
|
||||||
|
crop: cropChk.checked,
|
||||||
|
crop_margin: parseFloat(cropMargin.value) || 0,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
if (!resp.ok) throw new Error('HTTP ' + resp.status + ': ' + (await resp.text()));
|
||||||
|
const data = await resp.json();
|
||||||
|
const dataUrl = 'data:image/png;base64,' + data.image;
|
||||||
|
outImg.src = dataUrl;
|
||||||
|
dl.href = dataUrl;
|
||||||
|
dl.download = selectedFile.name.replace(/\.[^.]+$/, '') + '.png';
|
||||||
|
dlbtn.disabled = false;
|
||||||
|
const secs = ((performance.now() - t0) / 1000).toFixed(1);
|
||||||
|
setStatus('Done — ' + data.width + '×' + data.height + ' · ' + data.model +
|
||||||
|
' @ ' + data.resolution + ' · ' + secs + 's');
|
||||||
|
} catch (err) {
|
||||||
|
setStatus(err.message || String(err), true);
|
||||||
|
} finally {
|
||||||
|
go.disabled = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
Loading…
Reference in New Issue
Block a user