Spaces:
Sleeping
Sleeping
| # streamlit_app.py | |
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import torch | |
| from transformers import pipeline | |
| import time | |
| import os | |
| from io import BytesIO # <-- IMPORT BytesIO | |
| # --- Page Config (MUST BE FIRST st command) --- | |
| # Set page config early | |
| st.set_page_config( | |
| page_title="Depth Blur Studio", | |
| page_icon="๐ธ", | |
| layout="wide" | |
| ) | |
| # --- Import Custom Class --- | |
| # Assuming PortraitBlurrer.py is in a subfolder 'Portrait' relative to this script | |
| try: | |
| # If PortraitBlurrer is in ./Portrait/Portrait.py | |
| from Portrait.Portrait import PortraitBlurrer | |
| except ImportError: | |
| # Fallback if PortraitBlurrer is in ./PortraitBlurrer.py | |
| try: | |
| from PortraitBlurrer import PortraitBlurrer # type: ignore | |
| # st.warning("Assuming PortraitBlurrer class is in the root directory.") # Optional warning | |
| except ImportError: | |
| st.error("Fatal Error: Could not find the PortraitBlurrer class. Please check the file structure and import path.") | |
| st.stop() # Stop execution if class can't be found | |
| # --- Model Loading (Cached) --- | |
| # Use cache_resource for non-data objects like models/pipelines | |
| def load_depth_pipeline(): | |
| """Loads the depth estimation pipeline and caches it. Returns tuple (pipeline, device_id).""" | |
| t_device = 0 if torch.cuda.is_available() else -1 | |
| print(f"Attempting to load model on device: {'GPU (CUDA)' if t_device == 0 else 'CPU'}") | |
| try: | |
| # Use default precision (float32) | |
| t_pipe = pipeline(task="depth-estimation", | |
| model="depth-anything/Depth-Anything-V2-Large-hf", | |
| device=t_device) | |
| print("Depth Anything V2 Large model loaded successfully.") | |
| return t_pipe, t_device # Return pipeline and device used | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Error will be displayed in the main app body after this function returns None | |
| return None, t_device # Return None for pipe on error | |
| # Load the model via the cached function | |
| pipe, device_used = load_depth_pipeline() | |
| # --- Title and Model Status --- | |
| # Display title and info AFTER attempting model load | |
| st.title("Depth Blur Studio ๐ธ (Streamlit)") | |
| st.markdown( | |
| "Upload a portrait image. The model will estimate depth and blur the background, keeping the subject sharp." | |
| "\n*Model: `depth-anything/Depth-Anything-V2-Large-hf`*" | |
| ) | |
| st.caption(f"_(Using device: {'GPU (CUDA)' if device_used == 0 else 'CPU'})_") # Display device info | |
| # Handle model loading failure AFTER potential UI elements like title | |
| if pipe is None: | |
| st.error("Error loading depth estimation model. Application cannot proceed.") | |
| st.stop() # Stop if model loading failed | |
| # --- Processing Function --- | |
| def process_image_blur(pipeline_obj, input_image_pil, max_blur_ksize, depth_thresh, feather_ksize, sharpen_val): | |
| """ | |
| Processes the image using the pipeline and PortraitBlurrer. | |
| Returns tuple: (blurred_pil, depth_pil, mask_pil) or (None, None, None) on failure. | |
| """ | |
| print("Processing image...") | |
| processing_start_time = time.time() | |
| # 1. Convert PIL Image (RGB) to NumPy array (BGR for OpenCV) | |
| input_image_np_rgb = np.array(input_image_pil) | |
| original_bgr_np = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR) | |
| # 2. Perform depth estimation | |
| try: | |
| with torch.no_grad(): # Inference only | |
| depth_output = pipeline_obj(input_image_pil) | |
| # Ensure depth map is PIL Image | |
| if isinstance(depth_output, dict) and "depth" in depth_output: | |
| depth_image_pil = depth_output["depth"] | |
| if not isinstance(depth_image_pil, Image.Image): | |
| # Attempt conversion if it's tensor/numpy (specifics might depend on pipeline output) | |
| # This is a basic attempt; might need refinement based on actual output type | |
| try: | |
| depth_data = np.array(depth_image_pil) | |
| # Normalize if needed (example: scale to 0-255) | |
| depth_data = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) | |
| depth_image_pil = Image.fromarray(depth_data) | |
| except Exception as conversion_e: | |
| print(f"Could not convert depth output to PIL Image: {conversion_e}") | |
| raise ValueError("Depth estimation did not return a usable PIL Image.") | |
| else: | |
| # Handle cases where output might be directly the image or unexpected format | |
| if isinstance(depth_output, Image.Image): | |
| depth_image_pil = depth_output | |
| else: | |
| raise ValueError(f"Unexpected depth estimation output format: {type(depth_output)}") | |
| print("Depth map generated.") | |
| except Exception as e: | |
| print(f"Error during depth estimation: {e}") | |
| st.error(f"Depth estimation failed: {e}") # Show error in UI | |
| return None, None, None | |
| # 3. Initialize Blurrer and Process | |
| portrait_blurrer = PortraitBlurrer( | |
| max_blur=int(max_blur_ksize), | |
| depth_threshold=int(depth_thresh), | |
| feather_strength=int(feather_ksize), | |
| sharpen_strength=float(sharpen_val) # Use the passed sharpen value | |
| ) | |
| try: | |
| # process_image returns blurred_bgr, depth_gray, mask_gray | |
| blurred_bgr_np, refined_depth_np, mask_np = portrait_blurrer.process_image( | |
| original_bgr_np, depth_image_pil | |
| ) | |
| except Exception as e: | |
| print(f"Error during blurring/sharpening: {e}") | |
| st.error(f"Image processing (blur/sharpen) failed: {e}") # Show error in UI | |
| return None, None, None | |
| # 4. Convert results back to RGB PIL Images for Streamlit display | |
| blurred_pil = Image.fromarray(cv2.cvtColor(blurred_bgr_np, cv2.COLOR_BGR2RGB)) | |
| # Depth and mask are grayscale numpy, convert directly to PIL | |
| depth_pil = Image.fromarray(refined_depth_np) | |
| mask_pil = Image.fromarray(mask_np) | |
| processing_end_time = time.time() | |
| processing_duration = processing_end_time - processing_start_time | |
| print(f"Processing finished in {processing_duration:.2f} seconds.") | |
| # Move success message display outside this function, near where results are shown | |
| # st.success(f"Processing finished in {processing_duration:.2f} seconds.") | |
| return blurred_pil, depth_pil, mask_pil, processing_duration # Return duration | |
| # --- Initialize Session State --- (Do this early) | |
| if 'results' not in st.session_state: | |
| st.session_state.results = None # Will store tuple (blurred, depth, mask) or None | |
| if 'original_image_pil' not in st.session_state: | |
| st.session_state.original_image_pil = None | |
| if 'processing_error_occurred' not in st.session_state: | |
| st.session_state.processing_error_occurred = False | |
| if 'current_filename' not in st.session_state: | |
| st.session_state.current_filename = None | |
| if 'last_process_duration' not in st.session_state: | |
| st.session_state.last_process_duration = None | |
| # --- Sidebar for Controls --- | |
| with st.sidebar: # Use 'with' notation for clarity | |
| st.title("Controls") | |
| uploaded_file = st.file_uploader( | |
| "Upload Portrait Image", | |
| type=["jpg", "png", "jpeg"], | |
| label_visibility="collapsed" | |
| ) | |
| # --- Handle New Upload for Instant Display --- | |
| if uploaded_file is not None: | |
| # Check if it's a new file by comparing names | |
| if uploaded_file.name != st.session_state.get('current_filename', None): | |
| print(f"New file uploaded: {uploaded_file.name}. Loading for display.") | |
| try: | |
| # Load the new image immediately | |
| st.session_state.original_image_pil = Image.open(uploaded_file).convert("RGB") | |
| # Clear previous results, error state and duration | |
| st.session_state.results = None | |
| st.session_state.processing_error_occurred = False | |
| st.session_state.last_process_duration = None | |
| # Update the tracked filename | |
| st.session_state.current_filename = uploaded_file.name | |
| except Exception as e: | |
| st.error(f"Error loading image: {e}") | |
| # Clear states if loading failed | |
| st.session_state.original_image_pil = None | |
| st.session_state.results = None | |
| st.session_state.processing_error_occurred = False | |
| st.session_state.current_filename = None | |
| st.session_state.last_process_duration = None | |
| elif st.session_state.current_filename is not None: | |
| # If file uploader is cleared by the user (uploaded_file becomes None) | |
| print("File upload cleared.") | |
| st.session_state.original_image_pil = None | |
| st.session_state.results = None | |
| st.session_state.processing_error_occurred = False | |
| st.session_state.current_filename = None | |
| st.session_state.last_process_duration = None | |
| # --- End Handle New Upload --- | |
| st.markdown("---") # Separator | |
| st.markdown("**Adjust Parameters:**") | |
| slider_max_blur = st.slider("Blur Intensity (Kernel Size)", min_value=3, max_value=101, step=2, value=31) | |
| slider_depth_thr = st.slider("Subject Depth Threshold (Lower=Far Away)", min_value=1, max_value=254, step=1, value=120) | |
| slider_feather = st.slider("Feathering (Mask Smoothness)", min_value=1, max_value=51, step=2, value=5) # <-- Default changed to 5 | |
| # REMOVED: slider_sharpen = st.slider("Subject Sharpening Strength", min_value=0.0, max_value=2.5, step=0.1, value=1.0) | |
| st.markdown("---") # Separator | |
| # Button to trigger processing - disable if no file *loaded* in session state | |
| process_button = st.button( | |
| "Apply Blur", | |
| type="primary", | |
| disabled=(st.session_state.original_image_pil is None) # Disable if no original image is loaded | |
| ) | |
| # --- Main Area for Images --- | |
| col1, col2 = st.columns(2) # Create two columns for Original | Result | |
| # --- Handle Processing Trigger --- | |
| if process_button: # Button is only enabled if original_image_pil exists | |
| if st.session_state.original_image_pil is not None: | |
| # Reset error flag on new processing attempt | |
| st.session_state.processing_error_occurred = False | |
| # Clear previous results and duration before showing spinner | |
| st.session_state.results = None | |
| st.session_state.last_process_duration = None | |
| with col2: # Show spinner in the results column | |
| with st.spinner('Applying blur... This may take a moment...'): | |
| results_output = process_image_blur( | |
| pipeline_obj=pipe, | |
| input_image_pil=st.session_state.original_image_pil, # Use the image from session state | |
| max_blur_ksize=slider_max_blur, | |
| depth_thresh=slider_depth_thr, | |
| feather_ksize=slider_feather, | |
| sharpen_val=1.0 # <-- Hardcoded sharpen value | |
| ) | |
| # Check if processing returned successfully (4 values expected now) | |
| if results_output is not None and len(results_output) == 4: | |
| # Unpack results and store duration separately | |
| blurred_pil, depth_pil, mask_pil, duration = results_output | |
| st.session_state.results = (blurred_pil, depth_pil, mask_pil) # Store tuple | |
| st.session_state.last_process_duration = duration | |
| else: | |
| # Processing failed (returned None or wrong number of items) | |
| st.session_state.results = None # Ensure results are None | |
| st.session_state.processing_error_occurred = True | |
| st.session_state.last_process_duration = None | |
| else: | |
| # This case should technically not happen due to button disable logic, but good practice | |
| st.error("No image loaded to process.") | |
| # --- Display Images based on Session State --- | |
| # Display Original Image in Column 1 if available | |
| if st.session_state.original_image_pil is not None: | |
| col1.image(st.session_state.original_image_pil, caption="Original Image", use_container_width=True) | |
| else: | |
| col1.markdown("### Upload an image") | |
| col1.markdown("Use the sidebar controls to upload your portrait.") | |
| # Display Results/Status in Column 2 | |
| if st.session_state.results is not None: | |
| # Check if the first element (blurred_img) is not None, indicating successful processing within the function | |
| blurred_img, depth_img, mask_img = st.session_state.results | |
| if blurred_img is not None: | |
| # Display success message with duration | |
| if st.session_state.last_process_duration is not None: | |
| st.success(f"Processing finished in {st.session_state.last_process_duration:.2f} seconds.") | |
| col2.image(blurred_img, caption="Blurred Background Result", use_container_width=True) | |
| # --- ADD DOWNLOAD BUTTON --- | |
| # 1. Convert PIL Image to Bytes | |
| buf = BytesIO() | |
| blurred_img.save(buf, format="PNG") # Save image to buffer in PNG format | |
| byte_im = buf.getvalue() # Get bytes from buffer | |
| # 2. Add Download Button | |
| col2.download_button( | |
| label="Download Blurred Image", | |
| data=byte_im, | |
| file_name=f"blurred_{st.session_state.current_filename or 'result'}.png", # Suggest filename based on original | |
| mime="image/png" # Set the MIME type for PNG | |
| ) | |
| # --- END DOWNLOAD BUTTON --- | |
| # Optionally display depth and mask below the main images or in expanders | |
| with st.expander("Show Details (Depth Map & Mask)"): | |
| # Use columns inside expander for better layout if needed | |
| exp_col1, exp_col2 = st.columns(2) | |
| exp_col1.image(depth_img, caption="Refined Depth Map", use_container_width=True) | |
| exp_col2.image(mask_img, caption="Subject Mask", use_container_width=True) | |
| else: | |
| # This case might occur if results tuple was somehow malformed, treat as error | |
| st.session_state.processing_error_occurred = True # Mark as error if blurred_img is None but results tuple exists | |
| col2.error("An unexpected issue occurred during processing. Please check logs or try again.") | |
| # Handle explicit error state OR "Ready to Process" state OR default state | |
| if st.session_state.processing_error_occurred: | |
| # Display specific error message if processing failed after button press | |
| # The error might already be shown by st.error inside process_image_blur, | |
| # but this provides a fallback message in col2. | |
| col2.warning("Image processing failed. Check messages above or terminal logs.") | |
| elif st.session_state.original_image_pil is not None and st.session_state.results is None: | |
| # If file is uploaded/loaded but not processed yet (and no error occurred) | |
| col2.markdown("### Ready to Process") | |
| col2.markdown("Adjust parameters in the sidebar (if needed) and click **Apply Blur**.") | |
| elif st.session_state.original_image_pil is None: | |
| # Default state when no file is uploaded/loaded and nothing processed | |
| col2.markdown("### Results") | |
| col2.markdown("The processed image and details will appear here after uploading an image and clicking 'Apply Blur'.") |