diff --git a/synthesizer/train.py b/synthesizer/train.py index fc385ef..f327987 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -67,8 +67,17 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, # Instantiate Tacotron Model print("\nInitialising Tacotron Model...\n") + num_chars = len(symbols) + if weights_fpath.exists(): + # for compatibility purpose, change symbols accordingly: + loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape + if num_chars != loaded_shape[0]: + print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`") + num_chars != loaded_shape[0] + + model = Tacotron(embed_dims=hparams.tts_embed_dims, - num_chars=len(symbols), + num_chars=num_chars, encoder_dims=hparams.tts_encoder_dims, decoder_dims=hparams.tts_decoder_dims, n_mels=hparams.num_mels, diff --git a/synthesizer/utils/symbols.py b/synthesizer/utils/symbols.py index d9c3967..2036dde 100644 --- a/synthesizer/utils/symbols.py +++ b/synthesizer/utils/symbols.py @@ -9,6 +9,8 @@ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA _pad = "_" _eos = "~" _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? ' + +#_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz12340!\'(),-.:;? ' # use this old one if you want to train old model # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): #_arpabet = ["@' + s for s in cmudict.valid_symbols]