Source code for PVBM.DiscSegmenter

import os
import requests
import onnxruntime as ort
import numpy as np
import PIL
from PIL import Image
from torchvision import transforms
import cv2
import gdown


[docs] class DiscSegmenter: """A class that performs optic disc segmentation.""" def __init__(self): """Initialize the DiscSegmenter class with image size and model path.""" self.img_size = 512 script_dir = os.path.dirname(os.path.abspath(__file__)) self.model_path = os.path.join(script_dir, "lunetv2_odc.onnx") self.download_model() def download_model(self): """Download the ONNX model from Google Drive if it does not exist.""" file_id = "116EEFBn7qr_LpCBb8GBuyzpa_KGp4xPX" url = f"https://drive.google.com/uc?id={file_id}" print(f"Model path: {self.model_path}") if not os.path.exists(self.model_path): print(f"Downloading model from Google Drive...") gdown.download(url, self.model_path, quiet=False) print(f"Model downloaded to {self.model_path}") else: print("Model already exists, skipping download.") def find_biggest_contour(self, contours): """ Find the biggest contour in the provided list of contours. :param contours: List of contours. :type contours: list of numpy arrays :return: The center, radius, and the biggest contour. :rtype: tuple """ radius = -1 final_contour = contours[0] (x, y), radius = cv2.minEnclosingCircle(final_contour) center = (int(x), int(y)) radius = int(radius) for contour in contours: (x, y), radius_ = cv2.minEnclosingCircle(contour) if radius_ > radius: center = (int(x), int(y)) radius = int(radius_) final_contour = contour return center, radius, final_contour
[docs] def post_processing(self, segmentation, max_roi_size): """ Post-process the segmentation result to extract relevant zones. :param segmentation: Segmentation result as a numpy array. :type segmentation: numpy array :param max_roi_size: Maximum size of the region of interest. :type max_roi_size: int :return: The center, radius, region of interest, and zones ABC. :rtype: tuple """ segmentation = np.array(segmentation, dtype=np.uint8) try: contours, hierarchy = cv2.findContours(segmentation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) center, radius, contour = self.find_biggest_contour(contours) one_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) cv2.circle(one_radius, center, radius, (0, 255, 0), -1) two_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) cv2.circle(two_radius, center, int(radius * 2), (0, 255, 0), -1) three_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) cv2.circle(three_radius, center, radius * 3, (0, 255, 0), -1) roi = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) cv2.circle(roi, center, max_roi_size, (0, 255, 0), -1) except: one_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) two_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) three_radius = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) roi = np.zeros((segmentation.shape[0], segmentation.shape[1], 3), dtype=np.uint8) center = (segmentation.shape[0] // 2, segmentation.shape[1] // 2) radius = 0 zones_ABC = np.zeros((segmentation.shape[0], segmentation.shape[1], 4)) zones_ABC[:, :, :3] = one_radius zones_ABC[:, :, 0] = two_radius[:, :, 1] zones_ABC[:, :, 2] = three_radius[:, :, 1] zones_ABC[:, :, 3] = np.maximum(one_radius[:, :, 1], three_radius[:, :, 1]) / 2 return center, radius, roi, zones_ABC
[docs] def segment(self, image_path): """ Perform the optic disc segmentation given an image path. :param image_path: Path to the image. :type image_path: str :return: A PIL Image containing the Optic Disc segmentation. :rtype: PIL.Image """ session = ort.InferenceSession(self.model_path) input_name = session.get_inputs()[0].name img_orig = PIL.Image.open(image_path) original_size = img_orig.size image = img_orig.resize((self.img_size, self.img_size)) image = transforms.ToTensor()(image) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image = normalize(image) image_np = image.numpy() image_np = np.expand_dims(image_np, axis=0) outputs = session.run(None, {input_name: image_np}) od = outputs[0][0, 0] > 0 return PIL.Image.fromarray(np.array(od, dtype=np.uint8) * 255).resize(original_size, PIL.Image.Resampling.NEAREST)