From 68693ac560123a04931575eca8e0fcc197e2f5a6 Mon Sep 17 00:00:00 2001 From: MarcusLoppe Date: Wed, 27 Dec 2023 00:56:09 +0100 Subject: [PATCH 1/2] added mesh_dataset --- meshgpt_pytorch/mesh_dataset.py | 107 ++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 meshgpt_pytorch/mesh_dataset.py diff --git a/meshgpt_pytorch/mesh_dataset.py b/meshgpt_pytorch/mesh_dataset.py new file mode 100644 index 00000000..9ce58ee0 --- /dev/null +++ b/meshgpt_pytorch/mesh_dataset.py @@ -0,0 +1,107 @@ +from torch.utils.data import Dataset +import numpy as np +from meshgpt_pytorch import ( + MeshAutoencoder, + MeshTransformer +) + +from meshgpt_pytorch.data import ( + derive_face_edges_from_faces +) + +class MeshDataset(Dataset): + """ + A PyTorch Dataset to load and process mesh data. + The `MeshDataset` provides functions to load mesh data from a file, embed text information, generate face edges, and generate codes. + + Attributes: + data (list): A list of mesh data entries. Each entry is a dictionary containing the following keys: + vertices (torch.Tensor): A tensor of vertices with shape (num_vertices, 3). + faces (torch.Tensor): A tensor of faces with shape (num_faces, 3). + text (str): A string containing the associated text information for the mesh. + text_embeds (torch.Tensor): A tensor of text embeddings for the mesh. + face_edges (torch.Tensor): A tensor of face edges with shape (num_faces, num_edges). + codes (torch.Tensor): A tensor of codes generated from the mesh data. + + Example usage: + + ``` + data = [ + {'vertices': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'faces': torch.tensor([[0, 1, 2]], dtype=torch.long), 'text': 'table'}, + {'vertices': torch.tensor([[10, 20, 30], [40, 50, 60]], dtype=torch.float32), 'faces': torch.tensor([[1, 2, 0]], dtype=torch.long), "text": "chair"}, + ] + + # Create a MeshDataset instance + mesh_dataset = MeshDataset(data) + + # Save the MeshDataset to disk + mesh_dataset.save('mesh_dataset.npz') + + # Load the MeshDataset from disk + loaded_mesh_dataset = MeshDataset.load('mesh_dataset.npz') + + # Generate face edges so it doesn't need to be done every time during training + dataset.generate_face_edges() + ``` + """ + def __init__(self, data): + self.data = data + print(f"[MeshDataset] Created from {len(self.data)} entrys") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + data = self.data[idx] + return data + + def save(self, path): + np.savez_compressed(path, self.data, allow_pickle=True) + print(f"[MeshDataset] Saved {len(self.data)} entrys at {path}") + + @classmethod + def load(cls, path): + loaded_data = np.load(path, allow_pickle=True) + data = [] + for item in loaded_data["arr_0"]: + data.append(item) + print(f"[MeshDataset] Loaded {len(data)} entrys") + return cls(data) + + + + def generate_face_edges(self): + for i in range(0, len(self.data)): + item = self.data[i] + item['face_edges'] = derive_face_edges_from_faces(item['faces']) + + desired_order = ['vertices', 'faces', 'face_edges', 'texts'] + self.data = [ + {key: d[key] for key in desired_order} for d in self.data + ] + print(f"[MeshDataset] Generated face_edges for {len(self.data)} entrys") + + def generate_codes(self, autoencoder : MeshAutoencoder): + for i in range(0, len(self.data)): + item = self.data[i] + + codes = autoencoder.tokenize( + vertices = item['vertices'], + faces = item['faces'], + face_edges = item['face_edges'] + ) + item['codes'] = codes + + print(f"[MeshDataset] Generated codes for {len(self.data)} entrys") + + def embed_texts(self, transformer : MeshTransformer): + unique_texts = set(item['texts'] for item in self.data) + + text_embeddings = transformer.embed_texts(list(unique_texts)) + print(f"[MeshDataset] Generated {len(text_embeddings)} text_embeddings") + text_embedding_dict = dict(zip(unique_texts, text_embeddings)) + + for item in self.data: + text_value = item['texts'] + item['text_embeds'] = text_embedding_dict.get(text_value, None) + del item['texts'] \ No newline at end of file From 11df7b789a19f23ebd9289d8775593045e9d573b Mon Sep 17 00:00:00 2001 From: MarcusLoppe Date: Fri, 5 Jan 2024 20:48:48 +0100 Subject: [PATCH 2/2] update+ init fix --- meshgpt_pytorch/__init__.py | 5 ++++- meshgpt_pytorch/mesh_dataset.py | 38 +++++++++++++++++---------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/meshgpt_pytorch/__init__.py b/meshgpt_pytorch/__init__.py index 37db1d82..f2253628 100644 --- a/meshgpt_pytorch/__init__.py +++ b/meshgpt_pytorch/__init__.py @@ -12,4 +12,7 @@ from meshgpt_pytorch.data import ( DatasetFromTransforms, cache_text_embeds_for_dataset -) \ No newline at end of file +) +from meshgpt_pytorch.mesh_dataset import ( + MeshDataset +) diff --git a/meshgpt_pytorch/mesh_dataset.py b/meshgpt_pytorch/mesh_dataset.py index 9ce58ee0..201f311b 100644 --- a/meshgpt_pytorch/mesh_dataset.py +++ b/meshgpt_pytorch/mesh_dataset.py @@ -66,25 +66,25 @@ def load(cls, path): for item in loaded_data["arr_0"]: data.append(item) print(f"[MeshDataset] Loaded {len(data)} entrys") - return cls(data) - - - - def generate_face_edges(self): - for i in range(0, len(self.data)): - item = self.data[i] - item['face_edges'] = derive_face_edges_from_faces(item['faces']) - - desired_order = ['vertices', 'faces', 'face_edges', 'texts'] + return cls(data) + def sort_dataset_keys(self): + desired_order = ['vertices', 'faces', 'face_edges', 'texts','text_embeds','codes'] self.data = [ - {key: d[key] for key in desired_order} for d in self.data + {key: d[key] for key in desired_order if key in d} for d in self.data ] - print(f"[MeshDataset] Generated face_edges for {len(self.data)} entrys") + + def generate_face_edges(self): + i = 0 + for item in self.data: + if 'face_edges' not in item: + item['face_edges'] = derive_face_edges_from_faces(item['faces']) + i += 1 + + self.sort_dataset_keys() + print(f"[MeshDataset] Generated face_edges for {i}/{len(self.data)} entrys") def generate_codes(self, autoencoder : MeshAutoencoder): - for i in range(0, len(self.data)): - item = self.data[i] - + for item in self.data: codes = autoencoder.tokenize( vertices = item['vertices'], faces = item['faces'], @@ -92,6 +92,7 @@ def generate_codes(self, autoencoder : MeshAutoencoder): ) item['codes'] = codes + self.sort_dataset_keys() print(f"[MeshDataset] Generated codes for {len(self.data)} entrys") def embed_texts(self, transformer : MeshTransformer): @@ -102,6 +103,7 @@ def embed_texts(self, transformer : MeshTransformer): text_embedding_dict = dict(zip(unique_texts, text_embeddings)) for item in self.data: - text_value = item['texts'] - item['text_embeds'] = text_embedding_dict.get(text_value, None) - del item['texts'] \ No newline at end of file + if 'texts' in item: + item['text_embeds'] = text_embedding_dict.get(item['texts'], None) + del item['texts'] + self.sort_dataset_keys() \ No newline at end of file