ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Tacotron 무지성 구현 - 6/N
    Tacotron 1 2021. 8. 2. 13:24

    <수정중>

    저번 포스팅에서 간단한 Tacotron 모델 구성을 끝냈습니다.

     

    이번 포스팅에서는 학습과 관련된 코드를 구현해서

    하나의 클래스로 관리할 수 있게끔 하는 것이 목표입니다.

     

    필수적인 코드만 먼저 업로드하고,

    부분적으로 설명한 뒤에

    다음 포스팅에서 로깅과 관련된 코드도

    추가해보도록 하겠습니다.

     

     

     


    Hyper Parameters

     

    import os, torch
    
    class Hparams():
        # speaker name
        speaker = 'KSS'
        
        # Audio Pre-processing
        origin_sample_rate = 44100
        sample_rate = 22050
        n_fft = 1024
        hop_length = 256
        win_length = 1024
        n_mels = 80
        reduction = 5
        n_specs = n_fft // 2 + 1
        fmin = 0
        fmax = sample_rate // 2
        min_level_db = -80
        ref_level_db = 0
        
        # Text Pre-processing
        PAD = '_'
        EOS = '~'
        SPACE = ' '
        SPECIAL = '.,!?'
        JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
        JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
        JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
        symbols = PAD + EOS + JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + SPACE + SPECIAL
    
        _symbol_to_id = {s: i for i, s in enumerate(symbols)}
        _id_to_symbol = {i: s for i, s in enumerate(symbols)}
        
        # Pre-processing paths (text, mel, spec)
        data_dir = os.path.join('data')
        out_texts_dir = os.path.join(data_dir, 'texts')
        out_mels_dir = os.path.join(data_dir, 'mels')
        out_specs_dir = os.path.join(data_dir, 'specs')
        
        # Embedding Layer
        in_dim = 256
        
        # Encoder Pre-net Layer
        prenet_dropout_ratio = 0.5
        prenet_linear_size = 256
    
        # CBHG
        cbhg_K = 16
        cbhg_mp_k = 2
        cbhg_mp_s = 1
        cbhg_mp_p = 1
        cbhg_mp_d = 2
        cbhg_conv_proj_size = 128
        cbhg_conv_proj_k = 3
        cbhg_conv_proj_p = 1    
        cbhg_gru_hidden_dim = 128
        
        # Decoder
        decoder_rnn_dim = 256
        
        # Train
        batch_size = 32
        split_ratio = 0.2
        init_learning_rate = 0.001
        beta1 = 0.9
        beta2 = 0.999
        seed = 777
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        saved_model_dir = os.path.join(data_dir, 'saved_models', speaker)
        griffin_lim_iters = 32

    조금 달라진 점은

    import 할 때 torch를 불러오는 것과

    하단 주석(#Train) 부분에 추가된 인자들인데요,

     

    논문에서 batch_size를 32로 했다는 점

     

    Train data, Validation data를 구분짓기 위한 split_ratio

     

    Optimizer 설정 값으로 들어갈 learning_rate, betas

     

    동일한 학습 과정을 유지하기 위한 seed 값 설정

     

    GPU 학습을 위한 device 설정

     

    학습 도중(또는 종료 시점)에

    모델 파라미터 저장을 위한 saved_model_dir 설정

     

    앞으로 더 추가될 인자들이 있겠지만

    이번 포스팅에서는 언급한 인자들만 필요로하므로

    더 다루진 않겠습니다.

     

    아, 이전까지는 Embedding Layer에 들어가는

    in_dim 인자 값을 512로 두고 했는데,

    앞선 포스팅에서 언급했듯이

    학습 과정에서는 in_dim 값을 256으로 설정하고

    진행할 예정입니다.

     

     

     


    Dataset

     

    import os, glob, torch
    
    class KSSDatasetPath():
        def __init__(self, Hparams):
            self.Hparams = Hparams
            
            # Original data
            self.text_paths = glob.glob(os.path.join(self.Hparams.out_texts_dir, '*.pt'))
            self.mel_paths = glob.glob(os.path.join(self.Hparams.out_mels_dir, '*.pt'))
            self.spec_paths = glob.glob(os.path.join(self.Hparams.out_specs_dir, '*.pt'))
            
            self.original_len = len(self.text_paths)
            
            # Splited Data Length
            self.val_len = int(self.original_len * self.Hparams.split_ratio)
            if self.val_len % self.Hparams.batch_size != 0:
                extra_num = self.Hparams.batch_size - self.val_len % self.Hparams.batch_size
                self.val_len = self.val_len + extra_num
            self.trn_len = self.original_len - self.val_len
    
            # Train data, Validation data
            self.trn_text_paths, self.val_text_paths = (self.text_paths[:self.trn_len],
                                                        self.text_paths[self.trn_len:])
            self.trn_mel_paths, self.val_mel_paths = (self.mel_paths[:self.trn_len],
                                                      self.mel_paths[self.trn_len:])
            self.trn_spec_paths, self.val_spec_paths = (self.spec_paths[:self.trn_len],
                                                        self.spec_paths[self.trn_len:])
    
    
    class KSSTrainDataset(torch.utils.data.Dataset, KSSDatasetPath):
        def __init__(self):
            KSSDatasetPath.__init__(self, Hparams)
            
        def __len__(self):
            return self.trn_len
        
        def __getitem__(self, idx):
            # Train data
            self.trn_texts = torch.LongTensor(torch.load(self.trn_text_paths[idx]))
            self.trn_mels = torch.FloatTensor(torch.load(self.trn_mel_paths[idx]))
            self.trn_specs = torch.FloatTensor(torch.load(self.trn_spec_paths[idx]))
            return (self.trn_texts, self.trn_mels, self.trn_specs)
    
    
    class KSSValidateDataset(torch.utils.data.Dataset, KSSDatasetPath):
        def __init__(self):
            KSSDatasetPath.__init__(self, Hparams)
            
        def __len__(self):
            return self.val_len
        
        def __getitem__(self, idx):
            # Validate data
            self.val_texts = torch.LongTensor(torch.load(self.val_text_paths[idx]))
            self.val_mels = torch.FloatTensor(torch.load(self.val_mel_paths[idx]))
            self.val_specs = torch.FloatTensor(torch.load(self.val_spec_paths[idx]))
            return (self.val_texts, self.val_mels, self.val_specs)
        
    
    def collate_fn(batch):
        texts, mels, specs = zip(*batch)
        
        text_pads = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True)
        mel_pads = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True)
        spec_pads = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True)
        mel_lengths = torch.LongTensor([mel.size(0) for mel in mels])
    
        return (text_pads.contiguous(), 
                 mel_pads.contiguous(), 
                 spec_pads.contiguous(), 
                 mel_lengths.contiguous())
    trn_dataset = KSSTrainDataset()
    val_dataset = KSSValidateDataset()
    
    trn_dataloader = torch.utils.data.DataLoader(trn_dataset, 
                                                 batch_size=4, 
                                                 shuffle=True, 
                                                 collate_fn=collate_fn)
    
    val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                                 batch_size=4, 
                                                 shuffle=True, 
                                                 collate_fn=collate_fn)
    
    for step, (trn_values, val_values) in enumerate(zip(trn_dataloader, val_dataloader)):
        trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values
        val_texts, val_mels, val_specs, val_mel_lengths = val_values
        
        print("Trn Texts: {} \nTrn Mels: {} \nTrn Specs: {} \nTrn Mel-Lengths: {} \n".format(
            trn_texts.size(), trn_mels.size(), trn_specs.size(), trn_mel_lengths))
        
        print("Val Texts: {} \nVal Mels: {} \nVal Specs: {} \nVal Mel-Lengths: {} \n".format(
            val_texts.size(), val_mels.size(), val_specs.size(), val_mel_lengths))    
        
        if step == 4:
            break
            
    
    ####
    Trn Texts: torch.Size([4, 43]) 
    Trn Mels: torch.Size([4, 365, 80]) 
    Trn Specs: torch.Size([4, 365, 513]) 
    Trn Mel-Lengths: tensor([255, 365, 350, 335]) 
    
    Val Texts: torch.Size([4, 42]) 
    Val Mels: torch.Size([4, 280, 80]) 
    Val Specs: torch.Size([4, 280, 513]) 
    Val Mel-Lengths: tensor([170, 230, 280, 275]) 
    
    Trn Texts: torch.Size([4, 52]) 
    Trn Mels: torch.Size([4, 415, 80]) 
    Trn Specs: torch.Size([4, 415, 513]) 
    Trn Mel-Lengths: tensor([180, 330, 415, 375]) 
    
    Val Texts: torch.Size([4, 34]) 
    Val Mels: torch.Size([4, 215, 80]) 
    Val Specs: torch.Size([4, 215, 513]) 
    Val Mel-Lengths: tensor([145, 175, 175, 215]) 
    
    Trn Texts: torch.Size([4, 38]) 
    Trn Mels: torch.Size([4, 250, 80]) 
    Trn Specs: torch.Size([4, 250, 513]) 
    Trn Mel-Lengths: tensor([155, 215, 250, 220]) 
    
    Val Texts: torch.Size([4, 54]) 
    Val Mels: torch.Size([4, 315, 80]) 
    Val Specs: torch.Size([4, 315, 513]) 
    Val Mel-Lengths: tensor([255, 315, 150, 195]) 
    
    Trn Texts: torch.Size([4, 53]) 
    Trn Mels: torch.Size([4, 335, 80]) 
    Trn Specs: torch.Size([4, 335, 513]) 
    Trn Mel-Lengths: tensor([330, 215, 295, 335]) 
    
    Val Texts: torch.Size([4, 58]) 
    Val Mels: torch.Size([4, 315, 80]) 
    Val Specs: torch.Size([4, 315, 513]) 
    Val Mel-Lengths: tensor([315, 155, 225, 210]) 
    
    Trn Texts: torch.Size([4, 48]) 
    Trn Mels: torch.Size([4, 370, 80]) 
    Trn Specs: torch.Size([4, 370, 513]) 
    Trn Mel-Lengths: tensor([300, 370, 205, 180]) 
    
    Val Texts: torch.Size([4, 68]) 
    Val Mels: torch.Size([4, 410, 80]) 
    Val Specs: torch.Size([4, 410, 513]) 
    Val Mel-Lengths: tensor([200, 370, 410, 115])

     

    split_ratio 값을 도입함으로써

    Dataset 코드에도 변화가 있었습니다.

     

    저는 전공자가 아니라서

    클래스의 상속에 대한 개념이 생소했습니다.

    그래도 어렵지 않으니

    간단하게 언급하고 넘어가겠습니다.

     

    제일 먼저 볼 수 있는 KSSDatasetPath() 클래스는

    기존 KSSDataset의 __init__ 부분에서

    Train, Validation 데이터의 경로를 지정하는 부분만

    선언된 클래스입니다.

     

    그리고 각각 이어지는

    KSSTrainDataset(), KSSValidateDataset()은

    torch.utils.data.Dataset을 상속받으면서 동시에

    KSSDatasetPath 클래스를 상속받습니다.

     

    KSSDatasetPath 클래스를 상속받을 때,

    KSSTrainDataset(), KSSValidateDataset()의 __init__ 부분에

    KSSDatasetPath.__init__(self, Hparams)를 선언해주셔야

    KSSTrainDataset(), KSSValidateDataset()에서

    KSSDatasetPath()의 __init__ 부분의 변수를 사용할 수 있습니다.

     

    KSSDatasetPath() 클래스에서 

    Validation data의 길이를 정하는 부분이 있습니다.

    간단하게 Original data의 길이에

    split_ratio를 곱해서 그 만큼의 길이를 Validation data로

    사용해도 무방할테지만,

    소숫점 처리 등등

    여러가지 사소하면서 귀찮은 작업이 있으므로

    저는 그냥 batch_size의 배수에 맞게끔

    Validation data의 길이를 설정한뒤

    Original data의 길이와 Validation data의 길이의 차이 만큼을

    Train data의 길이로 설정했습니다.

     

     

     


    Trainer

     

    <스압 주의>

     

    class Trainer():
        def __init__(self, Hparams):
            self.Hparams = Hparams
            
        
        def train(self, target_step=None, target_loss=None, batch_size=None, train_continue=False):
            
            if (train_continue == False): # if train_continue == True, saved_model_dir will be delete
                if os.path.isdir(self.Hparams.saved_model_dir): # 이미 있으면 강제 삭제
                    shutil.rmtree(self.Hparams.saved_model_dir)
                os.makedirs(self.Hparams.saved_model_dir, exist_ok=True) # 디렉토리 생성
    
            # Logger
            logger = self.logging_fn()
                
            # Fit target step
            if (target_step == None):
                target_step = 5*(10**6) # 500M
            elif (target_step == target_step):
                target_step = target_step
    
            # Fit target loss
            if (target_loss == None):
                target_loss = 0.2
            elif (target_loss == target_loss):
                target_loss = target_loss
    
            # Fit batch size
            if (batch_size == None):
                batch_size = self.Hparams.batch_size
            elif (batch_size == batch_size):
                batch_size = batch_size
                
            # Original Condition
            basic_condition = (self.Hparams.split_ratio == None or self.Hparams.split_ratio == 0)
            splited_condition = (self.Hparams.split_ratio != None or self.Hparams.split_ratio > 0)
                
            # Dataset
            trn_dataset = KSSTrainDataset()
            val_dataset = KSSValidateDataset()
    
            # Dataloader
            trn_dataloader = torch.utils.data.DataLoader(dataset=trn_dataset, 
                                                         batch_size=batch_size, 
                                                         shuffle=True, 
                                                         collate_fn=collate_fn)
    
            val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                                         batch_size=batch_size, 
                                                         shuffle=True, 
                                                         collate_fn=collate_fn)
    
            # Setting model, loss_fn, optimizer, scheduler
            model = Tacotron(self.Hparams, teacher_forcing=True).to(self.Hparams.device)
            
            criterion = torch.nn.L1Loss().to(self.Hparams.device)
            
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=self.Hparams.init_learning_rate,
                                         betas=[self.Hparams.beta1, self.Hparams.beta2])
            
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=[5*(10**5),1*(10*6),2*(10*6)], 
                gamma=0.5)
    
            # Load checkpoint
            if (train_continue == True):
                model, optimizer, init_global_step, trn_loss = self.load_fn(model, optimizer, logger)
            elif (train_continue == False):
                init_global_step = 0
            
            # Start Train
            global_step = init_global_step
            if (basic_condition):
                logger.info("Basic Condition")
                for trn_values in trn_dataloader:
                    while True:
                        start_time = time.time()
                        
                        # data
                        trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values
                        trn_texts = trn_texts.to(self.Hparams.device)
                        trn_mels = trn_mels.to(self.Hparams.device)
                        trn_specs = trn_specs.to(self.Hparams.device)
                        trn_mel_lengths = trn_mel_lengths.to(self.Hparams.device)
                        
                        # train
                        trn_pred_mels, trn_pred_specs, trn_aligns, trn_loss = self.train_step(
                            trn_texts, trn_mels, trn_specs, trn_mel_lengths,
                            model, criterion, optimizer, scheduler)
                        
                        # update currnent step
                        global_step += 1
                        
                        # logs
                        running_time = time.time() - start_time
                        logger.info("Time(sec): {:.2f} - Step: {} - Train Loss: {:.6f}".format(
                            running_time, global_step, trn_loss))
    
                        val_loss = None
    
                        # save model state
                        if (global_step > 0 and global_step % 2000 == 0):
                            self.save_fn(global_step, model, optimizer, trn_loss, logger)
    
                        # early stop training
                        elif (trn_loss <= target_loss or global_step == (target_step+init_global_step)):
                            self.save_fn(global_step, model, optimizer, trn_loss, logger)
                            break
                        
            elif (splited_condition):
                logger.info("Splited Condition")
                while True:
                    for trn_batch_idx, trn_values in enumerate(trn_dataloader):
                        start_time = time.time()
    
                        # data
                        trn_texts, trn_mels, trn_specs, trn_mel_lengths = trn_values
                        trn_texts = trn_texts.to(self.Hparams.device)
                        trn_mels = trn_mels.to(self.Hparams.device)
                        trn_specs = trn_specs.to(self.Hparams.device)                
                        trn_mel_lengths = trn_mel_lengths.to(self.Hparams.device)
                        
                        # train
                        trn_pred_mels, trn_pred_specs, trn_aligns, trn_loss = self.train_step(
                            trn_texts, trn_mels, trn_specs, trn_mel_lengths,
                            model, criterion, optimizer, scheduler)
                        
                        # update currnent step
                        global_step += 1
                        
                        # train logs
                        logger.info("Time(sec): {:.2f} - Step: {} - Train Loss: {:.6f}".format(
                            (time.time() - start_time), global_step, trn_loss))
    
                        # save model state
                        if (global_step % 500 == 0):
                            self.save_fn(global_step, model, optimizer, trn_loss, logger)
                                                    
                        # early stop training (excape from for loop)
                        if (trn_loss <= target_loss or global_step == (target_step+init_global_step)):
                            break
    
                    val_losses = 0.0
                    for val_batch_idx, val_values in enumerate(val_dataloader):
                        start_time = time.time()
                        val_texts, val_mels, val_specs, val_mel_lengths = val_values
                        
                        # data
                        val_texts = val_texts.to(self.Hparams.device)
                        val_mels = val_mels.to(self.Hparams.device)
                        val_specs = val_specs.to(self.Hparams.device)
                        val_mel_lengths = val_mel_lengths.to(self.Hparams.device)
                        
                        # validate
                        val_pred_mels, val_pred_specs, val_aligns, val_loss = self.validate_step(
                            val_texts, val_mels, val_specs, val_mel_lengths,
                            model, criterion)
    
                        val_losses += val_loss
                        running_time = time.time() - start_time
    
                    # logs
                    logger.info("Time(sec): {:.2f} - Step: {} - Validation Loss: {:.6f}".format(
                        running_time, global_step, val_losses/(val_batch_idx+1)))
                    self.save_alignment(global_step, val_aligns, logger)
                    self.save_audio(global_step, val_pred_specs, logger)
                    # add save_spectrogram function
    
                    # stop training (excape from while loop)
                    if (trn_loss <= target_loss or global_step == (target_step+init_global_step)):
                        logger.info("Train has been finished at {} step".format(global_step))
                        self.save_fn(global_step, model, optimizer, trn_loss, logger)
                        self.save_alignment(global_step, val_aligns, logger)
                        self.save_audio(global_step, val_pred_specs, logger)
                        # add save_spectrogram function
                        break
        
        
        def train_step(self, texts, true_mels, true_specs, mel_lengths, 
                        model, criterion, optimizer, scheduler):
            model.train()
            optimizer.zero_grad()
            pred_mels, pred_specs, aligns = model(texts, mel_lengths, true_mels)
            mel_loss = self.mask_l1_loss_fn(pred_mels, true_mels, mel_lengths, criterion)
            spec_loss = self.mask_l1_loss_fn(pred_specs, true_specs, mel_lengths, criterion)
            loss = 0.5*mel_loss + 0.5*spec_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            return pred_mels, pred_specs, aligns, loss
    
        
        def validate_step(self, texts, true_mels, true_specs, mel_lengths,
                           model, criterion):
            model.eval()
            with torch.no_grad():
                pred_mels, pred_specs, aligns = model(texts, mel_lengths, true_mels)
                mel_loss = self.mask_l1_loss_fn(pred_mels, true_mels, mel_lengths, criterion)
                spec_loss = self.mask_l1_loss_fn(pred_specs, true_specs, mel_lengths, criterion)
                loss = 0.5*mel_loss + 0.5*spec_loss
            return pred_mels, pred_specs, aligns, loss
    
    
        def mask_l1_loss_fn(self, pred_batch, true_batch, lengths, loss_fn):
            true_batch, mask = self.get_mask_from_batch(true_batch, lengths)
            masked_pred_batch = pred_batch * mask
            loss_with_mask = loss_fn(masked_pred_batch, true_batch)
            return loss_with_mask
    
    
        def get_mask_from_batch(self, batch, lengths):
            # if batch shape is like sequence
            if len(batch.size()) == 3 and batch.size(1) == max(lengths):
                mask = batch.data.new(batch.permute(0,2,1).size()).fill_(1)
                for batch_idx, true_len in enumerate(lengths):
                    mask[batch_idx,:,true_len:] = 0
                mask = mask.permute(0,2,1)
            # if batch shape is like spectrogram
            elif len(batch.size()) == 2:
                mask = batch.data.new(batch.size()).fill_(1)
                for e_id, true_len in enumerate(lengths):
                    mask[e_id,true_len:] = 0
            return batch, mask
    
        
        def save_fn(self, global_step, model, optimizer, trn_loss, logger):
            model_states_dir = os.path.join(self.Hparams.saved_model_dir, 'model_state')
            os.makedirs(model_states_dir, exist_ok=True)
            torch.save({
                'step': global_step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': trn_loss,
            }, os.path.join(model_states_dir, self.Hparams.speaker+'_%08d.pth'%(global_step)))
            logger.info('Model Saved at {} step'.format(global_step))
    
    
        def load_fn(self, model, optimizer, logger):
            model_state_pahts = glob.glob(
                os.path.join(self.Hparams.saved_model_dir, 'model_state', '*.pth'))
            model_last_state_path = model_state_pahts[-1]
            checkpoint = torch.load(model_last_state_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            global_step = checkpoint['step']
            trn_loss = checkpoint['train_loss']
            logger.info("Training will be train with last checkpoint.")
            logger.info("Last Trained Model: {} \nLast Updated Optimizer: {} \nLast Global Step: {}\
                        \nLast Train Loss: {} \n".format(model, optimizer, global_step, trn_loss))
            return model, optimizer, global_step, trn_loss
    
        
        def save_alignment(self, global_step, aligns, logger):
            aligns_dir = os.path.join(self.Hparams.saved_model_dir, 'aligns')
            align_path = os.path.join(aligns_dir, self.Hparams.speaker+'_%08d.png'%(global_step))
            os.makedirs(aligns_dir, exist_ok=True)
            align = aligns[-1].detach().cpu().numpy()
            plt.imshow(align, aspect='auto', origin='lower')
            plt.xlabel('Decoder timesteps')
            plt.ylabel('Encoder timesteps')
            plt.tight_layout()
            plt.savefig(align_path, format='png')
            logger.info("Alignment Graph saved at {} step".format(global_step))
    
    
        def save_audio(self, global_step, specs, logger):
            audios_dir = os.path.join(self.Hparams.saved_model_dir, 'audio')
            audio_path = os.path.join(audios_dir, self.Hparams.speaker+'_%08d.wav'%(global_step))
            os.makedirs(audios_dir, exist_ok=True)
            spec = specs[-1].detach().cpu().numpy()
            spec = (np.clip(spec, 0, 1)*(-self.Hparams.min_level_db))\
                    -(-self.Hparams.min_level_db)+(self.Hparams.ref_level_db)
            spec = np.power(10.0, spec * 0.1)
            audio = librosa.core.spectrum.griffinlim(spec.T**2)
            audio = torch.FloatTensor(audio).unsqueeze(0)
            torchaudio.save(audio_path, audio, sample_rate=self.Hparams.sample_rate)
            logger.info("Predicted Audio saved at {} step".format(global_step))
    
    
        def logging_fn(self):
            logs_dir = os.path.join(self.Hparams.saved_model_dir)
            log_path = os.path.join(logs_dir, self.Hparams.speaker+'.log')
            logger = logging.getLogger()
            logger.setLevel(logging.INFO)
            formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
            stream_handler = logging.StreamHandler()
            stream_handler.setFormatter(formatter)
            logger.addHandler(stream_handler)
            file_handler = logging.FileHandler(log_path)
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)
            return logger
    trainer = Trainer(hparams)
    trainer.train(target_loss=0.5, target_step=10, batch_size=32)
    2021-08-03 09:23:55,565 - root - INFO - Splited Condition
    2021-08-03 09:23:57,589 - root - INFO - Time(sec): 1.97 - Step: 1 - Train Loss: 42.590023
    2021-08-03 09:23:58,162 - root - INFO - Time(sec): 0.43 - Step: 2 - Train Loss: 40.273907
    2021-08-03 09:23:58,704 - root - INFO - Time(sec): 0.39 - Step: 3 - Train Loss: 44.875076
    2021-08-03 09:23:59,274 - root - INFO - Time(sec): 0.43 - Step: 4 - Train Loss: 35.676170
    2021-08-03 09:23:59,741 - root - INFO - Time(sec): 0.39 - Step: 5 - Train Loss: 44.428562
    2021-08-03 09:24:00,255 - root - INFO - Time(sec): 0.41 - Step: 6 - Train Loss: 44.728226
    2021-08-03 09:24:00,914 - root - INFO - Time(sec): 0.55 - Step: 7 - Train Loss: 32.622391
    2021-08-03 09:24:01,432 - root - INFO - Time(sec): 0.40 - Step: 8 - Train Loss: 49.320732
    2021-08-03 09:24:02,043 - root - INFO - Time(sec): 0.46 - Step: 9 - Train Loss: 38.736153
    2021-08-03 09:24:02,530 - root - INFO - Time(sec): 0.42 - Step: 10 - Train Loss: 44.525208

     

    'Tacotron 1' 카테고리의 다른 글

    Tacotron 무지성 구현 - 8/N  (0) 2021.11.05
    Tacotron 무지성 구현 - 7/N  (0) 2021.08.05
    Tacotron 무지성 구현 - 5/N  (0) 2021.07.30
    Tacotron 무지성 구현 - 4/N  (0) 2021.07.29
    Tacotron 무지성 구현 - 3/N  (0) 2021.07.27

    댓글

Designed by Tistory.