Spaces:
Sleeping
Sleeping
| import difflib | |
| from collections import Counter | |
| import streamlit as st | |
| import pandas as pd | |
| import srsly | |
| def search(query): | |
| results = [] | |
| for grant in grants: | |
| if query in grant["tags"]: | |
| results.append({"title": grant["title"], "tags": grant["tags"]}) | |
| st.session_state["results"] = results | |
| st.header("Search π grants using MeSH π") | |
| st.sidebar.header("Information βΉ") | |
| st.sidebar.write( | |
| "A complete list of MeSH tags can be found here https://meshb.nlm.nih.gov/treeView" | |
| ) | |
| st.sidebar.write("The grants data can be found at [https://www.threesixtygiving.org/](https://data.threesixtygiving.org/). They are published under a [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) license.") | |
| st.sidebar.write( | |
| "The model used to tag grants is https://huggingface.co/Wellcome/WellcomeBertMesh" | |
| ) | |
| st.sidebar.header("Parameters") | |
| nb_results = st.sidebar.slider( | |
| "Number of results to display", value=20, min_value=1, max_value=100 | |
| ) | |
| if "grants" not in st.session_state: | |
| st.session_state["grants"] = list(srsly.read_jsonl("tagged_grants.jsonl")) | |
| grants = st.session_state["grants"] | |
| if "tags" not in st.session_state: | |
| st.session_state["tags"] = list(set([tag for grant in grants for tag in grant["tags"]])) | |
| tags = st.session_state["tags"] | |
| query = st.text_input("", value="Malaria") | |
| st.button("Search π", on_click=search, kwargs={"query": query}) | |
| if "results" in st.session_state: | |
| st.caption("Related MeSH terms") | |
| if st.session_state["results"]: | |
| retrieved_tags = [tag for res in st.session_state["results"] for tag in res["tags"]] | |
| most_common_tags = [tag for tag, _ in Counter(retrieved_tags).most_common(20)] | |
| else: | |
| most_common_tags = difflib.get_close_matches(query, tags, n=20) | |
| columns = st.columns(5) | |
| for row_i in range(3): | |
| for col_i, col in enumerate(columns): | |
| with col: | |
| tag_i = row_i * 5 + col_i | |
| if tag_i < len(most_common_tags): | |
| tag = most_common_tags[tag_i] | |
| st.button(tag, on_click=search, kwargs={"query": tag}) | |
| results = st.session_state["results"] | |
| st.caption(f"Found {len(results)}. Displaying {nb_results}") | |
| st.download_button( | |
| "Download results", | |
| data=pd.DataFrame(results).to_csv(), | |
| file_name="results.csv", | |
| mime="text/csv", | |
| ) | |
| st.table(results[:nb_results]) | |