19 de jun. de 2017

Algoritmo k-means e árvores de decisão com Weka e JavaFX

Weka é uma das ferramentas mais conhecidas para Machine Learning em Java, que também tem uma ótima API Java que inclui APIs para agrupamento (ou clustering) de dados usando o algoritmo k-means. Usando JavaFX é possível visualizar dados não-classificados e classsificar os dados usando as APIs usando APIs Weka e então visualizar o resultado em um gráfico JavaFX como o gráfico "scatter".


Nesse post vamos mostrar como  uma simples aplicação JavaFX permite você carregar dados, mostrar os dados sem distinção de categoria usando um gráfico, então usaremos weka para classificar os dados usando k-means e finalmnte vamos classificar os dados usando uma árvore de decisão. Usaremos os dados do arquivo iris.2D.arff que vem junto com o download do Weka

K-means clustering usando Weka é realmente simples e requer somente algumas linhas de código como você pode ver nesse post. Na nossa applicação iremos construir 3 gráficos para os dados de flores íris:

  1. Dados sem distinção de classe (sem séries)
  2. Os dados com a classificação real
  3. Dados clusterizados usando Weka

Como você pode ver os dados foram agrupados de uma forma que é bem próxima dos dados reais (os dados com valores coletados na vida real). O código para construir os dados agrupados é:

private List<Series<Number, Number>> buildClusteredSeries(Instances data) throws Exception {
List<XYChart.Series<Number, Number>> clusteredSeries = new ArrayList<>();
// to buld the cluster we remove the class information
Remove remove = new Remove();
remove.setAttributeIndices("3");
remove.setInputFormat(data);
Instances dataToBeClustered = Filter.useFilter(data, remove);
SimpleKMeans kmeans = new SimpleKMeans();
kmeans.setSeed(10);
kmeans.setPreserveInstancesOrder(true);
kmeans.setNumClusters(3);
kmeans.buildClusterer(dataToBeClustered);
data.deleteStringAttributes();
int[] assignments = kmeans.getAssignments();
for (int c = 0; c < 3; c++) {
XYChart.Series<Number, Number> series = new XYChart.Series<>();
series.setName("Cluster " + c);
clusteredSeries.add(series);
}
for (int i = 0; i < assignments.length; i++) {
int clusterNum = assignments[i];
clusteredSeries.get(clusterNum).getData().add(instancetoChartData(data.get(i)));
}
return clusteredSeries;
}

Depois de montar esses três gráficos também modifiquei todo o código para adicionar um classificar com árvores de decisão usando a implementação do  algoritmo J48. Logo após os gráficos você pode ver a árvore de decisão que montamos a partir dos dados:



Quando você clica no gráfico sem classificação você verá que novos dados são adicionados e ele será classificado nos gráficos superiores usando a árvore de decisão e o algoritmo k-means de clustering.

Nós usamos a árvore de decisão que geramos para classificar os ados e também o cluster. Na imagem acima o cluster classifica alguns dados de forma diferente do que é classificado com a árvore de decisão.

datafile = new BufferedReader(new FileReader(DATA_SET));
data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
tree = new J48();
tree.buildClassifier(data);
Instance instance = new DenseInstance(3);
instance.setDataset(data);
instance.setValue(0, xValue.doubleValue());
instance.setValue(1, yValue.doubleValue());
double predictedClass = tree.classifyInstance(instance);
instance.setValue(2, pred

Eu acho que é particularmente interessante como é fácil visualizar dados com JavaFX. O código completo para esse projeto pode ser encontrado no meu  github, mas aqui está o código da classe principal:



package org.fxapps.ml;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javafx.application.Application;
import javafx.geometry.Point2D;
import javafx.scene.Scene;
import javafx.scene.chart.Axis;
import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.ScatterChart;
import javafx.scene.chart.XYChart;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.scene.control.Button;
import javafx.scene.control.Label;
import javafx.scene.control.Separator;
import javafx.scene.layout.GridPane;
import javafx.scene.layout.VBox;
import javafx.scene.paint.Color;
import javafx.scene.text.Font;
import javafx.scene.text.FontPosture;
import javafx.scene.text.FontWeight;
import javafx.scene.text.Text;
import javafx.scene.text.TextAlignment;
import javafx.stage.Stage;
import weka.classifiers.trees.J48;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
public class Clustering extends Application {
private static final int NUMBER_OF_CLASSES = 3;
private static final String DATA_SET = "/opt/weka/weka-3-7-12/data/iris.2D.arff";
private ScatterChart<Number, Number> clusteredChart;
private ScatterChart<Number, Number> realDataChart;
private ScatterChart<Number, Number> noClassificationChart;
private static int swapIndex = 0;
private int[][] swapColorsCombinations = { { 0, 1 }, { 0, 2 }, { 1, 2 } };
private J48 tree;
private Instances data;
public static void main(String[] args) throws Exception {
launch();
}
@Override
public void start(Stage stage) throws Exception {
loadData();
tree = new J48();
tree.buildClassifier(data);
noClassificationChart = buildChart("No Classification (click to add new data)", buildSingleSeries());
clusteredChart = buildChart("Clustered", buildClusteredSeries());
realDataChart = buildChart("Real Data (+ Decision Tree classification for new data)", buildLabeledSeries());
noClassificationChart.setOnMouseClicked(e -> {
Axis<Number> xAxis = noClassificationChart.getXAxis();
Axis<Number> yAxis = noClassificationChart.getYAxis();
Point2D mouseSceneCoords = new Point2D(e.getSceneX(), e.getSceneY());
double x = xAxis.sceneToLocal(mouseSceneCoords).getX();
double y = yAxis.sceneToLocal(mouseSceneCoords).getY();
Number xValue = xAxis.getValueForDisplay(x);
Number yValue = yAxis.getValueForDisplay(y);
reloadSeries(xValue, yValue);
});
Label lblDecisionTreeTitle = new Label("Decision Tree generated for the Iris dataset:");
Text txtTree = new Text(tree.toString());
Button btnRestore = new Button("Restore original data");
Button btnSwapColors = new Button("Swap clustered chart colors");
VBox vbDecisionTree = new VBox(10, lblDecisionTreeTitle, new Separator(), txtTree, btnRestore, btnSwapColors);
btnRestore.setOnAction(e -> {
loadData();
reloadSeries();
});
btnSwapColors.setOnAction(e -> swapClusteredChartSeriesColors());
lblDecisionTreeTitle.setTextFill(Color.DARKRED);
lblDecisionTreeTitle.setFont(Font.font(Font.getDefault().getFamily(), FontWeight.BOLD, FontPosture.ITALIC, 16));
txtTree.setTranslateX(100);
txtTree.setFont(Font.font(Font.getDefault().getFamily(), FontWeight.BOLD, FontPosture.ITALIC, 14));
txtTree.setLineSpacing(1);
txtTree.setTextAlignment(TextAlignment.LEFT);
vbDecisionTree.setTranslateY(20);
vbDecisionTree.setTranslateX(20);
GridPane gpRoot = new GridPane();
gpRoot.add(realDataChart, 0, 0);
gpRoot.add(clusteredChart, 1, 0);
gpRoot.add(noClassificationChart, 0, 1);
gpRoot.add(vbDecisionTree, 1, 1);
stage.setScene(new Scene(gpRoot));
stage.setTitle("Íris dataset clustering and visualization");
stage.show();
}
private void loadData() {
BufferedReader datafile;
try {
datafile = new BufferedReader(new FileReader(DATA_SET));
data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
} catch (Exception e) {
System.out.println("Exception loading data... Leaving");
e.printStackTrace();
System.exit(0);
}
}
private void reloadSeries(Number xValue, Number yValue) {
try {
Instance instance = new DenseInstance(NUMBER_OF_CLASSES);
instance.setDataset(data);
instance.setValue(0, xValue.doubleValue());
instance.setValue(1, yValue.doubleValue());
double predictedClass = tree.classifyInstance(instance);
instance.setValue(2, predictedClass);
data.add(instance);
reloadSeries();
} catch (Exception e) {
e.printStackTrace();
}
}
private void reloadSeries() {
try {
noClassificationChart.getData().clear();
clusteredChart.getData().clear();
realDataChart.getData().clear();
noClassificationChart.getData().addAll(buildSingleSeries());
clusteredChart.getData().addAll(buildClusteredSeries());
realDataChart.getData().addAll(buildLabeledSeries());
} catch (Exception e) {
e.printStackTrace();
}
}
private void swapClusteredChartSeriesColors() {
List<Series<Number, Number>> clusteredSeries = new ArrayList<>();
// we have to copy the original data to swap the series
clusteredChart.getData().forEach(serie -> {
Series<Number, Number> series = new Series<>();
serie.getData().stream().map(d -> new Data<Number, Number>(d.getXValue(), d.getYValue()))
.forEach(series.getData()::add);
clusteredSeries.add(series);
});
int i = swapColorsCombinations[swapIndex][0];
int j = swapColorsCombinations[swapIndex][1];
Collections.swap(clusteredSeries, i, j);
clusteredChart.getData().clear();
clusteredChart.getData().addAll(clusteredSeries);
swapIndex = swapIndex == NUMBER_OF_CLASSES - 1 ? 0 : swapIndex + 1;
}
private List<XYChart.Series<Number, Number>> buildSingleSeries() {
XYChart.Series<Number, Number> singleSeries = new XYChart.Series<>();
data.stream().map(this::instancetoChartData).forEach(singleSeries.getData()::add);
singleSeries.setName("no classification");
return Arrays.asList(singleSeries);
}
private List<Series<Number, Number>> buildLabeledSeries() {
List<XYChart.Series<Number, Number>> realSeries = new ArrayList<>();
Attribute irisClasses = data.attribute(2);
data.stream().collect(Collectors.groupingBy(d -> {
int i = (int) d.value(2);
return irisClasses.value(i);
})).forEach((e, instances) -> {
XYChart.Series<Number, Number> series = new XYChart.Series<>();
series.setName(e);
instances.stream().map(this::instancetoChartData).forEach(series.getData()::add);
realSeries.add(series);
});
return realSeries;
}
private List<Series<Number, Number>> buildClusteredSeries() throws Exception {
List<XYChart.Series<Number, Number>> clusteredSeries = new ArrayList<>();
// to build the cluster we remove the class information
Remove remove = new Remove();
remove.setAttributeIndices("3");
remove.setInputFormat(data);
Instances dataToBeClustered = Filter.useFilter(data, remove);
SimpleKMeans kmeans = new SimpleKMeans();
kmeans.setSeed(10);
kmeans.setPreserveInstancesOrder(true);
kmeans.setNumClusters(3);
kmeans.buildClusterer(dataToBeClustered);
IntStream.range(0, 3).mapToObj(i -> {
Series<Number, Number> newSeries = new XYChart.Series<>();
newSeries.setName(String.valueOf(i));
return newSeries;
}).forEach(clusteredSeries::add);
int[] assignments = kmeans.getAssignments();
for (int i = 0; i < assignments.length; i++) {
int clusterNum = assignments[i];
clusteredSeries.get(clusterNum).getData().add(instancetoChartData(data.get(i)));
}
return clusteredSeries;
}
private XYChart.Data<Number, Number> instancetoChartData(Instance i) {
return new XYChart.Data<Number, Number>(i.value(0), i.value(1));
}
private ScatterChart<Number, Number> buildChart(String chartName, List<XYChart.Series<Number, Number>> series) {
final NumberAxis xAxis = new NumberAxis();
final NumberAxis yAxis = new NumberAxis();
final ScatterChart<Number, Number> sc = new ScatterChart<Number, Number>(xAxis, yAxis);
sc.setTitle(chartName);
sc.setPrefHeight(450);
sc.setPrefWidth(600);
xAxis.getValueForDisplay(1);
yAxis.getValueForDisplay(2);
sc.getData().addAll(series);
return sc;
}
}
view raw Clustering.java hosted with ❤ by GitHub

13 de jun. de 2017

Reconhecendo dígitos manuscritos em uma aplicação JavaFX usando DeepLearning4J





Já falamos sobre tensorflow e JavaFX no blog em inglês, mas a API Java do tensorflow ainda está incompleta. Uma API madura e melhor documentada é o DeepLearning4J.


Nesse exemplo nos carregamos o modelo treinado na nossa aplicação, criamos um canvas para desenhar e quando o Enter é pressionado, a imagem do canvas é redimensionada e enviada para o modelo já treinado do deeplearning4j para reconhecimento:






A forma que isso "advinha digitos é como o "olá mundo" para aprendizado usando redes neurais. Um neurônio imita o neurônio natural do nosso cérebro e ele tem um "peso" que controla quando o neurônio é ativado (no nosso cérebro acontece químicas para ativar um neurônio).Uma rede neural consiste de muitos neurônios conectados uns aos outros e também organizados em camadas. O que temos que fazer é fornecer dados identificados (com labels) para a nossa rede neural e ajustar os pesos dos neurônio até que ela possa "advinhar" resultados para um valor novo que não sabemos. Esse processo é chamado de treinamento.




Uma vez treinada, nós testamos a rede neural com outros valores também conhecidos para  saber a precisão da rede neural (no nosso caso a precisão é 97.5%!). No nosso case usamos o famoso conjunto de dados  MNIST.


Pelo motivo de termos camadas ocultas entra a camada de entrada de dados e a camada de saída,  nós chamados essas redes neurais de "profundas" (deep neural network) e são usadas no processo de aprendizagem profunda. Nós temos muitos outros conceitos e tipos de redes neurais, eu sugiro você assistis alguns vídeos super interessantes sobre isso no youtube.



E se você estiver ouvindo falar disso pela primeira vez tenha em mente que não será a última!

Se você tentar o código abaixo vai ver que não é tão preciso quanto essa app web, por exemplo. O motivo é que eu não manipulo a imagem precisamente antes de enviar ela para predição, simplesmente redimensionamos ela para ter  28x28 pixels como requerido pelo modelo treinado..

O código da aplicação JavaFX é mostrado abaixo e o projeto completo está no meu github, including o código usando para treinar nossa rede neural que foi tirado dos deeplearning4j examples.

package org.fxapps.deeplearning;
import java.awt.Graphics;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import javafx.application.Application;
import javafx.embed.swing.SwingFXUtils;
import javafx.geometry.Pos;
import javafx.scene.Scene;
import javafx.scene.canvas.Canvas;
import javafx.scene.canvas.GraphicsContext;
import javafx.scene.control.Label;
import javafx.scene.image.ImageView;
import javafx.scene.image.WritableImage;
import javafx.scene.input.KeyCode;
import javafx.scene.input.MouseButton;
import javafx.scene.layout.HBox;
import javafx.scene.layout.VBox;
import javafx.scene.paint.Color;
import javafx.scene.shape.StrokeLineCap;
import javafx.stage.Stage;
public class MnistTestFXApp extends Application {
private final int CANVAS_WIDTH = 150;
private final int CANVAS_HEIGHT = 150;
private NativeImageLoader loader;
private MultiLayerNetwork model;
private Label lblResult;
public static void main(String[] args) throws IOException {
launch();
}
@Override
public void start(Stage stage) throws Exception {
Canvas canvas = new Canvas(CANVAS_WIDTH, CANVAS_HEIGHT);
ImageView imgView = new ImageView();
GraphicsContext ctx = canvas.getGraphicsContext2D();
model = ModelSerializer.restoreMultiLayerNetwork(new File("minist-model.zip"));
loader = new NativeImageLoader(28,28,1,true);
imgView.setFitHeight(100);
imgView.setFitWidth(100);
ctx.setLineWidth(10);
ctx.setLineCap(StrokeLineCap.SQUARE);
lblResult = new Label();
HBox hbBottom = new HBox(10, imgView, lblResult);
VBox root = new VBox(5, canvas, hbBottom);
hbBottom.setAlignment(Pos.CENTER);
root.setAlignment(Pos.CENTER);
Scene scene = new Scene(root, 520, 300);
stage.setScene(scene);
stage.show();
stage.setTitle("Handwritten digits recognition");
canvas.setOnMousePressed(e -> {
ctx.setStroke(Color.WHITE);
ctx.beginPath();
ctx.moveTo(e.getX(), e.getY());
ctx.stroke();
});
canvas.setOnMouseDragged(e -> {
ctx.setStroke(Color.WHITE);
ctx.lineTo(e.getX(), e.getY());
ctx.stroke();
});
canvas.setOnMouseClicked(e -> {
if (e.getButton() == MouseButton.SECONDARY) {
clear(ctx);
}
});
canvas.setOnKeyReleased(e -> {
if(e.getCode() == KeyCode.ENTER) {
BufferedImage scaledImg = getScaledImage(canvas);
imgView.setImage(SwingFXUtils.toFXImage(scaledImg, null));
try {
predictImage(scaledImg);
} catch (Exception e1) {
e1.printStackTrace();
}
}
});
clear(ctx);
canvas.requestFocus();
}
private BufferedImage getScaledImage(Canvas canvas) {
// for a better recognition we should improve this part of how we retrieve the image from the canvas
WritableImage writableImage = new WritableImage(CANVAS_WIDTH, CANVAS_HEIGHT);
canvas.snapshot(null, writableImage);
Image tmp = SwingFXUtils.fromFXImage(writableImage, null).getScaledInstance(28, 28, Image.SCALE_SMOOTH);
BufferedImage scaledImg = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
Graphics graphics = scaledImg.getGraphics();
graphics.drawImage(tmp, 0, 0, null);
graphics.dispose();
return scaledImg;
}
private void clear(GraphicsContext ctx) {
ctx.setFill(Color.BLACK);
ctx.fillRect(0, 0, 300, 300);
}
private void predictImage(BufferedImage img ) throws IOException {
ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0, 1);
INDArray image = loader.asRowVector(img);
imagePreProcessingScaler.transform(image);
INDArray output = model.output(image);
String putStr = output.toString();
lblResult.setText("Prediction: " + model.predict(image)[0] + "\n " + putStr);
}
}