Spaces:
Sleeping
Sleeping
| from llmlib.base_llm import Message | |
| from PIL import Image | |
| from llmlib.phi3.phi3 import GenConf, Phi3Vision, extract_imgs_and_dicts, pad_left | |
| import pytest | |
| import torch | |
| from .helpers import ( | |
| assert_model_can_answer_batch_of_img_prompts, | |
| assert_model_can_answer_batch_of_text_prompts, | |
| assert_model_knows_capital_of_france, | |
| assert_model_rejects_unsupported_batches, | |
| get_mona_lisa_completion, | |
| is_ci, | |
| ) | |
| def test_extract_imgs_and_dicts(): | |
| img1 = Image.new(mode="RGB", size=(1, 1)) | |
| img2 = Image.new(mode="RGB", size=(1, 1)) | |
| msgs = [ | |
| a_msg(), | |
| a_msg(img=img1, img_name="img1"), | |
| a_msg(img=img2, img_name="img2"), | |
| a_msg(), | |
| a_msg(img=img1, img_name="img1"), | |
| a_msg(img=img2, img_name="img2"), | |
| ] | |
| images, messages = extract_imgs_and_dicts(msgs) | |
| assert len(images) == 2 | |
| assert len(messages) == 6 | |
| assert "<|image_1|>" in messages[1]["content"] | |
| assert "<|image_1|>" in messages[4]["content"] | |
| assert "<|image_2|>" in messages[5]["content"] | |
| assert "<|image_2|>" in messages[2]["content"] | |
| def a_msg(img: Image.Image | None = None, img_name: str | None = None) -> Message: | |
| return Message(role="user", msg="", img=img, img_name=img_name) | |
| def test_phi3_vision(model: Phi3Vision): | |
| assert_model_knows_capital_of_france(model) | |
| answer: str = get_mona_lisa_completion(model) | |
| assert isinstance(answer, str) | |
| def test_phi3_batching(model: Phi3Vision): | |
| assert_model_can_answer_batch_of_text_prompts(model) | |
| assert_model_can_answer_batch_of_img_prompts(model) | |
| def test_phi3_invalid_input(model: Phi3Vision): | |
| assert_model_rejects_unsupported_batches(model) | |
| def model(): | |
| yield Phi3Vision(GenConf(max_new_tokens=30)) | |
| def test_padleft(): | |
| pad_token = -1 | |
| seqs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])] | |
| expected = torch.tensor([[1, 2, 3], [pad_token, 4, 5], [pad_token, pad_token, 6]]) | |
| actual = pad_left(seqs, pad_token) | |
| assert torch.equal(actual, expected) | |