Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -2,7 +2,7 @@ 
     | 
|
| 2 | 
         
             
            """
         
     | 
| 3 | 
         
             
            Credit to Derek Thomas, [email protected]
         
     | 
| 4 | 
         
             
            """
         
     | 
| 5 | 
         
            -
             
     | 
| 6 | 
         
             
            import subprocess
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
         
     | 
| 
         @@ -59,46 +59,68 @@ def bot(history, cross_encoder): 
     | 
|
| 59 | 
         
             
                     raise ValueError("Empty string was submitted")
         
     | 
| 60 | 
         | 
| 61 | 
         
             
                logger.warning('Retrieving documents...')
         
     | 
| 62 | 
         
            -
                # Retrieve documents relevant to query
         
     | 
| 63 | 
         
            -
                document_start = perf_counter()
         
     | 
| 64 | 
         
            -
             
     | 
| 65 | 
         
            -
                query_vec = retriever.encode(query)
         
     | 
| 66 | 
         
            -
                logger.warning(f'Finished query vec')
         
     | 
| 67 | 
         
            -
                doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
         
     | 
| 68 | 
         
            -
             
     | 
| 69 | 
         | 
| 70 | 
         
            -
             
     | 
| 71 | 
         
            -
                 
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
             
     | 
| 74 | 
         
            -
             
     | 
| 75 | 
         
            -
             
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
             
     | 
| 78 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 79 | 
         
             
                else:
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
             
     | 
| 82 | 
         
            -
                sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
         
     | 
| 83 | 
         
            -
                logger.warning(f'Finished cross encoder {len(documents)}')
         
     | 
| 84 | 
         | 
| 85 | 
         
            -
             
     | 
| 86 | 
         
            -
             
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                 
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
            -
             
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
             
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
             
     | 
| 97 | 
         
            -
             
     | 
| 98 | 
         
            -
             
     | 
| 99 | 
         
            -
                     
     | 
| 100 | 
         
            -
             
     | 
| 101 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 102 | 
         | 
| 103 | 
         | 
| 104 | 
         
             
            with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
         
     | 
| 
         @@ -128,7 +150,7 @@ with gr.Blocks(theme='Insuz/SimpleIndigo') as demo: 
     | 
|
| 128 | 
         
             
                            )
         
     | 
| 129 | 
         
             
                    txt_btn = gr.Button(value="Submit text", scale=1)
         
     | 
| 130 | 
         | 
| 131 | 
         
            -
                cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy")
         
     | 
| 132 | 
         | 
| 133 | 
         
             
                prompt_html = gr.HTML()
         
     | 
| 134 | 
         
             
                # Turn off interactivity while generating if you click
         
     | 
| 
         | 
|
| 2 | 
         
             
            """
         
     | 
| 3 | 
         
             
            Credit to Derek Thomas, [email protected]
         
     | 
| 4 | 
         
             
            """
         
     | 
| 5 | 
         
            +
            from ragatouille import RAGPretrainedModel
         
     | 
| 6 | 
         
             
            import subprocess
         
     | 
| 7 | 
         | 
| 8 | 
         
             
            # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
         
     | 
| 
         | 
|
| 59 | 
         
             
                     raise ValueError("Empty string was submitted")
         
     | 
| 60 | 
         | 
| 61 | 
         
             
                logger.warning('Retrieving documents...')
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 62 | 
         | 
| 63 | 
         
            +
                # if COLBERT RAGATATOUILLE PROCEDURE  : 
         
     | 
| 64 | 
         
            +
                if cross_encoder=='ColBERT':
         
     | 
| 65 | 
         
            +
                    gr.Warning('Retrieving using ColBERT')
         
     | 
| 66 | 
         
            +
                    RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
         
     | 
| 67 | 
         
            +
                    RAG_db=RAG.from_index('.ragatouille/colbert/indexes/mockingbird')
         
     | 
| 68 | 
         
            +
                    documents_full=RAG_db.search(query)
         
     | 
| 69 | 
         
            +
                    
         
     | 
| 70 | 
         
            +
                    documents=[item['content'] for item in documents_full]
         
     | 
| 71 | 
         
            +
                    # Create Prompt
         
     | 
| 72 | 
         
            +
                    prompt = template.render(documents=documents, query=query)
         
     | 
| 73 | 
         
            +
                    prompt_html = template_html.render(documents=documents, query=query)
         
     | 
| 74 | 
         
            +
                
         
     | 
| 75 | 
         
            +
                    generate_fn = generate_hf
         
     | 
| 76 | 
         
            +
                
         
     | 
| 77 | 
         
            +
                    history[-1][1] = ""
         
     | 
| 78 | 
         
            +
                    for character in generate_fn(prompt, history[:-1]):
         
     | 
| 79 | 
         
            +
                        history[-1][1] = character
         
     | 
| 80 | 
         
            +
                        print('Final history is ',history)
         
     | 
| 81 | 
         
            +
                        yield history, prompt_html
         
     | 
| 82 | 
         
             
                else:
         
     | 
| 83 | 
         
            +
                    # Retrieve documents relevant to query
         
     | 
| 84 | 
         
            +
                    document_start = perf_counter()
         
     | 
| 
         | 
|
| 
         | 
|
| 85 | 
         | 
| 86 | 
         
            +
                    query_vec = retriever.encode(query)
         
     | 
| 87 | 
         
            +
                    logger.warning(f'Finished query vec')
         
     | 
| 88 | 
         
            +
                    doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
         
     | 
| 89 | 
         
            +
                
         
     | 
| 90 | 
         
            +
                    
         
     | 
| 91 | 
         
            +
                
         
     | 
| 92 | 
         
            +
                    logger.warning(f'Finished search')
         
     | 
| 93 | 
         
            +
                    documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
         
     | 
| 94 | 
         
            +
                    documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
         
     | 
| 95 | 
         
            +
                    logger.warning(f'start cross encoder {len(documents)}')
         
     | 
| 96 | 
         
            +
                    # Retrieve documents relevant to query
         
     | 
| 97 | 
         
            +
                    query_doc_pair = [[query, doc] for doc in documents]
         
     | 
| 98 | 
         
            +
                    if cross_encoder=='MiniLM-L6v2' :
         
     | 
| 99 | 
         
            +
                           cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') 
         
     | 
| 100 | 
         
            +
                    elif cross_encoder=='BGE reranker':
         
     | 
| 101 | 
         
            +
                           cross_encoder = CrossEncoder('BAAI/bge-reranker-base')
         
     | 
| 102 | 
         
            +
                    
         
     | 
| 103 | 
         
            +
                    cross_scores = cross_encoder.predict(query_doc_pair)
         
     | 
| 104 | 
         
            +
                    sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
         
     | 
| 105 | 
         
            +
                    logger.warning(f'Finished cross encoder {len(documents)}')
         
     | 
| 106 | 
         
            +
                    
         
     | 
| 107 | 
         
            +
                    documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
         
     | 
| 108 | 
         
            +
                    logger.warning(f'num documents {len(documents)}')
         
     | 
| 109 | 
         
            +
                
         
     | 
| 110 | 
         
            +
                    document_time = perf_counter() - document_start
         
     | 
| 111 | 
         
            +
                    logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
         
     | 
| 112 | 
         
            +
                
         
     | 
| 113 | 
         
            +
                    # Create Prompt
         
     | 
| 114 | 
         
            +
                    prompt = template.render(documents=documents, query=query)
         
     | 
| 115 | 
         
            +
                    prompt_html = template_html.render(documents=documents, query=query)
         
     | 
| 116 | 
         
            +
                
         
     | 
| 117 | 
         
            +
                    generate_fn = generate_hf
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
                    history[-1][1] = ""
         
     | 
| 120 | 
         
            +
                    for character in generate_fn(prompt, history[:-1]):
         
     | 
| 121 | 
         
            +
                        history[-1][1] = character
         
     | 
| 122 | 
         
            +
                        print('Final history is ',history)
         
     | 
| 123 | 
         
            +
                        yield history, prompt_html
         
     | 
| 124 | 
         | 
| 125 | 
         | 
| 126 | 
         
             
            with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
         
     | 
| 
         | 
|
| 150 | 
         
             
                            )
         
     | 
| 151 | 
         
             
                    txt_btn = gr.Button(value="Submit text", scale=1)
         
     | 
| 152 | 
         | 
| 153 | 
         
            +
                cross_encoder = gr.Radio(choices=['MiniLM-L6v2','BGE reranker','ColBERT'], value='BGE reranker',label="Embeddings", info="Choose MiniLM for Speed, BGE reranker for accuracy,ColBERT for both")
         
     | 
| 154 | 
         | 
| 155 | 
         
             
                prompt_html = gr.HTML()
         
     | 
| 156 | 
         
             
                # Turn off interactivity while generating if you click
         
     |