-
Tacotron 무지성 구현 - 6/NTacotron 1 2021. 8. 2. 13:24
<수정중>
저번 포스팅에서 간단한 Tacotron 모델 구성을 끝냈습니다.
이번 포스팅에서는 학습과 관련된 코드를 구현해서
하나의 클래스로 관리할 수 있게끔 하는 것이 목표입니다.
필수적인 코드만 먼저 업로드하고,
부분적으로 설명한 뒤에
다음 포스팅에서 로깅과 관련된 코드도
추가해보도록 하겠습니다.
Hyper Parameters
import os, torch class Hparams(): # speaker name speaker = 'KSS' # Audio Pre-processing origin_sample_rate = 44100 sample_rate = 22050 n_fft = 1024 hop_length = 256 win_length = 1024 n_mels = 80 reduction = 5 n_specs = n_fft // 2 + 1 fmin = 0 fmax = sample_rate // 2 min_level_db = -80 ref_level_db = 0 # Text Pre-processing PAD = '_' EOS = '~' SPACE = ' ' SPECIAL = '.,!?' JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) symbols = PAD + EOS + JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + SPACE + SPECIAL _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} # Pre-processing paths (text, mel, spec) data_dir = os.path.join('data') out_texts_dir = os.path.join(data_dir, 'texts') out_mels_dir = os.path.join(data_dir, 'mels') out_specs_dir = os.path.join(data_dir, 'specs') # Embedding Layer in_dim = 256 # Encoder Pre-net Layer prenet_dropout_ratio = 0.5 prenet_linear_size = 256 # CBHG cbhg_K = 16 cbhg_mp_k = 2 cbhg_mp_s = 1 cbhg_mp_p = 1 cbhg_mp_d = 2 cbhg_conv_proj_size = 128 cbhg_conv_proj_k = 3 cbhg_conv_proj_p = 1 cbhg_gru_hidden_dim = 128 # Decoder decoder_rnn_dim = 256 # Train batch_size = 32 split_ratio = 0.2 init_learning_rate = 0.001 beta1 = 0.9 beta2 = 0.999 seed = 777 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') saved_model_dir = os.path.join(data_dir, 'saved_models', speaker) griffin_lim_iters = 32
조금 달라진 점은
import 할 때 torch를 불러오는 것과
하단 주석(#Train) 부분에 추가된 인자들인데요,
논문에서 batch_size를 32로 했다는 점
Train data, Validation data를 구분짓기 위한 split_ratio
Optimizer 설정 값으로 들어갈 learning_rate, betas
동일한 학습 과정을 유지하기 위한 seed 값 설정
GPU 학습을 위한 device 설정
학습 도중(또는 종료 시점)에
모델 파라미터 저장을 위한 saved_model_dir 설정
앞으로 더 추가될 인자들이 있겠지만
이번 포스팅에서는 언급한 인자들만 필요로하므로
더 다루진 않겠습니다.
아, 이전까지는 Embedding Layer에 들어가는
in_dim 인자 값을 512로 두고 했는데,
앞선 포스팅에서 언급했듯이
학습 과정에서는 in_dim 값을 256으로 설정하고
진행할 예정입니다.
Dataset
import os, glob, torch class KSSDatasetPath(): def __init__(self, Hparams): self.Hparams = Hparams # Original data self.text_paths = glob.glob(os.path.join(self.Hparams.out_texts_dir, '*.pt')) self.mel_paths = glob.glob(os.path.join(self.Hparams.out_mels_dir, '*.pt')) self.spec_paths = glob.glob(os.path.join(self.Hparams.out_specs_dir, '*.pt')) self.original_len = len(self.text_paths) # Splited Data Length self.val_len = int(self.original_len * self.Hparams.split_ratio) if self.val_len % self.Hparams.batch_size != 0: extra_num = self.Hparams.batch_size - self.val_len % self.Hparams.batch_size self.val_len = self.val_len + extra_num self.trn_len = self.original_len - self.val_len # Train data, Validation data self.trn_text_paths, self.val_text_paths = (self.text_paths[:self.trn_len], self.text_paths[self.trn_len:]) self.trn_mel_paths, self.val_mel_paths = (self.mel_paths[:self.trn_len], self.mel_paths[self.trn_len:]) self.trn_spec_paths, self.val_spec_paths = (self.spec_paths[:self.trn_len], self.spec_paths[self.trn_len:]) class KSSTrainDataset(torch.utils.data.Dataset, KSSDatasetPath): def __init__(self): KSSDatasetPath.__init__(self, Hparams) def __len__(self): return self.trn_len def __getitem__(self, idx): # Train data self.trn_texts = torch.LongTensor(torch.load(self.trn_text_paths[idx])) self.trn_mels = torch.FloatTensor(torch.load(self.trn_mel_paths[idx])) self.trn_specs = torch.FloatTensor(torch.load(self.trn_spec_paths[idx])) return (self.trn_texts, self.trn_mels, self.trn_specs) class KSSValidateDataset(torch.utils.data.Dataset, KSSDatasetPath): def __init__(self): KSSDatasetPath.__init__(self, Hparams) def __len__(self): return self.val_len def __getitem__(self, idx): # Validate data self.val_texts = torch.LongTensor(torch.load(self.val_text_paths[idx])) self.val_mels = torch.FloatTensor(torch.load(self.val_mel_paths[idx])) self.val_specs = torch.FloatTensor(torch.load(self.val_spec_paths[idx])) return (self.val_texts, self.val_mels, self.val_specs) def collate_fn(batch): texts, mels, specs = zip(*batch) text_pads = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True) mel_pads = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True) spec_pads = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True) mel_lengths = torch.LongTensor([mel.size(0) for mel in mels]) return (text_pads.contiguous(), mel_pads.contiguous(), spec_pads.contiguous(), mel_lengths.contiguous())
trn_dataset = KSSTrainDataset() val_dataset = KSSValidateDataset() trn_dataloader = torch.utils.data.DataLoader(trn_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) for step, (trn_values, val_values) in enumerate(zip(trn_dataloader, val_dataloader)): trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values val_texts, val_mels, val_specs, val_mel_lengths = val_values print("Trn Texts: {} \nTrn Mels: {} \nTrn Specs: {} \nTrn Mel-Lengths: {} \n".format( trn_texts.size(), trn_mels.size(), trn_specs.size(), trn_mel_lengths)) print("Val Texts: {} \nVal Mels: {} \nVal Specs: {} \nVal Mel-Lengths: {} \n".format( val_texts.size(), val_mels.size(), val_specs.size(), val_mel_lengths)) if step == 4: break #### Trn Texts: torch.Size([4, 43]) Trn Mels: torch.Size([4, 365, 80]) Trn Specs: torch.Size([4, 365, 513]) Trn Mel-Lengths: tensor([255, 365, 350, 335]) Val Texts: torch.Size([4, 42]) Val Mels: torch.Size([4, 280, 80]) Val Specs: torch.Size([4, 280, 513]) Val Mel-Lengths: tensor([170, 230, 280, 275]) Trn Texts: torch.Size([4, 52]) Trn Mels: torch.Size([4, 415, 80]) Trn Specs: torch.Size([4, 415, 513]) Trn Mel-Lengths: tensor([180, 330, 415, 375]) Val Texts: torch.Size([4, 34]) Val Mels: torch.Size([4, 215, 80]) Val Specs: torch.Size([4, 215, 513]) Val Mel-Lengths: tensor([145, 175, 175, 215]) Trn Texts: torch.Size([4, 38]) Trn Mels: torch.Size([4, 250, 80]) Trn Specs: torch.Size([4, 250, 513]) Trn Mel-Lengths: tensor([155, 215, 250, 220]) Val Texts: torch.Size([4, 54]) Val Mels: torch.Size([4, 315, 80]) Val Specs: torch.Size([4, 315, 513]) Val Mel-Lengths: tensor([255, 315, 150, 195]) Trn Texts: torch.Size([4, 53]) Trn Mels: torch.Size([4, 335, 80]) Trn Specs: torch.Size([4, 335, 513]) Trn Mel-Lengths: tensor([330, 215, 295, 335]) Val Texts: torch.Size([4, 58]) Val Mels: torch.Size([4, 315, 80]) Val Specs: torch.Size([4, 315, 513]) Val Mel-Lengths: tensor([315, 155, 225, 210]) Trn Texts: torch.Size([4, 48]) Trn Mels: torch.Size([4, 370, 80]) Trn Specs: torch.Size([4, 370, 513]) Trn Mel-Lengths: tensor([300, 370, 205, 180]) Val Texts: torch.Size([4, 68]) Val Mels: torch.Size([4, 410, 80]) Val Specs: torch.Size([4, 410, 513]) Val Mel-Lengths: tensor([200, 370, 410, 115])
split_ratio 값을 도입함으로써
Dataset 코드에도 변화가 있었습니다.
저는 전공자가 아니라서
클래스의 상속에 대한 개념이 생소했습니다.
그래도 어렵지 않으니
간단하게 언급하고 넘어가겠습니다.
제일 먼저 볼 수 있는 KSSDatasetPath() 클래스는
기존 KSSDataset의 __init__ 부분에서
Train, Validation 데이터의 경로를 지정하는 부분만
선언된 클래스입니다.
그리고 각각 이어지는
KSSTrainDataset(), KSSValidateDataset()은
torch.utils.data.Dataset을 상속받으면서 동시에
KSSDatasetPath 클래스를 상속받습니다.
KSSDatasetPath 클래스를 상속받을 때,
KSSTrainDataset(), KSSValidateDataset()의 __init__ 부분에
KSSDatasetPath.__init__(self, Hparams)를 선언해주셔야
KSSTrainDataset(), KSSValidateDataset()에서
KSSDatasetPath()의 __init__ 부분의 변수를 사용할 수 있습니다.
KSSDatasetPath() 클래스에서
Validation data의 길이를 정하는 부분이 있습니다.
간단하게 Original data의 길이에
split_ratio를 곱해서 그 만큼의 길이를 Validation data로
사용해도 무방할테지만,
소숫점 처리 등등
여러가지 사소하면서 귀찮은 작업이 있으므로
저는 그냥 batch_size의 배수에 맞게끔
Validation data의 길이를 설정한뒤
Original data의 길이와 Validation data의 길이의 차이 만큼을
Train data의 길이로 설정했습니다.
Trainer
<스압 주의>
class Trainer(): def __init__(self, Hparams): self.Hparams = Hparams def train(self, target_step=None, target_loss=None, batch_size=None, train_continue=False): if (train_continue == False): # if train_continue == True, saved_model_dir will be delete if os.path.isdir(self.Hparams.saved_model_dir): # 이미 있으면 강제 삭제 shutil.rmtree(self.Hparams.saved_model_dir) os.makedirs(self.Hparams.saved_model_dir, exist_ok=True) # 디렉토리 생성 # Logger logger = self.logging_fn() # Fit target step if (target_step == None): target_step = 5*(10**6) # 500M elif (target_step == target_step): target_step = target_step # Fit target loss if (target_loss == None): target_loss = 0.2 elif (target_loss == target_loss): target_loss = target_loss # Fit batch size if (batch_size == None): batch_size = self.Hparams.batch_size elif (batch_size == batch_size): batch_size = batch_size # Original Condition basic_condition = (self.Hparams.split_ratio == None or self.Hparams.split_ratio == 0) splited_condition = (self.Hparams.split_ratio != None or self.Hparams.split_ratio > 0) # Dataset trn_dataset = KSSTrainDataset() val_dataset = KSSValidateDataset() # Dataloader trn_dataloader = torch.utils.data.DataLoader(dataset=trn_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) # Setting model, loss_fn, optimizer, scheduler model = Tacotron(self.Hparams, teacher_forcing=True).to(self.Hparams.device) criterion = torch.nn.L1Loss().to(self.Hparams.device) optimizer = torch.optim.Adam(model.parameters(), lr=self.Hparams.init_learning_rate, betas=[self.Hparams.beta1, self.Hparams.beta2]) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[5*(10**5),1*(10*6),2*(10*6)], gamma=0.5) # Load checkpoint if (train_continue == True): model, optimizer, init_global_step, trn_loss = self.load_fn(model, optimizer, logger) elif (train_continue == False): init_global_step = 0 # Start Train global_step = init_global_step if (basic_condition): logger.info("Basic Condition") for trn_values in trn_dataloader: while True: start_time = time.time() # data trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values trn_texts = trn_texts.to(self.Hparams.device) trn_mels = trn_mels.to(self.Hparams.device) trn_specs = trn_specs.to(self.Hparams.device) trn_mel_lengths = trn_mel_lengths.to(self.Hparams.device) # train trn_pred_mels, trn_pred_specs, trn_aligns, trn_loss = self.train_step( trn_texts, trn_mels, trn_specs, trn_mel_lengths, model, criterion, optimizer, scheduler) # update currnent step global_step += 1 # logs running_time = time.time() - start_time logger.info("Time(sec): {:.2f} - Step: {} - Train Loss: {:.6f}".format( running_time, global_step, trn_loss)) val_loss = None # save model state if (global_step > 0 and global_step % 2000 == 0): self.save_fn(global_step, model, optimizer, trn_loss, logger) # early stop training elif (trn_loss <= target_loss or global_step == (target_step+init_global_step)): self.save_fn(global_step, model, optimizer, trn_loss, logger) break elif (splited_condition): logger.info("Splited Condition") while True: for trn_batch_idx, trn_values in enumerate(trn_dataloader): start_time = time.time() # data trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values trn_texts = trn_texts.to(self.Hparams.device) trn_mels = trn_mels.to(self.Hparams.device) trn_specs = trn_specs.to(self.Hparams.device) trn_mel_lengths = trn_mel_lengths.to(self.Hparams.device) # train trn_pred_mels, trn_pred_specs, trn_aligns, trn_loss = self.train_step( trn_texts, trn_mels, trn_specs, trn_mel_lengths, model, criterion, optimizer, scheduler) # update currnent step global_step += 1 # train logs logger.info("Time(sec): {:.2f} - Step: {} - Train Loss: {:.6f}".format( (time.time() - start_time), global_step, trn_loss)) # save model state if (global_step % 500 == 0): self.save_fn(global_step, model, optimizer, trn_loss, logger) # early stop training (excape from for loop) if (trn_loss <= target_loss or global_step == (target_step+init_global_step)): break val_losses = 0.0 for val_batch_idx, val_values in enumerate(val_dataloader): start_time = time.time() val_texts, val_mels, val_specs, val_mel_lengths = val_values # data val_texts = val_texts.to(self.Hparams.device) val_mels = val_mels.to(self.Hparams.device) val_specs = val_specs.to(self.Hparams.device) val_mel_lengths = val_mel_lengths.to(self.Hparams.device) # validate val_pred_mels, val_pred_specs, val_aligns, val_loss = self.validate_step( val_texts, val_mels, val_specs, val_mel_lengths, model, criterion) val_losses += val_loss running_time = time.time() - start_time # logs logger.info("Time(sec): {:.2f} - Step: {} - Validation Loss: {:.6f}".format( running_time, global_step, val_losses/(val_batch_idx+1))) self.save_alignment(global_step, val_aligns, logger) self.save_audio(global_step, val_pred_specs, logger) # add save_spectrogram function # stop training (excape from while loop) if (trn_loss <= target_loss or global_step == (target_step+init_global_step)): logger.info("Train has been finished at {} step".format(global_step)) self.save_fn(global_step, model, optimizer, trn_loss, logger) self.save_alignment(global_step, val_aligns, logger) self.save_audio(global_step, val_pred_specs, logger) # add save_spectrogram function break def train_step(self, texts, true_mels, true_specs, mel_lengths, model, criterion, optimizer, scheduler): model.train() optimizer.zero_grad() pred_mels, pred_specs, aligns = model(texts, mel_lengths, true_mels) mel_loss = self.mask_l1_loss_fn(pred_mels, true_mels, mel_lengths, criterion) spec_loss = self.mask_l1_loss_fn(pred_specs, true_specs, mel_lengths, criterion) loss = 0.5*mel_loss + 0.5*spec_loss loss.backward() optimizer.step() scheduler.step() return pred_mels, pred_specs, aligns, loss def validate_step(self, texts, true_mels, true_specs, mel_lengths, model, criterion): model.eval() with torch.no_grad(): pred_mels, pred_specs, aligns = model(texts, mel_lengths, true_mels) mel_loss = self.mask_l1_loss_fn(pred_mels, true_mels, mel_lengths, criterion) spec_loss = self.mask_l1_loss_fn(pred_specs, true_specs, mel_lengths, criterion) loss = 0.5*mel_loss + 0.5*spec_loss return pred_mels, pred_specs, aligns, loss def mask_l1_loss_fn(self, pred_batch, true_batch, lengths, loss_fn): true_batch, mask = self.get_mask_from_batch(true_batch, lengths) masked_pred_batch = pred_batch * mask loss_with_mask = loss_fn(masked_pred_batch, true_batch) return loss_with_mask def get_mask_from_batch(self, batch, lengths): # if batch shape is like sequence if len(batch.size()) == 3 and batch.size(1) == max(lengths): mask = batch.data.new(batch.permute(0,2,1).size()).fill_(1) for batch_idx, true_len in enumerate(lengths): mask[batch_idx,:,true_len:] = 0 mask = mask.permute(0,2,1) # if batch shape is like spectrogram elif len(batch.size()) == 2: mask = batch.data.new(batch.size()).fill_(1) for e_id, true_len in enumerate(lengths): mask[e_id,true_len:] = 0 return batch, mask def save_fn(self, global_step, model, optimizer, trn_loss, logger): model_states_dir = os.path.join(self.Hparams.saved_model_dir, 'model_state') os.makedirs(model_states_dir, exist_ok=True) torch.save({ 'step': global_step, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': trn_loss, }, os.path.join(model_states_dir, self.Hparams.speaker+'_%08d.pth'%(global_step))) logger.info('Model Saved at {} step'.format(global_step)) def load_fn(self, model, optimizer, logger): model_state_pahts = glob.glob( os.path.join(self.Hparams.saved_model_dir, 'model_state', '*.pth')) model_last_state_path = model_state_pahts[-1] checkpoint = torch.load(model_last_state_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) global_step = checkpoint['step'] trn_loss = checkpoint['train_loss'] logger.info("Training will be train with last checkpoint.") logger.info("Last Trained Model: {} \nLast Updated Optimizer: {} \nLast Global Step: {}\ \nLast Train Loss: {} \n".format(model, optimizer, global_step, trn_loss)) return model, optimizer, global_step, trn_loss def save_alignment(self, global_step, aligns, logger): aligns_dir = os.path.join(self.Hparams.saved_model_dir, 'aligns') align_path = os.path.join(aligns_dir, self.Hparams.speaker+'_%08d.png'%(global_step)) os.makedirs(aligns_dir, exist_ok=True) align = aligns[-1].detach().cpu().numpy() plt.imshow(align, aspect='auto', origin='lower') plt.xlabel('Decoder timesteps') plt.ylabel('Encoder timesteps') plt.tight_layout() plt.savefig(align_path, format='png') logger.info("Alignment Graph saved at {} step".format(global_step)) def save_audio(self, global_step, specs, logger): audios_dir = os.path.join(self.Hparams.saved_model_dir, 'audio') audio_path = os.path.join(audios_dir, self.Hparams.speaker+'_%08d.wav'%(global_step)) os.makedirs(audios_dir, exist_ok=True) spec = specs[-1].detach().cpu().numpy() spec = (np.clip(spec, 0, 1)*(-self.Hparams.min_level_db))\ -(-self.Hparams.min_level_db)+(self.Hparams.ref_level_db) spec = np.power(10.0, spec * 0.1) audio = librosa.core.spectrum.griffinlim(spec.T**2) audio = torch.FloatTensor(audio).unsqueeze(0) torchaudio.save(audio_path, audio, sample_rate=self.Hparams.sample_rate) logger.info("Predicted Audio saved at {} step".format(global_step)) def logging_fn(self): logs_dir = os.path.join(self.Hparams.saved_model_dir) log_path = os.path.join(logs_dir, self.Hparams.speaker+'.log') logger = logging.getLogger() logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) file_handler = logging.FileHandler(log_path) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger
trainer = Trainer(hparams) trainer.train(target_loss=0.5, target_step=10, batch_size=32)
2021-08-03 09:23:55,565 - root - INFO - Splited Condition 2021-08-03 09:23:57,589 - root - INFO - Time(sec): 1.97 - Step: 1 - Train Loss: 42.590023 2021-08-03 09:23:58,162 - root - INFO - Time(sec): 0.43 - Step: 2 - Train Loss: 40.273907 2021-08-03 09:23:58,704 - root - INFO - Time(sec): 0.39 - Step: 3 - Train Loss: 44.875076 2021-08-03 09:23:59,274 - root - INFO - Time(sec): 0.43 - Step: 4 - Train Loss: 35.676170 2021-08-03 09:23:59,741 - root - INFO - Time(sec): 0.39 - Step: 5 - Train Loss: 44.428562 2021-08-03 09:24:00,255 - root - INFO - Time(sec): 0.41 - Step: 6 - Train Loss: 44.728226 2021-08-03 09:24:00,914 - root - INFO - Time(sec): 0.55 - Step: 7 - Train Loss: 32.622391 2021-08-03 09:24:01,432 - root - INFO - Time(sec): 0.40 - Step: 8 - Train Loss: 49.320732 2021-08-03 09:24:02,043 - root - INFO - Time(sec): 0.46 - Step: 9 - Train Loss: 38.736153 2021-08-03 09:24:02,530 - root - INFO - Time(sec): 0.42 - Step: 10 - Train Loss: 44.525208
'Tacotron 1' 카테고리의 다른 글
Tacotron 무지성 구현 - 8/N (0) 2021.11.05 Tacotron 무지성 구현 - 7/N (0) 2021.08.05 Tacotron 무지성 구현 - 5/N (0) 2021.07.30 Tacotron 무지성 구현 - 4/N (0) 2021.07.29 Tacotron 무지성 구현 - 3/N (0) 2021.07.27