Skip to content

Commit a3c46f5

Browse files
matthew29tangcopybara-github
authored andcommitted
feat: Support Imagen image segmentation on Vertex
PiperOrigin-RevId: 795227899
1 parent 10f07cc commit a3c46f5

File tree

5 files changed

+727
-0
lines changed

5 files changed

+727
-0
lines changed

google/genai/models.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3025,6 +3025,120 @@ def _RecontextImageParameters_to_vertex(
30253025
return to_object
30263026

30273027

3028+
def _ScribbleImage_to_vertex(
3029+
from_object: Union[dict[str, Any], object],
3030+
parent_object: Optional[dict[str, Any]] = None,
3031+
) -> dict[str, Any]:
3032+
to_object: dict[str, Any] = {}
3033+
if getv(from_object, ['image']) is not None:
3034+
setv(
3035+
to_object,
3036+
['image'],
3037+
_Image_to_vertex(getv(from_object, ['image']), to_object),
3038+
)
3039+
3040+
return to_object
3041+
3042+
3043+
def _SegmentImageSource_to_vertex(
3044+
from_object: Union[dict[str, Any], object],
3045+
parent_object: Optional[dict[str, Any]] = None,
3046+
) -> dict[str, Any]:
3047+
to_object: dict[str, Any] = {}
3048+
if getv(from_object, ['prompt']) is not None:
3049+
setv(
3050+
parent_object, ['instances[0]', 'prompt'], getv(from_object, ['prompt'])
3051+
)
3052+
3053+
if getv(from_object, ['image']) is not None:
3054+
setv(
3055+
parent_object,
3056+
['instances[0]', 'image'],
3057+
_Image_to_vertex(getv(from_object, ['image']), to_object),
3058+
)
3059+
3060+
if getv(from_object, ['scribble_image']) is not None:
3061+
setv(
3062+
parent_object,
3063+
['instances[0]', 'scribble'],
3064+
_ScribbleImage_to_vertex(
3065+
getv(from_object, ['scribble_image']), to_object
3066+
),
3067+
)
3068+
3069+
return to_object
3070+
3071+
3072+
def _SegmentImageConfig_to_vertex(
3073+
from_object: Union[dict[str, Any], object],
3074+
parent_object: Optional[dict[str, Any]] = None,
3075+
) -> dict[str, Any]:
3076+
to_object: dict[str, Any] = {}
3077+
3078+
if getv(from_object, ['mode']) is not None:
3079+
setv(parent_object, ['parameters', 'mode'], getv(from_object, ['mode']))
3080+
3081+
if getv(from_object, ['max_predictions']) is not None:
3082+
setv(
3083+
parent_object,
3084+
['parameters', 'maxPredictions'],
3085+
getv(from_object, ['max_predictions']),
3086+
)
3087+
3088+
if getv(from_object, ['confidence_threshold']) is not None:
3089+
setv(
3090+
parent_object,
3091+
['parameters', 'confidenceThreshold'],
3092+
getv(from_object, ['confidence_threshold']),
3093+
)
3094+
3095+
if getv(from_object, ['mask_dilation']) is not None:
3096+
setv(
3097+
parent_object,
3098+
['parameters', 'maskDilation'],
3099+
getv(from_object, ['mask_dilation']),
3100+
)
3101+
3102+
if getv(from_object, ['binary_color_threshold']) is not None:
3103+
setv(
3104+
parent_object,
3105+
['parameters', 'binaryColorThreshold'],
3106+
getv(from_object, ['binary_color_threshold']),
3107+
)
3108+
3109+
return to_object
3110+
3111+
3112+
def _SegmentImageParameters_to_vertex(
3113+
api_client: BaseApiClient,
3114+
from_object: Union[dict[str, Any], object],
3115+
parent_object: Optional[dict[str, Any]] = None,
3116+
) -> dict[str, Any]:
3117+
to_object: dict[str, Any] = {}
3118+
if getv(from_object, ['model']) is not None:
3119+
setv(
3120+
to_object,
3121+
['_url', 'model'],
3122+
t.t_model(api_client, getv(from_object, ['model'])),
3123+
)
3124+
3125+
if getv(from_object, ['source']) is not None:
3126+
setv(
3127+
to_object,
3128+
['config'],
3129+
_SegmentImageSource_to_vertex(getv(from_object, ['source']), to_object),
3130+
)
3131+
3132+
if getv(from_object, ['config']) is not None:
3133+
setv(
3134+
to_object,
3135+
['config'],
3136+
_SegmentImageConfig_to_vertex(getv(from_object, ['config']), to_object),
3137+
)
3138+
3139+
return to_object
3140+
3141+
30283142
def _GetModelParameters_to_vertex(
30293143
api_client: BaseApiClient,
30303144
from_object: Union[dict[str, Any], object],
@@ -4618,6 +4732,63 @@ def _RecontextImageResponse_from_vertex(
46184732
return to_object
46194733

46204734

4735+
def _EntityLabel_from_vertex(
4736+
from_object: Union[dict[str, Any], object],
4737+
parent_object: Optional[dict[str, Any]] = None,
4738+
) -> dict[str, Any]:
4739+
to_object: dict[str, Any] = {}
4740+
if getv(from_object, ['label']) is not None:
4741+
setv(to_object, ['label'], getv(from_object, ['label']))
4742+
4743+
if getv(from_object, ['score']) is not None:
4744+
setv(to_object, ['score'], getv(from_object, ['score']))
4745+
4746+
return to_object
4747+
4748+
4749+
def _GeneratedImageMask_from_vertex(
4750+
from_object: Union[dict[str, Any], object],
4751+
parent_object: Optional[dict[str, Any]] = None,
4752+
) -> dict[str, Any]:
4753+
to_object: dict[str, Any] = {}
4754+
if getv(from_object, ['_self']) is not None:
4755+
setv(
4756+
to_object,
4757+
['mask'],
4758+
_Image_from_vertex(getv(from_object, ['_self']), to_object),
4759+
)
4760+
4761+
if getv(from_object, ['labels']) is not None:
4762+
setv(
4763+
to_object,
4764+
['labels'],
4765+
[
4766+
_EntityLabel_from_vertex(item, to_object)
4767+
for item in getv(from_object, ['labels'])
4768+
],
4769+
)
4770+
4771+
return to_object
4772+
4773+
4774+
def _SegmentImageResponse_from_vertex(
4775+
from_object: Union[dict[str, Any], object],
4776+
parent_object: Optional[dict[str, Any]] = None,
4777+
) -> dict[str, Any]:
4778+
to_object: dict[str, Any] = {}
4779+
if getv(from_object, ['predictions']) is not None:
4780+
setv(
4781+
to_object,
4782+
['generated_masks'],
4783+
[
4784+
_GeneratedImageMask_from_vertex(item, to_object)
4785+
for item in getv(from_object, ['predictions'])
4786+
],
4787+
)
4788+
4789+
return to_object
4790+
4791+
46214792
def _Endpoint_from_vertex(
46224793
from_object: Union[dict[str, Any], object],
46234794
parent_object: Optional[dict[str, Any]] = None,
@@ -5511,6 +5682,89 @@ def recontext_image(
55115682
self._api_client._verify_response(return_value)
55125683
return return_value
55135684

5685+
def segment_image(
5686+
self,
5687+
*,
5688+
model: str,
5689+
source: types.SegmentImageSourceOrDict,
5690+
config: Optional[types.SegmentImageConfigOrDict] = None,
5691+
) -> types.SegmentImageResponse:
5692+
"""Segments an image, creating a mask of a specified area.
5693+
5694+
Args:
5695+
model (str): The model to use.
5696+
source (SegmentImageSource): An object containing the source inputs
5697+
(prompt, image, scribble_image) for image segmentation. The prompt is
5698+
required for prompt mode and semantic mode, disallowed for other modes.
5699+
scribble_image is required for the interactive mode, disallowed for
5700+
other modes.
5701+
config (SegmentImageConfig): Configuration for segmentation.
5702+
5703+
Usage:
5704+
5705+
```
5706+
response = client.models.segment_image(
5707+
model="image-segmentation-001",
5708+
source=types.SegmentImageSource(
5709+
image=types.Image.from_file(IMAGE_FILE_PATH),
5710+
),
5711+
)
5712+
5713+
mask_image = response.generated_masks[0].mask
5714+
```
5715+
"""
5716+
5717+
parameter_model = types._SegmentImageParameters(
5718+
model=model,
5719+
source=source,
5720+
config=config,
5721+
)
5722+
5723+
request_url_dict: Optional[dict[str, str]]
5724+
if not self._api_client.vertexai:
5725+
raise ValueError('This method is only supported in the Vertex AI client.')
5726+
else:
5727+
request_dict = _SegmentImageParameters_to_vertex(
5728+
self._api_client, parameter_model
5729+
)
5730+
request_url_dict = request_dict.get('_url')
5731+
if request_url_dict:
5732+
path = '{model}:predict'.format_map(request_url_dict)
5733+
else:
5734+
path = '{model}:predict'
5735+
5736+
query_params = request_dict.get('_query')
5737+
if query_params:
5738+
path = f'{path}?{urlencode(query_params)}'
5739+
# TODO: remove the hack that pops config.
5740+
request_dict.pop('config', None)
5741+
5742+
http_options: Optional[types.HttpOptions] = None
5743+
if (
5744+
parameter_model.config is not None
5745+
and parameter_model.config.http_options is not None
5746+
):
5747+
http_options = parameter_model.config.http_options
5748+
5749+
request_dict = _common.convert_to_dict(request_dict)
5750+
request_dict = _common.encode_unserializable_types(request_dict)
5751+
5752+
response = self._api_client.request(
5753+
'post', path, request_dict, http_options
5754+
)
5755+
5756+
response_dict = '' if not response.body else json.loads(response.body)
5757+
5758+
if self._api_client.vertexai:
5759+
response_dict = _SegmentImageResponse_from_vertex(response_dict)
5760+
5761+
return_value = types.SegmentImageResponse._from_response(
5762+
response=response_dict, kwargs=parameter_model.model_dump()
5763+
)
5764+
5765+
self._api_client._verify_response(return_value)
5766+
return return_value
5767+
55145768
def get(
55155769
self, *, model: str, config: Optional[types.GetModelConfigOrDict] = None
55165770
) -> types.Model:
@@ -7240,6 +7494,92 @@ async def recontext_image(
72407494
self._api_client._verify_response(return_value)
72417495
return return_value
72427496

7497+
async def segment_image(
7498+
self,
7499+
*,
7500+
model: str,
7501+
source: types.SegmentImageSourceOrDict,
7502+
config: Optional[types.SegmentImageConfigOrDict] = None,
7503+
) -> types.SegmentImageResponse:
7504+
"""Segments an image, creating a mask of a specified area.
7505+
7506+
Args:
7507+
model (str): The model to use.
7508+
source (SegmentImageSource): An object containing the source inputs
7509+
(prompt, image, scribble_image) for image segmentation. The prompt is
7510+
required for prompt mode and semantic mode, disallowed for other modes.
7511+
scribble_image is required for the interactive mode, disallowed for
7512+
other modes.
7513+
config (SegmentImageConfig): Configuration for segmentation.
7514+
7515+
Usage:
7516+
7517+
```
7518+
response = client.models.segment_image(
7519+
model="image-segmentation-001",
7520+
source=types.SegmentImageSource(
7521+
image=types.Image.from_file(IMAGE_FILE_PATH),
7522+
),
7523+
config=types.SegmentImageConfig(
7524+
mode=types.SegmentMode.foreground,
7525+
),
7526+
)
7527+
7528+
mask_image = response.generated_masks[0].mask
7529+
```
7530+
"""
7531+
7532+
parameter_model = types._SegmentImageParameters(
7533+
model=model,
7534+
source=source,
7535+
config=config,
7536+
)
7537+
7538+
request_url_dict: Optional[dict[str, str]]
7539+
if not self._api_client.vertexai:
7540+
raise ValueError('This method is only supported in the Vertex AI client.')
7541+
else:
7542+
request_dict = _SegmentImageParameters_to_vertex(
7543+
self._api_client, parameter_model
7544+
)
7545+
request_url_dict = request_dict.get('_url')
7546+
if request_url_dict:
7547+
path = '{model}:predict'.format_map(request_url_dict)
7548+
else:
7549+
path = '{model}:predict'
7550+
7551+
query_params = request_dict.get('_query')
7552+
if query_params:
7553+
path = f'{path}?{urlencode(query_params)}'
7554+
# TODO: remove the hack that pops config.
7555+
request_dict.pop('config', None)
7556+
7557+
http_options: Optional[types.HttpOptions] = None
7558+
if (
7559+
parameter_model.config is not None
7560+
and parameter_model.config.http_options is not None
7561+
):
7562+
http_options = parameter_model.config.http_options
7563+
7564+
request_dict = _common.convert_to_dict(request_dict)
7565+
request_dict = _common.encode_unserializable_types(request_dict)
7566+
7567+
response = await self._api_client.async_request(
7568+
'post', path, request_dict, http_options
7569+
)
7570+
7571+
response_dict = '' if not response.body else json.loads(response.body)
7572+
7573+
if self._api_client.vertexai:
7574+
response_dict = _SegmentImageResponse_from_vertex(response_dict)
7575+
7576+
return_value = types.SegmentImageResponse._from_response(
7577+
response=response_dict, kwargs=parameter_model.model_dump()
7578+
)
7579+
7580+
self._api_client._verify_response(return_value)
7581+
return return_value
7582+
72437583
async def get(
72447584
self, *, model: str, config: Optional[types.GetModelConfigOrDict] = None
72457585
) -> types.Model:
7.7 KB
Loading
343 KB
Loading

0 commit comments

Comments
 (0)