|
|
|
@ -67,8 +67,17 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
|
|
|
|
|
|
|
|
|
# Instantiate Tacotron Model
|
|
|
|
|
print("\nInitialising Tacotron Model...\n")
|
|
|
|
|
num_chars = len(symbols)
|
|
|
|
|
if weights_fpath.exists():
|
|
|
|
|
# for compatibility purpose, change symbols accordingly:
|
|
|
|
|
loaded_shape = torch.load(str(weights_fpath), map_location=device)["model_state"]["encoder.embedding.weight"].shape
|
|
|
|
|
if num_chars != loaded_shape[0]:
|
|
|
|
|
print("WARNING: you are using compatible mode due to wrong sympols length, please modify varible _characters in `utils\symbols.py`")
|
|
|
|
|
num_chars != loaded_shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
|
|
|
|
num_chars=len(symbols),
|
|
|
|
|
num_chars=num_chars,
|
|
|
|
|
encoder_dims=hparams.tts_encoder_dims,
|
|
|
|
|
decoder_dims=hparams.tts_decoder_dims,
|
|
|
|
|
n_mels=hparams.num_mels,
|
|
|
|
|