From 35e552ada03ae1ea8c732a79ab4d867201c0af56 Mon Sep 17 00:00:00 2001 From: Sayat Mimar Date: Wed, 5 Jul 2023 14:25:07 -0400 Subject: [PATCH] add new predict code --- .../Codes/IterativePredict_1X.py | 388 +++++++++++------- 1 file changed, 238 insertions(+), 150 deletions(-) diff --git a/histomicstk/segmentationschool/Codes/IterativePredict_1X.py b/histomicstk/segmentationschool/Codes/IterativePredict_1X.py index 039b342..7dd4719 100644 --- a/histomicstk/segmentationschool/Codes/IterativePredict_1X.py +++ b/histomicstk/segmentationschool/Codes/IterativePredict_1X.py @@ -1,60 +1,51 @@ import cv2 import numpy as np import os -import json import sys -import girder_client -# import argparse -# import multiprocessing +import argparse +import multiprocessing import lxml.etree as ET -# import warnings -# import time -# import copy -# from PIL import Image +import warnings +import time +import copy +from PIL import Image import glob -from .xml_to_json import convert_xml_json -# from subprocess import call -# from joblib import Parallel, delayed -# from skimage.io import imread,imsave -# from skimage.segmentation import clear_border +from subprocess import call +from joblib import Parallel, delayed +from skimage.io import imread,imsave +from skimage.segmentation import clear_border from tqdm import tqdm -# from skimage.transform import resize +from skimage.transform import resize from shutil import rmtree -# import matplotlib.pyplot as plt -# from matplotlib import path -# import detectron2 +import matplotlib.pyplot as plt +from matplotlib import path +import detectron2 from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg -# from detectron2.utils.visualizer import Visualizer -# from detectron2.data import MetadataCatalog, DatasetCatalog +from detectron2.utils.visualizer import Visualizer +from detectron2.data import MetadataCatalog, DatasetCatalog from detectron2 import model_zoo -from .get_dataset_list import decode_panoptic +from .get_dataset_list import * from scipy.ndimage.morphology import binary_fill_holes -# import tifffile as ti -import tiffslide as openslide -# from skimage.morphology import binary_erosion, disk -from scipy.ndimage import zoom -# import warnings -import torch - +import openslide +from skimage.morphology import binary_erosion, disk +import warnings from skimage.color import rgb2hsv from skimage.filters import gaussian -# from skimage.segmentation import clear_border +from skimage.segmentation import clear_border + -#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" -NAMES = ['cortical_interstitium','medullary_interstitium','non_globally_sclerotic_glomeruli','globally_sclerotic_glomeruli','tubules','arteries/arterioles'] -# from IterativeTraining import get_num_classes -# from .get_choppable_regions import get_choppable_regions -# from .get_network_performance import get_perf +#from .IterativeTraining import get_num_classes +from .get_choppable_regions import get_choppable_regions +from .get_network_performance import get_perf """ Pipeline code to segment regions from WSI """ -# os.environ['CUDA_VISIBLE_DEVICES']='0,1' # define xml class colormap xml_color = [65280, 16776960,65535, 255, 16711680, 33023] @@ -112,50 +103,31 @@ def predict(args): downsample = int(args.downsampleRateHR**.5) region_size = int(args.boxSize*(downsample)) step = int((region_size-(args.bordercrop*2))*(1-args.overlap_percentHR)) - # gc = girder_client.GirderClient(apiUrl=args.girderApiUrl) - # gc.setToken(args.girderToken) - # project_folder = args.project - # project_dir_id = project_folder.split('/')[-2] - #model_file = args.modelfile - #print(model_file,'here model') - #model_file_id = model_file .split('/')[-2] - - print('Handcoded iteration') + + print('Handcoded iteration') iteration=1 print(iteration) - dirs['xml_save_dir'] = args.base_dir - #real_path = os.path.realpath(args.project) - #print(real_path) + dirs['xml_save_dir'] = args.base_dir + '/' + 'test_data' + if iteration == 'none': print('ERROR: no trained models found \n\tplease use [--option train]') else: # check main directory exists - # make_folder(dirs['outDir']) - # outdir = gc.createFolder(project_directory_id,args.outDir) - # it = gc.createFolder(outdir['_id'],str(iteration)) + make_folder(dirs['outDir']) + make_folder(dirs['xml_save_dir']) # get all WSIs - #WSIs = [] + # WSIs = [] # usable_ext=args.wsi_ext.split(',') # for ext in usable_ext: - # WSIs.extend(glob.glob(args.project + '/*' + ext)) - # print('another one') - - # for file in args.files: - # print(file) - # slidename = file['name'] - # _ = os.system("printf '\n---\n\nFOUND: [{}]\n'".format(slidename)) - # WSIs.append(slidename) - - - # print(len(WSIs), 'number of WSI' ) + # WSIs.extend(glob.glob(args.base_dir + '/' + args.project + dirs['training_data_dir'] + str(iteration) + '/*' + ext)) print('Building network configuration ...\n') - #modeldir = args.project + dirs['modeldir'] + str(iteration) + '/HR' + #modeldir = args.base_dir + '/' + args.project + dirs['modeldir'] + str(iteration) + '/HR' + + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu) - os.environ["CUDA_VISIBLE_DEVICES"]="0,1" - cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_50_3x.yaml")) cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[32],[64],[128], [256], [512], [1024]] @@ -172,11 +144,8 @@ def predict(args): else: cfg.INPUT.MIN_SIZE_TEST=int(region_size/2) cfg.INPUT.MAX_SIZE_TEST=int(region_size/2) - - cfg.MODEL.WEIGHTS = args.modelfile - tc=['G','SG','T','A'] sc=['Ob','C','M','B'] classNum=len(tc)+len(sc)-1 @@ -191,7 +160,6 @@ def predict(args): predictor = DefaultPredictor(cfg) broken_slides=[] for wsi in [args.files]: - # try: # except Exception as e: @@ -204,14 +172,11 @@ def predict(args): extname = extsplit[-1] print(basename) # print(extname) - # try: - slide=openslide.TiffSlide(wsi) - print(wsi,'here/s the silde') - # slide = ti.imread(wsi) - - # except: - # broken_slides.append(wsi) - # continue + try: + slide=openslide.OpenSlide(wsi) + except: + broken_slides.append(wsi) + continue # continue # get image dimensions if extname=='.scn': @@ -225,14 +190,12 @@ def predict(args): offsetx=0 offsety=0 - print(dim_x,dim_y) + fileID=basename.split('/') dirs['fileID'] = fileID[-1] - dirs['extension'] = extname - dirs['file_name'] = wsi.split('/')[-1] - wsiMask = np.zeros([dim_y, dim_x], dtype='uint8') + wsiMask = np.zeros([dim_y, dim_x]).astype(np.uint8) index_y=np.array(range(offsety,dim_y+offsety,step)) index_x=np.array(range(offsetx,dim_x+offsetx,step)) @@ -254,11 +217,14 @@ def predict(args): binary=(g>0.05).astype('bool') binary=binary_fill_holes(binary) + xmlbuilder=XMLBuilder(dirs['xml_save_dir']+'/'+dirs['fileID']+'.xml',xml_color) + print('Segmenting tissue ...\n') totalpatches=len(index_x)*len(index_y) with tqdm(total=totalpatches,unit='image',colour='green',desc='Total WSI progress') as pbar: for i,j in coordinate_pairs(index_y,index_x): - + # for i in tqdm(index_y,unit='strip',colour='green',desc='outer y-index iterator'): + # for j in tqdm(index_x,leave=False,unit='image',colour='blue',desc='inner x-index iterator'): yEnd = min(dim_y+offsety,i+region_size) xEnd = min(dim_x+offsetx,j+region_size) # yStart_small = int(np.round((i-offsety)/resRatio)) @@ -280,27 +246,11 @@ def predict(args): dyS=i dxE=j+xLen dyE=i+yLen - print(xLen,yLen) - print('here is the length') + im=np.array(slide.read_region((dxS,dyS),0,(xLen,yLen)))[:,:,:3] - #print(sys.getsizeof(im), 'first') - #UPSAMPLE - im = zoom(im,(4,4,1),order=1) - print(sys.getsizeof(im), 'second') + panoptic_seg, segments_info = predictor(im)["panoptic_seg"] - del im - torch.cuda.empty_cache() - print(sys.getsizeof(panoptic_seg), 'third') - print(sys.getsizeof(segments_info), 'forth') maskpart=decode_panoptic(panoptic_seg.to("cpu").numpy(),segments_info,'kidney',args) - del panoptic_seg, segments_info - #outImageName=basename+'_'.join(['',str(dxS),str(dyS)]) - #print(sys.getsizeof(maskpart), 'fifth') - #DOWNSAMPLE - maskpart=zoom(maskpart,(0.25,0.25),order=0) - #print(sys.getsizeof(maskpart), 'sixth') - - # imsave(outImageName+'_p.png',maskpart) if dxE != dim_x: maskpart[:,-int(args.bordercrop/2):]=0 if dyE != dim_y: @@ -325,9 +275,7 @@ def predict(args): wsiMask[dyS:dyE,dxS:dxE]=np.maximum(maskpart, wsiMask[dyS:dyE,dxS:dxE]) - - del maskpart - torch.cuda.empty_cache() + # wsiMask[dyS:dyE,dxS:dxE]=maskpart # print('showing mask') @@ -341,17 +289,15 @@ def predict(args): # xmlbuilder.sew(args) # xmlbuilder.dump_to_xml(args,offsetx,offsety) if extname=='.scn': - print('here writing 1') xml_suey(wsiMask=wsiMask, dirs=dirs, args=args, classNum=classNum, downsample=downsample,glob_offset=[offsetx,offsety]) else: - print('here writing 2') xml_suey(wsiMask=wsiMask, dirs=dirs, args=args, classNum=classNum, downsample=downsample,glob_offset=[0,0]) print('\n\n\033[92;5mPlease correct the xml annotations found in: \n\t' + dirs['xml_save_dir']) - print('\nthen place them in: \n\t'+ dirs['training_data_dir'] + str(iteration) + '/') + print('\nthen place them in: \n\t'+ args.base_dir + '/' + args.project + dirs['training_data_dir'] + str(iteration) + '/') print('\nand run [--option train]\033[0m\n') print('The following slides were not openable by openslide:') print(broken_slides) @@ -364,7 +310,7 @@ def coordinate_pairs(v1,v2): for j in v2: yield i,j def get_iteration(args): - currentmodels=os.listdir(args.base_dir) + currentmodels=os.listdir(args.base_dir + '/' + args.project + '/MODELS/') if not currentmodels: return 'none' else: @@ -389,12 +335,8 @@ def get_test_model(modeldir): return ''.join([modeldir,'/model_',maxmodel,'.pth']) def make_folder(directory): - print(directory,'predict dir') - #if not os.path.exists(directory): - try: + if not os.path.exists(directory): os.makedirs(directory) # make directory if it does not exit already # make new directory - except: - print('folder exists!') def restart_line(): # for printing chopped image labels in command line sys.stdout.write('\r') @@ -402,7 +344,7 @@ def restart_line(): # for printing chopped image labels in command line def getWsi(path): #imports a WSI import openslide - slide = openslide.TiffSlide(path) + slide = openslide.OpenSlide(path) return slide def file_len(fname): # get txt file length (number of lines) @@ -429,38 +371,18 @@ def xml_suey(wsiMask, dirs, args, classNum, downsample,glob_offset): # print output print('\t working on: annotationID ' + str(value)) # get only 1 class binary mask - binary_mask = np.zeros(np.shape(wsiMask),dtype='uint8') + binary_mask = np.zeros(np.shape(wsiMask)).astype('uint8') binary_mask[wsiMask == value] = 1 # add mask to xml pointsList = get_contour_points(binary_mask, args=args, downsample=downsample,value=value,offset={'X':glob_offset[0],'Y':glob_offset[1]}) - for i in range(len(pointsList)): + for i in range(np.shape(pointsList)[0]): pointList = pointsList[i] Annotations = xml_add_region(Annotations=Annotations, pointList=pointList, annotationID=value) # save xml - folder = args.base_dir - girder_folder_id = folder.split('/')[-2] - _ = os.system("printf 'Using data from girder_client Folder: {}\n'".format(folder)) - file_name = dirs['file_name'] - print(file_name) - gc = girder_client.GirderClient(apiUrl=args.girderApiUrl) - gc.setToken(args.girderToken) - files = list(gc.listItem(girder_folder_id)) - # dict to link filename to gc id - item_dict = dict() - for file in files: - d = {file['name']:file['_id']} - item_dict.update(d) - print(item_dict) - print(item_dict[file_name]) - annots = convert_xml_json(Annotations, NAMES) - for annot in annots: - _ = gc.post(path='annotation',parameters={'itemId':item_dict[file_name]}, data = json.dumps(annot)) - print('uploating layers') - print('annotation uploaded...\n') - - + print(dirs['xml_save_dir']+'/'+dirs['fileID']+'.xml') + xml_save(Annotations=Annotations, filename=dirs['xml_save_dir']+'/'+dirs['fileID']+'.xml') def get_contour_points(mask, args, downsample,value, offset={'X': 0,'Y': 0}): # returns a dict pointList with point 'X' and 'Y' values @@ -469,7 +391,7 @@ def get_contour_points(mask, args, downsample,value, offset={'X': 0,'Y': 0}): pointsList = [] #maskPoints2=copy.deepcopy(maskPoints) - for j in np.array(range(len(maskPoints))): + for j in np.array(range(np.shape(maskPoints)[0])): if len(maskPoints[j])>2: #m=np.squeeze(np.asarray(maskPoints2[j])) #xMax=np.max(m[:,1]) @@ -486,11 +408,11 @@ def get_contour_points(mask, args, downsample,value, offset={'X': 0,'Y': 0}): if cv2.contourArea(maskPoints[j]) > args.min_size[value-1]: pointList = [] - for i in np.array(range(0,len(maskPoints[j]),4)): + for i in np.array(range(0,np.shape(maskPoints[j])[0],4)): point = {'X': (maskPoints[j][i][0][0] * downsample) + offset['X'], 'Y': (maskPoints[j][i][0][1] * downsample) + offset['Y']} pointList.append(point) pointsList.append(pointList) - return pointsList + return np.array(pointsList) ### functions for building an xml tree of annotations ### def xml_create(): # create new xml tree @@ -525,14 +447,180 @@ def xml_add_region(Annotations, pointList, annotationID=-1, regionID=None): # ad ET.SubElement(Vertices, 'Vertex', attrib={'X': str(pointList[0]['X']), 'Y': str(pointList[0]['Y']), 'Z': '0'}) return Annotations -# def xml_save(Annotations, filename): -# xml_data = ET.tostring(Annotations, pretty_print=True) -# #xml_data = Annotations.toprettyxml() -# f = open(filename, 'w') -# f.write(xml_data.decode()) -# f.close() - -# def read_xml(filename): -# # import xml file -# tree = ET.parse(filename) -# root = tree.getroot() +def xml_save(Annotations, filename): + xml_data = ET.tostring(Annotations, pretty_print=True) + #xml_data = Annotations.toprettyxml() + f = open(filename, 'wb') + f.write(xml_data) + f.close() + +def read_xml(filename): + # import xml file + tree = ET.parse(filename) + root = tree.getroot() + + +class XMLBuilder(): + def __init__(self,out_file,class_colors): + self.dump_contours={'1':[],'2':[],'3':[],'4':[],'5':[]} + self.merge_contours={'1':[],'2':[],'3':[],'4':[],'5':[]} + self.out_file=out_file + self.class_colors=class_colors + def unique_pairs(self,n): + for i in range(n): + for j in range(i+1, n): + yield i, j + def deconstruct(self,mask,offsetx,offsety,args): + classes_in_mask=np.unique(mask) + classes_in_mask=classes_in_mask[classes_in_mask>0] + for value in classes_in_mask: + submask=np.array(mask==value).astype('uint8') + contours, hierarchy = cv2.findContours(submask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + contours=np.array(contours) + for contour in contours: + merge_or_dump=False + for point in contour: + if point[0][0]<15 or point[0][1]<15 or point[0][0]>args.boxSize-15 or point[0][1]>args.boxSize-15: + merge_or_dump=True + break + points=np.asarray(contour) + points[:,0,0]+=offsetx + points[:,0,1]+=offsety + + if merge_or_dump: + self.merge_contours[str(value)].append({'contour':np.squeeze(points,axis=1),'annotationID':value}) + else: + self.dump_contours[str(value)].append({'contour':np.squeeze(points,axis=1),'annotationID':value}) + + def sew(self,args): + for cID in range(1,args.classNum): + print('Merging class... '+ str(cID)) + + did_merge=True + while did_merge: + + did_merge=self.check_and_merge_once(cID) + print('\n') + def check_and_merge_once(self,cID): + contours_at_value=self.merge_contours[str(cID)] + total=len(contours_at_value) + + + print('Total contours... '+ str(total),end='\r') + for idx1,idx2 in self.unique_pairs(total): + containPath=path.Path(contours_at_value[idx1]['contour']) + # print(containPath.contains_points(contour2['contour'])) + ovlpts=containPath.contains_points(contours_at_value[idx2]['contour']) + + if any(ovlpts): + mergePath=path.Path(contours_at_value[idx2]['contour']) + merged_verts=np.concatenate((containPath.vertices,mergePath.vertices),axis=0) + merged_path=path.Path(merged_verts) + bMinX=np.min(merged_verts[:,1]).astype('int32') + bMaxX=np.max(merged_verts[:,1]).astype('int32') + bMinY=np.min(merged_verts[:,0]).astype('int32') + bMaxY=np.max(merged_verts[:,0]).astype('int32') + # testim=np.zeros((bMaxX,bMaxY)).astype('uint8') + # cv2.fillPoly(testim,[np.array(mergePath.vertices).astype('int32')],255) + # # plt.imshow(testim) + # # plt.title('mergee') + # # plt.show() + # cv2.fillPoly(testim,[np.array(containPath.vertices).astype('int32')],128) + # plt.imshow(testim) + # plt.title('merge and contain') + # plt.show() + + testim=np.zeros((bMaxX-bMinX,bMaxY-bMinY)).astype('uint8') + testim=np.pad(testim,((0,1),(0,1))) + + #add offsets back + cvl=[np.array(containPath.vertices).astype('int32')] + mvl=[np.array(mergePath.vertices).astype('int32')] + m_dvl=[np.array(merged_path.vertices).astype('int32')] + cvl[0][:,1]-=bMinX + cvl[0][:,0]-=bMinY + mvl[0][:,1]-=bMinX + mvl[0][:,0]-=bMinY + m_dvl[0][:,1]-=bMinX + m_dvl[0][:,0]-=bMinY + + cv2.fillPoly(testim,cvl,1) + cv2.fillPoly(testim,mvl,1) + # plt.imshow(testim) + # plt.title('merged') + # plt.show() + contours, hierarchy = cv2.findContours(testim, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS) + points=np.asarray(contours[0]) + + points[:,0,0]+=bMinY + points[:,0,1]+=bMinX + + self.merge_contours[str(cID)].pop(idx2) + self.merge_contours[str(cID)].pop(idx1) + self.merge_contours[str(cID)].append({'contour':np.squeeze(points,axis=1),'annotationID':cID}) + + return True + return False + # testim=np.zeros((bMaxX-bMinX,bMaxY-bMinY)).astype('uint8') + # cv2.fillPoly(testim,contours,255) + # plt.imshow(testim) + # plt.title('merged') + # plt.show() + + + # input('1') + + def dump_to_xml(self,args,offsetx,offsety): + # make xml + self.Annotations = ET.Element('Annotations') + + # add annotation + for i in range(args.classNum)[1:]: # exclude background class + print('\t working on: annotationID ' + str(i)) + Annotations = self.xml_add_annotation(annotationID=i) + # for dump_contour in self.merge_contours[str(i)]: + # pointList=dump_contour['contour'] + # pointList[:,0]+=offsetx + # pointList[:,1]+=offsety + # self.xml_add_region(pointList=pointList, annotationID=i) + for dump_contour in self.dump_contours[str(i)]: + pointList=dump_contour['contour'] + pointList[:,0]+=offsetx + pointList[:,1]+=offsety + self.xml_add_region(pointList=pointList, annotationID=i) + self.xml_save() + + + def xml_add_annotation(self, annotationID=None): # add new annotation + # add new Annotation to Annotations + # defualts to new annotationID + if annotationID == None: # not specified + annotationID = len(self.Annotations.findall('Annotation')) + 1 + Annotation = ET.SubElement(self.Annotations, 'Annotation', attrib={'Type': '4', + 'Visible': '1', 'ReadOnly': '0', 'Incremental': '0', 'LineColorReadOnly': '0', + 'LineColor': str(self.class_colors[annotationID-1]), 'Id': str(annotationID), 'NameReadOnly': '0'}) + Regions = ET.SubElement(Annotation, 'Regions') + # return Annotations + + def xml_add_region(self,pointList, annotationID=-1, regionID=None): # add new region to annotation + # add new Region to Annotation + # defualts to last annotationID and new regionID + Annotation = self.Annotations.find("Annotation[@Id='" + str(annotationID) + "']") + Regions = Annotation.find('Regions') + if regionID == None: # not specified + regionID = len(Regions.findall('Region')) + 1 + Region = ET.SubElement(Regions, 'Region', attrib={'NegativeROA': '0', 'ImageFocus': '-1', 'DisplayId': '1', 'InputRegionId': '0', 'Analyze': '0', 'Type': '0', 'Id': str(regionID)}) + Vertices = ET.SubElement(Region, 'Vertices') + for point in pointList: # add new Vertex + ET.SubElement(Vertices, 'Vertex', attrib={'X': str(point[0]), 'Y': str(point[1]), 'Z': '0'}) + # add connecting point + ET.SubElement(Vertices, 'Vertex', attrib={'X': str(pointList[0][0]), 'Y': str(pointList[0][1]), 'Z': '0'}) + # return Annotations + + def xml_save(self): + xml_data = ET.tostring(self.Annotations, pretty_print=True) + #xml_data = Annotations.toprettyxml() + print('Writing... ' + self.out_file) + f = open(self.out_file, 'wb') + f.write(xml_data) + f.close()