Skip to content

Commit 763405d

Browse files
authored
feat(go/plugins/googlegenai): add image-generation models (#2903)
1 parent ea88f5d commit 763405d

File tree

4 files changed

+244
-8
lines changed

4 files changed

+244
-8
lines changed

go/plugins/googlegenai/gemini.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,30 @@ func defineModel(g *genkit.Genkit, client *genai.Client, name string, info ai.Mo
140140
provider = vertexAIProvider
141141
}
142142

143+
var config any
144+
config = &genai.GenerateContentConfig{}
145+
if mi, found := supportedImagenModels[name]; found {
146+
config = &genai.GenerateImagesConfig{}
147+
info = mi
148+
}
143149
meta := &ai.ModelInfo{
144150
Label: info.Label,
145151
Supports: info.Supports,
146152
Versions: info.Versions,
147-
ConfigSchema: configToMap(genai.GenerateContentConfig{}),
153+
ConfigSchema: configToMap(config),
148154
}
149155

150156
fn := func(
151157
ctx context.Context,
152158
input *ai.ModelRequest,
153159
cb func(context.Context, *ai.ModelResponseChunk) error,
154160
) (*ai.ModelResponse, error) {
155-
return generate(ctx, client, name, input, cb)
161+
switch config.(type) {
162+
case *genai.GenerateImagesConfig:
163+
return generateImage(ctx, client, name, input, cb)
164+
default:
165+
return generate(ctx, client, name, input, cb)
166+
}
156167
}
157168
// the gemini api doesn't support downloading media from http(s)
158169
if info.Supports.Media {

go/plugins/googlegenai/imagen.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
17+
package googlegenai
18+
19+
import (
20+
"context"
21+
"encoding/base64"
22+
"fmt"
23+
24+
"github.com/firebase/genkit/go/ai"
25+
"google.golang.org/genai"
26+
)
27+
28+
// Media describes model capabilities for Gemini models with media and text
29+
// input and image only output
30+
var Media = ai.ModelSupports{
31+
Media: true,
32+
Multiturn: false,
33+
Tools: false,
34+
ToolChoice: false,
35+
SystemRole: false,
36+
}
37+
38+
// imagenConfigFromRequest translates an [*ai.ModelRequest] configuration to [*genai.GenerateImagesConfig]
39+
func imagenConfigFromRequest(input *ai.ModelRequest) (*genai.GenerateImagesConfig, error) {
40+
var result genai.GenerateImagesConfig
41+
42+
switch config := input.Config.(type) {
43+
case genai.GenerateImagesConfig:
44+
result = config
45+
case *genai.GenerateImagesConfig:
46+
result = *config
47+
case map[string]any:
48+
if err := mapToStruct(config, &result); err != nil {
49+
return nil, err
50+
}
51+
case nil:
52+
// empty but valid config
53+
default:
54+
return nil, fmt.Errorf("unexpected config type: %T", input.Config)
55+
}
56+
57+
return &result, nil
58+
}
59+
60+
// translateImagenCandidates translates the image generation response to [*ai.ModelResponse]
61+
func translateImagenCandidates(images []*genai.GeneratedImage) *ai.ModelResponse {
62+
m := &ai.ModelResponse{}
63+
m.FinishReason = ai.FinishReasonStop
64+
65+
msg := &ai.Message{}
66+
msg.Role = ai.RoleModel
67+
68+
for _, img := range images {
69+
msg.Content = append(msg.Content, ai.NewMediaPart(img.Image.MIMEType, "data:"+img.Image.MIMEType+";base64,"+base64.StdEncoding.EncodeToString(img.Image.ImageBytes)))
70+
}
71+
72+
m.Message = msg
73+
return m
74+
}
75+
76+
// translateImagenResponse translates [*genai.GenerateImagesResponse] to an [*ai.ModelResponse]
77+
func translateImagenResponse(resp *genai.GenerateImagesResponse) *ai.ModelResponse {
78+
return translateImagenCandidates(resp.GeneratedImages)
79+
}
80+
81+
// generateImage requests a generate call to the specified imagen model with the
82+
// provided configuration
83+
func generateImage(
84+
ctx context.Context,
85+
client *genai.Client,
86+
model string,
87+
input *ai.ModelRequest,
88+
cb func(context.Context, *ai.ModelResponseChunk) error,
89+
) (*ai.ModelResponse, error) {
90+
gic, err := imagenConfigFromRequest(input)
91+
if err != nil {
92+
return nil, err
93+
}
94+
95+
var userPrompt string
96+
for _, m := range input.Messages {
97+
if m.Role == ai.RoleUser {
98+
userPrompt += m.Text()
99+
}
100+
}
101+
if userPrompt == "" {
102+
return nil, fmt.Errorf("error generating images: empty prompt detected")
103+
}
104+
105+
if cb != nil {
106+
return nil, fmt.Errorf("streaming mode not supported for image generation")
107+
}
108+
109+
resp, err := client.Models.GenerateImages(ctx, model, userPrompt, gic)
110+
if err != nil {
111+
return nil, err
112+
}
113+
114+
r := translateImagenResponse(resp)
115+
r.Request = input
116+
return r, nil
117+
}

go/plugins/googlegenai/models.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ const (
3232
gemini25ProExp0325 = "gemini-2.5-pro-exp-03-25"
3333
gemini25ProPreview0325 = "gemini-2.5-pro-preview-03-25"
3434
gemini25ProPreview0506 = "gemini-2.5-pro-preview-05-06"
35+
36+
imagen3Generate001 = "imagen-3.0-generate-001"
37+
imagen3Generate002 = "imagen-3.0-generate-002"
38+
imagen3FastGenerate001 = "imagen-3.0-fast-generate-001"
3539
)
3640

3741
var (
@@ -50,6 +54,10 @@ var (
5054
gemini25ProExp0325,
5155
gemini25ProPreview0325,
5256
gemini25ProPreview0506,
57+
58+
imagen3Generate001,
59+
imagen3Generate002,
60+
imagen3FastGenerate001,
5361
}
5462

5563
googleAIModels = []string{
@@ -66,9 +74,11 @@ var (
6674
gemini25ProExp0325,
6775
gemini25ProPreview0325,
6876
gemini25ProPreview0506,
77+
78+
imagen3Generate002,
6979
}
7080

71-
// models with native image support generation
81+
// Gemini models with native image support generation
7282
imageGenModels = []string{
7383
gemini20FlashPrevImageGen,
7484
}
@@ -175,6 +185,27 @@ var (
175185
},
176186
}
177187

188+
supportedImagenModels = map[string]ai.ModelInfo{
189+
imagen3Generate001: {
190+
Label: "Imagen 3 Generate 001",
191+
Versions: []string{},
192+
Supports: &Media,
193+
Stage: ai.ModelStageStable,
194+
},
195+
imagen3Generate002: {
196+
Label: "Imagen 3 Generate 002",
197+
Versions: []string{},
198+
Supports: &Media,
199+
Stage: ai.ModelStageStable,
200+
},
201+
imagen3FastGenerate001: {
202+
Label: "Imagen 3 Fast Generate 001",
203+
Versions: []string{},
204+
Supports: &Media,
205+
Stage: ai.ModelStageStable,
206+
},
207+
}
208+
178209
googleAIEmbedders = []string{
179210
"text-embedding-004",
180211
"embedding-001",
@@ -194,7 +225,7 @@ var (
194225
// listModels returns a map of supported models and their capabilities
195226
// based on the detected backend
196227
func listModels(provider string) (map[string]ai.ModelInfo, error) {
197-
names := []string{}
228+
var names []string
198229
var prefix string
199230

200231
switch provider {
@@ -210,7 +241,13 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {
210241

211242
models := make(map[string]ai.ModelInfo, 0)
212243
for _, n := range names {
213-
m, ok := supportedGeminiModels[n]
244+
var m ai.ModelInfo
245+
var ok bool
246+
if strings.HasPrefix(n, "image") {
247+
m, ok = supportedImagenModels[n]
248+
} else {
249+
m, ok = supportedGeminiModels[n]
250+
}
214251
if !ok {
215252
return nil, fmt.Errorf("model %s not found for provider %s", n, provider)
216253
}
@@ -227,7 +264,7 @@ func listModels(provider string) (map[string]ai.ModelInfo, error) {
227264
// listEmbedders returns a list of supported embedders based on the
228265
// detected backend
229266
func listEmbedders(backend genai.Backend) ([]string, error) {
230-
embedders := []string{}
267+
var embedders []string
231268

232269
switch backend {
233270
case genai.BackendGeminiAPI:
@@ -242,9 +279,10 @@ func listEmbedders(backend genai.Backend) ([]string, error) {
242279
}
243280

244281
// genaiModels collects all the available models in go-genai SDK
245-
// TODO: add imagen and veo models
282+
// TODO: add veo models
246283
type genaiModels struct {
247284
gemini []string
285+
imagen []string
248286
embedders []string
249287
}
250288

@@ -253,6 +291,7 @@ type genaiModels struct {
253291
func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, error) {
254292
models := genaiModels{}
255293
allowedModels := []string{"gemini", "gemma"}
294+
allowedImagenModels := []string{"imagen"}
256295

257296
for item, err := range client.Models.All(ctx) {
258297
var name string
@@ -283,7 +322,15 @@ func listGenaiModels(ctx context.Context, client *genai.Client) (genaiModels, er
283322
continue
284323
}
285324

286-
// TODO: add imagen and veo models
325+
found = slices.ContainsFunc(allowedImagenModels, func(s string) bool {
326+
return strings.Contains(name, s)
327+
})
328+
// filter out: Aqa, Text-bison, Chat, learnlm
329+
if found {
330+
models.imagen = append(models.imagen, name)
331+
continue
332+
}
287333
}
334+
288335
return models, nil
289336
}

go/samples/imagen/main.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"context"
19+
"log"
20+
21+
"github.com/firebase/genkit/go/ai"
22+
"github.com/firebase/genkit/go/genkit"
23+
"github.com/firebase/genkit/go/plugins/googlegenai"
24+
"google.golang.org/genai"
25+
)
26+
27+
func main() {
28+
ctx := context.Background()
29+
g, err := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.VertexAI{}))
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
34+
genkit.DefineFlow(g, "image-generation", func(ctx context.Context, input string) ([]string, error) {
35+
r, err := genkit.Generate(ctx, g,
36+
ai.WithModelName("vertexai/imagen-3.0-generate-001"),
37+
ai.WithPrompt("Generate an image of %s", input),
38+
ai.WithConfig(&genai.GenerateImagesConfig{
39+
NumberOfImages: 2,
40+
NegativePrompt: "night",
41+
AspectRatio: "9:16",
42+
SafetyFilterLevel: genai.SafetyFilterLevelBlockLowAndAbove,
43+
PersonGeneration: genai.PersonGenerationAllowAll,
44+
Language: genai.ImagePromptLanguageEn,
45+
AddWatermark: true,
46+
OutputMIMEType: "image/jpeg",
47+
}),
48+
)
49+
if err != nil {
50+
log.Fatal(err)
51+
}
52+
53+
var images []string
54+
for _, m := range r.Message.Content {
55+
images = append(images, m.Text)
56+
}
57+
return images, nil
58+
})
59+
60+
<-ctx.Done()
61+
}

0 commit comments

Comments
 (0)