Evalの次にGene
                                                                #                                                  #
parser.add_argument('--gene', type=int, default=0, help='whether to generate') #                                   #
parser.add_argument('--checkpoint', type=str, default='./PTB.pt', #                                                #
help='model checkpoint to use')                                 #                                                  #
                                                                #                                                  #
if 2 == args.gene:                                              # 生成指示がある場合(単独)                         #
    with open(args.checkpoint, 'rb') as f:                      # ptファイルを読込み                               #
        model = torch.load(f)                                   # モデルのファイル読込                             #
                                                                #                                                  #
    generate(input)                                             # 生成実施                                         #
                                                                #                                                  #
    sys.exit(args.gene)                                         # 単独生成時の後続処理スキップ                     #
                                                                #                                                  #
#* Generation *************************************************## 生成処理                                        *#
def generate(input, batch_size=1):                              #                                                  #
    model.eval()                                                # 推測モードにしてドロップアウト無効化             #
    ntokens = len(corpus.dictionary)                            # トークン数取得                                   #
    if args.model == 'QRNN': model.reset()                      # QRNN指示がある場合、各層初期化                   #
                                                                #                                                  #
    hidden = model.init_hidden(batch_size)                      # 中間層初期化依頼                                 #
    input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) # なんでもいいから1単語ランダムに生成    #
    if args.cuda:                                               # CUDA指示がある場合                               #
        input.data = input.data.cuda()                          # GPU処理のため                                    #
                                                                #                                                  #
    with open(args.outf, 'w') as outf:                          # 出力ファイルをオープンして                       #
        for i in range(args.words):                             # 出力する単語数分繰返し                           #
            output, hidden = model(input, hidden)               # forward依頼                                      #
            word_weights = output.squeeze().data.div(args.temperature).exp().cpu() # 
            word_idx = torch.multinomial(word_weights, 1)[0]    # 
            input.data.fill_(word_idx)                          # 
            word = corpus.dictionary.idx2word[word_idx]         # 
                                                                #                                                  #
            outf.write(word + ('\n' if i % 20 == 19 else ' '))  # 
                                                                #                                                  #
            if i % args.log_interval == 0:
                print('| Generated {}/{} words'.format(i, args.words))
# data, targets = get_batch(data_source, i, args, evaluation=True) # バッチデータ化依頼 #
# output, hidden = model(data, hidden) # forward依頼 #
# total_loss += \
# len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data # 損失合計算出 #
# hidden = repackage_hidden(hidden) # 中間層静止依頼 #
# return total_loss.item() / len(data_source) # 平均誤差を返して処理終了 #
#/*Generation *************************************************## 生成処理 *#
                                                                #                                                  #
if 1 == args.gene:                                              # 生成指示がある場合(学習後)                       #

トップ   編集 凍結 差分 バックアップ 添付 複製 名前変更 リロード   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
Last-modified: 2019-05-30 (木) 06:47:32 (87d)