File size: 5,435 Bytes
62bb9d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import logging
from typing import Optional

import torch
from comfy_api.input.video_types import VideoInput


def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
    if len(image.shape) == 4:
        return image.shape[1], image.shape[2]
    elif len(image.shape) == 3:
        return image.shape[0], image.shape[1]
    else:
        raise ValueError("Invalid image tensor shape.")


def validate_image_dimensions(
    image: torch.Tensor,
    min_width: Optional[int] = None,
    max_width: Optional[int] = None,
    min_height: Optional[int] = None,
    max_height: Optional[int] = None,
):
    height, width = get_image_dimensions(image)

    if min_width is not None and width < min_width:
        raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
    if max_width is not None and width > max_width:
        raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
    if min_height is not None and height < min_height:
        raise ValueError(
            f"Image height must be at least {min_height}px, got {height}px"
        )
    if max_height is not None and height > max_height:
        raise ValueError(f"Image height must be at most {max_height}px, got {height}px")


def validate_image_aspect_ratio(
    image: torch.Tensor,
    min_aspect_ratio: Optional[float] = None,
    max_aspect_ratio: Optional[float] = None,
):
    width, height = get_image_dimensions(image)
    aspect_ratio = width / height

    if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
        raise ValueError(
            f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
        )
    if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
        raise ValueError(
            f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
        )


def validate_image_aspect_ratio_range(
    image: torch.Tensor,
    min_ratio: tuple[float, float],  # e.g. (1, 4)
    max_ratio: tuple[float, float],  # e.g. (4, 1)
    *,
    strict: bool = True,             # True -> (min, max); False -> [min, max]
) -> float:
    a1, b1 = min_ratio
    a2, b2 = max_ratio
    if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
        raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
    lo, hi = (a1 / b1), (a2 / b2)
    if lo > hi:
        lo, hi = hi, lo
        a1, b1, a2, b2 = a2, b2, a1, b1  # swap only for error text
    w, h = get_image_dimensions(image)
    if w <= 0 or h <= 0:
        raise ValueError(f"Invalid image dimensions: {w}x{h}")
    ar = w / h
    ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
    if not ok:
        op = "<" if strict else "≤"
        raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
    return ar


def validate_aspect_ratio_closeness(
    start_img,
    end_img,
    min_rel: float,
    max_rel: float,
    *,
    strict: bool = False,   # True => exclusive, False => inclusive
) -> None:
    w1, h1 = get_image_dimensions(start_img)
    w2, h2 = get_image_dimensions(end_img)
    if min(w1, h1, w2, h2) <= 0:
        raise ValueError("Invalid image dimensions")
    ar1 = w1 / h1
    ar2 = w2 / h2
    # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
    closeness = max(ar1, ar2) / min(ar1, ar2)
    limit = max(max_rel, 1.0 / min_rel)  # for 0.8..1.25 this is 1.25
    if (closeness >= limit) if strict else (closeness > limit):
        raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")


def validate_video_dimensions(
    video: VideoInput,
    min_width: Optional[int] = None,
    max_width: Optional[int] = None,
    min_height: Optional[int] = None,
    max_height: Optional[int] = None,
):
    try:
        width, height = video.get_dimensions()
    except Exception as e:
        logging.error("Error getting dimensions of video: %s", e)
        return

    if min_width is not None and width < min_width:
        raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
    if max_width is not None and width > max_width:
        raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
    if min_height is not None and height < min_height:
        raise ValueError(
            f"Video height must be at least {min_height}px, got {height}px"
        )
    if max_height is not None and height > max_height:
        raise ValueError(f"Video height must be at most {max_height}px, got {height}px")


def validate_video_duration(
    video: VideoInput,
    min_duration: Optional[float] = None,
    max_duration: Optional[float] = None,
):
    try:
        duration = video.get_duration()
    except Exception as e:
        logging.error("Error getting duration of video: %s", e)
        return

    epsilon = 0.0001
    if min_duration is not None and min_duration - epsilon > duration:
        raise ValueError(
            f"Video duration must be at least {min_duration}s, got {duration}s"
        )
    if max_duration is not None and duration > max_duration + epsilon:
        raise ValueError(
            f"Video duration must be at most {max_duration}s, got {duration}s"
        )


def get_number_of_images(images):
    if isinstance(images, torch.Tensor):
        return images.shape[0] if images.ndim >= 4 else 1
    return len(images)