Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| import os | |
| import re | |
| import string | |
| import traceback | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import requests | |
| from huggingface_hub import HfApi | |
| hf_api = HfApi() | |
| roots_datasets = { | |
| dset.id.split("/")[-1]: dset | |
| for dset in hf_api.list_datasets( | |
| author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token") | |
| ) | |
| } | |
| def get_docid_html(docid): | |
| data_org, dataset, docid = docid.split("/") | |
| metadata = roots_datasets[dataset] | |
| locked_color = "LightGray" | |
| open_color = "#7978FF" | |
| if metadata.private: | |
| docid_html = """ | |
| <a title="This dataset is private. See the introductory text for more information" | |
| style="color:{locked_color}; font-weight: bold; text-decoration:none" | |
| onmouseover="style='color:{locked_color}; font-weight: bold; text-decoration:underline'" | |
| onmouseout="style='color:{locked_color}; font-weight: bold; text-decoration:none'" | |
| href="https://huggingface.co/datasets/bigscience-data/{dataset}" | |
| target="_blank"> | |
| π{dataset} | |
| </a> | |
| <span style="color:{open_color}; ">/{docid}</span>""".format( | |
| dataset=dataset, | |
| docid=docid, | |
| locked_color=locked_color, | |
| open_color=open_color, | |
| ) | |
| else: | |
| docid_html = """ | |
| <a title="This dataset is licensed {metadata}" | |
| style="color:{open_color}; font-weight: bold; text-decoration:none" | |
| onmouseover="style='color:{open_color}; font-weight: bold; text-decoration:underline'" | |
| onmouseout="style='color:{open_color}; font-weight: bold; text-decoration:none'" | |
| href="https://huggingface.co/datasets/bigscience-data/{dataset}" | |
| target="_blank"> | |
| {dataset} | |
| </a> | |
| <span style="color:{open_color}; ">/{docid}</span>""".format( | |
| metadata=metadata.tags[0].split(":")[-1], | |
| dataset=dataset, | |
| docid=docid, | |
| open_color=open_color, | |
| ) | |
| return docid_html | |
| PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
| PII_PREFIX = "PI:" | |
| def process_pii(text): | |
| for tag in PII_TAGS: | |
| text = text.replace( | |
| PII_PREFIX + tag, | |
| """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format( | |
| tag | |
| ), | |
| ) | |
| return text | |
| def extract_lang_from_docid(docid): | |
| return docid.split("_")[1] | |
| def normalize(document): | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(document)))) | |
| def format_result(result, highlight_terms, exact_search, datasets_filter=None): | |
| text, url, docid = result | |
| if datasets_filter is not None: | |
| datasets_filter = set(datasets_filter) | |
| dataset = docid.split("/")[1] | |
| if not dataset in datasets_filter: | |
| return "" | |
| tokens_html = "" | |
| if exact_search: | |
| query_variants = [highlight_terms] | |
| # lower | |
| query_variant = highlight_terms.lower() | |
| if query_variant not in query_variants: | |
| query_variants.append(query_variant) | |
| # upper | |
| query_variant = highlight_terms.upper() | |
| if query_variant not in query_variants: | |
| query_variants.append(query_variant) | |
| # first capital | |
| query_variant = highlight_terms.lower() | |
| query_variant = query_variant[0].upper() + query_variant[1:].lower() | |
| if query_variant not in query_variants: | |
| query_variants.append(query_variant) | |
| # camel case | |
| query_tokens = highlight_terms.split() | |
| query_variant = " ".join( | |
| [token[0].upper() + token[1:].lower() for token in query_tokens] | |
| ) | |
| if query_variant not in query_variants: | |
| query_variants.append(query_variant) | |
| for query_variant in query_variants: | |
| query_start = text.find(query_variant) | |
| if query_start >= 0: | |
| query_end = query_start + len(query_variant) | |
| tokens_html = text[0:query_start] | |
| tokens_html += "<b>{}</b>".format(text[query_start:query_end]) | |
| tokens_html += text[query_end:] | |
| break | |
| else: | |
| tokens = text.split() | |
| tokens_html = [] | |
| for token in tokens: | |
| if token in highlight_terms: | |
| tokens_html.append("<b>{}</b>".format(token)) | |
| else: | |
| tokens_html.append(token) | |
| tokens_html = " ".join(tokens_html) | |
| tokens_html = process_pii(tokens_html) | |
| url_html = ( | |
| """ | |
| <span style='font-size:12px; font-family: Arial; color:Silver; text-align: left;'> | |
| <a style='text-decoration:none; color:Silver;' | |
| onmouseover="style='text-decoration:underline; color:Silver;'" | |
| onmouseout="style='text-decoration:none; color:Silver;'" | |
| href='{url}' | |
| target="_blank"> | |
| {url} | |
| </a> | |
| </span><br> | |
| """.format( | |
| url=url | |
| ) | |
| if url is not None | |
| else "" | |
| ) | |
| docid_html = get_docid_html(docid) | |
| language = extract_lang_from_docid(docid) | |
| result_html = """{} | |
| <span style='font-size:14px; font-family: Arial; color:MediumAquaMarine'>Language: {} | </span> | |
| <span style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {} | </span> | |
| <a href="https://forms.gle/AdBLLwRApqcLkHYA8" target="_blank"> | |
| <button style="color:#ffcdf8; ">π΄ββ οΈ Flag result π΄ββ οΈ</button> | |
| </a><br> | |
| <span style='font-family: Arial;'>{}</span><br> | |
| <br> | |
| """.format( | |
| url_html, language, docid_html, tokens_html | |
| ) | |
| return "<p>" + result_html + "</p>" | |
| def format_result_page( | |
| language, results, highlight_terms, num_results, exact_search, datasets_filter=None | |
| ) -> gr.HTML: | |
| filtered_num_results = 0 | |
| header_html = "" | |
| if language == "detect_language" and not exact_search: | |
| header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'> | |
| Detected language: <b style='color:MediumAquaMarine'>{}</b></div>""".format( | |
| list(results.keys())[0] | |
| ) | |
| result_page_html = "" | |
| for lang, results_for_lang in results.items(): | |
| print("Processing language", lang) | |
| if len(results_for_lang) == 0: | |
| if exact_search: | |
| result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
| No results found.</div>""" | |
| else: | |
| result_page_html += """<div style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
| No results for language: <b>{}</b></div>""".format( | |
| lang | |
| ) | |
| continue | |
| results_for_lang_html = "" | |
| for result in results_for_lang: | |
| result_html = format_result( | |
| result, highlight_terms, exact_search, datasets_filter | |
| ) | |
| if result_html != "": | |
| filtered_num_results += 1 | |
| results_for_lang_html += result_html | |
| if language == "all" and not exact_search: | |
| results_for_lang_html = f""" | |
| <details> | |
| <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'> | |
| Results for language: <b>{lang}</b> | |
| </summary> | |
| {results_for_lang_html} | |
| </details>""" | |
| result_page_html += results_for_lang_html | |
| if num_results is not None: | |
| header_html += """<div style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'> | |
| Total number of matches: <b style='color:MediumAquaMarine'>{}</b></div>""".format( | |
| num_results | |
| ) | |
| return header_html + result_page_html | |
| def extract_results_from_payload(query, language, payload, exact_search): | |
| results = payload["results"] | |
| processed_results = dict() | |
| datasets = set() | |
| highlight_terms = None | |
| num_results = None | |
| if exact_search: | |
| highlight_terms = query | |
| num_results = payload["num_results"] | |
| results = {"dummy": results} | |
| else: | |
| highlight_terms = payload["highlight_terms"] | |
| for lang, results_for_lang in results.items(): | |
| processed_results[lang] = list() | |
| for result in results_for_lang: | |
| text = result["text"] | |
| url = ( | |
| result["meta"]["url"] | |
| if "meta" in result | |
| and result["meta"] is not None | |
| and "url" in result["meta"] | |
| else None | |
| ) | |
| docid = result["docid"] | |
| _, dataset, _ = docid.split("/") | |
| datasets.add(dataset) | |
| processed_results[lang].append((text, url, docid)) | |
| return processed_results, highlight_terms, num_results, list(datasets) | |
| def no_query_error_message(): | |
| return f""" | |
| <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
| Please provide a non-empty query. | |
| </p><br><hr><br>""" | |
| def process_error(error_type, payload): | |
| if error_type == "unsupported_lang": | |
| detected_lang = payload["err"]["meta"]["detected_lang"] | |
| return f""" | |
| <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
| Detected language <b>{detected_lang}</b> is not supported.<br> | |
| Please choose a language from the dropdown or type another query. | |
| </p><br><hr><br>""" | |
| def extract_error_from_payload(payload): | |
| if "err" in payload: | |
| return payload["err"]["type"] | |
| return None | |
| def request_payload(query, language, exact_search, num_results=10, received_results=0): | |
| post_data = {"query": query, "k": num_results, "received_results": received_results} | |
| if language != "detect_language": | |
| post_data["lang"] = language | |
| address = ( | |
| os.environ.get("address_exact_search") | |
| if exact_search | |
| else os.environ.get("address") | |
| ) | |
| output = requests.post( | |
| address, | |
| headers={"Content-type": "application/json"}, | |
| data=json.dumps(post_data), | |
| timeout=120, | |
| ) | |
| payload = json.loads(output.text) | |
| return payload | |
| title = ( | |
| """<p style="text-align: center; font-size:28px"> πΈ π ROOTS search tool π πΈ </p>""" | |
| ) | |
| description = """ | |
| The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose | |
| of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). The ROOTS Search | |
| Tool allows you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages | |
| included in ROOTS. We also offer exact search which is enabled if you enclose your query in double quotes. More details | |
| about the implementation and use cases is available in our [paper](https://arxiv.org/abs/2302.14035) - please cite it | |
| if you use ROOTS Search Tool in your work. For more information and instructions on how to access the full corpus | |
| consult [this form](https://forms.gle/qyYswbEL5kA23Wu99).""" | |
| if __name__ == "__main__": | |
| demo = gr.Blocks(css=".underline-on-hover:hover { text-decoration: underline; }") | |
| with demo: | |
| processed_results_state = gr.State([]) | |
| highlight_terms_state = gr.State([]) | |
| num_results_state = gr.State(0) | |
| exact_search_state = gr.State(False) | |
| received_results_state = gr.State(0) | |
| with gr.Row(): | |
| gr.Markdown(value=title) | |
| with gr.Row(): | |
| gr.Markdown(value=description) | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| lines=1, | |
| max_lines=1, | |
| placeholder="Put your query in double quotes for exact search.", | |
| label="Query", | |
| ) | |
| with gr.Row(): | |
| lang = gr.Dropdown( | |
| choices=[ | |
| "ar", | |
| "ca", | |
| "code", | |
| "en", | |
| "es", | |
| "eu", | |
| "fr", | |
| "id", | |
| "indic", | |
| "nigercongo", | |
| "pt", | |
| "vi", | |
| "zh", | |
| "detect_language", | |
| "all", | |
| ], | |
| value="en", | |
| label="Language", | |
| ) | |
| k = gr.Slider( | |
| 1, | |
| 100, | |
| value=10, | |
| step=1, | |
| label="Max Results in fuzzy search or Max Results per page in exact search", | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| with gr.Row(visible=False) as datasets_filter: | |
| available_datasets = gr.Dropdown( | |
| type="value", | |
| choices=[], | |
| value=[], | |
| label="Datasets Filter", | |
| multiselect=True, | |
| ) | |
| with gr.Row(): | |
| result_page_html = gr.HTML(label="Results") | |
| with gr.Row(visible=False) as pagination: | |
| next_page_btn = gr.Button("Next Page") | |
| def run_query(query, lang, k, dropdown_input, received_results): | |
| query = query.strip() | |
| exact_search = False | |
| if query.startswith('"') and query.endswith('"') and len(query) >= 2: | |
| exact_search = True | |
| query = query[1:-1] | |
| else: | |
| query = " ".join(query.split()) | |
| if query == "" or query is None: | |
| return ( | |
| [], | |
| [], | |
| 0, | |
| False, | |
| no_query_error_message(), | |
| [], | |
| ) | |
| payload = request_payload(query, lang, exact_search, k, received_results) | |
| err = extract_error_from_payload(payload) | |
| if err is not None: | |
| return ( | |
| [], | |
| [], | |
| 0, | |
| False, | |
| process_error(err, payload), | |
| [], | |
| ) | |
| ( | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| ds, | |
| ) = extract_results_from_payload( | |
| query, | |
| lang, | |
| payload, | |
| exact_search, | |
| ) | |
| result_page = format_result_page( | |
| lang, processed_results, highlight_terms, num_results, exact_search | |
| ) | |
| return ( | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| result_page, | |
| ds, | |
| ) | |
| def submit(query, lang, k, dropdown_input): | |
| print("submitting", query, lang, k) | |
| ( | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| result_page, | |
| datasets, | |
| ) = run_query(query, lang, k, dropdown_input, 0) | |
| has_more_results = exact_search and (num_results > k) | |
| current_results = ( | |
| len(next(iter(processed_results.values()))) | |
| if len(processed_results) > 0 | |
| else 0 | |
| ) | |
| return [ | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| gr.update(visible=True) | |
| if current_results > 0 | |
| else gr.update(visible=False), | |
| gr.Dropdown.update(choices=datasets, value=datasets), | |
| gr.update(visible=has_more_results), | |
| current_results, | |
| result_page, | |
| ] | |
| def next_page( | |
| query, | |
| lang, | |
| k, | |
| dropdown_input, | |
| received_results, | |
| processed_results, | |
| ): | |
| ( | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| result_page, | |
| datasets, | |
| ) = run_query(query, lang, k, dropdown_input, received_results) | |
| current_results = sum( | |
| len(results) for results in processed_results.values() | |
| ) | |
| has_more_results = exact_search and ( | |
| received_results + current_results < num_results | |
| ) | |
| print("received_results", received_results) | |
| print("current_results", current_results) | |
| print("has_more_results", has_more_results) | |
| return [ | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| gr.update(visible=True) | |
| if current_results > 0 | |
| else gr.update(visible=False), | |
| gr.Dropdown.update(choices=datasets, value=datasets), | |
| gr.update(visible=current_results >= k and has_more_results), | |
| received_results + current_results, | |
| result_page, | |
| ] | |
| def filter_datasets( | |
| lang, | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| datasets_filter, | |
| ): | |
| result_page_html = format_result_page( | |
| lang, | |
| processed_results, | |
| highlight_terms, | |
| num_results, | |
| exact_search, | |
| datasets_filter, | |
| ) | |
| return result_page_html | |
| query.submit( | |
| fn=submit, | |
| inputs=[query, lang, k, available_datasets], | |
| outputs=[ | |
| processed_results_state, | |
| highlight_terms_state, | |
| num_results_state, | |
| exact_search_state, | |
| datasets_filter, | |
| available_datasets, | |
| pagination, | |
| received_results_state, | |
| result_page_html, | |
| ], | |
| ) | |
| submit_btn.click( | |
| submit, | |
| inputs=[query, lang, k, available_datasets], | |
| outputs=[ | |
| processed_results_state, | |
| highlight_terms_state, | |
| num_results_state, | |
| exact_search_state, | |
| datasets_filter, | |
| available_datasets, | |
| pagination, | |
| received_results_state, | |
| result_page_html, | |
| ], | |
| ) | |
| next_page_btn.click( | |
| next_page, | |
| inputs=[ | |
| query, | |
| lang, | |
| k, | |
| available_datasets, | |
| received_results_state, | |
| processed_results_state, | |
| ], | |
| outputs=[ | |
| processed_results_state, | |
| highlight_terms_state, | |
| num_results_state, | |
| exact_search_state, | |
| datasets_filter, | |
| available_datasets, | |
| pagination, | |
| received_results_state, | |
| result_page_html, | |
| ], | |
| ) | |
| available_datasets.change( | |
| filter_datasets, | |
| inputs=[ | |
| lang, | |
| processed_results_state, | |
| highlight_terms_state, | |
| num_results_state, | |
| exact_search_state, | |
| available_datasets, | |
| ], | |
| outputs=result_page_html, | |
| ) | |
| demo.launch(enable_queue=False, debug=True) | |

