Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Union | |
| import torch | |
| import random | |
| import numpy as np | |
| import cv2 | |
| import os | |
| def create_random_mask(batch_size, num_frames, height, width, device, dtype, shape_type=None): | |
| """ | |
| Create random masks for sketch frames. | |
| Args: | |
| batch_size: Batch size | |
| num_frames: Number of frames to mask | |
| height, width: Image dimensions | |
| device: Device for tensor | |
| dtype: Data type for tensor | |
| mask_area_ratio: Ratio of area to mask (0-1) | |
| shape_type: Type of shape for masking ('square', 'circle', 'random'). If None, one is randomly selected. | |
| Returns: | |
| Mask tensor in [b, 1, num_frames, height, width] where 0 indicates areas to mask (inverse of previous implementation) | |
| """ | |
| # Initialize with ones (unmasked) | |
| masks = torch.ones(batch_size, 1, num_frames, height, width, device=device, dtype=dtype) | |
| for b in range(batch_size): | |
| for f in range(num_frames): | |
| # Randomly select shape type if not specified | |
| if shape_type is None: | |
| shape_type = random.choice(['square', 'circle', 'random']) | |
| # Create numpy mask for easier shape drawing | |
| mask = np.zeros((height, width), dtype=np.float32) | |
| if shape_type == 'square': | |
| # Random squares | |
| num_squares = random.randint(1, 5) | |
| for _ in range(num_squares): | |
| # Random square size (proportional to image dimensions) | |
| max_size = min(height, width) | |
| size = random.randint(max_size // 4, max_size) | |
| # Random position | |
| x = random.randint(0, width - size) | |
| y = random.randint(0, height - size) | |
| # Draw square | |
| mask[y:y+size, x:x+size] = 1.0 | |
| elif shape_type == 'circle': | |
| # Random circles | |
| num_circles = random.randint(1, 5) | |
| for _ in range(num_circles): | |
| # Random radius (proportional to image dimensions) | |
| max_radius = min(height, width) // 2 | |
| radius = random.randint(max_radius // 4, max_radius) | |
| # Random center | |
| center_x = random.randint(radius, width - radius) | |
| center_y = random.randint(radius, height - radius) | |
| # Draw circle | |
| cv2.circle(mask, (center_x, center_y), radius, 1.0, -1) | |
| elif shape_type == 'random': | |
| # Create connected random shape with cv2 | |
| num_points = random.randint(5, 16) | |
| points = [] | |
| # Generate random points | |
| for _ in range(num_points): | |
| x = random.randint(0, width - 1) | |
| y = random.randint(0, height - 1) | |
| points.append([x, y]) | |
| # Convert to numpy array for cv2 | |
| points = np.array(points, dtype=np.int32) | |
| # Draw filled polygon | |
| cv2.fillPoly(mask, [points], 1.0) | |
| # Convert numpy mask to tensor and subtract from ones (inverse the mask) | |
| masks[b, 0, f] = 1.0 - torch.from_numpy(mask).to(device=device, dtype=dtype) | |
| return masks | |
| def extract_img_to_sketch(_sketch_model, _img, model_name="random"): | |
| """ | |
| Return sketch: [-1, 1] | |
| """ | |
| orig_shape = (_img.shape[-2], _img.shape[-1]) | |
| with torch.amp.autocast(dtype=torch.float32, device_type="cuda"): | |
| reshaped_img = torch.nn.functional.interpolate(_img, (2048, 2048)) | |
| sketch = _sketch_model(reshaped_img, model_name=model_name) | |
| sketch = torch.nn.functional.interpolate(sketch, orig_shape) | |
| if sketch.shape[1] == 1: | |
| sketch = sketch.repeat(1, 3, 1, 1) | |
| return sketch | |
| def video_to_frame_and_sketch( | |
| sketch_model, | |
| original_video, | |
| max_num_preserved_sketch_frames=2, | |
| max_num_preserved_image_frames=1, | |
| min_num_preserved_sketch_frames=2, | |
| min_num_preserved_image_frames=1, | |
| model_name=None, | |
| detach_image_and_sketch=False, | |
| equally_spaced_preserve_sketch=False, | |
| apply_sketch_mask=False, | |
| sketch_mask_ratio=0.2, | |
| sketch_mask_shape=None, | |
| no_first_sketch: Union[bool, float] = False, | |
| video_clip_names=None, | |
| is_flux_sketch_available=None, | |
| is_evaluation=False, | |
| ): | |
| """ | |
| Args: | |
| sketch_model: torch.nn.Module, a sketch pool for extracting sketches from images | |
| original_video: torch.Tensor, shape=(batch_size, num_channels, num_frames, height, width) | |
| max_num_preserved_sketch_frames: int, maximum number of preserved sketch frames | |
| max_num_preserved_image_frames: int, maximum number of preserved image frames | |
| min_num_preserved_sketch_frames: int, minimum number of preserved sketch frames | |
| min_num_preserved_image_frames: int, minimum number of preserved image frames | |
| model_name: str, name of the sketch model. If None, randomly select from ["lineart", "lineart_anime", "anime2sketch"]. Default: None. | |
| equally_spaced_preserve_sketch: bool, whether to preserve sketches at equally spaced intervals. Default: False. | |
| apply_sketch_mask: bool, whether to apply random masking to sketch frames. Default: False. | |
| sketch_mask_ratio: float, ratio of frames to mask (0-1). Default: 0.2. | |
| sketch_mask_shape: str, shape type for masking ('square', 'circle', 'random'). If None, randomly selected. Default: None. | |
| Returns: | |
| conditional_image: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width) | |
| preserving_image_mask: torch.Tensor, shape=(batch_size, num_frames, 1, height, width) | |
| full_sketch_frames: torch.Tensor, shape=(batch_size, num_frames, num_channels, height, width) | |
| sketch_local_mask: torch.Tensor, shape=(batch_size, 1, num_frames, height, width) or None if apply_sketch_mask=False | |
| """ | |
| video_shape = original_video.shape | |
| video_dtype = original_video.dtype | |
| video_device = original_video.device | |
| if min_num_preserved_sketch_frames is None or min_num_preserved_sketch_frames < 2: | |
| min_num_preserved_sketch_frames = 2 # Minimum num: 2 (the first and the last) | |
| num_preserved_sketch_frames = random.randint(min_num_preserved_sketch_frames, max_num_preserved_sketch_frames) | |
| num_preserved_sketch_frames = min(num_preserved_sketch_frames, video_shape[2]) | |
| # Always include first and last frames | |
| if video_clip_names is not None and is_flux_sketch_available is not None: | |
| if is_flux_sketch_available[0]: | |
| num_preserved_sketch_frames = 2 | |
| if isinstance(no_first_sketch, float): | |
| no_first_sketch = random.random() < no_first_sketch | |
| if equally_spaced_preserve_sketch: | |
| preserved_sketch_indices = torch.linspace(0, video_shape[2] - 1, num_preserved_sketch_frames).long().tolist() | |
| if no_first_sketch: | |
| preserved_sketch_indices = preserved_sketch_indices[1:] | |
| else: | |
| if no_first_sketch: | |
| preserved_sketch_indices = [video_shape[2] - 1] | |
| else: | |
| preserved_sketch_indices = [0, video_shape[2] - 1] | |
| # If we need more frames than just first and last | |
| if num_preserved_sketch_frames > 2 and video_shape[2] > 4: | |
| # Create set of all valid candidates (excluding first, last and their adjacent frames) | |
| # Exclude indices adjacent to first and last | |
| candidates = set(range(2, video_shape[2] - 2)) | |
| # Determine how many additional frames to select | |
| additional_frames_needed = min(num_preserved_sketch_frames - 2, len(candidates)) | |
| # Keep selecting frames until we have enough or run out of candidates | |
| additional_indices = [] | |
| while len(additional_indices) < additional_frames_needed and candidates: | |
| # Convert set to list for random selection | |
| candidate_list = list(candidates) | |
| # Select a random candidate | |
| idx = random.choice(candidate_list) | |
| additional_indices.append(idx) | |
| # Remove selected index and adjacent indices from candidates | |
| candidates.remove(idx) | |
| if idx - 1 in candidates: | |
| candidates.remove(idx - 1) | |
| if idx + 1 in candidates: | |
| candidates.remove(idx + 1) | |
| preserved_sketch_indices.extend(additional_indices) | |
| preserved_sketch_indices.sort() | |
| # Indices to preserve has been determined. | |
| # Later code will not care the number of preserved frames but rely on the indices only. | |
| preserved_image_indices = [0] | |
| if max_num_preserved_image_frames is not None and max_num_preserved_image_frames > 1: | |
| max_num_preserved_image_frames -= 1 | |
| if min_num_preserved_image_frames is None or min_num_preserved_image_frames < 1: | |
| min_num_preserved_image_frames = 1 | |
| min_num_preserved_image_frames -= 1 | |
| other_indices = torch.tensor([i for i in range(video_shape[2]) if i not in preserved_sketch_indices]) | |
| max_num_preserved_image_frames = min(max_num_preserved_image_frames, len(other_indices)) | |
| min_num_preserved_image_frames = min(min_num_preserved_image_frames, max_num_preserved_image_frames) | |
| num_preserved_image_frames = random.randint(min_num_preserved_image_frames, max_num_preserved_image_frames) | |
| other_indices = other_indices[torch.randperm(len(other_indices))] | |
| if num_preserved_image_frames > 0: | |
| preserved_image_indices.extend(other_indices[:num_preserved_image_frames]) | |
| preserved_condition_mask = torch.zeros(size=(video_shape[0], video_shape[2]), dtype=video_dtype, device=video_device) # [b, t] | |
| masked_condition_video = torch.zeros_like(original_video) # [b, c, t, h, w] | |
| full_sketch_frames = torch.zeros_like(original_video) # [b, c, t, h, w] | |
| if detach_image_and_sketch: | |
| preserved_condition_mask_sketch = torch.zeros_like(preserved_condition_mask) | |
| masked_condition_video_sketch = torch.zeros_like(masked_condition_video) | |
| if 0 not in preserved_sketch_indices and not no_first_sketch: | |
| preserved_sketch_indices.append(0) | |
| else: | |
| preserved_condition_mask_sketch = None | |
| masked_condition_video_sketch = None | |
| for _idx in preserved_image_indices: | |
| preserved_condition_mask[:, _idx] = 1.0 | |
| masked_condition_video[:, :, _idx, :, :] = original_video[:, :, _idx, :, :] | |
| # Set up sketch_local_mask if masking is applied | |
| sketch_local_mask = None | |
| if apply_sketch_mask: | |
| # Create a full-sized mask initialized to all ones (unmasked) | |
| sketch_local_mask = torch.ones( | |
| video_shape[0], video_shape[2], video_shape[3], video_shape[4], | |
| device=video_device, | |
| dtype=video_dtype | |
| ).unsqueeze(1) # Add channel dimension to get [b, 1, t, h, w] | |
| if not is_evaluation and random.random() < sketch_mask_ratio: | |
| # For preserved frames, apply random masking | |
| for i, frame_idx in enumerate(preserved_sketch_indices): | |
| if i == 0: | |
| # First frame is not masked | |
| continue | |
| # Create masks only for preserved frames | |
| frame_masks = create_random_mask( | |
| batch_size=video_shape[0], | |
| num_frames=1, # Just one frame at a time | |
| height=video_shape[3], | |
| width=video_shape[4], | |
| device=video_device, | |
| dtype=video_dtype, | |
| # mask_area_ratio=0.4 * random.random() + 0.1, | |
| shape_type=sketch_mask_shape | |
| ) | |
| # Set the mask for this preserved frame | |
| sketch_local_mask[:, :, frame_idx:frame_idx+1, :, :] = frame_masks | |
| # Produce sketches for preserved frames | |
| # Sketches can either be 1) calculated from sketch pool or 2) loaded from the flux sketch directory | |
| if is_flux_sketch_available is not None and is_flux_sketch_available[0]: | |
| should_use_flux_sketch = random.random() < 0.75 if not is_evaluation else True | |
| else: | |
| should_use_flux_sketch = False | |
| cur_model_name = "flux" if should_use_flux_sketch else random.choice(["lineart", "lineart_anime", "anime2sketch"]) if model_name is None else model_name # "anime2sketch" | |
| # cur_model_name = "anyline" | |
| for _idx in preserved_sketch_indices: | |
| sketch_frame = None | |
| if should_use_flux_sketch: | |
| # Load flux sketch | |
| sketech_path = f"/group/40005/gzhiwang/iclora/linearts/{video_clip_names[0]}/{_idx}.lineart.png" | |
| print(f"Loading flux sketch from {sketech_path}...") | |
| if os.path.exists(sketech_path): | |
| sketch_frame = cv2.imread(sketech_path) | |
| sketch_frame = cv2.cvtColor(sketch_frame, cv2.COLOR_BGR2RGB) | |
| # resize to 480p | |
| sketch_frame = cv2.resize(sketch_frame, (video_shape[4], video_shape[3])) | |
| sketch_frame = torch.from_numpy(sketch_frame).to(video_device, dtype=video_dtype) | |
| # Normalize to [-1, 1] | |
| sketch_frame = sketch_frame / 255.0 * 2.0 - 1.0 | |
| sketch_frame = sketch_frame.permute(2, 0, 1) | |
| sketch_frame = sketch_frame.unsqueeze(0) | |
| else: | |
| print(f"FLUX Sketch path {sketech_path} does not exist. Falling back to sketch pool.") | |
| # raise ValueError(f"FLUX Sketch path {sketech_path} does not exist.") | |
| if sketch_frame is None: | |
| # Calculate sketch from sketch pool | |
| sketch_frame = extract_img_to_sketch( | |
| sketch_model, original_video[:, :, _idx, :, :].float(), | |
| model_name=cur_model_name).to(video_device, dtype=video_dtype) | |
| # Convert white BG (from sketch pool or loaded from flux sketch files) to black BG (for training) | |
| sketch_frame = -torch.clip(sketch_frame, -1, 1) | |
| full_sketch_frames[:, :, _idx, :, :] = sketch_frame | |
| if len(preserved_sketch_indices) > 0: | |
| _mask_to_add = preserved_condition_mask_sketch if detach_image_and_sketch else preserved_condition_mask | |
| _video_to_add = masked_condition_video_sketch if detach_image_and_sketch else masked_condition_video | |
| if not detach_image_and_sketch: | |
| preserved_sketch_indices = preserved_sketch_indices[1:] | |
| # Apply masking to sketch frames if required | |
| if apply_sketch_mask and sketch_local_mask is not None: | |
| # sketch_local_mask: [b, 1, t, h, w] | |
| for _idx in preserved_sketch_indices: | |
| _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0 | |
| _video_to_add[:, :, _idx, :, :] = torch.where(sketch_local_mask[:, 0:1, _idx, :, :] == 0, -1.0, full_sketch_frames[:, :, _idx, :, :]) | |
| else: | |
| for _idx in preserved_sketch_indices: | |
| _mask_to_add[:, _idx] = 1.0 if detach_image_and_sketch else -1.0 | |
| _video_to_add[:, :, _idx, :, :] = full_sketch_frames[:, :, _idx, :, :] | |
| return masked_condition_video, preserved_condition_mask, masked_condition_video_sketch, preserved_condition_mask_sketch, full_sketch_frames, sketch_local_mask, cur_model_name | |