Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import earthview as ev | |
| import utils | |
| import random | |
| import pandas as pd | |
| import os | |
| from itertools import islice | |
| # Configuration | |
| chunk_size = 100 # Size of the chunks to shuffle | |
| label_file = os.path.join(os.path.dirname(__file__), "labels.csv") # Save CSV in the same directory as the script | |
| # Load the Satellogic dataset (streaming) | |
| dataset = ev.load_dataset("satellogic", streaming=True) | |
| data_iter = iter(dataset) | |
| shuffled_chunk = [] # Initialize an empty list to hold the current chunk | |
| chunk_iter = None # Initialize the chunk iterator | |
| # Initialize or load labels DataFrame | |
| labels_df = None | |
| if os.path.exists(label_file): | |
| labels_df = pd.read_csv(label_file) | |
| else: | |
| labels_df = pd.DataFrame(columns=["image_id", "bounds", "rating", "google_maps_link"]) | |
| def get_next_image(): | |
| global data_iter, labels_df, shuffled_chunk, chunk_iter | |
| while True: | |
| # If we don't have a current chunk or it's exhausted, get a new one | |
| if not shuffled_chunk or chunk_iter is None: | |
| chunk = list(islice(data_iter, chunk_size)) | |
| if not chunk: # If the dataset is exhausted, reset the iterator | |
| print("Dataset exhausted, resetting iterator.") | |
| reset_dataset_iterator() # Use the reset function | |
| chunk = list(islice(data_iter, chunk_size)) | |
| if not chunk: | |
| print("Still no data after reset.") | |
| return None, "Dataset exhausted", None, None | |
| random.shuffle(chunk) | |
| shuffled_chunk = chunk | |
| chunk_iter = iter(shuffled_chunk) | |
| try: | |
| sample = next(chunk_iter) | |
| sample = ev.item_to_images("satellogic", sample) | |
| image = sample["rgb"][0] | |
| metadata = sample["metadata"] | |
| bounds = metadata["bounds"] | |
| google_maps_link = utils.get_google_map_link(sample, "satellogic") | |
| image_id = str(bounds) | |
| if labels_df is not None and image_id not in labels_df["image_id"].values: | |
| return image, image_id, bounds, google_maps_link | |
| elif labels_df is None: # Handle case where labels_df is not initialized yet | |
| return image, image_id, bounds, google_maps_link | |
| except StopIteration: | |
| # Current chunk is exhausted, reset chunk variables to get a new one in the next iteration | |
| shuffled_chunk = [] | |
| chunk_iter = None | |
| def rate_image(image_id, bounds, rating): | |
| global labels_df | |
| new_row = pd.DataFrame( | |
| { | |
| "image_id": [image_id], | |
| "bounds": [bounds], | |
| "rating": [rating], | |
| "google_maps_link": [""], # this isn't necessary to pass to the function since we aren't updating it here. | |
| } | |
| ) | |
| labels_df = pd.concat([labels_df, new_row], ignore_index=True) | |
| labels_df.to_csv(label_file, index=False) | |
| next_image, next_image_id, next_bounds, next_google_maps_link = get_next_image() | |
| return next_image, next_image_id, next_bounds, next_google_maps_link | |
| def save_labels_parquet(): | |
| global labels_df | |
| if labels_df is not None and not labels_df.empty: | |
| table = pa.Table.from_pandas(labels_df) | |
| pq.write_table(table, 'labeled_data.parquet') | |
| return 'labeled_data.parquet' | |
| else: | |
| return None | |
| def reset_dataset_iterator(): | |
| global data_iter, shuffled_chunk, chunk_iter | |
| data_iter = iter(ev.load_dataset("satellogic", streaming=True)) | |
| shuffled_chunk = [] | |
| chunk_iter = None | |
| def load_different_batch(): | |
| print("Loading a different batch of images...") | |
| reset_dataset_iterator() | |
| return get_next_image() # Return the first image from the new batch | |
| # Gradio interface | |
| with gr.Blocks() as iface: | |
| image_out = gr.Image(label="Satellite Image") | |
| image_id_out = gr.Textbox(label="Image ID", visible=False) | |
| bounds_out = gr.Textbox(label="Bounds", visible=False) | |
| google_maps_link_out = gr.Textbox(label="Google Maps Link", visible=True) | |
| rating_radio = gr.Radio(["Cool", "Not Cool"], label="Rating") | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit Rating") | |
| different_batch_button = gr.Button("Load Different Batch") # New button | |
| download_button = gr.Button("Download Labels (Parquet)") | |
| download_output = gr.File(label="Download Labeled Data") | |
| submit_button.click( | |
| fn=rate_image, | |
| inputs=[image_id_out, bounds_out, rating_radio], | |
| outputs=[image_out, image_id_out, bounds_out, google_maps_link_out], | |
| ) | |
| download_button.click( | |
| fn=save_labels_parquet, | |
| inputs=[], | |
| outputs=[download_output], | |
| ) | |
| different_batch_button.click( | |
| fn=load_different_batch, | |
| inputs=[], | |
| outputs=[image_out, image_id_out, bounds_out, google_maps_link_out] | |
| ) | |
| # Load the first image | |
| initial_image, initial_image_id, initial_bounds, initial_google_maps_link = get_next_image() | |
| # Set initial values | |
| if initial_image: | |
| iface.load(lambda: (initial_image, initial_image_id, initial_bounds, initial_google_maps_link), | |
| inputs=None, | |
| outputs=[image_out, image_id_out, bounds_out, google_maps_link_out]) | |
| iface.launch(share=True) |