diff --git a/gridworld/visualizer.py b/gridworld/visualizer.py index 77d39d5..1d0249a 100644 --- a/gridworld/visualizer.py +++ b/gridworld/visualizer.py @@ -80,8 +80,8 @@ def render_video(self, output, self.set_agent_state(*init_conds) fourcc = cv2.VideoWriter_fourcc(*'mp4v') output.parent.mkdir(exist_ok=True) - writer = cv2.VideoWriter(f'{output}.mp4', fourcc, fps, self.render_size) image = self.render() + writer = cv2.VideoWriter(f'{output}.mp4', fourcc, fps, image.shape[:2]) writer.write(image[..., ::-1]) tq = tqdm(total=len(event_sequence), disable=not verbose) for _, event in enumerate(event_sequence): @@ -106,4 +106,4 @@ def postproc_video(self, output): os.system(f"mv {output}2.mp4 {output}.mp4") else: raise ValueError('Install the latest version of ffmpeg') - \ No newline at end of file +