Der K-Means-Clustering-Algorithmus in Java

1. Übersicht

Clustering ist ein Überbegriff für eine Klasse unbeaufsichtigter Algorithmen, um Gruppen von Dingen, Personen oder Ideen zu entdecken, die eng miteinander verbunden sind .

In dieser scheinbar einfachen Einzeiler-Definition haben wir einige Schlagworte gesehen. Was genau ist Clustering? Was ist ein unbeaufsichtigter Algorithmus?

In diesem Tutorial werden wir zunächst einige Konzepte beleuchten. Dann werden wir sehen, wie sie sich in Java manifestieren können.

2. Unüberwachte Algorithmen

Bevor wir die meisten Lernalgorithmen verwenden, sollten wir ihnen einige Beispieldaten zuführen und dem Algorithmus erlauben, aus diesen Daten zu lernen. In der Terminologie des maschinellen Lernens nennen wir diese Beispieldatensatz-Trainingsdaten. Außerdem wird der gesamte Prozess als der Trainingsprozess bekannt.

Auf jeden Fall können wir Lernalgorithmen anhand des Umfangs der Überwachung klassifizieren, die sie während des Trainingsprozesses benötigen. Die zwei Haupttypen von Lernalgorithmen in dieser Kategorie sind:

  • Überwachtes Lernen : In überwachten Algorithmen sollten die Trainingsdaten die tatsächliche Lösung für jeden Punkt enthalten. Wenn wir beispielsweise unseren Spamfilteralgorithmus trainieren möchten, geben wir sowohl die Beispiel-E-Mails als auch deren Bezeichnung, dh Spam oder Nicht-Spam, an den Algorithmus weiter. Mathematisch gesehen werden wir das f (x) aus einem Trainingssatz ableiten, der sowohl xs als auch ys enthält.
  • Unüberwachtes Lernen : Wenn die Trainingsdaten keine Beschriftungen enthalten, ist der Algorithmus unbeaufsichtigt. Zum Beispiel haben wir viele Daten über Musiker und wir werden Gruppen ähnlicher Musiker in den Daten entdecken.

3. Clustering

Clustering ist ein unbeaufsichtigter Algorithmus, um Gruppen ähnlicher Dinge, Ideen oder Personen zu entdecken. Im Gegensatz zu überwachten Algorithmen trainieren wir keine Clustering-Algorithmen mit Beispielen bekannter Labels. Stattdessen versucht das Clustering, Strukturen innerhalb eines Trainingssatzes zu finden, bei denen kein Punkt der Daten die Bezeichnung ist.

3.1. K-Means Clustering

K-Means ist ein Clustering-Algorithmus mit einer grundlegenden Eigenschaft: Die Anzahl der Cluster wird im Voraus definiert . Zusätzlich zu K-Means gibt es andere Arten von Clustering-Algorithmen wie Hierarchical Clustering, Affinity Propagation oder Spectral Clustering.

3.2. Wie K-Means funktioniert

Angenommen, unser Ziel ist es, einige ähnliche Gruppen in einem Datensatz zu finden, wie:

K-Means beginnt mit k zufällig platzierten Zentroiden. Zentroide sind, wie der Name schon sagt, die Mittelpunkte der Cluster . Zum Beispiel fügen wir hier vier zufällige Schwerpunkte hinzu:

Dann ordnen wir jeden vorhandenen Datenpunkt seinem nächsten Schwerpunkt zu:

Nach der Zuweisung verschieben wir die Schwerpunkte an die durchschnittliche Position der ihm zugewiesenen Punkte. Denken Sie daran, Zentroide sollen die Mittelpunkte von Clustern sein:

Die aktuelle Iteration endet jedes Mal, wenn wir die Zentroide verschoben haben. Wir wiederholen diese Iterationen, bis sich die Zuordnung zwischen mehreren aufeinanderfolgenden Iterationen nicht mehr ändert:

Wenn der Algorithmus beendet wird, werden diese vier Cluster wie erwartet gefunden. Nachdem wir nun wissen, wie K-Means funktioniert, implementieren wir es in Java.

3.3. Funktionsdarstellung

Bei der Modellierung verschiedener Trainingsdatensätze benötigen wir eine Datenstruktur, um Modellattribute und ihre entsprechenden Werte darzustellen. Ein Musiker kann beispielsweise ein Genre-Attribut mit einem Wert wie Rock haben . Normalerweise verwenden wir den Begriff Feature, um die Kombination eines Attributs und seines Werts zu bezeichnen.

Um einen Datensatz für einen bestimmten Lernalgorithmus vorzubereiten, verwenden wir normalerweise einen gemeinsamen Satz numerischer Attribute, mit denen verschiedene Elemente verglichen werden können. Wenn wir beispielsweise zulassen, dass unsere Benutzer jeden Künstler mit einem Genre markieren, können wir am Ende des Tages zählen, wie oft jeder Künstler mit einem bestimmten Genre markiert wird:

Der Merkmalsvektor für einen Künstler wie Linkin Park ist [Rock -> 7890, Nu-Metal -> 700, Alternative -> 520, Pop -> 3]. Wenn wir also einen Weg finden könnten, Attribute als numerische Werte darzustellen, können wir einfach zwei verschiedene Elemente, z. B. Künstler, vergleichen, indem wir ihre entsprechenden Vektoreinträge vergleichen.

Da numerische Vektoren so vielseitige Datenstrukturen sind, werden wir Features darstellen, die sie verwenden . So implementieren wir Feature-Vektoren in Java:

public class Record { private final String description; private final Map features; // constructor, getter, toString, equals and hashcode }

3.4. Ähnliche Gegenstände finden

In jeder Iteration von K-Means benötigen wir eine Möglichkeit, den nächstgelegenen Schwerpunkt zu jedem Element im Datensatz zu finden. Eine der einfachsten Methoden zur Berechnung des Abstands zwischen zwei Merkmalsvektoren ist die Verwendung des euklidischen Abstands. Der euklidische Abstand zwischen zwei Vektoren wie [p1, q1] und [p2, q2] ist gleich:

Lassen Sie uns diese Funktion in Java implementieren. Erstens die Abstraktion:

public interface Distance { double calculate(Map f1, Map f2); }

Zusätzlich zum euklidischen Abstand gibt es andere Ansätze, um den Abstand oder die Ähnlichkeit zwischen verschiedenen Elementen wie dem Pearson-Korrelationskoeffizienten zu berechnen . Diese Abstraktion erleichtert das Umschalten zwischen verschiedenen Entfernungsmetriken.

Sehen wir uns die Implementierung für die euklidische Distanz an:

public class EuclideanDistance implements Distance { @Override public double calculate(Map f1, Map f2) { double sum = 0; for (String key : f1.keySet()) { Double v1 = f1.get(key); Double v2 = f2.get(key); if (v1 != null && v2 != null) { sum += Math.pow(v1 - v2, 2); } } return Math.sqrt(sum); } }

Zunächst berechnen wir die Summe der quadratischen Differenzen zwischen den entsprechenden Einträgen. Dann berechnen wir durch Anwenden der sqrt- Funktion den tatsächlichen euklidischen Abstand.

3.5. Schwerpunktdarstellung

Zentroide befinden sich im selben Raum wie normale Features, sodass wir sie ähnlich wie Features darstellen können:

public class Centroid { private final Map coordinates; // constructors, getter, toString, equals and hashcode }

Jetzt, da wir einige notwendige Abstraktionen haben, ist es Zeit, unsere K-Means-Implementierung zu schreiben. Hier ist ein kurzer Blick auf unsere Methodensignatur:

public class KMeans { private static final Random random = new Random(); public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { // omitted } }
    

Lassen Sie uns diese Methodensignatur aufschlüsseln:

  • Der Datensatz besteht aus einer Reihe von Merkmalsvektoren. Da jeder Merkmalsvektor ein Datensatz ist, lautet der Dataset-Typ Liste
  • Der Parameter k bestimmt die Anzahl der Cluster, die wir im Voraus bereitstellen sollten
  • distance encapsulates the way we're going to calculate the difference between two features
  • K-Means terminates when the assignment stops changing for a few consecutive iterations. In addition to this termination condition, we can place an upper bound for the number of iterations, too. The maxIterations argument determines that upper bound
  • When K-Means terminates, each centroid should have a few assigned features, hence we're using a Map as the return type. Basically, each map entry corresponds to a cluster

3.6. Centroid Generation

The first step is to generate k randomly placed centroids.

Although each centroid can contain totally random coordinates, it's a good practice to generate random coordinates between the minimum and maximum possible values for each attribute. Generating random centroids without considering the range of possible values would cause the algorithm to converge more slowly.

First, we should compute the minimum and maximum value for each attribute, and then, generate the random values between each pair of them:

private static List randomCentroids(List records, int k) { List centroids = new ArrayList(); Map maxs = new HashMap(); Map mins = new HashMap(); for (Record record : records) { record.getFeatures().forEach((key, value) -> ); } Set attributes = records.stream() .flatMap(e -> e.getFeatures().keySet().stream()) .collect(toSet()); for (int i = 0; i < k; i++) { Map coordinates = new HashMap(); for (String attribute : attributes) { double max = maxs.get(attribute); double min = mins.get(attribute); coordinates.put(attribute, random.nextDouble() * (max - min) + min); } centroids.add(new Centroid(coordinates)); } return centroids; }

Now, we can assign each record to one of these random centroids.

3.7. Assignment

First off, given a Record, we should find the centroid nearest to it:

private static Centroid nearestCentroid(Record record, List centroids, Distance distance) { double minimumDistance = Double.MAX_VALUE; Centroid nearest = null; for (Centroid centroid : centroids) { double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates()); if (currentDistance < minimumDistance) { minimumDistance = currentDistance; nearest = centroid; } } return nearest; }

Each record belongs to its nearest centroid cluster:

private static void assignToCluster(Map
    
      clusters, Record record, Centroid centroid) { clusters.compute(centroid, (key, list) -> { if (list == null) { list = new ArrayList(); } list.add(record); return list; }); }
    

3.8. Centroid Relocation

If, after one iteration, a centroid does not contain any assignments, then we won't relocate it. Otherwise, we should relocate the centroid coordinate for each attribute to the average location of all assigned records:

private static Centroid average(Centroid centroid, List records) { if (records == null || records.isEmpty()) { return centroid; } Map average = centroid.getCoordinates(); records.stream().flatMap(e -> e.getFeatures().keySet().stream()) .forEach(k -> average.put(k, 0.0)); for (Record record : records) { record.getFeatures().forEach( (k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue) ); } average.forEach((k, v) -> average.put(k, v / records.size())); return new Centroid(average); }

Since we can relocate a single centroid, now it's possible to implement the relocateCentroids method:

private static List relocateCentroids(Map
    
      clusters) { return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList()); }
    

This simple one-liner iterates through all centroids, relocates them, and returns the new centroids.

3.9. Putting It All Together

In each iteration, after assigning all records to their nearest centroid, first, we should compare the current assignments with the last iteration.

If the assignments were identical, then the algorithm terminates. Otherwise, before jumping to the next iteration, we should relocate the centroids:

public static Map
    
      fit(List records, int k, Distance distance, int maxIterations) { List centroids = randomCentroids(records, k); Map
     
       clusters = new HashMap(); Map
      
        lastState = new HashMap(); // iterate for a pre-defined number of times for (int i = 0; i < maxIterations; i++) { boolean isLastIteration = i == maxIterations - 1; // in each iteration we should find the nearest centroid for each record for (Record record : records) { Centroid centroid = nearestCentroid(record, centroids, distance); assignToCluster(clusters, record, centroid); } // if the assignments do not change, then the algorithm terminates boolean shouldTerminate = isLastIteration || clusters.equals(lastState); lastState = clusters; if (shouldTerminate) { break; } // at the end of each iteration we should relocate the centroids centroids = relocateCentroids(clusters); clusters = new HashMap(); } return lastState; }
      
     
    

4. Example: Discovering Similar Artists on Last.fm

Last.fm builds a detailed profile of each user's musical taste by recording details of what the user listens to. In this section, we're going to find clusters of similar artists. To build a dataset appropriate for this task, we'll use three APIs from Last.fm:

  1. API to get a collection of top artists on Last.fm.
  2. Another API to find popular tags. Each user can tag an artist with something, e.g. rock. So, Last.fm maintains a database of those tags and their frequencies.
  3. And an API to get the top tags for an artist, ordered by popularity. Since there are many such tags, we'll only keep those tags that are among the top global tags.

4.1. Last.fm's API

To use these APIs, we should get an API Key from Last.fm and send it in every HTTP request. We're going to use the following Retrofit service for calling those APIs:

public interface LastFmService { @GET("/2.0/?method=chart.gettopartists&format=json&limit=50") Call topArtists(@Query("page") int page); @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1") Call topTagsFor(@Query("artist") String artist); @GET("/2.0/?method=chart.gettoptags&format=json&limit=100") Call topTags(); // A few DTOs and one interceptor }

So, let's find the most popular artists on Last.fm:

// setting up the Retrofit service private static List getTop100Artists() throws IOException { List artists = new ArrayList(); // Fetching the first two pages, each containing 50 records. for (int i = 1; i <= 2; i++) { artists.addAll(lastFm.topArtists(i).execute().body().all()); } return artists; }

Similarly, we can fetch the top tags:

private static Set getTop100Tags() throws IOException { return lastFm.topTags().execute().body().all(); }

Finally, we can build a dataset of artists along with their tag frequencies:

private static List datasetWithTaggedArtists(List artists, Set topTags) throws IOException { List records = new ArrayList(); for (String artist : artists) { Map tags = lastFm.topTagsFor(artist).execute().body().all(); // Only keep popular tags. tags.entrySet().removeIf(e -> !topTags.contains(e.getKey())); records.add(new Record(artist, tags)); } return records; }

4.2. Forming Artist Clusters

Now, we can feed the prepared dataset to our K-Means implementation:

List artists = getTop100Artists(); Set topTags = getTop100Tags(); List records = datasetWithTaggedArtists(artists, topTags); Map
    
      clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000); // Printing the cluster configuration clusters.forEach((key, value) -> { System.out.println("-------------------------- CLUSTER ----------------------------"); // Sorting the coordinates to see the most significant tags first. System.out.println(sortedCentroid(key)); String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet())); System.out.print(members); System.out.println(); System.out.println(); });
    

If we run this code, then it would visualize the clusters as text output:

------------------------------ CLUSTER ----------------------------------- Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... } David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica, Fleetwood Mac, The Beatles, Elton John, The Clash ------------------------------ CLUSTER ----------------------------------- Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... } Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion, Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake ------------------------------ CLUSTER ----------------------------------- Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... } Tame Impala, The Black Keys ------------------------------ CLUSTER ----------------------------------- Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... } Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars, Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes, Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna, Adele, Lady Gaga, Jonas Brothers ------------------------------ CLUSTER ----------------------------------- Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... } Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons, The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire, Arctic Monkeys ------------------------------ CLUSTER ----------------------------------- Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... } Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers, Avicii, Kygo, Marshmello, David Guetta, Major Lazer ------------------------------ CLUSTER ----------------------------------- Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... } Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz, Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park, Red Hot Chili Peppers, Muse

Since centroid coordinations are sorted by the average tag frequency, we can easily spot the dominant genre in each cluster. For example, the last cluster is a cluster of a good old rock-bands, or the second one is filled with rap stars.

Although this clustering makes sense, for the most part, it's not perfect since the data is merely collected from user behavior.

5. Visualization

A few moments ago, our algorithm visualized the cluster of artists in a terminal-friendly way. If we convert our cluster configuration to JSON and feed it to D3.js, then with a few lines of JavaScript, we'll have a nice human-friendly Radial Tidy-Tree:

We have to convert our Map to a JSON with a similar schema like this d3.js example.

6. Number of Clusters

One of the fundamental properties of K-Means is the fact that we should define the number of clusters in advance. So far, we used a static value for k, but determining this value can be a challenging problem. There are two common ways to calculate the number of clusters:

  1. Domain Knowledge
  2. Mathematical Heuristics

If we're lucky enough that we know so much about the domain, then we might be able to simply guess the right number. Otherwise, we can apply a few heuristics like Elbow Method or Silhouette Method to get a sense on the number of clusters.

Before going any further, we should know that these heuristics, although useful, are just heuristics and may not provide clear-cut answers.

6.1. Elbow Method

To use the elbow method, we should first calculate the difference between each cluster centroid and all its members. As we group more unrelated members in a cluster, the distance between the centroid and its members goes up, hence the cluster quality decreases.

One way to perform this distance calculation is to use the Sum of Squared Errors. Sum of squared errors or SSE is equal to the sum of squared differences between a centroid and all its members:

public static double sse(Map
    
      clustered, Distance distance) { double sum = 0; for (Map.Entry
     
       entry : clustered.entrySet()) { Centroid centroid = entry.getKey(); for (Record record : entry.getValue()) { double d = distance.calculate(centroid.getCoordinates(), record.getFeatures()); sum += Math.pow(d, 2); } } return sum; }
     
    

Then, we can run the K-Means algorithm for different values of kand calculate the SSE for each of them:

List records = // the dataset; Distance distance = new EuclideanDistance(); List sumOfSquaredErrors = new ArrayList(); for (int k = 2; k <= 16; k++) { Map
    
      clusters = KMeans.fit(records, k, distance, 1000); double sse = Errors.sse(clusters, distance); sumOfSquaredErrors.add(sse); }
    

At the end of the day, it's possible to find an appropriate k by plotting the number of clusters against the SSE:

Usually, as the number of clusters increases, the distance between cluster members decreases. However, we can't choose any arbitrary large values for k, since having multiple clusters with just one member defeats the whole purpose of clustering.

Die Idee hinter der Ellbogenmethode besteht darin, einen geeigneten Wert für k so zu finden , dass die SSE um diesen Wert herum dramatisch abnimmt. Zum Beispiel kann k = 9 hier ein guter Kandidat sein.

7. Fazit

In diesem Tutorial haben wir zunächst einige wichtige Konzepte des maschinellen Lernens behandelt. Dann wurden wir mit der Mechanik des K-Means-Clustering-Algorithmus vertraut gemacht. Schließlich haben wir eine einfache Implementierung für K-Means geschrieben, unseren Algorithmus mit einem realen Datensatz von Last.fm getestet und das Clustering-Ergebnis auf eine schöne grafische Weise visualisiert.

Wie üblich ist der Beispielcode in unserem GitHub-Projekt verfügbar. Probieren Sie ihn also unbedingt aus!