import java.io.File; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; import org.bytedeco.cpython.*; import org.bytedeco.numpy.*; import org.bytedeco.tensorflow.*; import static org.bytedeco.cpython.global.python.*; import static org.bytedeco.numpy.global.numpy.*; import static org.bytedeco.tensorflow.global.tensorflow.*; /** * Based on the Keras code found at https://www.tensorflow.org/tutorials * * @author Samuel Audet */ public class KerasMNIST { public static void main(String[] args) throws Exception { /* try to use MKL when available */ System.setProperty("org.bytedeco.openblas.load", "mkl"); Py_AddPath(org.bytedeco.tensorflow.global.tensorflow.cachePackages()); Py_Initialize(); if (_import_array() < 0) { System.err.println("numpy.core.multiarray failed to import"); PyErr_Print(); System.exit(-1); } PyObject globals = PyModule_GetDict(PyImport_AddModule("__main__")); System.out.println("Running with TensorFlow " + TF_Version().getString()); PyRun_StringFlags("import tensorflow as tf\n" + "mnist = tf.keras.datasets.mnist\n" + "\n" + "(x_train, y_train),(x_test, y_test) = mnist.load_data()\n" + "x_train, x_test = x_train / 255.0, x_test / 255.0\n" + "\n" + "model = tf.keras.models.Sequential([\n" + " tf.keras.layers.Flatten(input_shape=(28, 28)),\n" + " tf.keras.layers.Dense(512, activation=tf.nn.relu),\n" + " tf.keras.layers.Dropout(0.2),\n" + " tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n" + "])\n" + "model.compile(optimizer='adam',\n" + " loss='sparse_categorical_crossentropy',\n" + " metrics=['accuracy'])\n" + "\n" + "model.fit(x_train, y_train, epochs=5)\n" + "model.evaluate(x_test, y_test)\n", Py_file_input, globals, globals, null); if (PyErr_Occurred() != null) { System.err.println("Python error occurred"); PyErr_Print(); System.exit(-1); } long[] dims = {1, 28, 28}; DoublePointer data = new DoublePointer( 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 ); PyObject sample = PyArray_New(PyArray_Type(), dims.length, new SizeTPointer(dims), NPY_DOUBLE, null, data, 0, NPY_ARRAY_CARRAY, null); PyDict_SetItemString(globals, "sample", sample); System.out.println("sample = " + DoubleIndexer.create(data, dims[1], dims[2])); System.out.println("prediction = "); PyObject result = PyRun_StringFlags("model.predict_classes(sample, batch_size=1, verbose=1)", Py_single_input, globals, globals, null); if (PyErr_Occurred() != null) { System.err.println("Python error occurred"); PyErr_Print(); System.exit(-1); } } }