Update
Browse files- README.md +3 -2
- modeling_minicpmv.py +20 -9
README.md
CHANGED
|
@@ -120,11 +120,12 @@ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V', trust_remote_code
|
|
| 120 |
model.eval().cuda()
|
| 121 |
|
| 122 |
image = Image.open('xx.jpg').convert('RGB')
|
| 123 |
-
question = '
|
|
|
|
| 124 |
|
| 125 |
res, context, _ = model.chat(
|
| 126 |
image=image,
|
| 127 |
-
|
| 128 |
context=None,
|
| 129 |
tokenizer=tokenizer,
|
| 130 |
sampling=True,
|
|
|
|
| 120 |
model.eval().cuda()
|
| 121 |
|
| 122 |
image = Image.open('xx.jpg').convert('RGB')
|
| 123 |
+
question = 'What is in the image?'
|
| 124 |
+
msgs = [{'role': 'user', 'content': question}]
|
| 125 |
|
| 126 |
res, context, _ = model.chat(
|
| 127 |
image=image,
|
| 128 |
+
msgs=msgs,
|
| 129 |
context=None,
|
| 130 |
tokenizer=tokenizer,
|
| 131 |
sampling=True,
|
modeling_minicpmv.py
CHANGED
|
@@ -235,12 +235,22 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 235 |
return result
|
| 236 |
|
| 237 |
|
| 238 |
-
def chat(self, image,
|
| 239 |
-
if
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
if sampling:
|
| 246 |
generation_config = {
|
|
@@ -268,10 +278,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
| 268 |
return_vision_hidden_states=True,
|
| 269 |
**generation_config
|
| 270 |
)
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
return res[0], context, generation_config
|
| 275 |
|
| 276 |
|
| 277 |
class LlamaTokenizerWrapper(LlamaTokenizer):
|
|
|
|
| 235 |
return result
|
| 236 |
|
| 237 |
|
| 238 |
+
def chat(self, image, msgs, context, tokenizer, vision_hidden_states=None, max_new_tokens=2048, sampling=False, **kwargs):
|
| 239 |
+
if isinstance(msgs, str):
|
| 240 |
+
msgs = json.loads(msgs)
|
| 241 |
+
# msgs to prompt
|
| 242 |
+
prompt = ''
|
| 243 |
+
for i, msg in enumerate(msgs):
|
| 244 |
+
role = msg['role']
|
| 245 |
+
content = msg['content']
|
| 246 |
+
assert role in ['user', 'assistant']
|
| 247 |
+
if i == 0:
|
| 248 |
+
assert role == 'user', 'The role of first msg should be user'
|
| 249 |
+
content = tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end + '\n' + content
|
| 250 |
+
prompt += '<用户>' if role=='user' else '<AI>'
|
| 251 |
+
prompt += content
|
| 252 |
+
prompt += '<AI>'
|
| 253 |
+
final_input = prompt
|
| 254 |
|
| 255 |
if sampling:
|
| 256 |
generation_config = {
|
|
|
|
| 278 |
return_vision_hidden_states=True,
|
| 279 |
**generation_config
|
| 280 |
)
|
| 281 |
+
answer = res[0]
|
| 282 |
+
context = msgs
|
| 283 |
+
context.append({'role':'assistant', 'content': answer})
|
| 284 |
|
| 285 |
+
return answer, context, generation_config
|
|
|
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
class LlamaTokenizerWrapper(LlamaTokenizer):
|