Spaces:
Running
Running
- app.py +11 -11
- rembg/_version.py +2 -2
- rembg/bg.py +18 -6
- rembg/cli.py +3 -429
- rembg/commands/__init__.py +13 -0
- rembg/commands/i_command.py +93 -0
- rembg/commands/p_command.py +181 -0
- rembg/commands/s_command.py +238 -0
- rembg/session_factory.py +10 -57
- rembg/sessions/__init__.py +22 -0
- rembg/sessions/base.py +63 -0
- rembg/sessions/dis.py +47 -0
- rembg/sessions/sam.py +165 -0
- rembg/sessions/silueta.py +49 -0
- rembg/sessions/u2net.py +49 -0
- rembg/sessions/u2net_cloth_seg.py +110 -0
- rembg/sessions/u2net_human_seg.py +49 -0
- rembg/sessions/u2netp.py +49 -0
app.py
CHANGED
|
@@ -9,9 +9,7 @@ def inference(file, af, mask, model):
|
|
| 9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
| 10 |
cv2.imwrite(os.path.join("input.png"), im)
|
| 11 |
|
| 12 |
-
from rembg import remove
|
| 13 |
-
from rembg.session_base import BaseSession
|
| 14 |
-
from rembg.session_factory import new_session
|
| 15 |
|
| 16 |
input_path = 'input.png'
|
| 17 |
output_path = 'output.png'
|
|
@@ -19,15 +17,15 @@ def inference(file, af, mask, model):
|
|
| 19 |
with open(input_path, 'rb') as i:
|
| 20 |
with open(output_path, 'wb') as o:
|
| 21 |
input = i.read()
|
| 22 |
-
sessions: dict[str, BaseSession] = {}
|
| 23 |
output = remove(
|
| 24 |
input,
|
| 25 |
-
session=
|
| 26 |
-
model, new_session(model)
|
| 27 |
-
),
|
| 28 |
alpha_matting_erode_size = af,
|
| 29 |
only_mask = (True if mask == "Mask only" else False)
|
| 30 |
-
)
|
|
|
|
|
|
|
|
|
|
| 31 |
o.write(output)
|
| 32 |
return os.path.join("output.png")
|
| 33 |
|
|
@@ -40,7 +38,7 @@ gr.Interface(
|
|
| 40 |
inference,
|
| 41 |
[
|
| 42 |
gr.inputs.Image(type="filepath", label="Input"),
|
| 43 |
-
gr.inputs.Slider(10, 25, default=10, label="Alpha matting"),
|
| 44 |
gr.inputs.Radio(
|
| 45 |
[
|
| 46 |
"Default",
|
|
@@ -55,14 +53,16 @@ gr.Interface(
|
|
| 55 |
"u2netp",
|
| 56 |
"u2net_human_seg",
|
| 57 |
"u2net_cloth_seg",
|
| 58 |
-
"silueta"
|
|
|
|
|
|
|
| 59 |
],
|
| 60 |
type="value",
|
| 61 |
default="u2net",
|
| 62 |
label="Models"
|
| 63 |
),
|
| 64 |
],
|
| 65 |
-
gr.outputs.Image(type="
|
| 66 |
title=title,
|
| 67 |
description=description,
|
| 68 |
article=article,
|
|
|
|
| 9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
| 10 |
cv2.imwrite(os.path.join("input.png"), im)
|
| 11 |
|
| 12 |
+
from rembg import new_session, remove
|
|
|
|
|
|
|
| 13 |
|
| 14 |
input_path = 'input.png'
|
| 15 |
output_path = 'output.png'
|
|
|
|
| 17 |
with open(input_path, 'rb') as i:
|
| 18 |
with open(output_path, 'wb') as o:
|
| 19 |
input = i.read()
|
|
|
|
| 20 |
output = remove(
|
| 21 |
input,
|
| 22 |
+
session = new_session(model),
|
|
|
|
|
|
|
| 23 |
alpha_matting_erode_size = af,
|
| 24 |
only_mask = (True if mask == "Mask only" else False)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
o.write(output)
|
| 30 |
return os.path.join("output.png")
|
| 31 |
|
|
|
|
| 38 |
inference,
|
| 39 |
[
|
| 40 |
gr.inputs.Image(type="filepath", label="Input"),
|
| 41 |
+
gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
|
| 42 |
gr.inputs.Radio(
|
| 43 |
[
|
| 44 |
"Default",
|
|
|
|
| 53 |
"u2netp",
|
| 54 |
"u2net_human_seg",
|
| 55 |
"u2net_cloth_seg",
|
| 56 |
+
"silueta",
|
| 57 |
+
"isnet-general-use",
|
| 58 |
+
"sam",
|
| 59 |
],
|
| 60 |
type="value",
|
| 61 |
default="u2net",
|
| 62 |
label="Models"
|
| 63 |
),
|
| 64 |
],
|
| 65 |
+
gr.outputs.Image(type="filepath", label="Output"),
|
| 66 |
title=title,
|
| 67 |
description=description,
|
| 68 |
article=article,
|
rembg/_version.py
CHANGED
|
@@ -24,8 +24,8 @@ def get_keywords():
|
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
git_refnames = " (HEAD -> main)"
|
| 27 |
-
git_full = "
|
| 28 |
-
git_date = "
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
|
|
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
git_refnames = " (HEAD -> main)"
|
| 27 |
+
git_full = "e47b2a0ed405a5a30f42bacb142b107f7a4b6536"
|
| 28 |
+
git_date = "2023-04-26 20:40:21 -0300"
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
rembg/bg.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import io
|
| 2 |
from enum import Enum
|
| 3 |
-
from typing import List, Optional, Union
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from cv2 import (
|
|
@@ -18,8 +18,8 @@ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
| 18 |
from pymatting.util.util import stack_images
|
| 19 |
from scipy.ndimage import binary_erosion
|
| 20 |
|
| 21 |
-
from .session_base import BaseSession
|
| 22 |
from .session_factory import new_session
|
|
|
|
| 23 |
|
| 24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
| 25 |
|
|
@@ -37,7 +37,6 @@ def alpha_matting_cutout(
|
|
| 37 |
background_threshold: int,
|
| 38 |
erode_structure_size: int,
|
| 39 |
) -> PILImage:
|
| 40 |
-
|
| 41 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
| 42 |
img = img.convert("RGB")
|
| 43 |
|
|
@@ -106,6 +105,14 @@ def post_process(mask: np.ndarray) -> np.ndarray:
|
|
| 106 |
return mask
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def remove(
|
| 110 |
data: Union[bytes, PILImage, np.ndarray],
|
| 111 |
alpha_matting: bool = False,
|
|
@@ -115,8 +122,10 @@ def remove(
|
|
| 115 |
session: Optional[BaseSession] = None,
|
| 116 |
only_mask: bool = False,
|
| 117 |
post_process_mask: bool = False,
|
|
|
|
|
|
|
|
|
|
| 118 |
) -> Union[bytes, PILImage, np.ndarray]:
|
| 119 |
-
|
| 120 |
if isinstance(data, PILImage):
|
| 121 |
return_type = ReturnType.PILLOW
|
| 122 |
img = data
|
|
@@ -130,9 +139,9 @@ def remove(
|
|
| 130 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
| 131 |
|
| 132 |
if session is None:
|
| 133 |
-
session = new_session("u2net")
|
| 134 |
|
| 135 |
-
masks = session.predict(img)
|
| 136 |
cutouts = []
|
| 137 |
|
| 138 |
for mask in masks:
|
|
@@ -163,6 +172,9 @@ def remove(
|
|
| 163 |
if len(cutouts) > 0:
|
| 164 |
cutout = get_concat_v_multi(cutouts)
|
| 165 |
|
|
|
|
|
|
|
|
|
|
| 166 |
if ReturnType.PILLOW == return_type:
|
| 167 |
return cutout
|
| 168 |
|
|
|
|
| 1 |
import io
|
| 2 |
from enum import Enum
|
| 3 |
+
from typing import Any, List, Optional, Tuple, Union
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
from cv2 import (
|
|
|
|
| 18 |
from pymatting.util.util import stack_images
|
| 19 |
from scipy.ndimage import binary_erosion
|
| 20 |
|
|
|
|
| 21 |
from .session_factory import new_session
|
| 22 |
+
from .sessions.base import BaseSession
|
| 23 |
|
| 24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
| 25 |
|
|
|
|
| 37 |
background_threshold: int,
|
| 38 |
erode_structure_size: int,
|
| 39 |
) -> PILImage:
|
|
|
|
| 40 |
if img.mode == "RGBA" or img.mode == "CMYK":
|
| 41 |
img = img.convert("RGB")
|
| 42 |
|
|
|
|
| 105 |
return mask
|
| 106 |
|
| 107 |
|
| 108 |
+
def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage:
|
| 109 |
+
r, g, b, a = color
|
| 110 |
+
colored_image = Image.new("RGBA", img.size, (r, g, b, a))
|
| 111 |
+
colored_image.paste(img, mask=img)
|
| 112 |
+
|
| 113 |
+
return colored_image
|
| 114 |
+
|
| 115 |
+
|
| 116 |
def remove(
|
| 117 |
data: Union[bytes, PILImage, np.ndarray],
|
| 118 |
alpha_matting: bool = False,
|
|
|
|
| 122 |
session: Optional[BaseSession] = None,
|
| 123 |
only_mask: bool = False,
|
| 124 |
post_process_mask: bool = False,
|
| 125 |
+
bgcolor: Optional[Tuple[int, int, int, int]] = None,
|
| 126 |
+
*args: Optional[Any],
|
| 127 |
+
**kwargs: Optional[Any]
|
| 128 |
) -> Union[bytes, PILImage, np.ndarray]:
|
|
|
|
| 129 |
if isinstance(data, PILImage):
|
| 130 |
return_type = ReturnType.PILLOW
|
| 131 |
img = data
|
|
|
|
| 139 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
| 140 |
|
| 141 |
if session is None:
|
| 142 |
+
session = new_session("u2net", *args, **kwargs)
|
| 143 |
|
| 144 |
+
masks = session.predict(img, *args, **kwargs)
|
| 145 |
cutouts = []
|
| 146 |
|
| 147 |
for mask in masks:
|
|
|
|
| 172 |
if len(cutouts) > 0:
|
| 173 |
cutout = get_concat_v_multi(cutouts)
|
| 174 |
|
| 175 |
+
if bgcolor is not None and not only_mask:
|
| 176 |
+
cutout = apply_background_color(cutout, bgcolor)
|
| 177 |
+
|
| 178 |
if ReturnType.PILLOW == return_type:
|
| 179 |
return cutout
|
| 180 |
|
rembg/cli.py
CHANGED
|
@@ -1,25 +1,7 @@
|
|
| 1 |
-
import pathlib
|
| 2 |
-
import sys
|
| 3 |
-
import time
|
| 4 |
-
from enum import Enum
|
| 5 |
-
from typing import IO, cast
|
| 6 |
-
|
| 7 |
-
import aiohttp
|
| 8 |
import click
|
| 9 |
-
import filetype
|
| 10 |
-
import uvicorn
|
| 11 |
-
from asyncer import asyncify
|
| 12 |
-
from fastapi import Depends, FastAPI, File, Form, Query
|
| 13 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
-
from starlette.responses import Response
|
| 15 |
-
from tqdm import tqdm
|
| 16 |
-
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
| 17 |
-
from watchdog.observers import Observer
|
| 18 |
|
| 19 |
from . import _version
|
| 20 |
-
from .
|
| 21 |
-
from .session_base import BaseSession
|
| 22 |
-
from .session_factory import new_session
|
| 23 |
|
| 24 |
|
| 25 |
@click.group()
|
|
@@ -28,413 +10,5 @@ def main() -> None:
|
|
| 28 |
pass
|
| 29 |
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"-m",
|
| 34 |
-
"--model",
|
| 35 |
-
default="u2net",
|
| 36 |
-
type=click.Choice(
|
| 37 |
-
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
| 38 |
-
),
|
| 39 |
-
show_default=True,
|
| 40 |
-
show_choices=True,
|
| 41 |
-
help="model name",
|
| 42 |
-
)
|
| 43 |
-
@click.option(
|
| 44 |
-
"-a",
|
| 45 |
-
"--alpha-matting",
|
| 46 |
-
is_flag=True,
|
| 47 |
-
show_default=True,
|
| 48 |
-
help="use alpha matting",
|
| 49 |
-
)
|
| 50 |
-
@click.option(
|
| 51 |
-
"-af",
|
| 52 |
-
"--alpha-matting-foreground-threshold",
|
| 53 |
-
default=240,
|
| 54 |
-
type=int,
|
| 55 |
-
show_default=True,
|
| 56 |
-
help="trimap fg threshold",
|
| 57 |
-
)
|
| 58 |
-
@click.option(
|
| 59 |
-
"-ab",
|
| 60 |
-
"--alpha-matting-background-threshold",
|
| 61 |
-
default=10,
|
| 62 |
-
type=int,
|
| 63 |
-
show_default=True,
|
| 64 |
-
help="trimap bg threshold",
|
| 65 |
-
)
|
| 66 |
-
@click.option(
|
| 67 |
-
"-ae",
|
| 68 |
-
"--alpha-matting-erode-size",
|
| 69 |
-
default=10,
|
| 70 |
-
type=int,
|
| 71 |
-
show_default=True,
|
| 72 |
-
help="erode size",
|
| 73 |
-
)
|
| 74 |
-
@click.option(
|
| 75 |
-
"-om",
|
| 76 |
-
"--only-mask",
|
| 77 |
-
is_flag=True,
|
| 78 |
-
show_default=True,
|
| 79 |
-
help="output only the mask",
|
| 80 |
-
)
|
| 81 |
-
@click.option(
|
| 82 |
-
"-ppm",
|
| 83 |
-
"--post-process-mask",
|
| 84 |
-
is_flag=True,
|
| 85 |
-
show_default=True,
|
| 86 |
-
help="post process the mask",
|
| 87 |
-
)
|
| 88 |
-
@click.argument(
|
| 89 |
-
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
| 90 |
-
)
|
| 91 |
-
@click.argument(
|
| 92 |
-
"output",
|
| 93 |
-
default=(None if sys.stdin.isatty() else "-"),
|
| 94 |
-
type=click.File("wb", lazy=True),
|
| 95 |
-
)
|
| 96 |
-
def i(model: str, input: IO, output: IO, **kwargs) -> None:
|
| 97 |
-
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
@main.command(help="for a folder as input")
|
| 101 |
-
@click.option(
|
| 102 |
-
"-m",
|
| 103 |
-
"--model",
|
| 104 |
-
default="u2net",
|
| 105 |
-
type=click.Choice(
|
| 106 |
-
["u2net", "u2netp", "u2net_human_seg", "u2net_cloth_seg", "silueta"]
|
| 107 |
-
),
|
| 108 |
-
show_default=True,
|
| 109 |
-
show_choices=True,
|
| 110 |
-
help="model name",
|
| 111 |
-
)
|
| 112 |
-
@click.option(
|
| 113 |
-
"-a",
|
| 114 |
-
"--alpha-matting",
|
| 115 |
-
is_flag=True,
|
| 116 |
-
show_default=True,
|
| 117 |
-
help="use alpha matting",
|
| 118 |
-
)
|
| 119 |
-
@click.option(
|
| 120 |
-
"-af",
|
| 121 |
-
"--alpha-matting-foreground-threshold",
|
| 122 |
-
default=240,
|
| 123 |
-
type=int,
|
| 124 |
-
show_default=True,
|
| 125 |
-
help="trimap fg threshold",
|
| 126 |
-
)
|
| 127 |
-
@click.option(
|
| 128 |
-
"-ab",
|
| 129 |
-
"--alpha-matting-background-threshold",
|
| 130 |
-
default=10,
|
| 131 |
-
type=int,
|
| 132 |
-
show_default=True,
|
| 133 |
-
help="trimap bg threshold",
|
| 134 |
-
)
|
| 135 |
-
@click.option(
|
| 136 |
-
"-ae",
|
| 137 |
-
"--alpha-matting-erode-size",
|
| 138 |
-
default=10,
|
| 139 |
-
type=int,
|
| 140 |
-
show_default=True,
|
| 141 |
-
help="erode size",
|
| 142 |
-
)
|
| 143 |
-
@click.option(
|
| 144 |
-
"-om",
|
| 145 |
-
"--only-mask",
|
| 146 |
-
is_flag=True,
|
| 147 |
-
show_default=True,
|
| 148 |
-
help="output only the mask",
|
| 149 |
-
)
|
| 150 |
-
@click.option(
|
| 151 |
-
"-ppm",
|
| 152 |
-
"--post-process-mask",
|
| 153 |
-
is_flag=True,
|
| 154 |
-
show_default=True,
|
| 155 |
-
help="post process the mask",
|
| 156 |
-
)
|
| 157 |
-
@click.option(
|
| 158 |
-
"-w",
|
| 159 |
-
"--watch",
|
| 160 |
-
default=False,
|
| 161 |
-
is_flag=True,
|
| 162 |
-
show_default=True,
|
| 163 |
-
help="watches a folder for changes",
|
| 164 |
-
)
|
| 165 |
-
@click.argument(
|
| 166 |
-
"input",
|
| 167 |
-
type=click.Path(
|
| 168 |
-
exists=True,
|
| 169 |
-
path_type=pathlib.Path,
|
| 170 |
-
file_okay=False,
|
| 171 |
-
dir_okay=True,
|
| 172 |
-
readable=True,
|
| 173 |
-
),
|
| 174 |
-
)
|
| 175 |
-
@click.argument(
|
| 176 |
-
"output",
|
| 177 |
-
type=click.Path(
|
| 178 |
-
exists=False,
|
| 179 |
-
path_type=pathlib.Path,
|
| 180 |
-
file_okay=False,
|
| 181 |
-
dir_okay=True,
|
| 182 |
-
writable=True,
|
| 183 |
-
),
|
| 184 |
-
)
|
| 185 |
-
def p(
|
| 186 |
-
model: str, input: pathlib.Path, output: pathlib.Path, watch: bool, **kwargs
|
| 187 |
-
) -> None:
|
| 188 |
-
session = new_session(model)
|
| 189 |
-
|
| 190 |
-
def process(each_input: pathlib.Path) -> None:
|
| 191 |
-
try:
|
| 192 |
-
mimetype = filetype.guess(each_input)
|
| 193 |
-
if mimetype is None:
|
| 194 |
-
return
|
| 195 |
-
if mimetype.mime.find("image") < 0:
|
| 196 |
-
return
|
| 197 |
-
|
| 198 |
-
each_output = (output / each_input.name).with_suffix(".png")
|
| 199 |
-
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
| 200 |
-
|
| 201 |
-
if not each_output.exists():
|
| 202 |
-
each_output.write_bytes(
|
| 203 |
-
cast(
|
| 204 |
-
bytes,
|
| 205 |
-
remove(each_input.read_bytes(), session=session, **kwargs),
|
| 206 |
-
)
|
| 207 |
-
)
|
| 208 |
-
|
| 209 |
-
if watch:
|
| 210 |
-
print(
|
| 211 |
-
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
| 212 |
-
)
|
| 213 |
-
except Exception as e:
|
| 214 |
-
print(e)
|
| 215 |
-
|
| 216 |
-
inputs = list(input.glob("**/*"))
|
| 217 |
-
if not watch:
|
| 218 |
-
inputs = tqdm(inputs)
|
| 219 |
-
|
| 220 |
-
for each_input in inputs:
|
| 221 |
-
if not each_input.is_dir():
|
| 222 |
-
process(each_input)
|
| 223 |
-
|
| 224 |
-
if watch:
|
| 225 |
-
observer = Observer()
|
| 226 |
-
|
| 227 |
-
class EventHandler(FileSystemEventHandler):
|
| 228 |
-
def on_any_event(self, event: FileSystemEvent) -> None:
|
| 229 |
-
if not (
|
| 230 |
-
event.is_directory or event.event_type in ["deleted", "closed"]
|
| 231 |
-
):
|
| 232 |
-
process(pathlib.Path(event.src_path))
|
| 233 |
-
|
| 234 |
-
event_handler = EventHandler()
|
| 235 |
-
observer.schedule(event_handler, input, recursive=False)
|
| 236 |
-
observer.start()
|
| 237 |
-
|
| 238 |
-
try:
|
| 239 |
-
while True:
|
| 240 |
-
time.sleep(1)
|
| 241 |
-
|
| 242 |
-
finally:
|
| 243 |
-
observer.stop()
|
| 244 |
-
observer.join()
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
@main.command(help="for a http server")
|
| 248 |
-
@click.option(
|
| 249 |
-
"-p",
|
| 250 |
-
"--port",
|
| 251 |
-
default=5000,
|
| 252 |
-
type=int,
|
| 253 |
-
show_default=True,
|
| 254 |
-
help="port",
|
| 255 |
-
)
|
| 256 |
-
@click.option(
|
| 257 |
-
"-l",
|
| 258 |
-
"--log_level",
|
| 259 |
-
default="info",
|
| 260 |
-
type=str,
|
| 261 |
-
show_default=True,
|
| 262 |
-
help="log level",
|
| 263 |
-
)
|
| 264 |
-
@click.option(
|
| 265 |
-
"-t",
|
| 266 |
-
"--threads",
|
| 267 |
-
default=None,
|
| 268 |
-
type=int,
|
| 269 |
-
show_default=True,
|
| 270 |
-
help="number of worker threads",
|
| 271 |
-
)
|
| 272 |
-
def s(port: int, log_level: str, threads: int) -> None:
|
| 273 |
-
sessions: dict[str, BaseSession] = {}
|
| 274 |
-
tags_metadata = [
|
| 275 |
-
{
|
| 276 |
-
"name": "Background Removal",
|
| 277 |
-
"description": "Endpoints that perform background removal with different image sources.",
|
| 278 |
-
"externalDocs": {
|
| 279 |
-
"description": "GitHub Source",
|
| 280 |
-
"url": "https://github.com/danielgatis/rembg",
|
| 281 |
-
},
|
| 282 |
-
},
|
| 283 |
-
]
|
| 284 |
-
app = FastAPI(
|
| 285 |
-
title="Rembg",
|
| 286 |
-
description="Rembg is a tool to remove images background. That is it.",
|
| 287 |
-
version=_version.get_versions()["version"],
|
| 288 |
-
contact={
|
| 289 |
-
"name": "Daniel Gatis",
|
| 290 |
-
"url": "https://github.com/danielgatis",
|
| 291 |
-
"email": "[email protected]",
|
| 292 |
-
},
|
| 293 |
-
license_info={
|
| 294 |
-
"name": "MIT License",
|
| 295 |
-
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
| 296 |
-
},
|
| 297 |
-
openapi_tags=tags_metadata,
|
| 298 |
-
)
|
| 299 |
-
|
| 300 |
-
app.add_middleware(
|
| 301 |
-
CORSMiddleware,
|
| 302 |
-
allow_credentials=True,
|
| 303 |
-
allow_origins=["*"],
|
| 304 |
-
allow_methods=["*"],
|
| 305 |
-
allow_headers=["*"],
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
class ModelType(str, Enum):
|
| 309 |
-
u2net = "u2net"
|
| 310 |
-
u2netp = "u2netp"
|
| 311 |
-
u2net_human_seg = "u2net_human_seg"
|
| 312 |
-
u2net_cloth_seg = "u2net_cloth_seg"
|
| 313 |
-
silueta = "silueta"
|
| 314 |
-
|
| 315 |
-
class CommonQueryParams:
|
| 316 |
-
def __init__(
|
| 317 |
-
self,
|
| 318 |
-
model: ModelType = Query(
|
| 319 |
-
default=ModelType.u2net,
|
| 320 |
-
description="Model to use when processing image",
|
| 321 |
-
),
|
| 322 |
-
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
| 323 |
-
af: int = Query(
|
| 324 |
-
default=240,
|
| 325 |
-
ge=0,
|
| 326 |
-
le=255,
|
| 327 |
-
description="Alpha Matting (Foreground Threshold)",
|
| 328 |
-
),
|
| 329 |
-
ab: int = Query(
|
| 330 |
-
default=10,
|
| 331 |
-
ge=0,
|
| 332 |
-
le=255,
|
| 333 |
-
description="Alpha Matting (Background Threshold)",
|
| 334 |
-
),
|
| 335 |
-
ae: int = Query(
|
| 336 |
-
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 337 |
-
),
|
| 338 |
-
om: bool = Query(default=False, description="Only Mask"),
|
| 339 |
-
ppm: bool = Query(default=False, description="Post Process Mask"),
|
| 340 |
-
):
|
| 341 |
-
self.model = model
|
| 342 |
-
self.a = a
|
| 343 |
-
self.af = af
|
| 344 |
-
self.ab = ab
|
| 345 |
-
self.ae = ae
|
| 346 |
-
self.om = om
|
| 347 |
-
self.ppm = ppm
|
| 348 |
-
|
| 349 |
-
class CommonQueryPostParams:
|
| 350 |
-
def __init__(
|
| 351 |
-
self,
|
| 352 |
-
model: ModelType = Form(
|
| 353 |
-
default=ModelType.u2net,
|
| 354 |
-
description="Model to use when processing image",
|
| 355 |
-
),
|
| 356 |
-
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
| 357 |
-
af: int = Form(
|
| 358 |
-
default=240,
|
| 359 |
-
ge=0,
|
| 360 |
-
le=255,
|
| 361 |
-
description="Alpha Matting (Foreground Threshold)",
|
| 362 |
-
),
|
| 363 |
-
ab: int = Form(
|
| 364 |
-
default=10,
|
| 365 |
-
ge=0,
|
| 366 |
-
le=255,
|
| 367 |
-
description="Alpha Matting (Background Threshold)",
|
| 368 |
-
),
|
| 369 |
-
ae: int = Form(
|
| 370 |
-
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 371 |
-
),
|
| 372 |
-
om: bool = Form(default=False, description="Only Mask"),
|
| 373 |
-
ppm: bool = Form(default=False, description="Post Process Mask"),
|
| 374 |
-
):
|
| 375 |
-
self.model = model
|
| 376 |
-
self.a = a
|
| 377 |
-
self.af = af
|
| 378 |
-
self.ab = ab
|
| 379 |
-
self.ae = ae
|
| 380 |
-
self.om = om
|
| 381 |
-
self.ppm = ppm
|
| 382 |
-
|
| 383 |
-
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
| 384 |
-
return Response(
|
| 385 |
-
remove(
|
| 386 |
-
content,
|
| 387 |
-
session=sessions.setdefault(
|
| 388 |
-
commons.model.value, new_session(commons.model.value)
|
| 389 |
-
),
|
| 390 |
-
alpha_matting=commons.a,
|
| 391 |
-
alpha_matting_foreground_threshold=commons.af,
|
| 392 |
-
alpha_matting_background_threshold=commons.ab,
|
| 393 |
-
alpha_matting_erode_size=commons.ae,
|
| 394 |
-
only_mask=commons.om,
|
| 395 |
-
post_process_mask=commons.ppm,
|
| 396 |
-
),
|
| 397 |
-
media_type="image/png",
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
@app.on_event("startup")
|
| 401 |
-
def startup():
|
| 402 |
-
if threads is not None:
|
| 403 |
-
from anyio import CapacityLimiter
|
| 404 |
-
from anyio.lowlevel import RunVar
|
| 405 |
-
|
| 406 |
-
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
| 407 |
-
|
| 408 |
-
@app.get(
|
| 409 |
-
path="/",
|
| 410 |
-
tags=["Background Removal"],
|
| 411 |
-
summary="Remove from URL",
|
| 412 |
-
description="Removes the background from an image obtained by retrieving an URL.",
|
| 413 |
-
)
|
| 414 |
-
async def get_index(
|
| 415 |
-
url: str = Query(
|
| 416 |
-
default=..., description="URL of the image that has to be processed."
|
| 417 |
-
),
|
| 418 |
-
commons: CommonQueryParams = Depends(),
|
| 419 |
-
):
|
| 420 |
-
async with aiohttp.ClientSession() as session:
|
| 421 |
-
async with session.get(url) as response:
|
| 422 |
-
file = await response.read()
|
| 423 |
-
return await asyncify(im_without_bg)(file, commons)
|
| 424 |
-
|
| 425 |
-
@app.post(
|
| 426 |
-
path="/",
|
| 427 |
-
tags=["Background Removal"],
|
| 428 |
-
summary="Remove from Stream",
|
| 429 |
-
description="Removes the background from an image sent within the request itself.",
|
| 430 |
-
)
|
| 431 |
-
async def post_index(
|
| 432 |
-
file: bytes = File(
|
| 433 |
-
default=...,
|
| 434 |
-
description="Image file (byte stream) that has to be processed.",
|
| 435 |
-
),
|
| 436 |
-
commons: CommonQueryPostParams = Depends(),
|
| 437 |
-
):
|
| 438 |
-
return await asyncify(im_without_bg)(file, commons)
|
| 439 |
-
|
| 440 |
-
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import click
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from . import _version
|
| 4 |
+
from .commands import command_functions
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@click.group()
|
|
|
|
| 10 |
pass
|
| 11 |
|
| 12 |
|
| 13 |
+
for command in command_functions:
|
| 14 |
+
main.add_command(command)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rembg/commands/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from pkgutil import iter_modules
|
| 4 |
+
|
| 5 |
+
command_functions = []
|
| 6 |
+
|
| 7 |
+
package_dir = Path(__file__).resolve().parent
|
| 8 |
+
for _b, module_name, _p in iter_modules([str(package_dir)]):
|
| 9 |
+
module = import_module(f"{__name__}.{module_name}")
|
| 10 |
+
for attribute_name in dir(module):
|
| 11 |
+
attribute = getattr(module, attribute_name)
|
| 12 |
+
if attribute_name.endswith("_command"):
|
| 13 |
+
command_functions.append(attribute)
|
rembg/commands/i_command.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
from typing import IO
|
| 4 |
+
|
| 5 |
+
import click
|
| 6 |
+
|
| 7 |
+
from ..bg import remove
|
| 8 |
+
from ..session_factory import new_session
|
| 9 |
+
from ..sessions import sessions_names
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@click.command(
|
| 13 |
+
name="i",
|
| 14 |
+
help="for a file as input",
|
| 15 |
+
)
|
| 16 |
+
@click.option(
|
| 17 |
+
"-m",
|
| 18 |
+
"--model",
|
| 19 |
+
default="u2net",
|
| 20 |
+
type=click.Choice(sessions_names),
|
| 21 |
+
show_default=True,
|
| 22 |
+
show_choices=True,
|
| 23 |
+
help="model name",
|
| 24 |
+
)
|
| 25 |
+
@click.option(
|
| 26 |
+
"-a",
|
| 27 |
+
"--alpha-matting",
|
| 28 |
+
is_flag=True,
|
| 29 |
+
show_default=True,
|
| 30 |
+
help="use alpha matting",
|
| 31 |
+
)
|
| 32 |
+
@click.option(
|
| 33 |
+
"-af",
|
| 34 |
+
"--alpha-matting-foreground-threshold",
|
| 35 |
+
default=240,
|
| 36 |
+
type=int,
|
| 37 |
+
show_default=True,
|
| 38 |
+
help="trimap fg threshold",
|
| 39 |
+
)
|
| 40 |
+
@click.option(
|
| 41 |
+
"-ab",
|
| 42 |
+
"--alpha-matting-background-threshold",
|
| 43 |
+
default=10,
|
| 44 |
+
type=int,
|
| 45 |
+
show_default=True,
|
| 46 |
+
help="trimap bg threshold",
|
| 47 |
+
)
|
| 48 |
+
@click.option(
|
| 49 |
+
"-ae",
|
| 50 |
+
"--alpha-matting-erode-size",
|
| 51 |
+
default=10,
|
| 52 |
+
type=int,
|
| 53 |
+
show_default=True,
|
| 54 |
+
help="erode size",
|
| 55 |
+
)
|
| 56 |
+
@click.option(
|
| 57 |
+
"-om",
|
| 58 |
+
"--only-mask",
|
| 59 |
+
is_flag=True,
|
| 60 |
+
show_default=True,
|
| 61 |
+
help="output only the mask",
|
| 62 |
+
)
|
| 63 |
+
@click.option(
|
| 64 |
+
"-ppm",
|
| 65 |
+
"--post-process-mask",
|
| 66 |
+
is_flag=True,
|
| 67 |
+
show_default=True,
|
| 68 |
+
help="post process the mask",
|
| 69 |
+
)
|
| 70 |
+
@click.option(
|
| 71 |
+
"-bgc",
|
| 72 |
+
"--bgcolor",
|
| 73 |
+
default=None,
|
| 74 |
+
type=(int, int, int, int),
|
| 75 |
+
nargs=4,
|
| 76 |
+
help="Background color (R G B A) to replace the removed background with",
|
| 77 |
+
)
|
| 78 |
+
@click.option("-x", "--extras", type=str)
|
| 79 |
+
@click.argument(
|
| 80 |
+
"input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb")
|
| 81 |
+
)
|
| 82 |
+
@click.argument(
|
| 83 |
+
"output",
|
| 84 |
+
default=(None if sys.stdin.isatty() else "-"),
|
| 85 |
+
type=click.File("wb", lazy=True),
|
| 86 |
+
)
|
| 87 |
+
def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None:
|
| 88 |
+
try:
|
| 89 |
+
kwargs.update(json.loads(extras))
|
| 90 |
+
except Exception:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
output.write(remove(input.read(), session=new_session(model), **kwargs))
|
rembg/commands/p_command.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pathlib
|
| 3 |
+
import time
|
| 4 |
+
from typing import cast
|
| 5 |
+
|
| 6 |
+
import click
|
| 7 |
+
import filetype
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from watchdog.events import FileSystemEvent, FileSystemEventHandler
|
| 10 |
+
from watchdog.observers import Observer
|
| 11 |
+
|
| 12 |
+
from ..bg import remove
|
| 13 |
+
from ..session_factory import new_session
|
| 14 |
+
from ..sessions import sessions_names
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@click.command(
|
| 18 |
+
name="p",
|
| 19 |
+
help="for a folder as input",
|
| 20 |
+
)
|
| 21 |
+
@click.option(
|
| 22 |
+
"-m",
|
| 23 |
+
"--model",
|
| 24 |
+
default="u2net",
|
| 25 |
+
type=click.Choice(sessions_names),
|
| 26 |
+
show_default=True,
|
| 27 |
+
show_choices=True,
|
| 28 |
+
help="model name",
|
| 29 |
+
)
|
| 30 |
+
@click.option(
|
| 31 |
+
"-a",
|
| 32 |
+
"--alpha-matting",
|
| 33 |
+
is_flag=True,
|
| 34 |
+
show_default=True,
|
| 35 |
+
help="use alpha matting",
|
| 36 |
+
)
|
| 37 |
+
@click.option(
|
| 38 |
+
"-af",
|
| 39 |
+
"--alpha-matting-foreground-threshold",
|
| 40 |
+
default=240,
|
| 41 |
+
type=int,
|
| 42 |
+
show_default=True,
|
| 43 |
+
help="trimap fg threshold",
|
| 44 |
+
)
|
| 45 |
+
@click.option(
|
| 46 |
+
"-ab",
|
| 47 |
+
"--alpha-matting-background-threshold",
|
| 48 |
+
default=10,
|
| 49 |
+
type=int,
|
| 50 |
+
show_default=True,
|
| 51 |
+
help="trimap bg threshold",
|
| 52 |
+
)
|
| 53 |
+
@click.option(
|
| 54 |
+
"-ae",
|
| 55 |
+
"--alpha-matting-erode-size",
|
| 56 |
+
default=10,
|
| 57 |
+
type=int,
|
| 58 |
+
show_default=True,
|
| 59 |
+
help="erode size",
|
| 60 |
+
)
|
| 61 |
+
@click.option(
|
| 62 |
+
"-om",
|
| 63 |
+
"--only-mask",
|
| 64 |
+
is_flag=True,
|
| 65 |
+
show_default=True,
|
| 66 |
+
help="output only the mask",
|
| 67 |
+
)
|
| 68 |
+
@click.option(
|
| 69 |
+
"-ppm",
|
| 70 |
+
"--post-process-mask",
|
| 71 |
+
is_flag=True,
|
| 72 |
+
show_default=True,
|
| 73 |
+
help="post process the mask",
|
| 74 |
+
)
|
| 75 |
+
@click.option(
|
| 76 |
+
"-w",
|
| 77 |
+
"--watch",
|
| 78 |
+
default=False,
|
| 79 |
+
is_flag=True,
|
| 80 |
+
show_default=True,
|
| 81 |
+
help="watches a folder for changes",
|
| 82 |
+
)
|
| 83 |
+
@click.option(
|
| 84 |
+
"-bgc",
|
| 85 |
+
"--bgcolor",
|
| 86 |
+
default=None,
|
| 87 |
+
type=(int, int, int, int),
|
| 88 |
+
nargs=4,
|
| 89 |
+
help="Background color (R G B A) to replace the removed background with",
|
| 90 |
+
)
|
| 91 |
+
@click.option("-x", "--extras", type=str)
|
| 92 |
+
@click.argument(
|
| 93 |
+
"input",
|
| 94 |
+
type=click.Path(
|
| 95 |
+
exists=True,
|
| 96 |
+
path_type=pathlib.Path,
|
| 97 |
+
file_okay=False,
|
| 98 |
+
dir_okay=True,
|
| 99 |
+
readable=True,
|
| 100 |
+
),
|
| 101 |
+
)
|
| 102 |
+
@click.argument(
|
| 103 |
+
"output",
|
| 104 |
+
type=click.Path(
|
| 105 |
+
exists=False,
|
| 106 |
+
path_type=pathlib.Path,
|
| 107 |
+
file_okay=False,
|
| 108 |
+
dir_okay=True,
|
| 109 |
+
writable=True,
|
| 110 |
+
),
|
| 111 |
+
)
|
| 112 |
+
def p_command(
|
| 113 |
+
model: str,
|
| 114 |
+
extras: str,
|
| 115 |
+
input: pathlib.Path,
|
| 116 |
+
output: pathlib.Path,
|
| 117 |
+
watch: bool,
|
| 118 |
+
**kwargs,
|
| 119 |
+
) -> None:
|
| 120 |
+
try:
|
| 121 |
+
kwargs.update(json.loads(extras))
|
| 122 |
+
except Exception:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
session = new_session(model)
|
| 126 |
+
|
| 127 |
+
def process(each_input: pathlib.Path) -> None:
|
| 128 |
+
try:
|
| 129 |
+
mimetype = filetype.guess(each_input)
|
| 130 |
+
if mimetype is None:
|
| 131 |
+
return
|
| 132 |
+
if mimetype.mime.find("image") < 0:
|
| 133 |
+
return
|
| 134 |
+
|
| 135 |
+
each_output = (output / each_input.name).with_suffix(".png")
|
| 136 |
+
each_output.parents[0].mkdir(parents=True, exist_ok=True)
|
| 137 |
+
|
| 138 |
+
if not each_output.exists():
|
| 139 |
+
each_output.write_bytes(
|
| 140 |
+
cast(
|
| 141 |
+
bytes,
|
| 142 |
+
remove(each_input.read_bytes(), session=session, **kwargs),
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
if watch:
|
| 147 |
+
print(
|
| 148 |
+
f"processed: {each_input.absolute()} -> {each_output.absolute()}"
|
| 149 |
+
)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(e)
|
| 152 |
+
|
| 153 |
+
inputs = list(input.glob("**/*"))
|
| 154 |
+
if not watch:
|
| 155 |
+
inputs = tqdm(inputs)
|
| 156 |
+
|
| 157 |
+
for each_input in inputs:
|
| 158 |
+
if not each_input.is_dir():
|
| 159 |
+
process(each_input)
|
| 160 |
+
|
| 161 |
+
if watch:
|
| 162 |
+
observer = Observer()
|
| 163 |
+
|
| 164 |
+
class EventHandler(FileSystemEventHandler):
|
| 165 |
+
def on_any_event(self, event: FileSystemEvent) -> None:
|
| 166 |
+
if not (
|
| 167 |
+
event.is_directory or event.event_type in ["deleted", "closed"]
|
| 168 |
+
):
|
| 169 |
+
process(pathlib.Path(event.src_path))
|
| 170 |
+
|
| 171 |
+
event_handler = EventHandler()
|
| 172 |
+
observer.schedule(event_handler, input, recursive=False)
|
| 173 |
+
observer.start()
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
while True:
|
| 177 |
+
time.sleep(1)
|
| 178 |
+
|
| 179 |
+
finally:
|
| 180 |
+
observer.stop()
|
| 181 |
+
observer.join()
|
rembg/commands/s_command.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import Annotated, Optional, Tuple, cast
|
| 3 |
+
|
| 4 |
+
import aiohttp
|
| 5 |
+
import click
|
| 6 |
+
import uvicorn
|
| 7 |
+
from asyncer import asyncify
|
| 8 |
+
from fastapi import Depends, FastAPI, File, Form, Query
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from starlette.responses import Response
|
| 11 |
+
|
| 12 |
+
from .._version import get_versions
|
| 13 |
+
from ..bg import remove
|
| 14 |
+
from ..session_factory import new_session
|
| 15 |
+
from ..sessions import sessions_names
|
| 16 |
+
from ..sessions.base import BaseSession
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@click.command(
|
| 20 |
+
name="s",
|
| 21 |
+
help="for a http server",
|
| 22 |
+
)
|
| 23 |
+
@click.option(
|
| 24 |
+
"-p",
|
| 25 |
+
"--port",
|
| 26 |
+
default=5000,
|
| 27 |
+
type=int,
|
| 28 |
+
show_default=True,
|
| 29 |
+
help="port",
|
| 30 |
+
)
|
| 31 |
+
@click.option(
|
| 32 |
+
"-l",
|
| 33 |
+
"--log_level",
|
| 34 |
+
default="info",
|
| 35 |
+
type=str,
|
| 36 |
+
show_default=True,
|
| 37 |
+
help="log level",
|
| 38 |
+
)
|
| 39 |
+
@click.option(
|
| 40 |
+
"-t",
|
| 41 |
+
"--threads",
|
| 42 |
+
default=None,
|
| 43 |
+
type=int,
|
| 44 |
+
show_default=True,
|
| 45 |
+
help="number of worker threads",
|
| 46 |
+
)
|
| 47 |
+
def s_command(port: int, log_level: str, threads: int) -> None:
|
| 48 |
+
sessions: dict[str, BaseSession] = {}
|
| 49 |
+
tags_metadata = [
|
| 50 |
+
{
|
| 51 |
+
"name": "Background Removal",
|
| 52 |
+
"description": "Endpoints that perform background removal with different image sources.",
|
| 53 |
+
"externalDocs": {
|
| 54 |
+
"description": "GitHub Source",
|
| 55 |
+
"url": "https://github.com/danielgatis/rembg",
|
| 56 |
+
},
|
| 57 |
+
},
|
| 58 |
+
]
|
| 59 |
+
app = FastAPI(
|
| 60 |
+
title="Rembg",
|
| 61 |
+
description="Rembg is a tool to remove images background. That is it.",
|
| 62 |
+
version=get_versions()["version"],
|
| 63 |
+
contact={
|
| 64 |
+
"name": "Daniel Gatis",
|
| 65 |
+
"url": "https://github.com/danielgatis",
|
| 66 |
+
"email": "[email protected]",
|
| 67 |
+
},
|
| 68 |
+
license_info={
|
| 69 |
+
"name": "MIT License",
|
| 70 |
+
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
| 71 |
+
},
|
| 72 |
+
openapi_tags=tags_metadata,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
app.add_middleware(
|
| 76 |
+
CORSMiddleware,
|
| 77 |
+
allow_credentials=True,
|
| 78 |
+
allow_origins=["*"],
|
| 79 |
+
allow_methods=["*"],
|
| 80 |
+
allow_headers=["*"],
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
class CommonQueryParams:
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
model: Annotated[
|
| 87 |
+
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
|
| 88 |
+
] = Query(
|
| 89 |
+
description="Model to use when processing image",
|
| 90 |
+
),
|
| 91 |
+
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
| 92 |
+
af: int = Query(
|
| 93 |
+
default=240,
|
| 94 |
+
ge=0,
|
| 95 |
+
le=255,
|
| 96 |
+
description="Alpha Matting (Foreground Threshold)",
|
| 97 |
+
),
|
| 98 |
+
ab: int = Query(
|
| 99 |
+
default=10,
|
| 100 |
+
ge=0,
|
| 101 |
+
le=255,
|
| 102 |
+
description="Alpha Matting (Background Threshold)",
|
| 103 |
+
),
|
| 104 |
+
ae: int = Query(
|
| 105 |
+
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 106 |
+
),
|
| 107 |
+
om: bool = Query(default=False, description="Only Mask"),
|
| 108 |
+
ppm: bool = Query(default=False, description="Post Process Mask"),
|
| 109 |
+
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
| 110 |
+
extras: Optional[str] = Query(
|
| 111 |
+
default=None, description="Extra parameters as JSON"
|
| 112 |
+
),
|
| 113 |
+
):
|
| 114 |
+
self.model = model
|
| 115 |
+
self.a = a
|
| 116 |
+
self.af = af
|
| 117 |
+
self.ab = ab
|
| 118 |
+
self.ae = ae
|
| 119 |
+
self.om = om
|
| 120 |
+
self.ppm = ppm
|
| 121 |
+
self.extras = extras
|
| 122 |
+
self.bgc = (
|
| 123 |
+
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
| 124 |
+
if bgc
|
| 125 |
+
else None
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
class CommonQueryPostParams:
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
model: Annotated[
|
| 132 |
+
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
|
| 133 |
+
] = Form(
|
| 134 |
+
description="Model to use when processing image",
|
| 135 |
+
),
|
| 136 |
+
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
| 137 |
+
af: int = Form(
|
| 138 |
+
default=240,
|
| 139 |
+
ge=0,
|
| 140 |
+
le=255,
|
| 141 |
+
description="Alpha Matting (Foreground Threshold)",
|
| 142 |
+
),
|
| 143 |
+
ab: int = Form(
|
| 144 |
+
default=10,
|
| 145 |
+
ge=0,
|
| 146 |
+
le=255,
|
| 147 |
+
description="Alpha Matting (Background Threshold)",
|
| 148 |
+
),
|
| 149 |
+
ae: int = Form(
|
| 150 |
+
default=10, ge=0, description="Alpha Matting (Erode Structure Size)"
|
| 151 |
+
),
|
| 152 |
+
om: bool = Form(default=False, description="Only Mask"),
|
| 153 |
+
ppm: bool = Form(default=False, description="Post Process Mask"),
|
| 154 |
+
bgc: Optional[str] = Query(default=None, description="Background Color"),
|
| 155 |
+
extras: Optional[str] = Query(
|
| 156 |
+
default=None, description="Extra parameters as JSON"
|
| 157 |
+
),
|
| 158 |
+
):
|
| 159 |
+
self.model = model
|
| 160 |
+
self.a = a
|
| 161 |
+
self.af = af
|
| 162 |
+
self.ab = ab
|
| 163 |
+
self.ae = ae
|
| 164 |
+
self.om = om
|
| 165 |
+
self.ppm = ppm
|
| 166 |
+
self.extras = extras
|
| 167 |
+
self.bgc = (
|
| 168 |
+
cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(","))))
|
| 169 |
+
if bgc
|
| 170 |
+
else None
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response:
|
| 174 |
+
kwargs = {}
|
| 175 |
+
|
| 176 |
+
if commons.extras:
|
| 177 |
+
try:
|
| 178 |
+
kwargs.update(json.loads(commons.extras))
|
| 179 |
+
except Exception:
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
return Response(
|
| 183 |
+
remove(
|
| 184 |
+
content,
|
| 185 |
+
session=sessions.setdefault(commons.model, new_session(commons.model)),
|
| 186 |
+
alpha_matting=commons.a,
|
| 187 |
+
alpha_matting_foreground_threshold=commons.af,
|
| 188 |
+
alpha_matting_background_threshold=commons.ab,
|
| 189 |
+
alpha_matting_erode_size=commons.ae,
|
| 190 |
+
only_mask=commons.om,
|
| 191 |
+
post_process_mask=commons.ppm,
|
| 192 |
+
bgcolor=commons.bgc,
|
| 193 |
+
**kwargs
|
| 194 |
+
),
|
| 195 |
+
media_type="image/png",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
@app.on_event("startup")
|
| 199 |
+
def startup():
|
| 200 |
+
if threads is not None:
|
| 201 |
+
from anyio import CapacityLimiter
|
| 202 |
+
from anyio.lowlevel import RunVar
|
| 203 |
+
|
| 204 |
+
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
| 205 |
+
|
| 206 |
+
@app.get(
|
| 207 |
+
path="/",
|
| 208 |
+
tags=["Background Removal"],
|
| 209 |
+
summary="Remove from URL",
|
| 210 |
+
description="Removes the background from an image obtained by retrieving an URL.",
|
| 211 |
+
)
|
| 212 |
+
async def get_index(
|
| 213 |
+
url: str = Query(
|
| 214 |
+
default=..., description="URL of the image that has to be processed."
|
| 215 |
+
),
|
| 216 |
+
commons: CommonQueryParams = Depends(),
|
| 217 |
+
):
|
| 218 |
+
async with aiohttp.ClientSession() as session:
|
| 219 |
+
async with session.get(url) as response:
|
| 220 |
+
file = await response.read()
|
| 221 |
+
return await asyncify(im_without_bg)(file, commons)
|
| 222 |
+
|
| 223 |
+
@app.post(
|
| 224 |
+
path="/",
|
| 225 |
+
tags=["Background Removal"],
|
| 226 |
+
summary="Remove from Stream",
|
| 227 |
+
description="Removes the background from an image sent within the request itself.",
|
| 228 |
+
)
|
| 229 |
+
async def post_index(
|
| 230 |
+
file: bytes = File(
|
| 231 |
+
default=...,
|
| 232 |
+
description="Image file (byte stream) that has to be processed.",
|
| 233 |
+
),
|
| 234 |
+
commons: CommonQueryPostParams = Depends(),
|
| 235 |
+
):
|
| 236 |
+
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
| 237 |
+
|
| 238 |
+
uvicorn.run(app, host="0.0.0.0", port=port, log_level=log_level)
|
rembg/session_factory.py
CHANGED
|
@@ -1,71 +1,24 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
import os
|
| 3 |
-
import sys
|
| 4 |
-
from contextlib import redirect_stdout
|
| 5 |
-
from pathlib import Path
|
| 6 |
from typing import Type
|
| 7 |
|
| 8 |
import onnxruntime as ort
|
| 9 |
-
import pooch
|
| 10 |
|
| 11 |
-
from .
|
| 12 |
-
from .
|
| 13 |
-
from .
|
| 14 |
|
| 15 |
|
| 16 |
-
def new_session(model_name: str = "u2net") -> BaseSession:
|
| 17 |
-
session_class: Type[BaseSession]
|
| 18 |
-
md5 = "60024c5c889badc19c04ad937298a77b"
|
| 19 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
| 20 |
-
session_class = SimpleSession
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
)
|
| 27 |
-
session_class = SimpleSession
|
| 28 |
-
elif model_name == "u2net_human_seg":
|
| 29 |
-
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
| 30 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
| 31 |
-
session_class = SimpleSession
|
| 32 |
-
elif model_name == "u2net_cloth_seg":
|
| 33 |
-
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
| 34 |
-
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
| 35 |
-
session_class = ClothSession
|
| 36 |
-
elif model_name == "silueta":
|
| 37 |
-
md5 = "55e59e0d8062d2f5d013f4725ee84782"
|
| 38 |
-
url = (
|
| 39 |
-
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
| 40 |
-
)
|
| 41 |
-
session_class = SimpleSession
|
| 42 |
-
|
| 43 |
-
u2net_home = os.getenv(
|
| 44 |
-
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
fname = f"{model_name}.onnx"
|
| 48 |
-
path = Path(u2net_home).expanduser()
|
| 49 |
-
full_path = Path(u2net_home).expanduser() / fname
|
| 50 |
-
|
| 51 |
-
pooch.retrieve(
|
| 52 |
-
url,
|
| 53 |
-
f"md5:{md5}",
|
| 54 |
-
fname=fname,
|
| 55 |
-
path=Path(u2net_home).expanduser(),
|
| 56 |
-
progressbar=True,
|
| 57 |
-
)
|
| 58 |
|
| 59 |
sess_opts = ort.SessionOptions()
|
| 60 |
|
| 61 |
if "OMP_NUM_THREADS" in os.environ:
|
| 62 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
| 63 |
|
| 64 |
-
return session_class(
|
| 65 |
-
model_name,
|
| 66 |
-
ort.InferenceSession(
|
| 67 |
-
str(full_path),
|
| 68 |
-
providers=ort.get_available_providers(),
|
| 69 |
-
sess_options=sess_opts,
|
| 70 |
-
),
|
| 71 |
-
)
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Type
|
| 3 |
|
| 4 |
import onnxruntime as ort
|
|
|
|
| 5 |
|
| 6 |
+
from .sessions import sessions_class
|
| 7 |
+
from .sessions.base import BaseSession
|
| 8 |
+
from .sessions.u2net import U2netSession
|
| 9 |
|
| 10 |
|
| 11 |
+
def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
| 12 |
+
session_class: Type[BaseSession] = U2netSession
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
for sc in sessions_class:
|
| 15 |
+
if sc.name() == model_name:
|
| 16 |
+
session_class = sc
|
| 17 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
sess_opts = ort.SessionOptions()
|
| 20 |
|
| 21 |
if "OMP_NUM_THREADS" in os.environ:
|
| 22 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
| 23 |
|
| 24 |
+
return session_class(model_name, sess_opts, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rembg/sessions/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from importlib import import_module
|
| 2 |
+
from inspect import isclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from pkgutil import iter_modules
|
| 5 |
+
|
| 6 |
+
from .base import BaseSession
|
| 7 |
+
|
| 8 |
+
sessions_class = []
|
| 9 |
+
sessions_names = []
|
| 10 |
+
|
| 11 |
+
package_dir = Path(__file__).resolve().parent
|
| 12 |
+
for _b, module_name, _p in iter_modules([str(package_dir)]):
|
| 13 |
+
module = import_module(f"{__name__}.{module_name}")
|
| 14 |
+
for attribute_name in dir(module):
|
| 15 |
+
attribute = getattr(module, attribute_name)
|
| 16 |
+
if (
|
| 17 |
+
isclass(attribute)
|
| 18 |
+
and issubclass(attribute, BaseSession)
|
| 19 |
+
and attribute != BaseSession
|
| 20 |
+
):
|
| 21 |
+
sessions_class.append(attribute)
|
| 22 |
+
sessions_names.append(attribute.name())
|
rembg/sessions/base.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, List, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BaseSession:
|
| 11 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
| 12 |
+
self.model_name = model_name
|
| 13 |
+
self.inner_session = ort.InferenceSession(
|
| 14 |
+
str(self.__class__.download_models()),
|
| 15 |
+
providers=ort.get_available_providers(),
|
| 16 |
+
sess_options=sess_opts,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def normalize(
|
| 20 |
+
self,
|
| 21 |
+
img: PILImage,
|
| 22 |
+
mean: Tuple[float, float, float],
|
| 23 |
+
std: Tuple[float, float, float],
|
| 24 |
+
size: Tuple[int, int],
|
| 25 |
+
*args,
|
| 26 |
+
**kwargs
|
| 27 |
+
) -> Dict[str, np.ndarray]:
|
| 28 |
+
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
| 29 |
+
|
| 30 |
+
im_ary = np.array(im)
|
| 31 |
+
im_ary = im_ary / np.max(im_ary)
|
| 32 |
+
|
| 33 |
+
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
| 34 |
+
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
| 35 |
+
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
|
| 36 |
+
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
|
| 37 |
+
|
| 38 |
+
tmpImg = tmpImg.transpose((2, 0, 1))
|
| 39 |
+
|
| 40 |
+
return {
|
| 41 |
+
self.inner_session.get_inputs()[0]
|
| 42 |
+
.name: np.expand_dims(tmpImg, 0)
|
| 43 |
+
.astype(np.float32)
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def u2net_home(cls, *args, **kwargs):
|
| 51 |
+
return os.path.expanduser(
|
| 52 |
+
os.getenv(
|
| 53 |
+
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
@classmethod
|
| 58 |
+
def download_models(cls, *args, **kwargs):
|
| 59 |
+
raise NotImplementedError
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def name(cls, *args, **kwargs):
|
| 63 |
+
raise NotImplementedError
|
rembg/sessions/dis.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DisSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 20 |
+
|
| 21 |
+
ma = np.max(pred)
|
| 22 |
+
mi = np.min(pred)
|
| 23 |
+
|
| 24 |
+
pred = (pred - mi) / (ma - mi)
|
| 25 |
+
pred = np.squeeze(pred)
|
| 26 |
+
|
| 27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 29 |
+
|
| 30 |
+
return [mask]
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def download_models(cls, *args, **kwargs):
|
| 34 |
+
fname = f"{cls.name()}.onnx"
|
| 35 |
+
pooch.retrieve(
|
| 36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
| 37 |
+
"md5:fc16ebd8b0c10d971d3513d564d01e29",
|
| 38 |
+
fname=fname,
|
| 39 |
+
path=cls.u2net_home(),
|
| 40 |
+
progressbar=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
def name(cls, *args, **kwargs):
|
| 47 |
+
return "isnet-general-use"
|
rembg/sessions/sam.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
import pooch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from PIL.Image import Image as PILImage
|
| 9 |
+
|
| 10 |
+
from .base import BaseSession
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
| 14 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 15 |
+
newh, neww = oldh * scale, oldw * scale
|
| 16 |
+
neww = int(neww + 0.5)
|
| 17 |
+
newh = int(newh + 0.5)
|
| 18 |
+
return (newh, neww)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
|
| 22 |
+
old_h, old_w = original_size
|
| 23 |
+
new_h, new_w = get_preprocess_shape(
|
| 24 |
+
original_size[0], original_size[1], target_length
|
| 25 |
+
)
|
| 26 |
+
coords = coords.copy().astype(float)
|
| 27 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 28 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 29 |
+
return coords
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def resize_longes_side(img: PILImage, size=1024):
|
| 33 |
+
w, h = img.size
|
| 34 |
+
if h > w:
|
| 35 |
+
new_h, new_w = size, int(w * size / h)
|
| 36 |
+
else:
|
| 37 |
+
new_h, new_w = int(h * size / w), size
|
| 38 |
+
|
| 39 |
+
return img.resize((new_w, new_h))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def pad_to_square(img: np.ndarray, size=1024):
|
| 43 |
+
h, w = img.shape[:2]
|
| 44 |
+
padh = size - h
|
| 45 |
+
padw = size - w
|
| 46 |
+
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
|
| 47 |
+
img = img.astype(np.float32)
|
| 48 |
+
return img
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SamSession(BaseSession):
|
| 52 |
+
def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs):
|
| 53 |
+
self.model_name = model_name
|
| 54 |
+
paths = self.__class__.download_models()
|
| 55 |
+
self.encoder = ort.InferenceSession(
|
| 56 |
+
str(paths[0]),
|
| 57 |
+
providers=ort.get_available_providers(),
|
| 58 |
+
sess_options=sess_opts,
|
| 59 |
+
)
|
| 60 |
+
self.decoder = ort.InferenceSession(
|
| 61 |
+
str(paths[1]),
|
| 62 |
+
providers=ort.get_available_providers(),
|
| 63 |
+
sess_options=sess_opts,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def normalize(
|
| 67 |
+
self,
|
| 68 |
+
img: np.ndarray,
|
| 69 |
+
mean=(123.675, 116.28, 103.53),
|
| 70 |
+
std=(58.395, 57.12, 57.375),
|
| 71 |
+
size=(1024, 1024),
|
| 72 |
+
*args,
|
| 73 |
+
**kwargs,
|
| 74 |
+
):
|
| 75 |
+
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
|
| 76 |
+
pixel_std = np.array([*std]).reshape(1, 1, -1)
|
| 77 |
+
x = (img - pixel_mean) / pixel_std
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
def predict(
|
| 81 |
+
self,
|
| 82 |
+
img: PILImage,
|
| 83 |
+
*args,
|
| 84 |
+
**kwargs,
|
| 85 |
+
) -> List[PILImage]:
|
| 86 |
+
# Preprocess image
|
| 87 |
+
image = resize_longes_side(img)
|
| 88 |
+
image = np.array(image)
|
| 89 |
+
image = self.normalize(image)
|
| 90 |
+
image = pad_to_square(image)
|
| 91 |
+
|
| 92 |
+
input_labels = kwargs.get("input_labels")
|
| 93 |
+
input_points = kwargs.get("input_points")
|
| 94 |
+
|
| 95 |
+
if input_labels is None:
|
| 96 |
+
raise ValueError("input_labels is required")
|
| 97 |
+
if input_points is None:
|
| 98 |
+
raise ValueError("input_points is required")
|
| 99 |
+
|
| 100 |
+
# Transpose
|
| 101 |
+
image = image.transpose(2, 0, 1)[None, :, :, :]
|
| 102 |
+
# Run encoder (Image embedding)
|
| 103 |
+
encoded = self.encoder.run(None, {"x": image})
|
| 104 |
+
image_embedding = encoded[0]
|
| 105 |
+
|
| 106 |
+
# Add a batch index, concatenate a padding point, and transform.
|
| 107 |
+
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
|
| 108 |
+
None, :, :
|
| 109 |
+
]
|
| 110 |
+
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
| 111 |
+
None, :
|
| 112 |
+
].astype(np.float32)
|
| 113 |
+
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
|
| 114 |
+
|
| 115 |
+
# Create an empty mask input and an indicator for no mask.
|
| 116 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
| 117 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
| 118 |
+
|
| 119 |
+
decoder_inputs = {
|
| 120 |
+
"image_embeddings": image_embedding,
|
| 121 |
+
"point_coords": onnx_coord,
|
| 122 |
+
"point_labels": onnx_label,
|
| 123 |
+
"mask_input": onnx_mask_input,
|
| 124 |
+
"has_mask_input": onnx_has_mask_input,
|
| 125 |
+
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
|
| 129 |
+
masks = masks > 0.0
|
| 130 |
+
masks = [
|
| 131 |
+
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
|
| 132 |
+
for i in range(masks.shape[0])
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
return masks
|
| 136 |
+
|
| 137 |
+
@classmethod
|
| 138 |
+
def download_models(cls, *args, **kwargs):
|
| 139 |
+
fname_encoder = f"{cls.name()}_encoder.onnx"
|
| 140 |
+
fname_decoder = f"{cls.name()}_decoder.onnx"
|
| 141 |
+
|
| 142 |
+
pooch.retrieve(
|
| 143 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
| 144 |
+
"md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
| 145 |
+
fname=fname_encoder,
|
| 146 |
+
path=cls.u2net_home(),
|
| 147 |
+
progressbar=True,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
pooch.retrieve(
|
| 151 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
| 152 |
+
"md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
| 153 |
+
fname=fname_decoder,
|
| 154 |
+
path=cls.u2net_home(),
|
| 155 |
+
progressbar=True,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
return (
|
| 159 |
+
os.path.join(cls.u2net_home(), fname_encoder),
|
| 160 |
+
os.path.join(cls.u2net_home(), fname_decoder),
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
@classmethod
|
| 164 |
+
def name(cls, *args, **kwargs):
|
| 165 |
+
return "sam"
|
rembg/sessions/silueta.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SiluetaSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(
|
| 17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
| 18 |
+
),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 22 |
+
|
| 23 |
+
ma = np.max(pred)
|
| 24 |
+
mi = np.min(pred)
|
| 25 |
+
|
| 26 |
+
pred = (pred - mi) / (ma - mi)
|
| 27 |
+
pred = np.squeeze(pred)
|
| 28 |
+
|
| 29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 31 |
+
|
| 32 |
+
return [mask]
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def download_models(cls, *args, **kwargs):
|
| 36 |
+
fname = f"{cls.name()}.onnx"
|
| 37 |
+
pooch.retrieve(
|
| 38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
| 39 |
+
"md5:55e59e0d8062d2f5d013f4725ee84782",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "silueta"
|
rembg/sessions/u2net.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class U2netSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(
|
| 17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
| 18 |
+
),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 22 |
+
|
| 23 |
+
ma = np.max(pred)
|
| 24 |
+
mi = np.min(pred)
|
| 25 |
+
|
| 26 |
+
pred = (pred - mi) / (ma - mi)
|
| 27 |
+
pred = np.squeeze(pred)
|
| 28 |
+
|
| 29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 31 |
+
|
| 32 |
+
return [mask]
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def download_models(cls, *args, **kwargs):
|
| 36 |
+
fname = f"{cls.name()}.onnx"
|
| 37 |
+
pooch.retrieve(
|
| 38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
| 39 |
+
"md5:60024c5c889badc19c04ad937298a77b",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "u2net"
|
rembg/sessions/u2net_cloth_seg.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
from scipy.special import log_softmax
|
| 9 |
+
|
| 10 |
+
from .base import BaseSession
|
| 11 |
+
|
| 12 |
+
pallete1 = [
|
| 13 |
+
0,
|
| 14 |
+
0,
|
| 15 |
+
0,
|
| 16 |
+
255,
|
| 17 |
+
255,
|
| 18 |
+
255,
|
| 19 |
+
0,
|
| 20 |
+
0,
|
| 21 |
+
0,
|
| 22 |
+
0,
|
| 23 |
+
0,
|
| 24 |
+
0,
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
pallete2 = [
|
| 28 |
+
0,
|
| 29 |
+
0,
|
| 30 |
+
0,
|
| 31 |
+
0,
|
| 32 |
+
0,
|
| 33 |
+
0,
|
| 34 |
+
255,
|
| 35 |
+
255,
|
| 36 |
+
255,
|
| 37 |
+
0,
|
| 38 |
+
0,
|
| 39 |
+
0,
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
pallete3 = [
|
| 43 |
+
0,
|
| 44 |
+
0,
|
| 45 |
+
0,
|
| 46 |
+
0,
|
| 47 |
+
0,
|
| 48 |
+
0,
|
| 49 |
+
0,
|
| 50 |
+
0,
|
| 51 |
+
0,
|
| 52 |
+
255,
|
| 53 |
+
255,
|
| 54 |
+
255,
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Unet2ClothSession(BaseSession):
|
| 59 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 60 |
+
ort_outs = self.inner_session.run(
|
| 61 |
+
None,
|
| 62 |
+
self.normalize(
|
| 63 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768)
|
| 64 |
+
),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
pred = ort_outs
|
| 68 |
+
pred = log_softmax(pred[0], 1)
|
| 69 |
+
pred = np.argmax(pred, axis=1, keepdims=True)
|
| 70 |
+
pred = np.squeeze(pred, 0)
|
| 71 |
+
pred = np.squeeze(pred, 0)
|
| 72 |
+
|
| 73 |
+
mask = Image.fromarray(pred.astype("uint8"), mode="L")
|
| 74 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 75 |
+
|
| 76 |
+
masks = []
|
| 77 |
+
|
| 78 |
+
mask1 = mask.copy()
|
| 79 |
+
mask1.putpalette(pallete1)
|
| 80 |
+
mask1 = mask1.convert("RGB").convert("L")
|
| 81 |
+
masks.append(mask1)
|
| 82 |
+
|
| 83 |
+
mask2 = mask.copy()
|
| 84 |
+
mask2.putpalette(pallete2)
|
| 85 |
+
mask2 = mask2.convert("RGB").convert("L")
|
| 86 |
+
masks.append(mask2)
|
| 87 |
+
|
| 88 |
+
mask3 = mask.copy()
|
| 89 |
+
mask3.putpalette(pallete3)
|
| 90 |
+
mask3 = mask3.convert("RGB").convert("L")
|
| 91 |
+
masks.append(mask3)
|
| 92 |
+
|
| 93 |
+
return masks
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def download_models(cls, *args, **kwargs):
|
| 97 |
+
fname = f"{cls.name()}.onnx"
|
| 98 |
+
pooch.retrieve(
|
| 99 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
| 100 |
+
"md5:2434d1f3cb744e0e49386c906e5a08bb",
|
| 101 |
+
fname=fname,
|
| 102 |
+
path=cls.u2net_home(),
|
| 103 |
+
progressbar=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def name(cls, *args, **kwargs):
|
| 110 |
+
return "u2net_cloth_seg"
|
rembg/sessions/u2net_human_seg.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class U2netHumanSegSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(
|
| 17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
| 18 |
+
),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 22 |
+
|
| 23 |
+
ma = np.max(pred)
|
| 24 |
+
mi = np.min(pred)
|
| 25 |
+
|
| 26 |
+
pred = (pred - mi) / (ma - mi)
|
| 27 |
+
pred = np.squeeze(pred)
|
| 28 |
+
|
| 29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 31 |
+
|
| 32 |
+
return [mask]
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def download_models(cls, *args, **kwargs):
|
| 36 |
+
fname = f"{cls.name()}.onnx"
|
| 37 |
+
pooch.retrieve(
|
| 38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
| 39 |
+
"md5:c09ddc2e0104f800e3e1bb4652583d1f",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "u2net_human_seg"
|
rembg/sessions/u2netp.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class U2netpSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(
|
| 17 |
+
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320)
|
| 18 |
+
),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 22 |
+
|
| 23 |
+
ma = np.max(pred)
|
| 24 |
+
mi = np.min(pred)
|
| 25 |
+
|
| 26 |
+
pred = (pred - mi) / (ma - mi)
|
| 27 |
+
pred = np.squeeze(pred)
|
| 28 |
+
|
| 29 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 30 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 31 |
+
|
| 32 |
+
return [mask]
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def download_models(cls, *args, **kwargs):
|
| 36 |
+
fname = f"{cls.name()}.onnx"
|
| 37 |
+
pooch.retrieve(
|
| 38 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
| 39 |
+
"md5:8e83ca70e441ab06c318d82300c84806",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "u2netp"
|