-
Notifications
You must be signed in to change notification settings - Fork 7
/
rollout.py
82 lines (70 loc) · 2.42 KB
/
rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding:utf-8 -*-
import os
import random
import math
import copy
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from helpers import *
class Rollout(object):
"""Roll-out policy"""
def __init__(self, model, update_rate):
self.ori_model = model
self.own_model = copy.deepcopy(model)
self.update_rate = update_rate
def get_reward(self, x, discriminator, VOCAB_SIZE, cuda):
"""
Args:
x : (batch_size, seq_len) input data
discriminator : discrimanator model
Directly outputting the prob of one sequence (no rollout)
"""
batch_size = x.size(0)
seq_len = x.size(1)
# samples = self.own_model.sample(batch_size, seq_len, x)
one_hot_samples = convert_to_one_hot(x, VOCAB_SIZE, cuda)
pred = discriminator(one_hot_samples)
pred = pred.cpu().data[:,1].numpy()
return pred
def get_reward_mc(self, x, num, discriminator):
"""
Args:
x : (batch_size, seq_len) input data
num : roll-out number
discriminator : discrimanator model
"""
rewards = []
batch_size = x.size(0)
seq_len = x.size(1)
for i in range(num):
for l in range(1, seq_len):
data = x[:, 0:l]
samples = self.own_model.sample(batch_size, seq_len, data)
pred = discriminator(samples)
pred = pred.cpu().data[:,1].numpy()
if i == 0:
rewards.append(pred)
else:
rewards[l-1] += pred
# for the last token
pred = discriminator(x)
pred = pred.cpu().data[:, 1].numpy()
if i == 0:
rewards.append(pred)
else:
rewards[seq_len-1] += pred
rewards = np.transpose(np.array(rewards)) / (1.0 * num) # batch_size * seq_len
return rewards
def update_params(self):
dic = {}
for name, param in self.ori_model.named_parameters():
dic[name] = param.data
for name, param in self.own_model.named_parameters():
if name.startswith('emb'):
param.data = dic[name]
else:
param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name]