Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions packages/firebase_ai/firebase_ai/lib/src/base_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,20 @@ abstract class BaseModel {

/// Returns a function that generates Firebase auth tokens.
static FutureOr<Map<String, String>> Function() firebaseTokens(
FirebaseAppCheck? appCheck, FirebaseAuth? auth, FirebaseApp? app) {
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
FirebaseApp? app,
bool? useLimitedUseAppCheckTokens,
) {
return () async {
Map<String, String> headers = {};
// Override the client name in Google AI SDK
headers['x-goog-api-client'] =
'gl-dart/$packageVersion fire/$packageVersion';
if (appCheck != null) {
final appCheckToken = await appCheck.getToken();
final appCheckToken = useLimitedUseAppCheckTokens == true
? await appCheck.getLimitedUseToken()
: await appCheck.getToken();
if (appCheckToken != null) {
headers['X-Firebase-AppCheck'] = appCheckToken;
}
Expand Down
27 changes: 19 additions & 8 deletions packages/firebase_ai/firebase_ai/lib/src/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ const _defaultLocation = 'us-central1';

/// The entrypoint for generative models.
class FirebaseAI extends FirebasePluginPlatform {
FirebaseAI._(
{required this.app,
required this.location,
required bool useVertexBackend,
this.appCheck,
this.auth})
: _useVertexBackend = useVertexBackend,
FirebaseAI._({
required this.app,
required this.location,
required bool useVertexBackend,
this.appCheck,
this.auth,
this.useLimitedUseAppCheckTokens = false,
}) : _useVertexBackend = useVertexBackend,
super(app.name, 'plugins.flutter.io/firebase_vertexai');

/// The [FirebaseApp] for this current [FirebaseAI] instance.
Expand All @@ -48,6 +49,9 @@ class FirebaseAI extends FirebasePluginPlatform {
/// The service location for this [FirebaseAI] instance.
String location;

/// Whether to use App Check limited use tokens. Defaults to false.
final bool useLimitedUseAppCheckTokens;

final bool _useVertexBackend;

static final Map<String, FirebaseAI> _cachedInstances = {};
Expand All @@ -61,6 +65,7 @@ class FirebaseAI extends FirebasePluginPlatform {
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
String? location,
bool? useLimitedUseAppCheckTokens,
}) {
app ??= Firebase.app();
var instanceKey = '${app.name}::vertexai';
Expand All @@ -77,6 +82,7 @@ class FirebaseAI extends FirebasePluginPlatform {
appCheck: appCheck,
auth: auth,
useVertexBackend: true,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ?? false,
);
_cachedInstances[instanceKey] = newInstance;

Expand All @@ -91,6 +97,7 @@ class FirebaseAI extends FirebasePluginPlatform {
FirebaseApp? app,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
bool? useLimitedUseAppCheckTokens,
}) {
app ??= Firebase.app();
var instanceKey = '${app.name}::googleai';
Expand All @@ -105,6 +112,7 @@ class FirebaseAI extends FirebasePluginPlatform {
appCheck: appCheck,
auth: auth,
useVertexBackend: false,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens ?? false,
);
_cachedInstances[instanceKey] = newInstance;

Expand Down Expand Up @@ -142,6 +150,7 @@ class FirebaseAI extends FirebasePluginPlatform {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
);
}

Expand All @@ -162,7 +171,8 @@ class FirebaseAI extends FirebasePluginPlatform {
generationConfig: generationConfig,
safetySettings: safetySettings,
appCheck: appCheck,
auth: auth);
auth: auth,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens);
}

/// Create a [LiveGenerativeModel] for real-time interaction.
Expand All @@ -185,6 +195,7 @@ class FirebaseAI extends FirebasePluginPlatform {
systemInstruction: systemInstruction,
appCheck: appCheck,
auth: auth,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
);
}
}
13 changes: 10 additions & 3 deletions packages/firebase_ai/firebase_ai/lib/src/generative_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ final class GenerativeModel extends BaseApiClientModel {
required String location,
required FirebaseApp app,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
List<SafetySetting>? safetySettings,
Expand All @@ -60,13 +61,15 @@ final class GenerativeModel extends BaseApiClientModel {
client: HttpApiClient(
apiKey: app.options.apiKey,
httpClient: httpClient,
requestHeaders: BaseModel.firebaseTokens(appCheck, auth, app)));
requestHeaders: BaseModel.firebaseTokens(
appCheck, auth, app, useLimitedUseAppCheckTokens)));

GenerativeModel._constructTestModel({
required String model,
required String location,
required FirebaseApp app,
required useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
List<SafetySetting>? safetySettings,
Expand All @@ -90,8 +93,8 @@ final class GenerativeModel extends BaseApiClientModel {
client: apiClient ??
HttpApiClient(
apiKey: app.options.apiKey,
requestHeaders:
BaseModel.firebaseTokens(appCheck, auth, app)));
requestHeaders: BaseModel.firebaseTokens(
appCheck, auth, app, useLimitedUseAppCheckTokens)));

final List<SafetySetting> _safetySettings;
final GenerationConfig? _generationConfig;
Expand Down Expand Up @@ -199,6 +202,7 @@ GenerativeModel createGenerativeModel({
required String location,
required String model,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
GenerationConfig? generationConfig,
Expand All @@ -212,6 +216,7 @@ GenerativeModel createGenerativeModel({
app: app,
appCheck: appCheck,
useVertexBackend: useVertexBackend,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
auth: auth,
location: location,
safetySettings: safetySettings,
Expand All @@ -230,6 +235,7 @@ GenerativeModel createModelWithClient({
required String model,
required ApiClient client,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
Content? systemInstruction,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
Expand All @@ -243,6 +249,7 @@ GenerativeModel createModelWithClient({
app: app,
appCheck: appCheck,
useVertexBackend: useVertexBackend,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
auth: auth,
location: location,
safetySettings: safetySettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ final class ImagenModel extends BaseApiClientModel {
required String model,
required String location,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
ImagenGenerationConfig? generationConfig,
Expand All @@ -45,7 +46,8 @@ final class ImagenModel extends BaseApiClientModel {
: _GoogleAIUri(app: app, model: model),
client: HttpApiClient(
apiKey: app.options.apiKey,
requestHeaders: BaseModel.firebaseTokens(appCheck, auth, app)));
requestHeaders: BaseModel.firebaseTokens(
appCheck, auth, app, useLimitedUseAppCheckTokens)));

final ImagenGenerationConfig? _generationConfig;
final ImagenSafetySettings? _safetySettings;
Expand Down Expand Up @@ -198,6 +200,7 @@ ImagenModel createImagenModel({
required String location,
required String model,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
ImagenGenerationConfig? generationConfig,
Expand All @@ -210,6 +213,7 @@ ImagenModel createImagenModel({
auth: auth,
location: location,
useVertexBackend: useVertexBackend,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
safetySettings: safetySettings,
generationConfig: generationConfig,
);
12 changes: 11 additions & 1 deletion packages/firebase_ai/firebase_ai/lib/src/live_model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ final class LiveGenerativeModel extends BaseModel {
required String location,
required FirebaseApp app,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
Expand All @@ -47,6 +48,7 @@ final class LiveGenerativeModel extends BaseModel {
_liveGenerationConfig = liveGenerationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
_useLimitedUseAppCheckTokens = useLimitedUseAppCheckTokens,
super._(
serializationStrategy: VertexSerialization(),
modelUri: useVertexBackend
Expand All @@ -69,6 +71,7 @@ final class LiveGenerativeModel extends BaseModel {
final LiveGenerationConfig? _liveGenerationConfig;
final List<Tool>? _tools;
final Content? _systemInstruction;
final bool? _useLimitedUseAppCheckTokens;

String _vertexAIUri() => 'wss://${_modelUri.baseAuthority}/'
'$_apiUrl.${_modelUri.apiVersion}.$_apiUrlSuffixVertexAI/'
Expand Down Expand Up @@ -107,7 +110,12 @@ final class LiveGenerativeModel extends BaseModel {
};

final request = jsonEncode(setupJson);
final headers = await BaseModel.firebaseTokens(_appCheck, _auth, _app)();
final headers = await BaseModel.firebaseTokens(
_appCheck,
_auth,
_app,
_useLimitedUseAppCheckTokens,
)();

var ws = kIsWeb
? WebSocketChannel.connect(Uri.parse(uri))
Expand All @@ -126,6 +134,7 @@ LiveGenerativeModel createLiveGenerativeModel({
required String location,
required String model,
required bool useVertexBackend,
bool? useLimitedUseAppCheckTokens,
FirebaseAppCheck? appCheck,
FirebaseAuth? auth,
LiveGenerationConfig? liveGenerationConfig,
Expand All @@ -139,6 +148,7 @@ LiveGenerativeModel createLiveGenerativeModel({
auth: auth,
location: location,
useVertexBackend: useVertexBackend,
useLimitedUseAppCheckTokens: useLimitedUseAppCheckTokens,
liveGenerationConfig: liveGenerationConfig,
tools: tools,
systemInstruction: systemInstruction,
Expand Down
32 changes: 27 additions & 5 deletions packages/firebase_ai/firebase_ai/test/base_model_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class MockFirebaseAppCheck extends Mock implements FirebaseAppCheck {
@override
Future<String?> getToken([bool? forceRefresh = false]) async =>
super.noSuchMethod(Invocation.method(#getToken, [forceRefresh]));

@override
Future<String> getLimitedUseToken() async =>
super.noSuchMethod(Invocation.method(#getLimitedUseToken, [])) ?? '';
}

// Mock Firebase Auth
Expand Down Expand Up @@ -72,7 +76,7 @@ class MockApiClient extends Mock implements ApiClient {
void main() {
group('BaseModel', () {
test('firebaseTokens returns a function that generates headers', () async {
final tokenFunction = BaseModel.firebaseTokens(null, null, null);
final tokenFunction = BaseModel.firebaseTokens(null, null, null, false);
final headers = await tokenFunction();
expect(headers['x-goog-api-client'], contains('gl-dart'));
expect(headers['x-goog-api-client'], contains('fire'));
Expand All @@ -83,7 +87,8 @@ void main() {
final mockAppCheck = MockFirebaseAppCheck();
when(mockAppCheck.getToken())
.thenAnswer((_) async => 'test-app-check-token');
final tokenFunction = BaseModel.firebaseTokens(mockAppCheck, null, null);
final tokenFunction =
BaseModel.firebaseTokens(mockAppCheck, null, null, false);
final headers = await tokenFunction();
expect(headers['X-Firebase-AppCheck'], 'test-app-check-token');
expect(headers['x-goog-api-client'], contains('gl-dart'));
Expand All @@ -96,7 +101,8 @@ void main() {
final mockUser = MockUser();
when(mockUser.getIdToken()).thenAnswer((_) async => 'test-id-token');
when(mockAuth.currentUser).thenReturn(mockUser);
final tokenFunction = BaseModel.firebaseTokens(null, mockAuth, null);
final tokenFunction =
BaseModel.firebaseTokens(null, mockAuth, null, false);
final headers = await tokenFunction();
expect(headers['Authorization'], 'Firebase test-id-token');
expect(headers['x-goog-api-client'], contains('gl-dart'));
Expand All @@ -109,7 +115,8 @@ void main() {
() async {
final mockApp = MockFirebaseApp();

final tokenFunction = BaseModel.firebaseTokens(null, null, mockApp);
final tokenFunction =
BaseModel.firebaseTokens(null, null, mockApp, false);
final headers = await tokenFunction();
expect(headers['X-Firebase-AppId'], 'test-app-id');
expect(headers['x-goog-api-client'], contains('gl-dart'));
Expand All @@ -128,7 +135,7 @@ void main() {
final mockApp = MockFirebaseApp();

final tokenFunction =
BaseModel.firebaseTokens(mockAppCheck, mockAuth, mockApp);
BaseModel.firebaseTokens(mockAppCheck, mockAuth, mockApp, false);
final headers = await tokenFunction();
expect(headers['X-Firebase-AppCheck'], 'test-app-check-token');
expect(headers['Authorization'], 'Firebase test-id-token');
Expand All @@ -137,5 +144,20 @@ void main() {
expect(headers['x-goog-api-client'], contains('fire'));
expect(headers.length, 4);
});

test('firebaseTokens includes limited use App Check token if specified',
() async {
final mockAppCheck = MockFirebaseAppCheck();
when(mockAppCheck.getLimitedUseToken())
.thenAnswer((_) async => 'test-limited-use-app-check-token');
final tokenFunction =
BaseModel.firebaseTokens(mockAppCheck, null, null, true);
final headers = await tokenFunction();
expect(
headers['X-Firebase-AppCheck'], 'test-limited-use-app-check-token');
expect(headers['x-goog-api-client'], contains('gl-dart'));
expect(headers['x-goog-api-client'], contains('fire'));
expect(headers.length, 2);
});
});
}
20 changes: 20 additions & 0 deletions packages/firebase_ai/firebase_ai/test/firebase_vertexai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ void main() {
// ignore: unused_local_variable
late FirebaseAppCheck appCheck;
late FirebaseApp customApp;
late FirebaseApp limitTokenApp;
late FirebaseAppCheck customAppCheck;
late FirebaseAppCheck limitTokenAppCheck;

group('FirebaseAI Tests', () {
late FirebaseApp app;
Expand All @@ -38,8 +40,13 @@ void main() {
name: 'custom-app',
options: Firebase.app().options,
);
limitTokenApp = await Firebase.initializeApp(
name: 'limit-token-app',
options: Firebase.app().options,
);
appCheck = FirebaseAppCheck.instance;
customAppCheck = FirebaseAppCheck.instanceFor(app: customApp);
limitTokenAppCheck = FirebaseAppCheck.instanceFor(app: limitTokenApp);
});

test('Singleton behavior', () {
Expand Down Expand Up @@ -76,6 +83,19 @@ void main() {
expect(model, isA<GenerativeModel>());
});

test('Instance creation with useLimitedUseAppCheckTokens', () {
final vertexAIAppCheck = FirebaseAI.vertexAI(
app: limitTokenApp,
appCheck: limitTokenAppCheck,
location: 'limit-token-location',
useLimitedUseAppCheckTokens: true,
);
expect(vertexAIAppCheck.app, equals(limitTokenApp));
expect(vertexAIAppCheck.appCheck, equals(limitTokenAppCheck));
expect(vertexAIAppCheck.location, equals('limit-token-location'));
expect(vertexAIAppCheck.useLimitedUseAppCheckTokens, true);
});

// ... other tests (e.g., with different parameters)
});
}
Loading