rmbg/scripts/client.py
2026-05-16 17:04:56 -06:00

108 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""Minimal stdlib-only client for the BiRefNet service.
Encodes an image, posts it to /predict, and saves the returned PNG.
No third-party dependencies so it can run with any system Python.
"""
from __future__ import annotations
import argparse
import base64
import json
import sys
import time
import urllib.error
import urllib.request
def wait_for_health(base_url: str, timeout: float) -> None:
deadline = time.time() + timeout
health = f"{base_url}/health"
while time.time() < deadline:
try:
with urllib.request.urlopen(health, timeout=5) as resp:
if resp.status == 200:
return
except (urllib.error.URLError, ConnectionError, OSError):
pass
time.sleep(2)
sys.exit(f"server at {health} not ready after {timeout:.0f}s")
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__)
ap.add_argument("--url", default="http://localhost:8000", help="service base URL")
ap.add_argument("--input", default="test.jpg", help="input image path")
ap.add_argument("--output", default="output.png", help="output PNG path")
ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red")
ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite")
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
ap.add_argument("--prompt", default=None, help="if set, use prompt segmentation (/segment)")
ap.add_argument("--box-threshold", type=float, default=0.3, help="GroundingDINO box threshold")
ap.add_argument("--text-threshold", type=float, default=0.25, help="GroundingDINO text threshold")
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
args = ap.parse_args()
base_url = args.url.rstrip("/")
wait_for_health(base_url, args.wait)
with open(args.input, "rb") as f:
payload = {
"image": base64.b64encode(f.read()).decode("ascii"),
"background": args.background,
"mask_blur": args.mask_blur,
"return_mask": args.mask_output is not None,
}
if args.crop:
payload["crop"] = True
payload["crop_margin"] = args.crop_margin
if args.prompt is not None:
endpoint = "/segment"
payload["prompt"] = args.prompt
payload["box_threshold"] = args.box_threshold
payload["text_threshold"] = args.text_threshold
else:
endpoint = "/predict"
if args.model is not None:
payload["model"] = args.model
if args.resolution is not None:
payload["resolution"] = args.resolution
req = urllib.request.Request(
f"{base_url}{endpoint}",
data=json.dumps(payload).encode(),
headers={"Content-Type": "application/json"},
method="POST",
)
started = time.time()
with urllib.request.urlopen(req, timeout=300) as resp:
result = json.loads(resp.read())
elapsed = time.time() - started
with open(args.output, "wb") as f:
f.write(base64.b64decode(result["image"]))
if "detections" in result:
detail = f"{result['detections']} object(s) matched '{result.get('prompt')}'"
else:
detail = f"{result.get('model')} @ {result.get('resolution')}"
print(
f"saved {args.output} {result['width']}x{result['height']} "
f"{detail} ({elapsed:.1f}s)"
)
if args.mask_output and "mask" in result:
with open(args.mask_output, "wb") as f:
f.write(base64.b64decode(result["mask"]))
print(f"saved {args.mask_output}")
if __name__ == "__main__":
main()