Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions model/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GMMLoss(nn.Module):
def __init__(self):
super(GMMLoss, self).__init__()

def forward(self, x, mu, std, pi):
x = x.unsqueeze(-1)
std = std + 1e-6
distrib = (1.0 / np.sqrt(2*np.pi)) * torch.exp(-0.5 * ((x - mu) / std) ** 2) / std
distrib = torch.sum(pi * distrib, dim=3) + 1e-6
loss = -1.0 * torch.log(distrib) # NLL

return torch.mean(loss)
77 changes: 77 additions & 0 deletions model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .tier import Tier
from .loss import GMMLoss
from utils.constant import f_div, t_div


class MelNet(nn.Module):
def __init__(self, hp):
super(MelNet, self).__init__()
self.hp = hp
self.f_div = f_div[hp.n_tiers+1]
self.t_div = t_div[hp.n_tiers+1]

self.tiers = nn.ModuleList([None] +
[Tier(hp=hp,
freq=hp.n_mels // self.f_div * f_div[tier],
layers=hp.layers[tier-1],
tierN=tier)
for tier in range(1, hp.n_tiers+1)])

def forward(self, x, tier_num, text=None):
assert tier_num > 0, 'tier_num should be larger than 0, got %d' % tier_num

mu, std, pi = self.tiers[tier_num](x, text)
return mu, std, pi



def interleave(self, tier_a, tier_b, axis):
assert axis in ['time', 'freq'], "Axis shoud be either time or frequncy"

if axis=='time':
output = tier_a.new_zeros(tier_a.size(0), tier_a.size(1), 2*tier_a.size(2))
for i in range(tier_a.size(2)):
output[:,:,2*i] = tier_a[:,:,i]
output[:,:,2*i+1] = tier_b[:,:,i]

elif axis=='freq':
output = tier_a.new_zeros(tier_a.size(0), 2*tier_a.size(1), tier_a.size(2))
for i in range(tier_a.size(1)):
output[:,2*i,:] = tier_a[:,i,:]
output[:,2*i+1,::] = tier_b[:,i,:]

return output



def sample(self, n_sec, text=None, device=None):
#autoregressively compute tier1
n_slice = n_sec * self.hp.sr
x = torch.zeros(1, self.hp.n_mels//self.f_div, n_slice//self.t_div)

if device is not None:
x = x.to(device)

tier1 = self.tiers[1].sample(x, text)
tier2 = self.tiers[2].sample(tier1)
tier1_2 = self.interleave(tier1, tier2, axis='freq')

tier3 = self.tiers[3].sample( tier1_2 )
tier1_2_3 = self.interleave(tier1_2, tier3, axis='time')

tier4 = self.tiers[4].sample( tier1_2_3 )
tier1_2_3_4 = self.interleave(tier1_2_3, tier4, axis='freq')

tier5 = self.tiers[5].sample( tier1_2_3_4 )
tier1_2_3_4_5 = self.interleave(tier1_2_3_4, tier5, axis='time')

tier6 = self.tiers[6].sample( tier1_2_3_4_5 )

output = self.interleave(tier1_2_3_4_5, tier6, axis='freq')

return output
70 changes: 70 additions & 0 deletions model/rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from .stack_c import Stack_C
from utils.constant import f_div


class DelayedRNN(nn.Module):
def __init__(self, hp, tierN):
super(DelayedRNN, self).__init__()
self.TTS = hp.TTS
self.num_hidden = hp.hidden_dim
self.tierN = tierN

self.freq = hp.n_mels * f_div[tierN] // f_div[hp.n_tiers+1]

self.t_delay_RNN_x = nn.LSTM(input_size=self.num_hidden, hidden_size=self.num_hidden, batch_first=True)
self.t_delay_RNN_yz = nn.LSTM(input_size=self.num_hidden, hidden_size=self.num_hidden, batch_first=True, bidirectional=True)

# use central stack only at initial tier
if (tierN == 1):
self.c_RNN = Stack_C(hp)

self.f_delay_RNN = nn.LSTM(input_size=self.num_hidden, hidden_size=self.num_hidden, batch_first=True)

self.W_t = nn.Linear(3*self.num_hidden, self.num_hidden)
self.W_c = nn.Linear(self.num_hidden, self.num_hidden)
self.W_f = nn.Linear(self.num_hidden, self.num_hidden)

def forward(self, input_h_t, input_h_f, input_h_c=0.0, memory=None):
# input_h_t, input_h_f: [B, M, T, D] / input_h_c: [B, T, D]
B, M, T, D = input_h_t.size()

####### time-delayed stack #######
# Fig. 2(a)-1 can be parallelized by viewing each horizontal line as batch
h_t_x, _ = self.t_delay_RNN_x(input_h_t.view(-1, T, D))
h_t_x = h_t_x.view(B, M, T, D)

# Fig. 2(a)-2,3 can be parallelized by viewing each vertical line as batch,
# using bi-directional version of LSTM
temp = input_h_t.transpose(1, 2).contiguous() # [B, T, M, D]
temp = temp.view(-1, M, D)
h_t_yz, _ = self.t_delay_RNN_yz(temp)
h_t_yz = h_t_yz.view(B, T, M, 2*D)
h_t_yz = h_t_yz.transpose(1, 2)

h_t_concat = torch.cat((h_t_x, h_t_yz), dim=3)
output_h_t = input_h_t + self.W_t(h_t_concat) # residual connection, eq. (6)

####### centralized stack #######
output_h_c = 0.0
h_c_expanded = 0.0
if self.tierN == 1:
h_c_temp, _ = self.c_RNN(input_h_c, memory)
output_h_c = input_h_c + self.W_c(h_c_temp) # residual connection, eq. (11)
h_c_expanded = output_h_c.unsqueeze(1).repeat(1, self.freq, 1, 1)


####### frequency-delayed stack #######
h_f_sum = input_h_f + output_h_t + h_c_expanded
h_f_sum = h_f_sum.transpose(1, 2).contiguous() # [B, T, M, D]
h_f_sum = h_f_sum.view(-1, M, D)

h_f_temp, _ = self.f_delay_RNN(h_f_sum)
h_f_temp = h_f_temp.view(B, T, M, D)
h_f_temp = h_f_temp.transpose(1, 2) # [B, M, T, D]
output_h_f = input_h_f + self.W_f(h_f_temp) # residual connection, eq. (8)

return output_h_t, output_h_f, output_h_c
90 changes: 90 additions & 0 deletions model/stack_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention_module(nn.Module):
def __init__(self, hp):
super(Attention_module, self).__init__()
self.rnn_cell = nn.LSTMCell(input_size=2*hp.num_hidden, hidden_size=hp.num_hidden, batch_first=True)
self.num_hidden = hp.num_hidden

self.W_g = nn.Linear(hp.num_hidden, 3*hp.M)


def attention(self, h_i, memory):

phi_hat = self.W_g(h_i)
self.ksi_hat = self.ksi_hat + torch.exp(phi_hat[:, :self.M])
self.beta_hat = torch.exp( phi_hat[:, self.M:2*self.M] )
self.alpha_hat = F.softmax(phi_hat[:, 2*self.M:3*self.M], dim=-1)

self.u = torch.LongTensor( range(memory.size(1)) )
self.u_R = self.u + 0.5
self.u_L = self.u - 0.5

term1 = torch.sum(self.alpha_hat.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((self.ksi_hat.unsqueeze(-1) - self.u_R) / self.beta_hat.unsqueeze(-1))), dim=1)

term2 = torch.sum(self.alpha_hat.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((self.ksi_hat.unsqueeze(-1) - self.u_L) / self.beta_hat.unsqueeze(-1))), dim=1)

weights = (term1-term2).unsqueeze(1)


context = torch.bmm(weights, memory)

termination = 1 - torch.sum(self.alpha_hat.unsqueeze(-1)*
torch.reciprocal(1 + torch.exp((self.ksi_hat.unsqueeze(-1) - self.u_R) / self.beta_hat.unsqueeze(-1))), dim=1)


return context, weights, termination # (B, 1, D), (B, 1, T), (B, 1, T)



def forward(self, input_h_c, memory, input_lengths):
B, T, D = input_h_c.size()

context = input_h_c.new_zeros(B, D)
h_i, c_i = input_h_c.new_zeros(B, D), input_h_c.new_zeros(B, D)

contexts, weights, terminations = [], [], []
for i in range(T):
x = torch.cat([input_h_c[:, i], context], dim=-1)
h_i, c_i = self.rnn_cell(x, (h_i, c_i))
context, weight, termination = self.attention(h_i, memory)

contexts.append(context)
weights.append(weight)
terminations.append(termination)

contexts = torch.cat(contexts, dim=1)
weights = torch.cat(weights, dim=1)
terminations = torch.cat(terminations, dim=1)
terminations = torch.gather(terminations, 2, input_lengths.unsqueeze(-1))

return context, weights, terminations



class Stack_C(nn.Module):
def __init__(self, hp):
super(Stack_C, self).__init__()
self.hp = hp
self.TTS = hp.TTS

if TTS==True:
self.c_RNN = nn.LSTM(input_size=hp.num_hidden, hidden_size=hp.num_hidden, batch_first=True)

else:
self.c_RNN = Attention_module(hp)

def forward(self, x, memory=None):
if memory is None:
h_c_temp, _ = self.c_RNN(input_h_c)

elif:
h_c_temp, _ = self.c_RNN(input_h_c, memory)


return h_c_temp
96 changes: 96 additions & 0 deletions model/tier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .rnn import DelayedRNN


class Tier(nn.Module):
def __init__(self, hp, freq, layers, tierN):
super(Tier, self).__init__()
self.TTS = hp.TTS
self.tierN = tierN

self.W_t_0 = nn.Linear(1, hp.hidden_dim)
self.W_f_0 = nn.Linear(1, hp.hidden_dim)
self.W_c_0 = nn.Linear(freq, hp.hidden_dim)

self.layers = nn.ModuleList([
DelayedRNN(hp, tierN) for _ in range(layers)
])

# Gaussian Mixture Model: eq. (2)
self.K = hp.K
self.pi_softmax = nn.Softmax(dim=3)

# map output to produce GMM parameter eq. (10)
self.W_theta = nn.Linear(hp.hidden_dim, 3*self.K)

if self.TTS==True:
self.TextEncoder = nn.Sequential(nn.Embedding(hp.n_voca, hp.hidden_dim),
nn.LSTM(input_size=hp.num_hidden, hidden_size=self.num_hidden, batch_first=True))


def forward(self, x, text=None):
if text is None:
memory = None
elif:
memory = self.TextEncoder(text)

if self.tierN == 1:
h_t = self.W_t_0(F.pad(x, [1, -1, 0, 0, 0, 0]).unsqueeze(-1))
h_f = self.W_f_0(F.pad(x, [0, 0, 1, -1, 0, 0]).unsqueeze(-1))
h_c = self.W_c_0(F.pad(x, [1, -1, 0, 0, 0, 0]).transpose(1, 2))
else:
h_t = self.W_t_0(x.unsqueeze(-1))
h_f = self.W_f_0(x.unsqueeze(-1))
h_c = self.W_c_0(x.transpose(1, 2))

# h_t, h_f: [B, M, T, D] / h_c: [B, T, D]
for layer in self.layers:
h_t, h_f, h_c = layer(h_t, h_f, h_c, memory)

theta_hat = self.W_theta(h_f)

mu = torch.sigmoid(theta_hat[:, :, :, :self.K]) # eq. (3)
std = torch.exp(theta_hat[:, :, :, self.K:2*self.K]) # eq. (4)
pi = self.pi_softmax(theta_hat[:, :, :, 2*self.K:]) # eq. (5)

return mu, std, pi




def sample(self, x, text=None):
# x: [1, M, T] / B=1, M=mel, T=time
if self.tierN == 1:
n_slice = x.size(-1)
x_t, x_f = x.clone(), x.clone()

for i in range(n_slice):
h_t = self.W_t_0(x_t.unsqueeze(-1))
h_f = self.W_f_0(x_f.unsqueeze(-1))
h_c = self.W_c_0(x_t.transpose(1, 2))

for layer in self.layers:
h_t, h_f, h_c = layer(h_t, h_f, h_c)

theta_hat = self.W_theta(h_f)

mu = torch.sigmoid(theta_hat[:, :, :, :self.K]) # eq. (3)
pi = self.pi_softmax(theta_hat[:, :, :, 2*self.K:]) # eq. (5)

mu = torch.sum(mu*pi, dim=3)

x_t[:,:,i+1] = mu[:,:,i]
x_f[:,i+1,:] = mu[:,i,:]

return mu

else:
return self.forward(x)




Loading