一、分子图
分子图(molecular graph)是一种用来表示分子结构的图形方式,其中原子被表示为节点(vertices),化学键被表示为边(edges)。对于HIV(人类免疫缺陷病毒),分子图可以用来详细描述其复杂的化学结构和相互作用,这对于理解HIV的生物学特性和开发治疗药物至关重要。
二、分子图分类
直接撸代码,以HIV为例
HIV分子图一共有8个问题,现在取第一个问题:
针对上面问题,进行二分类
代码环节
oversample_data.py 数据正负样本均衡处理,进行上采样
import pandas as pd
data = pd.read_csv("data/raw/HIV_train.csv")
data.index = data["index"]
data["HIV_active"].value_counts()
start_index = data.iloc[0]["index"]
# %% Apply oversampling
# Check how many additional samples we need
neg_class = data["HIV_active"].value_counts()[0]
pos_class = data["HIV_active"].value_counts()[1]
multiplier = int(neg_class/pos_class) - 1
# Replicate the dataset for the positive class
replicated_pos = [data[data["HIV_active"] == 1]]*multiplier
# Append replicated data
data = data.append(replicated_pos,
ignore_index=True)
print(data.shape)
# Shuffle dataset
data = data.sample(frac=1).reset_index(drop=True)
# Re-assign index (This is our ID later)
index = range(start_index, start_index + data.shape[0])
data.index = index
data["index"] = data.index
data.head()
# Save
data.to_csv("data/raw/HIV_train_oversampled.csv")
特征转化 dataset_featurizer.py ,生成pyg格式的Data对象
import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import os
from tqdm import tqdm
import deepchem as dc
from rdkit import Chem
print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")
class MoleculeDataset(Dataset):
def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
"""
root = Where the dataset should be stored. This folder is split
into raw_dir (downloaded dataset) and processed_dir (processed data).
"""
self.test = test
self.filename = filename
super(MoleculeDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
""" If this file exists in raw_dir, the download is not triggered.
(The download func. is not implemented here)
"""
return self.filename
@property
def processed_file_names(self):
""" If these files are found in raw_dir, processing is skipped"""
self.data = pd.read_csv(os.path.join(self.root,self.filename)).reset_index()
if self.test:
return [f'data_test_{i}.pt' for i in list(self.data.index)]
else:
return [f'data_{i}.pt' for i in list(self.data.index)]
def download(self):
pass
def process(self):
self.data = pd.read_csv(os.path.join(self.root,self.filename))
for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
mol_obj = Chem.MolFromSmiles(mol["smiles"])
# Get node features
node_feats = self._get_node_features(mol_obj)
# Get edge features
edge_feats = self._get_edge_features(mol_obj)
# Get adjacency info
edge_index = self._get_adjacency_info(mol_obj)
# Get labels info
label = self._get_labels(mol["HIV_active"])
# Create data object
data = Data(x=node_feats,
edge_index=edge_index,
edge_attr=edge_feats,
y=label,
smiles=mol["smiles"]
)
if self.test:
torch.save(data,
os.path.join(self.processed_dir,
f'data_test_{index}.pt'))
else:
torch.save(data,
os.path.join(self.processed_dir,
f'data_{index}.pt'))
def _get_node_features(self, mol):
"""
This will return a matrix / 2d array of the shape
[Number of Nodes, Node Feature size]
"""
all_node_feats = []
for atom in mol.GetAtoms():
node_feats = []
# Feature 1: Atomic number
node_feats.append(atom.GetAtomicNum())
# Feature 2: Atom degree
node_feats.append(atom.GetDegree())
# Feature 3: Formal charge
node_feats.append(atom.GetFormalCharge())
# Feature 4: Hybridization
node_feats.append(atom.GetHybridization())
# Feature 5: Aromaticity
node_feats.append(atom.GetIsAromatic())
# Feature 6: Total Num Hs
node_feats.append(atom.GetTotalNumHs())
# Feature 7: Radical Electrons
node_feats.append(atom.GetNumRadicalElectrons())
# Feature 8: In Ring
node_feats.append(atom.IsInRing())
# Feature 9: Chirality
node_feats.append(atom.GetChiralTag())
# Append node features to matrix
all_node_feats.append(node_feats)
all_node_feats = np.asarray(all_node_feats)
return torch.tensor(all_node_feats, dtype=torch.float)
def _get_edge_features(self, mol):
"""
This will return a matrix / 2d array of the shape
[Number of edges, Edge Feature size]
"""
all_edge_feats = []
for bond in mol.GetBonds():
edge_feats = []
# Feature 1: Bond type (as double)
edge_feats.append(bond.GetBondTypeAsDouble())
# Feature 2: Rings
edge_feats.append(bond.IsInRing())
# Append node features to matrix (twice, per direction)
all_edge_feats += [edge_feats, edge_feats]
all_edge_feats = np.asarray(all_edge_feats)
return torch.tensor(all_edge_feats, dtype=torch.float)
def _get_adjacency_info(self, mol):
"""
We could also use rdmolops.GetAdjacencyMatrix(mol)
but we want to be sure that the order of the indices
matches the order of the edge features
"""
edge_indices = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices += [[i, j], [j, i]]
edge_indices = torch.tensor(edge_indices)
edge_indices = edge_indices.t().to(torch.long).view(2, -1)
return edge_indices
def _get_labels(self, label):
label = np.asarray([label])
return torch.tensor(label, dtype=torch.int64)
def len(self):
return self.data.shape[0]
def get(self, idx):
""" - Equivalent to __getitem__ in pytorch
- Is not needed for PyG's InMemoryDataset
"""
if self.test:
data = torch.load(os.path.join('../',self.processed_dir,
f'data_test_{idx}.pt'))
else:
data = torch.load(os.path.join('../', self.processed_dir,
f'data_{idx}.pt'))
return data
模型结构 :GNN1,常见的GAT结构
import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)
class GNN1(torch.nn.Module):
def __init__(self, feature_size):
super(GNN1, self).__init__()
num_classes =2
embedding_size = 1024
# GNN1 layers
self.conv1 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
self.head_transform1 = Linear(embedding_size*3, embedding_size)
self.pool1 = TopKPooling(embedding_size, ratio=0.8)
self.conv2 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
self.head_transform2 = Linear(embedding_size*3, embedding_size)
self.pool2 = TopKPooling(embedding_size, ratio=0.5)
self.conv3 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
self.head_transform3 = Linear(embedding_size*3, embedding_size)
self.pool3 = TopKPooling(embedding_size, ratio=0.3)
# Linear layers
self.linear1 = Linear(embedding_size*2, 1024)
self.linear2 = Linear(1024, num_classes)
def forward(self, x, edge_attr, edge_index, batch_index):
# First block
x = self.conv1(x,edge_index)
x = self.head_transform1(x)
x, edge_index, edge_attr, batch_index, _,_ = self.pool1(x,
edge_index,
None,
batch_index)
x1 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)
# Second block
x = self.conv2(x,edge_index)
x = self.head_transform2(x)
x, edge_index, edge_attr, batch_index, _,_ = self.pool2(x,
edge_index,
None,
batch_index)
x2 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)
# Third block
x = self.conv3(x,edge_index)
x = self.head_transform3(x)
x, edge_index, edge_attr, batch_index, _,_ = self.pool3(x,
edge_index,
None,
batch_index)
x3 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)
# Concat pooled vector
x = x1+x2+x3
# Output block
x = self.linear1(x).relu()
x = F.dropout(x,p=0.5,training=self.training)
x = self.linear2(x)
return x
GNN2
import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)
class GNN2(torch.nn.Module):
def __init__(self, feature_size):
super(GNN2, self).__init__()
num_classes = 2
embedding_size = 1024
# GNN2 layers
self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
self.head_transform1 = Linear(embedding_size*3, embedding_size)
self.pool1 = TopKPooling(embedding_size, ratio=0.8)
self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
self.head_transform2 = Linear(embedding_size*3, embedding_size)
self.pool2 = TopKPooling(embedding_size, ratio=0.5)
self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
self.head_transform3 = Linear(embedding_size*3, embedding_size)
self.pool3 = TopKPooling(embedding_size, ratio=0.3)
# Transformer layer
self.transformer = TransformerConv(embedding_size, heads=4, num_layers=3)
# Isomorphism layer
self.isomorphism = XConv(Linear(feature_size, embedding_size), K=3)
# Linear layers
self.linear1 = Linear(embedding_size*2, 1024)
self.linear2 = Linear(1024, num_classes)
def forward(self, x, edge_attr, edge_index, batch_index):
# Isomorphism layer
x = self.isomorphism(x, edge_index, edge_attr)
# Transformer layer
x = self.transformer(x, edge_index)
# First block
x = self.conv1(x, edge_index)
x = self.head_transform1(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
edge_index,
None,
batch_index)
x1 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Second block
x = self.conv2(x, edge_index)
x = self.head_transform2(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
edge_index,
None,
batch_index)
x2 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Third block
x = self.conv3(x, edge_index)
x = self.head_transform3(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
edge_index,
None,
batch_index)
x3 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Concat pooled vector
x = x1 + x2 + x3
# Output block
x = self.linear1(x).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear2(x)
return x
GNN3
import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)
class GNN3(torch.nn.Module):
def __init__(self, feature_size, edge_feature_size):
super(GNN3, self).__init__()
num_classes = 2
embedding_size = 1024
# GNN3 layers
self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
self.head_transform1 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
self.pool1 = TopKPooling(embedding_size, ratio=0.8)
self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
self.head_transform2 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
self.pool2 = TopKPooling(embedding_size, ratio=0.5)
self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
self.head_transform3 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
self.pool3 = TopKPooling(embedding_size, ratio=0.3)
# Transformer layer
self.transformer = TransformerConv(embedding_size, heads=4, num_layers=3)
# Isomorphism layer
self.isomorphism = XConv(Linear(feature_size + edge_feature_size, embedding_size), K=3)
# Linear layers
self.linear1 = Linear(embedding_size*2, 1024)
self.linear2 = Linear(1024, num_classes)
def forward(self, x, edge_attr, edge_index, batch_index):
# Isomorphism layer
x = self.isomorphism(x, edge_index, edge_attr)
# Transformer layer
x = self.transformer(x, edge_index)
# First block
x_with_edge = torch.cat([x, edge_attr], dim=1)
x = self.conv1(x_with_edge, edge_index)
x = self.head_transform1(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
edge_index,
edge_attr,
batch_index)
x1 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Second block
x_with_edge = torch.cat([x, edge_attr], dim=1)
x = self.conv2(x_with_edge, edge_index)
x = self.head_transform2(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
edge_index,
edge_attr,
batch_index)
x2 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Third block
x_with_edge = torch.cat([x, edge_attr], dim=1)
x = self.conv3(x_with_edge, edge_index)
x = self.head_transform3(x)
x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
edge_index,
edge_attr,
batch_index)
x3 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
# Concat pooled vector
x = x1 + x2 + x3
# Output block
x = self.linear1(x).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.linear2(x)
return x
train.py 训练流程代码
#%%
import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
from gnn_project.model.GNN1 import GNN1
from gnn_project.model.GNN2 import GNN2
from gnn_project.model.GNN3 import GNN3
import torch.nn.functional as F
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)
data_path = 'data/raw/HIV_data.csv'
data = pd.read_csv(data_path,index_col=[0])
print("###### Raw Data Shape - ")
print(data.shape)
print("0 : ",data["HIV_active"].value_counts()[0])
print("1 : ",data["HIV_active"].value_counts()[1])
#%%
# Split the data, Due to imbalance datasets we need to set train ration high
print("###### After Data split Shape - ")
train_data = pd.read_csv('data/split_data/HIV_train.csv')
test_data = pd.read_csv('data/split_data/HIV_test.csv')
train_data_oversampled = pd.read_csv('data/split_data/HIV_train_oversampled.csv')
print("Train Data ", train_data.shape)
print("0 : ",train_data["HIV_active"].value_counts()[0])
print("1 : ",train_data["HIV_active"].value_counts()[1])
print("Test Data ", test_data.shape)
print("0 : ",test_data["HIV_active"].value_counts()[0])
print("1 : ",test_data["HIV_active"].value_counts()[1])
print("Train Data Oversampled ", train_data_oversampled.shape)
print("0 : ",train_data_oversampled["HIV_active"].value_counts()[0])
print("1 : ",train_data_oversampled["HIV_active"].value_counts()[1])
# %%
# Define the folder path to save the images
output_folder = "visualization"
os.makedirs(output_folder, exist_ok=True)
sample_smiles = train_data["smiles"][4:30].values
sdf = Chem.SDMolSupplier(output_folder+'/cdk2.sdf')
mols = [m for m in sdf]
for i, smiles in enumerate(sample_smiles):
core = Chem.MolFromSmiles(smiles)
img = Draw.MolsToGridImage(mols, molsPerRow=3, highlightAtomLists=[mol.GetSubstructMatch(core) for mol in mols], useSVG=True)
## Save the image in the output folder
image_path = os.path.join(output_folder, f"image_{i}.png")
cairosvg.svg2png(bytestring=img.data, write_to=image_path)
break
print(f"Image saved: {image_path}")
#%%
print("######## Loading dataset...")
train_dataset = MoleculeDataset(root="data/split_data", filename="HIV_train_oversampled.csv")
test_dataset = MoleculeDataset(root="data/split_data", filename="HIV_test.csv", test=True)
# %%
print("######## Loading GNN1 Model...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
model = GNN1(feature_size=train_dataset[0].x.shape[1])
model = model.to(device)
print(f"######### Number of parameters: {count_parameters(model)}")
print(model)
# %%
# Loss and Optimizer
# Due to imbalance postive and negative label so apply weight in the +ve side < 1 increases precision, > 1 recall
weight = torch.tensor([1,10], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight==weight)
optimizer = torch.optim.SGD(model.parameters(),
lr=0.1,
momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
# %%
NUM_GRAPHS_PER_BATCH = 256
train_loader = DataLoader(train_dataset,
batch_size = NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(test_dataset,
batch_size = NUM_GRAPHS_PER_BATCH, shuffle=True)
#%%
#%%
def train(epoch, model, train_loader, optimizer, loss_fn):
#Enumerate over the data
all_preds = []
all_labels = []
running_loss = 0.0
step = 0
for _,batch in enumerate(tqdm(train_loader)):
#Using GPU
batch.to(device)
#Reset gradient
optimizer.zero_grad()
#passing the node feature and concat the info
pred = model(batch.x.float(),
batch.edge_attr.float,
batch.edge_index,
batch.batch)
# Calculating the loss and gradients
loss = loss_fn(torch.squeeze(pred), batch.y.float())
loss.backward()
optimizer.step()
# Update tracking
running_loss += loss.item()
step += 1
all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
all_labels.append(batch.y.cpu().detach().numpy())
all_preds = np.concatenate(all_preds).ravel()
all_labels = np.concatenate(all_labels).ravel()
calculate_metrics(all_preds, all_labels, epoch, "train")
return running_loss/step
#%%
def test(epoch, model, test_loader, loss_fn):
#Enumerate over the data
all_preds = []
all_preds_raw = []
all_labels = []
running_loss = 0.0
step = 0
for batch in test_loader:
#Using GPU
batch.to(device)
pred = model(batch.x.float(),
batch.edge_attr.float,
batch.edge_index,
batch.batch)
loss = loss_fn(torch.squeeze(pred), batch.y.float())
# Update tracking
running_loss += loss.item()
step += 1
all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
all_labels.append(batch.y.cpu().detach().numpy())
all_preds = np.concatenate(all_preds).ravel()
all_labels = np.concatenate(all_labels).ravel()
print(all_preds_raw[0][:10])
print(all_preds[:10])
print(all_labels[:10])
calculate_metrics(all_preds, all_labels, epoch, "test")
log_conf_matrix(all_preds, all_labels, epoch)
return running_loss/step
#%%
def log_conf_matrix(y_pred, y_true, epoch):
# Log confusion matrix as image
cm = confusion_matrix(y_pred, y_true)
classes = ["0", "1"]
df_cfm = pd.DataFrame(cm, index = classes, columns = classes)
plt.figure(figsize = (10,7))
cfm_plot = sns.heatmap(df_cfm, annot=True, cmap='Blues', fmt='g')
cfm_plot.figure.savefig(f'data/images/cm_{epoch}.png')
def calculate_metrics(y_pred, y_true, epoch, type):
print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
print(f"F1 Score: {f1_score(y_true, y_pred)}")
print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
prec = precision_score(y_true, y_pred)
rec = recall_score(y_true, y_pred)
print(f"Precision: {prec}")
print(f"Recall: {rec}")
try:
roc = roc_auc_score(y_pred,y_true)
print(f"ROC AUC: {roc}")
except:
print(f"ROC AUC: not definded")
# %%
# Start training
print("###### Start GNN Model training")
best_loss = 1000
early_stopping_counter = 0
for epoch in range(300):
if early_stopping_counter <= 10: # = x * 5
# Training
model.train()
loss = train(epoch, model, train_loader, optimizer, loss_fn)
print(f"Epoch {epoch} | Train Loss {loss}")
#mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)
# Testing
model.eval()
if epoch % 5 == 0:
loss = test(epoch, model, test_loader, loss_fn)
print(f"Epoch {epoch} | Test Loss {loss}")
# mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
# Update best loss
if float(loss) < best_loss:
best_loss = loss
# Save the currently best model
#mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
early_stopping_counter = 0
else:
early_stopping_counter += 1
scheduler.step()
else:
print("Early stopping due to no improvement.")
print([best_loss])
print(f"Finishing training with best test loss: {best_loss}")
print([best_loss])
output_folder = "model_weight"
os.makedirs(output_folder, exist_ok=True)
model_path = os.join.path(output_folder,"model.pth") # Replace with the desired path to save the model
torch.save(model, model_path)
# %%
train_optimization.py
import argparse
import os
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.data import DataLoader
from torch.nn import BCELoss
from torch.optim import Adam
from dataset_featurizer import MoleculeDataset
from model.GNN1 import GNN1
from model.GNN2 import GNN2
from model.GNN3 import GNN3
import optuna
torch.manual_seed(42)
# Create a parser object
parser = argparse.ArgumentParser(description='GNN Model Training')
# Add arguments for data paths
parser.add_argument('--test_data_path', type=str, required=True, help='Path to the test data file')
parser.add_argument('--train_oversampled', type=str, required=True, help='Path to the train oversampled data file')
# Add an argument for the GNN model selection
parser.add_argument('--model', type=str, choices=['GNN1', 'GNN2', 'GNN3'], default='GNN1', help='Choose the GNN model (GNN1, GNN2, GNN3)')
# Add an argument for the number of epochs
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
# Parse the command-line arguments
args = parser.parse_args()
# Get the data paths from the command-line arguments
test_data_path = args.test_data_path
train_oversampled_ = args.train_oversampled
# Get the selected model from the command-line arguments
selected_model = args.model
# Get the number of epochs from the command-line arguments
num_epochs = args.epochs
# Load the data
test_data = pd.read_csv(test_data_path)
train_data = pd.read_csv(train_oversampled)
model_folder = "model_weights"
os.makedirs(model_folder, exist_ok=True)
# Define the GNN model based on the selected model
if selected_model == 'GNN1':
model = GNN1(feature_size=train_data[0].x.shape[1])
elif selected_model == 'GNN2':
model = GNN2(feature_size=train_data[0].x.shape[1])
elif selected_model == 'GNN3':
model = GNN3(feature_size=train_data[0].x.shape[1])
else:
raise ValueError('Invalid model selected')
# Define the loss function
loss_fn = BCELoss()
# Define the optimizer
optimizer = Adam(model.parameters(), lr=0.001)
# Define the objective function for Optuna optimization
def objective(trial):
# Sample the hyperparameters to be tuned
hyperparameters = {
"batch_size": trial.suggest_categorical("batch_size", [32, 128, 64]),
"learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 1e-1),
"weight_decay": trial.suggest_loguniform("weight_decay", 1e-5, 1e-3),
"sgd_momentum": trial.suggest_uniform("sgd_momentum", 0.5, 0.9),
"scheduler_gamma": trial.suggest_categorical("scheduler_gamma", [0.995, 0.9, 0.8, 0.5, 1]),
"pos_weight": trial.suggest_categorical("pos_weight", [1.0]),
"model_embedding_size": trial.suggest_categorical("model_embedding_size", [8, 16, 32, 64, 128]),
"model_attention_heads": trial.suggest_int("model_attention_heads", 1, 4),
"model_layers": trial.suggest_categorical("model_layers", [3]),
"model_dropout_rate": trial.suggest_uniform("model_dropout_rate", 0.2, 0.9),
"model_top_k_ratio": trial.suggest_categorical("model_top_k_ratio", [0.2, 0.5, 0.8, 0.9]),
"model_top_k_every_n": trial.suggest_categorical("model_top_k_every_n", [0]),
"model_dense_neurons": trial.suggest_categorical("model_dense_neurons", [16, 128, 64, 256, 32]),
}
# Set the hyperparameters in the model
model.set_hyperparameters(**hyperparameters)
# Create the data loaders
train_loader = DataLoader(train_dataset, batch_size=hyperparameters["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=hyperparameters["batch_size"], shuffle=False)
# Train the model
for epoch in range(num_epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
loss = loss_fn(out, batch.y)
loss.backward()
optimizer.step()
# Evaluate the model
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for batch in test_loader:
out = model(batch)
pred = (out >= 0.5).float()
y_pred.extend(pred.tolist())
y_true.extend(batch.y.tolist())
# Compute evaluation metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
auc_roc = roc_auc_score(y_true, y_pred)
return f1
# Create the dataset
test_dataset = MoleculeDataset(test_data)
train_dataset = MoleculeDataset(train_data)
# Create the Optuna study and run the optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)
# Get the best hyperparameters and metric
best_hyperparameters = study.best_params
best_metric = study.best_value
# Set the best hyperparameters in the model
model.set_hyperparameters(**best_hyperparameters)
# Create the best data loaders
best_train_loader = DataLoader(train_dataset, batch_size=best_hyperparameters["batch_size"], shuffle=True)
best_test_loader = DataLoader(test_dataset, batch_size=best_hyperparameters["batch_size"], shuffle=False)
# Train the model using the best hyperparameters
for epoch in range(num_epochs):
model.train()
for batch in best_train_loader:
optimizer.zero_grad()
out = model(batch)
loss = loss_fn(out, batch.y)
loss.backward()
optimizer.step()
# Evaluate the model
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for batch in best_test_loader:
out = model(batch)
pred = (out >= 0.5).float()
y_pred.extend(pred.tolist())
y_true.extend(batch.y.tolist())
# Compute evaluation metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
auc_roc = roc_auc_score(y_true, y_pred)
print(f'Epoch {epoch + 1}: Accuracy={accuracy:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, AUC-ROC={auc_roc:.4f}')
print('Best Hyperparameters:', best_hyperparameters)
print('Best Metric:', best_metric)
inference.py 推理代码
import torch
import pandas as pd
from dataset_featurizer import MoleculeDataset
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score
# Load the test dataset
test_dataset = MoleculeDataset(root="data/split_data", filename="HIV_test.csv", test=True)
test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
# Load the trained model
model = torch.load(os.join.path(output_folder,"model.pth"))
model.eval()
# Create lists to store the predicted and true labels
all_preds = []
all_labels = []
all_preds_raw = []
# Perform inference on the test dataset
with torch.no_grad():
for batch in test_loader:
# Move the batch to the device
batch = batch.to(device)
# Perform forward pass
pred = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)
# Convert the predictions to class labels
preds = torch.argmax(pred, dim=1)
# Append the predicted and true labels to the lists
all_preds.extend(preds.cpu().detach().numpy())
all_labels.extend(batch.y.cpu().detach().numpy())
all_preds_raw.extend(pred.cpu().detach().numpy())
# Calculate the confusion matrix
cm = confusion_matrix(all_labels, all_preds)
# Calculate the accuracy score
accuracy = accuracy_score(all_labels, all_preds)
# Calculate the ROC AUC score
roc_auc = roc_auc_score(all_labels, all_preds_raw)
# Print the confusion matrix, accuracy score, and ROC AUC score
print("Confusion Matrix:")
print(cm)
print("Accuracy Score:", accuracy)
print("ROC AUC Score:", roc_auc)
config.py
import numpy as np
HYPERPARAMETERS = {
"batch_size": [32, 128, 64],
"learning_rate": [0.1, 0.05, 0.01, 0.001],
"weight_decay": [0.0001, 0.00001, 0.001],
"sgd_momentum": [0.9, 0.8, 0.5],
"scheduler_gamma": [0.995, 0.9, 0.8, 0.5, 1],
"pos_weight" : [1.0],
"model_embedding_size": [8, 16, 32, 64, 128],
"model_attention_heads": [1, 2, 3, 4],
"model_layers": [3],
"model_dropout_rate": [0.2, 0.5, 0.9],
"model_top_k_ratio": [0.2, 0.5, 0.8, 0.9],
"model_top_k_every_n": [0],
"model_dense_neurons": [16, 128, 64, 256, 32]
}
'''
BEST_PARAMETERS = {
"batch_size": [128],
"learning_rate": [0.01],
"weight_decay": [0.0001],
"sgd_momentum": [0.8],
"scheduler_gamma": [0.8],
"pos_weight": [1.3],
"model_embedding_size": [64],
"model_attention_heads": [3],
"model_layers": [4],
"model_dropout_rate": [0.2],
"model_top_k_ratio": [0.5],
"model_top_k_every_n": [1],
"model_dense_neurons": [256]
}
'''