У меня есть два вопроса:
1) Что творится с этим долбанным фреймворком? Где нормльные книги, где нормальные статить, где нормальные доки? В официальных доках черт ногу сломит (это отличительная черта гугла, например, в доках по апи ютуба и гугл плюса тоже самое), при попытки нагуглить какую-то банальную вещь сразу натыкаешься на version-hell. Я на столько отчаялся в поисках нормальной обучающей инфы, что даже купил вот эту книгу, но меня ждал сюрприз - книга предполагает, что я знаю TensorFlow, лол.
2) У меня есть простая, на первый взгляд, задачка. Мне нужно сохранить обученную модель, загрузить ее и, собственно, использовать.
Вот так я создаю и тренирую модель:
def build_model(learning_rate=0.1):
tf.reset_default_graph()
net = tflearn.input_data([None, VOCAB_SIZE])
net = tflearn.fully_connected(net, 125, activation='ReLU')
net = tflearn.fully_connected(net, 25, activation='ReLU')
net = tflearn.fully_connected(net, 2, activation='softmax')
regression = tflearn.regression(
net,
optimizer='sgd',
learning_rate=learning_rate,
loss='categorical_crossentropy')
model = tflearn.DNN(net)
return model
Создаем модель:
model = build_model(learning_rate=0.75)
Тренируем:
model.fit(
X_train,
y_train,
validation_set=0.1,
show_metric=True,
batch_size=128,
n_epoch=30)
Далее, я пытаюсь сохранить модель:
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './sentiment-model')
Получаю вот такие варнинги:
WARNING:tensorflow:Error encountered when serializing data_preprocessing.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'NoneType' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing data_augmentation.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'NoneType' object has no attribute 'name'
WARNING:tensorflow:Error encountered when serializing summary_tags.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'dict' object has no attribute 'name'
Но файлики таки создаются:
checkpoint
sentiment-model.data-00000-of-00001
sentiment-model.index
sentiment-model.meta
И вот что делать дальше, я не понимаю абсолютно. По идее, нужно выполнить код, типа этого:
sess = tf.Session()
new_saver = tf.train.import_meta_graph('sentiment-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
v_ = sess.run(v)
print(v_)
Но тут я получаю ошибку:
"The name 'SGD' refers to an Operation not in the graph."