Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 185 additions & 108 deletions histomicstk/segmentationschool/Codes/IterativePredict_1X.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import cv2
from patchify import patchify, unpatchify
import numpy as np
import os
import json
Expand Down Expand Up @@ -36,8 +37,7 @@
from scipy.ndimage import zoom
# import warnings
import torch


from torch.utils.data import DataLoader

from skimage.color import rgb2hsv
from skimage.filters import gaussian
Expand Down Expand Up @@ -93,6 +93,25 @@ def decode_panoptic(image,segments_info,organType,args):
return out.astype('uint8')


class NewPredictor(DefaultPredictor):
def __call__(self, original_images:list):
with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258
# Apply pre-processing to image.
if self.input_format == "RGB":
# whether the model expects BGR inputs or RGB
for original_image in original_images:
original_image = original_image[:, :, ::-1]
inputs_list = []
for original_image in original_images:
height, width = original_image.shape[:2]
original_image = zoom(original_image,(4,4,1),order=1)
image = self.aug.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

inputs = {"image":image, "height": height, "width": width}
inputs_list.append(inputs)
predictions = self.model(inputs_list)
return predictions

def predict(args):
# define folder structure dict
Expand Down Expand Up @@ -188,7 +207,8 @@ def predict(args):
# cfg.MODEL.PANOPTIC_FPN.INSTANCES_CONFIDENCE_THRESH = args.roi_thresh
# cfg.MODEL.PANOPTIC_FPN.OVERLAP_THRESH = 1

predictor = DefaultPredictor(cfg)
#predictor = DefaultPredictor(cfg)
new_predictor = NewPredictor(cfg)
broken_slides=[]
for wsi in [args.files]:
print(wsi.split('/')[-1])
Expand All @@ -208,8 +228,7 @@ def predict(args):
slide=openslide.TiffSlide(wsi)
print(wsi,'here/s the silde')
# slide = ti.imread(wsi)

# except:
# except:
# broken_slides.append(wsi)
# continue
# continue
Expand Down Expand Up @@ -254,86 +273,162 @@ def predict(args):
binary=binary_fill_holes(binary)

print('Segmenting tissue ...\n')
totalpatches=len(index_x)*len(index_y)
with tqdm(total=totalpatches,unit='image',colour='green',desc='Total WSI progress') as pbar:
for i,j in coordinate_pairs(index_y,index_x):


image_height, image_width = slide.dimensions

channel_count = 3
patch_height, patch_width = 2048,2048
print(step)
print(thumbIm.shape)
patch_shape = (patch_height, patch_width, channel_count)
patches = patchify(np.array(slide.get_thumbnail((fullSize[0],fullSize[1]))), patch_shape, step=step)
patches[0][0]
count = 1
zoomed_patches = []
#output_patches = np.empty((patch_height, patch_width)).astype(np.uint8)
maskparts=[]
for i in range(patches.shape[0]):
for j in range(patches.shape[1]):
im = patches[i, j, 0]

print(im.shape)
zoomed_patches.append(im)
if count%3 == 0:
#print(len(zoomed_patches))
#predictions = (new_predictor(zoomed_patches))
predictions = new_predictor(zoomed_patches)
for prediction in predictions:
panoptic_seg, segments_info = prediction["panoptic_seg"]
maskpart = decode_panoptic(panoptic_seg.to("cpu").numpy(),segments_info,'kidney',args)
maskpart = zoom(maskpart,(0.25,0.25),order=0)
maskparts.append(maskpart)
zoomed_patches = []
# output_patches[i,j,0] = maskpart

count+=1



# zoomed_patches=[]
# for i in range(patches.shape[0]):
# for j in range(patches.shape[1]):
# im = patches[i, j, 0]
# im = zoom(im,(4,4,1),order=1)
# print(im.shape)
# zoomed_patches.append(im)
# print(len(zoomed_patches))
# totalpatches=len(index_x)*len(index_y)
# all_patches = []
# with tqdm(total=totalpatches,unit='image',colour='green',desc='Total WSI progress') as pbar:
# for i,j in coordinate_pairs(index_y,index_x):

yEnd = min(dim_y+offsety,i+region_size)
xEnd = min(dim_x+offsetx,j+region_size)
# yStart_small = int(np.round((i-offsety)/resRatio))
# yStop_small = int(np.round(((i-offsety)+args.boxSize)/resRatio))
# xStart_small = int(np.round((j-offsetx)/resRatio))
# xStop_small = int(np.round(((j-offsetx)+args.boxSize)/resRatio))
yStart_small = int(np.round((i-offsety)/resRatio))
yStop_small = int(np.round(((yEnd-offsety))/resRatio))
xStart_small = int(np.round((j-offsetx)/resRatio))
xStop_small = int(np.round(((xEnd-offsetx))/resRatio))
box_total=(xStop_small-xStart_small)*(yStop_small-yStart_small)
pbar.update(1)
if np.sum(binary[yStart_small:yStop_small,xStart_small:xStop_small])>(args.white_percent*box_total):

xLen=xEnd-j
yLen=yEnd-i

dxS=j
dyS=i
dxE=j+xLen
dyE=i+yLen
print(xLen,yLen)
print('here is the length')
im=np.array(slide.read_region((dxS,dyS),0,(xLen,yLen)))[:,:,:3]
#print(sys.getsizeof(im), 'first')
#UPSAMPLE
im = zoom(im,(4,4,1),order=1)
print(sys.getsizeof(im), 'second')
panoptic_seg, segments_info = predictor(im)["panoptic_seg"]
del im
torch.cuda.empty_cache()
print(sys.getsizeof(panoptic_seg), 'third')
print(sys.getsizeof(segments_info), 'forth')
maskpart=decode_panoptic(panoptic_seg.to("cpu").numpy(),segments_info,'kidney',args)
del panoptic_seg, segments_info
#outImageName=basename+'_'.join(['',str(dxS),str(dyS)])
#print(sys.getsizeof(maskpart), 'fifth')
#DOWNSAMPLE
maskpart=zoom(maskpart,(0.25,0.25),order=0)
#print(sys.getsizeof(maskpart), 'sixth')
# yEnd = min(dim_y+offsety,i+region_size)
# xEnd = min(dim_x+offsetx,j+region_size)
# # yStart_small = int(np.round((i-offsety)/resRatio))
# # yStop_small = int(np.round(((i-offsety)+args.boxSize)/resRatio))
# # xStart_small = int(np.round((j-offsetx)/resRatio))
# # xStop_small = int(np.round(((j-offsetx)+args.boxSize)/resRatio))
# yStart_small = int(np.round((i-offsety)/resRatio))
# yStop_small = int(np.round(((yEnd-offsety))/resRatio))
# xStart_small = int(np.round((j-offsetx)/resRatio))
# xStop_small = int(np.round(((xEnd-offsetx))/resRatio))
# box_total=(xStop_small-xStart_small)*(yStop_small-yStart_small)
# pbar.update(1)
# if np.sum(binary[yStart_small:yStop_small,xStart_small:xStop_small])>(args.white_percent*box_total):

# xLen=xEnd-j
# yLen=yEnd-i

# dxS=j
# dyS=i
# dxE=j+xLen
# dyE=i+yLen
# print(xLen,yLen)
# print('here is the length')
# im=np.array(slide.read_region((dxS,dyS),0,(xLen,yLen)))[:,:,:3]
# #im = zoom(im,(4,4,1),order=1)
# all_patches.append(im)
# #print(sys.getsizeof(im), 'first')
# #UPSAMPLE
# #im = zoom(im,(4,4,1),order=1)
# count=0
# test_1,test_2 = new_predictor(all_patches)#["panoptic_seg"]
# print(len(test_1))
# print(len(test_2))
# with tqdm(total=totalpatches,unit='image',colour='green',desc='Total WSI progress') as pbar:
# for i,j in coordinate_pairs(index_y,index_x):
# yEnd = min(dim_y+offsety,i+region_size)
# xEnd = min(dim_x+offsetx,j+region_size)
# # yStart_small = int(np.round((i-offsety)/resRatio))
# # yStop_small = int(np.round(((i-offsety)+args.boxSize)/resRatio))
# # xStart_small = int(np.round((j-offsetx)/resRatio))
# # xStop_small = int(np.round(((j-offsetx)+args.boxSize)/resRatio))
# yStart_small = int(np.round((i-offsety)/resRatio))
# yStop_small = int(np.round(((yEnd-offsety))/resRatio))
# xStart_small = int(np.round((j-offsetx)/resRatio))
# xStop_small = int(np.round(((xEnd-offsetx))/resRatio))
# box_total=(xStop_small-xStart_small)*(yStop_small-yStart_small)
# pbar.update(1)
# if np.sum(binary[yStart_small:yStop_small,xStart_small:xStop_small])>(args.white_percent*box_total):

# xLen=xEnd-j
# yLen=yEnd-i

# dxS=j
# dyS=i
# dxE=j+xLen
# dyE=i+yLen

# panoptic_seg, segments_info =test_1[count][0]["panoptic_seg"],test_2[count]
# count+=1
# print(test_1,test_2,'newmodel')
# #del im
# # torch.cuda.empty_cache()

# imsave(outImageName+'_p.png',maskpart)
if dxE != dim_x:
maskpart[:,-int(args.bordercrop/2):]=0
if dyE != dim_y:
maskpart[-int(args.bordercrop/2):,:]=0

if dxS != offsetx:
maskpart[:,:int(args.bordercrop/2)]=0
if dyS != offsety:
maskpart[:int(args.bordercrop/2),:]=0

# xmlbuilder.deconstruct(maskpart,dxS-offsetx,dyS-offsety,args)
# plt.subplot(121)
# plt.imshow(im)
# plt.subplot(122)
# plt.imshow(maskpart)
# plt.show()

dyE-=offsety
dyS-=offsety
dxS-=offsetx
dxE-=offsetx

wsiMask[dyS:dyE,dxS:dxE]=np.maximum(maskpart,
wsiMask[dyS:dyE,dxS:dxE])

del maskpart
torch.cuda.empty_cache()
# wsiMask[dyS:dyE,dxS:dxE]=maskpart

# print('showing mask')
# plt.imshow(wsiMask)
# plt.show()
slide.close()
print('\n\nStarting XML construction: ')
# maskpart=decode_panoptic(panoptic_seg.to("cpu").numpy(),segments_info,'kidney',args)
# #del panoptic_seg, segments_info
# #outImageName=basename+'_'.join(['',str(dxS),str(dyS)])
# #print(sys.getsizeof(maskpart), 'fifth')
# #DOWNSAMPLE
# #maskpart=zoom(maskpart,(0.25,0.25),order=0)
# #print(sys.getsizeof(maskpart), 'sixth')

# # imsave(outImageName+'_p.png',maskpart)
# if dxE != dim_x:
# maskpart[:,-int(args.bordercrop/2):]=0
# if dyE != dim_y:
# maskpart[-int(args.bordercrop/2):,:]=0

# if dxS != offsetx:
# maskpart[:,:int(args.bordercrop/2)]=0
# if dyS != offsety:
# maskpart[:int(args.bordercrop/2),:]=0

# # xmlbuilder.deconstruct(maskpart,dxS-offsetx,dyS-offsety,args)
# # plt.subplot(121)
# # plt.imshow(im)
# # plt.subplot(122)
# # plt.imshow(maskpart)
# # plt.show()

# dyE-=offsety
# dyS-=offsety
# dxS-=offsetx
# dxE-=offsetx

# wsiMask[dyS:dyE,dxS:dxE]=np.maximum(maskpart,
# wsiMask[dyS:dyE,dxS:dxE])

# del maskpart
# #torch.cuda.empty_cache()
# # wsiMask[dyS:dyE,dxS:dxE]=maskpart

# # print('showing mask')
# # plt.imshow(wsiMask)
# # plt.show()
# slide.close()
# print('\n\nStarting XML construction: ')

# wsiMask=np.swapaxes(wsiMask,0,1)
# print('swapped axes')
Expand Down Expand Up @@ -439,25 +534,7 @@ def xml_suey(wsiMask, dirs, args, classNum, downsample,glob_offset):

# save xml
folder = args.base_dir
girder_folder_id = folder.split('/')[-2]
_ = os.system("printf 'Using data from girder_client Folder: {}\n'".format(folder))
file_name = dirs['fileID'] + dirs['extension']
print(file_name)
gc = girder_client.GirderClient(apiUrl=args.girderApiUrl)
gc.setToken(args.girderToken)
files = list(gc.listItem(girder_folder_id))
# dict to link filename to gc id
item_dict = dict()
for file in files:
d = {file['name']:file['_id']}
item_dict.update(d)
print(item_dict)
print(item_dict[file_name])
annots = convert_xml_json(Annotations, NAMES)
for annot in annots:
_ = gc.post(path='annotation',parameters={'itemId':item_dict[file_name]}, data = json.dumps(annot))
print('uploating layers')
print('annotation uploaded...\n')
xml_save(Annotations=Annotations, filename=folder+'/test_data/'+dirs['fileID']+'.xml')



Expand Down Expand Up @@ -524,12 +601,12 @@ def xml_add_region(Annotations, pointList, annotationID=-1, regionID=None): # ad
ET.SubElement(Vertices, 'Vertex', attrib={'X': str(pointList[0]['X']), 'Y': str(pointList[0]['Y']), 'Z': '0'})
return Annotations

# def xml_save(Annotations, filename):
# xml_data = ET.tostring(Annotations, pretty_print=True)
# #xml_data = Annotations.toprettyxml()
# f = open(filename, 'w')
# f.write(xml_data.decode())
# f.close()
def xml_save(Annotations, filename):
xml_data = ET.tostring(Annotations, pretty_print=True)
#xml_data = Annotations.toprettyxml()
f = open(filename, 'wb')
f.write(xml_data)
f.close()

# def read_xml(filename):
# # import xml file
Expand Down
35 changes: 35 additions & 0 deletions histomicstk/segmentationschool/slurm_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/sh
#SBATCH --account=pinaki.sarder
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=12
#SBATCH --mem-per-cpu=16gb
#SBATCH --partition=gpu
#SBATCH --gpus=a100
#SBATCH --time=72:00:00
#SBATCH --output=./slurm_log.out
#SBATCH --job-name="segmentation_frozen"
echo "SLURM_JOBID="$SLURM_JOBID
echo "SLURM_JOB_NODELIST="$SLURM_JOB_NODELIST
echo "SLURM_NNODES="$SLURM_NNODES
echo "SLURMTMPDIR="$SLURMTMPDIR

echo "working directory = "$SLURM_SUBMIT_DIR
ulimit -s unlimited
module load singularity
ls
ml

# Add your userid here:
USER=sayat.mimar
# Add the name of the folder containing WSIs here
PROJECT=multic_segment

CODESDIR=/blue/pinaki.sarder/sayat.mimar/multi_compart_segment/Multi-Compartment-Segmentation/histomicstk/segmentationschool

DATADIR=/$CODESDIR/test_data
MODELDIR=$CODESDIR/pretrained_model

CONTAINER=$CODESDIR/multic_segment.sif
CUDA_LAUNCH_BLOCKING=1
singularity exec --nv -B $(pwd):/exec/,$DATADIR/:/data,$MODELDIR/:/model/ $CONTAINER python3 /exec/segmentation_school.py --option 'predict' --base_dir $CODESDIR --modelfile /model/model_0214999.pth --files /data/18-142_PAS_6of6.svs