class ChebConv(nn.Module):
def __init__(self, in_c, out_c, K, bias=True, normalize=True):
"""
ChebNet conv
:param in_c: input channels
:param out_c: output channels
:param K: the order of Chebyshev Polynomial
:param bias: if use bias
:param normalize: if use norm
"""
super(ChebConv, self).__init__()
self.normalize = normalize
self.weight = nn.Parameter(torch.Tensor(K + 1, 1, in_c, out_c)) # [K+1, 1, in_c, out_c]
init.xavier_normal_(self.weight)
if bias:
self.bias = nn.Parameter(torch.Tensor(1, 1, out_c))
init.zeros_(self.bias)
else:
self.register_parameter("bias", None)
self.K = K + 1
def forward(self, inputs, graph):
"""
:param inputs: he input data, [B, N, C]
:param graph: the graph structure, [N, N]
:return: convolution result, [B, N, D]
"""
L = ChebConv.get_laplacian(graph, self.normalize) # [N, N]
mul_L = self.cheb_polynomial(L).unsqueeze(1) # [K, 1, N, N]
result = torch.matmul(mul_L, inputs) # [K, B, N, C]
result = torch.matmul(result, self.weight) # [K, B, N, D]
result = torch.sum(result, dim=0) + self.bias # [B, N, D]
return result
def cheb_polynomial(self, laplacian):
"""
Compute the Chebyshev Polynomial, according to the graph laplacian
:param laplacian: the multi order Chebyshev laplacian, [K, N, N]
:return:
"""
N = laplacian.size(0) # [N, N]
multi_order_laplacian = torch.zeros([self.K, N, N], device=laplacian.device, dtype=torch.float) # [K, N, N]
multi_order_laplacian[0] = torch.eye(N, device=laplacian.device, dtype=torch.float)
if self.K == 1:
return multi_order_laplacian
else:
multi_order_laplacian[1] = laplacian
if self.K == 2:
return multi_order_laplacian
else:
for k in range(2, self.K):
multi_order_laplacian[k] = 2 * torch.mm(laplacian, multi_order_laplacian[k - 1]) - \
multi_order_laplacian[k - 2]
return multi_order_laplacian
@staticmethod
def get_laplacian(graph, normalize):
"""
compute the laplacian of the graph
:param graph: the graph structure without self loop, [N, N]
:param normalize: whether to used the normalized laplacian
:return:
"""
if normalize:
D = torch.diag(torch.sum(graph, dim=-1) ** (-1 / 2))
L = torch.eye(graph.size(0), device=graph.device, dtype=graph.dtype) - torch.mm(torch.mm(D, graph), D)
else:
D = torch.diag(torch.sum(graph, dim=-1))
L = D - graph
return L
class ChebNet(nn.Module):
def __init__(self, in_c, hid_c, out_c, K):
"""
:param in_c: int, number of input channels.
:param hid_c: int, number of hidden channels.
:param out_c: int, number of output channels.
:param K:
"""
super(ChebNet, self).__init__()
self.conv1 = ChebConv(in_c=in_c, out_c=hid_c, K=K)
self.conv2 = ChebConv(in_c=hid_c, out_c=out_c, K=K)
self.act = nn.ReLU()
def forward(self, data, device):
graph_data = data["graph"].to(device)[0] # [N, N]
flow_x = data["flow_x"].to(device) # [B, N, H, D]
B, N = flow_x.size(0), flow_x.size(1)
flow_x = flow_x.view(B, N, -1) # [B, N, H*D]
output_1 = self.act(self.conv1(flow_x, graph_data))
output_2 = self.act(self.conv2(output_1, graph_data))
return output_2.unsqueeze(2)