Search
Duplicate
🎳

FACT model

최종 업데이트
9/28/2021, 9:07:00 AM
Tags
Empty
일시
Empty
작성자
소준섭
포스팅 종류
Empty
✅ main

구현 코드

""" sinusoid position encoding """ def get_sinusoid_encoding_table(n_seq, d_hidn): def cal_angle(position, i_hidn): return position / np.power(10000, 2 * (i_hidn // 2) / d_hidn) def get_posi_angle_vec(position): return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)] sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)]) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # even index sin sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # odd index cos return sinusoid_table
Python
""" encoder """ class EncoderMel(nn.Module): def __init__(self, config): super().__init__() self.config = config # self.enc_emb = nn.Embedding(self.config.n_enc_vocab, self.config.d_hidn) # mel inputs self.enc_emb = nn.Linear(128, 128) sinusoid_table = torch.FloatTensor(get_sinusoid_encoding_table(self.config.n_enc_seq + 1, self.config.d_hidn)) self.pos_emb = nn.Embedding.from_pretrained(sinusoid_table, freeze=True) self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config.n_layer)]) def forward(self, inputs): positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1 # print(positions.shape) # pos_mask = inputs.eq(self.config.i_pad) # print(pos_mask.shape) # positions.masked_fill_(pos_mask, 0) # (bs, n_enc_seq, d_hidn) outputs = self.enc_emb(inputs) + self.pos_emb(positions.long()) print(outputs.shape) # (bs, n_enc_seq, n_enc_seq) # attn_mask = get_attn_pad_mask(inputs, inputs, self.config.i_pad) attn_probs = [] for layer in self.layers: # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq) outputs, attn_prob = layer(outputs) attn_probs.append(attn_prob) # (bs, n_enc_seq, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)] return outputs, attn_probs
Python
Input data embedding은 Linear로 처리
pos_emb 파트 수정이 필요
input data가 이미 embed 형태로 들어오기때문에 기존 Transformer 인풋의 pos embedding과 달라져야한다.

모델학습

Dataset ⇒ keypoints, mel 정상적으로 구현
keypoints 정답 데이터를 다시 설정해야됩니다.
성준님께서 수정해주신 모델로 한번 학습돌려보겠습니다.