Source code for GraphSL.GNN.SLVAE.main

import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from GraphSL.GNN.SLVAE.model import VAE, GNN
from torch.optim import Adam
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score
from GraphSL.utils import Metric


[docs] class SLVAE_model(nn.Module): """ Source Localization Variational Autoencoder (SLVAE) model combining VAE and GNN. Attributes: - vae (nn.Module): Variational Autoencoder module. - gnn (nn.Module): Graph Neural Network module. - reg_params (list): List of parameters requiring gradients. """ def __init__(self, vae: nn.Module, gnn: nn.Module): """ Initialize the SLVAE_model. Args: - vae (nn.Module): Variational Autoencoder module. - gnn (nn.Module): Graph Neural Network module. """ super(SLVAE_model, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.vae = vae.to(self.device) self.gnn = gnn.to(self.device) self.reg_params = list( filter( lambda x: x.requires_grad, self.gnn.parameters()))
[docs] def forward(self, seed_vec, train_mode): """ Forward pass method of the SLVAE model. Args: - seed_vec (torch.Tensor): Seed vector. - train_mode (bool): Flag indicating whether in training mode. Returns: - seed_hat (torch.Tensor): reconstructed seed vector. - mean (torch.Tensor): Mean of the VAE. - log_var (torch.Tensor): Log variance of the VAE. - predictions (torch.Tensor): Predictions made by the SLVAE model. """ # Pass seed_vec through VAE to obtain reconstructed seed vector, mean, # and log variance seed_hat, mean, log_var = self.vae(seed_vec) if train_mode: # Ensure values of seed_hat are within range [0, 1] seed_hat.clamp(0, 1) # Pass seed_hat through GNN and perform propagation predictions = self.gnn(seed_hat) else: # Ensure values of seed_vec are within range [0, 1] seed_vec.clamp(0, 1) # Pass seed_vec through GNN and perform propagation predictions = self.gnn(seed_vec) predictions = torch.transpose(predictions, 0, 1) # Return reconstructed seed vector, mean, log variance, and predictions return seed_hat, mean, log_var, predictions
[docs] def train_loss(self, x, x_hat, mean, log_var, y, y_hat): """ Compute training loss. Args: - x (torch.Tensor): Seed vector. - x_hat (torch.Tensor): Reconstructed seed tensor. - mean (torch.Tensor): Mean of the VAE. - log_var (torch.Tensor): Log variance of the VAE. - y (torch.Tensor): Diffusion vector. - y_hat (torch.Tensor): Predicted Diffusion vector. Returns: - total_loss (torch.Tensor): Total loss is the sum of prediction loss, reconstruction loss and KL divergence. """ forward_loss = F.mse_loss(y_hat, y) reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='mean') KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) total_loss = forward_loss + reproduction_loss + KLD return total_loss
[docs] def infer_loss(self, y_true, y_hat, x_hat, train_pred): """ Compute inference loss. Args: - y_true (torch.Tensor): True label tensor. - y_hat (torch.Tensor): Predicted label tensor. - x_hat (torch.Tensor): Reconstructed input tensor. - train_pred (torch.Tensor): Predicted tensor during training. Returns: - total_loss (torch.Tensor): Total loss tensor. """ epsilon =1e-8 BN = nn.BatchNorm1d(1, affine=False).to(self.device) y_hat = y_hat.to(self.device) y_true = y_true.to(self.device) forward_loss = F.mse_loss(y_hat, y_true) log_pmf = [] for pred in train_pred: log_lh = torch.zeros(1).to(self.device) for i, x_i in enumerate(x_hat[0]): temp = x_i * \ torch.log(pred[i]+epsilon) + (1 - x_i) * torch.log(1 - pred[i]+epsilon).to(torch.double) temp = temp.to(self.device) log_lh += temp log_pmf.append(log_lh) log_pmf = torch.stack(log_pmf) log_pmf = BN(log_pmf.float()) pmf_max = torch.max(log_pmf) pdf_sum = pmf_max + torch.logsumexp(log_pmf - pmf_max, dim=0) total_loss = forward_loss - pdf_sum return total_loss
[docs] class SLVAE: """ Implement the Source Localization Variational Autoencoder (SLVAE) model. Ling C, Jiang J, Wang J, et al. Source localization of graph diffusion via variational autoencoders for graph inverse problems[C]//Proceedings of the 28th ACM SIGKDD conference on knowledge discovery and data mining. 2022: 1010-1020. """ def __init__(self): """ Initialize the SLVAE model. """
[docs] def train( self, adj, train_dataset, num_thres=10, lr=1e-4, weight_decay=1e-4, num_epoch=100, print_epoch=10, random_seed=0): """ Train the SLVAE model. Args: - adj (scipy.sparse.csr_matrix): The adjacency matrix of the graph. - train_dataset (torch.utils.data.dataset.Subset): the training dataset (number of simulations * number of graph nodes * 2 (the first column is seed vector and the second column is diffusion vector)). - num_thres (int): Number of threshold values to try. - lr (float): Learning rate. - weight_decay (float): Weight decay. - num_epoch (int): Number of training epochs. - print_epoch (int): Number of epochs every time to print loss. - random_seed (int): Random seed. Returns: - slvae_model (SLVAE_model): Trained SLVAE model. - seed_vae_train (torch.Tensor): The latent representations of training seed vector from VAE, which is used to initialize seed vector in the test set. - opt_thres (float): Optimal threshold. - train_auc (float): Train AUC. - opt_f1 (float): Optimal F1 score. - pred (numpy.ndarray): Predicted seed vector of the training set, every column is the prediction of every simulation. It is used to adjust thres_list. Example: import os curr_dir = os.getcwd() from GraphSL.utils import load_dataset, diffusion_generation, split_dataset from GraphSL.GNN.SLVAE.main import SLVAE data_name = 'karate' graph = load_dataset(data_name, data_dir=curr_dir) dataset = diffusion_generation(graph=graph, infect_prob=0.3, diff_type='IC', sim_num=100, seed_ratio=0.1) adj, train_dataset, test_dataset =split_dataset(dataset) slave = SLVAE() slvae_model, seed_vae_train, thres, auc, f1, pred = slave.train(adj, train_dataset) print("SLVAE:") print(f"train auc: {auc:.3f}, train f1: {f1:.3f}") """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') num_node = adj.shape[0] adj_coo = adj.tocoo() values = adj_coo.data indices = np.vstack((adj_coo.row, adj_coo.col)) i = torch.LongTensor(indices) v = torch.FloatTensor(values) shape = adj.shape adj_matrix = torch.sparse_coo_tensor( i, v, torch.Size(shape)).to_dense() train_num = len(train_dataset) torch.manual_seed(random_seed) vae = VAE().to(self.device) gnn = GNN(adj_matrix=adj_matrix).to(self.device) slvae_model = SLVAE_model(vae, gnn).to(self.device) optimizer = Adam(slvae_model.parameters(), lr=lr) # Train SLVAE print("train SLVAE:") slvae_model.train() for epoch in range(num_epoch): overall_loss = 0 for influ_mat in train_dataset: seed_vec = influ_mat[:, 0].to(self.device) influ_vec = influ_mat[:, -1].to(self.device) influ_vec = influ_vec.unsqueeze(-1).float() seed_vec = seed_vec.unsqueeze(-1).float() optimizer.zero_grad() seed_vec_hat, mean, log_var, influ_vec_hat = slvae_model( seed_vec, True) loss = slvae_model.train_loss( seed_vec, seed_vec_hat, mean, log_var, influ_vec, influ_vec_hat) overall_loss += loss.item() loss.backward() optimizer.step() average_loss = overall_loss / train_num if epoch % print_epoch == 0: print(f"Epoch [{epoch}/{num_epoch}], loss = {average_loss:.3f}") # Evaluation print("infer seed from training set:") slvae_model.eval() for param in slvae_model.parameters(): param.requires_grad = False seed_vae_train = torch.zeros(size=(train_num, num_node)) for i, influ_mat in enumerate(train_dataset): seed_vec = influ_mat[:, 0].unsqueeze(-1).float() seed_vae_train[i, :] = slvae_model.vae(seed_vec)[0].squeeze(-1) seed_infer = [] seed_vae_mean = torch.mean(seed_vae_train, 0).unsqueeze(-1).to(self.device) for i in range(train_num): seed_vec_hat, _, _, influ_vec_hat = slvae_model( seed_vae_mean, False) seed_infer.append(seed_vec_hat) for seed in seed_infer: seed.requires_grad = True optimizer = Adam(seed_infer, lr=lr, weight_decay=weight_decay) infer_epoch = int(num_epoch/10) for epoch in range(infer_epoch): overall_loss = 0 for i, influ_mat in enumerate(train_dataset): influ_vec = influ_mat[:, -1] influ_vec = influ_vec.unsqueeze(-1).float() optimizer.zero_grad() seed_vec_hat, _, _, influ_vec_hat = slvae_model( seed_infer[i], False) loss = slvae_model.infer_loss( influ_vec, influ_vec_hat, seed_vec_hat, seed_vae_train) overall_loss += loss.item() loss.backward() optimizer.step() average_loss = overall_loss / train_num if epoch % print_epoch == 0: print(f"Epoch [{epoch}/{infer_epoch}], obj = {average_loss:.4f}") train_auc = 0 pred_min = 9999 pred_max = -9999 for i, influ_mat in enumerate(train_dataset): seed_vec = influ_mat[:, 0] seed_vec = seed_vec.squeeze(-1).cpu().detach().numpy() seed_pred = seed_infer[i].cpu().cpu().detach().numpy() pred_min = min(pred_min,seed_pred.min()) pred_max = max(pred_max,seed_pred.max()) train_auc += roc_auc_score(seed_vec, seed_pred) train_auc = train_auc / train_num opt_f1 = -1 opt_thres = -1 thres_list = np.linspace(pred_min, pred_max, num=num_thres+2)[1:-1].tolist() for thres in thres_list: train_f1 = 0 for i, influ_mat in enumerate(train_dataset): seed_vec = influ_mat[:, 0] seed_vec = seed_vec.squeeze(-1).cpu().detach().numpy() seed_pred = seed_infer[i].cpu().detach().numpy() train_f1 += f1_score(seed_vec, seed_pred >= thres, zero_division=1) train_f1 = train_f1 / train_num print(f"thres = {thres:.3f}, train_f1 = {train_f1:.3f}") if train_f1 > opt_f1: opt_f1 = train_f1 opt_thres = thres pred = np.zeros((num_node, train_num)) for i in range(train_num): pred[:, i] = seed_infer[i].squeeze(-1).cpu().detach().numpy() return slvae_model, seed_vae_train, opt_thres, train_auc, opt_f1, pred
[docs] def infer( self, test_dataset, slvae_model, seed_vae_train, thres, lr=0.0001, num_epoch=10, print_epoch=1): """ Infer using the SLVAE model. Args: - test_dataset (torch.utils.data.dataset.Subset): the test dataset (number of simulations * number of graph nodes * 2 (the first column is seed vector and the second column is diffusion vector)). - slvae_model (SLVAE_model): Trained SLVAE model. - seed_vae_train (torch.Tensor): The latent representations of training seed vector from VAE, which is used to initialize seed vector in the test set. - thres (float): Threshold value. - lr (float): Learning rate. - num_epoch (int): Number of epochs. - print_epoch (int): Number of epochs every time to print loss. Returns: - Metric: Evaluation metric containing accuracy, precision, recall, F1 score, and AUC. Example: import os curr_dir = os.getcwd() from GraphSL.utils import load_dataset, diffusion_generation, split_dataset from GraphSL.GNN.SLVAE.main import SLVAE data_name = 'karate' graph = load_dataset(data_name, data_dir=curr_dir) dataset = diffusion_generation(graph=graph, infect_prob=0.3, diff_type='IC', sim_num=100, seed_ratio=0.1) adj, train_dataset, test_dataset =split_dataset(dataset) slave = SLVAE() slvae_model, seed_vae_train, thres, auc, f1, pred = slave.train(adj, train_dataset) print("SLVAE:") print(f"train auc: {auc:.3f}, train f1: {f1:.3f}") metric = slave.infer(test_dataset, slvae_model, seed_vae_train, thres) print(f"test acc: {metric.acc:.3f}, test pr: {metric.pr:.3f}, test re: {metric.re:.3f}, test f1: {metric.f1:.3f}, test auc: {metric.auc:.3f}") """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') slvae_model = slvae_model.to(self.device) test_num = len(test_dataset) slvae_model.eval() for param in slvae_model.parameters(): param.requires_grad = False seed_infer = [] seed_mean = torch.mean(seed_vae_train, 0).unsqueeze(-1).to(self.device) for i in range(test_num): seed_vec_hat, _, _, influ_vec_hat = slvae_model(seed_mean, False) seed_infer.append(seed_vec_hat) for seed in seed_infer: seed.requires_grad = True optimizer = Adam(seed_infer, lr=lr) print("infer seed from test set:") for epoch in range(num_epoch): overall_loss = 0 for i, influ_mat in enumerate(test_dataset): influ_vec = influ_mat[:, -1] influ_vec = influ_vec.unsqueeze(-1).float() optimizer.zero_grad() seed_vec_hat, _, _, influ_vec_hat = slvae_model( seed_infer[i], False) loss = slvae_model.infer_loss( influ_vec, influ_vec_hat, seed_vec_hat, seed_vae_train) overall_loss += loss.item() average_loss = overall_loss / test_num loss.backward() optimizer.step() if epoch % print_epoch == 0: print(f"Epoch [{epoch}/{num_epoch}], obj = {average_loss:.4f}") test_acc = 0 test_pr = 0 test_re = 0 test_f1 = 0 test_auc = 0 for i, influ_mat in enumerate(test_dataset): seed_vec = influ_mat[:, 0] seed_vec = seed_vec.squeeze(-1).cpu().detach().numpy() seed_pred = seed_infer[i].cpu().detach().numpy() test_acc += accuracy_score(seed_vec, seed_pred >= thres) test_pr += precision_score(seed_vec, seed_pred >= thres, zero_division=1) test_re += recall_score(seed_vec, seed_pred >= thres, zero_division=1) test_f1 += f1_score(seed_vec, seed_pred >= thres, zero_division=1) test_auc += roc_auc_score(seed_vec, seed_pred) test_acc = test_acc / test_num test_pr = test_pr / test_num test_re = test_re / test_num test_f1 = test_f1 / test_num test_auc = test_auc / test_num metric = Metric(test_acc, test_pr, test_re, test_f1, test_auc) return metric