Skip to content

Commit

Permalink
Replace pooch retrieve calls with create followed by fetch to benefit…
Browse files Browse the repository at this point in the history
… from the retry_if_failed feature
  • Loading branch information
dimitribarbot committed Aug 26, 2024
1 parent 95b8114 commit 379c16a
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 161 deletions.
47 changes: 22 additions & 25 deletions rembg/sessions/birefnet_cod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os

import pooch

from . import BiRefNetSessionGeneral


Expand All @@ -11,42 +7,43 @@ class BiRefNetSessionCOD(BiRefNetSessionGeneral):
"""

@classmethod
def download_models(cls, *args, **kwargs):
def name(cls, *args, **kwargs):
"""
Downloads the BiRefNet-COD model file from a specific URL and saves it.
Returns the name of the BiRefNet-COD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
str: The name of the session.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)
return "birefnet-cod"

@classmethod
def url_fname(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-COD file in the model url.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the model file in the model url.
"""
return "BiRefNet-COD-epoch_125.onnx"

@classmethod
def name(cls, *args, **kwargs):
def model_md5(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-COD session.
Returns the md5 of the BiRefNet-COD model file.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
str: The md5 of the model file.
"""
return "birefnet-cod"
return "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
47 changes: 22 additions & 25 deletions rembg/sessions/birefnet_dis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os

import pooch

from . import BiRefNetSessionGeneral


Expand All @@ -11,42 +7,43 @@ class BiRefNetSessionDIS(BiRefNetSessionGeneral):
"""

@classmethod
def download_models(cls, *args, **kwargs):
def name(cls, *args, **kwargs):
"""
Downloads the BiRefNet-DIS model file from a specific URL and saves it.
Returns the name of the BiRefNet-DIS session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
str: The name of the session.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)
return "birefnet-dis"

@classmethod
def url_fname(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-DIS file in the model url.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the model file in the model url.
"""
return "BiRefNet-DIS-epoch_590.onnx"

@classmethod
def name(cls, *args, **kwargs):
def model_md5(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-DIS session.
Returns the md5 of the BiRefNet-DIS model file.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
str: The md5 of the model file.
"""
return "birefnet-dis"
return "md5:2d4d44102b446f33a4ebb2e56c051f2b"
62 changes: 51 additions & 11 deletions rembg/sessions/birefnet_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class BiRefNetSessionGeneral(BaseSession):
This class represents a BiRefNet-General session, which is a subclass of BaseSession.
"""

base_url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/"

def sigmoid(self, mat):
return 1 / (1 + np.exp(-mat))

Expand Down Expand Up @@ -62,19 +64,29 @@ def download_models(cls, *args, **kwargs):
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:7a35a0141cbbc80de11d9c9a28f52697"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
url = "".join([cls.base_url, cls.url_fname(*args, **kwargs)])
path = cls.u2net_home(*args, **kwargs)
pooch_instance = pooch.create(
path=path,
base_url=cls.base_url,
registry={
fname: (
None
if cls.checksum_disabled(*args, **kwargs)
else cls.model_hash(*args, **kwargs)
)
},
urls={
fname: url
},
retry_if_failed=2
)
pooch_instance.fetch(
fname,
progressbar=True
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)
return os.path.join(path, fname)

@classmethod
def name(cls, *args, **kwargs):
Expand All @@ -89,3 +101,31 @@ def name(cls, *args, **kwargs):
str: The name of the session.
"""
return "birefnet-general"

@classmethod
def url_fname(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General file in the model url.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the model file in the model url.
"""
return "BiRefNet-general-epoch_244.onnx"

@classmethod
def model_hash(cls, *args, **kwargs):
"""
Returns the hash of the BiRefNet-General model file.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The hash of the model file.
"""
return "md5:7a35a0141cbbc80de11d9c9a28f52697"
47 changes: 22 additions & 25 deletions rembg/sessions/birefnet_general_lite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os

import pooch

from . import BiRefNetSessionGeneral


Expand All @@ -11,42 +7,43 @@ class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
"""

@classmethod
def download_models(cls, *args, **kwargs):
def name(cls, *args, **kwargs):
"""
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
Returns the name of the BiRefNet-General-Lite session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
str: The name of the session.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:4fab47adc4ff364be1713e97b7e66334"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)
return "birefnet-general-lite"

@classmethod
def url_fname(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General-Lite file in the model url.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the model file in the model url.
"""
return "BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx"

@classmethod
def name(cls, *args, **kwargs):
def model_md5(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General-Lite session.
Returns the md5 of the BiRefNet-General-Lite model file.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
str: The md5 of the model file.
"""
return "birefnet-general-lite"
return "md5:4fab47adc4ff364be1713e97b7e66334"
47 changes: 22 additions & 25 deletions rembg/sessions/birefnet_hrsod.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import os

import pooch

from . import BiRefNetSessionGeneral


Expand All @@ -11,42 +7,43 @@ class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
"""

@classmethod
def download_models(cls, *args, **kwargs):
def name(cls, *args, **kwargs):
"""
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
Returns the name of the BiRefNet-HRSOD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
str: The name of the session.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:c017ade5de8a50ff0fd74d790d268dda"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)
return "birefnet-hrsod"

@classmethod
def url_fname(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-HRSOD file in the model url.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the model file in the model url.
"""
return "BiRefNet-HRSOD_DHU-epoch_115.onnx"

@classmethod
def name(cls, *args, **kwargs):
def model_md5(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-HRSOD session.
Returns the md5 of the BiRefNet-HRSOD model file.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
str: The md5 of the model file.
"""
return "birefnet-hrsod"
return "md5:c017ade5de8a50ff0fd74d790d268dda"
Loading

0 comments on commit 379c16a

Please sign in to comment.