多任务图注意力框架预测药物毒性,原文:Mining Toxicity Information from Large Amounts of Toxicity Data,代码:MGA,文章从 MGA/interpretation/Ames_interpret.ipynb 开始
文章目录
- 1.build_dataset
- 1.1.load_graph_from_csv_bin_for_splited
- 1.2.split_dataset_according_index
- 2.built_data_and_save_for_splited
- 2.1.multi_task_build_dataset
- 2.2.construct_RGCN_bigraph_from_smiles
- 2.2.1.atom_features
- 2.2.2.one_of_k_encoding_unk
- 2.2.3.one_of_k_encoding
- 2.2.4.etype_features
- 2.3.build_mask
1.build_dataset
args['data_name'] = 'toxicity' # changeargs['bin_path'] = '../data/' + args['data_name'] + '.bin'
args['group_path'] = '../data/' + args['data_name'] + '_group.csv'args['select_task_list'] = ['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity','Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5','Cardiotoxicity10', 'Cardiotoxicity30','CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4','NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD','NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53','Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50'] # change
args['all_task_list'] =['Carcinogenicity', 'Ames Mutagenicity', 'Respiratory toxicity','Eye irritation', 'Eye corrosion', 'Cardiotoxicity1', 'Cardiotoxicity5','Cardiotoxicity10', 'Cardiotoxicity30','CYP1A2', 'CYP2C19', 'CYP2C9', 'CYP2D6', 'CYP3A4','NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD','NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53','Acute oral toxicity (LD50)', 'LC50DM', 'BCF', 'LC50', 'IGC50']# change
args['select_task_index'] = []for index, task in enumerate(args['all_task_list']):if task in args['select_task_list']:args['select_task_index'].append(index)train_set, val_set, test_set, task_number = build_dataset.load_graph_from_csv_bin_for_splited(bin_path=args['bin_path'],group_path=args['group_path'],select_task_index=args['select_task_index']
)
- 根据参数构建数据集,这里 args[‘all_task_list’] 和 args[‘select_task_list’] 完全一致
args['bin_path'],args['group_path'],args['select_task_index']
"""
('../data/toxicity.bin','../data/toxicity_group.csv',[0,1,2,...30])"""
1.1.load_graph_from_csv_bin_for_splited
def load_graph_from_csv_bin_for_splited(bin_path='example.bin',group_path='example.csv',select_task_index=None):smiles = pd.read_csv(group_path, index_col=None).smiles.valuesgroup = pd.read_csv(group_path, index_col=None).group.to_list()graphs, detailed_information = load_graphs(bin_path)labels = detailed_information['labels']mask = detailed_information['mask']if select_task_index is not None:labels = labels[:, select_task_index]mask = mask[:, select_task_index]# calculate not_use indexnotuse_mask = torch.mean(mask.float(), 1).numpy().tolist()not_use_index = []for index, notuse in enumerate(notuse_mask):if notuse==0:not_use_index.append(index)train_index=[]val_index = []test_index = []for index, group_index in enumerate(group):if group_index=='training' and index not in not_use_index:train_index.append(index)if group_index=='valid' and index not in not_use_index:val_index.append(index)if group_index == 'test' and index not in not_use_index:test_index.append(index)graph_List = []for g in graphs:graph_List.append(g)graphs_np = np.array(graphs)train_smiles, val_smiles, test_smiles = split_dataset_according_index(smiles, train_index, val_index, test_index)train_labels, val_labels, test_labels = split_dataset_according_index(labels.numpy(), train_index, val_index,test_index, data_type='pd')train_mask, val_mask, test_mask = split_dataset_according_index(mask.numpy(), train_index, val_index, test_index,data_type='pd')train_graph, val_graph, test_graph = split_dataset_according_index(graphs_np, train_index, val_index, test_index)# delete the 0_pos_label and 0_neg_labeltask_number = train_labels.values.shape[1]train_set = []val_set = []test_set = []for i in range(len(train_index)):molecule = [train_smiles[i], train_graph[i], train_labels.values[i], train_mask.values[i]]train_set.append(molecule)for i in range(len(val_index)):molecule = [val_smiles[i], val_graph[i], val_labels.values[i], val_mask.values[i]]val_set.append(molecule)for i in range(len(test_index)):molecule = [test_smiles[i], test_graph[i], test_labels.values[i], test_mask.values[i]]test_set.append(molecule)print(len(train_set), len(val_set), len(test_set), task_number)return train_set, val_set, test_set, task_number
将源代码中的数据来源改成 Hepatotoxicity, select_task_index 设为 None 可以运行,打印输出:
train_set[:3],val_set[:3],test_set[:3],task_number
"""
([['O=[N+]([O-])c1cccc([N+](=O)[O-])c1', Graph(num_nodes=12, num_edges=24,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],['C#C[C@]1(O)CC[C@H]2[C@@H]3CCc4cc(OS(=O)(=O)O)ccc4[C@H]3CC[C@@]21C',Graph(num_nodes=26, num_edges=58,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),array([1]),array([1])],['C=C1C(=CC=C2CCCC3(C)C2CCC3C(C)CCCC(C)C)CC(O)CC1O',Graph(num_nodes=29, num_edges=62,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),array([1]),array([1])]],[['ClC1C(Cl)C(Cl)C(Cl)C(Cl)C1Cl', Graph(num_nodes=12, num_edges=24,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],['Cc1c([N+](=O)[O-])cc([N+](=O)[O-])cc1[N+](=O)[O-]',Graph(num_nodes=16, num_edges=32,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),array([1]),array([1])],['Nc1nc(NC2CC2)c2ncn([C@H]3C=C[C@@H](CO)C3)c2n1',Graph(num_nodes=21, num_edges=48,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),array([1]),array([1])]],[['C[C@H](N)C(=O)N[C@@H](C)C(=O)NC1[C@@H]2CN(c3nc4c(cc3F)c(=O)c(C(=O)O)cn4-c3ccc(F)cc3F)C[C@H]12',Graph(num_nodes=40, num_edges=88,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}),array([1]),array([1])],['C=CCOc1ccc(CC(=O)O)cc1Cl', Graph(num_nodes=15, num_edges=30,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])],['O=C(O)CCCCCCNC1c2ccccc2CCc2ccccc21', Graph(num_nodes=25, num_edges=54,ndata_schemes={'atom': Scheme(shape=(40,), dtype=torch.int64)}edata_schemes={'normal': Scheme(shape=(), dtype=torch.float32), 'etype': Scheme(shape=(), dtype=torch.int64)}), array([1]), array([1])]],1)
"""
- 这里每个数据点有 smiles,graph,label 和 mask,根据 2.3.build_mask 的分析,这里 mask 值为 0 表示 label 为 na,即数据无效
1.2.split_dataset_according_index
def split_dataset_according_index(dataset, train_index, val_index, test_index, data_type='np'):if data_type =='pd':return pd.DataFrame(dataset[train_index]), pd.DataFrame(dataset[val_index]), pd.DataFrame(dataset[test_index])if data_type =='np':return dataset[train_index], dataset[val_index], dataset[test_index]
- 构建模型所用的数据构造方法应该是下面的函数
2.built_data_and_save_for_splited
def built_data_and_save_for_splited(origin_path='example.csv',save_path='example.bin',group_path='example_group.csv',task_list_selected=None,):'''origin_path: strorigin csv data set path, including molecule name, smiles, tasksave_path: strgraph out put pathgroup_path: strgroup out put pathtask_list_selected: lista list of selected task'''data_origin = pd.read_csv(origin_path)smiles_name = 'smiles'data_origin = data_origin.fillna(123456)labels_list = [x for x in data_origin.columns if x not in ['smiles', 'group']]if task_list_selected is not None:labels_list = task_list_selecteddata_set_gnn = multi_task_build_dataset(dataset_smiles=data_origin, labels_list=labels_list, smiles_name=smiles_name)smiles, graphs, labels, mask, split_index = map(list, zip(*data_set_gnn))graph_labels = {'labels': torch.tensor(labels),'mask': torch.tensor(mask)}split_index_pd = pd.DataFrame(columns=['smiles', 'group'])split_index_pd.smiles = smilessplit_index_pd.group = split_indexsplit_index_pd.to_csv(group_path, index=None, columns=None)print('Molecules graph is saved!')save_graphs(save_path, graphs, graph_labels)
- 根据原始的 csv 文件构造分子图,输出 _group.csv 和 .bin 文件
- 原始的 csv 文件应该是有多列,其中两列分别是 smiles 和 group,其他列是 task 的名字
- map(list, zip(*data_set_gnn)) 的效果可以类比下面的情况,相当于把每个分子的 smiles,graphs 等分别汇总存储
dataset=[[1,2,3,4,5], #mol1[6,7,8,9,10] #mol2
]
list(map(list,zip(*dataset)))
"""
[[1, 6], [2, 7], [3, 8], [4, 9], [5, 10]]
"""
2.1.multi_task_build_dataset
def multi_task_build_dataset(dataset_smiles, labels_list, smiles_name):dataset_gnn = []failed_molecule = []labels = dataset_smiles[labels_list]split_index = dataset_smiles['group']smilesList = dataset_smiles[smiles_name]molecule_number = len(smilesList)for i, smiles in enumerate(smilesList):try:g = construct_RGCN_bigraph_from_smiles(smiles)mask = build_mask(labels.loc[i], mask_value=123456)molecule = [smiles, g, labels.loc[i], mask, split_index.loc[i]]dataset_gnn.append(molecule)print('{}/{} molecule is transformed!'.format(i+1, molecule_number))except:print('{} is transformed failed!'.format(smiles))molecule_number = molecule_number - 1failed_molecule.append(smiles)print('{}({}) is transformed failed!'.format(failed_molecule, len(failed_molecule)))return dataset_gnn
- 编码 smiles 为分子图并构建 mask 后返回数据集,shape 应该是 (molecule_num,5)
- 如果有多个 task,labels 应该是一个分子对应的多个任务的标签,依据 origin_path 文件进行分类
2.2.construct_RGCN_bigraph_from_smiles
def construct_RGCN_bigraph_from_smiles(smiles):g = DGLGraph()# Add nodesmol = MolFromSmiles(smiles)num_atoms = mol.GetNumAtoms()g.add_nodes(num_atoms)atoms_feature_all = []for atom_index, atom in enumerate(mol.GetAtoms()):atom_feature = atom_features(atom).tolist()atoms_feature_all.append(atom_feature)g.ndata["atom"] = torch.tensor(atoms_feature_all)# Add edgessrc_list = []dst_list = []etype_feature_all = []num_bonds = mol.GetNumBonds()for i in range(num_bonds):bond = mol.GetBondWithIdx(i)etype_feature = etype_features(bond)u = bond.GetBeginAtomIdx()v = bond.GetEndAtomIdx()src_list.extend([u, v])dst_list.extend([v, u])etype_feature_all.append(etype_feature)etype_feature_all.append(etype_feature)g.add_edges(src_list, dst_list)normal_all = []for i in etype_feature_all:normal = etype_feature_all.count(i)/len(etype_feature_all)normal = round(normal, 1)normal_all.append(normal)g.edata["etype"] = torch.tensor(etype_feature_all)g.edata["normal"] = torch.tensor(normal_all)return g
- 利用 rdkit 中的 atom 对象编码原子全局特征。g.ndata[“atom”] 是原子属性的独热编码,因此一个分子 smiles 转化为分子图后 atom 特征的 shape 是 (atom_num, n),这里的 n 是 40
- 利用 rdkit 中的 bond 对象编码化学键全局特征。g.edata[“etype”] 是化学键属性的数值编码,g.edata[“normal”] 是 g.edata[“etype”] 的统计数据
- construct_RGCN_bigraph_from_smiles 中 RGCN 体现在哪里还不明朗
2.2.1.atom_features
def atom_features(atom, explicit_H = False, use_chirality=True):results = one_of_k_encoding_unk(atom.GetSymbol(),['B','C','N','O','F','Si','P','S','Cl','As','Se','Br','Te','I','At','other']) + one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6]) +[atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] +one_of_k_encoding_unk(atom.GetHybridization(), [Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,Chem.rdchem.HybridizationType.SP3D2,'other']) + [atom.GetIsAromatic()]# [atom.GetIsAromatic()] # set all aromaticity feature blank.# In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`if not explicit_H:results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),[0, 1, 2, 3, 4])if use_chirality:try:results = results + one_of_k_encoding_unk(atom.GetProp('_CIPCode'),['R', 'S']) + [atom.HasProp('_ChiralityPossible')]except:results = results + [False, False] + [atom.HasProp('_ChiralityPossible')]return np.array(results)
atom.GetSymbol() 获取原子的元素符号,atom.GetDegree() 获取原子连接的键,对原子的一些全局属性进行 one-hot 编码
2.2.2.one_of_k_encoding_unk
def one_of_k_encoding_unk(x, allowable_set):"""Maps inputs not in the allowable set to the last element."""if x not in allowable_set:x = allowable_set[-1]return [x == s for s in allowable_set]
- 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 最后一个编码点会被设为 1
2.2.3.one_of_k_encoding
def one_of_k_encoding(x, allowable_set):if x not in allowable_set:raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))return [x == s for s in allowable_set]
- 根据 allowable_set 进行 one-hot 编码,如果不在 allowable_set 会报错
2.2.4.etype_features
def etype_features(bond, use_chirality=True, atompair=True):bt = bond.GetBondType()bond_feats_1 = [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE,bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC,]for i, m in enumerate(bond_feats_1):if m == True:a = ibond_feats_2 = bond.GetIsConjugated()if bond_feats_2 == True:b = 1else:b = 0bond_feats_3 = bond.IsInRingif bond_feats_3 == True:c = 1else:c = 0index = a * 1 + b * 4 + c * 8if use_chirality:bond_feats_4 = one_of_k_encoding_unk(str(bond.GetStereo()),["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])for i, m in enumerate(bond_feats_4):if m == True:d = iindex = index + d * 16if atompair == True:atom_pair_str = bond.GetBeginAtom().GetSymbol() + bond.GetEndAtom().GetSymbol()bond_feats_5 = one_of_k_atompair_encoding(atom_pair_str, [['CC'], ['CN', 'NC'], ['ON', 'NO'], ['CO', 'OC'], ['CS', 'SC'],['SO', 'OS'], ['NN'], ['SN', 'NS'], ['CCl', 'ClC'], ['CF', 'FC'],['CBr', 'BrC'], ['others']])for i, m in enumerate(bond_feats_5):if m == True:e = iindex = index + e*64return index
以十进制编码数字化学键特征,index 是二进制数字 e0dcb0a 转化为 十进制后的值,即 i n d e x = a × 2 0 + 0 × 2 1 + b × 2 2 + c × 2 3 + d × 2 4 + 0 × 2 5 + e × 2 6 index=a\times 2^0+0\times 2^1+b\times 2^2+c\times 2^3+d\times 2^4+0\times 2^5+e\times 2^6 index=a×20+0×21+b×22+c×23+d×24+0×25+e×26
2.3.build_mask
def build_mask(labels_list, mask_value=100):mask = []for i in labels_list:if i==mask_value:mask.append(0)else:mask.append(1)return mask
- 如果 label 的值无效,mask 就是 0,有效的话 mask 的值是 1。
- 之前在 built_data_and_save_for_splited 函数中 进行了 data_origin.fillna(123456),而这里 mask = build_mask(labels.loc[i], mask_value=123456)