minhvtt commited on
Commit
f056202
·
verified ·
1 Parent(s): 36cbe5f

Update embedding_service.py

Browse files
Files changed (1) hide show
  1. embedding_service.py +61 -88
embedding_service.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
- from transformers import AutoTokenizer, AutoModel
5
- import onnxruntime as ort
6
  from typing import Union, List
7
  import io
8
 
@@ -10,7 +9,7 @@ import io
10
  class JinaClipEmbeddingService:
11
  """
12
  Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
13
- Sử dụng ONNX model để tăng tốc độ inference
14
  """
15
 
16
  def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
@@ -22,40 +21,30 @@ class JinaClipEmbeddingService:
22
  """
23
  print(f"Loading Jina CLIP v2 model from {model_path}...")
24
 
25
- # Load tokenizer processor cho text (hỗ trợ tiếng Việt)
26
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
27
-
28
- # Load ONNX model cho vision encoder
29
- self.onnx_model_path = f"{model_path}/onnx/model_fp16.onnx"
30
-
31
- try:
32
- # Thử load ONNX model nếu
33
- self.vision_session = ort.InferenceSession(
34
- self.onnx_model_path,
35
- providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
36
- )
37
- self.use_onnx = True
38
- print("✓ Loaded ONNX model for vision encoder")
39
- except:
40
- # Fallback sang PyTorch model
41
- self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
42
- self.use_onnx = False
43
- print("✓ Loaded PyTorch model (ONNX not available)")
44
-
45
- # Chuyển sang eval mode
46
- self.model.eval()
47
-
48
- # Sử dụng GPU nếu có
49
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
- self.model.to(self.device)
51
- print(f"✓ Model running on: {self.device}")
52
-
53
- def encode_text(self, text: Union[str, List[str]], normalize: bool = True) -> np.ndarray:
54
  """
55
  Encode text thành vector embeddings (hỗ trợ tiếng Việt)
56
 
57
  Args:
58
  text: Text hoặc list of texts (tiếng Việt)
 
59
  normalize: Có normalize embeddings không
60
 
61
  Returns:
@@ -64,28 +53,16 @@ class JinaClipEmbeddingService:
64
  if isinstance(text, str):
65
  text = [text]
66
 
67
- # Tokenize text với max length cho Jina CLIP v2
68
- inputs = self.tokenizer(
 
69
  text,
70
- padding=True,
71
- truncation=True,
72
- max_length=512,
73
- return_tensors="pt"
74
  )
75
 
76
- if not self.use_onnx:
77
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
78
-
79
- # Generate embeddings
80
- with torch.no_grad():
81
- if self.use_onnx:
82
- # ONNX inference
83
- onnx_inputs = {k: v.numpy() for k, v in inputs.items()}
84
- embeddings = self.vision_session.run(None, onnx_inputs)[0]
85
- else:
86
- # PyTorch inference
87
- outputs = self.model.encode_text(**inputs)
88
- embeddings = outputs.cpu().numpy()
89
 
90
  # Normalize nếu cần
91
  if normalize:
@@ -93,12 +70,18 @@ class JinaClipEmbeddingService:
93
 
94
  return embeddings
95
 
96
- def encode_image(self, image: Union[Image.Image, bytes, List], normalize: bool = True) -> np.ndarray:
 
 
 
 
 
97
  """
98
  Encode image thành vector embeddings
99
 
100
  Args:
101
- image: PIL Image, bytes, hoặc list of images
 
102
  normalize: Có normalize embeddings không
103
 
104
  Returns:
@@ -112,40 +95,26 @@ class JinaClipEmbeddingService:
112
  for img in image:
113
  if isinstance(img, bytes):
114
  processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
 
 
 
115
  else:
116
  processed_images.append(img)
117
  image = processed_images
118
- else:
119
- if not isinstance(image, list):
120
- image = [image]
121
-
122
- # Process images
123
- if self.use_onnx:
124
- # Preprocessing cho ONNX model
125
- # Resize to 512x512 (Jina CLIP v2 high resolution)
126
- from torchvision import transforms
127
- preprocess = transforms.Compose([
128
- transforms.Resize((512, 512)),
129
- transforms.ToTensor(),
130
- transforms.Normalize(
131
- mean=[0.48145466, 0.4578275, 0.40821073],
132
- std=[0.26862954, 0.26130258, 0.27577711]
133
- )
134
- ])
135
-
136
- if isinstance(image, list):
137
- pixel_values = torch.stack([preprocess(img) for img in image])
138
- else:
139
- pixel_values = preprocess(image).unsqueeze(0)
140
-
141
- # ONNX inference
142
- onnx_inputs = {"pixel_values": pixel_values.numpy()}
143
- embeddings = self.vision_session.run(None, onnx_inputs)[0]
144
- else:
145
- # PyTorch inference
146
- with torch.no_grad():
147
- embeddings = self.model.encode_image(image)
148
- embeddings = embeddings.cpu().numpy()
149
 
150
  # Normalize nếu cần
151
  if normalize:
@@ -157,6 +126,7 @@ class JinaClipEmbeddingService:
157
  self,
158
  text: Union[str, List[str]] = None,
159
  image: Union[Image.Image, bytes, List] = None,
 
160
  normalize: bool = True
161
  ) -> np.ndarray:
162
  """
@@ -165,6 +135,7 @@ class JinaClipEmbeddingService:
165
  Args:
166
  text: Text hoặc list of texts (tiếng Việt)
167
  image: PIL Image, bytes, hoặc list of images
 
168
  normalize: Có normalize embeddings không
169
 
170
  Returns:
@@ -173,19 +144,21 @@ class JinaClipEmbeddingService:
173
  embeddings = []
174
 
175
  if text is not None:
176
- text_emb = self.encode_text(text, normalize=False)
177
  embeddings.append(text_emb)
178
 
179
  if image is not None:
180
- image_emb = self.encode_image(image, normalize=False)
181
  embeddings.append(image_emb)
182
 
183
- # Combine embeddings (average hoặc concat)
184
  if len(embeddings) == 2:
185
  # Average của text và image embeddings
186
  combined = np.mean(embeddings, axis=0)
187
- else:
188
  combined = embeddings[0]
 
 
189
 
190
  # Normalize nếu cần
191
  if normalize:
 
1
  import torch
2
  import numpy as np
3
  from PIL import Image
4
+ from transformers import AutoModel
 
5
  from typing import Union, List
6
  import io
7
 
 
9
  class JinaClipEmbeddingService:
10
  """
11
  Jina CLIP v2 Embedding Service với hỗ trợ tiếng Việt
12
+ Sử dụng AutoModel với trust_remote_code
13
  """
14
 
15
  def __init__(self, model_path: str = "jinaai/jina-clip-v2"):
 
21
  """
22
  print(f"Loading Jina CLIP v2 model from {model_path}...")
23
 
24
+ # Load model với trust_remote_code
25
+ self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
26
+
27
+ # Chuyển sang eval mode
28
+ self.model.eval()
29
+
30
+ # Sử dụng GPU nếu có
31
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
+ self.model.to(self.device)
33
+
34
+ print(f"✓ Loaded Jina CLIP v2 model on: {self.device}")
35
+
36
+ def encode_text(
37
+ self,
38
+ text: Union[str, List[str]],
39
+ truncate_dim: int = None,
40
+ normalize: bool = True
41
+ ) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
42
  """
43
  Encode text thành vector embeddings (hỗ trợ tiếng Việt)
44
 
45
  Args:
46
  text: Text hoặc list of texts (tiếng Việt)
47
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
48
  normalize: Có normalize embeddings không
49
 
50
  Returns:
 
53
  if isinstance(text, str):
54
  text = [text]
55
 
56
+ # Jina CLIP v2 encode_text method
57
+ # Automatically handles tokenization internally
58
+ embeddings = self.model.encode_text(
59
  text,
60
+ truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
 
 
 
61
  )
62
 
63
+ # Convert to numpy
64
+ if isinstance(embeddings, torch.Tensor):
65
+ embeddings = embeddings.cpu().detach().numpy()
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Normalize nếu cần
68
  if normalize:
 
70
 
71
  return embeddings
72
 
73
+ def encode_image(
74
+ self,
75
+ image: Union[Image.Image, bytes, List, str],
76
+ truncate_dim: int = None,
77
+ normalize: bool = True
78
+ ) -> np.ndarray:
79
  """
80
  Encode image thành vector embeddings
81
 
82
  Args:
83
+ image: PIL Image, bytes, URL string, hoặc list of images
84
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
85
  normalize: Có normalize embeddings không
86
 
87
  Returns:
 
95
  for img in image:
96
  if isinstance(img, bytes):
97
  processed_images.append(Image.open(io.BytesIO(img)).convert('RGB'))
98
+ elif isinstance(img, str):
99
+ # URL string - keep as is, Jina CLIP can handle URLs
100
+ processed_images.append(img)
101
  else:
102
  processed_images.append(img)
103
  image = processed_images
104
+ elif not isinstance(image, list) and not isinstance(image, str):
105
+ # Single PIL Image
106
+ image = [image]
107
+
108
+ # Jina CLIP v2 encode_image method
109
+ # Supports PIL Images, file paths, or URLs
110
+ embeddings = self.model.encode_image(
111
+ image,
112
+ truncate_dim=truncate_dim # Optional: 64, 128, 256, 512, 1024
113
+ )
114
+
115
+ # Convert to numpy
116
+ if isinstance(embeddings, torch.Tensor):
117
+ embeddings = embeddings.cpu().detach().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # Normalize nếu cần
120
  if normalize:
 
126
  self,
127
  text: Union[str, List[str]] = None,
128
  image: Union[Image.Image, bytes, List] = None,
129
+ truncate_dim: int = None,
130
  normalize: bool = True
131
  ) -> np.ndarray:
132
  """
 
135
  Args:
136
  text: Text hoặc list of texts (tiếng Việt)
137
  image: PIL Image, bytes, hoặc list of images
138
+ truncate_dim: Matryoshka dimension (64-1024, None = full 1024)
139
  normalize: Có normalize embeddings không
140
 
141
  Returns:
 
144
  embeddings = []
145
 
146
  if text is not None:
147
+ text_emb = self.encode_text(text, truncate_dim=truncate_dim, normalize=False)
148
  embeddings.append(text_emb)
149
 
150
  if image is not None:
151
+ image_emb = self.encode_image(image, truncate_dim=truncate_dim, normalize=False)
152
  embeddings.append(image_emb)
153
 
154
+ # Combine embeddings (average)
155
  if len(embeddings) == 2:
156
  # Average của text và image embeddings
157
  combined = np.mean(embeddings, axis=0)
158
+ elif len(embeddings) == 1:
159
  combined = embeddings[0]
160
+ else:
161
+ raise ValueError("Phải cung cấp ít nhất text hoặc image")
162
 
163
  # Normalize nếu cần
164
  if normalize: