Skip to content
Open

SAM2 #155

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public enum ImageSegmentaionModels
deeplabv3,
u2net,
modnet,
segment_anything1
segment_anything1,
segment_anything2
}

public class AiliaImageSegmentationSample : MonoBehaviour
Expand Down Expand Up @@ -69,6 +70,7 @@ public class AiliaImageSegmentationSample : MonoBehaviour
// Segment Anything Model
private SegmentationModel segModel;
private SegmentAnythingModel samModel;
private SegmentAnything2Model sam2Model;
private bool isDraggingForBox = false;
private Rect boxRect = new ();

Expand Down Expand Up @@ -122,13 +124,18 @@ void UISetup()
raw_image = UICanvas.transform.Find("RawImage").GetComponent<RawImage>();
raw_image.gameObject.SetActive(false);

if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1) {
mode_text.text = "ailia Image Segmentation\n" +
switch (imageSegmentaionModels)
{
case ImageSegmentaionModels.segment_anything1:
case ImageSegmentaionModels.segment_anything2:
mode_text.text = "ailia Image Segmentation\n" +
"Left/Right click: positive/negative point\n" +
"Middle click: drag to define border box\n" +
"Space key down to reset";
} else {
mode_text.text = "ailia Image Segmentation";
break;
default:
mode_text.text = "ailia Image Segmentation";
break;
}
}
Color32 [] VerticalFlip(Color32[] inputImage, int InputWidth, int InputHeight){
Expand All @@ -151,8 +158,14 @@ void Update()
}
if (modelPrepared && !modelAllocated)
{
if (imageSegmentaionModels != ImageSegmentaionModels.segment_anything1){
segModel.AllocateInputAndOutputTensor(imageSegmentaionModels, AiliaImageSource.Width, AiliaImageSource.Height);
switch (imageSegmentaionModels)
{
case ImageSegmentaionModels.segment_anything1:
case ImageSegmentaionModels.segment_anything2:
break;
default:
segModel.AllocateInputAndOutputTensor(imageSegmentaionModels, AiliaImageSource.Width, AiliaImageSource.Height);
break;
}
modelAllocated = true;
}
Expand All @@ -164,10 +177,20 @@ void Update()
// When space key down, draw original image
if (Input.GetKey(KeyCode.Space))
{
if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1){
samModel.ResetClickPoint();
boxRect = new();
oneshot = true;
switch (imageSegmentaionModels)
{
case ImageSegmentaionModels.segment_anything1:
samModel.ResetClickPoint();
boxRect = new();
oneshot = true;
break;
case ImageSegmentaionModels.segment_anything2:
sam2Model.ResetClickPoint();
boxRect = new();
oneshot = true;
break;
default:
break;
}
blendMaterial.SetFloat(blendFlagId, 0);
}
Expand Down Expand Up @@ -206,25 +229,38 @@ void Update()
long start_time2 = DateTime.UtcNow.Ticks / TimeSpan.TicksPerMillisecond;

bool result = false;
if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1)
switch (imageSegmentaionModels)
{
if (!samModel.EmbeddingExist()){
if (samModel.GetClickPoints(0).Length == 0)
{
samModel.AddClickPoint(inputImageWidth / 4, inputImageHeight / 4 + 30);
case ImageSegmentaionModels.segment_anything1:
if (!samModel.EmbeddingExist()){
if (samModel.GetClickPoints(0).Length == 0)
{
samModel.AddClickPoint(inputImageWidth / 4, inputImageHeight / 4 + 30);
}
}
}
if (camera_mode || !samModel.EmbeddingExist()){
samModel.ProcessEmbedding(inputImage, inputImageWidth, inputImageHeight);
}
samModel.ProcessMask(inputImage, inputImageWidth, inputImageHeight);
result = samModel.success;
}
else
{
result = segModel.ProcessFrame(imageSegmentaionModels, inputImage, inputImageWidth, inputImageHeight);
}

if (camera_mode || !samModel.EmbeddingExist()){
samModel.ProcessEmbedding(inputImage, inputImageWidth, inputImageHeight);
}
samModel.ProcessMask(inputImage, inputImageWidth, inputImageHeight);
result = samModel.success;
break;
case ImageSegmentaionModels.segment_anything2:
if (!sam2Model.EmbeddingExist()){
if (sam2Model.GetClickPoints(0).Length == 0)
{
sam2Model.AddClickPoint(inputImageWidth / 4, inputImageHeight / 4 + 30);
}
}
if (camera_mode || !sam2Model.EmbeddingExist()){
sam2Model.ProcessEmbedding(inputImage, inputImageWidth, inputImageHeight);
}
sam2Model.ProcessMask(inputImage, inputImageWidth, inputImageHeight);
result = sam2Model.success;
break;
default:
result = segModel.ProcessFrame(imageSegmentaionModels, inputImage, inputImageWidth, inputImageHeight);
break;
}

if (!result)
{
Expand All @@ -235,14 +271,22 @@ void Update()
// convert result to image
Color32[] outputImage;
int outputWidth, outputHeight;
if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1)
switch (imageSegmentaionModels)
{
outputImage = samModel.visualizedResult.GetPixels32();
outputWidth = samModel.visualizedResult.width;
outputHeight = samModel.visualizedResult.height;
}
else{
(outputImage, outputWidth, outputHeight) = segModel.PostProcesss(imageSegmentaionModels, inputImageWidth, inputImageHeight);
case ImageSegmentaionModels.segment_anything1:
outputImage = samModel.visualizedResult.GetPixels32();
outputWidth = samModel.visualizedResult.width;
outputHeight = samModel.visualizedResult.height;
break;
case ImageSegmentaionModels.segment_anything2:
outputImage = sam2Model.visualizedResult.GetPixels32();
outputWidth = sam2Model.visualizedResult.width;
outputHeight = sam2Model.visualizedResult.height;
break;
default:
(outputImage, outputWidth, outputHeight) = segModel.PostProcesss(imageSegmentaionModels, inputImageWidth, inputImageHeight);
break;

}

long end_time2 = DateTime.UtcNow.Ticks / TimeSpan.TicksPerMillisecond;
Expand Down Expand Up @@ -288,29 +332,39 @@ void CreateAiliaNet(ImageSegmentaionModels modelType, bool gpu_mode = true)
ailia_download.DownloaderProgressPanel = UICanvas.transform.Find("DownloaderProgressPanel").gameObject;
List<ModelDownloadURL> urlList = null;

if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1)
switch (imageSegmentaionModels)
{
samModel = new SegmentAnythingModel();
urlList = samModel.GetModelURLs(imageSegmentaionModels);
case ImageSegmentaionModels.segment_anything1:
samModel = new SegmentAnythingModel();
urlList = samModel.GetModelURLs(imageSegmentaionModels);
break;
case ImageSegmentaionModels.segment_anything2:
sam2Model = new SegmentAnything2Model();
urlList = sam2Model.GetModelURLs(imageSegmentaionModels);
break;
default:
segModel = new SegmentationModel();
urlList = segModel.GetModelURLs(imageSegmentaionModels);
break;
}
else
{
segModel = new SegmentationModel();
urlList = segModel.GetModelURLs(imageSegmentaionModels);
}

StartCoroutine(ailia_download.DownloadWithProgressFromURL(urlList, () =>
{
if (imageSegmentaionModels == ImageSegmentaionModels.segment_anything1)
{
modelPrepared = samModel.InitializeModels(imageSegmentaionModels, gpu_mode);
envName = samModel.EnvironmentName();
}
else
{
modelPrepared = segModel.InitializeModels(imageSegmentaionModels, gpu_mode);
envName = segModel.EnvironmentName();
}
switch (imageSegmentaionModels)
{
case ImageSegmentaionModels.segment_anything1:
modelPrepared = samModel.InitializeModels(imageSegmentaionModels, gpu_mode);
envName = samModel.EnvironmentName();
break;
case ImageSegmentaionModels.segment_anything2:
modelPrepared = sam2Model.InitializeModels(imageSegmentaionModels, gpu_mode);
envName = sam2Model.EnvironmentName();
break;
default:
modelPrepared = segModel.InitializeModels(imageSegmentaionModels, gpu_mode);
envName = segModel.EnvironmentName();
break;
}
}));
}

Expand Down Expand Up @@ -340,6 +394,7 @@ void LoadImage(ImageSegmentaionModels imageSegmentaionModels, AiliaImageSource a
ailiaImageSource.CreateSource(image_source_modnet);
break;
case ImageSegmentaionModels.segment_anything1:
case ImageSegmentaionModels.segment_anything2:
ailiaImageSource.CreateSource(image_source_sam1);
break;
}
Expand Down Expand Up @@ -376,6 +431,7 @@ void HandleClick(bool leftClick, bool rightClick, bool middleClick)
if (leftClick || rightClick)
{
samModel?.AddClickPoint(x, y, rightClick);
sam2Model?.AddClickPoint(x, y, rightClick);
oneshot = true;

Debug.Log($"Click registered at: {x}, {y}");
Expand All @@ -401,7 +457,8 @@ void HandleClick(bool leftClick, bool rightClick, bool middleClick)
boxRect.xMax = Math.Max(firstX, x);
boxRect.yMax = Math.Max(firstY, y);

samModel.SetBoxCoords(boxRect);
samModel?.SetBoxCoords(boxRect);
sam2Model?.SetBoxCoords(boxRect);
oneshot = true;
}
}
Expand Down Expand Up @@ -470,6 +527,9 @@ void DestroyAiliaDetector()
if (samModel != null){
samModel.Destroy();
}
if (sam2Model != null){
sam2Model.Destroy();
}
if (segModel != null){
segModel.Destroy();
}
Expand Down
Loading