GraphSL.GNN.GCNSI package
Submodules
GraphSL.GNN.GCNSI.main module
- class GraphSL.GNN.GCNSI.main.GCNSI[source]
Bases:
objectImplement the Graph Convolutional Networks based Source Identification (GCNSI).
Dong, Ming, et al. “Multiple rumor source detection with graph convolutional networks.” Proceedings of the 28th ACM international conference on information and knowledge management. 2019.
- test(adj, test_dataset, gcnsi_model, thres, alpha=0.01)[source]
Test the GCNSI model.
Args:
adj (scipy.sparse.csr_matrix): Adjacency matrix of the graph.
test_dataset (torch.utils.data.dataset.Subset): Test dataset containing simulations and graph nodes.
gcnsi_model (GCNSI_model): Trained GCNSI model.
thres (float): Threshold value.
alpha (float): Fraction of label information from neighbors.
Returns:
metric (Metric): Evaluation metric containing accuracy, precision, recall, F1 score, and AUC score.
- train(adj, train_dataset, alpha=0.01, num_thres=10, lr=0.001, num_epoch=100, print_epoch=10, random_seed=0)[source]
Train the GCNSI model.
Args:
adj (scipy.sparse.csr_matrix): Adjacency matrix of the graph.
train_dataset (torch.utils.data.dataset.Subset): The training dataset (number of simulations * number of graph nodes * 2 (the first column is seed vector and the second column is diffusion vector)).
alpha (float): The fraction of label information that a node gets from its neighbors (between 0 and 1) to try.
num_thres (int): Number of threshold values to try.
lr (float): Learning rate.
num_epoch (int): Number of training epochs.
print_epoch (int): Number of epochs every time to print loss.
random_seed (int): Random seed.
Returns:
gcnsi_model (GCNSI_model): GCNSI model.
opt_thres (float): Optimal threshold value.
train_auc (float): Training AUC score.
opt_f1 (float): Optimal F1 score.
opt_pred (numpy.ndarray): Predicted seed vector of the training set, every column is the prediction of every simulation. It is used to adjust thres_list.
GraphSL.GNN.GCNSI.model module
- class GraphSL.GNN.GCNSI.model.GCNConv(in_channels, out_channels)[source]
Bases:
MessagePassingDefine a Graph Convolutional Network (GCN) layer.