| import os | |
| import gradio as gr | |
| import torch | |
| import transformers | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from transformers import pipeline | |
| from diffusers import StableDiffusionPipeline | |
| summarizer = pipeline("summarization") | |
| model_id = "runwayml/stable-diffusion-v1-5" | |
| SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' | |
| MIN_WORDS = 120 | |
| READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) | |
| def get_image_pipe(): | |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) | |
| pipe.to('cuda') | |
| return pipe | |
| def get_model(): | |
| model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) | |
| tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) | |
| return model, tokenizer | |
| def generate(prompt): | |
| model, tokenizer = get_model() | |
| input_context = prompt | |
| input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| input_ids=input_ids, | |
| max_length=100, | |
| temperature=0.7, | |
| num_return_sequences=3, | |
| do_sample=True | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.' | |
| def make_image(prompt): | |
| pipe = get_image_pipe() | |
| image = pipe(prompt).images[0] | |
| def predict(prompt): | |
| story = generate(prompt=prompt) | |
| summary = summarizer(story, min_length=5, max_length=20)[0]['summary_text'] | |
| image = make_image(summary) | |
| return story, summarizer(story, min_length=5, max_length=20), image | |
| title = "Lord of the rings app" | |
| description = """A Lord of the rings insired app that combines text and image generation""" | |
| gr.Interface( | |
| fn=predict, | |
| inputs="textbox", | |
| outputs=["text", "text", "image"], | |
| title=title, | |
| description=description, | |
| examples=[["My new adventure would be"], ["Then I a hobbit appeared"], ["Frodo told me"]] | |
| ).launch(share=True) |