Skip to content

Commit

Permalink
Fixed bug for loading trajectories not collected on your machine
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKhazatsky committed Apr 21, 2023
1 parent 204a584 commit a1b9a3f
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 47 deletions.
28 changes: 14 additions & 14 deletions r2d2/camera_utils/recording_readers/mp4_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,23 @@
class MP4Reader:
def __init__(self, filepath, serial_number):
# Save Parameters #
self.filepath = filepath
self.serial_number = serial_number
self._index = 0

# Open Video Reader #
self._mp4_reader = cv2.VideoCapture(filepath)
if not self._mp4_reader.isOpened():
import pdb; pdb.set_trace()
print(filepath)
raise RuntimeError

# Load Recording Timestamps #
timestamp_filepath = filepath[:-4] + "_timestamps.json"
if not os.path.isfile(timestamp_filepath):
self._recording_timestamps = []
with open(timestamp_filepath, "r") as jsonFile:
self._recording_timestamps = json.load(jsonFile)

def set_reading_parameters(
self,
image=True,
Expand All @@ -30,19 +43,6 @@ def set_reading_parameters(
if self.skip_reading:
return

# Open Video Reader #
self._mp4_reader = cv2.VideoCapture(self.filepath)
if not self._mp4_reader.isOpened():
print(self.filepath)
return # raise RuntimeError

# Load Recording Timestamps #
timestamp_filepath = self.filepath[:-4] + "_timestamps.json"
if not os.path.isfile(timestamp_filepath):
self._recording_timestamps = []
with open(timestamp_filepath, "r") as jsonFile:
self._recording_timestamps = json.load(jsonFile)

def get_frame_resolution(self):
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
Expand Down
22 changes: 11 additions & 11 deletions r2d2/camera_utils/recording_readers/svo_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
class SVOReader:
def __init__(self, filepath, serial_number):
# Save Parameters #
self.filepath = filepath
self.serial_number = serial_number
self._index = 0

Expand All @@ -26,6 +25,16 @@ def __init__(self, filepath, serial_number):
self._left_pointcloud = sl.Mat()
self._right_pointcloud = sl.Mat()

# Set SVO path for playback
init_parameters = sl.InitParameters()
init_parameters.set_from_svo_file(filepath)

# Open the ZED
self._cam = sl.Camera()
status = self._cam.open(init_parameters)
if status != sl.ERROR_CODE.SUCCESS:
print("Zed Error: " + repr(status))

def set_reading_parameters(
self,
image=True,
Expand Down Expand Up @@ -53,16 +62,6 @@ def set_reading_parameters(
if self.skip_reading:
return

# Set SVO path for playback
init_parameters = sl.InitParameters()
init_parameters.set_from_svo_file(self.filepath)

# Open the ZED
self._cam = sl.Camera()
status = self._cam.open(init_parameters)
if status != sl.ERROR_CODE.SUCCESS:
print("Zed Error: " + repr(status))

def get_frame_resolution(self):
camera_info = self._cam.get_camera_information().camera_configuration
width = camera_info.resolution.width
Expand Down Expand Up @@ -109,6 +108,7 @@ def read_camera(self, ignore_data=False, correct_timestamp=None, return_timestam
timestamp_error = (correct_timestamp is not None) and (correct_timestamp != received_time)

if timestamp_error:
import pdb; pdb.set_trace()
print("Timestamps did not match...")
return None

Expand Down
11 changes: 9 additions & 2 deletions r2d2/camera_utils/wrappers/recorded_multi_camera_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

class RecordedMultiCameraWrapper:
def __init__(self, recording_folderpath, camera_kwargs={}):
# Save Camera Info #
self.camera_kwargs = camera_kwargs

# Open Camera Readers #
svo_filepaths = glob.glob(recording_folderpath + "/*.svo")
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
Expand All @@ -28,19 +31,23 @@ def __init__(self, recording_folderpath, camera_kwargs={}):
raise ValueError

self.camera_dict[serial_number] = Reader(f, serial_number)
self.camera_dict[serial_number].set_reading_parameters(**curr_cam_kwargs)

def read_cameras(self, index=None, timestamp_dict={}):
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}):
full_obs_dict = defaultdict(dict)

# Read Cameras In Randomized Order #
all_cam_ids = list(self.camera_dict.keys())
random.shuffle(all_cam_ids)

for cam_id in all_cam_ids:
cam_type = camera_type_dict[cam_id]
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)

timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
if index is not None:
self.camera_dict[cam_id].set_frame_index(index)

data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)

# Process Returned Data #
Expand Down
4 changes: 3 additions & 1 deletion r2d2/trajectory_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from r2d2.calibration.calibration_utils import *
from r2d2.camera_utils.wrappers.recorded_multi_camera_wrapper import RecordedMultiCameraWrapper
from r2d2.camera_utils.info import camera_type_to_string_dict
from r2d2.misc.parameters import *
from r2d2.misc.time import time_ms
from r2d2.misc.transformations import change_pose_frame
Expand Down Expand Up @@ -352,7 +353,8 @@ def load_trajectory(
# If Applicable, Get Recorded Data #
if read_recording_folderpath:
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
camera_obs = camera_reader.read_cameras(index=i, timestamp_dict=timestamp_dict)
camera_type_dict = {k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()}
camera_obs = camera_reader.read_cameras(index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict)
camera_failed = camera_obs is None

# Add Data To Timestep If Successful #
Expand Down
21 changes: 19 additions & 2 deletions scripts/post_processing/svo_to_mp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm

from r2d2.camera_utils.recording_readers.svo_reader import SVOReader
from r2d2.training.data_loading.trajectory_sampler import collect_data_folderpaths
from r2d2.data_loading.trajectory_sampler import collect_data_folderpaths


def convert_svo_to_mp4(filepath, recording_folderpath):
Expand Down Expand Up @@ -44,7 +44,7 @@ def convert_svo_to_mp4(filepath, recording_folderpath):
with open(timestamp_output_path, "w") as jsonFile:
json.dump(received_timestamps, jsonFile)


corrupted_traj = []
all_folderpaths = collect_data_folderpaths()
for folderpath in tqdm(all_folderpaths):
recording_folderpath = os.path.join(folderpath, "recordings")
Expand All @@ -71,7 +71,24 @@ def convert_svo_to_mp4(filepath, recording_folderpath):
serial_number = f.split("/")[-1][:-4]
if not any([serial_number in f for f in mp4_filepaths]):
files_to_convert.append(f)
for f in mp4_filepaths:
reader = cv2.VideoCapture(f)
if reader.isOpened():
reader.release()
else:
files_to_convert.append(f)

# Convert Files #
for f in files_to_convert:
convert_svo_to_mp4(f, recording_folderpath)

# Check Success #
num_mp4 = len(glob.glob(mp4_folderpath + "/*.mp4"))
num_svo = len(svo_filepaths)

if num_svo > num_mp4:
corrupted_traj.append(folderpath)

print('The following trajectories are corrupted: ')
for folderpath in corrupted_traj:
print(folderpath)
12 changes: 6 additions & 6 deletions scripts/tests/memory_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pyzed.sl as sl
from tqdm import tqdm

from r2d2.camera_utils.readers.recorded_zed_camera import RecordedZedCamera
from r2d2.camera_utils.recording_readers.svo_reader import SVOReader
from r2d2.camera_utils.wrappers.recorded_multi_camera_wrapper import RecordedMultiCameraWrapper
from r2d2.training.data_loading.trajectory_sampler import *
from r2d2.data_loading.trajectory_sampler import *


def example_script():
Expand Down Expand Up @@ -85,9 +85,9 @@ def traj_sampling_script():
def load_random_traj_script():
folderpath = random.choice(train_folderpaths)
filepath = os.path.join(folderpath, "trajectory.h5")
recording_folderpath = os.path.join(folderpath, "recordings")
recording_folderpath = os.path.join(folderpath, "recordings/MP4")

load_trajectory(
samples = load_trajectory(
filepath,
recording_folderpath=recording_folderpath,
read_cameras=True,
Expand Down Expand Up @@ -194,8 +194,8 @@ def single_reader_script():
curr_mem_usage = psutil.virtual_memory()[3]
memory_usage.append(curr_mem_usage)

single_reader_script()
# load_random_traj_script()
#single_reader_script()
load_random_traj_script()

try:
plt.plot(memory_usage)
Expand Down
14 changes: 3 additions & 11 deletions scripts/training/train_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from r2d2.training.model_trainer import exp_launcher

task_label_filepath = "/home/sasha/Desktop/R2D2/scripts/labeling/task_label_filepath.json"
task_label_filepath = "/home/sasha/R2D2/scripts/labeling/task_label_filepath.json"
with open(task_label_filepath, "r") as jsonFile:
task_labels = json.load(jsonFile)

Expand All @@ -22,14 +22,6 @@ def filter_func(h5_metadata, put_in_only=False):
weight_decay=1e-4,
lr=1e-4,
),
# camera_kwargs=dict(
# hand_camera=dict(
# image=True, depth=False, pointcloud=False, concatenate_images=False,
# resolution=(128, 128)),
# varied_camera=dict(
# image=True, depth=False, pointcloud=False, concatenate_images=False,
# resolution=(128, 128)),
# ),
camera_kwargs=dict(
hand_camera=dict(image=True, concatenate_images=False, resolution=(128, 128), resize_func="cv2"),
varied_camera=dict(image=False, concatenate_images=False, resolution=(128, 128), resize_func="cv2"),
Expand All @@ -51,8 +43,8 @@ def filter_func(h5_metadata, put_in_only=False):
recording_prefix="MP4",
batch_size=4,
prefetch_factor=1,
buffer_size=200,
num_workers=2,
buffer_size=1000,
num_workers=4,
data_filtering_kwargs=dict(
train_p=0.9,
remove_failures=True,
Expand Down

0 comments on commit a1b9a3f

Please sign in to comment.