File size: 5,087 Bytes
1c54a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39416f2
 
1c54a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import json
import sys
import os
import pickle

import faiss
import gradio as gr
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoImageProcessor, AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
from datasets import load_dataset
from hazm import Normalizer


DATASET_NAME = 'parsi-ai-nlpclass/tourist-attraction-data'
TEST_DATA_NAME = 'parsi-ai-nlpclass/tourist-attraction-test-data'

dataset = load_dataset(DATASET_NAME, streaming=True)
test_data_name = load_dataset(TEST_DATA_NAME, streaming=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vision_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
vision_model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

language_model = SentenceTransformer("xmanii/maux-gte-persian", trust_remote_code=True).to(device)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
    "universitytehran/PersianMind-v1.0",
    quantization_config=quantization_config,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "universitytehran/PersianMind-v1.0",
)

normalizer = Normalizer()

language_model.eval()
vision_model.eval()

# Load FAISS indices
text_index = faiss.read_index("text.index")
image_index = faiss.read_index("image.index")

# Load the index-item mapping
with open("idx_item_mapping.pkl", "rb") as f:
    idx_item_mapping = pickle.load(f)

print("FAISS indices and index-item mapping loaded.")


def search_by_text(query_text, k=5):
  """
  Searches the database for the top k items most similar to the query text.

  Args:
    query_text: The text query.
    k: The number of top similar items to return.

  Returns:
    A list of dictionaries, where each dictionary contains the item details
    for the top k similar items.
  """
  normalized_query = normalizer.normalize(query_text)
  query_embedding = language_model.encode(normalized_query)

  query_embedding_np = query_embedding[np.newaxis, :]
  faiss.normalize_L2(query_embedding_np)

  distances, indices = text_index.search(query_embedding_np, 100)

  unique_texts = set()
  results = []
  for idx in indices[0]:
    text = idx_item_mapping[idx]
    if text not in unique_texts:
      unique_texts.add(text)
      results.append(text)
      if len(results) == k:
        break

  return results

def search_by_image(query_image, k=5):
  """
  Searches the database for the top k items most similar to the query text.

  Args:
    query_text: The text query.
    k: The number of top similar items to return.

  Returns:
    A list of dictionaries, where each dictionary contains the item details
    for the top k similar items.
  """
  inputs = vision_processor(images=query_image, return_tensors="pt").to(device) # Move image inputs to device
  with torch.no_grad():
    outputs = vision_model(**inputs)
  image_embedding_np = outputs[0].mean(dim=1)[0].cpu().numpy()


  query_embedding_np = image_embedding_np[np.newaxis, :]
  faiss.normalize_L2(query_embedding_np)

  # Search the FAISS index
  distances, indices = image_index.search(query_embedding_np, 100)

  # Get the top k items using the indices and the mapping
  unique_texts = set()
  results = []
  for idx in indices[0]:
    text = idx_item_mapping[idx]
    if text not in unique_texts:
      unique_texts.add(text)
      results.append(text)
      if len(results) == k:
        break

  return results


def rag_pipeline(question, image=None):
    """
    Runs the RAG pipeline with the given question and optional image.

    Args:
        question: The text question.
        image: Optional image input.

    Returns:
        The generated answer from the language model.
    """
    retrieved_items = []
    if image is not None:
        retrieved_items.extend(search_by_image(image))
    retrieved_items.extend(search_by_text(question))

    TEMPLATE = "{context}\nYou: {prompt}\nPersianMind: "
    CONTEXT = '\n'.join(retrieved_items)
    PROMPT = '\n'.join([
        question,
        'به این سوال به فارسی جواب بده.'
    ])

    model_input = TEMPLATE.format(context=CONTEXT, prompt=PROMPT)
    input_tokens = tokenizer(model_input, return_tensors="pt")
    input_tokens = input_tokens.to(device)
    generate_ids = model.generate(**input_tokens, max_new_tokens=200, do_sample=False, repetition_penalty=1.1)
    model_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    return model_output[len(model_input):]


iface = gr.Interface(
    fn=rag_pipeline,
    inputs=[
        gr.Textbox(label="Your Question"),
        gr.Image(type="pil", label="Optional Image")
    ],
    outputs=gr.Textbox(label="Answer"),
    title="Tourist Attraction RAG Pipeline",
    description="Ask a question about tourist attractions and optionally provide an image."
)


iface.launch(debug=True)