序列到序列生成

Sequence to Sequence Learning with Neural Networks [Seq2Seq] 这篇文章描述了序列到序列模型,它读入一个输入句子,先通过循环神经网络编码成一个隐式的向量,再根据这个隐式的向量使用循环神经网络语言模型输出一个句子。这个模型在机器翻译问题上取得了很好的效果,也被用在于诸如图片描述生成、问答、对话生成等任务上。

循环神经网络

循环神经网络是神经元可以接受自身信息从而形成环路的网络结构,能够处理像自然语言中的句子等任意长度的序列数据。

神经网络与深度学习 [NNDL] 这本书讲解了神经网络与深度学习技术的基本原理,其中包括了循环神经网络等基础的神经网络模型。

The Unreasonable Effectiveness of Recurrent Neural NetworksUnderstanding LSTM Networks 这两篇文章是学习循环神经网络的很好的参考。

基于LSTM的序列到序列对话生成

根据序列到序列模型,实现对话生成需要如下步骤:

  1. 数据预处理:把对话数据处理成QA对。

  2. 建立词典:后续模型学习需要使用数字代替词,这一步根据QA对建立词典,词典为每个词生成一个唯一的数字,并提供词到数字的相互转换功能。

  3. 实现模型:使用Pytorch实现编码器模型及解码器模型。

  4. 训练模型:使用QA训练模型。

  5. 生成对话:使用训练好的模型生成对话回复。

这里来实现一个基于 LSTM 的序列到序列对话生成模型。本节大量参考了 PyTorchChatbot Tutorial [Chatbot]

数据预处理

这里依旧使用 CDial-GPT 开放的数据集作为生成模型的训练数据。首先处理原始数据为问答对。

def process_corpus(datapath, output_path):
    f = open(datapath, 'r', encoding='utf-8')
    convs = json.load(f)
    f.close()

    qas = []
    for conv in tqdm(convs):
        qas.extend(zip(conv[:-1], conv[1:]))

    qas = [(q, a) for q, a in qas if len(q) < 30 and len(a) < 30 and len(q) > 4]

    f = open(output_path, 'w', encoding='utf-8')
    f.write(json.dumps(qas, ensure_ascii=False))
    f.close()

处理数据后,保存到文件中备用,并重新读入。

process_corpus('\data\LCCC\LCCC-base-split\LCCC-base_train.json', '\data\seq2seq\LCCC-base_train.seq2seq.json')

f = open('\data\seq2seq\LCCC-base_train.seq2seq.json', 'r', encoding='utf-8')
qas = json.load(f)
f.close()

问答对内容是这样的:

qas[:10]

[['道歉 ! ! 再有 时间 找 你 去', '领个 搓衣板 去 吧'],
['咬咬牙 这回 要 全入 了 !', '干完 这 一票 我 的 会员 等级 就要 升 了 !'],
['干完 这 一票 我 的 会员 等级 就要 升 了 !', '升 了 继续'],
['代表 了 哪里 的 普通人 ?', '我 瞎说 的'],
['早点 好 起来 啊 。 生日快乐', '好 得 差不多 啦'],
['好 得 差不多 啦', '那 很 好 啊'],
['那么 早 ! ! 我 的 考试 周 还 没有 开始 呢 !', '是 啊 , 今年 好 快 啊'],
['是 啊 , 今年 好 快 啊', '不止 今年 , 我 发现 你 每次 都 好 早'],
['不止 今年 , 我 发现 你 每次 都 好 早', '小心 乌鸦嘴 , 下 一次 就 最晚 了'],
['今天 不 知道 能下 么 , 不过 天 又 冷 了', '希望 能 …']]

词典及输入输出

首先编写一个 Voc 来生成词典。

class Voc:

    PAD_TOKEN = "PAD"
    SOS_TOKEN = "SOS"
    EOS_TOKEN = "EOS"

    PAD_TOKEN_IDNEX = 0
    SOS_TOKEN_INDEX = 1
    EOS_TOKEN_INDEX = 2

    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            Voc.PAD_TOKEN_IDNEX: Voc.PAD_TOKEN,
            Voc.SOS_TOKEN_INDEX: Voc.SOS_TOKEN,
            Voc.EOS_TOKEN_INDEX: Voc.EOS_TOKEN
        }
        self.num_words = len(self.index2word)

    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            Voc.PAD_TOKEN_IDNEX: Voc.PAD_TOKEN,
            Voc.SOS_TOKEN_INDEX: Voc.SOS_TOKEN,
            Voc.EOS_TOKEN_INDEX: Voc.EOS_TOKEN
        }
        self.num_words = len(self.index2word)

        for word in keep_words:
            self.add_word(word)

使用问答对数据,生成词典。这里只保留在问答对中出现过2次以上的词。

voc = Voc('LCCC-base')
for q, a in tqdm(qas):
    voc.add_sentence(q)
    voc.add_sentence(a)

MIN_COUNT = 2
voc.trim(MIN_COUNT)

根据过滤后的词典,再次处理问答对,去掉那些含有不在词典中的词的句子。

keep_qas = []
for q, a in tqdm(qas):
    keep_q = True
    keep_a = True
    for word in q.split():
        if word not in voc.word2index:
            keep_q = False
            break
    for word in a.split():
        if word not in voc.word2index:
            keep_a = False
            break
    if keep_q and keep_a:
        keep_qas.append((q, a))

print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(qas), len(keep_qas), len(keep_qas) / len(qas)))

实现几个辅助函数,用于转换数据。

def words_to_indexes(voc, words):
    return [voc.word2index[word] for word in words] + [Voc.EOS_TOKEN_INDEX]

def zero_padding(l, fillvalue=Voc.PAD_TOKEN_IDNEX):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binary_matrix(l, value=Voc.PAD_TOKEN_IDNEX):
    m = []
    for seq in l:
        mask = [0 if token == Voc.PAD_TOKEN_IDNEX else 1 for token in seq]
        m.append(mask)
    return m

def input_var(l, voc):
    indexes_batch = [words_to_indexes(voc, sentence.split()) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    pad_list = zero_padding(indexes_batch)
    pad_var = torch.LongTensor(pad_list)

    return pad_var, lengths

def output_var(l, voc):
    indexes_batch = [words_to_indexes(voc, sentence.split()) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    pad_list = zero_padding(indexes_batch)

    mask = binary_matrix(pad_list)
    mask = torch.BoolTensor(mask)
    pad_var = torch.LongTensor(pad_list)

    return pad_var, mask, max_target_len

def batch_train_data(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split()), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])

    inp, lengths = input_var(input_batch, voc)
    outp, mask, max_target_len = output_var(output_batch, voc)

    return inp, lengths, outp, mask, max_target_len

模型

模型包括编码器,注意力机制,以及解码器。

编码器

class EncoderRNN(torch.nn.Module):

    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding

        self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embeded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embeded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]

        return outputs, hidden

注意力

class Attn(torch.nn.Module):

    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, self.hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, self.hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def _dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def _general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def _concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self._general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self._concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self._dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t()

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1)

解码器

class LuongAttnDecoderRNN(torch.nn.Module):

    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = embedding
        self.embedding_dropout = torch.nn.Dropout(dropout)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout), bidirectional=False)
        self.concat = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.out = torch.nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)


    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)

        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))

        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)

        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))

        output = self.out(concat_output)
        output = F.softmax(output, dim=1)

        return output, hidden

训练模型

计算损失

def mask_nll_loss(inp, target, mask):
    total = mask.sum()
    cross_entropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = cross_entropy.masked_select(mask).mean()
    loss = loss.to(device)

    return loss, total.item()

训练函数

def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=30):

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_variable = input_variable.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    lengths = lengths.to('cpu')

    loss = 0
    print_losses = []
    totals = 0

    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    decoder_input = torch.LongTensor([[Voc.SOS_TOKEN_INDEX for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)

    decoder_hidden = encoder_hidden[:decoder.n_layers]

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_input = target_variable[t].view(1, -1)
            mask_loss, total = mask_nll_loss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            totals += total

            print_losses.append(mask_loss.item() * total)
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)

            mask_loss, total = mask_nll_loss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            totals += total

            print_losses.append(mask_loss.item() * total)

    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / totals

训练

model_name = 's2s_model'
attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 128

embedding = torch.nn.Embedding(voc.num_words, hidden_size)
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
encoder = encoder.to(device)
decoder = decoder.to(device)

# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 60000
print_every = 1000
save_every = 10000

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

# If you have cuda, configure cuda to call
for state in encoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

for state in decoder_optimizer.state.values():
    for k, v in state.items():
        if isinstance(v, torch.Tensor):
            state[k] = v.cuda()

training_batches = [batch_train_data(voc, [random.choice(keep_qas) for _ in range(batch_size)]) for _ in range(n_iteration)]

save_dir = os.path.join("data", "save")
corpus_name = 'LCCC-base'

# Initializations
print('Initializing ...')
start_iteration = 1
print_loss = 0

# Training loop
print("Training...")
for iteration in tqdm(range(start_iteration, n_iteration + 1)):
    training_batch = training_batches[iteration - 1]
    # Extract fields from batch
    input_variable, lengths, target_variable, mask, max_target_len = training_batch

    # Run a training iteration with batch
    loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
    print_loss += loss

    # Print progress
    if iteration % print_every == 0:
        print_loss_avg = print_loss / print_every
        print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
        print_loss = 0

    # Save checkpoint
    if (iteration % save_every == 0):
        directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save({
            'iteration': iteration,
            'en': encoder.state_dict(),
            'de': decoder.state_dict(),
            'en_opt': encoder_optimizer.state_dict(),
            'de_opt': decoder_optimizer.state_dict(),
            'loss': loss,
            'voc_dict': voc.__dict__,
            'embedding': embedding.state_dict()
        }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

训练好后,模型就保存在了文件里。

对话生成

我们使用前面训练好的模型,实现一个 Seq2SeqChatAgent 对话引擎。

import random
import os
import torch

class Seq2SeqChatAgent(object):

    def __init__(self, model_filename, device):
        self.hidden_size = 500
        self.dropout = 0.1
        self.encoder_n_layers = 2
        self.decoder_n_layers = 2
        self.attn_model = 'dot'
        self.max_length = 30
        self.device = device
        checkpoint = torch.load(model_filename)
        self.encoder_sd = checkpoint['en']
        self.decoder_sd = checkpoint['de']
        self.encoder_optimizer_sd = checkpoint['en_opt']
        self.decoder_optimizer_sd = checkpoint['de_opt']
        embedding_sd = checkpoint['embedding']
        self.voc = Voc('LCCC-base')
        self.voc.__dict__ = checkpoint['voc_dict']

        self.embedding = torch.nn.Embedding(self.voc.num_words, self.hidden_size)
        self.embedding.load_state_dict(embedding_sd)

        self.encoder = EncoderRNN(self.hidden_size, self.embedding, self.encoder_n_layers, self.dropout)
        self.decoder = LuongAttnDecoderRNN(self.attn_model, self.embedding, self.hidden_size, self.voc.num_words, self.decoder_n_layers, self.dropout)

        self.encoder.load_state_dict(self.encoder_sd)
        self.decoder.load_state_dict(self.decoder_sd)

        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)

        # Set dropout layers to eval mode
        self.encoder.eval()
        self.decoder.eval()

    def reply(self, message):
        input_sentence = [x for x in jieba.cut(message)]
        # words -> indexes
        indexes_batch = [words_to_indexes(self.voc, input_sentence)]
        # Create lengths tensor
        lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
        # Transpose dimensions of batch to match models' expectations
        input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
        # Use appropriate device
        input_batch = input_batch.to(self.device)
        #lengths = lengths.to(device)

        # Decode sentence with searcher
        with torch.no_grad():
            tokens, scores = self._greedy_search(input_batch, lengths, self.max_length)
            # indexes -> words
            decoded_words = [self.voc.index2word[token.item()] for token in tokens]
            # Format response sentence
            decoded_words[:] = [x for x in decoded_words if not (x == 'EOS' or x == 'PAD')]
            return ''.join(decoded_words)

    def _greedy_search(self, input_seq, input_length, max_length):
        # Forward input through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:self.decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) * Voc.SOS_TOKEN_INDEX

        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
        all_scores = torch.zeros([0], device=self.device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)

            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)

        # Return collections of word tokens and scores
        return all_tokens, all_scores

我们使用基于生成模型的对话引擎进行交互。

class ConversationSystem(object):

patterns = [
    (r'你好[吗啊呀]?', {'*': ['你好', '嗨']}),
    (r'你是男是女', {'*': ['你觉得呢', '你呢']}),
    (r'你是谁', {'*': ['我是一个AI', '我是个机器人']}),
    (r'你(多大|几岁)[吗啊呀]?', {'*': ['我还年轻', '这是个秘密']}),
    (r'你在(哪里|哪儿|什么地方)', {'*': ['我在云端,你在哪儿', '我在你的身边,你在哪里呢']}),
    (r'我是(?P<Gender>[男女])(的|生|人)', {'男': ['帅哥好'], '女': ['美女好']}),
    (r'我也?在(?P<Place>.+)', {'北京': ['北京现在很冷吧'], '*': ['你在{Place}?']})
]

fallback_patterns = [
    (r'.*', {'*': ['你说什么', '不好意思,没明白你的话']})
]

def __init__(self):
    checkpoint_iter = 60000
    checkpoint_filename = os.path.join(save_dir, model_name, corpus_name,
        '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
        '{}_checkpoint.tar'.format(checkpoint_iter))

    self.chat_engines = [
        TemplateChatAgent(ConversationSystem.patterns),
        Seq2SeqChatAgent(checkpoint_filename),
        TemplateChatAgent(ConversationSystem.fallback_patterns),
    ]

def reply(self, uid, message):
    result = ''
    for chat_engine self.chat_engines:
        result = chat_engine.reply(message)
        if result:
            break

    return result

def interact_cli(self):
    while True:
        query = input('User:')
        if query == 'Q':
            break
        print('AI:', self.reply('UserA', query))

conv_system = ConversationSystem()
conv_system.interact_cli()

运行后可以得到下面的对话。

User:你好
AI: 你好
User:昨天下了一场大雨
AI: 下大雨了
User:英超结束了
AI: 结束了
User:明天太阳会不会出来
AI: 不会
User:夏天来了
AI: 夏天夏天夏天夏天夏天夏天夏天就来了

解码采样策略

Seq2SeqChatAgent 在生成句子时每一步都取概率最大的词,这种选择目标词的方式称为 贪心搜索 。这种选择方式只会产生一个生成结果。 如果在每一步多保留几个候选词,有可能会生成出更好的句子,或者使生成的结果具有多样性。这种选择目标词的方式称为 Beam Search

class BeamSearchDecoder(nn.Module):

    def __init__(self, encoder, decoder, k):
        super(BeamSearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.k = k

    def forward(self, input_seq, input_length, max_length):
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * Voc.SOS_TOKEN_INDEX
        candidates = [{
            'Hidden': decoder_hidden,
            'Inputs': decoder_input,
            'Tokens': [],
            'Scores': []
        }]


        dones = []
        for _ in range(max_length):
            next_candidates = []
            for candidate in candidates:
                current_input = candidate['Inputs']
                if current_input == Voc.EOS_TOKEN_INDEX:
                    dones.append(candidate)
                    continue
                current_hidden = candidate['Hidden']
                decoder_output, decoder_hidden = self.decoder(current_input, current_hidden, encoder_outputs)
                c_scores, c_input = torch.topk(decoder_output, self.k, dim=1)
                for s, i in zip(c_scores[0], c_input[0]):
                    new_score = candidate['Scores'] + [s]
                    new_tokens = candidate['Tokens'] + [i]

                    next_candidates.append({
                        'Hidden': decoder_hidden,
                        'Inputs': i.unsqueeze(0).unsqueeze(0),
                        'Tokens': new_tokens,
                        'Scores': new_score
                    })

            next_candidates.sort(key=lambda e: sum(e['Scores'])/(len(e['Scores']))*len(set(e['Tokens']))*len(set(e['Tokens']))/len(e['Tokens'])/len(e['Tokens']), reverse=True)
            if self.k <= len(dones):
                break
            candidates = next_candidates[:self.k-len(dones)]

        for candidate in candidates:
            current_input = candidate['Inputs']
            if current_input == Voc.EOS_TOKEN_INDEX:
                dones.append(candidate)

        candidates = dones
        candidates.sort(key=lambda e: sum(e['Scores'])/(len(e['Scores']))*len(set(e['Tokens']))*len(set(e['Tokens']))/len(e['Tokens'])/len(e['Tokens']), reverse=True)

        if candidates:
            candidate = random.choice(candidates)
            r = torch.stack(candidate['Tokens'])
            s = torch.stack(candidate['Scores'])

            return r, s

        return None, None

Controllable Neural Text Generation [Controllable] 这篇文章介绍了各种解码策略。