Einführung in Tensorflow für Java

1. Übersicht

TensorFlow ist eine Open Source-Bibliothek für die Datenflussprogrammierung . Dies wurde ursprünglich von Google entwickelt und ist für eine Vielzahl von Plattformen verfügbar. Obwohl TensorFlow auf einem einzelnen Kern arbeiten kann, kann es genauso leicht von mehreren verfügbaren CPUs, GPUs oder TPUs profitieren .

In diesem Tutorial werden die Grundlagen von TensorFlow und die Verwendung in Java erläutert. Bitte beachten Sie, dass die TensorFlow Java API eine experimentelle API ist und daher nicht unter eine Stabilitätsgarantie fällt. Wir werden später im Tutorial mögliche Anwendungsfälle für die Verwendung der TensorFlow-Java-API behandeln.

2. Grundlagen

Die TensorFlow-Berechnung basiert im Wesentlichen auf zwei grundlegenden Konzepten: Graph und Session . Lassen Sie uns sie schnell durchgehen, um den Hintergrund zu erhalten, der für den Rest des Tutorials erforderlich ist.

2.1. TensorFlow-Diagramm

Lassen Sie uns zunächst die grundlegenden Bausteine ​​von TensorFlow-Programmen verstehen. Berechnungen werden in TensorFlow als Diagramme dargestellt . Ein Graph ist typischerweise ein gerichteter azyklischer Graph von Operationen und Daten, zum Beispiel:

Das obige Bild zeigt den Berechnungsgraphen für die folgende Gleichung:

f(x, y) = z = a*x + b*y

Ein TensorFlow-Berechnungsdiagramm besteht aus zwei Elementen:

  1. Tensor: Dies sind die Kerndateneinheiten in TensorFlow. Sie werden als Kanten in einem Berechnungsdiagramm dargestellt, das den Datenfluss durch das Diagramm darstellt. Ein Tensor kann eine Form mit einer beliebigen Anzahl von Dimensionen haben. Die Anzahl der Dimensionen in einem Tensor wird üblicherweise als Rang bezeichnet. Ein Skalar ist also ein Tensor vom Rang 0, ein Vektor ist ein Tensor vom Rang 1, eine Matrix ist ein Tensor vom Rang 2 und so weiter und so fort.
  2. Operation: Dies sind die Knoten in einem Berechnungsgraphen. Sie beziehen sich auf eine Vielzahl von Berechnungen, die an den in die Operation einspeisenden Tensoren durchgeführt werden können. Sie führen häufig auch zu Tensoren, die aus der Operation in einem Berechnungsgraphen hervorgehen.

2.2. TensorFlow-Sitzung

Ein TensorFlow-Diagramm ist lediglich ein Schema der Berechnung, die tatsächlich keine Werte enthält. Ein solches Diagramm muss innerhalb einer sogenannten TensorFlow-Sitzung ausgeführt werden, damit die Tensoren im auszuwertenden Diagramm ausgewertet werden können . Die Sitzung kann eine Reihe von Tensoren benötigen, um sie aus einem Diagramm als Eingabeparameter auszuwerten. Dann läuft es im Diagramm rückwärts und führt alle Knoten aus, die zur Bewertung dieser Tensoren erforderlich sind.

Mit diesem Wissen sind wir nun bereit, dies auf die Java-API anzuwenden!

3. Maven Setup

Wir werden ein schnelles Maven-Projekt einrichten, um ein TensorFlow-Diagramm in Java zu erstellen und auszuführen. Wir brauchen nur die Tensorflow- Abhängigkeit:

 org.tensorflow tensorflow 1.12.0 

4. Erstellen des Diagramms

Versuchen wir nun, das im vorherigen Abschnitt beschriebene Diagramm mithilfe der TensorFlow-Java-API zu erstellen. Genauer gesagt verwenden wir für dieses Tutorial die TensorFlow Java API, um die durch die folgende Gleichung dargestellte Funktion zu lösen:

z = 3*x + 2*y

Der erste Schritt besteht darin, ein Diagramm zu deklarieren und zu initialisieren:

Graph graph = new Graph()

Jetzt müssen wir alle erforderlichen Operationen definieren. Denken Sie daran, dass Operationen in TensorFlow null oder mehr Tensoren verbrauchen und erzeugen . Darüber hinaus ist jeder Knoten im Diagramm eine Operation, die Konstanten und Platzhalter enthält. Dies mag kontraintuitiv erscheinen, aber halten Sie es für einen Moment aus!

Die Klasse Graph verfügt über eine generische Funktion namens opBuilder () , mit der jede Art von Operation auf TensorFlow erstellt werden kann.

4.1. Konstanten definieren

Zunächst definieren wir in unserer obigen Grafik konstante Operationen. Beachten Sie, dass eine konstante Operation einen Tensor für ihren Wert benötigt :

Operation a = graph.opBuilder("Const", "a") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(3.0, Double.class)) .build(); Operation b = graph.opBuilder("Const", "b") .setAttr("dtype", DataType.fromClass(Double.class)) .setAttr("value", Tensor.create(2.0, Double.class)) .build();

Hier haben wir eine Operation vom konstanten Typ definiert, die den Tensor mit den Doppelwerten 2.0 und 3.0 speist . Es mag zunächst wenig überwältigend erscheinen, aber so ist es derzeit in der Java-API. Diese Konstrukte sind in Sprachen wie Python viel prägnanter.

4.2. Platzhalter definieren

Während wir unseren Konstanten Werte bereitstellen müssen, benötigen Platzhalter zur Definitionszeit keinen Wert . Die Werte für Platzhalter müssen angegeben werden, wenn das Diagramm in einer Sitzung ausgeführt wird. Wir werden diesen Teil später im Tutorial durchgehen.

Lassen Sie uns zunächst sehen, wie wir unsere Platzhalter definieren können:

Operation x = graph.opBuilder("Placeholder", "x") .setAttr("dtype", DataType.fromClass(Double.class)) .build(); Operation y = graph.opBuilder("Placeholder", "y") .setAttr("dtype", DataType.fromClass(Double.class)) .build();

Beachten Sie, dass wir für unsere Platzhalter keinen Wert angeben mussten. Diese Werte werden beim Ausführen als Tensoren eingegeben .

4.3. Funktionen definieren

Schließlich müssen wir die mathematischen Operationen unserer Gleichung definieren, nämlich Multiplikation und Addition, um das Ergebnis zu erhalten.

Dies sind wieder nichts als Operationen in TensorFlow und Graph.opBuilder () ist wieder praktisch:

Operation ax = graph.opBuilder("Mul", "ax") .addInput(a.output(0)) .addInput(x.output(0)) .build(); Operation by = graph.opBuilder("Mul", "by") .addInput(b.output(0)) .addInput(y.output(0)) .build(); Operation z = graph.opBuilder("Add", "z") .addInput(ax.output(0)) .addInput(by.output(0)) .build();

Hier haben wir dort Operation definiert , zwei zum Multiplizieren unserer Eingaben und die letzte zum Aufsummieren der Zwischenergebnisse. Beachten Sie, dass Operationen hier Tensoren erhalten, die nichts anderes als die Ausgabe unserer früheren Operationen sind.

Bitte beachten Sie, dass wir den Ausgabe- Tensor von der Operation mit dem Index '0' erhalten. Wie bereits erwähnt, kann eine Operation zu einem oder mehreren Tensoren führen. Daher müssen wir beim Abrufen eines Handles den Index erwähnen. Da wir wissen, dass unsere Operationen nur einen Tensor zurückgeben , funktioniert '0' einwandfrei!

5. Visualisierung des Diagramms

Es ist schwierig, das Diagramm im Auge zu behalten, wenn es größer wird. Daher ist es wichtig, es auf irgendeine Weise zu visualisieren . Wir können immer eine Handzeichnung wie das zuvor erstellte kleine Diagramm erstellen, dies ist jedoch für größere Diagramme nicht praktikabel. TensorFlow bietet ein Dienstprogramm namens TensorBoard, um dies zu erleichtern .

Leider kann die Java-API keine Ereignisdatei generieren, die von TensorBoard verwendet wird. Mit APIs in Python können wir jedoch eine Ereignisdatei generieren wie:

writer = tf.summary.FileWriter('.') ...... writer.add_graph(tf.get_default_graph()) writer.flush()

Bitte stören Sie sich nicht, wenn dies im Kontext von Java nicht sinnvoll ist. Dies wurde hier nur der Vollständigkeit halber hinzugefügt und ist nicht erforderlich, um den Rest des Tutorials fortzusetzen.

Wir können jetzt die Ereignisdatei in TensorBoard wie folgt laden und visualisieren:

tensorboard --logdir .

TensorBoard ist Teil der TensorFlow-Installation.

Beachten Sie die Ähnlichkeit zwischen diesem und dem zuvor manuell gezeichneten Diagramm!

6. Arbeiten mit Sitzung

We have now created a computational graph for our simple equation in TensorFlow Java API. But how do we run it? Before addressing that, let's see what is the state of Graph we have just created at this point. If we try to print the output of our final Operation “z”:

System.out.println(z.output(0));

This will result in something like:


    

This isn't what we expected! But if we recall what we discussed earlier, this actually makes sense. The Graph we have just defined has not been run yet, so the tensors therein do not actually hold any actual value. The output above just says that this will be a Tensor of type Double.

Let's now define a Session to run our Graph:

Session sess = new Session(graph)

Finally, we are now ready to run our Graph and get the output we have been expecting:

Tensor tensor = sess.runner().fetch("z") .feed("x", Tensor.create(3.0, Double.class)) .feed("y", Tensor.create(6.0, Double.class)) .run().get(0).expect(Double.class); System.out.println(tensor.doubleValue());

So what are we doing here? It should be fairly intuitive:

  • Get a Runner from the Session
  • Define the Operation to fetch by its name “z”
  • Feed in tensors for our placeholders “x” and “y”
  • Run the Graph in the Session

And now we see the scalar output:

21.0

This is what we expected, isn't it!

7. The Use Case for Java API

At this point, TensorFlow may sound like overkill for performing basic operations. But, of course, TensorFlow is meant to run graphs much much larger than this.

Additionally, the tensors it deals with in real-world models are much larger in size and rank. These are the actual machine learning models where TensorFlow finds its real use.

It's not difficult to see that working with the core API in TensorFlow can become very cumbersome as the size of the graph increases. To this end, TensorFlow provides high-level APIs like Keras to work with complex models. Unfortunately, there is little to no official support for Keras on Java just yet.

However, we can use Python to define and train complex models either directly in TensorFlow or using high-level APIs like Keras. Subsequently, we can export a trained model and use that in Java using the TensorFlow Java API.

Now, why would we want to do something like that? This is particularly useful for situations where we want to use machine learning enabled features in existing clients running on Java. For instance, recommending caption for user images on an Android device. Nevertheless, there are several instances where we are interested in the output of a machine learning model but do not necessarily want to create and train that model in Java.

This is where TensorFlow Java API finds the bulk of its use. We'll go through how this can be achieved in the next section.

8. Using Saved Models

We'll now understand how we can save a model in TensorFlow to the file system and load that back possibly in a completely different language and platform. TensorFlow provides APIs to generate model files in a language and platform neutral structure called Protocol Buffer.

8.1. Saving Models to the File System

We'll begin by defining the same graph we created earlier in Python and saving that to the file system.

Let's see we can do this in Python:

import tensorflow as tf graph = tf.Graph() builder = tf.saved_model.builder.SavedModelBuilder('./model') with graph.as_default(): a = tf.constant(2, name="a") b = tf.constant(3, name="b") x = tf.placeholder(tf.int32, name="x") y = tf.placeholder(tf.int32, name="y") z = tf.math.add(a*x, b*y, name="z") sess = tf.Session() sess.run(z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING]) builder.save()

As the focus of this tutorial in Java, let's not pay much attention to the details of this code in Python, except for the fact that it generates a file called “saved_model.pb”. Do note in passing the brevity in defining a similar graph compared to Java!

8.2. Loading Models from the File System

We'll now load “saved_model.pb” into Java. Java TensorFlow API has SavedModelBundle to work with saved models:

SavedModelBundle model = SavedModelBundle.load("./model", "serve"); Tensor tensor = model.session().runner().fetch("z") .feed("x", Tensor.create(3, Integer.class)) .feed("y", Tensor.create(3, Integer.class)) .run().get(0).expect(Integer.class); System.out.println(tensor.intValue());

It should by now be fairly intuitive to understand what the above code is doing. It simply loads the model graph from the protocol buffer and makes available the session therein. From there onward, we can pretty much do anything with this graph as we would have done for a locally-defined graph.

9. Conclusion

To sum up, in this tutorial we went through the basic concepts related to the TensorFlow computational graph. We saw how to use the TensorFlow Java API to create and run such a graph. Then, we talked about the use cases for the Java API with respect to TensorFlow.

In the process, we also understood how to visualize the graph using TensorBoard, and save and reload a model using Protocol Buffer.

Wie immer ist der Code für die Beispiele auf GitHub verfügbar.