93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Minimal stdlib-only client for the BiRefNet service.
|
|
|
|
Encodes an image, posts it to /predict, and saves the returned PNG.
|
|
No third-party dependencies so it can run with any system Python.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import base64
|
|
import json
|
|
import sys
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
|
|
|
|
def wait_for_health(base_url: str, timeout: float) -> None:
|
|
deadline = time.time() + timeout
|
|
health = f"{base_url}/health"
|
|
while time.time() < deadline:
|
|
try:
|
|
with urllib.request.urlopen(health, timeout=5) as resp:
|
|
if resp.status == 200:
|
|
return
|
|
except (urllib.error.URLError, ConnectionError, OSError):
|
|
pass
|
|
time.sleep(2)
|
|
sys.exit(f"server at {health} not ready after {timeout:.0f}s")
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description=__doc__)
|
|
ap.add_argument("--url", default="http://localhost:8000", help="service base URL")
|
|
ap.add_argument("--input", default="test.jpg", help="input image path")
|
|
ap.add_argument("--output", default="output.png", help="output PNG path")
|
|
ap.add_argument("--background", default="alpha", help="alpha|white|black|gray|green|blue|red")
|
|
ap.add_argument("--model", default=None, help="variant: general|HR|portrait|matting|lite")
|
|
ap.add_argument("--resolution", type=int, default=None, help="inference resolution, e.g. 2048")
|
|
ap.add_argument("--crop", action="store_true", help="crop output to the subject bounding box")
|
|
ap.add_argument("--crop-margin", type=float, default=0.0, help="crop margin in inches")
|
|
ap.add_argument("--mask-blur", type=int, default=0, help="Gaussian blur radius for mask edges")
|
|
ap.add_argument("--mask-output", default=None, help="also save the raw mask to this path")
|
|
ap.add_argument("--wait", type=float, default=180, help="seconds to wait for /health")
|
|
args = ap.parse_args()
|
|
|
|
base_url = args.url.rstrip("/")
|
|
wait_for_health(base_url, args.wait)
|
|
|
|
with open(args.input, "rb") as f:
|
|
payload = {
|
|
"image": base64.b64encode(f.read()).decode("ascii"),
|
|
"background": args.background,
|
|
"mask_blur": args.mask_blur,
|
|
"return_mask": args.mask_output is not None,
|
|
}
|
|
if args.model is not None:
|
|
payload["model"] = args.model
|
|
if args.resolution is not None:
|
|
payload["resolution"] = args.resolution
|
|
if args.crop:
|
|
payload["crop"] = True
|
|
payload["crop_margin"] = args.crop_margin
|
|
|
|
req = urllib.request.Request(
|
|
f"{base_url}/predict",
|
|
data=json.dumps(payload).encode(),
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
|
|
started = time.time()
|
|
with urllib.request.urlopen(req, timeout=300) as resp:
|
|
result = json.loads(resp.read())
|
|
elapsed = time.time() - started
|
|
|
|
with open(args.output, "wb") as f:
|
|
f.write(base64.b64decode(result["image"]))
|
|
print(
|
|
f"saved {args.output} {result['width']}x{result['height']} "
|
|
f"{result.get('model')} @ {result.get('resolution')} ({elapsed:.1f}s)"
|
|
)
|
|
|
|
if args.mask_output and "mask" in result:
|
|
with open(args.mask_output, "wb") as f:
|
|
f.write(base64.b64decode(result["mask"]))
|
|
print(f"saved {args.mask_output}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|