Spaces:
Running
Running
| """ | |
| Generate and save the embeddings of a pre-defined list of icons. | |
| Compare them with keywords embeddings to find most relevant icons. | |
| """ | |
| from typing import Union | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import BertTokenizer, BertModel | |
| from ..global_config import GlobalConfig | |
| tokenizer = BertTokenizer.from_pretrained(GlobalConfig.TINY_BERT_MODEL) | |
| model = BertModel.from_pretrained(GlobalConfig.TINY_BERT_MODEL) | |
| def get_icons_list() -> list[str]: | |
| """ | |
| Get a list of available icons. | |
| Returns: | |
| The icons file names. | |
| """ | |
| items = GlobalConfig.ICONS_DIR.glob('*.png') | |
| items = [item.stem for item in items] | |
| return items | |
| def get_embeddings(texts: Union[str, list[str]]) -> np.ndarray: | |
| """ | |
| Generate embeddings for a list of texts using a pre-trained language model. | |
| Args: | |
| texts: A string or a list of strings to be converted into embeddings. | |
| Returns: | |
| A NumPy array containing the embeddings for the input texts. | |
| Raises: | |
| ValueError: If the input is not a string or a list of strings, or if any element | |
| in the list is not a string. | |
| Example usage: | |
| >>> keyword = 'neural network' | |
| >>> file_names = ['neural_network_icon.png', 'data_analysis_icon.png', 'machine_learning.png'] | |
| >>> keyword_embeddings = get_embeddings(keyword) | |
| >>> file_name_embeddings = get_embeddings(file_names) | |
| """ | |
| inputs = tokenizer(texts, return_tensors='pt', padding=True, max_length=128, truncation=True) | |
| outputs = model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
| def save_icons_embeddings(): | |
| """ | |
| Generate and save the embeddings for the icon file names. | |
| """ | |
| file_names = get_icons_list() | |
| print(f'{len(file_names)} icon files available...') | |
| file_name_embeddings = get_embeddings(file_names) | |
| print(f'file_name_embeddings.shape: {file_name_embeddings.shape}') | |
| # Save embeddings to a file | |
| np.save(GlobalConfig.EMBEDDINGS_FILE_NAME, file_name_embeddings) | |
| np.save(GlobalConfig.ICONS_FILE_NAME, file_names) # Save file names for reference | |
| def load_saved_embeddings() -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Load precomputed embeddings and icons file names. | |
| Returns: | |
| The embeddings and the icon file names. | |
| """ | |
| file_name_embeddings = np.load(GlobalConfig.EMBEDDINGS_FILE_NAME) | |
| file_names = np.load(GlobalConfig.ICONS_FILE_NAME) | |
| return file_name_embeddings, file_names | |
| def find_icons(keywords: list[str]) -> list[str]: | |
| """ | |
| Find relevant icon file names for a list of keywords. | |
| Args: | |
| keywords: The list of one or more keywords. | |
| Returns: | |
| A list of the file names relevant for each keyword. | |
| """ | |
| keyword_embeddings = get_embeddings(keywords) | |
| file_name_embeddings, file_names = load_saved_embeddings() | |
| # Compute similarity | |
| similarities = cosine_similarity(keyword_embeddings, file_name_embeddings) | |
| icon_files = file_names[np.argmax(similarities, axis=-1)] | |
| return icon_files | |
| def main(): | |
| """ | |
| Example usage. | |
| """ | |
| # Run this again if icons are to be added/removed | |
| save_icons_embeddings() | |
| keywords = [ | |
| 'deep learning', | |
| '', | |
| 'recycling', | |
| 'handshake', | |
| 'Ferry', | |
| 'rain drop', | |
| 'speech bubble', | |
| 'mental resilience', | |
| 'turmeric', | |
| 'Art', | |
| 'price tag', | |
| 'Oxygen', | |
| 'oxygen', | |
| 'Social Connection', | |
| 'Accomplishment', | |
| 'Python', | |
| 'XML', | |
| 'Handshake', | |
| ] | |
| icon_files = find_icons(keywords) | |
| print( | |
| f'The relevant icon files are:\n' | |
| f'{list(zip(keywords, icon_files))}' | |
| ) | |
| # BERT tiny: | |
| # [('deep learning', 'deep-learning'), ('', '123'), ('recycling', 'refinery'), | |
| # ('handshake', 'dash-circle'), ('Ferry', 'cart'), ('rain drop', 'bucket'), | |
| # ('speech bubble', 'globe'), ('mental resilience', 'exclamation-triangle'), | |
| # ('turmeric', 'kebab'), ('Art', 'display'), ('price tag', 'bug-fill'), | |
| # ('Oxygen', 'radioactive')] | |
| # BERT mini | |
| # [('deep learning', 'deep-learning'), ('', 'compass'), ('recycling', 'tools'), | |
| # ('handshake', 'bandaid'), ('Ferry', 'cart'), ('rain drop', 'trash'), | |
| # ('speech bubble', 'image'), ('mental resilience', 'recycle'), ('turmeric', 'linkedin'), | |
| # ('Art', 'book'), ('price tag', 'card-image'), ('Oxygen', 'radioactive')] | |
| # BERT small | |
| # [('deep learning', 'deep-learning'), ('', 'gem'), ('recycling', 'tools'), | |
| # ('handshake', 'handbag'), ('Ferry', 'truck'), ('rain drop', 'bucket'), | |
| # ('speech bubble', 'strategy'), ('mental resilience', 'deep-learning'), | |
| # ('turmeric', 'flower'), | |
| # ('Art', 'book'), ('price tag', 'hotdog'), ('Oxygen', 'radioactive')] | |
| if __name__ == '__main__': | |
| main() | |