Correct way of making a prediction with Spektral?
See original GitHub issueI am unable to make predictions using a Spektral model that I trained and evaluated using the model.predict()
using the loader
functionality. It also takes very long to not make the prediction:
E.g., the following code:
import spektral
from dataset_class import GNN_Dataset
from spektral.data.dataset import Dataset
from spektral.layers import GraphSageConv, EdgeConv, ARMAConv, TAGConv, GCNConv
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from spektral.data import Dataset, Graph
from spektral.data import DisjointLoader
import numpy as np
import tensorflow as tf
from datetime import datetime
t1 = datetime.now()
dataset = GNN_Dataset('/dataset/sn/')
train_val_test_split = [0.8, 0.1, 0.10]
n_samples = len(dataset)
idx_first_val_sample = int(np.floor(n_samples * train_val_test_split[0]))
idx_first_test_sample = int(np.floor(n_samples * (train_val_test_split[0] + train_val_test_split[1])))
dataset_train = dataset[0:idx_first_val_sample]
dataset_val = dataset[idx_first_val_sample: idx_first_test_sample]
dataset_test = dataset[idx_first_test_sample:]
class SN_GNN(Model):
def __init__(self, n_labels):
super().__init__()
self.GraphSage = GraphSageConv(512)
self.TAGConv = TAGConv(512, K=5)
self.ARMAConv = ARMAConv(512, order=5)
self.EdgeConv = EdgeConv(512, mlp_hidden=5)
self.GCNConv = GCNConv(512)
self.output_layer = Dense(n_labels, activation='softmax')
def call(self, inputs):
x, a = inputs[0], inputs[1]
x = self.GraphSage([x, a])
out = self.output_layer(x)
return out
model = SN_GNN(dataset.n_labels)
loader_train = DisjointLoader(dataset_train, node_level=True, batch_size=128)
loader_val = DisjointLoader(dataset_val, node_level=True, batch_size=1)
loader_test = DisjointLoader(dataset_test, node_level=True, batch_size=1)
t2 = datetime.now()
print(f'Data loaded. Duration: {(t2 - t1).seconds}s')
model.compile(optimizer='Adam', loss='binary_crossentropy')
model.fit(loader_train.load(),
steps_per_epoch=loader_train.steps_per_epoch,
epochs=2,
use_multiprocessing=True)
print(f'Evaluation loss: {model.evaluate(loader_val.load(), steps=loader_val.steps_per_epoch, use_multiprocessing=True)}')
t3 = datetime.now()
print(f'Fit and evaluation complete. Duration: {(t3 - t1).seconds}s')
try:
model.predict(loader_test.load(), use_multiprocessing=True)
except:
t4 = datetime.now()
print(f'Prediction error. Duration: {(t4 - t3).seconds}s')
pass
produces this result:
Data loaded. Duration: 5s
Epoch 1/2
13/13 [==============================] - 2s 27ms/step - loss: 0.6433
Epoch 2/2
13/13 [==============================] - 0s 25ms/step - loss: 0.4864
297/297 [==============================] - 1s 2ms/step - loss: 0.3796
Evaluation loss: 0.3795754015445709
Fit and evaluation complete. Duration: 8s
after about an hour of processing - so it never finishes prediction.
However, I can generate a prediction by changing the prediction line to:
for d in dataset_test:
print(model.predict([d.x, tf.sparse.from_dense(tf.convert_to_tensor(d.a))]))
Is there a better way of doing this in Spektral?
Issue Analytics
- State:
- Created 3 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Getting started - Spektral
In this tutorial, we will go over the main features of Spektral while creating a graph neural network for graph classification.
Read more >Using Spektral to Link Prediction · Discussion #183 - GitHub
Hi,. sure, it's possible to do link prediction with Spektral. There are no examples yet, but you model should use some combination of...
Read more >Graph Neural Networks on Molecules with Spektral and Keras
Notebook: https://colab.research.google.com/drive/1D3VZwCQ6Naw38n19XuZJbJTKxEgQ3hwU?usp=sharing.
Read more >Aleksander Molak: Practical graph neural networks in Python ...
Aleksander Molak: Practical graph neural networks in Python with TensorFlow and Spektral. 2.4K views · 7 months ago ...more ...
Read more >Spektral: Streamlining Graph Convolution Networks - Medium
Spektral makes this ludicrously easy with a GraphConv.preprocess(matrix) method. Here we're scaling the weights of each node's connections ( ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Ah, ok. That makes sense.
loader_test
, and notloader_test.load()
…Sorry, I forgot to mention that! In your stack trace, I noticed that you were calling
No problem at all, I need to know if there are usability issues with the library or it will never improve 😄 So thanks for opening these issues and feel free to make as many as you like.