Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| import numpy as np | |
| from PIL import Image | |
| try: | |
| from sklearn.cluster import KMeans # type: ignore[import] | |
| except ImportError: | |
| print("Please install sklearn to use this script.") | |
| exit(1) | |
| # Define the folder containing the image and JSON files | |
| subfolder = "/path/to/your/dataset" | |
| output_file = os.path.join(subfolder, "transforms.json") | |
| # List to hold the frames | |
| frames = [] | |
| # Iterate over the files in the folder | |
| for file in sorted(os.listdir(subfolder)): | |
| if file.endswith(".json"): | |
| # Read the JSON file containing camera extrinsics and intrinsics | |
| json_path = os.path.join(subfolder, file) | |
| with open(json_path, "r") as f: | |
| data = json.load(f) | |
| # Read the corresponding image file | |
| image_file = file.replace(".json", ".png") | |
| image_path = os.path.join(subfolder, image_file) | |
| if not os.path.exists(image_path): | |
| print(f"Image file not found for {file}, skipping...") | |
| continue | |
| with Image.open(image_path) as img: | |
| w, h = img.size | |
| # Extract and normalize intrinsic matrix K | |
| K = data["K"] | |
| fx = K[0][0] * w | |
| fy = K[1][1] * h | |
| cx = K[0][2] * w | |
| cy = K[1][2] * h | |
| # Extract the transformation matrix | |
| transform_matrix = np.array(data["c2w"]) | |
| # Adjust for OpenGL convention | |
| transform_matrix[..., [1, 2]] *= -1 | |
| # Add the frame data | |
| frames.append( | |
| { | |
| "fl_x": fx, | |
| "fl_y": fy, | |
| "cx": cx, | |
| "cy": cy, | |
| "w": w, | |
| "h": h, | |
| "file_path": f"./{os.path.relpath(image_path, subfolder)}", | |
| "transform_matrix": transform_matrix.tolist(), | |
| } | |
| ) | |
| # Create the output dictionary | |
| transforms_data = {"orientation_override": "none", "frames": frames} | |
| # Write to the transforms.json file | |
| with open(output_file, "w") as f: | |
| json.dump(transforms_data, f, indent=4) | |
| print(f"transforms.json generated at {output_file}") | |
| # Train-test split function using K-means clustering with stride | |
| def create_train_test_split(frames, n, output_path, stride): | |
| # Prepare the data for K-means | |
| positions = [] | |
| for frame in frames: | |
| transform_matrix = np.array(frame["transform_matrix"]) | |
| position = transform_matrix[:3, 3] # 3D camera position | |
| direction = transform_matrix[:3, 2] / np.linalg.norm( | |
| transform_matrix[:3, 2] | |
| ) # Normalized 3D direction | |
| positions.append(np.concatenate([position, direction])) | |
| positions = np.array(positions) | |
| # Apply K-means clustering | |
| kmeans = KMeans(n_clusters=n, random_state=42) | |
| kmeans.fit(positions) | |
| centers = kmeans.cluster_centers_ | |
| # Find the index closest to each cluster center | |
| train_ids = [] | |
| for center in centers: | |
| distances = np.linalg.norm(positions - center, axis=1) | |
| train_ids.append(int(np.argmin(distances))) # Convert to Python int | |
| # Remaining indices as test_ids, applying stride | |
| all_indices = set(range(len(frames))) | |
| remaining_indices = sorted(all_indices - set(train_ids)) | |
| test_ids = [ | |
| int(idx) for idx in remaining_indices[::stride] | |
| ] # Convert to Python int | |
| # Create the split data | |
| split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids} | |
| with open(output_path, "w") as f: | |
| json.dump(split_data, f, indent=4) | |
| print(f"Train-test split file generated at {output_path}") | |
| # Parse arguments | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Generate train-test split JSON file using K-means clustering." | |
| ) | |
| parser.add_argument( | |
| "--n", | |
| type=int, | |
| required=True, | |
| help="Number of frames to include in the training set.", | |
| ) | |
| parser.add_argument( | |
| "--stride", | |
| type=int, | |
| default=1, | |
| help="Stride for selecting test frames (not used with K-means).", | |
| ) | |
| args = parser.parse_args() | |
| # Create train-test split | |
| train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json") | |
| create_train_test_split(frames, args.n, train_test_split_path, args.stride) | |