robinhad commited on
Commit
7719ac7
·
verified ·
1 Parent(s): dc5393d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -13,7 +13,7 @@ import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
  from kernels import get_kernel
15
 
16
- vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
17
 
18
  #torch._dynamo.config.disable = True
19
 
@@ -28,7 +28,7 @@ def load_model():
28
  MODEL_ID,
29
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
30
  device_map="auto", # if device == "cuda" else None,
31
- attn_implementation="flash_attention_3", # "flash_attention_2", #
32
  ) # .cuda()
33
  print(f"Selected device:", device)
34
  return model, tokenizer, device
 
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
  from kernels import get_kernel
15
 
16
+ #vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
17
 
18
  #torch._dynamo.config.disable = True
19
 
 
28
  MODEL_ID,
29
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
30
  device_map="auto", # if device == "cuda" else None,
31
+ attn_implementation="kernels-community/vllm-flash-attn3", # "flash_attention_2", #
32
  ) # .cuda()
33
  print(f"Selected device:", device)
34
  return model, tokenizer, device