zRzRzRzRzRzRzR commited on
Commit
06459ba
·
1 Parent(s): b94b06e
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -21,14 +21,13 @@ class GLM4VModel:
21
  def __init__(self):
22
  self.processor = None
23
  self.model = None
24
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
25
 
26
  def load(self):
27
  self.processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
28
  self.model = Glm4vForConditionalGeneration.from_pretrained(
29
  MODEL_PATH,
30
  torch_dtype=torch.bfloat16,
31
- device_map=self.device,
32
  attn_implementation="sdpa",
33
  )
34
 
@@ -137,7 +136,7 @@ class GLM4VModel:
137
  return_dict=True,
138
  return_tensors="pt",
139
  padding=True,
140
- ).to(self.device)
141
 
142
  streamer = TextIteratorStreamer(self.processor.tokenizer, skip_prompt=True, skip_special_tokens=False)
143
  gen_args = dict(
@@ -251,7 +250,7 @@ def chat(files, msg, raw_hist, sys_prompt):
251
  break
252
  place["content"] = chunk
253
  display_hist = create_display_history(raw_hist)
254
- yield display_hist, copy.deepcopy(raw_hist), None, ""
255
 
256
  display_hist = create_display_history(raw_hist)
257
  yield display_hist, copy.deepcopy(raw_hist), None, ""
 
21
  def __init__(self):
22
  self.processor = None
23
  self.model = None
 
24
 
25
  def load(self):
26
  self.processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
27
  self.model = Glm4vForConditionalGeneration.from_pretrained(
28
  MODEL_PATH,
29
  torch_dtype=torch.bfloat16,
30
+ device_map="auto",
31
  attn_implementation="sdpa",
32
  )
33
 
 
136
  return_dict=True,
137
  return_tensors="pt",
138
  padding=True,
139
+ ).to(self.model.device)
140
 
141
  streamer = TextIteratorStreamer(self.processor.tokenizer, skip_prompt=True, skip_special_tokens=False)
142
  gen_args = dict(
 
250
  break
251
  place["content"] = chunk
252
  display_hist = create_display_history(raw_hist)
253
+ yield display_hist, copyf.deepcopy(raw_hist), None, ""
254
 
255
  display_hist = create_display_history(raw_hist)
256
  yield display_hist, copy.deepcopy(raw_hist), None, ""