Spaces:
Paused
Paused
Commit
·
749f953
1
Parent(s):
4640e16
update utils
Browse files- utils/db_utils.py +47 -5
utils/db_utils.py
CHANGED
|
@@ -7,6 +7,8 @@ from utils.bridge_content_encoder import get_matched_entries
|
|
| 7 |
from nltk.tokenize import word_tokenize
|
| 8 |
from nltk import ngrams
|
| 9 |
|
|
|
|
|
|
|
| 10 |
def add_a_record(question, db_id):
|
| 11 |
conn = sqlite3.connect('data/history/history.sqlite')
|
| 12 |
cursor = conn.cursor()
|
|
@@ -97,15 +99,19 @@ def get_column_contents(column_name, table_name, cursor):
|
|
| 97 |
return column_contents
|
| 98 |
|
| 99 |
def get_matched_contents(question, searcher):
|
| 100 |
-
#
|
| 101 |
grams = obtain_n_grams(question, 4)
|
| 102 |
hits = []
|
|
|
|
|
|
|
| 103 |
for query in grams:
|
| 104 |
-
|
|
|
|
|
|
|
| 105 |
|
| 106 |
coarse_matched_contents = dict()
|
| 107 |
-
for
|
| 108 |
-
matched_result = json.loads(
|
| 109 |
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
| 110 |
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
| 111 |
if tc_name in coarse_matched_contents.keys():
|
|
@@ -116,7 +122,7 @@ def get_matched_contents(question, searcher):
|
|
| 116 |
|
| 117 |
fine_matched_contents = dict()
|
| 118 |
for tc_name, contents in coarse_matched_contents.items():
|
| 119 |
-
#
|
| 120 |
fm_contents = get_matched_entries(question, contents)
|
| 121 |
|
| 122 |
if fm_contents is None:
|
|
@@ -132,6 +138,42 @@ def get_matched_contents(question, searcher):
|
|
| 132 |
|
| 133 |
return fine_matched_contents
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def get_db_schema_sequence(schema):
|
| 136 |
schema_sequence = "database schema :\n"
|
| 137 |
for table in schema["schema_items"]:
|
|
|
|
| 7 |
from nltk.tokenize import word_tokenize
|
| 8 |
from nltk import ngrams
|
| 9 |
|
| 10 |
+
from whoosh.qparser import QueryParser
|
| 11 |
+
|
| 12 |
def add_a_record(question, db_id):
|
| 13 |
conn = sqlite3.connect('data/history/history.sqlite')
|
| 14 |
cursor = conn.cursor()
|
|
|
|
| 99 |
return column_contents
|
| 100 |
|
| 101 |
def get_matched_contents(question, searcher):
|
| 102 |
+
# Coarse-grained matching between the input text and all contents in database
|
| 103 |
grams = obtain_n_grams(question, 4)
|
| 104 |
hits = []
|
| 105 |
+
|
| 106 |
+
# Parse each n-gram query into a valid Whoosh query object
|
| 107 |
for query in grams:
|
| 108 |
+
query_parser = QueryParser("content", schema=searcher.schema) # 'content' should match the field you are searching
|
| 109 |
+
parsed_query = query_parser.parse(query) # Convert the query string into a Whoosh Query object
|
| 110 |
+
hits.extend(searcher.search(parsed_query, limit=10)) # Perform the search with the parsed query
|
| 111 |
|
| 112 |
coarse_matched_contents = dict()
|
| 113 |
+
for hit in hits:
|
| 114 |
+
matched_result = json.loads(hit.raw)
|
| 115 |
# `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
| 116 |
tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
| 117 |
if tc_name in coarse_matched_contents.keys():
|
|
|
|
| 122 |
|
| 123 |
fine_matched_contents = dict()
|
| 124 |
for tc_name, contents in coarse_matched_contents.items():
|
| 125 |
+
# Fine-grained matching between the question and coarse matched contents
|
| 126 |
fm_contents = get_matched_entries(question, contents)
|
| 127 |
|
| 128 |
if fm_contents is None:
|
|
|
|
| 138 |
|
| 139 |
return fine_matched_contents
|
| 140 |
|
| 141 |
+
# def get_matched_contents(question, searcher):
|
| 142 |
+
# # coarse-grained matching between the input text and all contents in database
|
| 143 |
+
# grams = obtain_n_grams(question, 4)
|
| 144 |
+
# hits = []
|
| 145 |
+
# for query in grams:
|
| 146 |
+
# hits.extend(searcher.search(query, limit = 10))
|
| 147 |
+
|
| 148 |
+
# coarse_matched_contents = dict()
|
| 149 |
+
# for i in range(len(hits)):
|
| 150 |
+
# matched_result = json.loads(hits[i].raw)
|
| 151 |
+
# # `tc_name` refers to column names like `table_name.column_name`, e.g., document_drafts.document_id
|
| 152 |
+
# tc_name = ".".join(matched_result["id"].split("-**-")[:2])
|
| 153 |
+
# if tc_name in coarse_matched_contents.keys():
|
| 154 |
+
# if matched_result["contents"] not in coarse_matched_contents[tc_name]:
|
| 155 |
+
# coarse_matched_contents[tc_name].append(matched_result["contents"])
|
| 156 |
+
# else:
|
| 157 |
+
# coarse_matched_contents[tc_name] = [matched_result["contents"]]
|
| 158 |
+
|
| 159 |
+
# fine_matched_contents = dict()
|
| 160 |
+
# for tc_name, contents in coarse_matched_contents.items():
|
| 161 |
+
# # fine-grained matching between the question and coarse matched contents
|
| 162 |
+
# fm_contents = get_matched_entries(question, contents)
|
| 163 |
+
|
| 164 |
+
# if fm_contents is None:
|
| 165 |
+
# continue
|
| 166 |
+
# for _match_str, (field_value, _s_match_str, match_score, s_match_score, _match_size,) in fm_contents:
|
| 167 |
+
# if match_score < 0.9:
|
| 168 |
+
# continue
|
| 169 |
+
# if tc_name in fine_matched_contents.keys():
|
| 170 |
+
# if len(fine_matched_contents[tc_name]) < 25:
|
| 171 |
+
# fine_matched_contents[tc_name].append(field_value.strip())
|
| 172 |
+
# else:
|
| 173 |
+
# fine_matched_contents[tc_name] = [field_value.strip()]
|
| 174 |
+
|
| 175 |
+
# return fine_matched_contents
|
| 176 |
+
|
| 177 |
def get_db_schema_sequence(schema):
|
| 178 |
schema_sequence = "database schema :\n"
|
| 179 |
for table in schema["schema_items"]:
|