Source code for GraphSL.GNN.GCNSI.model

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn.functional as F


[docs] class GCNConv(MessagePassing): """ Define a Graph Convolutional Network (GCN) layer. """ def __init__(self, in_channels, out_channels): """ Initialize the GCNConv layer with input and output channel dimensions. Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. """ super( GCNConv, self).__init__( aggr='add') # Setting the aggregation method for message passing # Initializing a linear transformation self.lin = torch.nn.Linear(in_channels, out_channels)
[docs] def forward(self, x, edge_index): """ Perform the forward pass of the GCNConv layer. Args: - x (torch.Tensor): Input node features. - edge_index (torch.Tensor): Edge indices representing connectivity. Returns: - torch.Tensor: Tensor after the GCN layer computation. """ # Step 1: Add self-loops edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Multiply with weights x = self.lin(x) # Step 3: Calculate the normalization row, col = edge_index deg = degree(row, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Step 4: Propagate the embeddings to the next layer return self.propagate( edge_index, size=( x.size(0), x.size(0)), x=x, norm=norm)
[docs] class GCNSI_model(torch.nn.Module): """ Define the model of Graph Convolutional Networks based Source Identification (GCNSI). """ def __init__(self): super(GCNSI_model, self).__init__() self.conv1 = GCNConv(4, 32) # Initializing the first GCN layer self.conv2 = GCNConv(32, 32) # Initializing the second GCN layer # Initializing a linear transformation layer self.fc = torch.nn.Linear(32, 2) # self.softmax=torch.nn.Softmax(dim=1)
[docs] def forward(self, x, edge_index): """ Performs the forward pass of the GCNSI model. Args: - x (numpy.ndarray): The input features augmented by LPSI. - edge_index (torch.Tensor): Edge indices representing connectivity. Returns: - x (torch.Tensor): A tensor representing identified source nodes. """ x = torch.tensor(x, dtype=torch.float) x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) x = self.fc(x) # x = self.softmax(x) return x