Já falamos sobre tensorflow e JavaFX no blog em inglês, mas a API Java do tensorflow ainda está incompleta. Uma API madura e melhor documentada é o DeepLearning4J.
Nesse exemplo nos carregamos o modelo treinado na nossa aplicação, criamos um canvas para desenhar e quando o Enter é pressionado, a imagem do canvas é redimensionada e enviada para o modelo já treinado do deeplearning4j para reconhecimento:
A forma que isso "advinha digitos é como o "olá mundo" para aprendizado usando redes neurais. Um neurônio imita o neurônio natural do nosso cérebro e ele tem um "peso" que controla quando o neurônio é ativado (no nosso cérebro acontece químicas para ativar um neurônio).Uma rede neural consiste de muitos neurônios conectados uns aos outros e também organizados em camadas. O que temos que fazer é fornecer dados identificados (com labels) para a nossa rede neural e ajustar os pesos dos neurônio até que ela possa "advinhar" resultados para um valor novo que não sabemos. Esse processo é chamado de treinamento.
Uma vez treinada, nós testamos a rede neural com outros valores também conhecidos para saber a precisão da rede neural (no nosso caso a precisão é 97.5%!). No nosso case usamos o famoso conjunto de dados MNIST.
Pelo motivo de termos camadas ocultas entra a camada de entrada de dados e a camada de saída, nós chamados essas redes neurais de "profundas" (deep neural network) e são usadas no processo de aprendizagem profunda. Nós temos muitos outros conceitos e tipos de redes neurais, eu sugiro você assistis alguns vídeos super interessantes sobre isso no youtube.
E se você estiver ouvindo falar disso pela primeira vez tenha em mente que não será a última!
Se você tentar o código abaixo vai ver que não é tão preciso quanto essa app web, por exemplo. O motivo é que eu não manipulo a imagem precisamente antes de enviar ela para predição, simplesmente redimensionamos ela para ter 28x28 pixels como requerido pelo modelo treinado..
O código da aplicação JavaFX é mostrado abaixo e o projeto completo está no meu github, including o código usando para treinar nossa rede neural que foi tirado dos deeplearning4j examples.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.fxapps.deeplearning; | |
import java.awt.Graphics; | |
import java.awt.Image; | |
import java.awt.image.BufferedImage; | |
import java.io.File; | |
import java.io.IOException; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import javafx.application.Application; | |
import javafx.embed.swing.SwingFXUtils; | |
import javafx.geometry.Pos; | |
import javafx.scene.Scene; | |
import javafx.scene.canvas.Canvas; | |
import javafx.scene.canvas.GraphicsContext; | |
import javafx.scene.control.Label; | |
import javafx.scene.image.ImageView; | |
import javafx.scene.image.WritableImage; | |
import javafx.scene.input.KeyCode; | |
import javafx.scene.input.MouseButton; | |
import javafx.scene.layout.HBox; | |
import javafx.scene.layout.VBox; | |
import javafx.scene.paint.Color; | |
import javafx.scene.shape.StrokeLineCap; | |
import javafx.stage.Stage; | |
public class MnistTestFXApp extends Application { | |
private final int CANVAS_WIDTH = 150; | |
private final int CANVAS_HEIGHT = 150; | |
private NativeImageLoader loader; | |
private MultiLayerNetwork model; | |
private Label lblResult; | |
public static void main(String[] args) throws IOException { | |
launch(); | |
} | |
@Override | |
public void start(Stage stage) throws Exception { | |
Canvas canvas = new Canvas(CANVAS_WIDTH, CANVAS_HEIGHT); | |
ImageView imgView = new ImageView(); | |
GraphicsContext ctx = canvas.getGraphicsContext2D(); | |
model = ModelSerializer.restoreMultiLayerNetwork(new File("minist-model.zip")); | |
loader = new NativeImageLoader(28,28,1,true); | |
imgView.setFitHeight(100); | |
imgView.setFitWidth(100); | |
ctx.setLineWidth(10); | |
ctx.setLineCap(StrokeLineCap.SQUARE); | |
lblResult = new Label(); | |
HBox hbBottom = new HBox(10, imgView, lblResult); | |
VBox root = new VBox(5, canvas, hbBottom); | |
hbBottom.setAlignment(Pos.CENTER); | |
root.setAlignment(Pos.CENTER); | |
Scene scene = new Scene(root, 520, 300); | |
stage.setScene(scene); | |
stage.show(); | |
stage.setTitle("Handwritten digits recognition"); | |
canvas.setOnMousePressed(e -> { | |
ctx.setStroke(Color.WHITE); | |
ctx.beginPath(); | |
ctx.moveTo(e.getX(), e.getY()); | |
ctx.stroke(); | |
}); | |
canvas.setOnMouseDragged(e -> { | |
ctx.setStroke(Color.WHITE); | |
ctx.lineTo(e.getX(), e.getY()); | |
ctx.stroke(); | |
}); | |
canvas.setOnMouseClicked(e -> { | |
if (e.getButton() == MouseButton.SECONDARY) { | |
clear(ctx); | |
} | |
}); | |
canvas.setOnKeyReleased(e -> { | |
if(e.getCode() == KeyCode.ENTER) { | |
BufferedImage scaledImg = getScaledImage(canvas); | |
imgView.setImage(SwingFXUtils.toFXImage(scaledImg, null)); | |
try { | |
predictImage(scaledImg); | |
} catch (Exception e1) { | |
e1.printStackTrace(); | |
} | |
} | |
}); | |
clear(ctx); | |
canvas.requestFocus(); | |
} | |
private BufferedImage getScaledImage(Canvas canvas) { | |
// for a better recognition we should improve this part of how we retrieve the image from the canvas | |
WritableImage writableImage = new WritableImage(CANVAS_WIDTH, CANVAS_HEIGHT); | |
canvas.snapshot(null, writableImage); | |
Image tmp = SwingFXUtils.fromFXImage(writableImage, null).getScaledInstance(28, 28, Image.SCALE_SMOOTH); | |
BufferedImage scaledImg = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); | |
Graphics graphics = scaledImg.getGraphics(); | |
graphics.drawImage(tmp, 0, 0, null); | |
graphics.dispose(); | |
return scaledImg; | |
} | |
private void clear(GraphicsContext ctx) { | |
ctx.setFill(Color.BLACK); | |
ctx.fillRect(0, 0, 300, 300); | |
} | |
private void predictImage(BufferedImage img ) throws IOException { | |
ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0, 1); | |
INDArray image = loader.asRowVector(img); | |
imagePreProcessingScaler.transform(image); | |
INDArray output = model.output(image); | |
String putStr = output.toString(); | |
lblResult.setText("Prediction: " + model.predict(image)[0] + "\n " + putStr); | |
} | |
} |
Nenhum comentário:
Postar um comentário