Spaces:
Build error
Build error
| import os | |
| import pickle | |
| import json | |
| import time | |
| import datetime | |
| from xml.etree import ElementTree | |
| from huggingface_hub import CommitScheduler | |
| from huggingface_hub import HfApi | |
| from pathlib import Path | |
| import requests | |
| from datasets import load_dataset_builder | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ['KMP_DUPLICATE_LIB_OK']='True' | |
| from utils import * | |
| import thread6 | |
| MAX_DAILY_PAPER = 200 | |
| DAY_TIME = 60 * 60 * 24 | |
| DAY_TIME_MIN = 60 * 24 | |
| DATA_REPO_ID = "cmulgy/ArxivCopilot_data" | |
| READ_WRITE_TOKEN = os.environ['READ_WRITE'] | |
| api = HfApi(token = READ_WRITE_TOKEN) | |
| DATASET_DIR = Path(".") | |
| DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| from huggingface_hub import hf_hub_download | |
| scheduler = CommitScheduler( | |
| repo_id=DATA_REPO_ID, | |
| repo_type="dataset", | |
| folder_path=DATASET_DIR, | |
| path_in_repo=".", | |
| hf_api = api, | |
| every = DAY_TIME_MIN, | |
| ) | |
| def feedback_thought(input_ls): # preload | |
| agent, query, ansA, ansB, feedbackA, feedbackB = input_ls | |
| filename_thought = agent.thought_path | |
| filename = agent.feedback_path | |
| date = agent.today | |
| json_data = agent.feedback | |
| json_data_thought = agent.thought | |
| if date in json_data: | |
| if query not in json_data[date]: | |
| json_data[date][query] = {} | |
| else: | |
| json_data[date] = {} | |
| json_data[date][query] = {} | |
| if date not in json_data_thought: | |
| json_data_thought[date] = [] | |
| json_data[date][query]["answerA"] = (ansA) | |
| json_data[date][query]["feedbackA"] = feedbackA | |
| json_data[date][query]["answerB"] = (ansB) | |
| json_data[date][query]["feedbackB"] = feedbackB | |
| with scheduler.lock: | |
| with open(filename,"w") as f: | |
| json.dump(json_data,f) | |
| preferred_ans = "" | |
| if feedbackA == 1: | |
| new_knowledge = response_verify([query], [ansA], verify=False) | |
| preferred_ans = ansA | |
| # json_data_thought[date].append(query + ansA) | |
| else: | |
| new_knowledge = response_verify([query], [ansB], verify=False) | |
| preferred_ans = ansB | |
| # json_data_thought[date].append(query + ansB) | |
| if ('idk' not in new_knowledge[0]): | |
| new_knowledge_embedding = get_bert_embedding(new_knowledge) | |
| thought_embedding_all = [] | |
| for k in agent.thought_embedding.keys(): | |
| thought_embedding_all.extend(agent.thought_embedding[k]) | |
| similarity = calculate_similarity(thought_embedding_all, new_knowledge_embedding[0]) | |
| similarity_values = [s.item() for s in similarity] # Convert each tensor to a scalar | |
| if all(s < 0.85 for s in similarity_values): | |
| # self.update_feedback(an, answer_l_org, query) | |
| tem_thought = query + preferred_ans | |
| json_data_thought[date].append(tem_thought) | |
| if date not in agent.thought_embedding: | |
| agent.thought_embedding = {} | |
| agent.thought_embedding[date] = [get_bert_embedding([tem_thought])[0]] | |
| else: | |
| agent.thought_embedding[date].append(get_bert_embedding([tem_thought])[0]) | |
| with scheduler.lock: | |
| with open(filename_thought,"w") as f: | |
| json.dump(json_data_thought,f) | |
| with open(agent.thought_embedding_path, "wb") as f: | |
| pickle.dump(agent.thought_embedding, f) | |
| # return "Give feedback successfully!" | |
| def dailyDownload(agent_ls): | |
| agent = agent_ls[0] | |
| while True: | |
| time.sleep(DAY_TIME) | |
| data_collector = [] | |
| keywords = dict() | |
| keywords["Machine Learning"] = "Machine Learning" | |
| for topic,keyword in keywords.items(): | |
| data, agent.newest_day = get_daily_papers(topic, query = keyword, max_results = MAX_DAILY_PAPER) | |
| data_collector.append(data) | |
| json_file = agent.dataset_path | |
| update_file=update_json_file(json_file, data_collector, scheduler) | |
| time_chunks_embed={} | |
| for data in data_collector: | |
| for date in data.keys(): | |
| papers = data[date]['abstract'] | |
| papers_embedding=get_bert_embedding(papers) | |
| time_chunks_embed[date.strftime("%m/%d/%Y")] = papers_embedding | |
| update_paper_file=update_pickle_file(agent.embedding_path,time_chunks_embed, scheduler) | |
| agent.paper = update_file | |
| agent.paper_embedding = update_paper_file | |
| print("Today is " + agent.newest_day.strftime("%m/%d/%Y")) | |
| def dailySave(agent_ls): | |
| agent = agent_ls[0] | |
| while True: | |
| time.sleep(DAY_TIME) | |
| with scheduler.lock: | |
| with open(agent.trend_idea_path, "w") as f_: | |
| json.dump(agent.trend_idea, f_) | |
| with open(agent.thought_path, "w") as f_: | |
| json.dump(agent.thought, f_) | |
| with open(agent.thought_embedding_path, "wb") as f: | |
| pickle.dump(agent.thought_embedding, f) | |
| with open(agent.profile_path,"w") as f: | |
| json.dump(agent.profile,f) | |
| with open(agent.comment_path,"w") as f: | |
| json.dump(agent.comment,f) | |
| class ArxivAgent: | |
| def __init__(self): | |
| self.dataset_path = DATASET_DIR / "dataset/paper.json" | |
| self.thought_path = DATASET_DIR / "dataset/thought.json" | |
| self.trend_idea_path = DATASET_DIR / "dataset/trend_idea.json" | |
| self.profile_path = DATASET_DIR / "dataset/profile.json" | |
| self.comment_path = DATASET_DIR / "dataset/comment.json" | |
| self.embedding_path = DATASET_DIR / "dataset/paper_embedding.pkl" | |
| self.thought_embedding_path = DATASET_DIR / "dataset/thought_embedding.pkl" | |
| self.feedback_path = DATASET_DIR / "dataset/feedback.json" | |
| self.today = datetime.datetime.now().strftime("%m/%d/%Y") | |
| self.newest_day = "" | |
| # import pdb | |
| # pdb.set_trace() | |
| self.load_cache() | |
| self.download() | |
| try: | |
| thread6.run_threaded(dailyDownload, [self]) | |
| thread6.run_threaded(dailySave, [self]) | |
| except: | |
| print("Error: unable to start thread") | |
| def edit_profile(self, profile, author_name): | |
| self.profile[author_name]=profile | |
| return "Successfully edit profile!" | |
| def get_profile(self, author_name): | |
| if author_name == "": return None | |
| profile = self.get_arxiv_data_by_author(author_name) | |
| return profile | |
| def select_date(self, method, profile_input): | |
| today = self.newest_day | |
| chunk_embedding_date={} | |
| paper_by_date = {} | |
| if method == "day": | |
| offset_day = today | |
| str_day = offset_day.strftime("%m/%d/%Y") | |
| if str_day in self.paper: | |
| paper_by_date[str_day] = self.paper[str_day] | |
| chunk_embedding_date[str_day]=self.paper_embedding[str_day] | |
| elif method == "week": | |
| for i in range(7): | |
| offset_day = today - datetime.timedelta(days=i) | |
| str_day = offset_day.strftime("%m/%d/%Y") | |
| if str_day in self.paper: | |
| # print(str_day) | |
| paper_by_date[str_day] = self.paper[str_day] | |
| chunk_embedding_date[str_day] = self.paper_embedding[str_day] | |
| else: | |
| # import pdb | |
| # pdb.set_trace() | |
| paper_by_date = self.paper | |
| chunk_embedding_date=self.paper_embedding | |
| dataset = paper_by_date | |
| data_chunk_embedding=chunk_embedding_date | |
| profile = profile_input | |
| key_update = list(self.paper.keys())[-1] | |
| isQuery = False | |
| if profile in self.trend_idea: | |
| if key_update in self.trend_idea[profile]: | |
| if method in self.trend_idea[profile][key_update]: | |
| trend = self.trend_idea[profile][key_update][method]["trend"] | |
| reference = self.trend_idea[profile][key_update][method]["reference"] | |
| idea = self.trend_idea[profile][key_update][method]["idea"] | |
| isQuery = True | |
| # import pdb | |
| # pdb.set_trace() | |
| if not(isQuery): | |
| trend, paper_link = summarize_research_field(profile, "Machine Learning", dataset,data_chunk_embedding) # trend | |
| reference = papertitleAndLink(paper_link) | |
| idea = generate_ideas(trend) # idea | |
| if profile in self.trend_idea: | |
| if key_update in self.trend_idea[profile]: | |
| if not(method in self.trend_idea[profile][key_update]): | |
| self.trend_idea[profile][key_update][method] = {} | |
| else: | |
| self.trend_idea[profile][key_update] = {} | |
| self.trend_idea[profile][key_update][method] = {} | |
| else: | |
| self.trend_idea[profile] = {} | |
| self.trend_idea[profile][key_update] = {} | |
| self.trend_idea[profile][key_update][method] = {} | |
| self.trend_idea[profile][key_update][method]["trend"] = trend | |
| self.trend_idea[profile][key_update][method]["reference"] = reference | |
| self.trend_idea[profile][key_update][method]["idea"] = idea | |
| if key_update not in self.thought: | |
| self.thought[key_update] = [] | |
| if key_update not in self.thought_embedding: | |
| self.thought_embedding[key_update] = [] | |
| self.thought[key_update].append(trend[0]) | |
| self.thought_embedding[key_update].append(get_bert_embedding([trend])[0]) | |
| self.thought[key_update].append(idea[0]) | |
| self.thought_embedding[key_update].append(get_bert_embedding([idea])[0]) | |
| return trend, reference, idea | |
| def response(self, data, profile_input): | |
| query = [data] | |
| profile = profile_input | |
| query_embedding=get_bert_embedding(query) | |
| retrieve_text,retrieve_text_org=self.generate_pair_retrieve_text(query_embedding) | |
| context,context_org = [retrieve_text],[retrieve_text_org] | |
| answer_l = get_response_through_LLM_answer(query, context,profile) | |
| answer_l_org = get_response_through_LLM_answer(query, context_org, profile) | |
| return answer_l,answer_l_org | |
| def generate_pair_retrieve_text(self, query_embedding): | |
| # Access dataset | |
| dataset = self.paper | |
| thought = self.thought | |
| text_chunk_l = [] | |
| chunks_embedding_text_all = [] | |
| text_org_chunk_l = [] | |
| chunks_org_embedding_text_all = [] | |
| # Include all text chunks and their embeddings | |
| for k in dataset.keys(): | |
| text_chunk_l.extend(dataset[k]['abstract']) | |
| chunks_embedding_text_all.extend(self.paper_embedding[k]) | |
| text_org_chunk_l.extend(dataset[k]['abstract']) | |
| chunks_org_embedding_text_all.extend(self.paper_embedding[k]) | |
| for k in thought.keys(): | |
| if k in self.thought_embedding.keys(): | |
| text_chunk_l.extend(thought[k]) | |
| chunks_embedding_text_all.extend(self.thought_embedding[k]) | |
| # Include thoughts if not excluded | |
| neib_all = neiborhood_search(chunks_embedding_text_all, query_embedding, num=10) | |
| neib_all = neib_all.reshape(-1) | |
| # import pdb | |
| # pdb.set_trace() | |
| # Compile retrieved text | |
| # import pdb | |
| # pdb.set_trace() | |
| retrieve_text = ''.join([text_chunk_l[i] for i in neib_all]) | |
| neib_all = neiborhood_search(chunks_org_embedding_text_all, query_embedding, num=10) | |
| neib_all = neib_all.reshape(-1) | |
| # Compile retrieved text | |
| retrieve_text_org = ''.join([text_org_chunk_l[i] for i in neib_all]) | |
| return retrieve_text,retrieve_text_org | |
| def download(self): | |
| # key_word = "Machine Learning" | |
| data_collector = [] | |
| keywords = dict() | |
| keywords["Machine Learning"] = "Machine Learning" | |
| for topic,keyword in keywords.items(): | |
| data, self.newest_day = get_daily_papers(topic, query = keyword, max_results = MAX_DAILY_PAPER) | |
| data_collector.append(data) | |
| json_file = self.dataset_path | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/paper.json", local_dir = ".", repo_type="dataset") | |
| except: | |
| with open(json_file,'w')as a: | |
| print(json_file) | |
| update_file=update_json_file(json_file, data_collector, scheduler) | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/paper_embedding.pkl", local_dir = ".", repo_type="dataset") | |
| except: | |
| with open(self.embedding_path,'wb')as a: | |
| print(self.embedding_path) | |
| time_chunks_embed={} | |
| for data in data_collector: | |
| for date in data.keys(): | |
| papers = data[date]['abstract'] | |
| papers_embedding=get_bert_embedding(papers) | |
| time_chunks_embed[date.strftime("%m/%d/%Y")] = papers_embedding | |
| update_paper_file=update_pickle_file(self.embedding_path,time_chunks_embed, scheduler) | |
| self.paper = update_file | |
| self.paper_embedding = update_paper_file | |
| def load_cache(self): | |
| filename = self.feedback_path | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/feedback.json", local_dir = ".", repo_type="dataset") | |
| with open(filename,"rb") as f: | |
| content = f.read() | |
| if not content: | |
| m = {} | |
| else: | |
| m = json.loads(content) | |
| except: | |
| with open(filename, mode='w', encoding='utf-8') as ff: | |
| m = {} | |
| self.feedback = m.copy() | |
| filename = self.trend_idea_path | |
| # if os.path.exists(filename): | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/trend_idea.json", local_dir = ".", repo_type="dataset") | |
| with open(filename,"rb") as f: | |
| content = f.read() | |
| if not content: | |
| m = {} | |
| else: | |
| m = json.loads(content) | |
| except: | |
| with open(filename, mode='w', encoding='utf-8') as ff: | |
| m = {} | |
| self.trend_idea = m.copy() | |
| filename = self.profile_path | |
| # if os.path.exists(filename): | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/profile.json", local_dir = ".", repo_type="dataset") | |
| with open(filename,"rb") as f: | |
| content = f.read() | |
| if not content: | |
| m = {} | |
| else: | |
| m = json.loads(content) | |
| except: | |
| with open(filename, mode='w', encoding='utf-8') as ff: | |
| m = {} | |
| self.profile = m.copy() | |
| filename = self.thought_path | |
| filename_emb = self.thought_embedding_path | |
| # if os.path.exists(filename): | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/thought.json", local_dir = ".", repo_type="dataset") | |
| with open(filename,"rb") as f: | |
| content = f.read() | |
| if not content: | |
| m = {} | |
| else: | |
| m = json.loads(content) | |
| except: | |
| with open(filename, mode='w', encoding='utf-8') as ff: | |
| m = {} | |
| # if os.path.exists(filename_emb): | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/thought_embedding.pkl", local_dir = ".", repo_type="dataset") | |
| with open(filename_emb,"rb") as f: | |
| content = f.read() | |
| if not content: | |
| m_emb = {} | |
| else: | |
| m_emb = pickle.loads(content) | |
| except: | |
| with open(filename_emb, mode='w', encoding='utf-8') as ff: | |
| m_emb = {} | |
| self.thought = m.copy() | |
| self.thought_embedding = m_emb.copy() | |
| filename = self.comment_path | |
| # if os.path.exists(filename): | |
| try: | |
| hf_hub_download(repo_id=DATA_REPO_ID, filename="dataset/comment.json", local_dir = ".", repo_type="dataset") | |
| with open(filename,"r") as f: | |
| content = f.read() | |
| if not content: | |
| m = {} | |
| else: | |
| m = json.loads(content) | |
| except: | |
| with open(filename, mode='w', encoding='utf-8') as ff: | |
| m = {} | |
| self.comment = m.copy() | |
| def update_feedback_thought(self, query, ansA, ansB, feedbackA, feedbackB): | |
| try: | |
| thread6.run_threaded(feedback_thought, [self, query, ansA, ansB, feedbackA, feedbackB]) | |
| # thread6.start_new_thread( print_time, ["Thread-2", 4] ) | |
| except: | |
| print("Error: unable to start thread") | |
| def update_comment(self, comment): | |
| date = datetime.datetime.now().strftime("%m/%d/%Y") | |
| json_data = self.comment | |
| if date not in json_data: | |
| json_data[date] = [comment] | |
| else: json_data[date].append(comment) | |
| # with scheduler.lock: | |
| # with open(filename,"w") as f: | |
| # json.dump(json_data,f) | |
| return "Thanks for your comment!" | |
| def get_arxiv_data_by_author(self, author_name): | |
| if author_name in self.profile: return self.profile[author_name] | |
| author_query = author_name.replace(" ", "+") | |
| url = f"http://export.arxiv.org/api/query?search_query=au:{author_query}&start=0&max_results=300" # Adjust max_results if needed | |
| response = requests.get(url) | |
| papers_list = [] | |
| if response.status_code == 200: | |
| root = ElementTree.fromstring(response.content) | |
| entries = root.findall('{http://www.w3.org/2005/Atom}entry') | |
| total_papers = 0 | |
| data_to_save = [] | |
| papers_by_year = {} | |
| for entry in entries: | |
| title = entry.find('{http://www.w3.org/2005/Atom}title').text.strip() | |
| published = entry.find('{http://www.w3.org/2005/Atom}published').text.strip() | |
| abstract = entry.find('{http://www.w3.org/2005/Atom}summary').text.strip() | |
| authors_elements = entry.findall('{http://www.w3.org/2005/Atom}author') | |
| authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in authors_elements] | |
| link = entry.find('{http://www.w3.org/2005/Atom}id').text.strip() # Get the paper link | |
| # Check if the specified author is exactly in the authors list | |
| if author_name in authors: | |
| # Remove the specified author from the coauthors list for display | |
| coauthors = [author for author in authors if author != author_name] | |
| coauthors_str = ", ".join(coauthors) | |
| papers_list.append({ | |
| "date": published, | |
| "Title & Abstract": f"{title}; {abstract}", | |
| "coauthors": coauthors_str, | |
| "link": link # Add the paper link to the dictionary | |
| }) | |
| authors_elements = entry.findall('{http://www.w3.org/2005/Atom}author') | |
| authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in authors_elements] | |
| if author_name in authors: | |
| # print(author_name) | |
| # print(authors) | |
| total_papers += 1 | |
| published_date = entry.find('{http://www.w3.org/2005/Atom}published').text.strip() | |
| date_obj = datetime.datetime.strptime(published_date, '%Y-%m-%dT%H:%M:%SZ') | |
| year = date_obj.year | |
| if year not in papers_by_year: | |
| papers_by_year[year] = [] | |
| papers_by_year[year].append(entry) | |
| if total_papers > 40: | |
| for cycle_start in range(min(papers_by_year), max(papers_by_year) + 1, 5): | |
| cycle_end = cycle_start + 4 | |
| for year in range(cycle_start, cycle_end + 1): | |
| if year in papers_by_year: | |
| selected_papers = papers_by_year[year][:2] | |
| for paper in selected_papers: | |
| title = paper.find('{http://www.w3.org/2005/Atom}title').text.strip() | |
| abstract = paper.find('{http://www.w3.org/2005/Atom}summary').text.strip() | |
| authors_elements = paper.findall('{http://www.w3.org/2005/Atom}author') | |
| co_authors = [author.find('{http://www.w3.org/2005/Atom}name').text for author in authors_elements if author.find('{http://www.w3.org/2005/Atom}name').text != author_name] | |
| papers_list.append({ | |
| "Author": author_name, | |
| "Title & Abstract": f"{title}; {abstract}", | |
| "Date Period": f"{year}", | |
| "Cycle": f"{cycle_start}-{cycle_end}", | |
| "Co_author": ", ".join(co_authors) | |
| }) | |
| # Trim the list to the 10 most recent papers | |
| papers_list = papers_list[:10] | |
| # Prepare the data dictionary with the author's name as a key | |
| # import pdb | |
| # pdb.set_trace() | |
| personal_info = "; ".join([f"{details['Title & Abstract']}" for details in papers_list]) | |
| info = summarize_research_direction(personal_info) | |
| self.profile[author_name] = info | |
| return self.profile[author_name] | |
| else: | |
| return None | |