-
Notifications
You must be signed in to change notification settings - Fork 7
/
data_loader.py
110 lines (83 loc) · 3.47 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import random
import numpy as np
import torch
# from main import BATCH_SIZE, VOCAB_SIZE, g_sequence_len
class DataLoader:
def __init__(self, file_path, batch_size=16):
self.batch_size = batch_size
self.char_to_ix = {
'x': 0,
'+': 1,
'-': 2,
'*': 3,
'/': 4,
'_': 5,
#'\n': 6
}
self.ix_to_char = {v:k for (k,v) in self.char_to_ix.items()}
self.readFile(file_path)
self.idx = 0
def __len__(self):
pass
def __iter__(self):
return self
def __next__(self):
return self.next()
def reset(self):
self.idx = 0
random.shuffle(self.lines)
def next(self):
# iterator edge case
if self.idx >= self.total_lines:
raise StopIteration
# figure out end_index based on what is left in the list
if(self.idx + self.batch_size < self.total_lines):
end_index = self.idx + self.batch_size
else:
end_index = self.total_lines
# contains list of strings (length is batch_size and each element of this list is math eq string)
batch_lines = self.lines[self.idx : end_index]
#increment idx (bookeeping for iterator)
self.idx += self.batch_size
# contains input data to be returned
all_input_data = []
# contains target data to be returned
all_target_data = []
for i,line in enumerate(batch_lines):
# convert char to index (do this for input data and target data)
# here input data and target data are staggered by one position
input_data = [self.char_to_ix[c] for c in line]
# target doesn't contain the first char, add 6 (maps to '\n') to end
if i == end_index-1:
print('break here')
target_data = input_data[1:]
target_data.append(random.choice([1,2,3,4]))
# print(f"line {i}. input_data = {input_data}, target_data = {target_data}")
all_input_data.append(input_data)
all_target_data.append(target_data)
# convert to torch long tensor (ready to be used by nn.Embedding)
all_input_data = torch.from_numpy(np.asarray(all_input_data)).long()
all_target_data = torch.from_numpy(np.asarray(all_target_data)).long()
return all_input_data, all_target_data
def readFile(self, file_path):
with open(file_path, 'r') as f:
self.lines = f.read().split('\n')
self.total_lines = len(self.lines)
def frequency(self, file_path, vocab_size=5, seq_len=15):
freq_arr = np.zeros((vocab_size,vocab_size))
with open(file_path, 'r') as f:
self.lines = f.read().split('\n')
chars = list(self.lines)
for i in range(1,len(chars)):
freq_arr[self.char_to_ix.get(chars[i-1]),self.char_to_ix.get(chars[i])]+=1
if seq_len == 15:
np.save('freq_array.npy',freq_arr/np.sum(freq_arr))
elif seq_len == 3:
np.save('freq_array_3.npy', freq_arr / np.sum(freq_arr))
self.total_lines = len(self.lines)
def convert_to_char(self, data):
string_arr = []
for each_tensor in data:
string = ''.join([self.ix_to_char[i] for i in each_tensor.data.numpy()])
string_arr.append(string)
return string_arr