Skip to content

Commit

Permalink
feat:update model loader
Browse files Browse the repository at this point in the history
feat:update model loader

feat:update model loader
  • Loading branch information
qinguoyi committed Oct 20, 2024
1 parent 7652e3d commit a26bf30
Show file tree
Hide file tree
Showing 24 changed files with 220 additions and 131 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ cache
__pycache__
*.pyc
.pytest_cache
*.tgz
*.tgz
.huggingface
6 changes: 6 additions & 0 deletions api/core/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ type ModelHub struct {
// +kubebuilder:default=main
// +optional
Revision *string `json:"revision,omitempty"`
// AllowPatterns refers to only files matching at least one pattern are downloaded.
// +optional
AllowPatterns []string `json:"allowPatterns,omitempty"`
// IgnorePatterns refers to files matching any of the patterns are not downloaded.
// +optional
IgnorePatterns []string `json:"ignorePatterns,omitempty"`
}

// URIProtocol represents the protocol of the URI.
Expand Down
10 changes: 10 additions & 0 deletions api/core/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 26 additions & 4 deletions client-go/applyconfiguration/core/v1alpha1/modelhub.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions config/crd/bases/llmaz.io_openmodels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,25 @@ spec:
description: ModelHub represents the model registry for model
downloads.
properties:
allowPatterns:
description: AllowPatterns refers to only files matching at
least one pattern are downloaded.
items:
type: string
type: array
filename:
description: |-
Filename refers to a specified model file rather than the whole repo.
This is helpful to download a specified GGUF model rather than downloading
the whole repo which includes all kinds of quantized models.
in the near future.
type: string
ignorePatterns:
description: IgnorePatterns refers to files matching any of
the patterns are not downloaded.
items:
type: string
type: array
modelID:
description: |-
ModelID refers to the model identifier on model hub,
Expand Down
33 changes: 20 additions & 13 deletions llmaz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,39 @@
import os
from datetime import datetime

from llmaz.model_loader.constant import *

from llmaz.model_loader.objstore.objstore import model_download
from llmaz.model_loader.model_hub.hub_factory import HubFactory
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE
from llmaz.model_loader.model_hub.huggingface import HUB_HUGGING_FACE
from llmaz.util.logger import Logger


if __name__ == "__main__":
model_source_type = os.getenv("MODEL_SOURCE_TYPE")
model_source_type = os.getenv(ENV_HUB_MODEL_SOURCE_TYPE)
start_time = datetime.now()

if model_source_type == "modelhub":
hub_name = os.getenv("MODEL_HUB_NAME", HUGGING_FACE)
revision = os.getenv("REVISION")
model_id = os.getenv("MODEL_ID")
model_file_name = os.getenv("MODEL_FILENAME")
hub_name = os.getenv(ENV_HUB_MODEL_HUB_NAME, HUB_HUGGING_FACE)
revision = os.getenv(ENV_HUB_REVISION)
model_id = os.getenv(ENV_HUB_MODEL_ID)
model_file_name = os.getenv(ENV_HUB_MODEL_FILENAME)
model_allow_patterns = os.getenv(ENV_HUB_MODEL_ALLOW_PATTERNS)
model_ignore_patterns = os.getenv(ENV_HUB_MODEL_IGNORE_PATTERNS)

if not model_id:
raise EnvironmentError(f"Environment variable '{model_id}' not found.")

hub = HubFactory.new(hub_name)
hub.load_model(model_id, model_file_name, revision)
model_allow_patterns_list, model_ignore_patterns_list = [], []
if model_allow_patterns:
model_allow_patterns_list = model_allow_patterns.split(',')
if model_ignore_patterns:
model_ignore_patterns_list = model_ignore_patterns.split(',')
hub.load_model(model_id, model_file_name, revision, model_allow_patterns_list, model_ignore_patterns_list)
elif model_source_type == "objstore":
provider = os.getenv("PROVIDER")
endpoint = os.getenv("ENDPOINT")
bucket = os.getenv("BUCKET")
src = os.getenv("MODEL_PATH")
provider = os.getenv(ENV_OBJ_PROVIDER)
endpoint = os.getenv(ENV_OBJ_ENDPOINT)
bucket = os.getenv(ENV_OBJ_BUCKET)
src = os.getenv(ENV_OBJ_MODEL_PATH)

model_download(provider=provider, endpoint=endpoint, bucket=bucket, src=src)
else:
Expand Down
16 changes: 16 additions & 0 deletions llmaz/model_loader/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
MODEL_LOCAL_DIR = "/workspace/models/"
HUB_HUGGING_FACE = "Huggingface"
HUB_MODEL_SCOPE = "ModelScope"

ENV_HUB_MODEL_SOURCE_TYPE = "MODEL_SOURCE_TYPE"
ENV_HUB_MODEL_HUB_NAME = "MODEL_HUB_NAME"
ENV_HUB_REVISION = "REVISION"
ENV_HUB_MODEL_ID = "MODEL_ID"
ENV_HUB_MODEL_FILENAME = "MODEL_FILENAME"
ENV_HUB_MODEL_ALLOW_PATTERNS = "MODEL_ALLOW_PATTERNS"
ENV_HUB_MODEL_IGNORE_PATTERNS = "MODEL_IGNORE_PATTERNS"

ENV_OBJ_PROVIDER = "PROVIDER"
ENV_OBJ_ENDPOINT = "ENDPOINT"
ENV_OBJ_BUCKET = "BUCKET"
ENV_OBJ_MODEL_PATH = "MODEL_PATH"
1 change: 0 additions & 1 deletion llmaz/model_loader/defaults.py

This file was deleted.

12 changes: 6 additions & 6 deletions llmaz/model_loader/model_hub/hub_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

from llmaz.model_loader.constant import HUB_HUGGING_FACE, HUB_MODEL_SCOPE
from llmaz.model_loader.model_hub.model_hub import ModelHub
from llmaz.model_loader.model_hub.huggingface import HUGGING_FACE, Huggingface
from llmaz.model_loader.model_hub.modelscope import MODEL_SCOPE, ModelScope

from llmaz.model_loader.model_hub.huggingface import Huggingface
from llmaz.model_loader.model_hub.modelscope import ModelScope

SUPPORT_MODEL_HUBS = {
HUGGING_FACE: Huggingface,
MODEL_SCOPE: ModelScope,
HUB_HUGGING_FACE: Huggingface,
HUB_MODEL_SCOPE: ModelScope,
}


class HubFactory:

@classmethod
def new(cls, hub_name: str) -> ModelHub:
if hub_name not in SUPPORT_MODEL_HUBS.keys():
Expand Down
64 changes: 23 additions & 41 deletions llmaz/model_loader/model_hub/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,51 @@
import concurrent.futures
import os

from huggingface_hub import hf_hub_download, list_repo_files
from huggingface_hub import snapshot_download

from llmaz.model_loader.defaults import MODEL_LOCAL_DIR
from llmaz.model_loader.constant import MODEL_LOCAL_DIR, HUB_HUGGING_FACE
from llmaz.model_loader.model_hub.model_hub import (
HUGGING_FACE,
MAX_WORKERS,
ModelHub,
)
from llmaz.util.logger import Logger
from llmaz.model_loader.model_hub.util import get_folder_total_size

from typing import Optional
from typing import Optional, List


class Huggingface(ModelHub):
@classmethod
def name(cls) -> str:
return HUGGING_FACE
return HUB_HUGGING_FACE

@classmethod
def load_model(
cls, model_id: str, filename: Optional[str], revision: Optional[str]
cls,
model_id: str,
filename: Optional[str],
revision: Optional[str],
allow_patterns: Optional[List[str]],
ignore_patterns: Optional[List[str]],
) -> None:
Logger.info(
f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}"
)

if filename:
hf_hub_download(
repo_id=model_id,
filename=filename,
local_dir=MODEL_LOCAL_DIR,
revision=revision,
)
file_size = os.path.getsize(MODEL_LOCAL_DIR + filename) / (1024**3)
Logger.info(
f"The total size of {MODEL_LOCAL_DIR + filename} is {file_size: .2f} GB"
)
return

local_dir = os.path.join(
MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}"
MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}"
)

# # TODO: Should we verify the download is finished?
with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
for file in list_repo_files(repo_id=model_id):
# TODO: support version management, right now we didn't distinguish with them.
futures.append(
executor.submit(
hf_hub_download,
repo_id=model_id,
filename=file,
local_dir=local_dir,
revision=revision,
).add_done_callback(handle_completion)
)
if filename:
allow_patterns.append(filename)
local_dir = MODEL_LOCAL_DIR

snapshot_download(
repo_id=model_id,
revision=revision,
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)

total_size = get_folder_total_size(local_dir)
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")


def handle_completion(future):
filename = future.result()
Logger.info(f"Download completed for {filename}")
Logger.info(f"The total size of {local_dir} is {total_size: .2f} GB")
13 changes: 7 additions & 6 deletions llmaz/model_loader/model_hub/model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
"""

from abc import ABC, abstractmethod
from typing import Optional

MAX_WORKERS = 4
HUGGING_FACE = "Huggingface"
MODEL_SCOPE = "ModelScope"
from typing import Optional, List


class ModelHub(ABC):
Expand All @@ -31,6 +27,11 @@ def name(cls) -> str:
@classmethod
@abstractmethod
def load_model(
cls, model_id: str, filename: Optional[str], revision: Optional[str]
cls,
model_id: str,
filename: Optional[str],
revision: Optional[str],
allow_patterns: Optional[List[str]],
ignore_patterns: Optional[List[str]],
) -> None:
pass
44 changes: 20 additions & 24 deletions llmaz/model_loader/model_hub/modelscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
"""

import os
import concurrent.futures
from typing import Optional
from typing import Optional, List

from modelscope import snapshot_download

from llmaz.model_loader.defaults import MODEL_LOCAL_DIR
from llmaz.model_loader.constant import *
from llmaz.model_loader.model_hub.model_hub import (
MAX_WORKERS,
MODEL_SCOPE,
ModelHub,
)
from llmaz.util.logger import Logger
Expand All @@ -33,36 +30,35 @@
class ModelScope(ModelHub):
@classmethod
def name(cls) -> str:
return MODEL_SCOPE
return HUB_MODEL_SCOPE

# TODO: support filename
@classmethod
def load_model(
cls, model_id: str, filename: Optional[str], revision: Optional[str]
cls,
model_id: str,
filename: Optional[str],
revision: Optional[str],
allow_patterns: Optional[List[str]],
ignore_patterns: Optional[List[str]],
) -> None:
Logger.info(
f"Start to download, model_id: {model_id}, filename: {filename}, revision: {revision}"
)

local_dir = os.path.join(
MODEL_LOCAL_DIR, f"models--{model_id.replace('/','--')}"
MODEL_LOCAL_DIR, f"models--{model_id.replace('/', '--')}"
)

with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
futures = []
futures.append(
executor.submit(
snapshot_download,
model_id=model_id,
local_dir=local_dir,
revision=revision,
).add_done_callback(handle_completion)
)
if filename:
allow_patterns.append(filename)
local_dir = MODEL_LOCAL_DIR

snapshot_download(
model_id=model_id,
revision=revision,
local_dir=local_dir,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
total_size = get_folder_total_size(local_dir)
Logger.info(f"The total size of {local_dir} is {total_size:.2f} GB")


def handle_completion(future):
filename = future.result()
Logger.info(f"Download completed for {filename}")
Loading

0 comments on commit a26bf30

Please sign in to comment.