Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| from starvector.data.util import process_and_rasterize_svg | |
| import torch | |
| import io | |
| USE_BOTH_MODELS = True # Set this to True to load both models | |
| # Load models at startup | |
| models = {} | |
| if USE_BOTH_MODELS: | |
| # Load 8b model | |
| model_name_8b = "starvector/starvector-8b-im2svg" | |
| models['8b'] = { | |
| 'model': AutoModelForCausalLM.from_pretrained(model_name_8b, torch_dtype=torch.float16, trust_remote_code=True), | |
| 'processor': None # Will be set below | |
| } | |
| models['8b']['model'].cuda() | |
| models['8b']['model'].eval() | |
| models['8b']['processor'] = models['8b']['model'].model.processor | |
| # Load 1b model | |
| model_name_1b = "starvector/starvector-1b-im2svg" | |
| models['1b'] = { | |
| 'model': AutoModelForCausalLM.from_pretrained(model_name_1b, torch_dtype=torch.float16, trust_remote_code=True), | |
| 'processor': None | |
| } | |
| models['1b']['model'].cuda() | |
| models['1b']['model'].eval() | |
| models['1b']['processor'] = models['1b']['model'].model.processor | |
| else: | |
| # Load only 8b model | |
| model_name = "starvector/starvector-8b-im2svg" | |
| models['8b'] = { | |
| 'model': AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True), | |
| 'processor': None | |
| } | |
| models['8b']['model'].cuda() | |
| models['8b']['model'].eval() | |
| models['8b']['processor'] = models['8b']['model'].model.processor | |
| def convert_to_svg(image, model_choice): | |
| try: | |
| if image is None: | |
| return None, None, "Please upload an image first" | |
| # Select the model based on user choice | |
| selected_model = models[model_choice]['model'] | |
| selected_processor = models[model_choice]['processor'] | |
| # Process the uploaded image | |
| image_pil = Image.open(image) | |
| image_tensor = selected_processor(image_pil, return_tensors="pt")['pixel_values'].cuda() | |
| if not image_tensor.shape[0] == 1: | |
| image_tensor = image_tensor.squeeze(0) | |
| batch = {"image": image_tensor} | |
| # Generate SVG | |
| raw_svg = selected_model.generate_im2svg(batch, max_length=4000)[0] | |
| svg, raster_image = process_and_rasterize_svg(raw_svg) | |
| # Convert SVG string to bytes for download | |
| svg_bytes = io.BytesIO(svg.encode('utf-8')) | |
| return raster_image, svg_bytes, f"Conversion successful using {model_choice} model!" | |
| except Exception as e: | |
| return None, None, f"Error: {str(e)}" | |
| # Create Blocks interface | |
| with gr.Blocks(title="StarVector") as demo: | |
| gr.Markdown("# StarVector") | |
| gr.Markdown("Upload an image to convert it to SVG format using StarVector model") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| input_image = gr.Image(type="filepath", label="Upload Image") | |
| if USE_BOTH_MODELS: | |
| model_choice = gr.Radio( | |
| choices=["8b", "1b"], | |
| value="8b", | |
| label="Select Model", | |
| info="Choose between 8b (larger) and 1b (smaller) models" | |
| ) | |
| convert_btn = gr.Button("Convert to SVG") | |
| example = gr.Examples( | |
| examples=[["assets/examples/sample-18.png"]], | |
| inputs=input_image | |
| ) | |
| with gr.Column(scale=1): | |
| # Output section | |
| output_preview = gr.Image(type="pil", label="Rasterized SVG Preview") | |
| output_file = gr.File(label="Download SVG") | |
| status = gr.Textbox(label="Status") | |
| # Connect button click to conversion function | |
| inputs = [input_image] | |
| if USE_BOTH_MODELS: | |
| inputs.append(model_choice) | |
| convert_btn.click( | |
| fn=convert_to_svg, | |
| inputs=inputs, | |
| outputs=[output_preview, output_file, status] | |
| ) | |
| # Launch the app | |
| demo.launch() |