nielsr HF Staff Claude commited on
Commit
d463280
·
1 Parent(s): c25975a

Fix ZeroGPU and model loading issues

Browse files

- Add accelerate dependency to requirements
- Replace deprecated torch_dtype with dtype parameter
- Implement lazy model loading to avoid ZeroGPU context issues
- Load models only when needed inside @spaces.GPU decorated functions

Fixes:
- ValueError: Using a device_map requires accelerate
- torch_dtype deprecation warnings
- ZeroGPU function called outside Gradio context warnings

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

Files changed (2) hide show
  1. app.py +45 -33
  2. requirements.txt +1 -0
app.py CHANGED
@@ -10,31 +10,37 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
 
12
 
13
- # Initialize models and processors
14
- @spaces.GPU
15
- def load_models():
16
- base_repo = "microsoft/kosmos-2.5"
17
- chat_repo = "microsoft/kosmos-2.5-chat"
18
-
19
- base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
20
- base_repo,
21
- device_map=device,
22
- torch_dtype=dtype,
23
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
24
- )
25
- base_processor = AutoProcessor.from_pretrained(base_repo)
26
-
27
- chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
28
- chat_repo,
29
- device_map=device,
30
- torch_dtype=dtype,
31
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
32
- )
33
- chat_processor = AutoProcessor.from_pretrained(chat_repo)
34
-
35
- return base_model, base_processor, chat_model, chat_processor
36
 
37
- base_model, base_processor, chat_model, chat_processor = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
40
  y = y.replace(prompt, "")
@@ -65,8 +71,10 @@ def generate_markdown(image):
65
  if image is None:
66
  return "Please upload an image."
67
 
 
 
68
  prompt = "<md>"
69
- inputs = base_processor(text=prompt, images=image, return_tensors="pt")
70
 
71
  height, width = inputs.pop("height"), inputs.pop("width")
72
  raw_width, raw_height = image.size
@@ -77,12 +85,12 @@ def generate_markdown(image):
77
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
78
 
79
  with torch.no_grad():
80
- generated_ids = base_model.generate(
81
  **inputs,
82
  max_new_tokens=1024,
83
  )
84
 
85
- generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
86
  result = generated_text[0].replace(prompt, "").strip()
87
 
88
  return result
@@ -92,8 +100,10 @@ def generate_ocr(image):
92
  if image is None:
93
  return "Please upload an image.", None
94
 
 
 
95
  prompt = "<ocr>"
96
- inputs = base_processor(text=prompt, images=image, return_tensors="pt")
97
 
98
  height, width = inputs.pop("height"), inputs.pop("width")
99
  raw_width, raw_height = image.size
@@ -104,12 +114,12 @@ def generate_ocr(image):
104
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
105
 
106
  with torch.no_grad():
107
- generated_ids = base_model.generate(
108
  **inputs,
109
  max_new_tokens=1024,
110
  )
111
 
112
- generated_text = base_processor.batch_decode(generated_ids, skip_special_tokens=True)
113
 
114
  # Post-process OCR output
115
  output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
@@ -140,10 +150,12 @@ def generate_chat_response(image, question):
140
  if not question.strip():
141
  return "Please ask a question."
142
 
 
 
143
  template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
144
  prompt = template.format(question)
145
 
146
- inputs = chat_processor(text=prompt, images=image, return_tensors="pt")
147
 
148
  height, width = inputs.pop("height"), inputs.pop("width")
149
  raw_width, raw_height = image.size
@@ -154,12 +166,12 @@ def generate_chat_response(image, question):
154
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
155
 
156
  with torch.no_grad():
157
- generated_ids = chat_model.generate(
158
  **inputs,
159
  max_new_tokens=1024,
160
  )
161
 
162
- generated_text = chat_processor.batch_decode(generated_ids, skip_special_tokens=True)
163
 
164
  # Extract only the assistant's response
165
  result = generated_text[0]
 
10
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
11
 
12
 
13
+ # Initialize models and processors lazily
14
+ base_model = None
15
+ base_processor = None
16
+ chat_model = None
17
+ chat_processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def load_base_model():
20
+ global base_model, base_processor
21
+ if base_model is None:
22
+ base_repo = "microsoft/kosmos-2.5"
23
+ base_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
24
+ base_repo,
25
+ device_map=device,
26
+ dtype=dtype,
27
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
28
+ )
29
+ base_processor = AutoProcessor.from_pretrained(base_repo)
30
+ return base_model, base_processor
31
+
32
+ def load_chat_model():
33
+ global chat_model, chat_processor
34
+ if chat_model is None:
35
+ chat_repo = "microsoft/kosmos-2.5-chat"
36
+ chat_model = Kosmos2_5ForConditionalGeneration.from_pretrained(
37
+ chat_repo,
38
+ device_map=device,
39
+ dtype=dtype,
40
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None
41
+ )
42
+ chat_processor = AutoProcessor.from_pretrained(chat_repo)
43
+ return chat_model, chat_processor
44
 
45
  def post_process_ocr(y, scale_height, scale_width, prompt="<ocr>"):
46
  y = y.replace(prompt, "")
 
71
  if image is None:
72
  return "Please upload an image."
73
 
74
+ model, processor = load_base_model()
75
+
76
  prompt = "<md>"
77
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
78
 
79
  height, width = inputs.pop("height"), inputs.pop("width")
80
  raw_width, raw_height = image.size
 
85
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
86
 
87
  with torch.no_grad():
88
+ generated_ids = model.generate(
89
  **inputs,
90
  max_new_tokens=1024,
91
  )
92
 
93
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
94
  result = generated_text[0].replace(prompt, "").strip()
95
 
96
  return result
 
100
  if image is None:
101
  return "Please upload an image.", None
102
 
103
+ model, processor = load_base_model()
104
+
105
  prompt = "<ocr>"
106
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
107
 
108
  height, width = inputs.pop("height"), inputs.pop("width")
109
  raw_width, raw_height = image.size
 
114
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
115
 
116
  with torch.no_grad():
117
+ generated_ids = model.generate(
118
  **inputs,
119
  max_new_tokens=1024,
120
  )
121
 
122
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
123
 
124
  # Post-process OCR output
125
  output_text = post_process_ocr(generated_text[0], scale_height, scale_width)
 
150
  if not question.strip():
151
  return "Please ask a question."
152
 
153
+ model, processor = load_chat_model()
154
+
155
  template = "<md>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
156
  prompt = template.format(question)
157
 
158
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
159
 
160
  height, width = inputs.pop("height"), inputs.pop("width")
161
  raw_width, raw_height = image.size
 
166
  inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
167
 
168
  with torch.no_grad():
169
+ generated_ids = model.generate(
170
  **inputs,
171
  max_new_tokens=1024,
172
  )
173
 
174
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
175
 
176
  # Extract only the assistant's response
177
  result = generated_text[0]
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  gradio==4.44.0
2
  torch>=2.0.0
3
  git+https://github.com/huggingface/transformers.git
 
4
  pillow
5
  requests
6
  spaces
 
1
  gradio==4.44.0
2
  torch>=2.0.0
3
  git+https://github.com/huggingface/transformers.git
4
+ accelerate
5
  pillow
6
  requests
7
  spaces