Pocket-Gen / app.py
Zaixi's picture
1
6ee4b83
import spaces
import gradio as gr
from gradio_molecule3d import Molecule3D
import os
import numpy as np
import torch
from rdkit import Chem
import argparse
import random
from tqdm import tqdm
from vina import Vina
import esm
from utils.relax import openmm_relax, relax_sdf
from utils.protein_ligand import PDBProtein, parse_sdf_file
from utils.data import torchify_dict
from torch_geometric.transforms import Compose
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.data import *
from torch.utils.data import DataLoader
from models.PD import Pocket_Design_new
from functools import partial
import pickle
import yaml
from easydict import EasyDict
import uuid
from datetime import datetime
import tempfile
import shutil
from Bio import PDB
from Bio.PDB import MMCIFParser, PDBIO
import logging
import zipfile
# 配置日志
logger = logging.getLogger(__name__)
LOG_FORMAT = "%(asctime)s,%(msecs)-3d %(levelname)-8s [%(filename)s:%(lineno)s %(funcName)s] %(message)s"
logging.basicConfig(
format=LOG_FORMAT,
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
filemode="w",
)
# 确保目录存在
os.makedirs("./generate/upload", exist_ok=True)
os.makedirs("./tmp", exist_ok=True)
# 自定义CSS样式
custom_css = """
.title {
font-size: 32px;
font-weight: bold;
color: #4CAF50;
display: flex;
align-items: center;
}
.subtitle {
font-size: 20px;
color: #666;
margin-bottom: 20px;
}
.footer {
margin-top: 20px;
text-align: center;
color: #666;
}
"""
# 3D显示表示设置 - 默认配置
default_reps = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "cartoon",
"color": "whiteCarbon",
"residue_range": "",
"around": 0,
"byres": False,
"visible": True,
"opacity": 1.0
},
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "greenCarbon",
"around": 5, # 显示配体周围5Å的残基
"byres": True,
"visible": True,
"opacity": 0.8
}
]
def create_zip_file(directory_path, zip_filename):
"""将指定目录压缩为zip文件"""
try:
with zipfile.ZipFile(zip_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(directory_path):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, directory_path)
zipf.write(file_path, arcname)
logger.info(f"成功创建压缩文件: {zip_filename}")
return zip_filename
except Exception as e:
logger.error(f"创建压缩文件时出错: {str(e)}")
return None
def load_config(config_path):
"""加载配置文件"""
with open(config_path, 'r') as f:
config_dict = yaml.load(f, Loader=yaml.FullLoader)
return EasyDict(config_dict)
# 删除了Vina相关的计算函数,因为只需要RMSD结果
def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, residue_dict=None, seq=None, full_seq_idx=None,
r10_idx=None):
"""从蛋白质和配体字典创建数据实例"""
instance = {}
if protein_dict is not None:
for key, item in protein_dict.items():
instance['protein_' + key] = item
if ligand_dict is not None:
for key, item in ligand_dict.items():
instance['ligand_' + key] = item
if residue_dict is not None:
for key, item in residue_dict.items():
instance[key] = item
if seq is not None:
instance['seq'] = seq
if full_seq_idx is not None:
instance['full_seq_idx'] = full_seq_idx
if r10_idx is not None:
instance['r10_idx'] = r10_idx
return instance
def ith_true_index(tensor, i):
"""找到张量中第i个为真的元素的索引"""
true_indices = torch.nonzero(tensor).squeeze()
return true_indices[i].item()
def name2data(pdb_path, lig_path):
"""从PDB和SDF文件生成数据"""
name = os.path.basename(pdb_path).split('.')[0]
dir_name = os.path.dirname(pdb_path)
pocket_path = os.path.join(dir_name, f"{name}_pocket.pdb")
try:
with open(pdb_path, 'r') as f:
pdb_block = f.read()
protein = PDBProtein(pdb_block)
seq = ''.join(protein.to_dict_residue()['seq'])
ligand = parse_sdf_file(lig_path, feat=False)
if ligand is None:
raise ValueError(f"无法从{lig_path}解析配体")
r10_idx, r10_residues = protein.query_residues_ligand(ligand, radius=10, selected_residue=None, return_mask=False)
full_seq_idx, _ = protein.query_residues_ligand(ligand, radius=3.5, selected_residue=r10_residues, return_mask=False)
if not r10_residues:
raise ValueError("在配体10Å范围内未找到任何残基")
assert len(r10_idx) == len(r10_residues)
pdb_block_pocket = protein.residues_to_pdb_block(r10_residues)
with open(pocket_path, 'w') as f:
f.write(pdb_block_pocket)
with open(pocket_path, 'r') as f:
pdb_block = f.read()
pocket = PDBProtein(pdb_block)
pocket_dict = pocket.to_dict_atom()
residue_dict = pocket.to_dict_residue()
_, residue_dict['protein_edit_residue'] = pocket.query_residues_ligand(ligand)
if residue_dict['protein_edit_residue'].sum() == 0:
raise ValueError("在口袋内未找到可编辑残基")
assert residue_dict['protein_edit_residue'].sum() > 0 and residue_dict['protein_edit_residue'].sum() == len(full_seq_idx)
assert len(residue_dict['protein_edit_residue']) == len(r10_idx)
full_seq_idx.sort()
r10_idx.sort()
data = from_protein_ligand_dicts(
protein_dict=torchify_dict(pocket_dict),
ligand_dict=torchify_dict(ligand),
residue_dict=torchify_dict(residue_dict),
seq=seq,
full_seq_idx=torch.tensor(full_seq_idx),
r10_idx=torch.tensor(r10_idx)
)
data['protein_filename'] = pocket_path
data['ligand_filename'] = lig_path
data['whole_protein_name'] = pdb_path
return transform(data)
except Exception as e:
logger.error(f"name2data中出错: {str(e)}")
raise
def convert_cif_to_pdb(cif_path):
"""将CIF文件转换为PDB文件并保存为临时文件"""
try:
parser = MMCIFParser()
structure = parser.get_structure("protein", cif_path)
with tempfile.NamedTemporaryFile(suffix=".pdb", delete=False) as temp_file:
temp_pdb_path = temp_file.name
io = PDBIO()
io.set_structure(structure)
io.save(temp_pdb_path)
return temp_pdb_path
except Exception as e:
logger.error(f"将CIF转换为PDB时出错: {str(e)}")
raise
def align_pdb_files(pdb_file_1, pdb_file_2):
"""将两个PDB文件对齐,将第二个结构对齐到第一个结构上"""
try:
parser = PDB.PPBuilder()
io = PDB.PDBIO()
structure_1 = PDB.PDBParser(QUIET=True).get_structure('Structure_1', pdb_file_1)
structure_2 = PDB.PDBParser(QUIET=True).get_structure('Structure_2', pdb_file_2)
super_imposer = PDB.Superimposer()
model_1 = structure_1[0]
model_2 = structure_2[0]
atoms_1 = [atom for atom in model_1.get_atoms() if atom.get_name() == "CA"]
atoms_2 = [atom for atom in model_2.get_atoms() if atom.get_name() == "CA"]
if not atoms_1 or not atoms_2:
logger.warning("未找到用于对齐的CA原子")
return
min_length = min(len(atoms_1), len(atoms_2))
if min_length == 0:
logger.warning("没有可用于对齐的原子")
return
super_imposer.set_atoms(atoms_1[:min_length], atoms_2[:min_length])
super_imposer.apply(model_2)
io.set_structure(structure_2)
io.save(pdb_file_2)
except Exception as e:
logger.error(f"对齐PDB文件时出错: {str(e)}")
raise
def create_combined_structure(protein_path, ligand_path, output_path):
"""将蛋白质和配体合并为一个PDB文件以便可视化"""
try:
# 读取蛋白质PDB文件
with open(protein_path, 'r') as f:
protein_content = f.read()
# 读取配体SDF文件并转换为PDB格式的字符串
mol = Chem.MolFromMolFile(ligand_path)
if mol is None:
logger.error(f"无法读取配体文件: {ligand_path}")
return protein_path
# 将配体转换为PDB格式
ligand_pdb_block = Chem.MolToPDBBlock(mol)
# 合并蛋白质和配体
combined_content = protein_content.rstrip() + "\n" + ligand_pdb_block
# 保存合并后的文件
with open(output_path, 'w') as f:
f.write(combined_content)
return output_path
except Exception as e:
logger.error(f"创建合并结构时出错: {str(e)}")
return protein_path # 如果失败,返回原始蛋白质文件
@spaces.GPU(duration=500)
def process_files(pdb_file, sdf_file, config_path):
"""处理上传的PDB和SDF文件"""
try:
unique_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
upload_dir = os.path.join("./generate/upload", unique_id)
os.makedirs(upload_dir, exist_ok=True)
logger.info(f"使用ID处理文件: {unique_id}")
config = load_config(config_path)
pdb_save_path = os.path.join(upload_dir, "protein.pdb")
sdf_save_path = os.path.join(upload_dir, "ligand.sdf")
shutil.copy(pdb_file, pdb_save_path)
shutil.copy(sdf_file, sdf_save_path)
logger.info(f"文件已保存到 {upload_dir}")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device}")
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
global transform
transform = Compose([
protein_featurizer,
ligand_featurizer,
])
logger.info("加载ESM模型...")
name = 'esm2_t33_650M_UR50D'
pretrained_model, alphabet = esm.pretrained.load_model_and_alphabet_hub(name)
batch_converter = alphabet.get_batch_converter()
checkpoint_path = config.model.checkpoint
logger.info(f"从{checkpoint_path}加载检查点")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
del pretrained_model
logger.info("初始化模型...")
model = Pocket_Design_new(
config.model,
protein_atom_feature_dim=protein_featurizer.feature_dim,
ligand_atom_feature_dim=ligand_featurizer.feature_dim,
device=device
).to(device)
model.load_state_dict(ckpt['model'])
logger.info("处理输入数据...")
data = name2data(pdb_save_path, sdf_save_path)
batch_size = 2
datalist = [data for _ in range(batch_size)]
protein_filename = data['protein_filename']
ligand_filename = data['ligand_filename']
whole_protein_name = data['whole_protein_name']
dir_name = os.path.dirname(protein_filename)
model.generate_id = 0
model.generate_id1 = 0
test_loader = DataLoader(
datalist,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=partial(collate_mols_block, batch_converter=batch_converter)
)
logger.info("生成结构...")
with torch.no_grad():
model.eval()
for batch in tqdm(test_loader, desc='Test'):
for key in batch:
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(device)
aar, rmsd, attend_logits = model.generate(batch, dir_name)
logger.info(f'RMSD: {rmsd}')
# 创建结果文件
result_path = os.path.join(dir_name, "0_whole.pdb")
relaxed_path = os.path.join(dir_name, "0_relaxed.pdb")
if os.path.exists(relaxed_path):
shutil.copy(relaxed_path, result_path)
else:
shutil.copy(pdb_save_path, result_path)
# 创建包含蛋白质和配体的合并文件用于可视化
combined_path = os.path.join(dir_name, "combined_structure.pdb")
visualization_path = create_combined_structure(result_path, sdf_save_path, combined_path)
# 创建压缩文件
zip_filename = os.path.join("./generate/upload", f"{unique_id}_results.zip")
zip_path = create_zip_file(upload_dir, zip_filename)
logger.info(f"结果已保存到 {result_path}")
logger.info(f"压缩文件已创建: {zip_path}")
summary = f"""
处理完成!
结果摘要:
- 均方根偏差 (RMSD): {rmsd}
文件说明:
- 所有结果文件已打包为ZIP文件供下载
- 包含原始输入、处理结果等
- 任务ID: {unique_id}
"""
return visualization_path, zip_path, summary
except Exception as e:
import traceback
error_trace = traceback.format_exc()
logger.error(f"处理过程中出错: {error_trace}")
return None, None, f"处理过程中出错: {str(e)}"
def gradio_interface(pdb_file, sdf_file, config_path):
"""Gradio接口函数"""
if pdb_file is None or sdf_file is None:
return None, None, "请上传PDB和SDF文件。"
logger.info(f"开始处理{pdb_file}{sdf_file}")
pdb_viewer, zip_path, message = process_files(pdb_file, sdf_file, config_path)
if pdb_viewer and os.path.exists(pdb_viewer):
return pdb_viewer, zip_path, message
else:
return None, None, message if message else "处理失败,未知错误。"
# 创建Gradio接口
with gr.Blocks(title="蛋白质-配体处理", css=custom_css) as demo:
gr.Markdown("# 蛋白质-配体结构处理", elem_classes=["title"])
gr.Markdown("上传PDB和SDF文件进行蛋白质口袋设计和配体对接分析", elem_classes=["subtitle"])
with gr.Row():
with gr.Column(scale=1):
pdb_input = gr.File(label="上传PDB文件", file_types=[".pdb"])
sdf_input = gr.File(label="上传SDF文件", file_types=[".sdf"])
config_input = gr.Textbox(label="配置文件路径", value="./configs/train_model_moad.yml")
submit_btn = gr.Button("处理文件", variant="primary")
with gr.Column(scale=2):
# 使用Molecule3D组件,固定为默认样式
view3d = Molecule3D(
label="3D结构可视化 (蛋白质卡通 + 配体周围残基棒状)",
reps=default_reps
)
output_message = gr.Textbox(label="处理状态和结果摘要", lines=8)
output_file = gr.File(label="下载完整结果包 (ZIP)")
# 处理文件的点击事件
submit_btn.click(
fn=gradio_interface,
inputs=[pdb_input, sdf_input, config_input],
outputs=[view3d, output_file, output_message]
)
gr.Markdown("""
## 使用说明
1. **上传文件**: 上传蛋白质PDB文件和配体SDF文件
2. **配置设置**: 保持默认配置路径或调整为您的配置文件位置
3. **处理文件**: 点击"处理文件"按钮开始处理
4. **结果查看**:
- 在3D查看器中交互式查看优化后的蛋白质-配体复合物结构
- 查看详细的处理结果摘要
- 下载包含所有结果文件的ZIP压缩包
## 3D可视化功能
- **旋转**: 鼠标左键拖拽
- **缩放**: 鼠标滚轮或双指缩放
- **平移**: 鼠标右键拖拽
- **重置视图**: 双击重置到初始视角
可视化样式说明:
- 蛋白质以卡通形式显示(白色碳骨架)
- 配体周围5Å内的残基以棒状形式显示(绿色碳骨架)
## 下载文件说明
ZIP压缩包包含以下文件:
- **protein.pdb**: 原始输入蛋白质文件
- **ligand.sdf**: 原始输入配体文件
- **protein_pocket.pdb**: 提取的蛋白质口袋文件
- **0_whole.pdb**: 优化后的完整蛋白质结构
- **0_relaxed.pdb**: 松弛优化后的蛋白质结构
- **combined_structure.pdb**: 用于可视化的蛋白质-配体复合物
## 技术说明
该应用程序使用深度学习方法优化蛋白质口袋结构,提高与特定配体的结合能力。主要功能包括:
- **蛋白质口袋识别**: 自动识别并提取配体结合口袋
- **结构优化设计**: 使用AI模型优化口袋残基构象
- **分子对接评分**: 使用Vina进行结合能评估
- **交互式3D可视化**: 清晰展示蛋白质-配体相互作用
- **完整结果打包**: 所有中间和最终结果文件统一打包下载
处理可能需要几分钟时间,请耐心等待。
""")
gr.Markdown("© 2025 zaixi", elem_classes=["footer"])
if __name__ == "__main__":
demo.launch(share=True)