-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathgnn_explainer_link_pred.py
124 lines (101 loc) · 3.59 KB
/
gnn_explainer_link_pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os.path as osp
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig
from torch_geometric.nn import GCNConv
if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True),
])
dataset = Planetoid(path, dataset, transform=transform)
train_data, val_data, test_data = dataset[0]
class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
def decode(self, z, edge_label_index):
src, dst = edge_label_index
return (z[src] * z[dst]).sum(dim=-1)
def forward(self, x, edge_index, edge_label_index):
z = model.encode(x, edge_index)
return model.decode(z, edge_label_index).view(-1)
model = GCN(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
out = model(train_data.x, train_data.edge_index,
train_data.edge_label_index)
loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test(data):
model.eval()
out = model(data.x, data.edge_index, data.edge_label_index).sigmoid()
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
for epoch in range(1, 201):
loss = train()
if epoch % 20 == 0:
val_auc = test(val_data)
test_auc = test(test_data)
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
f'Test: {test_auc:.4f}')
model_config = ModelConfig(
mode='binary_classification',
task_level='edge',
return_type='raw',
)
# Explain model output for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]
explainer = Explainer(
model=model,
explanation_type='model',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=train_data.x,
edge_index=train_data.edge_index,
edge_label_index=edge_label_index,
)
print(f'Generated model explanations in {explanation.available_explanations}')
# Explain a selected target (phenomenon) for a single edge:
edge_label_index = val_data.edge_label_index[:, 0]
target = val_data.edge_label[0].unsqueeze(dim=0).long()
explainer = Explainer(
model=model,
explanation_type='phenomenon',
algorithm=GNNExplainer(epochs=200),
node_mask_type='attributes',
edge_mask_type='object',
model_config=model_config,
)
explanation = explainer(
x=train_data.x,
edge_index=train_data.edge_index,
target=target,
edge_label_index=edge_label_index,
)
available_explanations = explanation.available_explanations
print(f'Generated phenomenon explanations in {available_explanations}')