davanstrien HF Staff commited on
Commit
76273a4
Β·
1 Parent(s): ddf6b7f

Add PaddleOCR-VL script for document processing with vLLM support

Browse files
Files changed (1) hide show
  1. paddleocr-vl.py +676 -0
paddleocr-vl.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.11"
3
+ # dependencies = [
4
+ # "datasets",
5
+ # "huggingface-hub[hf_transfer]",
6
+ # "pillow",
7
+ # "vllm",
8
+ # "tqdm",
9
+ # "toolz",
10
+ # "torch",
11
+ # ]
12
+ #
13
+ # [[tool.uv.index]]
14
+ # url = "https://wheels.vllm.ai/nightly"
15
+ #
16
+ # [tool.uv]
17
+ # prerelease = "allow"
18
+ # ///
19
+
20
+ """
21
+ Convert document images to text/tables/formulas using PaddleOCR-VL with vLLM.
22
+
23
+ PaddleOCR-VL is a compact 0.9B OCR model with task-specific capabilities for
24
+ document parsing. It combines a NaViT-style dynamic resolution visual encoder
25
+ with the ERNIE-4.5-0.3B language model for accurate element recognition.
26
+
27
+ Features:
28
+ - 🎯 Ultra-compact: Only 0.9B parameters (smallest OCR model)
29
+ - πŸ“ OCR mode: General text extraction to markdown
30
+ - πŸ“Š Table mode: HTML table recognition and extraction
31
+ - πŸ“ Formula mode: LaTeX mathematical notation
32
+ - πŸ“ˆ Chart mode: Structured chart analysis
33
+ - 🌍 Multilingual support
34
+ - ⚑ Fast initialization due to small size
35
+ - πŸ”§ Based on ERNIE-4.5 (different from Qwen-based models)
36
+
37
+ Model: PaddlePaddle/PaddleOCR-VL
38
+ vLLM: Requires nightly build for full support
39
+ """
40
+
41
+ import argparse
42
+ import base64
43
+ import io
44
+ import json
45
+ import logging
46
+ import math
47
+ import os
48
+ import sys
49
+ from typing import Any, Dict, List, Union
50
+ from datetime import datetime
51
+
52
+ import torch
53
+ from datasets import load_dataset
54
+ from huggingface_hub import DatasetCard, login
55
+ from PIL import Image
56
+ from toolz import partition_all
57
+ from tqdm.auto import tqdm
58
+ from vllm import LLM, SamplingParams
59
+
60
+ logging.basicConfig(level=logging.INFO)
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ # Task mode configurations from official PaddleOCR-VL documentation
65
+ TASK_MODES = {
66
+ "ocr": "OCR:",
67
+ "table": "Table Recognition:",
68
+ "formula": "Formula Recognition:",
69
+ "chart": "Chart Recognition:",
70
+ }
71
+
72
+ # Task descriptions for dataset card
73
+ TASK_DESCRIPTIONS = {
74
+ "ocr": "General text extraction to markdown format",
75
+ "table": "Table extraction to HTML format",
76
+ "formula": "Mathematical formula recognition to LaTeX",
77
+ "chart": "Chart and diagram analysis",
78
+ }
79
+
80
+
81
+ def check_cuda_availability():
82
+ """Check if CUDA is available and exit if not."""
83
+ if not torch.cuda.is_available():
84
+ logger.error("CUDA is not available. This script requires a GPU.")
85
+ logger.error("Please run on a machine with a CUDA-capable GPU.")
86
+ sys.exit(1)
87
+ else:
88
+ logger.info(f"CUDA is available. GPU: {torch.cuda.get_device_name(0)}")
89
+
90
+
91
+ def smart_resize(
92
+ height: int,
93
+ width: int,
94
+ factor: int = 28,
95
+ min_pixels: int = 28 * 28 * 130,
96
+ max_pixels: int = 28 * 28 * 1280,
97
+ ) -> tuple[int, int]:
98
+ """
99
+ PaddleOCR-VL's intelligent resize logic.
100
+
101
+ Rescales the image so that:
102
+ 1. Both dimensions are divisible by 'factor' (28)
103
+ 2. Total pixels are within [min_pixels, max_pixels]
104
+ 3. Aspect ratio is maintained as closely as possible
105
+
106
+ Args:
107
+ height: Original image height
108
+ width: Original image width
109
+ factor: Dimension divisibility factor (default: 28)
110
+ min_pixels: Minimum total pixels (default: 100,880)
111
+ max_pixels: Maximum total pixels (default: 1,003,520)
112
+
113
+ Returns:
114
+ Tuple of (new_height, new_width)
115
+ """
116
+ if height < factor:
117
+ width = round((width * factor) / height)
118
+ height = factor
119
+
120
+ if width < factor:
121
+ height = round((height * factor) / width)
122
+ width = factor
123
+
124
+ if max(height, width) / min(height, width) > 200:
125
+ logger.warning(
126
+ f"Extreme aspect ratio detected: {max(height, width) / min(height, width):.1f}"
127
+ )
128
+ # Continue anyway, but warn about potential issues
129
+
130
+ h_bar = round(height / factor) * factor
131
+ w_bar = round(width / factor) * factor
132
+
133
+ if h_bar * w_bar > max_pixels:
134
+ beta = math.sqrt((height * width) / max_pixels)
135
+ h_bar = math.floor(height / beta / factor) * factor
136
+ w_bar = math.floor(width / beta / factor) * factor
137
+ elif h_bar * w_bar < min_pixels:
138
+ beta = math.sqrt(min_pixels / (height * width))
139
+ h_bar = math.ceil(height * beta / factor) * factor
140
+ w_bar = math.ceil(width * beta / factor) * factor
141
+
142
+ return h_bar, w_bar
143
+
144
+
145
+ def make_ocr_message(
146
+ image: Union[Image.Image, Dict[str, Any], str],
147
+ task_mode: str = "ocr",
148
+ apply_smart_resize: bool = True,
149
+ ) -> List[Dict]:
150
+ """
151
+ Create chat message for PaddleOCR-VL processing.
152
+
153
+ PaddleOCR-VL expects a specific format with the task prefix after the image.
154
+ """
155
+ # Convert to PIL Image if needed
156
+ if isinstance(image, Image.Image):
157
+ pil_img = image
158
+ elif isinstance(image, dict) and "bytes" in image:
159
+ pil_img = Image.open(io.BytesIO(image["bytes"]))
160
+ elif isinstance(image, str):
161
+ pil_img = Image.open(image)
162
+ else:
163
+ raise ValueError(f"Unsupported image type: {type(image)}")
164
+
165
+ # Convert to RGB
166
+ pil_img = pil_img.convert("RGB")
167
+
168
+ # Apply smart resize if requested
169
+ if apply_smart_resize:
170
+ original_size = pil_img.size
171
+ new_width, new_height = smart_resize(pil_img.height, pil_img.width)
172
+ if (new_width, new_height) != (pil_img.width, pil_img.height):
173
+ pil_img = pil_img.resize((new_width, new_height), Image.Resampling.LANCZOS)
174
+ logger.debug(f"Resized image from {original_size} to {pil_img.size}")
175
+
176
+ # Convert to base64 data URI
177
+ buf = io.BytesIO()
178
+ pil_img.save(buf, format="PNG")
179
+ data_uri = f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"
180
+
181
+ # PaddleOCR-VL message format: image first, then task prefix
182
+ return [
183
+ {
184
+ "role": "user",
185
+ "content": [
186
+ {"type": "image_url", "image_url": {"url": data_uri}},
187
+ {"type": "text", "text": TASK_MODES[task_mode]},
188
+ ],
189
+ }
190
+ ]
191
+
192
+
193
+ def create_dataset_card(
194
+ source_dataset: str,
195
+ model: str,
196
+ task_mode: str,
197
+ num_samples: int,
198
+ processing_time: str,
199
+ batch_size: int,
200
+ max_model_len: int,
201
+ max_tokens: int,
202
+ gpu_memory_utilization: float,
203
+ temperature: float,
204
+ apply_smart_resize: bool,
205
+ image_column: str = "image",
206
+ split: str = "train",
207
+ ) -> str:
208
+ """Create a dataset card documenting the OCR process."""
209
+ task_description = TASK_DESCRIPTIONS[task_mode]
210
+
211
+ return f"""---
212
+ tags:
213
+ - ocr
214
+ - document-processing
215
+ - paddleocr-vl
216
+ - {task_mode}
217
+ - uv-script
218
+ - generated
219
+ ---
220
+
221
+ # Document Processing using PaddleOCR-VL ({task_mode.upper()} mode)
222
+
223
+ This dataset contains {task_mode.upper()} results from images in [{source_dataset}](https://huggingface.co/datasets/{source_dataset}) using PaddleOCR-VL, an ultra-compact 0.9B OCR model.
224
+
225
+ ## Processing Details
226
+
227
+ - **Source Dataset**: [{source_dataset}](https://huggingface.co/datasets/{source_dataset})
228
+ - **Model**: [{model}](https://huggingface.co/{model})
229
+ - **Task Mode**: `{task_mode}` - {task_description}
230
+ - **Number of Samples**: {num_samples:,}
231
+ - **Processing Time**: {processing_time}
232
+ - **Processing Date**: {datetime.now().strftime("%Y-%m-%d %H:%M UTC")}
233
+
234
+ ### Configuration
235
+
236
+ - **Image Column**: `{image_column}`
237
+ - **Output Column**: `paddleocr_{task_mode}`
238
+ - **Dataset Split**: `{split}`
239
+ - **Batch Size**: {batch_size}
240
+ - **Smart Resize**: {"Enabled" if apply_smart_resize else "Disabled"}
241
+ - **Max Model Length**: {max_model_len:,} tokens
242
+ - **Max Output Tokens**: {max_tokens:,}
243
+ - **Temperature**: {temperature}
244
+ - **GPU Memory Utilization**: {gpu_memory_utilization:.1%}
245
+
246
+ ## Model Information
247
+
248
+ PaddleOCR-VL is a state-of-the-art, resource-efficient model tailored for document parsing:
249
+ - 🎯 **Ultra-compact** - Only 0.9B parameters (smallest OCR model)
250
+ - πŸ“ **OCR mode** - General text extraction
251
+ - πŸ“Š **Table mode** - HTML table recognition
252
+ - πŸ“ **Formula mode** - LaTeX mathematical notation
253
+ - πŸ“ˆ **Chart mode** - Structured chart analysis
254
+ - 🌍 **Multilingual** - Support for multiple languages
255
+ - ⚑ **Fast** - Quick initialization and inference
256
+ - πŸ”§ **ERNIE-4.5 based** - Different architecture from Qwen models
257
+
258
+ ### Task Modes
259
+
260
+ - **OCR**: Extract text content to markdown format
261
+ - **Table Recognition**: Extract tables to HTML format
262
+ - **Formula Recognition**: Extract mathematical formulas to LaTeX
263
+ - **Chart Recognition**: Analyze and describe charts/diagrams
264
+
265
+ ## Dataset Structure
266
+
267
+ The dataset contains all original columns plus:
268
+ - `paddleocr_{task_mode}`: The extracted content based on task mode
269
+ - `inference_info`: JSON list tracking all OCR models applied to this dataset
270
+
271
+ ## Usage
272
+
273
+ ```python
274
+ from datasets import load_dataset
275
+ import json
276
+
277
+ # Load the dataset
278
+ dataset = load_dataset("{{output_dataset_id}}", split="{split}")
279
+
280
+ # Access the extracted content
281
+ for example in dataset:
282
+ print(example["paddleocr_{task_mode}"])
283
+ break
284
+
285
+ # View all OCR models applied to this dataset
286
+ inference_info = json.loads(dataset[0]["inference_info"])
287
+ for info in inference_info:
288
+ print(f"Task: {{info['task_mode']}} - Model: {{info['model_id']}}")
289
+ ```
290
+
291
+ ## Reproduction
292
+
293
+ This dataset was generated using the [uv-scripts/ocr](https://huggingface.co/datasets/uv-scripts/ocr) PaddleOCR-VL script:
294
+
295
+ ```bash
296
+ uv run https://huggingface.co/datasets/uv-scripts/ocr/raw/main/paddleocr-vl.py \\
297
+ {source_dataset} \\
298
+ <output-dataset> \\
299
+ --task-mode {task_mode} \\
300
+ --image-column {image_column} \\
301
+ --batch-size {batch_size} \\
302
+ --max-model-len {max_model_len} \\
303
+ --max-tokens {max_tokens} \\
304
+ --gpu-memory-utilization {gpu_memory_utilization}
305
+ ```
306
+
307
+ ## Performance
308
+
309
+ - **Model Size**: 0.9B parameters (smallest among OCR models)
310
+ - **Processing Speed**: ~{num_samples / (float(processing_time.split()[0]) * 60):.2f} images/second
311
+ - **Architecture**: NaViT visual encoder + ERNIE-4.5-0.3B language model
312
+
313
+ Generated with πŸ€– [UV Scripts](https://huggingface.co/uv-scripts)
314
+ """
315
+
316
+
317
+ def main(
318
+ input_dataset: str,
319
+ output_dataset: str,
320
+ image_column: str = "image",
321
+ batch_size: int = 16,
322
+ task_mode: str = "ocr",
323
+ max_model_len: int = 8192,
324
+ max_tokens: int = 4096,
325
+ temperature: float = 0.0,
326
+ gpu_memory_utilization: float = 0.8,
327
+ apply_smart_resize: bool = True,
328
+ hf_token: str = None,
329
+ split: str = "train",
330
+ max_samples: int = None,
331
+ private: bool = False,
332
+ shuffle: bool = False,
333
+ seed: int = 42,
334
+ output_column: str = None,
335
+ ):
336
+ """Process images from HF dataset through PaddleOCR-VL model."""
337
+
338
+ # Check CUDA availability first
339
+ check_cuda_availability()
340
+
341
+ # Track processing start time
342
+ start_time = datetime.now()
343
+
344
+ # Enable HF_TRANSFER for faster downloads
345
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
346
+
347
+ # Login to HF if token provided
348
+ HF_TOKEN = hf_token or os.environ.get("HF_TOKEN")
349
+ if HF_TOKEN:
350
+ login(token=HF_TOKEN)
351
+
352
+ # Validate task mode
353
+ if task_mode not in TASK_MODES:
354
+ raise ValueError(
355
+ f"Invalid task_mode '{task_mode}'. Choose from: {list(TASK_MODES.keys())}"
356
+ )
357
+
358
+ # Auto-generate output column name based on task mode
359
+ if output_column is None:
360
+ output_column = f"paddleocr_{task_mode}"
361
+
362
+ logger.info(f"Using task mode: {task_mode} - {TASK_DESCRIPTIONS[task_mode]}")
363
+ logger.info(f"Output will be written to column: {output_column}")
364
+
365
+ # Load dataset
366
+ logger.info(f"Loading dataset: {input_dataset}")
367
+ dataset = load_dataset(input_dataset, split=split)
368
+
369
+ # Validate image column
370
+ if image_column not in dataset.column_names:
371
+ raise ValueError(
372
+ f"Column '{image_column}' not found. Available: {dataset.column_names}"
373
+ )
374
+
375
+ # Shuffle if requested
376
+ if shuffle:
377
+ logger.info(f"Shuffling dataset with seed {seed}")
378
+ dataset = dataset.shuffle(seed=seed)
379
+
380
+ # Limit samples if requested
381
+ if max_samples:
382
+ dataset = dataset.select(range(min(max_samples, len(dataset))))
383
+ logger.info(f"Limited to {len(dataset)} samples")
384
+
385
+ # Initialize vLLM model
386
+ model_name = "PaddlePaddle/PaddleOCR-VL"
387
+ logger.info(f"Initializing vLLM with {model_name}")
388
+ logger.info("This may take a minute on first run (model is only 0.9B)...")
389
+
390
+ llm = LLM(
391
+ model=model_name,
392
+ trust_remote_code=True,
393
+ max_model_len=max_model_len,
394
+ gpu_memory_utilization=gpu_memory_utilization,
395
+ limit_mm_per_prompt={"image": 1},
396
+ )
397
+
398
+ # Sampling parameters - deterministic for OCR
399
+ sampling_params = SamplingParams(
400
+ temperature=temperature,
401
+ max_tokens=max_tokens,
402
+ )
403
+
404
+ logger.info(f"Processing {len(dataset)} images in batches of {batch_size}")
405
+ if apply_smart_resize:
406
+ logger.info("Smart resize enabled (PaddleOCR-VL's adaptive resolution)")
407
+
408
+ # Process images in batches
409
+ all_outputs = []
410
+
411
+ for batch_indices in tqdm(
412
+ partition_all(batch_size, range(len(dataset))),
413
+ total=(len(dataset) + batch_size - 1) // batch_size,
414
+ desc=f"PaddleOCR-VL {task_mode.upper()} processing",
415
+ ):
416
+ batch_indices = list(batch_indices)
417
+ batch_images = [dataset[i][image_column] for i in batch_indices]
418
+
419
+ try:
420
+ # Create messages for batch with task-specific prefix
421
+ batch_messages = [
422
+ make_ocr_message(img, task_mode=task_mode, apply_smart_resize=apply_smart_resize)
423
+ for img in batch_images
424
+ ]
425
+
426
+ # Process with vLLM
427
+ outputs = llm.chat(batch_messages, sampling_params)
428
+
429
+ # Extract outputs
430
+ for output in outputs:
431
+ text = output.outputs[0].text.strip()
432
+ all_outputs.append(text)
433
+
434
+ except Exception as e:
435
+ logger.error(f"Error processing batch: {e}")
436
+ # Add error placeholders for failed batch
437
+ all_outputs.extend([f"[{task_mode.upper()} ERROR]"] * len(batch_images))
438
+
439
+ # Calculate processing time
440
+ processing_duration = datetime.now() - start_time
441
+ processing_time_str = f"{processing_duration.total_seconds() / 60:.1f} min"
442
+
443
+ # Add output column to dataset
444
+ logger.info(f"Adding '{output_column}' column to dataset")
445
+ dataset = dataset.add_column(output_column, all_outputs)
446
+
447
+ # Handle inference_info tracking (for multi-model comparisons)
448
+ inference_entry = {
449
+ "model_id": model_name,
450
+ "model_name": "PaddleOCR-VL",
451
+ "model_size": "0.9B",
452
+ "task_mode": task_mode,
453
+ "column_name": output_column,
454
+ "timestamp": datetime.now().isoformat(),
455
+ "temperature": temperature,
456
+ "max_tokens": max_tokens,
457
+ "smart_resize": apply_smart_resize,
458
+ }
459
+
460
+ if "inference_info" in dataset.column_names:
461
+ # Append to existing inference info
462
+ logger.info("Updating existing inference_info column")
463
+
464
+ def update_inference_info(example):
465
+ try:
466
+ existing_info = json.loads(example["inference_info"]) if example["inference_info"] else []
467
+ except (json.JSONDecodeError, TypeError):
468
+ existing_info = []
469
+
470
+ existing_info.append(inference_entry)
471
+ return {"inference_info": json.dumps(existing_info)}
472
+
473
+ dataset = dataset.map(update_inference_info)
474
+ else:
475
+ # Create new inference_info column
476
+ logger.info("Creating new inference_info column")
477
+ inference_list = [json.dumps([inference_entry])] * len(dataset)
478
+ dataset = dataset.add_column("inference_info", inference_list)
479
+
480
+ # Push to hub
481
+ logger.info(f"Pushing to {output_dataset}")
482
+ dataset.push_to_hub(output_dataset, private=private, token=HF_TOKEN)
483
+
484
+ # Create and push dataset card
485
+ logger.info("Creating dataset card")
486
+ card_content = create_dataset_card(
487
+ source_dataset=input_dataset,
488
+ model=model_name,
489
+ task_mode=task_mode,
490
+ num_samples=len(dataset),
491
+ processing_time=processing_time_str,
492
+ batch_size=batch_size,
493
+ max_model_len=max_model_len,
494
+ max_tokens=max_tokens,
495
+ gpu_memory_utilization=gpu_memory_utilization,
496
+ temperature=temperature,
497
+ apply_smart_resize=apply_smart_resize,
498
+ image_column=image_column,
499
+ split=split,
500
+ )
501
+
502
+ card = DatasetCard(card_content)
503
+ card.push_to_hub(output_dataset, token=HF_TOKEN)
504
+
505
+ logger.info("βœ… PaddleOCR-VL processing complete!")
506
+ logger.info(f"Dataset available at: https://huggingface.co/datasets/{output_dataset}")
507
+ logger.info(f"Processing time: {processing_time_str}")
508
+ logger.info(f"Task mode: {task_mode} - {TASK_DESCRIPTIONS[task_mode]}")
509
+
510
+
511
+ if __name__ == "__main__":
512
+ # Show example usage if no arguments
513
+ if len(sys.argv) == 1:
514
+ print("=" * 80)
515
+ print("PaddleOCR-VL Document Processing")
516
+ print("=" * 80)
517
+ print("\nUltra-compact 0.9B OCR model with task-specific capabilities")
518
+ print("\nFeatures:")
519
+ print("- 🎯 Smallest OCR model - Only 0.9B parameters")
520
+ print("- πŸ“ OCR mode - General text extraction")
521
+ print("- πŸ“Š Table mode - HTML table recognition")
522
+ print("- πŸ“ Formula mode - LaTeX mathematical notation")
523
+ print("- πŸ“ˆ Chart mode - Structured chart analysis")
524
+ print("- 🌍 Multilingual support")
525
+ print("- ⚑ Fast initialization and inference")
526
+ print("- πŸ”§ Based on ERNIE-4.5 (unique architecture)")
527
+ print("\nTask Modes:")
528
+ for mode, description in TASK_DESCRIPTIONS.items():
529
+ print(f" {mode:8} - {description}")
530
+ print("\nExample usage:")
531
+ print("\n1. Basic OCR (default mode):")
532
+ print(" uv run paddleocr-vl.py input-dataset output-dataset")
533
+ print("\n2. Table extraction:")
534
+ print(" uv run paddleocr-vl.py docs tables-extracted --task-mode table")
535
+ print("\n3. Formula recognition:")
536
+ print(" uv run paddleocr-vl.py papers formulas --task-mode formula --batch-size 32")
537
+ print("\n4. Chart analysis:")
538
+ print(" uv run paddleocr-vl.py diagrams charts-analyzed --task-mode chart")
539
+ print("\n5. Test with small sample:")
540
+ print(" uv run paddleocr-vl.py dataset test --max-samples 10 --shuffle")
541
+ print("\n6. Running on HF Jobs:")
542
+ print(" hf jobs uv run --flavor l4x1 \\")
543
+ print(" -e HF_TOKEN=$(python3 -c \"from huggingface_hub import get_token; print(get_token())\") \\")
544
+ print(" -e HF_HUB_ENABLE_HF_TRANSFER=1 \\")
545
+ print(" https://huggingface.co/datasets/uv-scripts/ocr/raw/main/paddleocr-vl.py \\")
546
+ print(" input-dataset output-dataset --task-mode ocr")
547
+ print("\n" + "=" * 80)
548
+ print("\nFor full help, run: uv run paddleocr-vl.py --help")
549
+ sys.exit(0)
550
+
551
+ parser = argparse.ArgumentParser(
552
+ description="Document processing using PaddleOCR-VL (0.9B task-specific model)",
553
+ formatter_class=argparse.RawDescriptionHelpFormatter,
554
+ epilog="""
555
+ Task Modes:
556
+ ocr General text extraction to markdown (default)
557
+ table Table extraction to HTML format
558
+ formula Mathematical formula recognition to LaTeX
559
+ chart Chart and diagram analysis
560
+
561
+ Examples:
562
+ # Basic text OCR
563
+ uv run paddleocr-vl.py my-docs analyzed-docs
564
+
565
+ # Extract tables from documents
566
+ uv run paddleocr-vl.py papers tables --task-mode table
567
+
568
+ # Recognize mathematical formulas
569
+ uv run paddleocr-vl.py textbooks formulas --task-mode formula
570
+
571
+ # Analyze charts and diagrams
572
+ uv run paddleocr-vl.py reports charts --task-mode chart
573
+
574
+ # Test with random sampling
575
+ uv run paddleocr-vl.py large-dataset test --max-samples 50 --shuffle --task-mode ocr
576
+
577
+ # Disable smart resize for original resolution
578
+ uv run paddleocr-vl.py images output --no-smart-resize
579
+ """,
580
+ )
581
+
582
+ parser.add_argument("input_dataset", help="Input dataset ID from Hugging Face Hub")
583
+ parser.add_argument("output_dataset", help="Output dataset ID for Hugging Face Hub")
584
+ parser.add_argument(
585
+ "--image-column",
586
+ default="image",
587
+ help="Column containing images (default: image)",
588
+ )
589
+ parser.add_argument(
590
+ "--batch-size",
591
+ type=int,
592
+ default=16,
593
+ help="Batch size for processing (default: 16)",
594
+ )
595
+ parser.add_argument(
596
+ "--task-mode",
597
+ choices=list(TASK_MODES.keys()),
598
+ default="ocr",
599
+ help="Task type: ocr (default), table, formula, or chart",
600
+ )
601
+ parser.add_argument(
602
+ "--max-model-len",
603
+ type=int,
604
+ default=8192,
605
+ help="Maximum model context length (default: 8192)",
606
+ )
607
+ parser.add_argument(
608
+ "--max-tokens",
609
+ type=int,
610
+ default=4096,
611
+ help="Maximum tokens to generate (default: 4096)",
612
+ )
613
+ parser.add_argument(
614
+ "--temperature",
615
+ type=float,
616
+ default=0.0,
617
+ help="Sampling temperature (default: 0.0 for deterministic)",
618
+ )
619
+ parser.add_argument(
620
+ "--gpu-memory-utilization",
621
+ type=float,
622
+ default=0.8,
623
+ help="GPU memory utilization (default: 0.8)",
624
+ )
625
+ parser.add_argument(
626
+ "--no-smart-resize",
627
+ action="store_true",
628
+ help="Disable PaddleOCR-VL's smart resize, use original image size",
629
+ )
630
+ parser.add_argument("--hf-token", help="Hugging Face API token")
631
+ parser.add_argument(
632
+ "--split", default="train", help="Dataset split to use (default: train)"
633
+ )
634
+ parser.add_argument(
635
+ "--max-samples",
636
+ type=int,
637
+ help="Maximum number of samples to process (for testing)",
638
+ )
639
+ parser.add_argument(
640
+ "--private", action="store_true", help="Make output dataset private"
641
+ )
642
+ parser.add_argument(
643
+ "--shuffle", action="store_true", help="Shuffle dataset before processing"
644
+ )
645
+ parser.add_argument(
646
+ "--seed",
647
+ type=int,
648
+ default=42,
649
+ help="Random seed for shuffling (default: 42)",
650
+ )
651
+ parser.add_argument(
652
+ "--output-column",
653
+ help="Column name for output (default: paddleocr_[task_mode])",
654
+ )
655
+
656
+ args = parser.parse_args()
657
+
658
+ main(
659
+ input_dataset=args.input_dataset,
660
+ output_dataset=args.output_dataset,
661
+ image_column=args.image_column,
662
+ batch_size=args.batch_size,
663
+ task_mode=args.task_mode,
664
+ max_model_len=args.max_model_len,
665
+ max_tokens=args.max_tokens,
666
+ temperature=args.temperature,
667
+ gpu_memory_utilization=args.gpu_memory_utilization,
668
+ apply_smart_resize=not args.no_smart_resize,
669
+ hf_token=args.hf_token,
670
+ split=args.split,
671
+ max_samples=args.max_samples,
672
+ private=args.private,
673
+ shuffle=args.shuffle,
674
+ seed=args.seed,
675
+ output_column=args.output_column,
676
+ )