Spaces:
Running
Running
Prepare to move up to Gemma3
Browse files- custom_llm.py +11 -2
custom_llm.py
CHANGED
|
@@ -32,8 +32,16 @@ async def models_lifespan(app: FastAPI):
|
|
| 32 |
#model_name = 'google/gemma-1.1-7b-it'
|
| 33 |
#model_name = 'google/gemma-1.1-2b-it'
|
| 34 |
model_name = 'google/gemma-2-9b-it'
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
ml_models["llm"] = llm = {
|
| 39 |
'tokenizer': AutoTokenizer.from_pretrained(model_name),
|
|
@@ -41,7 +49,8 @@ async def models_lifespan(app: FastAPI):
|
|
| 41 |
model_name,
|
| 42 |
device_map="auto" if USE_GPU else "cpu",
|
| 43 |
torch_dtype=dtype,
|
| 44 |
-
attn_implementation='eager'
|
|
|
|
| 45 |
)
|
| 46 |
}
|
| 47 |
print("Loaded llm with device map:")
|
|
|
|
| 32 |
#model_name = 'google/gemma-1.1-7b-it'
|
| 33 |
#model_name = 'google/gemma-1.1-2b-it'
|
| 34 |
model_name = 'google/gemma-2-9b-it'
|
| 35 |
+
#model_name = 'google/gemma-3-12b-it'
|
| 36 |
+
#model_name = 'google/gemma-3-4b-it'
|
| 37 |
|
| 38 |
+
if USE_GPU:
|
| 39 |
+
dtype = torch.bfloat16
|
| 40 |
+
from transformers import TorchAoConfig
|
| 41 |
+
quantization_config = None#TorchAoConfig("int4_weight_only", group_size=128)
|
| 42 |
+
else:
|
| 43 |
+
dtype = torch.float16
|
| 44 |
+
quantization_config = None
|
| 45 |
|
| 46 |
ml_models["llm"] = llm = {
|
| 47 |
'tokenizer': AutoTokenizer.from_pretrained(model_name),
|
|
|
|
| 49 |
model_name,
|
| 50 |
device_map="auto" if USE_GPU else "cpu",
|
| 51 |
torch_dtype=dtype,
|
| 52 |
+
attn_implementation='eager',
|
| 53 |
+
quantization_config=quantization_config,
|
| 54 |
)
|
| 55 |
}
|
| 56 |
print("Loaded llm with device map:")
|