Spaces:
Runtime error
Runtime error
Commit
·
9d6c917
1
Parent(s):
58c6fa7
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,9 @@ import torch
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from transformers import CLIPImageProcessor
|
| 8 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
ctx_limit = 3500
|
| 11 |
title = 'ViusualRWKV-v5'
|
|
@@ -14,12 +17,12 @@ vision_remote_path = "rwkv1b5-vitl336p14-577token_mix665k_visual.pth"
|
|
| 14 |
vision_tower_name = 'openai/clip-vit-large-patch14-336'
|
| 15 |
|
| 16 |
os.environ["RWKV_JIT_ON"] = '1'
|
| 17 |
-
os.environ["RWKV_CUDA_ON"] = '
|
| 18 |
|
| 19 |
from modeling_vision import VisionEncoder, VisionEncoderConfig
|
| 20 |
from modeling_rwkv import RWKV
|
| 21 |
model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
|
| 22 |
-
model = RWKV(model=model_path, strategy='
|
| 23 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 24 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 25 |
|
|
@@ -32,6 +35,8 @@ vision_local_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=
|
|
| 32 |
vision_state_dict = torch.load(vision_local_path, map_location='cpu')
|
| 33 |
visual_encoder.load_state_dict(vision_state_dict)
|
| 34 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
|
|
|
|
|
|
| 35 |
##########################################################################
|
| 36 |
def generate_prompt(instruction):
|
| 37 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
@@ -83,9 +88,13 @@ def generate(
|
|
| 83 |
yield out_str.strip()
|
| 84 |
out_last = i + 1
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
del out
|
| 87 |
del state
|
| 88 |
gc.collect()
|
|
|
|
| 89 |
yield out_str.strip()
|
| 90 |
|
| 91 |
|
|
@@ -157,7 +166,7 @@ with gr.Blocks(title=title) as demo:
|
|
| 157 |
with gr.Column():
|
| 158 |
output = gr.Textbox(label="Output", lines=10)
|
| 159 |
data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
|
| 160 |
-
submit.click(chatbot, [image, prompt], [output]
|
| 161 |
clear.click(lambda: None, [], [output])
|
| 162 |
data.click(lambda x: x, [data], [image, prompt])
|
| 163 |
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from transformers import CLIPImageProcessor
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
+
from pynvml import *
|
| 10 |
+
nvmlInit()
|
| 11 |
+
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
| 12 |
|
| 13 |
ctx_limit = 3500
|
| 14 |
title = 'ViusualRWKV-v5'
|
|
|
|
| 17 |
vision_tower_name = 'openai/clip-vit-large-patch14-336'
|
| 18 |
|
| 19 |
os.environ["RWKV_JIT_ON"] = '1'
|
| 20 |
+
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
|
| 21 |
|
| 22 |
from modeling_vision import VisionEncoder, VisionEncoderConfig
|
| 23 |
from modeling_rwkv import RWKV
|
| 24 |
model_path = hf_hub_download(repo_id="howard-hou/visualrwkv-5", filename=rwkv_remote_path)
|
| 25 |
+
model = RWKV(model=model_path, strategy='cuda fp16')
|
| 26 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 27 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 28 |
|
|
|
|
| 35 |
vision_state_dict = torch.load(vision_local_path, map_location='cpu')
|
| 36 |
visual_encoder.load_state_dict(vision_state_dict)
|
| 37 |
image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
visual_encoder = visual_encoder.cuda()
|
| 40 |
##########################################################################
|
| 41 |
def generate_prompt(instruction):
|
| 42 |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
|
|
|
|
| 88 |
yield out_str.strip()
|
| 89 |
out_last = i + 1
|
| 90 |
|
| 91 |
+
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 92 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 93 |
+
print(f'{timestamp} - vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
| 94 |
del out
|
| 95 |
del state
|
| 96 |
gc.collect()
|
| 97 |
+
torch.cuda.empty_cache()
|
| 98 |
yield out_str.strip()
|
| 99 |
|
| 100 |
|
|
|
|
| 166 |
with gr.Column():
|
| 167 |
output = gr.Textbox(label="Output", lines=10)
|
| 168 |
data = gr.Dataset(components=[image, prompt], samples=examples, label="Examples", headers=["Image", "Prompt"])
|
| 169 |
+
submit.click(chatbot, [image, prompt], [output])
|
| 170 |
clear.click(lambda: None, [], [output])
|
| 171 |
data.click(lambda x: x, [data], [image, prompt])
|
| 172 |
|