- 追加された行はこの色です。
- 削除された行はこの色です。
#author("2019-05-30T06:47:32+09:00","default:Authors","Authors")
Evalの次にGene
# #
parser.add_argument('--gene', type=int, default=0, help='whether to generate') # #
parser.add_argument('--checkpoint', type=str, default='./PTB.pt', # #
help='model checkpoint to use') # #
# #
if 2 == args.gene: # 生成指示がある場合(単独) #
with open(args.checkpoint, 'rb') as f: # ptファイルを読込み #
model = torch.load(f) # モデルのファイル読込 #
# #
generate(input) # 生成実施 #
# #
sys.exit(args.gene) # 単独生成時の後続処理スキップ #
# #
#* Generation *************************************************## 生成処理 *#
def generate(input, batch_size=1): # #
model.eval() # 推測モードにしてドロップアウト無効化 #
ntokens = len(corpus.dictionary) # トークン数取得 #
if args.model == 'QRNN': model.reset() # QRNN指示がある場合、各層初期化 #
# #
hidden = model.init_hidden(batch_size) # 中間層初期化依頼 #
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True) # なんでもいいから1単語ランダムに生成 #
if args.cuda: # CUDA指示がある場合 #
input.data = input.data.cuda() # GPU処理のため #
# #
with open(args.outf, 'w') as outf: # 出力ファイルをオープンして #
for i in range(args.words): # 出力する単語数分繰返し #
output, hidden = model(input, hidden) # forward依頼 #
word_weights = output.squeeze().data.div(args.temperature).exp().cpu() #
word_idx = torch.multinomial(word_weights, 1)[0] #
input.data.fill_(word_idx) #
word = corpus.dictionary.idx2word[word_idx] #
# #
outf.write(word + ('\n' if i % 20 == 19 else ' ')) #
# #
if i % args.log_interval == 0:
print('| Generated {}/{} words'.format(i, args.words))
# data, targets = get_batch(data_source, i, args, evaluation=True) # バッチデータ化依頼 #
# output, hidden = model(data, hidden) # forward依頼 #
# total_loss += \
# len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data # 損失合計算出 #
# hidden = repackage_hidden(hidden) # 中間層静止依頼 #
# return total_loss.item() / len(data_source) # 平均誤差を返して処理終了 #
#/*Generation *************************************************## 生成処理 *#
# #
if 1 == args.gene: # 生成指示がある場合(学習後) #