19 de jun. de 2017

Algoritmo k-means e árvores de decisão com Weka e JavaFX

Weka é uma das ferramentas mais conhecidas para Machine Learning em Java, que também tem uma ótima API Java que inclui APIs para agrupamento (ou clustering) de dados usando o algoritmo k-means. Usando JavaFX é possível visualizar dados não-classificados e classsificar os dados usando as APIs usando APIs Weka e então visualizar o resultado em um gráfico JavaFX como o gráfico "scatter".


Nesse post vamos mostrar como  uma simples aplicação JavaFX permite você carregar dados, mostrar os dados sem distinção de categoria usando um gráfico, então usaremos weka para classificar os dados usando k-means e finalmnte vamos classificar os dados usando uma árvore de decisão. Usaremos os dados do arquivo iris.2D.arff que vem junto com o download do Weka

K-means clustering usando Weka é realmente simples e requer somente algumas linhas de código como você pode ver nesse post. Na nossa applicação iremos construir 3 gráficos para os dados de flores íris:

  1. Dados sem distinção de classe (sem séries)
  2. Os dados com a classificação real
  3. Dados clusterizados usando Weka

Como você pode ver os dados foram agrupados de uma forma que é bem próxima dos dados reais (os dados com valores coletados na vida real). O código para construir os dados agrupados é:

private List<Series<Number, Number>> buildClusteredSeries(Instances data) throws Exception {
List<XYChart.Series<Number, Number>> clusteredSeries = new ArrayList<>();
// to buld the cluster we remove the class information
Remove remove = new Remove();
remove.setAttributeIndices("3");
remove.setInputFormat(data);
Instances dataToBeClustered = Filter.useFilter(data, remove);
SimpleKMeans kmeans = new SimpleKMeans();
kmeans.setSeed(10);
kmeans.setPreserveInstancesOrder(true);
kmeans.setNumClusters(3);
kmeans.buildClusterer(dataToBeClustered);
data.deleteStringAttributes();
int[] assignments = kmeans.getAssignments();
for (int c = 0; c < 3; c++) {
XYChart.Series<Number, Number> series = new XYChart.Series<>();
series.setName("Cluster " + c);
clusteredSeries.add(series);
}
for (int i = 0; i < assignments.length; i++) {
int clusterNum = assignments[i];
clusteredSeries.get(clusterNum).getData().add(instancetoChartData(data.get(i)));
}
return clusteredSeries;
}

Depois de montar esses três gráficos também modifiquei todo o código para adicionar um classificar com árvores de decisão usando a implementação do  algoritmo J48. Logo após os gráficos você pode ver a árvore de decisão que montamos a partir dos dados:



Quando você clica no gráfico sem classificação você verá que novos dados são adicionados e ele será classificado nos gráficos superiores usando a árvore de decisão e o algoritmo k-means de clustering.

Nós usamos a árvore de decisão que geramos para classificar os ados e também o cluster. Na imagem acima o cluster classifica alguns dados de forma diferente do que é classificado com a árvore de decisão.

datafile = new BufferedReader(new FileReader(DATA_SET));
data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
tree = new J48();
tree.buildClassifier(data);
Instance instance = new DenseInstance(3);
instance.setDataset(data);
instance.setValue(0, xValue.doubleValue());
instance.setValue(1, yValue.doubleValue());
double predictedClass = tree.classifyInstance(instance);
instance.setValue(2, pred

Eu acho que é particularmente interessante como é fácil visualizar dados com JavaFX. O código completo para esse projeto pode ser encontrado no meu  github, mas aqui está o código da classe principal:



package org.fxapps.ml;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javafx.application.Application;
import javafx.geometry.Point2D;
import javafx.scene.Scene;
import javafx.scene.chart.Axis;
import javafx.scene.chart.NumberAxis;
import javafx.scene.chart.ScatterChart;
import javafx.scene.chart.XYChart;
import javafx.scene.chart.XYChart.Data;
import javafx.scene.chart.XYChart.Series;
import javafx.scene.control.Button;
import javafx.scene.control.Label;
import javafx.scene.control.Separator;
import javafx.scene.layout.GridPane;
import javafx.scene.layout.VBox;
import javafx.scene.paint.Color;
import javafx.scene.text.Font;
import javafx.scene.text.FontPosture;
import javafx.scene.text.FontWeight;
import javafx.scene.text.Text;
import javafx.scene.text.TextAlignment;
import javafx.stage.Stage;
import weka.classifiers.trees.J48;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
public class Clustering extends Application {
private static final int NUMBER_OF_CLASSES = 3;
private static final String DATA_SET = "/opt/weka/weka-3-7-12/data/iris.2D.arff";
private ScatterChart<Number, Number> clusteredChart;
private ScatterChart<Number, Number> realDataChart;
private ScatterChart<Number, Number> noClassificationChart;
private static int swapIndex = 0;
private int[][] swapColorsCombinations = { { 0, 1 }, { 0, 2 }, { 1, 2 } };
private J48 tree;
private Instances data;
public static void main(String[] args) throws Exception {
launch();
}
@Override
public void start(Stage stage) throws Exception {
loadData();
tree = new J48();
tree.buildClassifier(data);
noClassificationChart = buildChart("No Classification (click to add new data)", buildSingleSeries());
clusteredChart = buildChart("Clustered", buildClusteredSeries());
realDataChart = buildChart("Real Data (+ Decision Tree classification for new data)", buildLabeledSeries());
noClassificationChart.setOnMouseClicked(e -> {
Axis<Number> xAxis = noClassificationChart.getXAxis();
Axis<Number> yAxis = noClassificationChart.getYAxis();
Point2D mouseSceneCoords = new Point2D(e.getSceneX(), e.getSceneY());
double x = xAxis.sceneToLocal(mouseSceneCoords).getX();
double y = yAxis.sceneToLocal(mouseSceneCoords).getY();
Number xValue = xAxis.getValueForDisplay(x);
Number yValue = yAxis.getValueForDisplay(y);
reloadSeries(xValue, yValue);
});
Label lblDecisionTreeTitle = new Label("Decision Tree generated for the Iris dataset:");
Text txtTree = new Text(tree.toString());
Button btnRestore = new Button("Restore original data");
Button btnSwapColors = new Button("Swap clustered chart colors");
VBox vbDecisionTree = new VBox(10, lblDecisionTreeTitle, new Separator(), txtTree, btnRestore, btnSwapColors);
btnRestore.setOnAction(e -> {
loadData();
reloadSeries();
});
btnSwapColors.setOnAction(e -> swapClusteredChartSeriesColors());
lblDecisionTreeTitle.setTextFill(Color.DARKRED);
lblDecisionTreeTitle.setFont(Font.font(Font.getDefault().getFamily(), FontWeight.BOLD, FontPosture.ITALIC, 16));
txtTree.setTranslateX(100);
txtTree.setFont(Font.font(Font.getDefault().getFamily(), FontWeight.BOLD, FontPosture.ITALIC, 14));
txtTree.setLineSpacing(1);
txtTree.setTextAlignment(TextAlignment.LEFT);
vbDecisionTree.setTranslateY(20);
vbDecisionTree.setTranslateX(20);
GridPane gpRoot = new GridPane();
gpRoot.add(realDataChart, 0, 0);
gpRoot.add(clusteredChart, 1, 0);
gpRoot.add(noClassificationChart, 0, 1);
gpRoot.add(vbDecisionTree, 1, 1);
stage.setScene(new Scene(gpRoot));
stage.setTitle("Íris dataset clustering and visualization");
stage.show();
}
private void loadData() {
BufferedReader datafile;
try {
datafile = new BufferedReader(new FileReader(DATA_SET));
data = new Instances(datafile);
data.setClassIndex(data.numAttributes() - 1);
} catch (Exception e) {
System.out.println("Exception loading data... Leaving");
e.printStackTrace();
System.exit(0);
}
}
private void reloadSeries(Number xValue, Number yValue) {
try {
Instance instance = new DenseInstance(NUMBER_OF_CLASSES);
instance.setDataset(data);
instance.setValue(0, xValue.doubleValue());
instance.setValue(1, yValue.doubleValue());
double predictedClass = tree.classifyInstance(instance);
instance.setValue(2, predictedClass);
data.add(instance);
reloadSeries();
} catch (Exception e) {
e.printStackTrace();
}
}
private void reloadSeries() {
try {
noClassificationChart.getData().clear();
clusteredChart.getData().clear();
realDataChart.getData().clear();
noClassificationChart.getData().addAll(buildSingleSeries());
clusteredChart.getData().addAll(buildClusteredSeries());
realDataChart.getData().addAll(buildLabeledSeries());
} catch (Exception e) {
e.printStackTrace();
}
}
private void swapClusteredChartSeriesColors() {
List<Series<Number, Number>> clusteredSeries = new ArrayList<>();
// we have to copy the original data to swap the series
clusteredChart.getData().forEach(serie -> {
Series<Number, Number> series = new Series<>();
serie.getData().stream().map(d -> new Data<Number, Number>(d.getXValue(), d.getYValue()))
.forEach(series.getData()::add);
clusteredSeries.add(series);
});
int i = swapColorsCombinations[swapIndex][0];
int j = swapColorsCombinations[swapIndex][1];
Collections.swap(clusteredSeries, i, j);
clusteredChart.getData().clear();
clusteredChart.getData().addAll(clusteredSeries);
swapIndex = swapIndex == NUMBER_OF_CLASSES - 1 ? 0 : swapIndex + 1;
}
private List<XYChart.Series<Number, Number>> buildSingleSeries() {
XYChart.Series<Number, Number> singleSeries = new XYChart.Series<>();
data.stream().map(this::instancetoChartData).forEach(singleSeries.getData()::add);
singleSeries.setName("no classification");
return Arrays.asList(singleSeries);
}
private List<Series<Number, Number>> buildLabeledSeries() {
List<XYChart.Series<Number, Number>> realSeries = new ArrayList<>();
Attribute irisClasses = data.attribute(2);
data.stream().collect(Collectors.groupingBy(d -> {
int i = (int) d.value(2);
return irisClasses.value(i);
})).forEach((e, instances) -> {
XYChart.Series<Number, Number> series = new XYChart.Series<>();
series.setName(e);
instances.stream().map(this::instancetoChartData).forEach(series.getData()::add);
realSeries.add(series);
});
return realSeries;
}
private List<Series<Number, Number>> buildClusteredSeries() throws Exception {
List<XYChart.Series<Number, Number>> clusteredSeries = new ArrayList<>();
// to build the cluster we remove the class information
Remove remove = new Remove();
remove.setAttributeIndices("3");
remove.setInputFormat(data);
Instances dataToBeClustered = Filter.useFilter(data, remove);
SimpleKMeans kmeans = new SimpleKMeans();
kmeans.setSeed(10);
kmeans.setPreserveInstancesOrder(true);
kmeans.setNumClusters(3);
kmeans.buildClusterer(dataToBeClustered);
IntStream.range(0, 3).mapToObj(i -> {
Series<Number, Number> newSeries = new XYChart.Series<>();
newSeries.setName(String.valueOf(i));
return newSeries;
}).forEach(clusteredSeries::add);
int[] assignments = kmeans.getAssignments();
for (int i = 0; i < assignments.length; i++) {
int clusterNum = assignments[i];
clusteredSeries.get(clusterNum).getData().add(instancetoChartData(data.get(i)));
}
return clusteredSeries;
}
private XYChart.Data<Number, Number> instancetoChartData(Instance i) {
return new XYChart.Data<Number, Number>(i.value(0), i.value(1));
}
private ScatterChart<Number, Number> buildChart(String chartName, List<XYChart.Series<Number, Number>> series) {
final NumberAxis xAxis = new NumberAxis();
final NumberAxis yAxis = new NumberAxis();
final ScatterChart<Number, Number> sc = new ScatterChart<Number, Number>(xAxis, yAxis);
sc.setTitle(chartName);
sc.setPrefHeight(450);
sc.setPrefWidth(600);
xAxis.getValueForDisplay(1);
yAxis.getValueForDisplay(2);
sc.getData().addAll(series);
return sc;
}
}
view raw Clustering.java hosted with ❤ by GitHub

2 comentários:

  1. Parabéns pelo blog, excelente artigo meu caro.

    ResponderExcluir
  2. Casinos Near Me - Go GoCasino
    Best casinos near 가입시꽁머니사이트 me · Lucky Nugget. $3,000. · Harrah's Las Vegas. $3,000. · 장원도메인 Caesars Palace. $3,000. · Harrah's 먹튀 Las 텍사스 홀덤 룰 Vegas. $3,000. 챗 룰렛 · Caesars

    ResponderExcluir