Source code for lecture2notes.models.slide_classifier.class_cluster_scikit

import logging
import os
from collections import OrderedDict

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from sklearn.cluster import AffinityPropagation, KMeans
from sklearn.metrics import pairwise_distances_argmin_min

from . import inference

logger = logging.getLogger(__name__)

if os.environ.get("DISPLAY", "") == "":
    logger.debug("No display found. Using non-interactive Agg backend")
    mpl.use("Agg")


[docs]class Cluster: def __init__( self, algorithm_name="kmeans", num_centroids=20, preference=None, damping=0.5, max_iter=200, ): """Set up cluster object by defining necessary variables and asserting that user provided algorithm is supported""" self.vectors = OrderedDict() algorithms = ["kmeans", "affinity_propagation"] if algorithm_name not in algorithms: raise AssertionError self.algorithm_name = algorithm_name self.centroids = None self.algorithm = None self.cost = None self.labels = None self.closest = None self.move_list = None self.closest_filenames = None self.num_centroids = num_centroids self.preference = preference self.damping = damping self.max_iter = max_iter
[docs] def add(self, vector, filename): """Adds a filename and its coresponding feature vector to the cluster object""" self.vectors[filename] = vector
[docs] def get_vectors(self): return self.vectors
[docs] def get_labels(self): return self.labels
[docs] def create_algorithm_if_none(self): """Creates algorithm if it has not been created (if it equals None) based on algorithm_name set in __init__""" if self.algorithm is None: if self.algorithm_name == "kmeans": self.create_kmeans(self.num_centroids) elif self.algorithm_name == "affinity_propagation": self.create_affinity_propagation( self.preference, self.damping, self.max_iter )
[docs] def predict(self, array): """Wrapper function for algorithm.predict. Creates algorithm if it has not been created.""" self.create_algorithm_if_none() return self.algorithm.predict(array)
[docs] def get_vector_array(self): """Return a numpy array of the list of vectors stored in self.vectors""" vector_list = list(self.vectors.values()) vector_array = np.stack(vector_list) return vector_array
[docs] def create_affinity_propagation(self, preference, damping, max_iter, store=True): """Create and fit an affinity propagation cluster""" logger.info("Creating and fitting affinity propagation cluster") vector_array = self.get_vector_array() affinity_propagation = AffinityPropagation( preference=preference, damping=damping, max_iter=max_iter ) affinity_propagation.fit(vector_array) centroids = affinity_propagation.cluster_centers_ labels = affinity_propagation.labels_ if store: self.centroids = centroids self.algorithm = affinity_propagation self.labels = labels return affinity_propagation, centroids, labels
[docs] def create_kmeans(self, num_centroids, store=True): """Create and fit a kmeans cluster""" logger.info("Creating and fitting kmeans cluster") vector_array = self.get_vector_array() kmeans = KMeans(n_clusters=num_centroids, random_state=0) kmeans.fit(vector_array) cost = str(kmeans.inertia_) centroids = kmeans.cluster_centers_ labels = kmeans.labels_ if store: self.centroids = centroids self.algorithm = kmeans self.cost = cost self.labels = labels return kmeans, centroids, cost, labels
[docs] def get_move_list(self): """Creates a dictionary of file names and their coresponding centroid numbers""" if self.move_list is not None: return self.move_list self.create_algorithm_if_none() move_list = {} for idx, filename in enumerate(self.vectors): move_list[filename] = self.labels[idx] self.move_list = move_list return move_list
[docs] def get_num_clusters(self): if self.algorithm_name == "affinity_propagation": cluster_centers_indices = self.algorithm.cluster_centers_indices_ n_clusters_ = len(cluster_centers_indices) return n_clusters_ return self.num_centroids
[docs] def get_closest_sample_filenames_to_centroids(self): """ Return the sample indexes that are closest to each centroid. Ex: If [0,8] is returned then X[0] (X is training data/vectors) is the closest point in X to centroid 0 and X[8] is the closest to centroid 1 """ if self.closest_filenames is not None: return self.closest_filenames vector_array = self.get_vector_array() closest, _ = pairwise_distances_argmin_min(self.centroids, vector_array) self.closest = closest closest_filenames = [] for centroid_number, sample_index in enumerate(closest): # vector = vector_array[sample_index] vector_filenames = list(self.vectors.keys()) filename = vector_filenames[sample_index] # filename = list(self.vectors.keys())[list(self.vectors.values()).index(vector.all())] # get key by value in self.vectors closest_filenames.append(filename) self.closest_filenames = closest_filenames return closest_filenames
[docs] def visualize(self, tensorboard_dir): """Creates tensorboard projection of cluster for simplified viewing and understanding""" logger.info("Visualizing cluster") import torch from torch.utils.tensorboard import SummaryWriter images = [] for current_image, _ in self.vectors.items(): img = Image.open(self.slides_dir / current_image) images.append(inference.transform_image(img).float()) feature_vectors = self.get_vector_array() writer = SummaryWriter(tensorboard_dir) writer.add_embedding( feature_vectors, metadata=self.labels, label_img=torch.cat(images, 0) ) writer.close()
[docs] def calculate_best_k(self, max_k=50): """ Implements elbow method to graph the cost (squared error) as a function of the number of centroids (value of k) The point at which the graph becomes essentially linear is the optimal value of k. Only works if `algorithm` is "kmeans". """ # Elbow method: https://www.geeksforgeeks.org/elbow-method-for-optimal-value-of-k-in-kmeans/ # Other methods: https://en.wikipedia.org/wiki/Determining_the_number_of_clusters_in_a_data_set if self.algorithm_name != "kmeans": raise AssertionError costs = [] for i in range(1, max_k): kmeans, _, cost, _ = self.create_kmeans(num_centroids=i, store=False) costs.append(cost) logger.info("Iteration " + str(i) + ": " + cost) costs = [int(float(cost)) for cost in costs] # plot the cost against K values plt.plot(range(1, max_k), costs) plt.xlabel("Value of K") plt.ylabel("Sqaured Error (Cost)") if mpl.backends.backend == "agg": plt.savefig("best_k_value.png") else: plt.show()