semantic-search / src /streamlit_app.py
INLEXIO's picture
Update src/streamlit_app.py
2e73ba2 verified
import streamlit as st
import requests
from sentence_transformers import SentenceTransformer
import numpy as np
from collections import defaultdict
import time
import os
import shutil
# Set cache directory to /tmp (gets cleared on restart)
os.environ['HF_HOME'] = '/tmp/huggingface'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/huggingface'
# Clear old cache on startup to prevent accumulation
def clear_old_cache():
"""Clear /tmp cache if it gets too large"""
cache_dir = '/tmp/huggingface'
try:
if os.path.exists(cache_dir):
size_mb = sum(
os.path.getsize(os.path.join(dirpath, filename))
for dirpath, dirnames, filenames in os.walk(cache_dir)
for filename in filenames
) / (1024 * 1024)
# If cache > 5GB, clear it
if size_mb > 5000:
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
except:
pass
# Run cleanup on startup
clear_old_cache()
# Page config
st.set_page_config(
page_title="OpenAlex Semantic Search",
page_icon="πŸ”¬",
layout="wide"
)
# Cache the model loading
@st.cache_resource
def load_model():
"""Load the SPECTER model - trained specifically on scientific papers"""
# SPECTER is much better for scientific content than general models
# Model size: ~440MB (vs ~80MB for MiniLM)
# Embedding size: 768 dimensions (vs 384 for MiniLM)
return SentenceTransformer('allenai/specter', cache_folder='/tmp/huggingface')
# LIMITED CACHE: Only store 50 recent searches
@st.cache_data(ttl=3600, max_entries=50, show_spinner=False)
def search_openalex_papers(query, num_results=50, country_code=None, use_fulltext=False, year_min=None, year_max=None):
"""
Search OpenAlex for papers related to the query
Optionally filter by author's country
Optionally use full-text search (searches title + abstract + full text when available)
Optionally filter by publication year range
Note: Results are cached for 1 hour, max 50 searches stored
For large requests (>100), uses pagination
"""
base_url = "https://api.openalex.org/works"
all_papers = []
# OpenAlex max per_page is 200, so we need pagination for large requests
per_page = min(200, num_results)
num_pages = (num_results + per_page - 1) // per_page # Ceiling division
for page in range(1, num_pages + 1):
params = {
"per_page": per_page,
"page": page,
"select": "id,title,abstract_inverted_index,authorships,publication_year,cited_by_count,display_name",
"mailto": "[email protected]" # Polite pool
}
# Build filter string
filters = []
if use_fulltext:
# Full-text search (searches title + abstract + full text when available)
filters.append(f"fulltext.search:{query}")
else:
# Standard search (title + abstract only)
params["search"] = query
# Add country filter if specified
if country_code:
filters.append(f"authorships.countries:{country_code}")
# Add year range filter if specified
if year_min is not None:
filters.append(f"publication_year:>{year_min-1}") # Greater than or equal
if year_max is not None:
filters.append(f"publication_year:<{year_max+1}") # Less than or equal
# Combine filters with comma (AND operation)
if filters:
params["filter"] = ",".join(filters)
try:
response = requests.get(base_url, params=params, timeout=30)
response.raise_for_status()
data = response.json()
papers = data.get("results", [])
all_papers.extend(papers)
# If we got fewer papers than requested, no more pages available
if len(papers) < per_page:
break
# Rate limiting - be nice to OpenAlex
if page < num_pages:
time.sleep(0.1) # 100ms delay between requests
except Exception as e:
st.error(f"Error fetching papers (page {page}): {str(e)}")
break
return all_papers[:num_results] # Return exactly what was requested
def reconstruct_abstract(inverted_index):
"""
Reconstruct abstract from OpenAlex inverted index format
"""
if not inverted_index:
return ""
# Create list of (position, word) tuples
words_with_positions = []
for word, positions in inverted_index.items():
for pos in positions:
words_with_positions.append((pos, word))
# Sort by position and join
words_with_positions.sort(key=lambda x: x[0])
return " ".join([word for _, word in words_with_positions])
# LIMITED CACHE: Only store 200 recent author lookups
@st.cache_data(ttl=3600, max_entries=200)
def get_author_details(author_id):
"""
Fetch detailed author information from OpenAlex
Cache limited to 200 authors to prevent storage issues
"""
base_url = f"https://api.openalex.org/authors/{author_id}"
params = {
"mailto": "[email protected]"
}
try:
response = requests.get(base_url, params=params, timeout=10)
response.raise_for_status()
return response.json()
except Exception as e:
return None
# LIMITED CACHE: Only store 200 recent author works lookups
@st.cache_data(ttl=3600, max_entries=200)
def get_author_works(author_id, max_works=20):
"""
Fetch author's recent works for validation
Returns up to max_works most recent papers by this author
"""
base_url = "https://api.openalex.org/works"
params = {
"filter": f"author.id:A{author_id}",
"per_page": max_works,
"sort": "cited_by_count:desc", # Get most cited papers
"select": "id,title,abstract_inverted_index,publication_year",
"mailto": "[email protected]"
}
try:
response = requests.get(base_url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
return data.get("results", [])
except Exception as e:
return []
def validate_author_relevance(author_id, query_embedding, model, threshold=0.25, max_works=20):
"""
Validate if an author is actually relevant to the search query
by checking semantic similarity of their body of work
Returns: (is_valid, avg_similarity, num_works_checked)
"""
# Fetch author's works
works = get_author_works(author_id, max_works)
if not works:
return False, 0.0, 0
# Generate embeddings for author's works
work_texts = []
for work in works:
title = work.get('title', '') or work.get('display_name', '')
abstract = reconstruct_abstract(work.get('abstract_inverted_index', {}))
text = f"{title} {title} {abstract}"
if text.strip():
work_texts.append(text)
if not work_texts:
return False, 0.0, 0
# Calculate similarity to query
work_embeddings = model.encode(work_texts, convert_to_tensor=False, show_progress_bar=False)
similarities = calculate_semantic_similarity(query_embedding, work_embeddings)
avg_similarity = np.mean(similarities)
# Author is valid if their average work similarity exceeds threshold
is_valid = avg_similarity >= threshold
return is_valid, avg_similarity, len(work_texts)
def calculate_semantic_similarity(query_embedding, paper_embeddings):
"""
Calculate cosine similarity between query and papers
"""
# Normalize embeddings
query_norm = query_embedding / np.linalg.norm(query_embedding)
paper_norms = paper_embeddings / np.linalg.norm(paper_embeddings, axis=1, keepdims=True)
# Calculate cosine similarity
similarities = np.dot(paper_norms, query_norm)
return similarities
def rank_authors(papers, paper_scores, model, query_embedding, min_papers=2, validate_authors=True, validation_threshold=0.25):
"""
Extract authors from papers and rank them based on:
- Semantic relevance (average of their paper scores)
- H-index
- Total citations
If validate_authors=True, checks each author's body of work for relevance
"""
author_data = defaultdict(lambda: {
'name': '',
'id': '',
'paper_scores': [],
'paper_ids': [],
'total_citations': 0,
'works_count': 0,
'h_index': 0,
'institution': '',
'validation_score': 0.0,
'validated': False
})
# Collect author information from papers
for paper, score in zip(papers, paper_scores):
for authorship in paper.get('authorships', []):
author = authorship.get('author', {})
author_id = author.get('id', '').split('/')[-1] if author.get('id') else None
if author_id and author_id.startswith('A'):
author_data[author_id]['name'] = author.get('display_name', 'Unknown')
author_data[author_id]['id'] = author_id
author_data[author_id]['paper_scores'].append(score)
author_data[author_id]['paper_ids'].append(paper.get('id', ''))
# Get institution
institutions = authorship.get('institutions', [])
if institutions and not author_data[author_id]['institution']:
author_data[author_id]['institution'] = institutions[0].get('display_name', '')
# Filter authors with minimum paper count
filtered_authors = {
aid: data for aid, data in author_data.items()
if len(data['paper_scores']) >= min_papers
}
# Fetch detailed metrics for each author
with st.spinner(f"Fetching metrics for {len(filtered_authors)} authors..."):
progress_bar = st.progress(0)
for idx, (author_id, data) in enumerate(filtered_authors.items()):
author_details = get_author_details(author_id)
if author_details:
data['h_index'] = author_details.get('summary_stats', {}).get('h_index', 0)
data['total_citations'] = author_details.get('cited_by_count', 0)
data['works_count'] = author_details.get('works_count', 0)
progress_bar.progress((idx + 1) / len(filtered_authors))
time.sleep(0.1) # Rate limiting
progress_bar.empty()
# Validate authors if requested
if validate_authors:
with st.spinner(f"Validating author relevance (checking their body of work)..."):
progress_bar = st.progress(0)
validated_count = 0
for idx, (author_id, data) in enumerate(filtered_authors.items()):
is_valid, val_score, num_works = validate_author_relevance(
author_id, query_embedding, model, validation_threshold
)
data['validated'] = is_valid
data['validation_score'] = val_score
data['num_works_checked'] = num_works
if is_valid:
validated_count += 1
progress_bar.progress((idx + 1) / len(filtered_authors))
time.sleep(0.1) # Rate limiting
progress_bar.empty()
st.success(f"βœ… {validated_count}/{len(filtered_authors)} authors validated as relevant to your query")
# Filter to only validated authors
filtered_authors = {
aid: data for aid, data in filtered_authors.items()
if data['validated']
}
# Calculate composite score for ranking
ranked_authors = []
for author_id, data in filtered_authors.items():
avg_relevance = np.mean(data['paper_scores'])
# Normalize metrics (using log scale for citations)
normalized_h_index = data['h_index'] / 100.0 # Assume max h-index of 100
normalized_citations = np.log1p(data['total_citations']) / 15.0 # Log scale
# Weighted composite score
if validate_authors:
# Include validation score in composite
composite_score = (
0.4 * avg_relevance + # 40% relevance in initial papers
0.3 * data['validation_score'] + # 30% validation (their body of work)
0.2 * min(normalized_h_index, 1.0) + # 20% h-index
0.1 * min(normalized_citations, 1.0) # 10% citations
)
else:
# Original scoring without validation
composite_score = (
0.5 * avg_relevance + # 50% relevance
0.3 * min(normalized_h_index, 1.0) + # 30% h-index
0.2 * min(normalized_citations, 1.0) # 20% citations
)
ranked_authors.append({
'name': data['name'],
'id': author_id,
'h_index': data['h_index'],
'total_citations': data['total_citations'],
'works_count': data['works_count'],
'num_relevant_papers': len(data['paper_scores']),
'avg_relevance_score': avg_relevance,
'validation_score': data['validation_score'],
'validated': data['validated'],
'composite_score': composite_score,
'institution': data['institution'],
'openalex_url': f"https://openalex.org/A{author_id}"
})
# Sort by composite score
ranked_authors.sort(key=lambda x: x['composite_score'], reverse=True)
return ranked_authors
# Define country codes
COUNTRIES = {
"All Countries": None,
"Australia": "AU",
"Canada": "CA",
"China": "CN",
"France": "FR",
"Germany": "DE",
"India": "IN",
"Japan": "JP",
"United Kingdom": "GB",
"United States": "US",
}
def main():
# Header
st.title("πŸ”¬ OpenAlex Semantic Search")
st.markdown("""
Search for research papers and discover top researchers using semantic similarity matching.
This tool uses **SPECTER** (Scientific Paper Embeddings using Citation-informed TransformERs),
a model specifically trained on scientific papers for better relevance matching.
""")
# Sidebar configuration
st.sidebar.header("βš™οΈ Search Configuration")
# Search mode selection
search_mode = st.sidebar.radio(
"Search Mode",
["Quick Search", "Deep Search"],
help="Quick: 50-100 papers (~30s) | Deep: 500-1,000 papers (2-5 min)"
)
# Number of papers based on mode
if search_mode == "Quick Search":
num_papers = st.sidebar.slider(
"Number of papers to analyze",
min_value=20,
max_value=100,
value=50,
step=10,
help="More papers = more comprehensive but slower"
)
else: # Deep Search - LIMIT TO 1000 to prevent storage issues
num_papers = st.sidebar.slider(
"Number of papers to analyze",
min_value=100,
max_value=1000, # REDUCED from 5000
value=500,
step=100,
help="⚠️ Limited to 1000 papers to prevent storage issues. Deep search takes 2-5 minutes."
)
# Country filter
selected_country = st.sidebar.selectbox(
"Filter by author country (optional)",
options=list(COUNTRIES.keys()),
help="Only include papers where at least one author is from this country"
)
country_code = COUNTRIES[selected_country]
# Year range filter
st.sidebar.subheader("πŸ“… Year Range")
current_year = 2025
use_year_filter = st.sidebar.checkbox(
"Limit by publication year",
value=False,
help="Filter papers by publication year range"
)
if use_year_filter:
year_col1, year_col2 = st.sidebar.columns(2)
with year_col1:
year_min = st.number_input(
"From",
min_value=1900,
max_value=current_year,
value=2015,
step=1
)
with year_col2:
year_max = st.number_input(
"To",
min_value=1900,
max_value=current_year,
value=current_year,
step=1
)
else:
year_min = None
year_max = None
# Full-text search option
use_fulltext = st.sidebar.checkbox(
"Include full text (when available)",
value=False,
help="Search within full paper text (not just title/abstract). ~10-15% of papers have full text available. Slightly slower."
)
# Author validation
st.sidebar.subheader("πŸ‘€ Author Validation")
validate_authors = st.sidebar.checkbox(
"Validate authors' body of work",
value=True,
help="Check each author's recent papers to confirm they're actually working in this area. More accurate but slower."
)
if validate_authors:
validation_threshold = st.sidebar.slider(
"Validation threshold",
min_value=0.15,
max_value=0.50,
value=0.25,
step=0.05,
help="Minimum average similarity score for author's works. Higher = stricter filter."
)
else:
validation_threshold = 0.25
# Minimum papers per author
min_papers_per_author = st.sidebar.slider(
"Minimum papers per author",
min_value=1,
max_value=5,
value=2,
help="Filters out authors who appear in fewer than N papers"
)
# Display settings
st.sidebar.header("πŸ“Š Display Settings")
top_papers_display = st.sidebar.slider("Number of top papers to show", 5, 50, 10)
top_authors_display = st.sidebar.slider("Number of top authors to show", 5, 50, 10)
# Storage usage info
st.sidebar.markdown("---")
st.sidebar.info("πŸ’Ύ Cache limited to prevent storage issues:\n- Max 50 searches stored\n- Max 200 authors cached\n- Max 1000 papers in Deep Search")
# Main search interface
st.header("πŸ” Search Query")
query = st.text_input(
"Enter your search query:",
placeholder="e.g., 'graph neural networks for protein structure prediction'",
help="Enter keywords or a description of what you're looking for"
)
search_button = st.button("πŸ” Search", type="primary")
if search_button and query:
# Display search parameters
year_range_text = f"Years: **{year_min}-{year_max}**" if use_year_filter else "Years: **All**"
validation_text = f"Validation: **On (threshold {validation_threshold})**" if validate_authors else "Validation: **Off**"
st.info(f"πŸ” Searching: **{query}** | Mode: **{search_mode}** | Papers: **{num_papers}** | {year_range_text} | Country: **{selected_country}** | Full-text: **{'Yes' if use_fulltext else 'No'}** | {validation_text} | Min papers/author: **{min_papers_per_author}**")
# Load model
with st.spinner("Loading semantic model..."):
model = load_model()
# Search papers
search_key = f"{query}_{num_papers}_{country_code}_{use_fulltext}_{year_min}_{year_max}"
if search_mode == "Deep Search":
progress_text = f"πŸ” Deep search in progress: Fetching up to {num_papers} papers from OpenAlex..."
progress_bar = st.progress(0, text=progress_text)
year_filter_text = f" from {year_min}-{year_max}" if use_year_filter else ""
with st.spinner(f"Searching OpenAlex for papers about '{query}'{year_filter_text}{' from ' + selected_country if country_code else ''}{' (including full text)' if use_fulltext else ''}..."):
papers = search_openalex_papers(query, num_papers, country_code, use_fulltext, year_min, year_max)
if search_mode == "Deep Search":
progress_bar.progress(33, text="πŸ“„ Papers fetched! Now generating embeddings...")
if not papers:
st.warning("No papers found. Try different search terms.")
return
st.success(f"Found {len(papers)} papers!")
# Show debug info in expander
with st.expander("πŸ” Search Details", expanded=False):
st.write(f"**Search Mode:** {search_mode}")
st.write(f"**Query:** {query}")
st.write(f"**Full-text search:** {'Enabled' if use_fulltext else 'Disabled'}")
st.write(f"**Year range:** {year_min}-{year_max}" if use_year_filter else "**Year range:** All years")
st.write(f"**Papers requested:** {num_papers}")
st.write(f"**Papers fetched:** {len(papers)}")
st.write(f"**Country filter:** {selected_country} ({country_code or 'None'})")
st.write(f"**Author validation:** {'Enabled (threshold: ' + str(validation_threshold) + ')' if validate_authors else 'Disabled'}")
st.write(f"**First paper:** {papers[0].get('display_name', 'N/A')[:100]}...")
st.write(f"**Last paper:** {papers[-1].get('display_name', 'N/A')[:100]}...")
# Prepare papers for semantic search
if search_mode == "Deep Search":
progress_bar.progress(50, text="🧠 Generating semantic embeddings...")
with st.spinner("Analyzing papers with semantic search..."):
paper_texts = []
valid_papers = []
for paper in papers:
title = paper.get('display_name', '') or paper.get('title', '')
abstract = reconstruct_abstract(paper.get('abstract_inverted_index', {}))
# Combine title and abstract (title weighted more)
text = f"{title} {title} {abstract}" # Title appears twice for emphasis
if text.strip():
paper_texts.append(text)
valid_papers.append(paper)
if not paper_texts:
st.error("No valid paper content found.")
return
# Generate embeddings
query_embedding = model.encode(query, convert_to_tensor=False)
if search_mode == "Deep Search":
progress_bar.progress(66, text=f"πŸ”’ Computing similarity for {len(paper_texts)} papers...")
paper_embeddings = model.encode(paper_texts, convert_to_tensor=False, show_progress_bar=False)
# Calculate similarities
similarities = calculate_semantic_similarity(query_embedding, paper_embeddings)
# Sort papers by similarity
sorted_indices = np.argsort(similarities)[::-1]
sorted_papers = [valid_papers[i] for i in sorted_indices]
sorted_scores = [similarities[i] for i in sorted_indices]
if search_mode == "Deep Search":
progress_bar.progress(100, text="βœ… Complete!")
time.sleep(0.5)
progress_bar.empty()
# Display top papers
st.header(f"πŸ“„ Top {top_papers_display} Most Relevant Papers")
for idx, (paper, score) in enumerate(zip(sorted_papers[:top_papers_display], sorted_scores[:top_papers_display])):
with st.expander(f"**{idx+1}. {paper.get('display_name', 'Untitled')}** (Relevance: {score:.3f})"):
col1, col2 = st.columns([3, 1])
with col1:
abstract = reconstruct_abstract(paper.get('abstract_inverted_index', {}))
if abstract:
st.markdown(f"**Abstract:** {abstract[:500]}{'...' if len(abstract) > 500 else ''}")
else:
st.markdown("*No abstract available*")
# Authors
authors = [a.get('author', {}).get('display_name', 'Unknown')
for a in paper.get('authorships', [])]
if authors:
st.markdown(f"**Authors:** {', '.join(authors[:5])}{'...' if len(authors) > 5 else ''}")
with col2:
st.metric("Year", paper.get('publication_year', 'N/A'))
st.metric("Citations", paper.get('cited_by_count', 0))
paper_id = paper.get('id', '').split('/')[-1]
if paper_id:
st.markdown(f"[View on OpenAlex](https://openalex.org/{paper_id})")
# Rank authors
st.header(f"πŸ‘¨β€πŸ”¬ Top {top_authors_display} Researchers")
ranked_authors = rank_authors(
sorted_papers,
sorted_scores,
model,
query_embedding,
min_papers=min_papers_per_author,
validate_authors=validate_authors,
validation_threshold=validation_threshold
)
if not ranked_authors:
st.warning(f"No authors found with at least {min_papers_per_author} relevant papers.")
return
# Display authors in a table
st.markdown(f"Found {len(ranked_authors)} researchers with at least {min_papers_per_author} relevant papers.")
for idx, author in enumerate(ranked_authors[:top_authors_display], 1):
with st.container():
col1, col2, col3, col4 = st.columns([3, 1, 1, 1])
with col1:
st.markdown(f"**{idx}. [{author['name']}]({author['openalex_url']})**")
if author['institution']:
st.caption(author['institution'])
with col2:
st.metric("H-Index", author['h_index'])
with col3:
st.metric("Citations", f"{author['total_citations']:,}")
with col4:
if validate_authors:
st.metric("Body Relevance", f"{author['validation_score']:.3f}")
else:
st.metric("Relevance", f"{author['avg_relevance_score']:.3f}")
caption_text = f"Total works: {author['works_count']} | Relevant papers: {author['num_relevant_papers']}"
if validate_authors:
caption_text += f" | Paper relevance: {author['avg_relevance_score']:.3f}"
st.caption(caption_text)
st.divider()
# Download results
st.header("πŸ“₯ Download Results")
# Prepare CSV data for authors
import io
import csv
csv_buffer = io.StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
header = [
'Rank', 'Name', 'Institution', 'H-Index', 'Total Citations',
'Total Works', 'Relevant Papers', 'Avg Relevance Score', 'Composite Score', 'OpenAlex URL'
]
if validate_authors:
header.insert(-1, 'Body of Work Validation Score')
csv_writer.writerow(header)
# Write data
for idx, author in enumerate(ranked_authors, 1):
row = [
idx,
author['name'],
author['institution'],
author['h_index'],
author['total_citations'],
author['works_count'],
author['num_relevant_papers'],
f"{author['avg_relevance_score']:.4f}",
f"{author['composite_score']:.4f}",
]
if validate_authors:
row.append(f"{author['validation_score']:.4f}")
row.append(author['openalex_url'])
csv_writer.writerow(row)
csv_data = csv_buffer.getvalue()
st.download_button(
label="Download Author Rankings (CSV)",
data=csv_data,
file_name=f"openalex_authors_{query.replace(' ', '_')[:30]}.csv",
mime="text/csv"
)
if __name__ == "__main__":
main()