Source code for lecture2notes.end_to_end.slide_classifier

import logging
import os
import shutil

from PIL import Image
from tqdm import tqdm

from ..models.slide_classifier import inference  # pylint: disable=wrong-import-position
from ..models.slide_classifier.custom_nnmodules import *  # noqa: F403,F401
from .helpers import make_dir_if_not_exist

logger = logging.getLogger(__name__)


[docs]def classify_frames( frames_dir, do_move=True, incorrect_threshold=0.60, model_path="model_best.ckpt" ): """Classifies images in a directory using the slide classifier model. Args: frames_dir (str): path to directory containing images to classify do_move (bool, optional): move the images to their sorted folders instead of copying them. Defaults to True. incorrect_threshold (float, optional): the certainty value that the model must be below for a prediction to be marked "probably incorrect". Defaults to 0.60. Returns: [tuple]: (frames_sorted_dir, certainties, percent_wrong) """ model = inference.load_model(model_path) certainties = [] frames_sorted_dir = frames_dir.parents[0] / "frames_sorted" logger.debug("Received inputs:\nframes_dir=" + str(frames_dir)) frames = os.listdir(frames_dir) num_frames = len(frames) num_incorrect = 0 percent_wrong = 0 logger.info("Ready to classify " + str(num_frames) + " frames") frames_tqdm = tqdm(enumerate(frames), total=len(frames), desc="Classifying Frames") for idx, frame in frames_tqdm: # logger.info("Progress: " + str(idx+1) + "/" + str(num_frames)) current_frame_path = os.path.join(frames_dir, frame) # run classification best_guess, best_guess_idx, probs, _ = inference.get_prediction( model, Image.open(current_frame_path), extract_features=False ) # pylint: disable=no-member prob_max_correct = list(probs.values())[best_guess_idx] certainties.append(prob_max_correct) logger.debug("Prediction is " + best_guess) logger.debug("Probabilities are " + str(probs)) if prob_max_correct < incorrect_threshold: num_incorrect = num_incorrect + 1 percent_wrong = (num_incorrect / num_frames) * 100 frames_tqdm.set_postfix( {"num_incorrect": num_incorrect, "percent_wrong": int(percent_wrong)} ) # print(colored(str(prob_max_correct) + " Likely Incorrect", 'red')) # else: # print(colored(str(prob_max_correct) + " Likely Correct", 'green')) if do_move: classified_image_dir = frames_sorted_dir / best_guess make_dir_if_not_exist(classified_image_dir) shutil.move(str(current_frame_path), str(classified_image_dir)) logger.info("Percent frames classified incorrectly: " + str(percent_wrong)) logger.debug("Returning frames_sorted_dir=" + str(frames_sorted_dir)) return frames_sorted_dir, certainties, percent_wrong