Spaces:
Sleeping
Sleeping
| import time | |
| s = time.time() | |
| import os | |
| import datetime | |
| import faiss | |
| import streamlit as st | |
| import feedparser | |
| import urllib | |
| import cloudpickle as cp | |
| import pickle | |
| from urllib.request import urlopen | |
| from summa import summarizer | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import requests | |
| import json | |
| from scipy import ndimage | |
| from langchain_openai import AzureOpenAIEmbeddings | |
| # from langchain.llms import OpenAI | |
| from langchain_community.llms import OpenAI | |
| from langchain_openai import AzureChatOpenAI | |
| from fns import * | |
| st.image('local_files/synth_logo.png') | |
| st.markdown("") | |
| query = st.text_input('Ask me anything:', | |
| value="What causes galaxy quenching at high redshifts?") | |
| arxiv_id = None | |
| top_k = st.slider('How many papers should I show?', 1, 30, 6) | |
| retrieval_system = st.session_state.retrieval_system | |
| results = retrieval_system.retrieve(query, arxiv_id, top_k) | |
| aids = st.session_state.dataset['id'] | |
| titles = st.session_state.dataset['title'] | |
| auths = st.session_state.dataset['author'] | |
| bibcodes = st.session_state.dataset['bibcode'] | |
| all_keywords = st.session_state.dataset['keyword_search'] | |
| allyrs = st.session_state.dataset['year'] | |
| ret_indices = np.array([aids.index(results[i]) for i in range(top_k)]) | |
| yrs = [] | |
| for i in range(len(ret_indices)): | |
| yr = allyrs[ret_indices[i]] | |
| if yr < 50: | |
| yr = yr + 2000 | |
| else: | |
| yr = yr + 1900 | |
| yrs.append(yr) | |
| print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))] | |
| print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))] | |
| print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))] | |
| st.divider() | |
| st.header('top-k papers:') | |
| for i in range(len(ret_indices)): | |
| st.subheader(str(i+1)+'. '+print_titles[i]) | |
| st.write(print_auths[i]+' '+print_links[i]) | |
| st.divider() | |
| st.header('top-k papers in context:') | |
| gtkws = get_keywords(query, ret_indices, all_keywords) | |
| umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl') | |
| fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2)) | |
| im = plt.imread('local_files/astro_worldmap.png') | |
| implot = plt.imshow(im,) | |
| xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0 | |
| xax = xax / np.amax(xax) | |
| xax = xax * 1580 + 170 | |
| yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0 | |
| yax = yax / np.amax(yax) | |
| yax = (np.amax(yax)-yax) * 1700 + 30 | |
| # plt.scatter(xax, yax,s=2,alpha=0.7,c='k') | |
| for i in range(np.amax(clbls)): | |
| clust_ids = np.arange(len(clbls))[clbls == i] | |
| clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids])) | |
| # plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center", | |
| # bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3)) | |
| plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center", | |
| fontfamily='serif',color='w', | |
| bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3)) | |
| plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100) | |
| plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101) | |
| plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101) | |
| plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101) | |
| tempx = plt.xlim(); tempy = plt.ylim() | |
| plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif') | |
| plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif') | |
| plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top') | |
| plt.axis('off') | |
| st.pyplot(fig, transparent = True, bbox_inches='tight') | |