Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Lupin1998 committed May 7, 2023
1 parent f5d55ff commit 20703af
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion openstl/datasets/dataloader_taxibj.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
import numpy as np
import torch
Expand Down Expand Up @@ -42,7 +43,7 @@ def load_data(batch_size, val_batch_size, data_root, num_workers=4,
pre_seq_length=None, aft_seq_length=None, in_shape=None,
distributed=False, use_augment=False, use_prefetcher=False):

dataset = np.load(data_root+'taxibj/dataset.npz')
dataset = np.load(os.path.join(data_root, 'taxibj/dataset.npz'))
X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset[
'Y_train'], dataset['X_test'], dataset['Y_test']
assert X_train.shape[1] == pre_seq_length and Y_train.shape[1] == aft_seq_length
Expand Down
2 changes: 1 addition & 1 deletion openstl/methods/crevnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class CrevNet(Base_method):
"""

def __init__(self, args, device, steps_per_epoch):
args.pre_seq_length = 8
args.pre_seq_length = args.pre_seq_length - 2
args.total_length = args.pre_seq_length + args.aft_seq_length
Base_method.__init__(self, args, device, steps_per_epoch)
self.model = self._build_model(self.config)
Expand Down
1 change: 1 addition & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sphinx==4.0.2
sphinx-copybutton
sphinx_markdown_tables>=0.0.16
sphinx_rtd_theme==0.5.2
urllib3==1.26.15

0 comments on commit 20703af

Please sign in to comment.