import torch
from GraphSL.GNN.IVGD.correction import correction
[docs]
class validity_net(torch.nn.Module):
"""
Validity-aware layers.
Attributes:
- number_layer (int): Number of layers.
- alpha1, alpha2, alpha3, alpha4, alpha5 (float): Alpha values for each layer.
- tau1, tau2, tau3, tau4, tau5 (float): Tau values for each layer.
- net1, net2, net3, net4, net5 (correction): Correction layer.
- rho1, rho2, rho3, rho4, rho5 (float): Rho values for each layer.
"""
def __init__(self, alpha, tau, rho):
"""
Initialize the validity_net model.
Args:
- alpha (float): Alpha value.
- tau (float): Tau value.
- rho (float): Rho value.
"""
super(validity_net, self).__init__()
self.number_layer = 5
self.alpha1 = alpha
self.alpha2 = alpha
self.alpha3 = alpha
self.alpha4 = alpha
self.alpha5 = alpha
self.tau1 = tau
self.tau2 = tau
self.tau3 = tau
self.tau4 = tau
self.tau5 = tau
self.net1 = correction()
self.net2 = correction()
self.net3 = correction()
self.net4 = correction()
self.net5 = correction()
self.rho1 = rho
self.rho2 = rho
self.rho3 = rho
self.rho4 = rho
self.rho5 = rho
[docs]
def forward(self, x, label, lamda):
"""
Forward pass of the validity-ware layer.
Args:
- x (torch.Tensor): corrected prediction of seed vector.
- label (torch.Tensor): Source Label.
- lamda (torch.Tensor): Lambda tensor.
Returns:
- Tensor: prediction subject to the validity constraint.
"""
sum = torch.sum(label)
label = torch.cat((1 - label, label), dim=1)
x = torch.cat((1 - x, x), dim=1)
prob = x[:, 1].unsqueeze(-1)
x = (self.tau1 * self.net1(prob) - label * torch.softmax(x, dim=1) / label.shape[0] - lamda
- self.rho1 * (torch.sum(x) - sum) + self.alpha1 * x) / (
self.tau1 + self.alpha1)
prob = x[:, 1].unsqueeze(-1)
lamda = lamda + self.rho1 * (torch.sum(prob) - sum)
x = (self.tau2 * self.net2(prob) - label * torch.softmax(x, dim=1) / label.shape[0] - lamda
- self.rho2 * (torch.sum(x) - sum) + self.alpha2 * x) / (
self.tau2 + self.alpha2)
prob = x[:, 1].unsqueeze(-1)
lamda = lamda + self.rho2 * (torch.sum(prob) - sum)
x = (self.tau3 * self.net3(prob) - label * torch.softmax(x, dim=1) / label.shape[0] - lamda
- self.rho3 * (torch.sum(x) - sum) + self.alpha3 * x) / (
self.tau3 + self.alpha3)
prob = x[:, 1].unsqueeze(-1)
lamda = lamda + self.rho3 * (torch.sum(prob) - sum)
x = (self.tau4 * self.net4(prob) - label * torch.softmax(x, dim=1) / label.shape[0] - lamda
- self.rho4 * (torch.sum(x) - sum) + self.alpha4 * x) / (
self.tau4 + self.alpha4)
prob = x[:, 1].unsqueeze(-1)
lamda = lamda + self.rho4 * (torch.sum(prob) - sum)
x = (self.tau5 * self.net5(prob) - label * torch.softmax(x, dim=1) / label.shape[0] - lamda
- self.rho5 * (torch.sum(x) - sum) + self.alpha5 * x) / (
self.tau5 + self.alpha5)
return x
[docs]
def correction(self, pred):
"""
Impose validity constraint on predictions.
Args:
- pred (torch.Tensor): Predictions tensor.
Returns:
- Tensor: predictions tensor after passing validity-ware layers.
"""
temp = pred[:, 0].unsqueeze(-1)
return (self.net1(temp) + self.net2(temp) + self.net3(temp) +
self.net4(temp) + self.net5(temp)) / self.number_layer