Skip to content

Commit

Permalink
🎨 format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhzLuke96 committed Jun 28, 2024
1 parent a5f90cb commit f5f483a
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 30 deletions.
17 changes: 8 additions & 9 deletions modules/ChatTTS/ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import os
import logging
import lzma
import os
import tempfile
from dataclasses import dataclass
from typing import Literal, Optional, List, Callable, Tuple, Dict, Union
from json import load
from pathlib import Path
import lzma
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pybase16384 as b14
import torch
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .model import DVAE, GPT, gen_logits
from .norm import Normalizer
from .utils import (
check_all_assets,
del_all,
download_all_assets,
select_device,
get_latest_modified_file,
del_all,
)
from .utils import logger as utils_logger

from .norm import Normalizer
from .utils import select_device


class Chat:
Expand Down
9 changes: 4 additions & 5 deletions modules/ChatTTS/ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from dataclasses import dataclass
import logging
from typing import Union, List, Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

import omegaconf
import torch
Expand All @@ -16,13 +16,12 @@
import torch.nn.utils.parametrize as P
from torch.nn.utils.parametrizations import weight_norm
from tqdm import tqdm
from transformers import LlamaModel, LlamaConfig, LogitsWarper
from transformers import LlamaConfig, LlamaModel, LogitsWarper
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast

from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils import del_all

from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat

"""class LlamaMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size):
Expand Down
4 changes: 2 additions & 2 deletions modules/ChatTTS/ChatTTS/norm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json
import logging
import re
from typing import Dict, Tuple, List, Literal, Callable, Optional
import sys
from typing import Callable, Dict, List, Literal, Optional, Tuple

from numba import jit
import numpy as np
from numba import jit

from .utils import del_all

Expand Down
2 changes: 1 addition & 1 deletion modules/ChatTTS/ChatTTS/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .dl import check_all_assets, download_all_assets
from .gpu import select_device
from .io import get_latest_modified_file, del_all
from .io import del_all, get_latest_modified_file
from .log import logger
11 changes: 6 additions & 5 deletions modules/ChatTTS/ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from pathlib import Path
import hashlib
import requests
import os
from io import BytesIO
from mmap import ACCESS_READ, mmap
from pathlib import Path
from typing import Dict
from mmap import mmap, ACCESS_READ

import requests

from .log import logger

Expand Down Expand Up @@ -118,8 +119,8 @@ def download_dns_yaml(url: str, folder: str):


def download_all_assets(tmpdir: str, version="0.2.5"):
import subprocess
import platform
import subprocess

archs = {
"aarch64": "arm64",
Expand Down
4 changes: 2 additions & 2 deletions modules/ChatTTS/ChatTTS/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import logging
from typing import Union
import os
from dataclasses import is_dataclass
from typing import Union

from .log import logger

Expand Down
3 changes: 2 additions & 1 deletion modules/ChatTTSInfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import numpy as np
import torch

from modules import config
from modules.ChatTTS.ChatTTS.core import Chat
from modules.utils.monkey_tqdm import disable_tqdm
from modules import config


# 主要解决类型问题
Expand Down
2 changes: 1 addition & 1 deletion modules/Enhancer/ResembleEnhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import numpy as np
import torch

from modules import config
from modules.devices import devices
from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
from modules.repos_static.resemble_enhance.inference import inference
from modules.utils.constants import MODELS_DIR
from modules.utils.monkey_tqdm import disable_tqdm
from modules import config

logger = logging.getLogger(__name__)

Expand Down
6 changes: 2 additions & 4 deletions modules/api/impl/handler/AudioHandler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import base64
import io
from typing import Generator
import wave
from typing import Generator

import numpy as np
import soundfile as sf
from pydub import AudioSegment

from modules.api import utils as api_utils
from modules.api.impl.model.audio_model import AudioFormat

from pydub import AudioSegment

from modules.utils.audio import ndarray_to_segment


Expand Down
1 change: 1 addition & 0 deletions modules/utils/monkey_tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def exit_tqdm(self, exc_type, exc_value, traceback):

if __name__ == "__main__":
import time

from tqdm import tqdm

with disable_tqdm():
Expand Down

0 comments on commit f5f483a

Please sign in to comment.