File size: 1,886 Bytes
b39b584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""Example of batch processing multiple images with PANDORA."""

import os
from pathlib import Path

from src.pandora_removal import PandoraRemoval, PandoraConfig


def main():
    """Demonstrate batch processing of multiple images."""

    # Configure the model
    config = PandoraConfig(
        model_path="stabilityai/stable-diffusion-2-1",
        device="cuda",
        max_steps=50,
        guidance_scale_ladg=7.5
    )

    # Initialize and load model
    print("Initializing PANDORA model...")
    model = PandoraRemoval(config=config)
    model.load_model()
    print("βœ“ Model loaded successfully!")

    # Define dataset structure
    # Expected structure:
    # dataset/
    #   β”œβ”€β”€ Images/
    #   β”‚   β”œβ”€β”€ 001.jpg
    #   β”‚   β”œβ”€β”€ 002.jpg
    #   β”‚   └── ...
    #   └── Masks/
    #       β”œβ”€β”€ 001.png
    #       β”œβ”€β”€ 002.png
    #       └── ...

    dataset_path = "path/to/your/dataset"
    images_dir = os.path.join(dataset_path, "Images")
    masks_dir = os.path.join(dataset_path, "Masks")
    output_dir = "output/batch_results"

    # Check if directories exist
    if not Path(images_dir).exists():
        print(f"❌ Images directory not found: {images_dir}")
        print("Please update the dataset_path variable.")
        return

    if not Path(masks_dir).exists():
        print(f"❌ Masks directory not found: {masks_dir}")
        print("Please update the dataset_path variable.")
        return

    # Process batch
    print(f"\nProcessing images from: {images_dir}")
    print(f"Using masks from: {masks_dir}")
    print(f"Saving results to: {output_dir}\n")

    model.batch_process(
        images_dir=images_dir,
        masks_dir=masks_dir,
        output_dir=output_dir,
        border_size=17
    )

    print("\nβœ“ Batch processing complete!")


if __name__ == "__main__":
    main()