Skip to content

Commit

Permalink
Merge pull request jwyang#45 from cclauss/patch-2
Browse files Browse the repository at this point in the history
In __init__() save self.num_data for use in __len__()
  • Loading branch information
jwyang committed Jan 21, 2018
2 parents a48e940 + 43c1afa commit 992bf4d
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions trainval_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,15 @@ def parse_args():

class sampler(Sampler):
def __init__(self, train_size, batch_size):
num_data = train_size
self.num_per_batch = int(num_data / batch_size)
self.num_data = train_size
self.num_per_batch = int(train_size / batch_size)
self.batch_size = batch_size
self.range = torch.arange(0,batch_size).view(1, batch_size).long()
self.leftover_flag = False
if num_data % batch_size:
self.leftover = torch.arange(self.num_per_batch*batch_size, num_data).long()
if train_size % batch_size:
self.leftover = torch.arange(self.num_per_batch*batch_size, train_size).long()
self.leftover_flag = True

def __iter__(self):
rand_num = torch.randperm(self.num_per_batch).view(-1,1) * self.batch_size
self.rand_num = rand_num.expand(self.num_per_batch, self.batch_size) + self.range
Expand All @@ -143,7 +144,7 @@ def __iter__(self):
return iter(self.rand_num_view)

def __len__(self):
return num_data
return self.num_data

if __name__ == '__main__':

Expand Down

0 comments on commit 992bf4d

Please sign in to comment.