nielsr HF Staff Claude commited on
Commit
dce53e9
·
1 Parent(s): 680f872

Fix Flash Attention 2 import error with conditional loading

Browse files

- Add is_flash_attention_available() helper function to detect flash-attn package
- Update both load_base_model() and load_chat_model() to conditionally use Flash Attention 2
- Fall back to default attention implementation if flash-attn is not installed
- Resolves ImportError in ZeroGPU environments without flash-attn dependency

This ensures the app works in all environments regardless of Flash Attention availability.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

Files changed (1) hide show
  1. app.py +27 -6
app.py CHANGED
@@ -9,6 +9,13 @@ import re
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
 
 
 
 
 
 
 
 
12
 
13
  # Initialize models and processors lazily
14
  base_model = None
@@ -20,11 +27,18 @@ def load_base_model():
20
  global base_model, base_processor
21
  if base_model is None:
22
  base_repo = "microsoft/kosmos-2.5"
 
 
 
 
 
 
 
 
 
23
  base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
24
  base_repo,
25
- device_map="cuda",
26
- dtype=dtype,
27
- attn_implementation="flash_attention_2"
28
  )
29
  base_processor = AutoProcessor.from_pretrained(base_repo)
30
  return base_model, base_processor
@@ -33,11 +47,18 @@ def load_chat_model():
33
  global chat_model, chat_processor
34
  if chat_model is None:
35
  chat_repo = "microsoft/kosmos-2.5-chat"
 
 
 
 
 
 
 
 
 
36
  chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
37
  chat_repo,
38
- device_map="cuda",
39
- dtype=dtype,
40
- attn_implementation="flash_attention_2"
41
  )
42
  chat_processor = AutoProcessor.from_pretrained(chat_repo)
43
  return chat_model, chat_processor
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
 
12
+ # Check if Flash Attention 2 is available
13
+ def is_flash_attention_available():
14
+ try:
15
+ import flash_attn
16
+ return True
17
+ except ImportError:
18
+ return False
19
 
20
  # Initialize models and processors lazily
21
  base_model = None
 
27
  global base_model, base_processor
28
  if base_model is None:
29
  base_repo = "microsoft/kosmos-2.5"
30
+
31
+ # Use Flash Attention 2 if available, otherwise use default attention
32
+ model_kwargs = {
33
+ "device_map": "cuda",
34
+ "dtype": dtype,
35
+ }
36
+ if is_flash_attention_available():
37
+ model_kwargs["attn_implementation"] = "flash_attention_2"
38
+
39
  base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
40
  base_repo,
41
+ **model_kwargs
 
 
42
  )
43
  base_processor = AutoProcessor.from_pretrained(base_repo)
44
  return base_model, base_processor
 
47
  global chat_model, chat_processor
48
  if chat_model is None:
49
  chat_repo = "microsoft/kosmos-2.5-chat"
50
+
51
+ # Use Flash Attention 2 if available, otherwise use default attention
52
+ model_kwargs = {
53
+ "device_map": "cuda",
54
+ "dtype": dtype,
55
+ }
56
+ if is_flash_attention_available():
57
+ model_kwargs["attn_implementation"] = "flash_attention_2"
58
+
59
  chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
60
  chat_repo,
61
+ **model_kwargs
 
 
62
  )
63
  chat_processor = AutoProcessor.from_pretrained(chat_repo)
64
  return chat_model, chat_processor