|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import argparse |
|
|
import time |
|
|
import re |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
from openai import AzureOpenAI, OpenAI |
|
|
from volcenginesdkarkruntime import Ark |
|
|
from multiprocessing import Pool, Manager, Lock |
|
|
|
|
|
|
|
|
REASONING_MULTIPLE_CHOICE_TEMPLATE = """ |
|
|
You are an AI assistant evaluating video frames to answer a multiple-choice question. |
|
|
The user will provide you with a set of video frames and a question with several options (e.g., A, B, C, D). |
|
|
|
|
|
First, provide a step-by-step reasoning process that analyzes the video frames and leads to your conclusion. |
|
|
After your reasoning, provide the final answer in a JSON block. The JSON object must contain a single key "answer" with the value being one of 'A', 'B', 'C', or 'D'. |
|
|
|
|
|
Your output should follow this format exactly: |
|
|
<Your step-by-step reasoning here> |
|
|
```json |
|
|
{"answer": "A"} |
|
|
``` |
|
|
Do not include any other text after the JSON block. |
|
|
""" |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
""" |
|
|
Parse command line arguments for evaluation configuration. |
|
|
|
|
|
Returns: |
|
|
argparse.Namespace: Parsed command line arguments |
|
|
""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Video QA Evaluation with Pre-computed Similarity Frame Selection" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--target-model", |
|
|
"-tm", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Model to be evaluated (e.g., gpt-4o, gpt-4-vision-preview)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--frame-num", |
|
|
"-fn", |
|
|
type=int, |
|
|
default=32, |
|
|
help="Number of most similar frames to select for each video (default: 32)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frames-path", |
|
|
"-fp", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Absolute path to the base directory containing video frame folders.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-file", |
|
|
"-df", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Absolute path to the JSON file containing the evaluation dataset.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--similarity-file", |
|
|
"-sf", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Absolute path to the pre-computed similarity JSON file (e.g., lv_bench_similarity.json).", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--max-retry-times", |
|
|
"-mr", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Maximum number of retries for API calls (default: 10)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pool-processes", |
|
|
"-pp", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Number of parallel processes for evaluation (default: 20)", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--base_url", type=str, required=True, help="Azure OpenAI endpoint URL." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--api_key", type=str, required=True, help="Azure OpenAI API key." |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def save_json_file(data, output_file): |
|
|
""" |
|
|
Save data to a JSON file. |
|
|
""" |
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, indent=4) |
|
|
|
|
|
|
|
|
def extract_json_from_response(response): |
|
|
""" |
|
|
Extracts a JSON object from a string that contains reasoning followed by a tagged JSON block. |
|
|
""" |
|
|
if not response: |
|
|
return None |
|
|
try: |
|
|
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
|
|
if match: |
|
|
json_str = match.group(1) |
|
|
return json.loads(json_str) |
|
|
return None |
|
|
except (json.JSONDecodeError, IndexError): |
|
|
return None |
|
|
|
|
|
|
|
|
def calculate_metrics(results): |
|
|
""" |
|
|
Calculate evaluation metrics from the results. |
|
|
""" |
|
|
total_samples = len(results) |
|
|
if total_samples == 0: |
|
|
return { |
|
|
"total_samples": 0, |
|
|
"answered_samples": 0, |
|
|
"correct_answers": 0, |
|
|
"accuracy": 0.0, |
|
|
} |
|
|
|
|
|
answered_samples = sum(1 for x in results if x.get("model_answer") is not None) |
|
|
correct_answers = sum(1 for x in results if x.get("is_correct")) |
|
|
|
|
|
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
|
|
|
|
|
return { |
|
|
"total_samples": total_samples, |
|
|
"answered_samples": answered_samples, |
|
|
"correct_answers": correct_answers, |
|
|
"accuracy": accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
def call_single_model(client, messages, model, item_id, max_retry_times): |
|
|
""" |
|
|
Make a single API call to the specified model with retry logic. |
|
|
""" |
|
|
if "doubao" in model: |
|
|
max_tokens = 32768 |
|
|
else: |
|
|
max_tokens = 65535 |
|
|
retry_times = 0 |
|
|
while retry_times < max_retry_times: |
|
|
try: |
|
|
completion = client.chat.completions.create( |
|
|
model=model, messages=messages, max_tokens=max_tokens |
|
|
) |
|
|
return completion.choices[0].message.content |
|
|
except Exception as e: |
|
|
retry_times += 1 |
|
|
print( |
|
|
f"Error processing item {item_id} with model {model}: {str(e)}. Retrying ({retry_times}/{max_retry_times})..." |
|
|
) |
|
|
if retry_times == max_retry_times: |
|
|
error_log_file = f"error_log_{model.replace('/', '_')}.txt" |
|
|
with open(error_log_file, "a") as f: |
|
|
f.write( |
|
|
f"Error processing item {item_id} with model {model} after {max_retry_times} retries: {str(e)}\n" |
|
|
) |
|
|
return None |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
def evaluate_single_item( |
|
|
data_item, frames, target_model, api_key, base_url, max_retry_times |
|
|
): |
|
|
""" |
|
|
Evaluate a single data item using the target model. |
|
|
""" |
|
|
if "ark" in base_url: |
|
|
client = Ark(base_url=base_url, api_key=api_key) |
|
|
elif "aliyun" in base_url or "127.0.0.1" in base_url: |
|
|
client = OpenAI(api_key=api_key, base_url=base_url) |
|
|
else: |
|
|
client = AzureOpenAI( |
|
|
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": REASONING_MULTIPLE_CHOICE_TEMPLATE}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "text": "Here are the video frames:"}, |
|
|
*frames, |
|
|
{"type": "text", "text": f"Question: {data_item['question']}"}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
response = call_single_model( |
|
|
client, messages, target_model, data_item["key"], max_retry_times |
|
|
) |
|
|
|
|
|
is_correct = False |
|
|
model_answer_cleaned = None |
|
|
parsed_json = None |
|
|
|
|
|
if response: |
|
|
parsed_json = extract_json_from_response(response) |
|
|
if parsed_json and "answer" in parsed_json: |
|
|
model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
|
|
gold_answer = data_item["answer"].strip().upper() |
|
|
if model_answer_cleaned == gold_answer: |
|
|
is_correct = True |
|
|
|
|
|
return { |
|
|
**data_item, |
|
|
"model_reasoning_and_answer": response, |
|
|
"model_answer_raw": parsed_json.get("answer") if parsed_json else None, |
|
|
"model_answer": model_answer_cleaned, |
|
|
"is_correct": is_correct, |
|
|
} |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
""" |
|
|
Encode an image file to base64 string. |
|
|
""" |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
|
|
|
def process_frames_from_similarity_file( |
|
|
frames_base_path, frame_num, data_item, similarity_data |
|
|
): |
|
|
""" |
|
|
Select and encode the top N frames using a pre-computed similarity file. |
|
|
""" |
|
|
item_key = data_item["key"] |
|
|
question_uid = str(data_item["uid"]) |
|
|
|
|
|
|
|
|
sorted_filenames = similarity_data.get(question_uid) |
|
|
|
|
|
if not sorted_filenames: |
|
|
print( |
|
|
f"Warning: No similarity data found for question UID '{question_uid}', skipping." |
|
|
) |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
num_frames_to_select = min(frame_num, len(sorted_filenames)) |
|
|
selected_filenames = sorted_filenames[:num_frames_to_select] |
|
|
selected_ids = [int(f.split(".")[0].split("_")[-1]) for f in selected_filenames] |
|
|
selected_ids = sorted(selected_ids) |
|
|
selected_filenames = [f"frame_{i:06d}.jpg" for i in selected_ids] |
|
|
|
|
|
|
|
|
video_frames_path = os.path.join(frames_base_path, item_key) |
|
|
sampled_paths = [os.path.join(video_frames_path, f) for f in selected_filenames] |
|
|
|
|
|
|
|
|
base64_images = [encode_image(path) for path in sampled_paths] |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}, |
|
|
} |
|
|
for b64_img in base64_images |
|
|
] |
|
|
except Exception as e: |
|
|
print(f"Error during frame processing for key '{item_key}': {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def process_single_data( |
|
|
data_item, |
|
|
args, |
|
|
shared_results, |
|
|
progress_counter, |
|
|
total_items, |
|
|
locks, |
|
|
similarity_data, |
|
|
): |
|
|
""" |
|
|
Process a single data item in a multiprocessing context. |
|
|
""" |
|
|
item_key = data_item["key"] |
|
|
try: |
|
|
|
|
|
frames = process_frames_from_similarity_file( |
|
|
args.frames_path, args.frame_num, data_item, similarity_data |
|
|
) |
|
|
|
|
|
if not frames: |
|
|
raise ValueError( |
|
|
f"No frames were processed from similarity file for key '{item_key}'" |
|
|
) |
|
|
|
|
|
result = evaluate_single_item( |
|
|
data_item, |
|
|
frames, |
|
|
args.target_model, |
|
|
args.api_key, |
|
|
args.base_url, |
|
|
args.max_retry_times, |
|
|
) |
|
|
|
|
|
if result is not None: |
|
|
with locks["results"]: |
|
|
shared_results.append(result) |
|
|
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[ |
|
|
0 |
|
|
] |
|
|
model_name_safe = args.target_model.replace("/", "_") |
|
|
output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar" |
|
|
results_output_file = f"{output_prefix}_results.json" |
|
|
save_json_file(list(shared_results), results_output_file) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing video key {item_key}: {str(e)}") |
|
|
with locks["file"]: |
|
|
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
|
|
with open(error_log_file, "a") as f: |
|
|
f.write(f"Critical error processing video key {item_key}: {str(e)}\n") |
|
|
finally: |
|
|
with locks["counter"]: |
|
|
progress_counter.value += 1 |
|
|
print( |
|
|
f"\rProcessed: {progress_counter.value}/{total_items} videos...", |
|
|
end="", |
|
|
flush=True, |
|
|
) |
|
|
|
|
|
|
|
|
def load_test_data(json_file): |
|
|
""" |
|
|
Load test data from a JSON file. |
|
|
""" |
|
|
try: |
|
|
with open(json_file, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Data file not found at {json_file}") |
|
|
exit(1) |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: Could not decode JSON from {json_file}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
def main(): |
|
|
""" |
|
|
Main function to run the video QA evaluation framework. |
|
|
""" |
|
|
args = parse_arguments() |
|
|
|
|
|
print("--- Evaluation Configuration ---") |
|
|
print(f"Target Model: {args.target_model}") |
|
|
print(f"Frames to Sample (by pre-computed similarity): {args.frame_num}") |
|
|
print(f"Frames Base Path: {args.frames_path}") |
|
|
print(f"Similarity File: {args.similarity_file}") |
|
|
print(f"Data File: {args.data_file}") |
|
|
print(f"Parallel Processes: {args.pool_processes}") |
|
|
print("---------------------------------") |
|
|
|
|
|
error_log_file = f"error_log_{args.target_model.replace('/', '_')}.txt" |
|
|
with open(error_log_file, "w") as f: |
|
|
f.write( |
|
|
f"=== Error Log Started at {datetime.now()} for model {args.target_model} ===\n" |
|
|
) |
|
|
|
|
|
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
|
|
model_name_safe = args.target_model.replace("/", "_") |
|
|
output_prefix = f"{model_name_safe}_{data_filename_base}_{args.frame_num}frames_precomputed_similar" |
|
|
|
|
|
results_output_file = f"{output_prefix}_results.json" |
|
|
metrics_output_file = f"{output_prefix}_metrics.json" |
|
|
|
|
|
|
|
|
test_data = load_test_data(args.data_file) |
|
|
try: |
|
|
with open(args.similarity_file, "r", encoding="utf-8") as f: |
|
|
similarity_data = json.load(f) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Similarity file not found at {args.similarity_file}") |
|
|
exit(1) |
|
|
|
|
|
total_videos = len(test_data) |
|
|
print(f"\nLoaded {total_videos} videos to process.") |
|
|
|
|
|
with Manager() as manager: |
|
|
shared_results = manager.list() |
|
|
progress_counter = manager.Value("i", 0) |
|
|
locks = { |
|
|
"results": manager.Lock(), |
|
|
"file": manager.Lock(), |
|
|
"counter": manager.Lock(), |
|
|
} |
|
|
|
|
|
|
|
|
process_func = partial( |
|
|
process_single_data, |
|
|
args=args, |
|
|
shared_results=shared_results, |
|
|
progress_counter=progress_counter, |
|
|
total_items=total_videos, |
|
|
locks=locks, |
|
|
similarity_data=similarity_data, |
|
|
) |
|
|
|
|
|
|
|
|
with Pool(processes=args.pool_processes) as pool: |
|
|
pool.map(process_func, test_data) |
|
|
|
|
|
all_results = list(shared_results) |
|
|
|
|
|
print(f"\n\nProcessing complete for model: {args.target_model}") |
|
|
|
|
|
final_metrics = calculate_metrics(all_results) |
|
|
save_json_file(final_metrics, metrics_output_file) |
|
|
print(f"\nMetrics saved to: {metrics_output_file}") |
|
|
print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
|
save_json_file(all_results, results_output_file) |
|
|
print(f"Detailed results saved to: {results_output_file}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|