Spaces:
Runtime error
Runtime error
| import random | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import AutoModelForSeq2SeqLM | |
| from transformers import AutoProcessor | |
| from transformers import pipeline, set_seed | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") | |
| big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") | |
| text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator') | |
| zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval() | |
| zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en') | |
| en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval() | |
| en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") | |
| def load_prompter(): | |
| prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return prompter_model, tokenizer | |
| prompter_model, prompter_tokenizer = load_prompter() | |
| def generate_prompter(plain_text, max_new_tokens=75, num_beams=8, num_return_sequences=8, length_penalty=-1.0): | |
| input_ids = prompter_tokenizer(plain_text.strip() + " Rephrase:", return_tensors="pt").input_ids | |
| eos_id = prompter_tokenizer.eos_token_id | |
| outputs = prompter_model.generate( | |
| input_ids, | |
| do_sample=False, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return_sequences, | |
| eos_token_id=eos_id, | |
| pad_token_id=eos_id, | |
| length_penalty=length_penalty | |
| ) | |
| output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| result = [] | |
| for output_text in output_texts: | |
| result.append(output_text.replace(plain_text + " Rephrase:", "").strip()) | |
| return "\n".join(result) | |
| def translate_zh2en(text): | |
| with torch.no_grad(): | |
| encoded = zh2en_tokenizer([text], return_tensors='pt') | |
| sequences = zh2en_model.generate(**encoded) | |
| return zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
| def translate_en2zh(text): | |
| with torch.no_grad(): | |
| encoded = en2zh_tokenizer([text], return_tensors="pt") | |
| sequences = en2zh_model.generate(**encoded) | |
| return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
| def text_generate(text_in_english): | |
| seed = random.randint(100, 1000000) | |
| set_seed(seed) | |
| result = "" | |
| for _ in range(6): | |
| sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8) | |
| list = [] | |
| for sequence in sequences: | |
| line = sequence['generated_text'].strip() | |
| if line != text_in_english and len(line) > (len(text_in_english) + 4) and line.endswith( | |
| (':', '-', '—')) is False: | |
| list.append(line) | |
| result = "\n".join(list) | |
| result = re.sub('[^ ]+\.[^ ]+', '', result) | |
| result = result.replace('<', '').replace('>', '') | |
| if result != '': | |
| break | |
| return result, "\n".join(translate_en2zh(line) for line in result.split("\n") if len(line) > 0) | |
| def get_prompt_from_image(input_image): | |
| image = input_image.convert('RGB') | |
| pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values | |
| generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50) | |
| generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| print(generated_caption) | |
| return generated_caption | |
| with gr.Blocks() as block: | |
| with gr.Column(): | |
| with gr.Tab('文本生成'): | |
| with gr.Row(): | |
| input_text = gr.Textbox(lines=6, label='你的想法', placeholder='在此输入内容...') | |
| translate_output = gr.Textbox(lines=6, label='翻译结果(Prompt输入)') | |
| with gr.Accordion('SD优化参数设置', open=False): | |
| max_new_tokens = gr.Slider(1, 255, 75, label='max_new_tokens', step=1) | |
| nub_beams = gr.Slider(1, 30, 8, label='num_beams', step=1) | |
| num_return_sequences = gr.Slider(1, 30, 8, label='num_return_sequences', step=1) | |
| length_penalty = gr.Slider(-1.0, 1.0, -1.0, label='length_penalty') | |
| generate_prompter_output = gr.Textbox(lines=6, label='SD优化的 Prompt') | |
| output = gr.Textbox(lines=6, label='瞎编的 Prompt') | |
| output_zh = gr.Textbox(lines=6, label='瞎编的 Prompt(zh)') | |
| with gr.Row(): | |
| translate_btn = gr.Button('翻译') | |
| generate_prompter_btn = gr.Button('SD优化') | |
| gpt_btn = gr.Button('瞎编') | |
| with gr.Tab('从图片中生成'): | |
| with gr.Row(): | |
| input_image = gr.Image(type='pil') | |
| img_btn = gr.Button('提交') | |
| output_image = gr.Textbox(lines=6, label='生成的 Prompt') | |
| translate_btn.click( | |
| fn=translate_zh2en, | |
| inputs=input_text, | |
| outputs=translate_output | |
| ) | |
| generate_prompter_btn.click( | |
| fn=generate_prompter, | |
| inputs=[translate_output, max_new_tokens, nub_beams, num_return_sequences, length_penalty], | |
| outputs=generate_prompter_output | |
| ) | |
| gpt_btn.click( | |
| fn=text_generate, | |
| inputs=translate_output, | |
| outputs=[output, output_zh] | |
| ) | |
| img_btn.click( | |
| fn=get_prompt_from_image, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0') | |