diff --git a/modules/ChatTTS/ChatTTS/core.py b/modules/ChatTTS/ChatTTS/core.py index 29daaf2..7da39b8 100644 --- a/modules/ChatTTS/ChatTTS/core.py +++ b/modules/ChatTTS/ChatTTS/core.py @@ -1,31 +1,30 @@ -import os import logging +import lzma +import os import tempfile from dataclasses import dataclass -from typing import Literal, Optional, List, Callable, Tuple, Dict, Union from json import load from pathlib import Path -import lzma +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np +import pybase16384 as b14 import torch import torch.nn.functional as F +from huggingface_hub import snapshot_download from omegaconf import OmegaConf from vocos import Vocos -from huggingface_hub import snapshot_download -import pybase16384 as b14 from .model import DVAE, GPT, gen_logits +from .norm import Normalizer from .utils import ( check_all_assets, + del_all, download_all_assets, - select_device, get_latest_modified_file, - del_all, ) from .utils import logger as utils_logger - -from .norm import Normalizer +from .utils import select_device class Chat: diff --git a/modules/ChatTTS/ChatTTS/model/gpt.py b/modules/ChatTTS/ChatTTS/model/gpt.py index 710b3d2..38a2d31 100644 --- a/modules/ChatTTS/ChatTTS/model/gpt.py +++ b/modules/ChatTTS/ChatTTS/model/gpt.py @@ -5,9 +5,9 @@ https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning """ -from dataclasses import dataclass import logging -from typing import Union, List, Optional, Tuple +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union import omegaconf import torch @@ -16,13 +16,12 @@ import torch.nn.utils.parametrize as P from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm -from transformers import LlamaModel, LlamaConfig, LogitsWarper +from transformers import LlamaConfig, LlamaModel, LogitsWarper from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast -from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat from ..utils import del_all - +from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat """class LlamaMLP(nn.Module): def __init__(self, hidden_size, intermediate_size): diff --git a/modules/ChatTTS/ChatTTS/norm.py b/modules/ChatTTS/ChatTTS/norm.py index bf6e7b8..bd60318 100644 --- a/modules/ChatTTS/ChatTTS/norm.py +++ b/modules/ChatTTS/ChatTTS/norm.py @@ -1,11 +1,11 @@ import json import logging import re -from typing import Dict, Tuple, List, Literal, Callable, Optional import sys +from typing import Callable, Dict, List, Literal, Optional, Tuple -from numba import jit import numpy as np +from numba import jit from .utils import del_all diff --git a/modules/ChatTTS/ChatTTS/utils/__init__.py b/modules/ChatTTS/ChatTTS/utils/__init__.py index 9697dd7..7351879 100644 --- a/modules/ChatTTS/ChatTTS/utils/__init__.py +++ b/modules/ChatTTS/ChatTTS/utils/__init__.py @@ -1,4 +1,4 @@ from .dl import check_all_assets, download_all_assets from .gpu import select_device -from .io import get_latest_modified_file, del_all +from .io import del_all, get_latest_modified_file from .log import logger diff --git a/modules/ChatTTS/ChatTTS/utils/dl.py b/modules/ChatTTS/ChatTTS/utils/dl.py index 31a15fa..8451b2b 100644 --- a/modules/ChatTTS/ChatTTS/utils/dl.py +++ b/modules/ChatTTS/ChatTTS/utils/dl.py @@ -1,10 +1,11 @@ -import os -from pathlib import Path import hashlib -import requests +import os from io import BytesIO +from mmap import ACCESS_READ, mmap +from pathlib import Path from typing import Dict -from mmap import mmap, ACCESS_READ + +import requests from .log import logger @@ -118,8 +119,8 @@ def download_dns_yaml(url: str, folder: str): def download_all_assets(tmpdir: str, version="0.2.5"): - import subprocess import platform + import subprocess archs = { "aarch64": "arm64", diff --git a/modules/ChatTTS/ChatTTS/utils/io.py b/modules/ChatTTS/ChatTTS/utils/io.py index b37f939..3d03430 100644 --- a/modules/ChatTTS/ChatTTS/utils/io.py +++ b/modules/ChatTTS/ChatTTS/utils/io.py @@ -1,7 +1,7 @@ -import os import logging -from typing import Union +import os from dataclasses import is_dataclass +from typing import Union from .log import logger diff --git a/modules/ChatTTSInfer.py b/modules/ChatTTSInfer.py index ca7313d..da59e1d 100644 --- a/modules/ChatTTSInfer.py +++ b/modules/ChatTTSInfer.py @@ -2,9 +2,10 @@ import numpy as np import torch + +from modules import config from modules.ChatTTS.ChatTTS.core import Chat from modules.utils.monkey_tqdm import disable_tqdm -from modules import config # 主要解决类型问题 diff --git a/modules/Enhancer/ResembleEnhance.py b/modules/Enhancer/ResembleEnhance.py index 7344cea..a29ca4d 100644 --- a/modules/Enhancer/ResembleEnhance.py +++ b/modules/Enhancer/ResembleEnhance.py @@ -7,13 +7,13 @@ import numpy as np import torch +from modules import config from modules.devices import devices from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer from modules.repos_static.resemble_enhance.enhancer.hparams import HParams from modules.repos_static.resemble_enhance.inference import inference from modules.utils.constants import MODELS_DIR from modules.utils.monkey_tqdm import disable_tqdm -from modules import config logger = logging.getLogger(__name__) diff --git a/modules/api/impl/handler/AudioHandler.py b/modules/api/impl/handler/AudioHandler.py index 8503b1a..4f6da9a 100644 --- a/modules/api/impl/handler/AudioHandler.py +++ b/modules/api/impl/handler/AudioHandler.py @@ -1,16 +1,14 @@ import base64 import io -from typing import Generator import wave +from typing import Generator import numpy as np import soundfile as sf +from pydub import AudioSegment from modules.api import utils as api_utils from modules.api.impl.model.audio_model import AudioFormat - -from pydub import AudioSegment - from modules.utils.audio import ndarray_to_segment diff --git a/modules/utils/monkey_tqdm.py b/modules/utils/monkey_tqdm.py index f145e59..be02bed 100644 --- a/modules/utils/monkey_tqdm.py +++ b/modules/utils/monkey_tqdm.py @@ -73,6 +73,7 @@ def exit_tqdm(self, exc_type, exc_value, traceback): if __name__ == "__main__": import time + from tqdm import tqdm with disable_tqdm():