#author("2019-05-30T06:47:32+09:00","default:Authors","Authors")
 Evalの次にGene
                                                                 #                                                  #
 parser.add_argument('--gene', type=int, default=0, help='whether to generate') #                                   #
 parser.add_argument('--checkpoint', type=str, default='./PTB.pt', #                                                #
 help='model checkpoint to use')                                 #                                                  #
                                                                 #                                                  #
 if 2 == args.gene:                                              # 生成指示がある場合(単独)                         #
     with open(args.checkpoint, 'rb') as f:                      # ptファイルを読込み                               #
         model = torch.load(f)                                   # モデルのファイル読込                             #
                                                                 #                                                  #
     generate(input)                                             # 生成実施                                         #
                                                                 #                                                  #
     sys.exit(args.gene)                                         # 単独生成時の後続処理スキップ                     #
                                                                 #                                                  #
 #* Generation *************************************************## 生成処理                                        *#
 def generate(input, batch_size=1):                              #                                                  #
     model.eval()                                                # 推測モードにしてドロップアウト無効化             #
     ntokens = len(corpus.dictionary)                            # トークン数取得                                   #
     if args.model == 'QRNN': model.reset()                      # QRNN指示がある場合、各層初期化                   #
                                                                 #                                                  #
     hidden = model.init_hidden(batch_size)                      # 中間層初期化依頼                                 #
     input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) # なんでもいいから1単語ランダムに生成    #
     if args.cuda:                                               # CUDA指示がある場合                               #
         input.data = input.data.cuda()                          # GPU処理のため                                    #
                                                                 #                                                  #
     with open(args.outf, 'w') as outf:                          # 出力ファイルをオープンして                       #
         for i in range(args.words):                             # 出力する単語数分繰返し                           #
             output, hidden = model(input, hidden)               # forward依頼                                      #
             word_weights = output.squeeze().data.div(args.temperature).exp().cpu() # 
             word_idx = torch.multinomial(word_weights, 1)[0]    # 
             input.data.fill_(word_idx)                          # 
             word = corpus.dictionary.idx2word[word_idx]         # 
                                                                 #                                                  #
             outf.write(word + ('\n' if i % 20 == 19 else ' '))  # 
                                                                 #                                                  #
             if i % args.log_interval == 0:
                 print('| Generated {}/{} words'.format(i, args.words))
 # data, targets = get_batch(data_source, i, args, evaluation=True) # バッチデータ化依頼 #
 # output, hidden = model(data, hidden) # forward依頼 #
 # total_loss += \
 # len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data # 損失合計算出 #
 # hidden = repackage_hidden(hidden) # 中間層静止依頼 #
 # return total_loss.item() / len(data_source) # 平均誤差を返して処理終了 #
 #/*Generation *************************************************## 生成処理 *#
                                                                 #                                                  #
 if 1 == args.gene:                                              # 生成指示がある場合(学習後)                       #


トップ   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS