Spaces:
Running
Running
| import hashlib | |
| import time | |
| import uuid | |
| from urllib.parse import urlencode | |
| import json | |
| import requests | |
| from PIL import Image, ImageOps | |
| import os | |
| import gradio as gr | |
| import ops | |
| from watermark import WatermarkApp | |
| import db_examples | |
| class ApiClient(object): | |
| def __init__(self, app_key: str, access_key_id: str, access_key_secret: str, endpoint: str): | |
| self.app_key = app_key | |
| self.access_key_id = access_key_id | |
| self.access_key_secret = access_key_secret | |
| self.endpoint = endpoint | |
| self.base_url = 'http://' + self.endpoint + '/api' | |
| self.timeout = 8000 | |
| self.session = requests.Session() | |
| self.session.headers.update( | |
| { | |
| "Content-Type": "application/json;charset=utf-8", | |
| "accessKey": self.access_key_id, | |
| "appKey": self.app_key | |
| } | |
| ) | |
| def send_get(self, headers=None, params=None): | |
| if headers is None: | |
| headers = {} | |
| if params is None: | |
| params = {} | |
| args = self.cleanNoneValue( | |
| { | |
| "url": self.base_url, | |
| "headers": self._prepare_headers(headers), | |
| "params": self._prepare_params(params), | |
| "timeout": self.timeout | |
| } | |
| ) | |
| response = self._dispatch_request("GET")(**args) | |
| self._handle_exception(response) | |
| try: | |
| data = response.json() | |
| except ValueError: | |
| data = response.text | |
| return data | |
| def send_post(self, headers=None, params=None): | |
| if headers is None: | |
| headers = {} | |
| if params is None: | |
| params = {} | |
| args = self.cleanNoneValue( | |
| { | |
| "url": self.base_url, | |
| "headers": self._prepare_headers(headers), | |
| "json": params, | |
| "timeout": self.timeout | |
| } | |
| ) | |
| response = self._dispatch_request("POST")(**args) | |
| self._handle_exception(response) | |
| try: | |
| data = response.json() | |
| except ValueError: | |
| data = response.text | |
| return data | |
| def _prepare_headers(self, headers): | |
| headers['requestId'] = str(uuid.uuid1()) | |
| timestamp = int(round(time.time() * 1000)) | |
| headers['timestamp'] = str(timestamp) | |
| headers['sign'] = self._get_sign(timestamp) | |
| return headers | |
| def _prepare_params(self, params): | |
| return self.encoded_string(self.cleanNoneValue(params)) | |
| def _get_sign(self, timestamp): | |
| key = self.app_key + self.access_key_id + str(timestamp) + self.access_key_secret | |
| md5hash = hashlib.md5(str.encode(key, 'utf-8')) | |
| return md5hash.hexdigest() | |
| def _dispatch_request(self, http_method): | |
| return { | |
| "GET": self.session.get, | |
| "DELETE": self.session.delete, | |
| "PUT": self.session.put, | |
| "POST": self.session.post, | |
| }.get(http_method, "GET") | |
| def _handle_exception(self, response): | |
| status_code = response.status_code | |
| if status_code < 400: | |
| return | |
| raise Exception(response.text) | |
| def encoded_string(self, query): | |
| return urlencode(query, True).replace("%40", "@") | |
| def cleanNoneValue(self, d) -> dict: | |
| out = {} | |
| for k in d.keys(): | |
| if d[k] is not None: | |
| out[k] = d[k] | |
| return out | |
| api_key = os.environ['APIKEY'] | |
| ak = os.environ['AK'] | |
| sk = os.environ['SK'] | |
| endpoint = os.environ['ENDPOINT'] | |
| apiname = os.environ['APINAME'] | |
| callbackapiname = os.environ['CALLBACKAPINAME'] | |
| osssigneapiname = os.environ['OSSSIGNAPINAME'] | |
| client = ApiClient(api_key, ak, sk, endpoint) | |
| watermark_app = WatermarkApp() | |
| def upload_to_oss(file_path, accessid, policy, signature, host, key, callback): | |
| with open(file_path, 'rb') as f: | |
| files = {'file': (key, f)} | |
| data = { | |
| 'OSSAccessKeyId': accessid, | |
| 'policy': policy, | |
| 'Signature': signature, | |
| 'key': key, | |
| 'callback': callback, | |
| } | |
| response = requests.post(host, files=files, data=data, timeout=20) | |
| if response.status_code == 204 or response.ok: | |
| return key | |
| else: | |
| print(f"file upload failed: {response.text}") | |
| return None | |
| def upload_oss_bucket(image_pil): | |
| headers = { | |
| 'apiName': osssigneapiname | |
| } | |
| params = { | |
| 'fileType': '1' | |
| } | |
| result = client.send_get(headers, params) | |
| try: | |
| result = result['data'] | |
| except Exception as e: | |
| raise ValueError('oss sign error') | |
| accessid = result['accessid'] | |
| policy = result['policy'] | |
| signature = result['signature'] | |
| host = result['host'] | |
| callback = '' | |
| oss_key = os.path.join(result['dir'], '{}.jpg'.format(result['expire'])) | |
| file_path = 'test.jpg' | |
| image_pil.save(file_path) | |
| oss_key = upload_to_oss(file_path, accessid, policy, signature, host, oss_key, callback) | |
| if oss_key is None: | |
| raise ValueError('oss upload error') | |
| return oss_key | |
| def call_text_guided_relighting(image, mode, prompt, seed, steps): | |
| headers = { | |
| 'apiName': apiname | |
| } | |
| image = ImageOps.exif_transpose(image).convert('RGB') | |
| image = ops.resize_keep_hw_rate(image, tar_res=1280) | |
| image = upload_oss_bucket(image) | |
| ops.print_with_datetime('start call_text_guided_relighting') | |
| ops.print_with_datetime(prompt) | |
| params = { | |
| 'image': image, | |
| 'inference_mode': 'free_txt2bg_gen', | |
| 'mode': mode, | |
| 'prompt': prompt, | |
| 'seed': seed, | |
| 'steps': steps, | |
| } | |
| ops.print_with_datetime(f'length params, {len(json.dumps(params))}') | |
| task_id = client.send_post(headers, params)['data'] | |
| time.sleep(10) | |
| headers = { | |
| 'apiName': callbackapiname | |
| } | |
| params = { | |
| 'id': task_id | |
| } | |
| flag = True | |
| while flag: | |
| result = client.send_get(headers, params)['data'] | |
| if result['status'] != 1: | |
| flag = False | |
| else: | |
| time.sleep(10) | |
| if result['status'] != 2: | |
| raise ValueError('something wrong in the process') | |
| result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] | |
| result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] | |
| result_1 = ops.decode_img_from_url(result_1) | |
| result_1 = watermark_app.process_image(result_1) | |
| result_2 = ops.decode_img_from_url(result_2) | |
| return result_1, result_2 | |
| def call_image_guided_relighting(image, ref_img, seed, steps): | |
| headers = { | |
| 'apiName': apiname | |
| } | |
| image = ImageOps.exif_transpose(image).convert('RGB') | |
| image = ops.resize_keep_hw_rate(image, tar_res=1280) | |
| image = upload_oss_bucket(image) | |
| ref_img = ImageOps.exif_transpose(ref_img).convert('RGB') | |
| ref_img = ops.resize_keep_hw_rate(ref_img, tar_res=1280) | |
| ref_img = upload_oss_bucket(ref_img) | |
| ops.print_with_datetime('start call_image_guided_relighting') | |
| params = { | |
| 'image': image, | |
| 'inference_mode': 'replica_gen', | |
| 'mode': 'normal', | |
| 'ref_img': ref_img, | |
| 'seed': seed, | |
| 'steps': steps, | |
| } | |
| ops.print_with_datetime(f'length params, {len(json.dumps(params))}') | |
| task_id = client.send_post(headers, params) | |
| print(task_id) | |
| task_id = task_id['data'] | |
| time.sleep(10) | |
| headers = { | |
| 'apiName': callbackapiname | |
| } | |
| params = { | |
| 'id': task_id | |
| } | |
| flag = True | |
| while flag: | |
| result = client.send_get(headers, params) | |
| result = result['data'] | |
| if result['status'] != 1: | |
| flag = False | |
| else: | |
| time.sleep(10) | |
| if result['status'] != 2: | |
| raise ValueError('something wrong in the process') | |
| result_1 = result['sasMyCreationPicVOs'][0]['picUrl'] | |
| result_2 = result['sasMyCreationPicVOs'][0]['maskUrl'] | |
| result_1 = ops.decode_img_from_url(result_1) | |
| result_1 = watermark_app.process_image(result_1) | |
| result_2 = ops.decode_img_from_url(result_2) | |
| return result_1, result_2 | |
| quick_prompts = [ | |
| 'warm lighting', | |
| 'sunshine from window', | |
| 'neon lighting', | |
| 'at noon. bright sunlight', | |
| 'at dusk', | |
| 'golden time', | |
| 'natural lighting', | |
| 'shadow from window', | |
| 'soft studio lighting', | |
| 'red lighting', | |
| 'purple lighting' | |
| ] | |
| quick_prompts = [[x] for x in quick_prompts] | |
| quick_content_prompts = [ | |
| 'by the sea', | |
| 'in the forest', | |
| 'on the snow mountain', | |
| 'by the city street', | |
| 'on the grassy field', | |
| 'cityscape', | |
| 'on the desert', | |
| 'in the living room', | |
| ] | |
| quick_content_prompts = [[x] for x in quick_content_prompts] | |
| quick_subjects = [ | |
| 'portrait photography of a woman', | |
| 'portrait photography of man', | |
| 'product photography', | |
| ] | |
| quick_subjects = [[x] for x in quick_subjects] | |
| with gr.Blocks().queue() as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | |
| FreeLighting: relighting model with both text-condition and image-condition | |
| </div> | |
| """) | |
| with gr.Row(): | |
| gr.Markdown("See more information in https://github.com/liuyuxuan3060/FreeLighting") | |
| with gr.Row(): | |
| gr.Markdown("We use a open source segmentation model to generate image mask") | |
| with gr.Tabs(): | |
| with gr.TabItem("text-guided relighting") as t2v_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image = gr.Image(label="original_image", type="pil", height=480) | |
| image_mask = gr.Image(label="image_mask", type="pil", height=480) | |
| with gr.Row(): | |
| prompt = gr.Textbox(value="", label="text prompt") | |
| with gr.Row(): | |
| mode = gr.Radio(choices=["normal", "uniform-lit"], value='normal', label="uniform-lit mode will use double time", type='value') | |
| seed = gr.Number(value=12345, label="random seed", precision=0) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
| button = gr.Button("generate") | |
| with gr.Row(): | |
| example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt]) | |
| with gr.Row(): | |
| example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt]) | |
| with gr.Row(): | |
| example_quick_content_prompts = gr.Dataset(samples=quick_content_prompts, label='Content Quick List', samples_per_page=1000, components=[prompt]) | |
| with gr.Column(): | |
| relighting_image = gr.Image(label="relighted_image", type="pil", height=480) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=db_examples.text_guided_examples, | |
| inputs=[ | |
| image, mode, prompt, seed, steps | |
| , relighting_image | |
| ], | |
| outputs=[relighting_image], | |
| examples_per_page=1024 | |
| ) | |
| button.click( | |
| fn=call_text_guided_relighting, | |
| inputs=[ | |
| image, mode, prompt, seed, steps | |
| ], | |
| outputs=[relighting_image, image_mask], | |
| ) | |
| example_quick_content_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_content_prompts, prompt], outputs=prompt, show_progress=False, queue=False) | |
| example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False) | |
| example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False) | |
| with gr.TabItem("image-guided relighting") as i2v_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| image = gr.Image(label="original_image", type="pil", height=480) | |
| ref_img = gr.Image(label="reference_image", type="pil", height=480) | |
| image_mask = gr.Image(label="image_mask", type="pil", height=480) | |
| with gr.Row(): | |
| seed = gr.Number(value=12345, label="random seed", precision=0) | |
| steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) | |
| button = gr.Button("generate") | |
| with gr.Column(): | |
| relighting_image = gr.Image(label="relighted_image", type="pil", height=480) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=db_examples.image_guided_examples, | |
| inputs=[ | |
| image, ref_img, seed, steps | |
| , relighting_image | |
| ], | |
| outputs=[relighting_image], | |
| examples_per_page=1024 | |
| ) | |
| button.click( | |
| fn=call_image_guided_relighting, | |
| inputs=[ | |
| image, ref_img, seed, steps | |
| ], | |
| outputs=[relighting_image, image_mask], | |
| ) | |
| demo.launch() | |