Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import crystal_toolkit.components as ctc | |
| import numpy as np | |
| import periodictable | |
| from dash import dcc, html | |
| from datasets import concatenate_datasets, load_dataset | |
| from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| top_k = 500 | |
| def get_dataset(): | |
| # Load only the train split of the dataset | |
| datasets = [] | |
| subsets = [ | |
| "compatible_pbe", | |
| "compatible_pbesol", | |
| "compatible_scan", | |
| "non_compatible", | |
| ] | |
| for subset in subsets: | |
| dataset = load_dataset( | |
| "LeMaterial/LeMat-Bulk", | |
| subset, | |
| # token=HF_TOKEN, | |
| columns=[ | |
| "lattice_vectors", | |
| "species_at_sites", | |
| "cartesian_site_positions", | |
| "energy", | |
| # "energy_corrected", # not yet available in LeMat-Bulk | |
| "immutable_id", | |
| "elements", | |
| "stress_tensor", | |
| "magnetic_moments", | |
| "forces", | |
| # "band_gap_direct", #future release | |
| # "band_gap_indirect", #future release | |
| "dos_ef", | |
| # "charges", #future release | |
| "functional", | |
| "chemical_formula_reduced", | |
| "chemical_formula_descriptive", | |
| "total_magnetization", | |
| "entalpic_fingerprint", | |
| ], | |
| ) | |
| datasets.append(dataset["train"]) | |
| return concatenate_datasets(datasets) | |
| display_columns = [ | |
| "chemical_formula_descriptive", | |
| "functional", | |
| "immutable_id", | |
| "energy", | |
| ] | |
| display_names = { | |
| "chemical_formula_descriptive": "Formula", | |
| "functional": "Functional", | |
| "immutable_id": "Material ID", | |
| "energy": "Energy (eV)", | |
| } | |
| # Global shared variables | |
| mapping_table_idx_dataset_idx = {} | |
| def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False): | |
| print("Building formula index") | |
| if empty_data: | |
| return np.zeros((1, 1)), {} | |
| use_dataset = dataset | |
| if index_range is not None: | |
| use_dataset = dataset.select(index_range) | |
| # Preprocessing step to create an index for the dataset | |
| from scipy.sparse import load_npz | |
| if cache_path is not None and os.path.exists(f"{cache_path}/train_df.pkl"): | |
| train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb")) | |
| dataset_index = load_npz(f"{cache_path}/dataset_index.npz") | |
| else: | |
| train_df = use_dataset.select_columns( | |
| ["species_at_sites", "immutable_id", "functional"] | |
| ).to_pandas() | |
| import tqdm | |
| all_elements = { | |
| str(el.symbol): i for i, el in enumerate(periodictable.elements) | |
| } # full element list | |
| dataset_index = np.zeros((len(train_df), len(all_elements))) | |
| for idx, species in tqdm.tqdm(enumerate(train_df["species_at_sites"].values)): | |
| for el in species: | |
| dataset_index[idx, all_elements[el]] += 1 | |
| dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] | |
| dataset_index = ( | |
| dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] | |
| ) # Normalize vectors | |
| from scipy.sparse import csr_matrix, save_npz | |
| dataset_index = csr_matrix(dataset_index) | |
| if cache_path is not None: | |
| pickle.dump(train_df, open(f"{cache_path}/train_df.pkl", "wb")) | |
| save_npz(f"{cache_path}/dataset_index.npz", dataset_index) | |
| immutable_id_to_idx = train_df["immutable_id"].to_dict() | |
| del train_df | |
| immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()} | |
| return dataset_index, immutable_id_to_idx | |
| import pickle | |
| from pathlib import Path | |
| # TODO: Just load the index from a file | |
| def build_embeddings_index(empty_data=False): | |
| if empty_data: | |
| return None, {}, {} | |
| features_dict = pickle.load(open("features_dict.pkl", "rb")) | |
| from indexer import FAISSIndex | |
| index = FAISSIndex() | |
| for key in features_dict: | |
| index.index.add(features_dict[key].reshape(1, -1)) | |
| idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)} | |
| # index = FAISSIndex.from_store("index.faiss") | |
| return index, features_dict, idx_to_immutable_id | |
| def search_materials( | |
| query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table | |
| ): | |
| n_elements = len(map_periodic_table) | |
| query_vector = np.zeros(n_elements) | |
| if "," in query: | |
| element_list = [el.strip() for el in query.split(",")] | |
| for el in element_list: | |
| query_vector[map_periodic_table[el]] = 1 | |
| else: | |
| # Formula | |
| import re | |
| matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) | |
| for el, numb in matches: | |
| numb = int(numb) if numb else 1 | |
| query_vector[map_periodic_table[el]] = numb | |
| similarity = dataset_index.dot(query_vector) / (np.linalg.norm(query_vector)) | |
| indices = np.argsort(similarity)[::-1][:top_k] | |
| options = [dataset[int(i)] for i in indices] | |
| mapping_table_idx_dataset_idx.clear() | |
| for i, idx in enumerate(indices): | |
| mapping_table_idx_dataset_idx[int(i)] = int(idx) | |
| return options | |
| def get_properties_table( | |
| row, structure, sga, properties_container_update, container_type="query" | |
| ): | |
| properties = { | |
| "Material ID": row["immutable_id"], | |
| "Formula": row["chemical_formula_descriptive"], | |
| "Energy per atom (eV/atom)": round( | |
| row["energy"] / len(row["species_at_sites"]), 3 | |
| ), | |
| # "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release | |
| "Total Magnetization (μB)": ( | |
| round(row["total_magnetization"], 3) | |
| if row["total_magnetization"] is not None | |
| else None | |
| ), | |
| "Density (g/cm^3)": round(structure.density, 3), | |
| "Fermi energy level (eV)": ( | |
| round(row["dos_ef"], 3) if row["dos_ef"] is not None else None | |
| ), | |
| "Crystal system": sga.get_crystal_system(), | |
| "International Spacegroup": sga.get_symmetry_dataset().international, | |
| "Magnetic moments (μB)": np.round(row["magnetic_moments"], 3), | |
| "Stress tensor (kB)": np.round(row["stress_tensor"], 3), | |
| "Forces on atoms (eV/A)": np.round(row["forces"], 3), | |
| # "Bader charges (e-)": np.round(row["charges"], 3), # future release | |
| "DFT Functional": row["functional"], | |
| "Entalpic fingerprint": row["entalpic_fingerprint"], | |
| } | |
| style = { | |
| "padding": "10px", | |
| "borderBottom": "1px solid #ddd", | |
| } | |
| if container_type == "query": | |
| properties_container_update[0] = properties | |
| else: | |
| properties_container_update[1] = properties | |
| # if (type(value) in [str, float]) and ( | |
| # properties_container_update[0][key] == properties_container_update[1][key] | |
| # ): | |
| # style["backgroundColor"] = "#e6f7ff" | |
| # Format properties as an HTML table | |
| properties_html = html.Table( | |
| [ | |
| html.Tbody( | |
| [ | |
| html.Tr( | |
| [ | |
| html.Th( | |
| key, | |
| style={ | |
| "padding": "10px", | |
| "verticalAlign": "middle", | |
| }, | |
| ), | |
| html.Td( | |
| str(value), | |
| style=style, | |
| ), | |
| ], | |
| ) | |
| for key, value in properties.items() | |
| ], | |
| ) | |
| ], | |
| style={ | |
| "width": "100%", | |
| "borderCollapse": "collapse", | |
| "fontFamily": "'Arial', sans-serif", | |
| "fontSize": "14px", | |
| "color": "#333333", | |
| }, | |
| ) | |
| return properties_html | |
| def get_crystal_plot(structure): | |
| sga = SpacegroupAnalyzer(structure) | |
| # Create the StructureMoleculeComponent | |
| structure_component = ctc.StructureMoleculeComponent(structure) | |
| return structure_component.layout(), sga | |