import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .utils.mol_attention import MOLAttention
from .utils.basic_layers import Linear
from .utils.vc_utils import get_mask_from_lengths


class DecoderPrenet(nn.Module):
    def __init__(self, in_dim, sizes):
        super().__init__()
        in_sizes = [in_dim] + sizes[:-1]
        self.layers = nn.ModuleList(
            [Linear(in_size, out_size, bias=False)
             for (in_size, out_size) in zip(in_sizes, sizes)])

    def forward(self, x):
        for linear in self.layers:
            x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
        return x


class Decoder(nn.Module):
    """Mixture of Logistic (MoL) attention-based RNN Decoder."""
    def __init__(
        self,
        enc_dim,
        num_mels,
        frames_per_step,
        attention_rnn_dim,
        decoder_rnn_dim,
        prenet_dims,
        num_mixtures,
        encoder_down_factor=1,
        num_decoder_rnn_layer=1,
        use_stop_tokens=False,
        concat_context_to_last=False,
    ):
        super().__init__()
        self.enc_dim = enc_dim
        self.encoder_down_factor = encoder_down_factor
        self.num_mels = num_mels
        self.frames_per_step = frames_per_step
        self.attention_rnn_dim = attention_rnn_dim
        self.decoder_rnn_dim = decoder_rnn_dim
        self.prenet_dims = prenet_dims
        self.use_stop_tokens = use_stop_tokens
        self.num_decoder_rnn_layer = num_decoder_rnn_layer
        self.concat_context_to_last = concat_context_to_last

        # Mel prenet
        self.prenet = DecoderPrenet(num_mels, prenet_dims)
        self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)

        # Attention RNN
        self.attention_rnn = nn.LSTMCell(
            prenet_dims[-1] + enc_dim,
            attention_rnn_dim
        )
        
        # Attention
        self.attention_layer = MOLAttention(
            attention_rnn_dim,
            r=frames_per_step/encoder_down_factor,
            M=num_mixtures,
        )

        # Decoder RNN
        self.decoder_rnn_layers = nn.ModuleList()
        for i in range(num_decoder_rnn_layer):
            if i == 0:
                self.decoder_rnn_layers.append(
                    nn.LSTMCell(
                        enc_dim + attention_rnn_dim,
                        decoder_rnn_dim))
            else:
                self.decoder_rnn_layers.append(
                    nn.LSTMCell(
                        decoder_rnn_dim,
                        decoder_rnn_dim))
        # self.decoder_rnn = nn.LSTMCell(
            # 2 * enc_dim + attention_rnn_dim,
            # decoder_rnn_dim
        # )
        if concat_context_to_last:
            self.linear_projection = Linear(
                enc_dim + decoder_rnn_dim,
                num_mels * frames_per_step
            )
        else:
            self.linear_projection = Linear(
                decoder_rnn_dim,
                num_mels * frames_per_step
            )


        # Stop-token layer
        if self.use_stop_tokens:
            if concat_context_to_last:
                self.stop_layer = Linear(
                    enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
                )
            else:
                self.stop_layer = Linear(
                    decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
                )
                

    def get_go_frame(self, memory):
        B = memory.size(0)
        go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
                               device=memory.device)
        return go_frame

    def initialize_decoder_states(self, memory, mask):
        device = next(self.parameters()).device
        B = memory.size(0)
        
        # attention rnn states
        self.attention_hidden = torch.zeros(
            (B, self.attention_rnn_dim), device=device)
        self.attention_cell = torch.zeros(
            (B, self.attention_rnn_dim), device=device)

        # decoder rnn states
        self.decoder_hiddens = []
        self.decoder_cells = []
        for i in range(self.num_decoder_rnn_layer):
            self.decoder_hiddens.append(
                torch.zeros((B, self.decoder_rnn_dim),
                            device=device)
            )
            self.decoder_cells.append(
                torch.zeros((B, self.decoder_rnn_dim),
                            device=device)
            )
        # self.decoder_hidden = torch.zeros(
            # (B, self.decoder_rnn_dim), device=device)
        # self.decoder_cell = torch.zeros(
            # (B, self.decoder_rnn_dim), device=device)
        
        self.attention_context =  torch.zeros(
            (B, self.enc_dim), device=device)

        self.memory = memory
        # self.processed_memory = self.attention_layer.memory_layer(memory)
        self.mask = mask

    def parse_decoder_inputs(self, decoder_inputs):
        """Prepare decoder inputs, i.e. gt mel
        Args:
            decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
        """
        decoder_inputs = decoder_inputs.reshape(
            decoder_inputs.size(0),
            int(decoder_inputs.size(1)/self.frames_per_step), -1)
        # (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
        decoder_inputs = decoder_inputs.transpose(0, 1)
        # (T_out//r, B, num_mels)
        decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
        return decoder_inputs
        
    def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
        """ Prepares decoder outputs for output
        Args:
            mel_outputs:
            alignments:
        """
        # (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
        alignments = torch.stack(alignments).transpose(0, 1)
        # (T_out//r, B) -> (B, T_out//r)
        if stop_outputs is not None:
            if alignments.size(0) == 1:
                stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
            else:
                stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
            stop_outputs = stop_outputs.contiguous()
        # (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
        mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
        # decouple frames per step
        # (B, T_out, num_mels)
        mel_outputs = mel_outputs.view(
            mel_outputs.size(0), -1, self.num_mels)
        return mel_outputs, alignments, stop_outputs     
    
    def attend(self, decoder_input):
        cell_input = torch.cat((decoder_input, self.attention_context), -1)
        self.attention_hidden, self.attention_cell = self.attention_rnn(
            cell_input, (self.attention_hidden, self.attention_cell))
        self.attention_context, attention_weights = self.attention_layer(
            self.attention_hidden, self.memory, None, self.mask)
        
        decoder_rnn_input = torch.cat(
            (self.attention_hidden, self.attention_context), -1)

        return decoder_rnn_input, self.attention_context, attention_weights

    def decode(self, decoder_input):
        for i in range(self.num_decoder_rnn_layer):
            if i == 0:
                self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
                    decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
            else:
                self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
                    self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
        return self.decoder_hiddens[-1]
    
    def forward(self, memory, mel_inputs, memory_lengths):
        """ Decoder forward pass for training
        Args:
            memory: (B, T_enc, enc_dim) Encoder outputs
            decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
            memory_lengths: (B, ) Encoder output lengths for attention masking.
        Returns:
            mel_outputs: (B, T, num_mels) mel outputs from the decoder
            alignments: (B, T//r, T_enc) attention weights.
        """
        # [1, B, num_mels]
        go_frame = self.get_go_frame(memory).unsqueeze(0)
        # [T//r, B, num_mels]
        mel_inputs = self.parse_decoder_inputs(mel_inputs)
        # [T//r + 1, B, num_mels]
        mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
        # [T//r + 1, B, prenet_dim]
        decoder_inputs = self.prenet(mel_inputs) 
        # decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)

        self.initialize_decoder_states(
            memory, mask=~get_mask_from_lengths(memory_lengths),
        )
        
        self.attention_layer.init_states(memory)
        # self.attention_layer_pitch.init_states(memory_pitch)

        mel_outputs, alignments = [], []
        if self.use_stop_tokens:
            stop_outputs = []
        else:
            stop_outputs = None
        while len(mel_outputs) < decoder_inputs.size(0) - 1:
            decoder_input = decoder_inputs[len(mel_outputs)]
            # decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]

            decoder_rnn_input, context, attention_weights = self.attend(decoder_input)

            decoder_rnn_output = self.decode(decoder_rnn_input)
            if self.concat_context_to_last:    
                decoder_rnn_output = torch.cat(
                    (decoder_rnn_output, context), dim=1)
                   
            mel_output = self.linear_projection(decoder_rnn_output)
            if self.use_stop_tokens:
                stop_output = self.stop_layer(decoder_rnn_output)
                stop_outputs += [stop_output.squeeze()]
            mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
            alignments += [attention_weights]
            # alignments_pitch += [attention_weights_pitch]   

        mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
            mel_outputs, alignments, stop_outputs)
        if stop_outputs is None:
            return mel_outputs, alignments
        else:
            return mel_outputs, stop_outputs, alignments

    def inference(self, memory, stop_threshold=0.5):
        """ Decoder inference
        Args:
            memory: (1, T_enc, D_enc) Encoder outputs
        Returns:
            mel_outputs: mel outputs from the decoder
            alignments: sequence of attention weights from the decoder
        """
        # [1, num_mels]
        decoder_input = self.get_go_frame(memory)

        self.initialize_decoder_states(memory, mask=None)

        self.attention_layer.init_states(memory)
        
        mel_outputs, alignments = [], []
        # NOTE(sx): heuristic 
        max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step 
        min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
        while True:
            decoder_input = self.prenet(decoder_input)

            decoder_input_final, context, alignment = self.attend(decoder_input)

            #mel_output, stop_output, alignment = self.decode(decoder_input)
            decoder_rnn_output = self.decode(decoder_input_final)
            if self.concat_context_to_last:    
                decoder_rnn_output = torch.cat(
                    (decoder_rnn_output, context), dim=1)
            
            mel_output = self.linear_projection(decoder_rnn_output)
            stop_output = self.stop_layer(decoder_rnn_output)
            
            mel_outputs += [mel_output.squeeze(1)]
            alignments += [alignment]
            
            if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
                break
            if len(mel_outputs) >= max_decoder_step:
                # print("Warning! Decoding steps reaches max decoder steps.")
                break

            decoder_input = mel_output[:,-self.num_mels:]


        mel_outputs, alignments, _  = self.parse_decoder_outputs(
            mel_outputs, alignments, None)

        return mel_outputs, alignments

    def inference_batched(self, memory, stop_threshold=0.5):
        """ Decoder inference
        Args:
            memory: (B, T_enc, D_enc) Encoder outputs
        Returns:
            mel_outputs: mel outputs from the decoder
            alignments: sequence of attention weights from the decoder
        """
        # [1, num_mels]
        decoder_input = self.get_go_frame(memory)

        self.initialize_decoder_states(memory, mask=None)

        self.attention_layer.init_states(memory)
        
        mel_outputs, alignments = [], []
        stop_outputs = []
        # NOTE(sx): heuristic 
        max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step 
        min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
        while True:
            decoder_input = self.prenet(decoder_input)

            decoder_input_final, context, alignment = self.attend(decoder_input)

            #mel_output, stop_output, alignment = self.decode(decoder_input)
            decoder_rnn_output = self.decode(decoder_input_final)
            if self.concat_context_to_last:    
                decoder_rnn_output = torch.cat(
                    (decoder_rnn_output, context), dim=1)
            
            mel_output = self.linear_projection(decoder_rnn_output)
            # (B, 1)
            stop_output = self.stop_layer(decoder_rnn_output)
            stop_outputs += [stop_output.squeeze()]
            # stop_outputs.append(stop_output) 

            mel_outputs += [mel_output.squeeze(1)]
            alignments += [alignment]
            # print(stop_output.shape)
            if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
                    and len(mel_outputs) >= min_decoder_step:
                break
            if len(mel_outputs) >= max_decoder_step:
                # print("Warning! Decoding steps reaches max decoder steps.")
                break

            decoder_input = mel_output[:,-self.num_mels:]


        mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
            mel_outputs, alignments, stop_outputs)
        mel_outputs_stacked = []
        for mel, stop_logit in zip(mel_outputs, stop_outputs):
            idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
            mel_outputs_stacked.append(mel[:idx,:])
        mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
        return mel_outputs, alignments