robinhad commited on
Commit
d03d3f9
·
verified ·
1 Parent(s): e6380a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import subprocess
3
 
4
- subprocess.run('pip install flash-attn==2.7.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
 
6
  import threading
7
 
@@ -11,6 +11,9 @@ import spaces
11
  import gradio as gr
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
14
 
15
  #torch._dynamo.config.disable = True
16
 
@@ -25,7 +28,7 @@ def load_model():
25
  MODEL_ID,
26
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
27
  device_map="auto", # if device == "cuda" else None,
28
- attn_implementation="flash_attention_2",
29
  ) # .cuda()
30
  print(f"Selected device:", device)
31
  return model, tokenizer, device
 
1
  import os
2
  import subprocess
3
 
4
+ #subprocess.run('pip install flash-attn==2.7.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
5
 
6
  import threading
7
 
 
11
  import gradio as gr
12
  import torch
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
+ from kernels import get_kernel
15
+
16
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
17
 
18
  #torch._dynamo.config.disable = True
19
 
 
28
  MODEL_ID,
29
  torch_dtype=torch.bfloat16, # if device == "cuda" else torch.float32,
30
  device_map="auto", # if device == "cuda" else None,
31
+ attn_implementation=vllm_flash_attn3,
32
  ) # .cuda()
33
  print(f"Selected device:", device)
34
  return model, tokenizer, device