Spaces:
Running
Running
Show a problem with the current approach.
Browse files- pyproject.toml +1 -0
- test_llm_inference.py +64 -1
- uv.lock +35 -0
pyproject.toml
CHANGED
|
@@ -8,6 +8,7 @@ dependencies = [
|
|
| 8 |
"fastapi>=0.115.8",
|
| 9 |
"pandas>=2.2.3",
|
| 10 |
"pydantic>=2.10.6",
|
|
|
|
| 11 |
"requests>=2.32.3",
|
| 12 |
"streamlit==1.40.1",
|
| 13 |
]
|
|
|
|
| 8 |
"fastapi>=0.115.8",
|
| 9 |
"pandas>=2.2.3",
|
| 10 |
"pydantic>=2.10.6",
|
| 11 |
+
"pytest>=8.3.4",
|
| 12 |
"requests>=2.32.3",
|
| 13 |
"streamlit==1.40.1",
|
| 14 |
]
|
test_llm_inference.py
CHANGED
|
@@ -13,7 +13,7 @@ def model_and_tokenizer():
|
|
| 13 |
model = AutoModelForCausalLM.from_pretrained(
|
| 14 |
model_name,
|
| 15 |
device_map="cpu",
|
| 16 |
-
|
| 17 |
)
|
| 18 |
return model, tokenizer
|
| 19 |
|
|
@@ -63,3 +63,66 @@ def test_highlights(model_and_tokenizer, sample_inputs):
|
|
| 63 |
assert isinstance(h['token_loss'], float)
|
| 64 |
assert isinstance(h['most_likely_token'], str)
|
| 65 |
assert isinstance(h['topk_tokens'], list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
model = AutoModelForCausalLM.from_pretrained(
|
| 14 |
model_name,
|
| 15 |
device_map="cpu",
|
| 16 |
+
torch_dtype=torch.float16
|
| 17 |
)
|
| 18 |
return model, tokenizer
|
| 19 |
|
|
|
|
| 63 |
assert isinstance(h['token_loss'], float)
|
| 64 |
assert isinstance(h['most_likely_token'], str)
|
| 65 |
assert isinstance(h['topk_tokens'], list)
|
| 66 |
+
|
| 67 |
+
def compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress, k=5):
|
| 68 |
+
"""
|
| 69 |
+
Extracts and compares the next token predictions between the fast method and slow method.
|
| 70 |
+
Returns the differences between the two approaches for analysis.
|
| 71 |
+
"""
|
| 72 |
+
# Get predictions from the fast method (using cache)
|
| 73 |
+
fast_tokens, fast_logits = custom_llm_inference.get_next_token_predictions_inner(
|
| 74 |
+
model, tokenizer, doc, prompt, doc_in_progress, k
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Get predictions from the slow method (recomputing for each token)
|
| 78 |
+
slow_tokens, slow_logits = custom_llm_inference.get_next_token_predictions_slow(
|
| 79 |
+
model, tokenizer, doc, prompt, doc_in_progress, k
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Compare the decoded tokens (this is what users will see)
|
| 83 |
+
token_matches = [fast == slow for fast, slow in zip(fast_tokens, slow_tokens)]
|
| 84 |
+
|
| 85 |
+
# Calculate the difference in logits for most likely next tokens
|
| 86 |
+
fast_most_likely = fast_logits.argmax(dim=-1)
|
| 87 |
+
slow_most_likely = slow_logits.argmax(dim=-1)
|
| 88 |
+
logit_match = torch.eq(fast_most_likely, slow_most_likely).cpu().numpy()
|
| 89 |
+
|
| 90 |
+
# Calculate numerical difference in logits
|
| 91 |
+
logit_diff_norm = torch.linalg.vector_norm((fast_logits - slow_logits).to(torch.float32), dim=1).cpu().numpy()
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
"fast_tokens": fast_tokens,
|
| 95 |
+
"slow_tokens": slow_tokens,
|
| 96 |
+
"token_matches": token_matches,
|
| 97 |
+
"token_match_all": all(token_matches),
|
| 98 |
+
"logit_match": logit_match,
|
| 99 |
+
"logit_diff_norm": logit_diff_norm
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def test_lookahead_token_consistency(model_and_tokenizer, sample_inputs):
|
| 103 |
+
"""
|
| 104 |
+
Test that demonstrates the potential issue with cache position indices
|
| 105 |
+
when generating lookahead tokens.
|
| 106 |
+
"""
|
| 107 |
+
model, tokenizer = model_and_tokenizer
|
| 108 |
+
doc, prompt, doc_in_progress = sample_inputs
|
| 109 |
+
|
| 110 |
+
results = compare_lookahead_predictions(model, tokenizer, doc, prompt, doc_in_progress)
|
| 111 |
+
|
| 112 |
+
# Check if the tokens are the same
|
| 113 |
+
assert results["token_match_all"], (
|
| 114 |
+
f"Fast and slow methods produced different tokens.\n"
|
| 115 |
+
f"Fast: {results['fast_tokens']}\n"
|
| 116 |
+
f"Slow: {results['slow_tokens']}"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Check if the most likely next tokens based on logits are the same
|
| 120 |
+
assert all(results["logit_match"]), (
|
| 121 |
+
f"Fast and slow methods predicted different most likely next tokens"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Check that the logit differences are minimal
|
| 125 |
+
# This might fail if there's a bug in the cache position indices
|
| 126 |
+
assert all(diff < 1e-4 for diff in results["logit_diff_norm"]), (
|
| 127 |
+
f"Significant difference in logits between fast and slow methods: {results['logit_diff_norm']}"
|
| 128 |
+
)
|
uv.lock
CHANGED
|
@@ -287,6 +287,15 @@ wheels = [
|
|
| 287 |
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
| 288 |
]
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
[[package]]
|
| 291 |
name = "ipython"
|
| 292 |
version = "8.32.0"
|
|
@@ -719,6 +728,15 @@ wheels = [
|
|
| 719 |
{ url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
|
| 720 |
]
|
| 721 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
[[package]]
|
| 723 |
name = "prompt-toolkit"
|
| 724 |
version = "3.0.50"
|
|
@@ -917,6 +935,21 @@ wheels = [
|
|
| 917 |
{ url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
|
| 918 |
]
|
| 919 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
[[package]]
|
| 921 |
name = "python-dateutil"
|
| 922 |
version = "2.9.0.post0"
|
|
@@ -1505,6 +1538,7 @@ dependencies = [
|
|
| 1505 |
{ name = "fastapi" },
|
| 1506 |
{ name = "pandas" },
|
| 1507 |
{ name = "pydantic" },
|
|
|
|
| 1508 |
{ name = "requests" },
|
| 1509 |
{ name = "streamlit" },
|
| 1510 |
]
|
|
@@ -1526,6 +1560,7 @@ requires-dist = [
|
|
| 1526 |
{ name = "fastapi", specifier = ">=0.115.8" },
|
| 1527 |
{ name = "pandas", specifier = ">=2.2.3" },
|
| 1528 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
|
|
|
| 1529 |
{ name = "requests", specifier = ">=2.32.3" },
|
| 1530 |
{ name = "streamlit", specifier = "==1.40.1" },
|
| 1531 |
]
|
|
|
|
| 287 |
{ url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 },
|
| 288 |
]
|
| 289 |
|
| 290 |
+
[[package]]
|
| 291 |
+
name = "iniconfig"
|
| 292 |
+
version = "2.0.0"
|
| 293 |
+
source = { registry = "https://pypi.org/simple" }
|
| 294 |
+
sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 }
|
| 295 |
+
wheels = [
|
| 296 |
+
{ url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 },
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
[[package]]
|
| 300 |
name = "ipython"
|
| 301 |
version = "8.32.0"
|
|
|
|
| 728 |
{ url = "https://files.pythonhosted.org/packages/0b/30/2b61876e2722374558b871dfbfcbe4e406626d63f4f6ed92e9c8e24cac37/pillow-11.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:27a7860107500d813fcd203b4ea19b04babe79448268403172782754870dac25", size = 2254890 },
|
| 729 |
]
|
| 730 |
|
| 731 |
+
[[package]]
|
| 732 |
+
name = "pluggy"
|
| 733 |
+
version = "1.5.0"
|
| 734 |
+
source = { registry = "https://pypi.org/simple" }
|
| 735 |
+
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
| 736 |
+
wheels = [
|
| 737 |
+
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
| 738 |
+
]
|
| 739 |
+
|
| 740 |
[[package]]
|
| 741 |
name = "prompt-toolkit"
|
| 742 |
version = "3.0.50"
|
|
|
|
| 935 |
{ url = "https://files.pythonhosted.org/packages/eb/f5/b9e2a42aa8f9e34d52d66de87941ecd236570c7ed2e87775ed23bbe4e224/pymdown_extensions-10.14.3-py3-none-any.whl", hash = "sha256:05e0bee73d64b9c71a4ae17c72abc2f700e8bc8403755a00580b49a4e9f189e9", size = 264467 },
|
| 936 |
]
|
| 937 |
|
| 938 |
+
[[package]]
|
| 939 |
+
name = "pytest"
|
| 940 |
+
version = "8.3.4"
|
| 941 |
+
source = { registry = "https://pypi.org/simple" }
|
| 942 |
+
dependencies = [
|
| 943 |
+
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
| 944 |
+
{ name = "iniconfig" },
|
| 945 |
+
{ name = "packaging" },
|
| 946 |
+
{ name = "pluggy" },
|
| 947 |
+
]
|
| 948 |
+
sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 }
|
| 949 |
+
wheels = [
|
| 950 |
+
{ url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 },
|
| 951 |
+
]
|
| 952 |
+
|
| 953 |
[[package]]
|
| 954 |
name = "python-dateutil"
|
| 955 |
version = "2.9.0.post0"
|
|
|
|
| 1538 |
{ name = "fastapi" },
|
| 1539 |
{ name = "pandas" },
|
| 1540 |
{ name = "pydantic" },
|
| 1541 |
+
{ name = "pytest" },
|
| 1542 |
{ name = "requests" },
|
| 1543 |
{ name = "streamlit" },
|
| 1544 |
]
|
|
|
|
| 1560 |
{ name = "fastapi", specifier = ">=0.115.8" },
|
| 1561 |
{ name = "pandas", specifier = ">=2.2.3" },
|
| 1562 |
{ name = "pydantic", specifier = ">=2.10.6" },
|
| 1563 |
+
{ name = "pytest", specifier = ">=8.3.4" },
|
| 1564 |
{ name = "requests", specifier = ">=2.32.3" },
|
| 1565 |
{ name = "streamlit", specifier = "==1.40.1" },
|
| 1566 |
]
|