diff --git a/.gitattributes b/.gitattributes index 9936f69d..7d098507 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,6 +2,7 @@ .gitattributes export-ignore /.github/ export-ignore .gitignore export-ignore +/cli.php export-ignore /*.md export-ignore /LICENSE.md -export-ignore /README.md -export-ignore diff --git a/cli.php b/cli.php new file mode 100755 index 00000000..17f913f9 --- /dev/null +++ b/cli.php @@ -0,0 +1,181 @@ + $value) { + if (!isset($schema[$key])) { + continue; + } + + $property_schema = $schema[$key]; + $type = $property_schema['type'] ?? null; + + $processed_value = $value; + if ($type === 'array' || $type === 'object') { + $decoded = json_decode((string) $value, true); + if (json_last_error() !== JSON_ERROR_NONE) { + logWarning("Invalid JSON for argument --{$key}: " . json_last_error_msg()); + continue; + } + $processed_value = $decoded; + } elseif ($type === 'integer') { + $processed_value = (int) $value; + } elseif ($type === 'number') { + $processed_value = (float) $value; + } elseif ($type === 'boolean') { + $processed_value = filter_var($value, FILTER_VALIDATE_BOOLEAN, FILTER_NULL_ON_FAILURE); + if (null === $processed_value) { + logWarning("Invalid boolean for argument --{$key}: {$value}"); + continue; + } + } + + $model_config_data[$key] = $processed_value; +} + +// --- Main logic --- + +try { + $modelConfig = ModelConfig::fromArray($model_config_data); + + $promptBuilder = AiClient::prompt($promptInput); + $promptBuilder = $promptBuilder->usingModelConfig($modelConfig); + if ($providerId && $modelId) { + $providerClassName = AiClient::defaultRegistry()->getProviderClassName($providerId); + $promptBuilder = $promptBuilder->usingModel($providerClassName::model($modelId)); + } elseif ($providerId) { + $promptBuilder = $promptBuilder->usingProvider($providerId); + } +} catch (InvalidArgumentException $e) { + logError('Invalid arguments while trying to set up prompt builder: ' . $e->getMessage()); +} catch (ResponseException $e) { + logError('Request failed while trying to set up prompt builder: ' . $e->getMessage()); +} + +try { + $result = $promptBuilder->generateTextResult(); +} catch (InvalidArgumentException $e) { + logError('Invalid arguments while trying to generate text result: ' . $e->getMessage()); +} catch (ResponseException $e) { + logError('Request failed while trying to generate text result: ' . $e->getMessage()); +} + +logInfo("Using provider ID: \"{$result->getProviderMetadata()->getId()}\""); +logInfo("Using model ID: \"{$result->getModelMetadata()->getId()}\""); + +switch ($outputFormat) { + case 'result-json': + $output = json_encode($result, JSON_PRETTY_PRINT); + break; + case 'candidates-json': + $output = json_encode($result->getCandidates(), JSON_PRETTY_PRINT); + break; + case 'message-text': + default: + $output = $result->toText(); +} + +printOutput($output); diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 895d2e03..03654156 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -893,7 +893,7 @@ direction LR +send(Request $request) Response } class RequestAuthenticationInterface { - +authenticate(Request $request) void + +authenticateRequest(Request $request) Request +getJsonSchema() array< string, mixed >$ } class WithHttpTransporterInterface { diff --git a/src/AiClient.php b/src/AiClient.php index dbafd5d3..80ffbd62 100644 --- a/src/AiClient.php +++ b/src/AiClient.php @@ -5,7 +5,11 @@ namespace WordPress\AiClient; use WordPress\AiClient\Builders\PromptBuilder; +use WordPress\AiClient\ProviderImplementations\Anthropic\AnthropicProvider; +use WordPress\AiClient\ProviderImplementations\Google\GoogleProvider; +use WordPress\AiClient\ProviderImplementations\OpenAi\OpenAiProvider; use WordPress\AiClient\Providers\Contracts\ProviderAvailabilityInterface; +use WordPress\AiClient\Providers\Http\HttpTransporterFactory; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\ProviderRegistry; @@ -95,12 +99,11 @@ public static function defaultRegistry(): ProviderRegistry if (self::$defaultRegistry === null) { $registry = new ProviderRegistry(); - // Provider registration will be enabled once concrete provider implementations are available. - // This follows the pattern established in the provider registry architecture. - //$registry->setHttpTransporter(HttpTransporterFactory::createTransporter()); - //$registry->registerProvider(AnthropicProvider::class); - //$registry->registerProvider(GoogleProvider::class); - //$registry->registerProvider(OpenAiProvider::class); + // Set up default HTTP transporter and register built-in providers. + $registry->setHttpTransporter(HttpTransporterFactory::createTransporter()); + $registry->registerProvider(AnthropicProvider::class); + $registry->registerProvider(GoogleProvider::class); + $registry->registerProvider(OpenAiProvider::class); self::$defaultRegistry = $registry; } diff --git a/src/Builders/PromptBuilder.php b/src/Builders/PromptBuilder.php index d90e4644..08de3d1e 100644 --- a/src/Builders/PromptBuilder.php +++ b/src/Builders/PromptBuilder.php @@ -191,6 +191,9 @@ public function withHistory(Message ...$messages): self /** * Sets the model to use for generation. * + * The model's configuration will be merged with the builder's configuration, + * with the builder's configuration taking precedence for any overlapping settings. + * * @since n.e.x.t * * @param ModelInterface $model The model to use. @@ -199,6 +202,14 @@ public function withHistory(Message ...$messages): self public function usingModel(ModelInterface $model): self { $this->model = $model; + + // Merge model's config with builder's config, with builder's config taking precedence + $modelConfigArray = $model->getConfig()->toArray(); + $builderConfigArray = $this->modelConfig->toArray(); + $mergedConfigArray = array_merge($modelConfigArray, $builderConfigArray); + + $this->modelConfig = ModelConfig::fromArray($mergedConfigArray); + return $this; } @@ -1009,6 +1020,7 @@ private function getConfiguredModel(CapabilityEnum $capability): ModelInterface // If a model has been explicitly set, return it if ($this->model !== null) { $this->model->setConfig($this->modelConfig); + $this->registry->bindModelDependencies($this->model); return $this->model; } diff --git a/src/Common/AbstractEnum.php b/src/Common/AbstractEnum.php index 6b4a2cf4..685002b9 100644 --- a/src/Common/AbstractEnum.php +++ b/src/Common/AbstractEnum.php @@ -6,6 +6,7 @@ use BadMethodCallException; use InvalidArgumentException; +use JsonSerializable; use ReflectionClass; use RuntimeException; @@ -36,7 +37,7 @@ * * @since n.e.x.t */ -abstract class AbstractEnum +abstract class AbstractEnum implements JsonSerializable { /** * @var string The value of the enum instance. @@ -393,4 +394,17 @@ final public function __toString(): string { return $this->value; } + + /** + * Converts the enum to a JSON-serializable format. + * + * @since n.e.x.t + * + * @return string The enum value. + */ + #[\ReturnTypeWillChange] + public function jsonSerialize() + { + return $this->value; + } } diff --git a/src/Files/ValueObjects/MimeType.php b/src/Files/ValueObjects/MimeType.php index 3dfd18da..966f7423 100644 --- a/src/Files/ValueObjects/MimeType.php +++ b/src/Files/ValueObjects/MimeType.php @@ -72,6 +72,7 @@ final class MimeType 'ogg' => 'audio/ogg', 'flac' => 'audio/flac', 'm4a' => 'audio/m4a', + 'aac' => 'audio/aac', // Video 'mp4' => 'video/mp4', @@ -130,6 +131,27 @@ public function __construct(string $value) $this->value = strtolower($value); } + /** + * Gets the primary known file extension for this MIME type. + * + * @since n.e.x.t + * + * @return string The file extension (without the dot). + * @throws InvalidArgumentException If no known extension exists for this MIME type. + */ + public function toExtension(): string + { + // Reverse lookup for the MIME type to find the extension. + $extension = array_search($this->value, self::$extensionMap, true); + if ($extension === false) { + throw new InvalidArgumentException( + sprintf('No known extension for MIME type: %s', $this->value) + ); + } + + return $extension; + } + /** * Creates a MimeType from a file extension. * diff --git a/src/ProviderImplementations/Anthropic/AnthropicApiKeyRequestAuthentication.php b/src/ProviderImplementations/Anthropic/AnthropicApiKeyRequestAuthentication.php new file mode 100644 index 00000000..468cefc5 --- /dev/null +++ b/src/ProviderImplementations/Anthropic/AnthropicApiKeyRequestAuthentication.php @@ -0,0 +1,32 @@ +withHeader('anthropic-version', self::ANTHROPIC_API_VERSION); + + // Add the API key to the request headers. + return $request->withHeader('x-api-key', $this->apiKey); + } +} diff --git a/src/ProviderImplementations/Anthropic/AnthropicModelMetadataDirectory.php b/src/ProviderImplementations/Anthropic/AnthropicModelMetadataDirectory.php new file mode 100644 index 00000000..0e75c714 --- /dev/null +++ b/src/ProviderImplementations/Anthropic/AnthropicModelMetadataDirectory.php @@ -0,0 +1,143 @@ + + * } + */ +class AnthropicModelMetadataDirectory extends AbstractOpenAiCompatibleModelMetadataDirectory +{ + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getRequestAuthentication(): RequestAuthenticationInterface + { + /* + * Since we're calling the primary Anthropic API models endpoint here, we need to use the Anthropic specific + * API key authentication class. + */ + $requestAuthentication = parent::getRequestAuthentication(); + if (!$requestAuthentication instanceof ApiKeyRequestAuthentication) { + return $requestAuthentication; + } + return new AnthropicApiKeyRequestAuthentication($requestAuthentication->getApiKey()); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function createRequest(HttpMethodEnum $method, string $path, array $headers = [], $data = null): Request + { + return new Request( + $method, + AnthropicProvider::BASE_URI . '/' . ltrim($path, '/'), + $headers, + $data + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function parseResponseToModelMetadataList(Response $response): array + { + /** @var ModelsResponseData $responseData */ + $responseData = $response->getData(); + if (!isset($responseData['data']) || !$responseData['data']) { + throw new RuntimeException( + 'Unexpected API response: Missing the data key.' + ); + } + + // Unfortunately, the Anthropic API does not return model capabilities, so we have to hardcode them here. + $anthropicCapabilities = [ + CapabilityEnum::textGeneration(), + CapabilityEnum::chatHistory(), + ]; + $anthropicOptions = [ + new SupportedOption(OptionEnum::systemInstruction()), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::maxTokens()), + new SupportedOption(OptionEnum::temperature()), + new SupportedOption(OptionEnum::topP()), + new SupportedOption(OptionEnum::stopSequences()), + new SupportedOption(OptionEnum::presencePenalty()), + new SupportedOption(OptionEnum::frequencyPenalty()), + new SupportedOption(OptionEnum::logprobs()), + new SupportedOption(OptionEnum::topLogprobs()), + new SupportedOption(OptionEnum::outputMimeType(), ['text/plain', 'application/json']), + new SupportedOption(OptionEnum::outputSchema()), + new SupportedOption(OptionEnum::functionDeclarations()), + new SupportedOption(OptionEnum::customOptions()), + new SupportedOption( + OptionEnum::inputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + ] + ), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::text()]]), + ]; + $anthropicWebSearchOptions = array_merge($anthropicOptions, [ + new SupportedOption(OptionEnum::webSearch()), + ]); + + $modelsData = (array) $responseData['data']; + + return array_values( + array_map( + static function (array $modelData) use ( + $anthropicCapabilities, + $anthropicOptions, + $anthropicWebSearchOptions + ): ModelMetadata { + $modelId = $modelData['id']; + $modelCaps = $anthropicCapabilities; + if (!preg_match('/^claude-3-[a-z]+/', $modelId)) { + // Only models newer than Claude 3 support web search. + $modelOptions = $anthropicWebSearchOptions; + } else { + $modelOptions = $anthropicOptions; + } + + $modelName = $modelData['display_name'] ?? $modelId; + + return new ModelMetadata( + $modelId, + $modelName, + $modelCaps, + $modelOptions + ); + }, + $modelsData + ) + ); + } +} diff --git a/src/ProviderImplementations/Anthropic/AnthropicProvider.php b/src/ProviderImplementations/Anthropic/AnthropicProvider.php new file mode 100644 index 00000000..af598bfe --- /dev/null +++ b/src/ProviderImplementations/Anthropic/AnthropicProvider.php @@ -0,0 +1,83 @@ +getSupportedCapabilities(); + foreach ($capabilities as $capability) { + if ($capability->isTextGeneration()) { + return new AnthropicTextGenerationModel($modelMetadata, $providerMetadata); + } + } + + throw new RuntimeException( + 'Unsupported model capabilities: ' . implode(', ', $capabilities) + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata( + 'anthropic', + 'Anthropic', + ProviderTypeEnum::cloud() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + // Check valid API access by attempting to list models. + return new ListModelsApiBasedProviderAvailability( + static::modelMetadataDirectory() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new AnthropicModelMetadataDirectory(); + } +} diff --git a/src/ProviderImplementations/Anthropic/AnthropicTextGenerationModel.php b/src/ProviderImplementations/Anthropic/AnthropicTextGenerationModel.php new file mode 100644 index 00000000..e7f48fc7 --- /dev/null +++ b/src/ProviderImplementations/Anthropic/AnthropicTextGenerationModel.php @@ -0,0 +1,32 @@ +withHeader('X-Goog-Api-Key', $this->apiKey); + } +} diff --git a/src/ProviderImplementations/Google/GoogleModelMetadataDirectory.php b/src/ProviderImplementations/Google/GoogleModelMetadataDirectory.php new file mode 100644 index 00000000..99f72ea6 --- /dev/null +++ b/src/ProviderImplementations/Google/GoogleModelMetadataDirectory.php @@ -0,0 +1,237 @@ +, + * displayName?: string + * }> + * } + */ +class GoogleModelMetadataDirectory extends AbstractOpenAiCompatibleModelMetadataDirectory +{ + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getRequestAuthentication(): RequestAuthenticationInterface + { + /* + * Since we're calling the primary Google API models endpoint here, we need to use the Google specific API key + * authentication class. + */ + $requestAuthentication = parent::getRequestAuthentication(); + if (!$requestAuthentication instanceof ApiKeyRequestAuthentication) { + return $requestAuthentication; + } + return new GoogleApiKeyRequestAuthentication($requestAuthentication->getApiKey()); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function createRequest(HttpMethodEnum $method, string $path, array $headers = [], $data = null): Request + { + /* + * We don't call Google's OpenAI compatible models endpoint here because it provides fewer details about the + * models than the primary models endpoint. + * For Google's models endpoint, set pageSize=1000 which is the maximum page size. + * This allows us to retrieve all models in one go. + */ + if ($path === 'models' && $data === null) { + $data = ['pageSize' => 1000]; + } + return new Request( + $method, + GoogleProvider::BASE_URI . '/' . ltrim($path, '/'), + $headers, + $data + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function parseResponseToModelMetadataList(Response $response): array + { + /** @var ModelsResponseData $responseData */ + $responseData = $response->getData(); + if (!isset($responseData['models']) || !$responseData['models']) { + throw new RuntimeException( + 'Unexpected API response: Missing the models key.' + ); + } + + $geminiCapabilities = [ + CapabilityEnum::textGeneration(), + CapabilityEnum::chatHistory(), + ]; + $geminiBaseOptions = [ + new SupportedOption(OptionEnum::systemInstruction()), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::maxTokens()), + new SupportedOption(OptionEnum::temperature()), + new SupportedOption(OptionEnum::topP()), + new SupportedOption(OptionEnum::stopSequences()), + new SupportedOption(OptionEnum::presencePenalty()), + new SupportedOption(OptionEnum::frequencyPenalty()), + new SupportedOption(OptionEnum::logprobs()), + new SupportedOption(OptionEnum::topLogprobs()), + new SupportedOption(OptionEnum::outputMimeType(), ['text/plain', 'application/json']), + new SupportedOption(OptionEnum::outputSchema()), + new SupportedOption(OptionEnum::functionDeclarations()), + new SupportedOption(OptionEnum::customOptions()), + ]; + $geminiLegacyOptions = array_merge($geminiBaseOptions, [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::text()]]), + ]); + $geminiOptions = array_merge($geminiBaseOptions, [ + new SupportedOption( + OptionEnum::inputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], + ] + ), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::text()]]), + ]); + $geminiWebSearchOptions = array_merge($geminiOptions, [ + new SupportedOption(OptionEnum::webSearch()), + ]); + $geminiMultimodalImageOutputOptions = array_merge($geminiBaseOptions, [ + new SupportedOption( + OptionEnum::inputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], + ] + ), + new SupportedOption( + OptionEnum::outputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + ] + ), + ]); + $imagenCapabilities = [ + CapabilityEnum::imageGeneration(), + ]; + $imagenOptions = [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::image()]]), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::outputMimeType(), ['image/png', 'image/jpeg', 'image/webp']), + new SupportedOption(OptionEnum::outputFileType(), [FileTypeEnum::inline()]), + new SupportedOption(OptionEnum::outputMediaOrientation(), [ + MediaOrientationEnum::square(), + MediaOrientationEnum::landscape(), + MediaOrientationEnum::portrait(), + ]), + new SupportedOption(OptionEnum::outputMediaAspectRatio(), ['1:1', '16:9', '4:3', '9:16', '3:4']), + ]; + + $modelsData = (array) $responseData['models']; + + return array_values( + array_map( + static function (array $modelData) use ( + $geminiCapabilities, + $geminiLegacyOptions, + $geminiOptions, + $geminiWebSearchOptions, + $geminiMultimodalImageOutputOptions, + $imagenCapabilities, + $imagenOptions + ): ModelMetadata { + $modelId = $modelData['baseModelId'] ?? $modelData['name']; + if (str_starts_with($modelId, 'models/')) { + $modelId = substr($modelId, 7); + } + if ( + isset($modelData['supportedGenerationMethods']) && + is_array($modelData['supportedGenerationMethods']) && + in_array('generateContent', $modelData['supportedGenerationMethods'], true) + ) { + $modelCaps = $geminiCapabilities; + if ( + str_starts_with($modelId, 'gemini-1.0') || + str_starts_with($modelId, 'gemini-pro') // 'gemini-pro' without version refers to 1.0. + ) { + $modelOptions = $geminiLegacyOptions; + } else { + if ( + // Web search is supported by Gemini 2.0 and newer. + str_starts_with($modelId, 'gemini-') && + ! str_starts_with($modelId, 'gemini-1.5-') + ) { + $modelOptions = $geminiWebSearchOptions; + } elseif ( + // New multimodal output model for image generation. + str_contains($modelId, 'image-generation') || + str_starts_with($modelId, 'gemini-2.0-flash-exp') + ) { + $modelOptions = $geminiMultimodalImageOutputOptions; + } else { + $modelOptions = $geminiOptions; + } + } + } elseif ( + isset($modelData['supportedGenerationMethods']) && + is_array($modelData['supportedGenerationMethods']) && + in_array('predict', $modelData['supportedGenerationMethods'], true) + ) { + $modelCaps = $imagenCapabilities; + $modelOptions = $imagenOptions; + } else { + $modelCaps = []; + $modelOptions = []; + } + + $modelName = $modelData['displayName'] ?? $modelId; + + return new ModelMetadata( + $modelId, + $modelName, + $modelCaps, + $modelOptions + ); + }, + $modelsData + ) + ); + } +} diff --git a/src/ProviderImplementations/Google/GoogleProvider.php b/src/ProviderImplementations/Google/GoogleProvider.php new file mode 100644 index 00000000..2aa7bc79 --- /dev/null +++ b/src/ProviderImplementations/Google/GoogleProvider.php @@ -0,0 +1,89 @@ +getSupportedCapabilities(); + foreach ($capabilities as $capability) { + if ($capability->isTextGeneration()) { + return new GoogleTextGenerationModel($modelMetadata, $providerMetadata); + } + if ($capability->isImageGeneration()) { + // TODO: Implement GoogleImageGenerationModel. + throw new RuntimeException( + 'Google image generation model class is not yet implemented.' + ); + } + } + + throw new RuntimeException( + 'Unsupported model capabilities: ' . implode(', ', $capabilities) + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata( + 'google', + 'Google', + ProviderTypeEnum::cloud() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + // Check valid API access by attempting to list models. + return new ListModelsApiBasedProviderAvailability( + static::modelMetadataDirectory() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new GoogleModelMetadataDirectory(); + } +} diff --git a/src/ProviderImplementations/Google/GoogleTextGenerationModel.php b/src/ProviderImplementations/Google/GoogleTextGenerationModel.php new file mode 100644 index 00000000..e8066681 --- /dev/null +++ b/src/ProviderImplementations/Google/GoogleTextGenerationModel.php @@ -0,0 +1,32 @@ + + * } + */ +class OpenAiModelMetadataDirectory extends AbstractOpenAiCompatibleModelMetadataDirectory +{ + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function createRequest(HttpMethodEnum $method, string $path, array $headers = [], $data = null): Request + { + return new Request( + $method, + OpenAiProvider::BASE_URI . '/' . ltrim($path, '/'), + $headers, + $data + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected function parseResponseToModelMetadataList(Response $response): array + { + /** @var ModelsResponseData $responseData */ + $responseData = $response->getData(); + if (!isset($responseData['data']) || !$responseData['data']) { + throw new RuntimeException( + 'Unexpected API response: Missing the data key.' + ); + } + + // Unfortunately, the OpenAI API does not return model capabilities, so we have to hardcode them here. + $gptCapabilities = [ + CapabilityEnum::textGeneration(), + CapabilityEnum::chatHistory(), + ]; + $gptBaseOptions = [ + new SupportedOption(OptionEnum::systemInstruction()), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::maxTokens()), + new SupportedOption(OptionEnum::temperature()), + new SupportedOption(OptionEnum::topP()), + new SupportedOption(OptionEnum::stopSequences()), + new SupportedOption(OptionEnum::presencePenalty()), + new SupportedOption(OptionEnum::frequencyPenalty()), + new SupportedOption(OptionEnum::logprobs()), + new SupportedOption(OptionEnum::topLogprobs()), + new SupportedOption(OptionEnum::outputMimeType(), ['text/plain', 'application/json']), + new SupportedOption(OptionEnum::outputSchema()), + new SupportedOption(OptionEnum::functionDeclarations()), + new SupportedOption(OptionEnum::customOptions()), + ]; + $gptOptions = array_merge($gptBaseOptions, [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::text()]]), + ]); + $gptMultimodalInputOptions = array_merge($gptBaseOptions, [ + new SupportedOption( + OptionEnum::inputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], + ] + ), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::text()]]), + ]); + $gptMultimodalSpeechOutputOptions = array_merge($gptBaseOptions, [ + new SupportedOption( + OptionEnum::inputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], + ] + ), + new SupportedOption( + OptionEnum::outputModalities(), + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::audio()], + ] + ), + ]); + $imageCapabilities = [ + CapabilityEnum::imageGeneration(), + ]; + $dalleImageOptions = [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::image()]]), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::outputMimeType(), ['image/png']), + new SupportedOption(OptionEnum::outputFileType(), [FileTypeEnum::inline(), FileTypeEnum::remote()]), + new SupportedOption(OptionEnum::outputMediaOrientation(), [ + MediaOrientationEnum::square(), + MediaOrientationEnum::landscape(), + MediaOrientationEnum::portrait(), + ]), + new SupportedOption(OptionEnum::outputMediaAspectRatio(), ['1:1', '7:4', '4:7']), + ]; + $gptImageOptions = [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::image()]]), + new SupportedOption(OptionEnum::candidateCount()), + new SupportedOption(OptionEnum::outputMimeType(), ['image/png', 'image/jpeg', 'image/webp']), + new SupportedOption(OptionEnum::outputFileType(), [FileTypeEnum::inline()]), + new SupportedOption(OptionEnum::outputMediaOrientation(), [ + MediaOrientationEnum::square(), + MediaOrientationEnum::landscape(), + MediaOrientationEnum::portrait(), + ]), + new SupportedOption(OptionEnum::outputMediaAspectRatio(), ['1:1', '3:2', '2:3']), + ]; + $ttsCapabilities = [ + CapabilityEnum::textToSpeechConversion(), + ]; + $ttsOptions = [ + new SupportedOption(OptionEnum::inputModalities(), [[ModalityEnum::text()]]), + new SupportedOption(OptionEnum::outputModalities(), [[ModalityEnum::audio()]]), + new SupportedOption(OptionEnum::outputMimeType(), [ + 'audio/mpeg', + 'audio/ogg', + 'audio/wav', + 'audio/flac', + 'audio/aac', + ]), + new SupportedOption(OptionEnum::outputSpeechVoice()), + ]; + + $modelsData = (array) $responseData['data']; + + return array_values( + array_map( + static function (array $modelData) use ( + $gptCapabilities, + $gptOptions, + $gptMultimodalInputOptions, + $gptMultimodalSpeechOutputOptions, + $imageCapabilities, + $dalleImageOptions, + $gptImageOptions, + $ttsCapabilities, + $ttsOptions + ): ModelMetadata { + $modelId = $modelData['id']; + if ( + str_starts_with($modelId, 'dall-e-') || + str_starts_with($modelId, 'gpt-image-') + ) { + $modelCaps = $imageCapabilities; + if (str_starts_with($modelId, 'gpt-image-')) { + $modelOptions = $gptImageOptions; + } else { + $modelOptions = $dalleImageOptions; + } + } elseif ( + str_starts_with($modelId, 'tts-') || + str_contains($modelId, '-tts') + ) { + $modelCaps = $ttsCapabilities; + $modelOptions = $ttsOptions; + } elseif ( + (str_starts_with($modelId, 'gpt-') || str_starts_with($modelId, 'o1-')) + && !str_contains($modelId, '-instruct') + && !str_contains($modelId, '-realtime') + ) { + if (str_starts_with($modelId, 'gpt-4o')) { + $modelCaps = $gptCapabilities; + $modelOptions = $gptMultimodalInputOptions; + // New multimodal output model for audio generation. + if (str_contains($modelId, '-audio')) { + $modelOptions = $gptMultimodalSpeechOutputOptions; + } + } elseif (!str_contains($modelId, '-audio')) { + $modelCaps = $gptCapabilities; + $modelOptions = $gptOptions; + } else { + $modelCaps = []; + $modelOptions = []; + } + } else { + $modelCaps = []; + $modelOptions = []; + } + + return new ModelMetadata( + $modelId, + $modelId, // The OpenAI API does not return a display name. + $modelCaps, + $modelOptions + ); + }, + $modelsData + ) + ); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiProvider.php b/src/ProviderImplementations/OpenAi/OpenAiProvider.php new file mode 100644 index 00000000..1021664a --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiProvider.php @@ -0,0 +1,95 @@ +getSupportedCapabilities(); + foreach ($capabilities as $capability) { + if ($capability->isTextGeneration()) { + return new OpenAiTextGenerationModel($modelMetadata, $providerMetadata); + } + if ($capability->isImageGeneration()) { + // TODO: Implement OpenAiImageGenerationModel. + throw new RuntimeException( + 'OpenAI image generation model class is not yet implemented.' + ); + } + if ($capability->isTextToSpeechConversion()) { + // TODO: Implement OpenAiTextToSpeechConversionModel. + throw new RuntimeException( + 'OpenAI text to speech conversion model class is not yet implemented.' + ); + } + } + + throw new RuntimeException( + 'Unsupported model capabilities: ' . implode(', ', $capabilities) + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata( + 'openai', + 'OpenAI', + ProviderTypeEnum::cloud() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + // Check valid API access by attempting to list models. + return new ListModelsApiBasedProviderAvailability( + static::modelMetadataDirectory() + ); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new OpenAiModelMetadataDirectory(); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php b/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php new file mode 100644 index 00000000..66823ad7 --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php @@ -0,0 +1,32 @@ + Cache for provider metadata per class. + */ + private static array $metadataCache = []; + + /** + * @var array Cache for provider availability per class. + */ + private static array $availabilityCache = []; + + /** + * @var array Cache for model metadata directory per class. + */ + private static array $modelMetadataDirectoryCache = []; + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public static function metadata(): ProviderMetadata + { + $className = static::class; + if (!isset(self::$metadataCache[$className])) { + self::$metadataCache[$className] = static::createProviderMetadata(); + } + return self::$metadataCache[$className]; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public static function model(string $modelId, ?ModelConfig $modelConfig = null): ModelInterface + { + $providerMetadata = static::metadata(); + $modelMetadata = static::modelMetadataDirectory()->getModelMetadata($modelId); + + $model = static::createModel($modelMetadata, $providerMetadata); + if ($modelConfig) { + $model->setConfig($modelConfig); + } + return $model; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public static function availability(): ProviderAvailabilityInterface + { + $className = static::class; + if (!isset(self::$availabilityCache[$className])) { + self::$availabilityCache[$className] = static::createProviderAvailability(); + } + return self::$availabilityCache[$className]; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public static function modelMetadataDirectory(): ModelMetadataDirectoryInterface + { + $className = static::class; + if (!isset(self::$modelMetadataDirectoryCache[$className])) { + self::$modelMetadataDirectoryCache[$className] = static::createModelMetadataDirectory(); + } + return self::$modelMetadataDirectoryCache[$className]; + } + + /** + * Creates a model instance based on the given model metadata and provider metadata. + * + * @since n.e.x.t + * + * @param ModelMetadata $modelMetadata The model metadata. + * @param ProviderMetadata $providerMetadata The provider metadata. + * @return ModelInterface The new model instance. + */ + abstract protected static function createModel( + ModelMetadata $modelMetadata, + ProviderMetadata $providerMetadata + ): ModelInterface; + + /** + * Creates the provider metadata instance. + * + * @since n.e.x.t + * + * @return ProviderMetadata The provider metadata. + */ + abstract protected static function createProviderMetadata(): ProviderMetadata; + + /** + * Creates the provider availability instance. + * + * @since n.e.x.t + * + * @return ProviderAvailabilityInterface The provider availability. + */ + abstract protected static function createProviderAvailability(): ProviderAvailabilityInterface; + + /** + * Creates the model metadata directory instance. + * + * @since n.e.x.t + * + * @return ModelMetadataDirectoryInterface The model metadata directory. + */ + abstract protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface; +} diff --git a/src/Providers/ApiBasedImplementation/AbstractApiBasedModel.php b/src/Providers/ApiBasedImplementation/AbstractApiBasedModel.php new file mode 100644 index 00000000..7d6a34ef --- /dev/null +++ b/src/Providers/ApiBasedImplementation/AbstractApiBasedModel.php @@ -0,0 +1,101 @@ +metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = ModelConfig::fromArray([]); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function metadata(): ModelMetadata + { + return $this->metadata; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function getConfig(): ModelConfig + { + return $this->config; + } +} diff --git a/src/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectory.php b/src/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectory.php new file mode 100644 index 00000000..aa27265f --- /dev/null +++ b/src/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectory.php @@ -0,0 +1,94 @@ + Map of model ID to model metadata, effectively for caching. + */ + private ?array $modelMetadataMap = null; + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function listModelMetadata(): array + { + $modelsMetadata = $this->getModelMetadataMap(); + return array_values($modelsMetadata); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function hasModelMetadata(string $modelId): bool + { + $modelsMetadata = $this->getModelMetadataMap(); + return isset($modelsMetadata[$modelId]); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function getModelMetadata(string $modelId): ModelMetadata + { + $modelsMetadata = $this->getModelMetadataMap(); + if (!isset($modelsMetadata[$modelId])) { + throw new InvalidArgumentException( + sprintf('No model with ID %s was found in the provider', $modelId) + ); + } + return $modelsMetadata[$modelId]; + } + + /** + * Returns the map of model ID to model metadata for all models from the provider. + * + * @since n.e.x.t + * + * @return array Map of model ID to model metadata. + */ + private function getModelMetadataMap(): array + { + if ($this->modelMetadataMap === null) { + $this->modelMetadataMap = $this->sendListModelsRequest(); + } + return $this->modelMetadataMap; + } + + /** + * Sends the API request to list models from the provider and returns the map of model ID to model metadata. + * + * @since n.e.x.t + * + * @return array Map of model ID to model metadata. + */ + abstract protected function sendListModelsRequest(): array; +} diff --git a/src/Providers/ApiBasedImplementation/GenerateTextApiBasedProviderAvailability.php b/src/Providers/ApiBasedImplementation/GenerateTextApiBasedProviderAvailability.php new file mode 100644 index 00000000..37fcf998 --- /dev/null +++ b/src/Providers/ApiBasedImplementation/GenerateTextApiBasedProviderAvailability.php @@ -0,0 +1,76 @@ +model = $model; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function isConfigured(): bool + { + // Set config to use as few resources as possible for the test. + $modelConfig = ModelConfig::fromArray([ + ModelConfig::KEY_MAX_TOKENS => 1, + ]); + $this->model->setConfig($modelConfig); + + try { + // Attempt to generate text to check if the provider is available. + $this->model->generateTextResult([ + new Message( + MessageRoleEnum::user(), + [new MessagePart('a')] + ), + ]); + return true; + } catch (Exception $e) { + // If an exception occurs, the provider is not available. + return false; + } + } +} diff --git a/src/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailability.php b/src/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailability.php new file mode 100644 index 00000000..ff1010ac --- /dev/null +++ b/src/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailability.php @@ -0,0 +1,56 @@ +modelMetadataDirectory = $modelMetadataDirectory; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function isConfigured(): bool + { + try { + // Attempt to list models to check if the provider is available. + $this->modelMetadataDirectory->listModelMetadata(); + return true; + } catch (Exception $e) { + // If an exception occurs, the provider is not available. + return false; + } + } +} diff --git a/src/Providers/Contracts/ProviderInterface.php b/src/Providers/Contracts/ProviderInterface.php index 42ec540d..5ff5c808 100644 --- a/src/Providers/Contracts/ProviderInterface.php +++ b/src/Providers/Contracts/ProviderInterface.php @@ -33,7 +33,7 @@ public static function metadata(): ProviderMetadata; * * @since n.e.x.t * - * @param string $modelId Model identifier. + * @param string $modelId Model identifier. * @param ?ModelConfig $modelConfig Model configuration. * @return ModelInterface Model instance. * @throws InvalidArgumentException If model not found or configuration invalid. diff --git a/src/Providers/Http/Contracts/RequestAuthenticationInterface.php b/src/Providers/Http/Contracts/RequestAuthenticationInterface.php index a49aedb2..038481ae 100644 --- a/src/Providers/Http/Contracts/RequestAuthenticationInterface.php +++ b/src/Providers/Http/Contracts/RequestAuthenticationInterface.php @@ -12,7 +12,8 @@ * * @since n.e.x.t */ -interface RequestAuthenticationInterface extends WithJsonSchemaInterface +interface RequestAuthenticationInterface extends + WithJsonSchemaInterface { /** * Authenticates an HTTP request. @@ -20,7 +21,7 @@ interface RequestAuthenticationInterface extends WithJsonSchemaInterface * @since n.e.x.t * * @param Request $request The request to authenticate. - * @return void + * @return Request The authenticated request. */ - public function authenticate(Request $request): void; + public function authenticateRequest(Request $request): Request; } diff --git a/src/Providers/Http/DTO/ApiKeyRequestAuthentication.php b/src/Providers/Http/DTO/ApiKeyRequestAuthentication.php new file mode 100644 index 00000000..7a726ff2 --- /dev/null +++ b/src/Providers/Http/DTO/ApiKeyRequestAuthentication.php @@ -0,0 +1,114 @@ + + */ +class ApiKeyRequestAuthentication extends AbstractDataTransferObject implements RequestAuthenticationInterface +{ + public const KEY_API_KEY = 'apiKey'; + + /** + * @var string The API key used for authentication. + */ + protected string $apiKey; + + /** + * Constructor. + * + * @since n.e.x.t + * + * @param string $apiKey The API key used for authentication. + */ + public function __construct(string $apiKey) + { + $this->apiKey = $apiKey; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function authenticateRequest(Request $request): Request + { + // Add the API key to the request headers. + return $request->withHeader('Authorization', 'Bearer ' . $this->apiKey); + } + + /** + * Gets the API key. + * + * @since n.e.x.t + * + * @return string The API key. + */ + public function getApiKey(): string + { + return $this->apiKey; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + * + * @since n.e.x.t + * + * @return ApiKeyRequestAuthenticationArrayShape + */ + public function toArray(): array + { + return [ + self::KEY_API_KEY => $this->apiKey, + ]; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + * + * @since n.e.x.t + */ + public static function fromArray(array $array): self + { + static::validateFromArrayData($array, [self::KEY_API_KEY]); + + return new self($array[self::KEY_API_KEY]); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [ + self::KEY_API_KEY => [ + 'type' => 'string', + 'title' => 'API Key', + 'description' => 'The API key used for authentication.', + ], + ], + 'required' => [self::KEY_API_KEY], + ]; + } +} diff --git a/src/Providers/Http/Exception/ResponseException.php b/src/Providers/Http/Exception/ResponseException.php new file mode 100644 index 00000000..a835995a --- /dev/null +++ b/src/Providers/Http/Exception/ResponseException.php @@ -0,0 +1,16 @@ +httpTransporter = $httpTransporter; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getHttpTransporter(): HttpTransporterInterface + { + if ($this->httpTransporter === null) { + throw new RuntimeException( + 'HttpTransporterInterface instance not set. Make sure you use the AiClient class for all requests.' + ); + } + return $this->httpTransporter; + } +} diff --git a/src/Providers/Http/Traits/WithRequestAuthenticationTrait.php b/src/Providers/Http/Traits/WithRequestAuthenticationTrait.php new file mode 100644 index 00000000..3f5dc2d5 --- /dev/null +++ b/src/Providers/Http/Traits/WithRequestAuthenticationTrait.php @@ -0,0 +1,47 @@ +requestAuthentication = $requestAuthentication; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function getRequestAuthentication(): RequestAuthenticationInterface + { + if ($this->requestAuthentication === null) { + throw new RuntimeException( + 'RequestAuthenticationInterface instance not set. ' . + 'Make sure you use the AiClient class for all requests.' + ); + } + return $this->requestAuthentication; + } +} diff --git a/src/Providers/Http/Util/ResponseUtil.php b/src/Providers/Http/Util/ResponseUtil.php new file mode 100644 index 00000000..3bde5d22 --- /dev/null +++ b/src/Providers/Http/Util/ResponseUtil.php @@ -0,0 +1,67 @@ +isSuccessful()) { + return; + } + + $errorMessage = sprintf( + 'Bad status code: %d.', + $response->getStatusCode() + ); + + // Handle common error formats in API responses. + $data = $response->getData(); + if ( + is_array($data) && + isset($data['error']) && + is_array($data['error']) && + isset($data['error']['message']) && + is_string($data['error']['message']) + ) { + $errorMessage .= ' ' . $data['error']['message']; + } elseif ( + is_array($data) && + isset($data['error']) && + is_string($data['error']) + ) { + $errorMessage .= ' ' . $data['error']; + } elseif ( + is_array($data) && + isset($data['message']) && + is_string($data['message']) + ) { + $errorMessage .= ' ' . $data['message']; + } + + throw new ResponseException($errorMessage, $response->getStatusCode()); + } +} diff --git a/src/Providers/Models/Contracts/ModelInterface.php b/src/Providers/Models/Contracts/ModelInterface.php index e0448e0f..1c6adbc7 100644 --- a/src/Providers/Models/Contracts/ModelInterface.php +++ b/src/Providers/Models/Contracts/ModelInterface.php @@ -4,6 +4,7 @@ namespace WordPress\AiClient\Providers\Models\Contracts; +use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; @@ -26,6 +27,15 @@ interface ModelInterface */ public function metadata(): ModelMetadata; + /** + * Returns the metadata for the model's provider. + * + * @since n.e.x.t + * + * @return ProviderMetadata The provider metadata. + */ + public function providerMetadata(): ProviderMetadata; + /** * Sets model configuration. * diff --git a/src/Providers/Models/DTO/ModelConfig.php b/src/Providers/Models/DTO/ModelConfig.php index 86a8e63e..63cf6aba 100644 --- a/src/Providers/Models/DTO/ModelConfig.php +++ b/src/Providers/Models/DTO/ModelConfig.php @@ -45,6 +45,7 @@ * outputSchema?: array, * outputMediaOrientation?: string, * outputMediaAspectRatio?: string, + * outputSpeechVoice?: string, * customOptions?: array * } * @@ -71,8 +72,16 @@ class ModelConfig extends AbstractDataTransferObject public const KEY_OUTPUT_SCHEMA = 'outputSchema'; public const KEY_OUTPUT_MEDIA_ORIENTATION = 'outputMediaOrientation'; public const KEY_OUTPUT_MEDIA_ASPECT_RATIO = 'outputMediaAspectRatio'; + public const KEY_OUTPUT_SPEECH_VOICE = 'outputSpeechVoice'; public const KEY_CUSTOM_OPTIONS = 'customOptions'; + /* + * Note: This key is not an actual model config key, but specified here for convenience. + * It is relevant for model discovery, to determine which models support which input modalities. + * The actual input modalities are part of the message sent to the model, not the model config. + */ + public const KEY_INPUT_MODALITIES = 'inputModalities'; + /** * @var list|null Output modalities for the model. */ @@ -168,6 +177,11 @@ class ModelConfig extends AbstractDataTransferObject */ protected ?string $outputMediaAspectRatio = null; + /** + * @var string|null Output speech voice. + */ + protected ?string $outputSpeechVoice = null; + /** * @var array Custom provider-specific options. */ @@ -662,6 +676,30 @@ public function getOutputMediaAspectRatio(): ?string return $this->outputMediaAspectRatio; } + /** + * Sets the output speech voice. + * + * @since n.e.x.t + * + * @param string $outputSpeechVoice The output speech voice. + */ + public function setOutputSpeechVoice(string $outputSpeechVoice): void + { + $this->outputSpeechVoice = $outputSpeechVoice; + } + + /** + * Gets the output speech voice. + * + * @since n.e.x.t + * + * @return string|null The output speech voice. + */ + public function getOutputSpeechVoice(): ?string + { + return $this->outputSpeechVoice; + } + /** * Sets a single custom option. * @@ -802,6 +840,10 @@ public static function getJsonSchema(): array 'pattern' => '^\d+:\d+$', 'description' => 'Output media aspect ratio.', ], + self::KEY_OUTPUT_SPEECH_VOICE => [ + 'type' => 'string', + 'description' => 'Output speech voice.', + ], self::KEY_CUSTOM_OPTIONS => [ 'type' => 'object', 'additionalProperties' => true, @@ -909,6 +951,10 @@ static function (FunctionDeclaration $function_declaration): array { $data[self::KEY_OUTPUT_MEDIA_ASPECT_RATIO] = $this->outputMediaAspectRatio; } + if ($this->outputSpeechVoice !== null) { + $data[self::KEY_OUTPUT_SPEECH_VOICE] = $this->outputSpeechVoice; + } + if (!empty($this->customOptions)) { $data[self::KEY_CUSTOM_OPTIONS] = $this->customOptions; } @@ -1142,6 +1188,10 @@ static function (array $function_declaration_data): FunctionDeclaration { $config->setOutputMediaAspectRatio($array[self::KEY_OUTPUT_MEDIA_ASPECT_RATIO]); } + if (isset($array[self::KEY_OUTPUT_SPEECH_VOICE])) { + $config->setOutputSpeechVoice($array[self::KEY_OUTPUT_SPEECH_VOICE]); + } + if (isset($array[self::KEY_CUSTOM_OPTIONS])) { $config->setCustomOptions($array[self::KEY_CUSTOM_OPTIONS]); } diff --git a/src/Providers/Models/DTO/SupportedOption.php b/src/Providers/Models/DTO/SupportedOption.php index 1a4d121f..e9291b37 100644 --- a/src/Providers/Models/DTO/SupportedOption.php +++ b/src/Providers/Models/DTO/SupportedOption.php @@ -85,6 +85,21 @@ public function isSupportedValue($value): bool return true; } + // If the value is an array, consider it a set (i.e. order doesn't matter). + if (is_array($value)) { + sort($value); + foreach ($this->supportedValues as $supportedValue) { + if (!is_array($supportedValue)) { + continue; + } + sort($supportedValue); + if ($value === $supportedValue) { + return true; + } + } + return false; + } + return in_array($value, $this->supportedValues, true); } diff --git a/src/Providers/Models/Enums/OptionEnum.php b/src/Providers/Models/Enums/OptionEnum.php index 508cc67d..7c28f593 100644 --- a/src/Providers/Models/Enums/OptionEnum.php +++ b/src/Providers/Models/Enums/OptionEnum.php @@ -31,6 +31,7 @@ * @method static self outputMimeType() Creates an instance for OUTPUT_MIME_TYPE option. * @method static self outputModalities() Creates an instance for OUTPUT_MODALITIES option. * @method static self outputSchema() Creates an instance for OUTPUT_SCHEMA option. + * @method static self outputSpeechVoice() Creates an instance for OUTPUT_SPEECH_VOICE option. * @method static self presencePenalty() Creates an instance for PRESENCE_PENALTY option. * @method static self stopSequences() Creates an instance for STOP_SEQUENCES option. * @method static self systemInstruction() Creates an instance for SYSTEM_INSTRUCTION option. @@ -51,6 +52,7 @@ * @method bool isOutputMimeType() Checks if the option is OUTPUT_MIME_TYPE. * @method bool isOutputModalities() Checks if the option is OUTPUT_MODALITIES. * @method bool isOutputSchema() Checks if the option is OUTPUT_SCHEMA. + * @method bool isOutputSpeechVoice() Checks if the option is OUTPUT_SPEECH_VOICE. * @method bool isPresencePenalty() Checks if the option is PRESENCE_PENALTY. * @method bool isStopSequences() Checks if the option is STOP_SEQUENCES. * @method bool isSystemInstruction() Checks if the option is SYSTEM_INSTRUCTION. diff --git a/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectory.php b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectory.php new file mode 100644 index 00000000..a267eeeb --- /dev/null +++ b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectory.php @@ -0,0 +1,90 @@ +getHttpTransporter(); + + $request = $this->createRequest(HttpMethodEnum::GET(), 'models'); + $request = $this->getRequestAuthentication()->authenticateRequest($request); + $response = $httpTransporter->send($request); + + $this->throwIfNotSuccessful($response); + $modelsMetadataList = $this->parseResponseToModelMetadataList($response); + + // Parse list to map. + $modelMetadataMap = []; + foreach ($modelsMetadataList as $modelMetadata) { + $modelMetadataMap[$modelMetadata->getId()] = $modelMetadata; + } + return $modelMetadataMap; + } + + /** + * Creates a request object for the provider's API. + * + * @since n.e.x.t + * + * @param HttpMethodEnum $method The HTTP method. + * @param string $path The API endpoint path, relative to the base URI. + * @param array> $headers The request headers. + * @param string|array|null $data The request data. + * @return Request The request object. + */ + abstract protected function createRequest( + HttpMethodEnum $method, + string $path, + array $headers = [], + $data = null + ): Request; + + /** + * Throws an exception if the response is not successful. + * + * @since n.e.x.t + * + * @param Response $response The HTTP response to check. + * @throws ResponseException If the response is not successful. + */ + protected function throwIfNotSuccessful(Response $response): void + { + /* + * While this method only calls the utility method, it's important to have it here as a protected method so + * that child classes can override it if needed. + */ + ResponseUtil::throwIfNotSuccessful($response); + } + + /** + * Parses the response from the API endpoint to list models into a list of model metadata objects. + * + * @since n.e.x.t + * + * @param Response $response The response from the API endpoint to list models. + * @return list List of model metadata objects. + */ + abstract protected function parseResponseToModelMetadataList(Response $response): array; +} diff --git a/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php new file mode 100644 index 00000000..d0df3476 --- /dev/null +++ b/src/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModel.php @@ -0,0 +1,756 @@ + + * } + * } + * @phpstan-type MessageData array{ + * role?: string, + * reasoning_content?: string, + * content?: string, + * tool_calls?: list + * } + * @phpstan-type ChoiceData array{ + * message?: MessageData, + * finish_reason?: string + * } + * @phpstan-type UsageData array{ + * prompt_tokens?: int, + * completion_tokens?: int, + * total_tokens?: int + * } + * @phpstan-type ResponseData array{ + * id?: string, + * choices?: list, + * usage?: UsageData + * } + */ +abstract class AbstractOpenAiCompatibleTextGenerationModel extends AbstractApiBasedModel implements + TextGenerationModelInterface +{ + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function generateTextResult(array $prompt): GenerativeAiResult + { + $httpTransporter = $this->getHttpTransporter(); + + $params = $this->prepareGenerateTextParams($prompt); + + $request = $this->createRequest( + HttpMethodEnum::POST(), + 'chat/completions', + ['Content-Type' => 'application/json'], + $params + ); + + // Add authentication credentials to the request. + $request = $this->getRequestAuthentication()->authenticateRequest($request); + + // Send and process the request. + $response = $httpTransporter->send($request); + $this->throwIfNotSuccessful($response); + return $this->parseResponseToGenerativeAiResult($response); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + final public function streamGenerateTextResult(array $prompt): Generator + { + $params = $this->prepareGenerateTextParams($prompt); + + // TODO: Implement streaming support. + throw new RuntimeException( + 'Streaming is not yet implemented.' + ); + } + + /** + * Prepares the given prompt and the model configuration into parameters for the API request. + * + * @since n.e.x.t + * + * @param list $prompt The prompt to generate text for. Either a single message or a list of messages + * from a chat. + * @return array The parameters for the API request. + */ + protected function prepareGenerateTextParams(array $prompt): array + { + $config = $this->getConfig(); + + $params = [ + 'model' => $this->metadata()->getId(), + 'messages' => $this->prepareMessagesParam($prompt, $config->getSystemInstruction()), + ]; + + $outputModalities = $config->getOutputModalities(); + if (is_array($outputModalities)) { + $this->validateOutputModalities($outputModalities); + if (count($outputModalities) > 1) { + $params['modalities'] = $this->prepareOutputModalitiesParam($outputModalities); + } + } + + $candidateCount = $config->getCandidateCount(); + if ($candidateCount !== null) { + $params['n'] = $candidateCount; + } + + $maxTokens = $config->getMaxTokens(); + if ($maxTokens !== null) { + $params['max_tokens'] = $maxTokens; + } + + $temperature = $config->getTemperature(); + if ($temperature !== null) { + $params['temperature'] = $temperature; + } + + $topP = $config->getTopP(); + if ($topP !== null) { + $params['top_p'] = $topP; + } + + $stopSequences = $config->getStopSequences(); + if (is_array($stopSequences)) { + $params['stop'] = $stopSequences; + } + + $presencePenalty = $config->getPresencePenalty(); + if ($presencePenalty !== null) { + $params['presence_penalty'] = $presencePenalty; + } + + $frequencyPenalty = $config->getFrequencyPenalty(); + if ($frequencyPenalty !== null) { + $params['frequency_penalty'] = $frequencyPenalty; + } + + $logprobs = $config->getLogprobs(); + if ($logprobs !== null) { + $params['logprobs'] = $logprobs; + } + + $topLogprobs = $config->getTopLogprobs(); + if ($topLogprobs !== null) { + $params['top_logprobs'] = $topLogprobs; + } + + $functionDeclarations = $config->getFunctionDeclarations(); + if (is_array($functionDeclarations)) { + $params['tools'] = $this->prepareToolsParam($functionDeclarations); + } + + $outputMimeType = $config->getOutputMimeType(); + if ('application/json' === $outputMimeType) { + $outputSchema = $config->getOutputSchema(); + $params['response_format'] = $this->prepareResponseFormatParam($outputSchema); + } + + /* + * Any custom options are added to the parameters as well. + * This allows developers to pass other options that may be more niche or not yet supported by the SDK. + */ + $customOptions = $config->getCustomOptions(); + foreach ($customOptions as $key => $value) { + if (isset($params[$key])) { + throw new InvalidArgumentException( + sprintf( + 'The custom option "%s" conflicts with an existing parameter.', + $key + ) + ); + } + $params[$key] = $value; + } + + return $params; + } + + /** + * Prepares the messages parameter for the API request. + * + * @since n.e.x.t + * + * @param list $messages The messages to prepare. + * @param string|null $systemInstruction An optional system instruction to prepend to the messages. + * @return list> The prepared messages parameter. + */ + protected function prepareMessagesParam(array $messages, ?string $systemInstruction = null): array + { + $messagesParam = array_map( + function (Message $message): array { + // Special case: Function response. + $messageParts = $message->getParts(); + if (count($messageParts) === 1 && $messageParts[0]->getType()->isFunctionResponse()) { + $functionResponse = $messageParts[0]->getFunctionResponse(); + if (!$functionResponse) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The function response typed message part must contain a function response.' + ); + } + return [ + 'role' => 'tool', + 'content' => json_encode($functionResponse->getResponse()), + 'tool_call_id' => $functionResponse->getId(), + ]; + } + return [ + 'role' => $this->getMessageRoleString($message->getRole()), + 'content' => array_values(array_filter(array_map( + [$this, 'getMessagePartContentData'], + $messageParts + ))), + 'tool_calls' => array_values(array_filter(array_map( + [$this, 'getMessagePartToolCallData'], + $messageParts + ))), + ]; + }, + $messages + ); + + if ($systemInstruction) { + array_unshift( + $messagesParam, + [ + /* + * TODO: Replace this with 'developer' in the future. + * See https://platform.openai.com/docs/api-reference/chat/create#chat_create-messages + */ + 'role' => 'system', + 'content' => [ + [ + 'type' => 'text', + 'text' => $systemInstruction, + ], + ], + ] + ); + } + + return $messagesParam; + } + + /** + * Returns the OpenAI API specific role string for the given message role. + * + * @since n.e.x.t + * + * @param MessageRoleEnum $role The message role. + * @return string The role for the API request. + */ + protected function getMessageRoleString(MessageRoleEnum $role): string + { + if ($role === MessageRoleEnum::model()) { + return 'assistant'; + } + return 'user'; + } + + /** + * Returns the OpenAI API specific content data for a message part. + * + * @since n.e.x.t + * + * @param MessagePart $part The message part to get the data for. + * @return ?array The data for the message content part, or null if not applicable. + * @throws InvalidArgumentException If the message part type or data is unsupported. + */ + protected function getMessagePartContentData(MessagePart $part): ?array + { + $type = $part->getType(); + if ($type->isText()) { + return [ + 'type' => 'text', + 'text' => $part->getText(), + ]; + } + if ($type->isFile()) { + $file = $part->getFile(); + if (!$file) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The file typed message part must contain a file.' + ); + } + if ($file->isRemote()) { + if ($file->isImage()) { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $file->getUrl(), + ], + ]; + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported MIME type "%s" for remote file message part.', + $file->getMimeType() + ) + ); + } + // Else, it is an inline file. + if ($file->isImage()) { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $file->getBase64Data(), + ], + ]; + } + if ($file->isAudio()) { + return [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $file->getBase64Data(), + 'format' => $file->getMimeTypeObject()->toExtension(), + ], + ]; + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported MIME type "%s" for inline file message part.', + $file->getMimeType() + ) + ); + } + if ($type->isFunctionCall()) { + // Skip, as this is separately included. See `getMessagePartToolCallData()`. + return null; + } + if ($type->isFunctionResponse()) { + // Special case: Function response. + throw new InvalidArgumentException( + 'The API only allows a single function response, as the only content of the message.' + ); + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported message part type "%s".', + $type + ) + ); + } + + /** + * Returns the OpenAI API specific tool calls data for a message part. + * + * @since n.e.x.t + * + * @param MessagePart $part The message part to get the data for. + * @return ?array The data for the message tool call part, or null if not applicable. + * @throws InvalidArgumentException If the message part type or data is unsupported. + */ + protected function getMessagePartToolCallData(MessagePart $part): ?array + { + $type = $part->getType(); + if ($type->isFunctionCall()) { + $functionCall = $part->getFunctionCall(); + if (!$functionCall) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The function call typed message part must contain a function call.' + ); + } + return [ + 'type' => 'function', + 'id' => $functionCall->getId(), + 'function' => [ + 'name' => $functionCall->getName(), + 'arguments' => json_encode($functionCall->getArgs()), + ], + ]; + } + // All other types are handled in `getMessagePartContentData()`. + return null; + } + + /** + * Validates that the given output modalities to ensure that at least one output modality is text. + * + * @since n.e.x.t + * + * @param array $outputModalities The output modalities to validate. + * @throws InvalidArgumentException If no text output modality is present. + */ + protected function validateOutputModalities(array $outputModalities): void + { + // If no output modalities are set, it's fine, as we can assume text. + if (count($outputModalities) === 0) { + return; + } + + foreach ($outputModalities as $modality) { + if ($modality->isText()) { + return; + } + } + + throw new InvalidArgumentException( + 'A text output modality must be present when generating text.' + ); + } + + /** + * Prepares the output modalities parameter for the API request. + * + * @since n.e.x.t + * + * @param array $modalities The modalities to prepare. + * @return list The prepared modalities parameter. + */ + protected function prepareOutputModalitiesParam(array $modalities): array + { + $prepared = []; + foreach ($modalities as $modality) { + if ($modality->isText()) { + $prepared[] = 'text'; + } elseif ($modality->isImage()) { + $prepared[] = 'image'; + } elseif ($modality->isAudio()) { + $prepared[] = 'audio'; + } else { + throw new InvalidArgumentException( + sprintf( + 'Unsupported output modality "%s".', + $modality + ) + ); + } + } + return $prepared; + } + + /** + * Prepares the tools parameter for the API request. + * + * @since n.e.x.t + * + * @param list $functionDeclarations The function declarations. + * @return list> The prepared tools parameter. + */ + protected function prepareToolsParam(array $functionDeclarations): array + { + $tools = []; + foreach ($functionDeclarations as $functionDeclaration) { + $tools[] = [ + 'type' => 'function', + 'function' => $functionDeclaration->toArray(), + ]; + } + + return $tools; + } + + /** + * Prepares the response format parameter for the API request. + * + * This is only called if the output MIME type is `application/json`. + * + * @since n.e.x.t + * + * @param array|null $outputSchema The output schema. + * @return array The prepared response format parameter. + */ + protected function prepareResponseFormatParam(?array $outputSchema): array + { + if (is_array($outputSchema)) { + return [ + 'type' => 'json_schema', + 'json_schema' => $outputSchema, + ]; + } + + return [ + 'type' => 'json_object', + ]; + } + + /** + * Creates a request object for the provider's API. + * + * @since n.e.x.t + * + * @param HttpMethodEnum $method The HTTP method. + * @param string $path The API endpoint path, relative to the base URI. + * @param array> $headers The request headers. + * @param string|array|null $data The request data. + * @return Request The request object. + */ + abstract protected function createRequest( + HttpMethodEnum $method, + string $path, + array $headers = [], + $data = null + ): Request; + + /** + * Throws an exception if the response is not successful. + * + * @since n.e.x.t + * + * @param Response $response The HTTP response to check. + * @throws ResponseException If the response is not successful. + */ + protected function throwIfNotSuccessful(Response $response): void + { + /* + * While this method only calls the utility method, it's important to have it here as a protected method so + * that child classes can override it if needed. + */ + ResponseUtil::throwIfNotSuccessful($response); + } + + /** + * Parses the response from the API endpoint to a generative AI result. + * + * @since n.e.x.t + * + * @param Response $response The response from the API endpoint. + * @return GenerativeAiResult The parsed generative AI result. + */ + protected function parseResponseToGenerativeAiResult(Response $response): GenerativeAiResult + { + /** @var ResponseData $responseData */ + $responseData = $response->getData(); + if (!isset($responseData['choices']) || !$responseData['choices']) { + throw new RuntimeException( + 'Unexpected API response: Missing the choices key.' + ); + } + if (!is_array($responseData['choices'])) { + throw new RuntimeException( + 'Unexpected API response: The choices key must contain an array.' + ); + } + + $candidates = []; + foreach ($responseData['choices'] as $choiceData) { + if (!is_array($choiceData) || array_is_list($choiceData)) { + throw new RuntimeException( + 'Unexpected API response: Each element in the choices key must be an associative array.' + ); + } + + $candidates[] = $this->parseResponseChoiceToCandidate($choiceData); + } + + $id = isset($responseData['id']) && is_string($responseData['id']) ? $responseData['id'] : ''; + + if (isset($responseData['usage']) && is_array($responseData['usage'])) { + $usage = $responseData['usage']; + + $tokenUsage = new TokenUsage( + $usage['prompt_tokens'] ?? 0, + $usage['completion_tokens'] ?? 0, + $usage['total_tokens'] ?? 0 + ); + } else { + $tokenUsage = new TokenUsage(0, 0, 0); + } + + // Use any other data from the response as provider metadata. + $additionalData = $responseData; + unset($additionalData['id'], $additionalData['choices'], $additionalData['usage']); + + return new GenerativeAiResult( + $id, + $candidates, + $tokenUsage, + $this->providerMetadata(), + $this->metadata(), + $additionalData + ); + } + + /** + * Parses a single choice from the API response into a Candidate object. + * + * @since n.e.x.t + * + * @param ChoiceData $choiceData The choice data from the API response. + * @return Candidate The parsed candidate. + * @throws RuntimeException If the choice data is invalid. + */ + protected function parseResponseChoiceToCandidate(array $choiceData): Candidate + { + if ( + !isset($choiceData['message']) || + !is_array($choiceData['message']) || + array_is_list($choiceData['message']) + ) { + throw new RuntimeException( + 'Unexpected API response: Each choice must contain a message key with an associative array.' + ); + } + + if (!isset($choiceData['finish_reason']) || !is_string($choiceData['finish_reason'])) { + throw new RuntimeException( + 'Unexpected API response: Each choice must contain a finish_reason key with a string value.' + ); + } + + $messageData = $choiceData['message']; + $message = $this->parseResponseChoiceMessage($messageData); + + switch ($choiceData['finish_reason']) { + case 'stop': + $finishReason = FinishReasonEnum::stop(); + break; + case 'length': + $finishReason = FinishReasonEnum::length(); + break; + case 'content_filter': + $finishReason = FinishReasonEnum::contentFilter(); + break; + case 'tool_calls': + $finishReason = FinishReasonEnum::toolCalls(); + break; + default: + throw new RuntimeException( + sprintf( + 'Unexpected API response: Invalid finish reason "%s".', + $choiceData['finish_reason'] + ) + ); + } + + return new Candidate($message, $finishReason); + } + + /** + * Parses the message from a choice in the API response. + * + * @since n.e.x.t + * + * @param MessageData $messageData The message data from the API response. + * @return Message The parsed message. + */ + protected function parseResponseChoiceMessage(array $messageData): Message + { + $role = isset($messageData['role']) && 'user' === $messageData['role'] + ? MessageRoleEnum::user() + : MessageRoleEnum::model(); + + $parts = $this->parseResponseChoiceMessageParts($messageData); + + return new Message($role, $parts); + } + + /** + * Parses the message parts from a choice in the API response. + * + * @since n.e.x.t + * + * @param MessageData $messageData The message data from the API response. + * @return MessagePart[] The parsed message parts. + */ + protected function parseResponseChoiceMessageParts(array $messageData): array + { + $parts = []; + + if (isset($messageData['reasoning_content']) && is_string($messageData['reasoning_content'])) { + $parts[] = new MessagePart($messageData['reasoning_content'], MessagePartChannelEnum::thought()); + } + + if (isset($messageData['content']) && is_string($messageData['content'])) { + $parts[] = new MessagePart($messageData['content']); + } + + if (isset($messageData['tool_calls']) && is_array($messageData['tool_calls'])) { + foreach ($messageData['tool_calls'] as $toolCallData) { + $toolCallPart = $this->parseResponseChoiceMessageToolCallPart($toolCallData); + if (!$toolCallPart) { + throw new RuntimeException( + 'Unexpected API response: The response includes a tool call of an unexpected type.' + ); + } + $parts[] = $toolCallPart; + } + } + + return $parts; + } + + /** + * Parses a tool call part from the API response. + * + * @since n.e.x.t + * + * @param ToolCallData $toolCallData The tool call data from the API response. + * @return MessagePart|null The parsed message part for the tool call, or null if not applicable. + */ + protected function parseResponseChoiceMessageToolCallPart(array $toolCallData): ?MessagePart + { + /* + * For now, only function calls are supported. + * + * Not all OpenAI compatible APIs include a 'type' key, so we only check its value if it is set. + */ + if ( + (isset($toolCallData['type']) && 'function' !== $toolCallData['type']) || + !isset($toolCallData['function']) || + !is_array($toolCallData['function']) + ) { + return null; + } + + $functionArguments = is_string($toolCallData['function']['arguments']) + ? json_decode($toolCallData['function']['arguments'], true) + : $toolCallData['function']['arguments']; + + $functionCall = new FunctionCall( + isset($toolCallData['id']) && is_string($toolCallData['id']) ? + $toolCallData['id'] : + null, + isset($toolCallData['function']['name']) && is_string($toolCallData['function']['name']) ? + $toolCallData['function']['name'] : + null, + $functionArguments + ); + + return new MessagePart($functionCall); + } +} diff --git a/src/Providers/ProviderRegistry.php b/src/Providers/ProviderRegistry.php index 8fc39c76..ff352714 100644 --- a/src/Providers/ProviderRegistry.php +++ b/src/Providers/ProviderRegistry.php @@ -5,9 +5,17 @@ namespace WordPress\AiClient\Providers; use InvalidArgumentException; +use RuntimeException; use WordPress\AiClient\Providers\Contracts\ProviderInterface; +use WordPress\AiClient\Providers\Contracts\ProviderWithOperationsHandlerInterface; use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\DTO\ProviderModelsMetadata; +use WordPress\AiClient\Providers\Http\Contracts\HttpTransporterInterface; +use WordPress\AiClient\Providers\Http\Contracts\RequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithHttpTransporterInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithRequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\DTO\ApiKeyRequestAuthentication; +use WordPress\AiClient\Providers\Http\Traits\WithHttpTransporterTrait; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; @@ -21,8 +29,12 @@ * * @since n.e.x.t */ -class ProviderRegistry +class ProviderRegistry implements WithHttpTransporterInterface { + use WithHttpTransporterTrait { + setHttpTransporter as setHttpTransporterOriginal; + } + /** * @var array> Mapping of provider IDs to class names. */ @@ -33,6 +45,11 @@ class ProviderRegistry */ private array $registeredClassNames = []; + /** + * @var array, RequestAuthenticationInterface> Mapping of provider class names to + * authentication instances. + */ + private array $providerAuthenticationInstances = []; /** * Registers a provider class with the registry. @@ -66,6 +83,32 @@ public function registerProvider(string $className): void ); } + // If there is already a HTTP transporter instance set, hook it up to the provider as needed. + try { + $httpTransporter = $this->getHttpTransporter(); + $this->setHttpTransporterForProvider($className, $httpTransporter); + } catch (RuntimeException $e) { + /* + * If this fails, it's okay. There is no defined sequence between setting the HTTP transporter in the + * registry and registering providers in it, so it might be that the transporter is set later. It will be + * hooked up then. + * Therefore we can simply ignore this exception. + */ + } + + // Hook up the request authentication instance, using a default if not set. + if (!isset($this->providerAuthenticationInstances[$className])) { + $defaultProviderAuthentication = $this->createDefaultProviderRequestAuthentication( + $className + ); + if ($defaultProviderAuthentication !== null) { + $this->providerAuthenticationInstances[$className] = $defaultProviderAuthentication; + } + } + if (isset($this->providerAuthenticationInstances[$className])) { + $this->setRequestAuthenticationForProvider($className, $this->providerAuthenticationInstances[$className]); + } + $this->providerClassNames[$metadata->getId()] = $className; $this->registeredClassNames[$className] = true; } @@ -128,7 +171,7 @@ public function isProviderConfigured(string $idOrClassName): bool } /** - * Finds models across all providers that support the given requirements. + * Finds models across all available providers that support the given requirements. * * @since n.e.x.t * @@ -157,7 +200,7 @@ public function findModelsMetadataForSupport(ModelRequirements $modelRequirement } /** - * Finds models within a specific provider that support the given requirements. + * Finds models within a specific available provider that support the given requirements. * * @since n.e.x.t * @@ -171,6 +214,11 @@ public function findProviderModelsMetadataForSupport( ): array { $className = $this->resolveProviderClassName($idOrClassName); + // If the provider is not configured, there is no way to use it, so it is considered unavailable. + if (!$this->isProviderConfigured($className)) { + return []; + } + $modelMetadataDirectory = $className::modelMetadataDirectory(); // Filter models that meet requirements @@ -202,9 +250,38 @@ public function getProviderModel( ): ModelInterface { $className = $this->resolveProviderClassName($idOrClassName); - // Use static method from ProviderInterface - /** @var class-string $className */ - return $className::model($modelId, $modelConfig); + $modelInstance = $className::model($modelId, $modelConfig); + + $this->bindModelDependencies($modelInstance); + + return $modelInstance; + } + + /** + * Binds dependencies to a model instance. + * + * This method injects required dependencies such as HTTP transporter + * and authentication into model instances that need them. + * + * @since n.e.x.t + * + * @param ModelInterface $modelInstance The model instance to bind dependencies to. + * @return void + */ + public function bindModelDependencies(ModelInterface $modelInstance): void + { + $className = $this->resolveProviderClassName($modelInstance->providerMetadata()->getId()); + + if ($modelInstance instanceof WithHttpTransporterInterface) { + $modelInstance->setHttpTransporter($this->getHttpTransporter()); + } + + if ($modelInstance instanceof WithRequestAuthenticationInterface) { + $requestAuthentication = $this->getProviderRequestAuthentication($className); + if ($requestAuthentication !== null) { + $modelInstance->setRequestAuthentication($requestAuthentication); + } + } } /** @@ -228,4 +305,209 @@ private function resolveProviderClassName(string $idOrClassName): string // @phpstan-ignore-next-line return.type (Interface implementation guaranteed by registration validation) return $className; } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function setHttpTransporter(HttpTransporterInterface $httpTransporter): void + { + $this->setHttpTransporterOriginal($httpTransporter); + + // Make sure all registered providers have the HTTP transporter hooked up as needed. + foreach ($this->providerClassNames as $className) { + $this->setHttpTransporterForProvider($className, $httpTransporter); + } + } + + /** + * Sets the request authentication instance for the given provider. + * + * @since n.e.x.t + * + * @param string|class-string $idOrClassName The provider ID or class name. + * @param RequestAuthenticationInterface $requestAuthentication The request authentication instance. + */ + public function setProviderRequestAuthentication( + string $idOrClassName, + RequestAuthenticationInterface $requestAuthentication + ): void { + $className = $this->resolveProviderClassName($idOrClassName); + + $this->providerAuthenticationInstances[$className] = $requestAuthentication; + + $this->setRequestAuthenticationForProvider($className, $requestAuthentication); + } + + /** + * Gets the request authentication instance for the given provider, if set. + * + * @since n.e.x.t + * + * @param string|class-string $idOrClassName The provider ID or class name. + * @return ?RequestAuthenticationInterface The request authentication instance, or null if not set. + */ + public function getProviderRequestAuthentication(string $idOrClassName): ?RequestAuthenticationInterface + { + $className = $this->resolveProviderClassName($idOrClassName); + if (!isset($this->providerAuthenticationInstances[$className])) { + return null; + } + return $this->providerAuthenticationInstances[$className]; + } + + /** + * Sets the HTTP transporter for a specific provider, hooking up its class instances. + * + * @since n.e.x.t + * + * @param class-string $className The provider class name. + * @param HttpTransporterInterface $httpTransporter The HTTP transporter instance. + */ + private function setHttpTransporterForProvider( + string $className, + HttpTransporterInterface $httpTransporter + ): void { + $availability = $className::availability(); + if ($availability instanceof WithHttpTransporterInterface) { + $availability->setHttpTransporter($httpTransporter); + } + + $modelMetadataDirectory = $className::modelMetadataDirectory(); + if ($modelMetadataDirectory instanceof WithHttpTransporterInterface) { + $modelMetadataDirectory->setHttpTransporter($httpTransporter); + } + + if (is_subclass_of($className, ProviderWithOperationsHandlerInterface::class)) { + $operationsHandler = $className::operationsHandler(); + if ($operationsHandler instanceof WithHttpTransporterInterface) { + $operationsHandler->setHttpTransporter($httpTransporter); + } + } + } + + /** + * Sets the request authentication for a specific provider, hooking up its class instances. + * + * @since n.e.x.t + * + * @param class-string $className The provider class name. + * @param RequestAuthenticationInterface $requestAuthentication The authentication instance. + */ + private function setRequestAuthenticationForProvider( + string $className, + RequestAuthenticationInterface $requestAuthentication + ): void { + $availability = $className::availability(); + if ($availability instanceof WithRequestAuthenticationInterface) { + $availability->setRequestAuthentication($requestAuthentication); + } + + $modelMetadataDirectory = $className::modelMetadataDirectory(); + if ($modelMetadataDirectory instanceof WithRequestAuthenticationInterface) { + $modelMetadataDirectory->setRequestAuthentication($requestAuthentication); + } + + if (is_subclass_of($className, ProviderWithOperationsHandlerInterface::class)) { + $operationsHandler = $className::operationsHandler(); + if ($operationsHandler instanceof WithRequestAuthenticationInterface) { + $operationsHandler->setRequestAuthentication($requestAuthentication); + } + } + } + + /** + * Creates a default request authentication instance for a provider. + * + * @since n.e.x.t + * + * @param class-string $className The provider class name. + * @return ?RequestAuthenticationInterface The default request authentication instance, or null if not required or + * if no credential data can be found. + */ + private function createDefaultProviderRequestAuthentication( + string $className + ): ?RequestAuthenticationInterface { + $providerId = $className::metadata()->getId(); + + /* + * For now, we assume API key authentication is used by default. + * In the future, this could be made more flexible by allowing the provider to express a specific type of + * request authentication to use. + */ + $authenticationClass = ApiKeyRequestAuthentication::class; + $authenticationSchema = $authenticationClass::getJsonSchema(); + + // Iterate over all JSON schema object properties to try to determine the necessary authentication data. + $authenticationData = []; + if (isset($authenticationSchema['properties']) && is_array($authenticationSchema['properties'])) { + /** @var array $details */ + foreach ($authenticationSchema['properties'] as $property => $details) { + $envVarName = $this->getEnvVarName($providerId, $property); + + // Try to get the value from environment variable or constant. + $envValue = getenv($envVarName); + if ($envValue === false) { + if (!defined($envVarName)) { + continue; // Skip if neither environment variable nor constant is defined. + } + $envValue = constant($envVarName); + if (!is_scalar($envValue)) { + continue; + } + } + + if (isset($details['type'])) { + switch ($details['type']) { + case 'boolean': + $authenticationData[$property] = filter_var($envValue, FILTER_VALIDATE_BOOLEAN); + break; + case 'number': + $authenticationData[$property] = (int) $envValue; + break; + case 'string': + default: + $authenticationData[$property] = (string) $envValue; + } + } else { + // Default to string if no type is specified. + $authenticationData[$property] = (string) $envValue; + } + } + + // If any required fields are missing, return null to avoid immediate errors. + if (isset($authenticationSchema['required']) && is_array($authenticationSchema['required'])) { + /** @var list $requiredProperties */ + $requiredProperties = $authenticationSchema['required']; + if (array_diff_key(array_flip($requiredProperties), $authenticationData)) { + return null; + } + } + } + + return $authenticationClass::fromArray($authenticationData); + } + + /** + * Converts a provider ID and field name to a constant case environment variable name. + * + * @since n.e.x.t + * + * @param string $providerId The provider ID. + * @param string $field The field name. + * @return string The environment variable name in CONSTANT_CASE. + */ + private function getEnvVarName(string $providerId, string $field): string + { + // Convert camelCase or kebab-case or snake_case to CONSTANT_CASE. + $constantCaseProviderId = strtoupper( + (string) preg_replace('/([a-z])([A-Z])/', '$1_$2', str_replace('-', '_', $providerId)) + ); + $constantCaseField = strtoupper( + (string) preg_replace('/([a-z])([A-Z])/', '$1_$2', str_replace('-', '_', $field)) + ); + + return "{$constantCaseProviderId}_{$constantCaseField}"; + } } diff --git a/src/Results/DTO/Candidate.php b/src/Results/DTO/Candidate.php index 11172336..3c5e6a95 100644 --- a/src/Results/DTO/Candidate.php +++ b/src/Results/DTO/Candidate.php @@ -130,7 +130,7 @@ public static function fromArray(array $array): self return new self( Message::fromArray($messageData), - FinishReasonEnum::from($array[self::KEY_FINISH_REASON]), + FinishReasonEnum::from($array[self::KEY_FINISH_REASON]) ); } } diff --git a/src/Tools/DTO/FunctionResponse.php b/src/Tools/DTO/FunctionResponse.php index 1078ddc1..4d6d4143 100644 --- a/src/Tools/DTO/FunctionResponse.php +++ b/src/Tools/DTO/FunctionResponse.php @@ -4,6 +4,7 @@ namespace WordPress\AiClient\Tools\DTO; +use InvalidArgumentException; use WordPress\AiClient\Common\AbstractDataTransferObject; /** @@ -151,7 +152,7 @@ public static function fromArray(array $array): self // Validate that at least one of id or name is provided if (!array_key_exists(self::KEY_ID, $array) && !array_key_exists(self::KEY_NAME, $array)) { - throw new \InvalidArgumentException('At least one of id or name must be provided.'); + throw new InvalidArgumentException('At least one of id or name must be provided.'); } return new self( diff --git a/tests/mocks/MockAbstractProvider.php b/tests/mocks/MockAbstractProvider.php new file mode 100644 index 00000000..9a800f4c --- /dev/null +++ b/tests/mocks/MockAbstractProvider.php @@ -0,0 +1,82 @@ +lastRequest = $request; + return $this->responseToReturn ?? new Response(200, [], '{"status":"success"}'); + } + + /** + * Gets the last request that was sent. + * + * @return Request|null + */ + public function getLastRequest(): ?Request + { + return $this->lastRequest; + } + + /** + * Sets the response to return for subsequent requests. + * + * @param Response $response + */ + public function setResponseToReturn(Response $response): void + { + $this->responseToReturn = $response; + } +} diff --git a/tests/mocks/MockModel.php b/tests/mocks/MockModel.php index 57541368..7623b122 100644 --- a/tests/mocks/MockModel.php +++ b/tests/mocks/MockModel.php @@ -4,6 +4,11 @@ namespace WordPress\AiClient\Tests\mocks; +use WordPress\AiClient\Providers\DTO\ProviderMetadata; +use WordPress\AiClient\Providers\Http\Contracts\WithHttpTransporterInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithRequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\Traits\WithHttpTransporterTrait; +use WordPress\AiClient\Providers\Http\Traits\WithRequestAuthenticationTrait; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; @@ -13,8 +18,11 @@ * * @since n.e.x.t */ -class MockModel implements ModelInterface +class MockModel implements ModelInterface, WithHttpTransporterInterface, WithRequestAuthenticationInterface { + use WithHttpTransporterTrait; + use WithRequestAuthenticationTrait; + /** * @var ModelMetadata The model metadata. */ @@ -45,6 +53,15 @@ public function metadata(): ModelMetadata return $this->metadata; } + /** + * {@inheritDoc} + */ + public function providerMetadata(): ProviderMetadata + { + // Return the MockProvider's metadata + return MockProvider::metadata(); + } + /** * {@inheritDoc} */ diff --git a/tests/mocks/MockModelMetadataDirectory.php b/tests/mocks/MockModelMetadataDirectory.php index e7e19e3a..f2b21603 100644 --- a/tests/mocks/MockModelMetadataDirectory.php +++ b/tests/mocks/MockModelMetadataDirectory.php @@ -6,6 +6,10 @@ use InvalidArgumentException; use WordPress\AiClient\Providers\Contracts\ModelMetadataDirectoryInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithHttpTransporterInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithRequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\Traits\WithHttpTransporterTrait; +use WordPress\AiClient\Providers\Http\Traits\WithRequestAuthenticationTrait; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; /** @@ -13,8 +17,14 @@ * * @since n.e.x.t */ -class MockModelMetadataDirectory implements ModelMetadataDirectoryInterface +class MockModelMetadataDirectory implements + ModelMetadataDirectoryInterface, + WithHttpTransporterInterface, + WithRequestAuthenticationInterface { + use WithHttpTransporterTrait; + use WithRequestAuthenticationTrait; + /** * @var array Available models. */ diff --git a/tests/mocks/MockProvider.php b/tests/mocks/MockProvider.php index 7b0a07cc..7ad38a3c 100644 --- a/tests/mocks/MockProvider.php +++ b/tests/mocks/MockProvider.php @@ -11,6 +11,7 @@ use WordPress\AiClient\Providers\Enums\ProviderTypeEnum; use WordPress\AiClient\Providers\Models\Contracts\ModelInterface; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; +use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; /** @@ -74,7 +75,7 @@ public static function modelMetadataDirectory(): ModelMetadataDirectoryInterface if (static::$modelMetadataDirectory === null) { // Create some mock models for testing $mockModels = [ - 'mock-text-model' => new \WordPress\AiClient\Providers\Models\DTO\ModelMetadata( + 'mock-text-model' => new ModelMetadata( 'mock-text-model', 'Mock Text Model', [CapabilityEnum::textGeneration()], diff --git a/tests/mocks/MockProviderAvailability.php b/tests/mocks/MockProviderAvailability.php index 805aa3df..0e4a8f0b 100644 --- a/tests/mocks/MockProviderAvailability.php +++ b/tests/mocks/MockProviderAvailability.php @@ -5,14 +5,24 @@ namespace WordPress\AiClient\Tests\mocks; use WordPress\AiClient\Providers\Contracts\ProviderAvailabilityInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithHttpTransporterInterface; +use WordPress\AiClient\Providers\Http\Contracts\WithRequestAuthenticationInterface; +use WordPress\AiClient\Providers\Http\Traits\WithHttpTransporterTrait; +use WordPress\AiClient\Providers\Http\Traits\WithRequestAuthenticationTrait; /** * Mock provider availability for testing. * * @since n.e.x.t */ -class MockProviderAvailability implements ProviderAvailabilityInterface +class MockProviderAvailability implements + ProviderAvailabilityInterface, + WithHttpTransporterInterface, + WithRequestAuthenticationInterface { + use WithHttpTransporterTrait; + use WithRequestAuthenticationTrait; + /** * @var bool Whether the provider is configured. */ diff --git a/tests/mocks/MockRequestAuthentication.php b/tests/mocks/MockRequestAuthentication.php new file mode 100644 index 00000000..8c1eafb1 --- /dev/null +++ b/tests/mocks/MockRequestAuthentication.php @@ -0,0 +1,54 @@ +token = $token; + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public function authenticateRequest(Request $request): Request + { + return $request->withHeader('X-Mock-Auth', $this->token); + } + + /** + * {@inheritDoc} + * + * @since n.e.x.t + */ + public static function getJsonSchema(): array + { + return [ + 'type' => 'object', + 'properties' => [], + 'required' => [], + ]; + } +} diff --git a/tests/traits/ArrayTransformationTestTrait.php b/tests/traits/ArrayTransformationTestTrait.php index d670c0f2..159b573f 100644 --- a/tests/traits/ArrayTransformationTestTrait.php +++ b/tests/traits/ArrayTransformationTestTrait.php @@ -4,6 +4,8 @@ namespace WordPress\AiClient\Tests\traits; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; + /** * Trait for testing array transformation functionality. * @@ -20,7 +22,7 @@ trait ArrayTransformationTestTrait protected function assertImplementsArrayTransformation($object): void { $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $object, 'Object should implement WithArrayTransformationInterface interface' ); diff --git a/tests/traits/MockModelCreationTrait.php b/tests/traits/MockModelCreationTrait.php index 0f4517fe..3423f950 100644 --- a/tests/traits/MockModelCreationTrait.php +++ b/tests/traits/MockModelCreationTrait.php @@ -15,10 +15,12 @@ use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; +use WordPress\AiClient\Providers\ProviderRegistry; use WordPress\AiClient\Results\DTO\Candidate; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Tests\mocks\MockProvider; /** * Trait providing shared mock model creation methods for testing. @@ -30,6 +32,19 @@ */ trait MockModelCreationTrait { + /** + * Creates a provider registry with the mock provider registered. + * + * @since n.e.x.t + * + * @return ProviderRegistry The registry with mock provider. + */ + protected function createRegistryWithMockProvider(): ProviderRegistry + { + $registry = new ProviderRegistry(); + $registry->registerProvider(MockProvider::class); + return $registry; + } /** * Creates a test GenerativeAiResult for testing purposes. * @@ -45,7 +60,7 @@ protected function createTestResult(string $content = 'Test response'): Generati $tokenUsage = new TokenUsage(10, 20, 30); $providerMetadata = new ProviderMetadata( - 'mock-provider', + 'mock', 'Mock Provider', ProviderTypeEnum::cloud() ); @@ -116,14 +131,29 @@ protected function createMockTextGenerationModel( ): ModelInterface { $metadata = $metadata ?? $this->createTestTextModelMetadata(); - return new class ($metadata, $result) implements ModelInterface, TextGenerationModelInterface { + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, TextGenerationModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private GenerativeAiResult $result; private ModelConfig $config; - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + GenerativeAiResult $result + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->result = $result; $this->config = new ModelConfig(); } @@ -133,6 +163,11 @@ public function metadata(): ModelMetadata return $this->metadata; } + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + public function setConfig(ModelConfig $config): void { $this->config = $config; @@ -168,14 +203,29 @@ protected function createMockImageGenerationModel( ): ModelInterface { $metadata = $metadata ?? $this->createTestImageModelMetadata(); - return new class ($metadata, $result) implements ModelInterface, ImageGenerationModelInterface { + $providerMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, ImageGenerationModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private GenerativeAiResult $result; private ModelConfig $config; - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + GenerativeAiResult $result + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->result = $result; $this->config = new ModelConfig(); } @@ -185,6 +235,11 @@ public function metadata(): ModelMetadata return $this->metadata; } + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + public function setConfig(ModelConfig $config): void { $this->config = $config; @@ -212,6 +267,11 @@ protected function createMockUnsupportedModel(string $modelId = 'unsupported-mod { $mockModel = $this->createMock(ModelInterface::class); $mockMetadata = $this->createMock(ModelMetadata::class); + $mockProviderMetadata = new ProviderMetadata( + 'mock', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); $mockMetadata->expects($this->any()) ->method('getId') @@ -221,6 +281,14 @@ protected function createMockUnsupportedModel(string $modelId = 'unsupported-mod ->method('metadata') ->willReturn($mockMetadata); + $mockModel->expects($this->any()) + ->method('providerMetadata') + ->willReturn($mockProviderMetadata); + + $mockModel->expects($this->any()) + ->method('getConfig') + ->willReturn(new ModelConfig()); + return $mockModel; } } diff --git a/tests/unit/AiClientTest.php b/tests/unit/AiClientTest.php index 6cb76d7d..32f5b21b 100644 --- a/tests/unit/AiClientTest.php +++ b/tests/unit/AiClientTest.php @@ -94,8 +94,9 @@ public function testGenerateTextResultWithStringAndModel(): void $prompt = 'Generate text'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateTextResult($prompt, $mockModel); + $result = AiClient::generateTextResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -107,11 +108,12 @@ public function testGenerateTextResultWithInvalidModel(): void { $prompt = 'Generate text'; $invalidModel = $this->createMockUnsupportedModel('invalid-text-model'); + $registry = $this->createRegistryWithMockProvider(); $this->expectException(RuntimeException::class); $this->expectExceptionMessage('Model "invalid-text-model" does not support text generation.'); - AiClient::generateTextResult($prompt, $invalidModel); + AiClient::generateTextResult($prompt, $invalidModel, $registry); } /** @@ -122,8 +124,9 @@ public function testGenerateImageResultWithStringAndModel(): void $prompt = 'Generate image'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockImageGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateImageResult($prompt, $mockModel); + $result = AiClient::generateImageResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -135,11 +138,12 @@ public function testGenerateImageResultWithInvalidModel(): void { $prompt = 'Generate image'; $invalidModel = $this->createMockUnsupportedModel('invalid-image-model'); + $registry = $this->createRegistryWithMockProvider(); $this->expectException(RuntimeException::class); $this->expectExceptionMessage('Model "invalid-image-model" does not support image generation.'); - AiClient::generateImageResult($prompt, $invalidModel); + AiClient::generateImageResult($prompt, $invalidModel, $registry); } @@ -152,8 +156,9 @@ public function testGenerateTextResultWithMessage(): void $message = new UserMessage([$messagePart]); $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateTextResult($message, $mockModel); + $result = AiClient::generateTextResult($message, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -166,8 +171,9 @@ public function testGenerateTextResultWithMessagePart(): void $messagePart = new MessagePart('Test message part'); $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateTextResult($messagePart, $mockModel); + $result = AiClient::generateTextResult($messagePart, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -185,8 +191,9 @@ public function testGenerateTextResultWithMessageArray(): void $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateTextResult($messages, $mockModel); + $result = AiClient::generateTextResult($messages, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -202,8 +209,9 @@ public function testGenerateTextResultWithMessagePartArray(): void $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateTextResult($messageParts, $mockModel); + $result = AiClient::generateTextResult($messageParts, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -246,8 +254,9 @@ public function testGenerateResultDelegatesToTextGeneration(): void $prompt = 'Test prompt'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateResult($prompt, $mockModel); + $result = AiClient::generateResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -260,8 +269,9 @@ public function testGenerateResultDelegatesToImageGeneration(): void $prompt = 'Generate image prompt'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockImageGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateResult($prompt, $mockModel); + $result = AiClient::generateResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -288,8 +298,9 @@ public function testGenerateResultWithTextGenerationModel(): void $prompt = 'Generate text content'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockTextGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateResult($prompt, $mockModel); + $result = AiClient::generateResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } @@ -302,8 +313,9 @@ public function testGenerateResultWithImageGenerationModel(): void $prompt = 'Generate an image'; $expectedResult = $this->createTestResult(); $mockModel = $this->createMockImageGenerationModel($expectedResult); + $registry = $this->createRegistryWithMockProvider(); - $result = AiClient::generateResult($prompt, $mockModel); + $result = AiClient::generateResult($prompt, $mockModel, $registry); $this->assertSame($expectedResult, $result); } diff --git a/tests/unit/Builders/PromptBuilderTest.php b/tests/unit/Builders/PromptBuilderTest.php index 437674e3..a17db192 100644 --- a/tests/unit/Builders/PromptBuilderTest.php +++ b/tests/unit/Builders/PromptBuilderTest.php @@ -22,8 +22,6 @@ use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; -use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; -use WordPress\AiClient\Providers\Models\ImageGeneration\Contracts\ImageGenerationModelInterface; use WordPress\AiClient\Providers\Models\SpeechGeneration\Contracts\SpeechGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextGeneration\Contracts\TextGenerationModelInterface; use WordPress\AiClient\Providers\Models\TextToSpeechConversion\Contracts\TextToSpeechConversionModelInterface; @@ -32,6 +30,7 @@ use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; use WordPress\AiClient\Results\Enums\FinishReasonEnum; +use WordPress\AiClient\Tests\traits\MockModelCreationTrait; use WordPress\AiClient\Tools\DTO\FunctionResponse; /** @@ -39,6 +38,8 @@ */ class PromptBuilderTest extends TestCase { + use MockModelCreationTrait; + /** * @var ProviderRegistry */ @@ -55,85 +56,37 @@ private function createTestProviderMetadata(): ProviderMetadata } /** - * Creates a test model metadata instance. - * - * @return ModelMetadata - */ - private function createTestModelMetadata(): ModelMetadata - { - return new ModelMetadata( - 'test-model', - 'Test Model', - [CapabilityEnum::textGeneration()], - [] - ); - } - - /** - * Creates a mock model that implements both ModelInterface and TextGenerationModelInterface. + * Creates a mock model that implements both ModelInterface and SpeechGenerationModelInterface. * * @param ModelMetadata $metadata The metadata for the model. * @param GenerativeAiResult $result The result to return from generation. - * @return ModelInterface&TextGenerationModelInterface The mock model. + * @return ModelInterface&SpeechGenerationModelInterface The mock model. */ - private function createTextGenerationModel(ModelMetadata $metadata, GenerativeAiResult $result): ModelInterface + private function createSpeechGenerationModel(ModelMetadata $metadata, GenerativeAiResult $result): ModelInterface { - return new class ($metadata, $result) implements ModelInterface, TextGenerationModelInterface { - private ModelMetadata $metadata; - private GenerativeAiResult $result; - private ModelConfig $config; - - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { - $this->metadata = $metadata; - $this->result = $result; - $this->config = new ModelConfig(); - } - - public function metadata(): ModelMetadata - { - return $this->metadata; - } - - public function setConfig(ModelConfig $config): void - { - $this->config = $config; - } - - public function getConfig(): ModelConfig - { - return $this->config; - } - - public function generateTextResult(array $prompt): GenerativeAiResult - { - return $this->result; - } - - public function streamGenerateTextResult(array $prompt): Generator - { - yield $this->result; - } - }; - } + $providerMetadata = new ProviderMetadata( + 'mock-provider', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); - /** - * Creates a mock model that implements both ModelInterface and ImageGenerationModelInterface. - * - * @param ModelMetadata $metadata The metadata for the model. - * @param GenerativeAiResult $result The result to return from generation. - * @return ModelInterface&ImageGenerationModelInterface The mock model. - */ - private function createImageGenerationModel(ModelMetadata $metadata, GenerativeAiResult $result): ModelInterface - { - return new class ($metadata, $result) implements ModelInterface, ImageGenerationModelInterface { + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, SpeechGenerationModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private GenerativeAiResult $result; private ModelConfig $config; - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + GenerativeAiResult $result + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->result = $result; $this->config = new ModelConfig(); } @@ -143,47 +96,9 @@ public function metadata(): ModelMetadata return $this->metadata; } - public function setConfig(ModelConfig $config): void - { - $this->config = $config; - } - - public function getConfig(): ModelConfig - { - return $this->config; - } - - public function generateImageResult(array $prompt): GenerativeAiResult + public function providerMetadata(): ProviderMetadata { - return $this->result; - } - }; - } - - /** - * Creates a mock model that implements both ModelInterface and SpeechGenerationModelInterface. - * - * @param ModelMetadata $metadata The metadata for the model. - * @param GenerativeAiResult $result The result to return from generation. - * @return ModelInterface&SpeechGenerationModelInterface The mock model. - */ - private function createSpeechGenerationModel(ModelMetadata $metadata, GenerativeAiResult $result): ModelInterface - { - return new class ($metadata, $result) implements ModelInterface, SpeechGenerationModelInterface { - private ModelMetadata $metadata; - private GenerativeAiResult $result; - private ModelConfig $config; - - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { - $this->metadata = $metadata; - $this->result = $result; - $this->config = new ModelConfig(); - } - - public function metadata(): ModelMetadata - { - return $this->metadata; + return $this->providerMetadata; } public function setConfig(ModelConfig $config): void @@ -212,14 +127,29 @@ public function generateSpeechResult(array $prompt): GenerativeAiResult */ private function createTextToSpeechModel(ModelMetadata $metadata, GenerativeAiResult $result): ModelInterface { - return new class ($metadata, $result) implements ModelInterface, TextToSpeechConversionModelInterface { + $providerMetadata = new ProviderMetadata( + 'mock-provider', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + return new class ( + $metadata, + $providerMetadata, + $result + ) implements ModelInterface, TextToSpeechConversionModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private GenerativeAiResult $result; private ModelConfig $config; - public function __construct(ModelMetadata $metadata, GenerativeAiResult $result) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata, + GenerativeAiResult $result + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->result = $result; $this->config = new ModelConfig(); } @@ -229,6 +159,11 @@ public function metadata(): ModelMetadata return $this->metadata; } + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + public function setConfig(ModelConfig $config): void { $this->config = $config; @@ -627,7 +562,11 @@ public function testWithHistory(): void */ public function testUsingModel(): void { + // Create a model with empty config + $modelConfig = new ModelConfig(); $model = $this->createMock(ModelInterface::class); + $model->method('getConfig')->willReturn($modelConfig); + $builder = new PromptBuilder($this->registry); $result = $builder->usingModel($model); @@ -676,6 +615,7 @@ public function testUsingModelConfig(): void /** @var ModelConfig $mergedConfig */ $mergedConfig = $configProperty->getValue($builder); + // Check that builder's additional config was included // Assert builder values take precedence $this->assertEquals('Builder instruction', $mergedConfig->getSystemInstruction()); $this->assertEquals(500, $mergedConfig->getMaxTokens()); @@ -1194,7 +1134,7 @@ public function testGenerateResultWithTextModality(): void $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Test prompt'); $builder->usingModel($model); @@ -1218,14 +1158,14 @@ public function testGenerateResultWithImageModality(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate an image'); $builder->usingModel($model); @@ -1250,7 +1190,7 @@ public function testGenerateResultWithAudioModality(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -1279,14 +1219,14 @@ public function testGenerateResultWithMultimodalOutput(): void [new Candidate(new ModelMessage([new MessagePart('Generated text')]), FinishReasonEnum::stop())], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate multimodal'); $builder->usingModel($model); @@ -1356,14 +1296,14 @@ public function testGenerateTextResult(): void [new Candidate(new ModelMessage([new MessagePart('Generated text')]), FinishReasonEnum::stop())], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Test prompt'); $builder->usingModel($model); @@ -1398,14 +1338,14 @@ public function testGenerateImageResult(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate image'); $builder->usingModel($model); @@ -1440,7 +1380,7 @@ public function testGenerateSpeechResult(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -1482,7 +1422,7 @@ public function testConvertTextToSpeechResult(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -1549,14 +1489,14 @@ public function testGenerateText(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate text'); $builder->usingModel($model); @@ -1578,13 +1518,26 @@ public function testGenerateTextThrowsExceptionWhenNoCandidates(): void $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = new class ($metadata) implements ModelInterface, TextGenerationModelInterface { + $providerMetadata = new ProviderMetadata( + 'mock-provider', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + $model = new class ( + $metadata, + $providerMetadata + ) implements ModelInterface, TextGenerationModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private ModelConfig $config; - public function __construct(ModelMetadata $metadata) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->config = new ModelConfig(); } @@ -1593,6 +1546,11 @@ public function metadata(): ModelMetadata return $this->metadata; } + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + public function setConfig(ModelConfig $config): void { $this->config = $config; @@ -1638,14 +1596,14 @@ public function testGenerateTextThrowsExceptionWhenNoParts(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate text'); $builder->usingModel($model); @@ -1673,14 +1631,14 @@ public function testGenerateTextThrowsExceptionWhenPartHasNoText(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate text'); $builder->usingModel($model); @@ -1718,14 +1676,14 @@ public function testGenerateTexts(): void $candidates, new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate texts'); $builder->usingModel($model); @@ -1758,13 +1716,26 @@ public function testGenerateTextsThrowsExceptionWhenNoTextGenerated(): void $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = new class ($metadata) implements ModelInterface, TextGenerationModelInterface { + $providerMetadata = new ProviderMetadata( + 'mock-provider', + 'Mock Provider', + ProviderTypeEnum::cloud() + ); + + $model = new class ( + $metadata, + $providerMetadata + ) implements ModelInterface, TextGenerationModelInterface { private ModelMetadata $metadata; + private ProviderMetadata $providerMetadata; private ModelConfig $config; - public function __construct(ModelMetadata $metadata) - { + public function __construct( + ModelMetadata $metadata, + ProviderMetadata $providerMetadata + ) { $this->metadata = $metadata; + $this->providerMetadata = $providerMetadata; $this->config = new ModelConfig(); } @@ -1773,6 +1744,11 @@ public function metadata(): ModelMetadata return $this->metadata; } + public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + public function setConfig(ModelConfig $config): void { $this->config = $config; @@ -1820,14 +1796,14 @@ public function testGenerateImage(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate image'); $builder->usingModel($model); @@ -1852,14 +1828,14 @@ public function testGenerateImageThrowsExceptionWhenNoFile(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate image'); $builder->usingModel($model); @@ -1895,14 +1871,14 @@ public function testGenerateImages(): void $candidates, new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate images'); $builder->usingModel($model); @@ -1931,7 +1907,7 @@ public function testConvertTextToSpeech(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -1972,7 +1948,7 @@ public function testConvertTextToSpeeches(): void $candidates, new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -2008,7 +1984,7 @@ public function testGenerateSpeech(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -2051,7 +2027,7 @@ public function testGenerateSpeeches(): void $candidates, new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); @@ -2259,14 +2235,14 @@ public function testIncludeOutputModalityPreservesExisting(): void [new Candidate(new ModelMessage([new MessagePart('Generated text')]), FinishReasonEnum::stop())], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Test'); $builder->usingModel($model); @@ -2403,14 +2379,14 @@ public function testGenerateImageResultCreatesProperOperation(): void )], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate an image'); $builder->usingModel($model); @@ -2468,14 +2444,14 @@ public function testGenerateImageReturnsFileDirectly(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createImageGenerationModel($metadata, $result); + $model = $this->createMockImageGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate an image'); $builder->usingModel($model); @@ -2539,7 +2515,7 @@ public function testGenerateTextWithNoCandidatesThrowsException(): void $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate text'); $builder->usingModel($model); @@ -2568,14 +2544,14 @@ public function testGenerateTextWithNonStringPartThrowsException(): void [$candidate], new TokenUsage(100, 50, 150), $this->createTestProviderMetadata(), - $this->createTestModelMetadata() + $this->createTestTextModelMetadata() ); $metadata = $this->createMock(ModelMetadata::class); $metadata->method('getId')->willReturn('test-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Generate text'); $builder->usingModel($model); @@ -2613,9 +2589,9 @@ public function testIsSupportedForText(): void new ModelMessage([new MessagePart('Test')]), FinishReasonEnum::stop() ) - ], new TokenUsage(10, 5, 15), $this->createTestProviderMetadata(), $this->createTestModelMetadata()); + ], new TokenUsage(10, 5, 15), $this->createTestProviderMetadata(), $this->createTestTextModelMetadata()); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); $builder = new PromptBuilder($this->registry, 'Test prompt'); $builder->usingModel($model); @@ -2687,7 +2663,7 @@ public function testIsSupportedForSpeechGeneration(): void new ModelMessage([new MessagePart(new File('https://example.com/speech.mp3', 'audio/mp3'))]), FinishReasonEnum::stop() ) - ], new TokenUsage(10, 5, 15), $this->createTestProviderMetadata(), $this->createTestModelMetadata()); + ], new TokenUsage(10, 5, 15), $this->createTestProviderMetadata(), $this->createTestTextModelMetadata()); $model = $this->createSpeechGenerationModel($metadata, $result); @@ -2710,7 +2686,7 @@ public function testGenerateResultWithProvider(): void $modelMetadata->method('getId')->willReturn('provider-model'); $modelMetadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($modelMetadata, $result); + $model = $this->createMockTextGenerationModel($result, $modelMetadata); // Mock the registry to return the model when provider is specified $this->registry->expects($this->once()) @@ -2765,7 +2741,7 @@ public function testModelTakesPrecedenceOverProvider(): void $metadata->method('getId')->willReturn('explicit-model'); $metadata->method('meetsRequirements')->willReturn(true); - $model = $this->createTextGenerationModel($metadata, $result); + $model = $this->createMockTextGenerationModel($result, $metadata); // Registry should not be called when model is explicitly set $this->registry->expects($this->never()) diff --git a/tests/unit/Common/AbstractDataTransferObjectTest.php b/tests/unit/Common/AbstractDataTransferObjectTest.php index fa7c954d..f26292d4 100644 --- a/tests/unit/Common/AbstractDataTransferObjectTest.php +++ b/tests/unit/Common/AbstractDataTransferObjectTest.php @@ -8,6 +8,8 @@ use PHPUnit\Framework\TestCase; use stdClass; use WordPress\AiClient\Common\AbstractDataTransferObject; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; /** * Tests for the AbstractDataTransferObject class. @@ -546,10 +548,10 @@ public static function getJsonSchema(): array // Verify interface implementations $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $testObject ); - $this->assertInstanceOf(\WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, $testObject); + $this->assertInstanceOf(WithJsonSchemaInterface::class, $testObject); $this->assertInstanceOf(JsonSerializable::class, $testObject); // Verify methods exist and work diff --git a/tests/unit/Files/DTO/FileTest.php b/tests/unit/Files/DTO/FileTest.php index 84d6adcf..a249832b 100644 --- a/tests/unit/Files/DTO/FileTest.php +++ b/tests/unit/Files/DTO/FileTest.php @@ -6,8 +6,10 @@ use InvalidArgumentException; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Files\Enums\FileTypeEnum; +use WordPress\AiClient\Files\ValueObjects\MimeType; /** * @covers \WordPress\AiClient\Files\DTO\File @@ -206,7 +208,7 @@ public function testMimeTypeMethods(): void $file = new File('https://example.com/video.mp4'); $this->assertEquals('video/mp4', $file->getMimeType()); - $this->assertInstanceOf(\WordPress\AiClient\Files\ValueObjects\MimeType::class, $file->getMimeTypeObject()); + $this->assertInstanceOf(MimeType::class, $file->getMimeTypeObject()); $this->assertTrue($file->isVideo()); $this->assertFalse($file->isImage()); $this->assertFalse($file->isAudio()); @@ -294,7 +296,7 @@ public function testToArrayRemoteFile(): void $json = $file->toArray(); $this->assertIsArray($json); - $this->assertEquals(\WordPress\AiClient\Files\Enums\FileTypeEnum::remote()->value, $json[File::KEY_FILE_TYPE]); + $this->assertEquals(FileTypeEnum::remote()->value, $json[File::KEY_FILE_TYPE]); $this->assertEquals('image/jpeg', $json[File::KEY_MIME_TYPE]); $this->assertEquals('https://example.com/image.jpg', $json[File::KEY_URL]); $this->assertArrayNotHasKey(File::KEY_BASE64_DATA, $json); @@ -313,7 +315,7 @@ public function testToArrayInlineFile(): void $json = $file->toArray(); $this->assertIsArray($json); - $this->assertEquals(\WordPress\AiClient\Files\Enums\FileTypeEnum::inline()->value, $json[File::KEY_FILE_TYPE]); + $this->assertEquals(FileTypeEnum::inline()->value, $json[File::KEY_FILE_TYPE]); $this->assertEquals('text/plain', $json[File::KEY_MIME_TYPE]); $this->assertEquals($base64Data, $json[File::KEY_BASE64_DATA]); $this->assertArrayNotHasKey(File::KEY_URL, $json); @@ -327,7 +329,7 @@ public function testToArrayInlineFile(): void public function testFromArrayRemoteFile(): void { $json = [ - File::KEY_FILE_TYPE => \WordPress\AiClient\Files\Enums\FileTypeEnum::remote()->value, + File::KEY_FILE_TYPE => FileTypeEnum::remote()->value, File::KEY_MIME_TYPE => 'image/png', File::KEY_URL => 'https://example.com/test.png' ]; @@ -350,7 +352,7 @@ public function testFromArrayInlineFile(): void { $base64Data = 'SGVsbG8gV29ybGQ='; $json = [ - File::KEY_FILE_TYPE => \WordPress\AiClient\Files\Enums\FileTypeEnum::inline()->value, + File::KEY_FILE_TYPE => FileTypeEnum::inline()->value, File::KEY_MIME_TYPE => 'text/plain', File::KEY_BASE64_DATA => $base64Data ]; @@ -401,7 +403,7 @@ public function testImplementsWithArrayTransformationInterface(): void $file = new File('https://example.com/test.jpg'); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $file ); } diff --git a/tests/unit/Files/ValueObjects/MimeTypeTest.php b/tests/unit/Files/ValueObjects/MimeTypeTest.php index 306d5286..4481aef1 100644 --- a/tests/unit/Files/ValueObjects/MimeTypeTest.php +++ b/tests/unit/Files/ValueObjects/MimeTypeTest.php @@ -153,6 +153,82 @@ public function testUnknownExtensionThrowsException(): void MimeType::fromExtension('xyz'); } + /** + * Tests toExtension method. + * + * @dataProvider mimeTypeToExtensionProvider + * @param string $mimeType + * @param string $expectedExtension + * @return void + */ + public function testToExtension(string $mimeType, string $expectedExtension): void + { + $mimeType = new MimeType($mimeType); + $this->assertEquals($expectedExtension, $mimeType->toExtension()); + } + + /** + * Provides MIME types and expected extensions. + * + * @return array + */ + public function mimeTypeToExtensionProvider(): array + { + return [ + // Text + ['text/plain', 'txt'], + ['text/html', 'html'], + ['text/css', 'css'], + ['application/javascript', 'js'], + ['application/json', 'json'], + ['application/xml', 'xml'], + ['text/csv', 'csv'], + + // Images + ['image/jpeg', 'jpg'], + ['image/png', 'png'], + ['image/gif', 'gif'], + ['image/webp', 'webp'], + ['image/svg+xml', 'svg'], + ['image/x-icon', 'ico'], + + // Documents + ['application/pdf', 'pdf'], + ['application/msword', 'doc'], + ['application/vnd.openxmlformats-officedocument.wordprocessingml.document', 'docx'], + ['application/vnd.ms-excel', 'xls'], + ['application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', 'xlsx'], + + // Audio + ['audio/mpeg', 'mp3'], + ['audio/wav', 'wav'], + ['audio/ogg', 'ogg'], + + // Video + ['video/mp4', 'mp4'], + ['video/x-msvideo', 'avi'], + ['video/webm', 'webm'], + + // Archives + ['application/zip', 'zip'], + ['application/x-tar', 'tar'], + ['application/gzip', 'gz'], + ]; + } + + /** + * Tests toExtension throws exception for unknown MIME type. + * + * @return void + */ + public function testToExtensionThrowsExceptionForUnknownMimeType(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('No known extension for MIME type: application/octet-stream'); + + (new MimeType('application/octet-stream'))->toExtension(); + } + /** * Tests isValid method. * diff --git a/tests/unit/Messages/DTO/MessagePartTest.php b/tests/unit/Messages/DTO/MessagePartTest.php index 971ffcd2..063d2a85 100644 --- a/tests/unit/Messages/DTO/MessagePartTest.php +++ b/tests/unit/Messages/DTO/MessagePartTest.php @@ -7,6 +7,7 @@ use InvalidArgumentException; use PHPUnit\Framework\TestCase; use stdClass; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Files\Enums\FileTypeEnum; use WordPress\AiClient\Messages\DTO\MessagePart; @@ -388,7 +389,7 @@ public function testImplementsWithArrayTransformationInterface(): void $part = new MessagePart('test'); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $part ); } diff --git a/tests/unit/Messages/DTO/MessageTest.php b/tests/unit/Messages/DTO/MessageTest.php index db77cad8..757967f3 100644 --- a/tests/unit/Messages/DTO/MessageTest.php +++ b/tests/unit/Messages/DTO/MessageTest.php @@ -5,6 +5,7 @@ namespace WordPress\AiClient\Tests\unit\Messages\DTO; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; use WordPress\AiClient\Files\DTO\File; use WordPress\AiClient\Messages\DTO\Message; use WordPress\AiClient\Messages\DTO\MessagePart; @@ -375,7 +376,7 @@ public function testImplementsWithArrayTransformationInterface(): void $message = new Message(MessageRoleEnum::user(), [new MessagePart('test')]); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $message ); } diff --git a/tests/unit/Messages/DTO/ModelMessageTest.php b/tests/unit/Messages/DTO/ModelMessageTest.php index c7c25e25..db6c6390 100644 --- a/tests/unit/Messages/DTO/ModelMessageTest.php +++ b/tests/unit/Messages/DTO/ModelMessageTest.php @@ -6,6 +6,7 @@ use PHPUnit\Framework\TestCase; use WordPress\AiClient\Files\DTO\File; +use WordPress\AiClient\Messages\DTO\Message; use WordPress\AiClient\Messages\DTO\MessagePart; use WordPress\AiClient\Messages\DTO\ModelMessage; use WordPress\AiClient\Messages\Enums\MessagePartTypeEnum; @@ -14,7 +15,7 @@ use WordPress\AiClient\Tools\DTO\FunctionCall; /** - * @covers \WordPress\AiClient\Messages\DTO\ModelMessage + * @covers ModelMessage */ class ModelMessageTest extends TestCase { @@ -79,7 +80,7 @@ public function testInheritsFromMessage(): void { $message = new ModelMessage([]); - $this->assertInstanceOf(\WordPress\AiClient\Messages\DTO\Message::class, $message); + $this->assertInstanceOf(Message::class, $message); } /** @@ -115,7 +116,7 @@ public function testWithVariousContentTypes(): void public function testJsonSchemaInheritance(): void { $schema = ModelMessage::getJsonSchema(); - $parentSchema = \WordPress\AiClient\Messages\DTO\Message::getJsonSchema(); + $parentSchema = Message::getJsonSchema(); $this->assertEquals($parentSchema, $schema); } diff --git a/tests/unit/Messages/DTO/UserMessageTest.php b/tests/unit/Messages/DTO/UserMessageTest.php index 5b5df92f..a7be9dda 100644 --- a/tests/unit/Messages/DTO/UserMessageTest.php +++ b/tests/unit/Messages/DTO/UserMessageTest.php @@ -14,7 +14,7 @@ use WordPress\AiClient\Tests\traits\ArrayTransformationTestTrait; /** - * @covers \WordPress\AiClient\Messages\DTO\UserMessage + * @covers UserMessage */ class UserMessageTest extends TestCase { @@ -82,7 +82,7 @@ public function testInheritsFromMessage(): void { $message = new UserMessage([]); - $this->assertInstanceOf(\WordPress\AiClient\Messages\DTO\Message::class, $message); + $this->assertInstanceOf(Message::class, $message); } /** @@ -138,7 +138,7 @@ public function testWithImageAndText(): void public function testJsonSchemaInheritance(): void { $schema = UserMessage::getJsonSchema(); - $parentSchema = \WordPress\AiClient\Messages\DTO\Message::getJsonSchema(); + $parentSchema = Message::getJsonSchema(); $this->assertEquals($parentSchema, $schema); } diff --git a/tests/unit/Operations/DTO/GenerativeAiOperationTest.php b/tests/unit/Operations/DTO/GenerativeAiOperationTest.php index c663b045..6b281ccf 100644 --- a/tests/unit/Operations/DTO/GenerativeAiOperationTest.php +++ b/tests/unit/Operations/DTO/GenerativeAiOperationTest.php @@ -9,6 +9,7 @@ use WordPress\AiClient\Messages\DTO\MessagePart; use WordPress\AiClient\Messages\DTO\ModelMessage; use WordPress\AiClient\Messages\Enums\MessageRoleEnum; +use WordPress\AiClient\Operations\Contracts\OperationInterface; use WordPress\AiClient\Operations\DTO\GenerativeAiOperation; use WordPress\AiClient\Operations\Enums\OperationStateEnum; use WordPress\AiClient\Providers\DTO\ProviderMetadata; @@ -174,7 +175,7 @@ public function testImplementsOperationInterface(): void ); $this->assertInstanceOf( - \WordPress\AiClient\Operations\Contracts\OperationInterface::class, + OperationInterface::class, $operation ); } diff --git a/tests/unit/Providers/AbstractProviderTest.php b/tests/unit/Providers/AbstractProviderTest.php new file mode 100644 index 00000000..10bbaf6a --- /dev/null +++ b/tests/unit/Providers/AbstractProviderTest.php @@ -0,0 +1,249 @@ +getProperty('metadataCache'); + $metadataCacheProperty->setAccessible(true); + $metadataCacheProperty->setValue(null, []); + + $availabilityCacheProperty = $reflectionClass->getProperty('availabilityCache'); + $availabilityCacheProperty->setAccessible(true); + $availabilityCacheProperty->setValue(null, []); + + $modelMetadataDirectoryCacheProperty = $reflectionClass->getProperty('modelMetadataDirectoryCache'); + $modelMetadataDirectoryCacheProperty->setAccessible(true); + $modelMetadataDirectoryCacheProperty->setValue(null, []); + } + + /** + * Tests the metadata() method. + * + * @return void + */ + public function testMetadata(): void + { + $providerMetadata = $this->createMock(ProviderMetadata::class); + MockAbstractProvider::$mockProviderMetadata = $providerMetadata; + + // Call metadata twice to ensure caching works + $result1 = MockAbstractProvider::metadata(); + $result2 = MockAbstractProvider::metadata(); + + $this->assertSame($providerMetadata, $result1); + $this->assertSame($providerMetadata, $result2); + } + + /** + * Tests the model() method without ModelConfig. + * + * @return void + */ + public function testModelWithoutModelConfig(): void + { + $modelId = 'test-model'; + $modelMetadata = $this->createMock(ModelMetadata::class); + $providerMetadata = $this->createMock(ProviderMetadata::class); + $model = $this->createMock(ModelInterface::class); // Use ModelInterface for the mock + $mockModelMetadataDirectory = $this->createMock(ModelMetadataDirectoryInterface::class); + + // Set expectations on the mock that will be used by MockAbstractProvider + $mockModelMetadataDirectory->expects($this->once()) + ->method('getModelMetadata') + ->with($modelId) + ->willReturn($modelMetadata); + + MockAbstractProvider::$mockProviderMetadata = $providerMetadata; + MockAbstractProvider::$mockModelMetadataDirectory = $mockModelMetadataDirectory; + MockAbstractProvider::$mockModel = $model; + + $model->expects($this->never())->method('setConfig'); + + $result = MockAbstractProvider::model($modelId); + + $this->assertSame($model, $result); + } + + /** + * Tests the model() method with ModelConfig. + * + * @return void + */ + public function testModelWithModelConfig(): void + { + $modelId = 'test-model'; + $modelConfig = $this->createMock(ModelConfig::class); + $modelMetadata = $this->createMock(ModelMetadata::class); + $providerMetadata = $this->createMock(ProviderMetadata::class); + $model = $this->createMock(ModelInterface::class); // Use ModelInterface for the mock + $mockModelMetadataDirectory = $this->createMock(ModelMetadataDirectoryInterface::class); + + // Set expectations on the mock that will be used by MockAbstractProvider + $mockModelMetadataDirectory->expects($this->once()) + ->method('getModelMetadata') + ->with($modelId) + ->willReturn($modelMetadata); + + MockAbstractProvider::$mockProviderMetadata = $providerMetadata; + MockAbstractProvider::$mockModelMetadataDirectory = $mockModelMetadataDirectory; + MockAbstractProvider::$mockModel = $model; + + $model->expects($this->once())->method('setConfig')->with($modelConfig); + + $result = MockAbstractProvider::model($modelId, $modelConfig); + + $this->assertSame($model, $result); + } + + /** + * Tests the availability() method. + * + * @return void + */ + public function testAvailability(): void + { + $providerAvailability = $this->createMock(ProviderAvailabilityInterface::class); + MockAbstractProvider::$mockProviderAvailability = $providerAvailability; + + // Call availability twice to ensure caching works + $result1 = MockAbstractProvider::availability(); + $result2 = MockAbstractProvider::availability(); + + $this->assertSame($providerAvailability, $result1); + $this->assertSame($providerAvailability, $result2); + } + + /** + * Tests the modelMetadataDirectory() method. + * + * @return void + */ + public function testModelMetadataDirectory(): void + { + $modelMetadataDirectory = $this->createMock(ModelMetadataDirectoryInterface::class); + MockAbstractProvider::$mockModelMetadataDirectory = $modelMetadataDirectory; + + // Call modelMetadataDirectory twice to ensure caching works + $result1 = MockAbstractProvider::modelMetadataDirectory(); + $result2 = MockAbstractProvider::modelMetadataDirectory(); + + $this->assertSame($modelMetadataDirectory, $result1); + $this->assertSame($modelMetadataDirectory, $result2); + } + + /** + * Tests that the caches are reset between tests for different concrete provider classes. + * + * @return void + */ + public function testCachesArePerConcreteClass(): void + { + // Create two distinct anonymous classes extending AbstractProvider + $mockProviderClass1 = new class extends AbstractProvider { + protected static function createModel( + ModelMetadata $modelMetadata, + ProviderMetadata $providerMetadata + ): ModelInterface { + return new MockModel(); + } + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata('mock-provider-1', 'Mock Provider 1', ProviderTypeEnum::cloud()); + } + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + return new MockProviderAvailability(); + } + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new MockModelMetadataDirectory(); + } + }; + + $mockProviderClass2 = new class extends AbstractProvider { + protected static function createModel( + ModelMetadata $modelMetadata, + ProviderMetadata $providerMetadata + ): ModelInterface { + return new MockModel(); + } + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata('mock-provider-2', 'Mock Provider 2', ProviderTypeEnum::cloud()); + } + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + return new MockProviderAvailability(); + } + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new MockModelMetadataDirectory(); + } + }; + + // Get metadata for the first provider + $metadata1_1 = $mockProviderClass1::metadata(); + $metadata1_2 = $mockProviderClass1::metadata(); // Should be cached + + // Get metadata for the second provider + $metadata2_1 = $mockProviderClass2::metadata(); + $metadata2_2 = $mockProviderClass2::metadata(); // Should be cached + + // Assert that the first provider's metadata is consistent and distinct from the second + $this->assertSame($metadata1_1, $metadata1_2); + $this->assertEquals('mock-provider-1', $metadata1_1->getId()); + $this->assertNotSame($metadata1_1, $metadata2_1); // Ensure they are different instances + + // Assert that the second provider's metadata is consistent + $this->assertSame($metadata2_1, $metadata2_2); + $this->assertEquals('mock-provider-2', $metadata2_1->getId()); + + // Repeat for availability + $availability1_1 = $mockProviderClass1::availability(); + $availability1_2 = $mockProviderClass1::availability(); + $availability2_1 = $mockProviderClass2::availability(); + $availability2_2 = $mockProviderClass2::availability(); + + $this->assertSame($availability1_1, $availability1_2); + $this->assertNotSame($availability1_1, $availability2_1); + $this->assertSame($availability2_1, $availability2_2); + + // Repeat for modelMetadataDirectory + $modelMetadataDirectory1_1 = $mockProviderClass1::modelMetadataDirectory(); + $modelMetadataDirectory1_2 = $mockProviderClass1::modelMetadataDirectory(); + $modelMetadataDirectory2_1 = $mockProviderClass2::modelMetadataDirectory(); + $modelMetadataDirectory2_2 = $mockProviderClass2::modelMetadataDirectory(); + + $this->assertSame($modelMetadataDirectory1_1, $modelMetadataDirectory1_2); + $this->assertNotSame($modelMetadataDirectory1_1, $modelMetadataDirectory2_1); + $this->assertSame($modelMetadataDirectory2_1, $modelMetadataDirectory2_2); + } +} diff --git a/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectoryTest.php b/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectoryTest.php new file mode 100644 index 00000000..664127f8 --- /dev/null +++ b/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelMetadataDirectoryTest.php @@ -0,0 +1,86 @@ +mockModels = [ + 'model-1' => $this->createStub(ModelMetadata::class), + 'model-2' => $this->createStub(ModelMetadata::class), + ]; + } + + /** + * Tests listModelMetadata() method. + * + * @return void + */ + public function testListModelMetadata(): void + { + $directory = new MockApiBasedModelMetadataDirectory($this->mockModels); + $models = $directory->listModelMetadata(); + + $this->assertIsArray($models); + $this->assertCount(2, $models); + $this->assertContains($this->mockModels['model-1'], $models); + $this->assertContains($this->mockModels['model-2'], $models); + } + + /** + * Tests hasModelMetadata() method. + * + * @return void + */ + public function testHasModelMetadata(): void + { + $directory = new MockApiBasedModelMetadataDirectory($this->mockModels); + + $this->assertTrue($directory->hasModelMetadata('model-1')); + $this->assertFalse($directory->hasModelMetadata('non-existent-model')); + } + + /** + * Tests getModelMetadata() method. + * + * @return void + */ + public function testGetModelMetadata(): void + { + $directory = new MockApiBasedModelMetadataDirectory($this->mockModels); + + $this->assertSame($this->mockModels['model-1'], $directory->getModelMetadata('model-1')); + } + + /** + * Tests getModelMetadata() method with non-existent model. + * + * @return void + */ + public function testGetModelMetadataThrowsExceptionForNonExistentModel(): void + { + $directory = new MockApiBasedModelMetadataDirectory($this->mockModels); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('No model with ID non-existent-model was found in the provider'); + + $directory->getModelMetadata('non-existent-model'); + } +} diff --git a/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelTest.php b/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelTest.php new file mode 100644 index 00000000..cd559bfa --- /dev/null +++ b/tests/unit/Providers/ApiBasedImplementation/AbstractApiBasedModelTest.php @@ -0,0 +1,89 @@ +modelMetadata = $this->createStub(ModelMetadata::class); + $this->providerMetadata = $this->createStub(ProviderMetadata::class); + } + + /** + * Tests the constructor and initial state. + * + * @return void + */ + public function testConstructorAndInitialState(): void + { + $model = new MockApiBasedModel($this->modelMetadata, $this->providerMetadata); + + $this->assertSame($this->modelMetadata, $model->metadata()); + $this->assertSame($this->providerMetadata, $model->providerMetadata()); + $this->assertInstanceOf(ModelConfig::class, $model->getConfig()); + $this->assertEquals([], $model->getConfig()->toArray()); + } + + /** + * Tests the metadata() method. + * + * @return void + */ + public function testMetadata(): void + { + $model = new MockApiBasedModel($this->modelMetadata, $this->providerMetadata); + $this->assertSame($this->modelMetadata, $model->metadata()); + } + + /** + * Tests the providerMetadata() method. + * + * @return void + */ + public function testProviderMetadata(): void + { + $model = new MockApiBasedModel($this->modelMetadata, $this->providerMetadata); + $this->assertSame($this->providerMetadata, $model->providerMetadata()); + } + + /** + * Tests the setConfig() and getConfig() methods. + * + * @return void + */ + public function testSetConfigAndGetConfig(): void + { + $model = new MockApiBasedModel($this->modelMetadata, $this->providerMetadata); + $initialConfig = $model->getConfig(); + + $newConfig = ModelConfig::fromArray(['temperature' => 0.7]); + $model->setConfig($newConfig); + + $this->assertSame($newConfig, $model->getConfig()); + $this->assertNotSame($initialConfig, $model->getConfig()); + $this->assertEquals(['temperature' => 0.7], $model->getConfig()->toArray()); + } +} diff --git a/tests/unit/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailabilityTest.php b/tests/unit/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailabilityTest.php new file mode 100644 index 00000000..549879a5 --- /dev/null +++ b/tests/unit/Providers/ApiBasedImplementation/ListModelsApiBasedProviderAvailabilityTest.php @@ -0,0 +1,62 @@ +modelMetadataDirectory = $this->createMock(ModelMetadataDirectoryInterface::class); + } + + /** + * Tests isConfigured() method when listing models succeeds. + * + * @return void + */ + public function testIsConfiguredReturnsTrueOnSuccess(): void + { + $this->modelMetadataDirectory + ->expects($this->once()) + ->method('listModelMetadata') + ->willReturn([]); + + $availability = new ListModelsApiBasedProviderAvailability($this->modelMetadataDirectory); + + $this->assertTrue($availability->isConfigured()); + } + + /** + * Tests isConfigured() method when listing models throws an exception. + * + * @return void + */ + public function testIsConfiguredReturnsFalseOnException(): void + { + $this->modelMetadataDirectory + ->expects($this->once()) + ->method('listModelMetadata') + ->willThrowException(new Exception('API error')); + + $availability = new ListModelsApiBasedProviderAvailability($this->modelMetadataDirectory); + + $this->assertFalse($availability->isConfigured()); + } +} diff --git a/tests/unit/Providers/ApiBasedImplementation/MockApiBasedModel.php b/tests/unit/Providers/ApiBasedImplementation/MockApiBasedModel.php new file mode 100644 index 00000000..70c9b624 --- /dev/null +++ b/tests/unit/Providers/ApiBasedImplementation/MockApiBasedModel.php @@ -0,0 +1,21 @@ + + */ + private array $mockModels; + + /** + * Constructor. + * + * @param array $mockModels + */ + public function __construct(array $mockModels = []) + { + $this->mockModels = $mockModels; + } + + /** + * @inheritdoc + */ + protected function sendListModelsRequest(): array + { + return $this->mockModels; + } +} diff --git a/tests/unit/Providers/DTO/ProviderMetadataTest.php b/tests/unit/Providers/DTO/ProviderMetadataTest.php index 2118b2bc..bf54cfd4 100644 --- a/tests/unit/Providers/DTO/ProviderMetadataTest.php +++ b/tests/unit/Providers/DTO/ProviderMetadataTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\Enums\ProviderTypeEnum; @@ -213,11 +215,11 @@ public function testImplementsCorrectInterfaces(): void $metadata = new ProviderMetadata('test', 'Test', ProviderTypeEnum::cloud()); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $metadata ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $metadata ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/DTO/ProviderModelsMetadataTest.php b/tests/unit/Providers/DTO/ProviderModelsMetadataTest.php index cd46472e..cdbd5bdc 100644 --- a/tests/unit/Providers/DTO/ProviderModelsMetadataTest.php +++ b/tests/unit/Providers/DTO/ProviderModelsMetadataTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\DTO\ProviderModelsMetadata; use WordPress\AiClient\Providers\Enums\ProviderTypeEnum; @@ -357,11 +359,11 @@ public function testImplementsCorrectInterfaces(): void $metadata = new ProviderModelsMetadata($this->createProviderMetadata(), []); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $metadata ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $metadata ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Http/DTO/ApiKeyRequestAuthenticationTest.php b/tests/unit/Providers/Http/DTO/ApiKeyRequestAuthenticationTest.php new file mode 100644 index 00000000..7f2c82b9 --- /dev/null +++ b/tests/unit/Providers/Http/DTO/ApiKeyRequestAuthenticationTest.php @@ -0,0 +1,116 @@ +assertEquals($apiKey, $auth->getApiKey()); + } + + /** + * Tests authenticateRequest method. + * + * @return void + */ + public function testAuthenticateRequest(): void + { + $apiKey = 'test_api_key_456'; + $auth = new ApiKeyRequestAuthentication($apiKey); + + $request = new Request(HttpMethodEnum::get(), 'https://example.com/api'); + $authenticatedRequest = $auth->authenticateRequest($request); + + $this->assertNotSame($request, $authenticatedRequest); // Ensure immutability + $this->assertTrue($authenticatedRequest->hasHeader('Authorization')); + $this->assertEquals('Bearer ' . $apiKey, $authenticatedRequest->getHeaderAsString('Authorization')); + } + + /** + * Tests toArray method. + * + * @return void + */ + public function testToArray(): void + { + $apiKey = 'test_api_key_789'; + $auth = new ApiKeyRequestAuthentication($apiKey); + + $array = $auth->toArray(); + + $this->assertIsArray($array); + $this->assertArrayHasKey(ApiKeyRequestAuthentication::KEY_API_KEY, $array); + $this->assertEquals($apiKey, $array[ApiKeyRequestAuthentication::KEY_API_KEY]); + } + + /** + * Tests fromArray method. + * + * @return void + */ + public function testFromArray(): void + { + $apiKey = 'test_api_key_abc'; + $array = [ + ApiKeyRequestAuthentication::KEY_API_KEY => $apiKey, + ]; + + $auth = ApiKeyRequestAuthentication::fromArray($array); + + $this->assertInstanceOf(ApiKeyRequestAuthentication::class, $auth); + $this->assertEquals($apiKey, $auth->getApiKey()); + } + + /** + * Tests fromArray method with missing API key. + * + * @return void + */ + public function testFromArrayWithMissingApiKey(): void + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage( + ApiKeyRequestAuthentication::class . '::fromArray() missing required keys: apiKey' + ); + + ApiKeyRequestAuthentication::fromArray([]); + } + + /** + * Tests getJsonSchema method. + * + * @return void + */ + public function testGetJsonSchema(): void + { + $schema = ApiKeyRequestAuthentication::getJsonSchema(); + + $this->assertIsArray($schema); + $this->assertEquals('object', $schema['type']); + $this->assertArrayHasKey('properties', $schema); + $this->assertArrayHasKey(ApiKeyRequestAuthentication::KEY_API_KEY, $schema['properties']); + $this->assertEquals('string', $schema['properties'][ApiKeyRequestAuthentication::KEY_API_KEY]['type']); + $this->assertArrayHasKey('required', $schema); + $this->assertEquals([ApiKeyRequestAuthentication::KEY_API_KEY], $schema['required']); + } +} diff --git a/tests/unit/Providers/Http/Util/ResponseUtilTest.php b/tests/unit/Providers/Http/Util/ResponseUtilTest.php new file mode 100644 index 00000000..68b81ea3 --- /dev/null +++ b/tests/unit/Providers/Http/Util/ResponseUtilTest.php @@ -0,0 +1,115 @@ +createMock(Response::class); + $response->method('isSuccessful')->willReturn(true); + $response->method('getStatusCode')->willReturn($statusCode); + + // Expect no exception to be thrown + $this->expectNotToPerformAssertions(); + ResponseUtil::throwIfNotSuccessful($response); + } + + /** + * Provides successful HTTP status codes. + * + * @return array + */ + public function successfulResponseStatusCodeProvider(): array + { + return [ + '200 OK' => [200], + '201 Created' => [201], + '204 No Content' => [204], + ]; + } + + /** + * Tests that throwIfNotSuccessful throws an exception for unsuccessful responses. + * + * @dataProvider unsuccessfulResponseStatusCodeProvider + * @param int $statusCode The unsuccessful HTTP status code. + * @param array $data The response data. + * @param string $expectedMessagePart The expected part of the exception message. + * @return void + */ + public function testThrowIfNotSuccessfulThrowsForUnsuccessfulResponses( + int $statusCode, + array $data, + string $expectedMessagePart + ): void { + $response = $this->createMock(Response::class); + $response->method('isSuccessful')->willReturn(false); + $response->method('getStatusCode')->willReturn($statusCode); + $response->method('getData')->willReturn($data); + + $this->expectException(ResponseException::class); + $this->expectExceptionCode($statusCode); + $this->expectExceptionMessageMatches("/^Bad status code: {$statusCode}\.($| {$expectedMessagePart})$/"); + + ResponseUtil::throwIfNotSuccessful($response); + } + + /** + * Provides unsuccessful HTTP status codes and corresponding data for testing. + * + * @return array + */ + public function unsuccessfulResponseStatusCodeProvider(): array + { + return [ + '400 Bad Request (no extra message)' => [ + 400, + [], + '', + ], + '401 Unauthorized (error.message)' => [ + 401, + ['error' => ['message' => 'Invalid API key.']], + 'Invalid API key\.', + ], + '403 Forbidden (error string)' => [ + 403, + ['error' => 'Access denied.'], + 'Access denied\.', + ], + '404 Not Found (message string)' => [ + 404, + ['message' => 'Resource not found.'], + 'Resource not found\.', + ], + '500 Internal Server Error (no extra message)' => [ + 500, + [], + '', + ], + '503 Service Unavailable (error.message with special chars)' => [ + 503, + ['error' => ['message' => 'Service is temporarily unavailable. Please try again later.']], + 'Service is temporarily unavailable\. Please try again later\.', + ], + ]; + } +} diff --git a/tests/unit/Providers/Models/DTO/ModelConfigTest.php b/tests/unit/Providers/Models/DTO/ModelConfigTest.php index 742132a4..e875f993 100644 --- a/tests/unit/Providers/Models/DTO/ModelConfigTest.php +++ b/tests/unit/Providers/Models/DTO/ModelConfigTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Files\Enums\FileTypeEnum; use WordPress\AiClient\Files\Enums\MediaOrientationEnum; use WordPress\AiClient\Messages\Enums\ModalityEnum; @@ -70,6 +72,7 @@ public function testDefaultConstructor(): void $this->assertNull($config->getOutputSchema()); $this->assertNull($config->getOutputMediaOrientation()); $this->assertNull($config->getOutputMediaAspectRatio()); + $this->assertNull($config->getOutputSpeechVoice()); $this->assertEquals([], $config->getCustomOptions()); } @@ -169,6 +172,10 @@ public function testSettersAndGetters(): void $config->setOutputMediaAspectRatio('4:3'); $this->assertEquals('4:3', $config->getOutputMediaAspectRatio()); + // Test output speech voice + $config->setOutputSpeechVoice('alloy'); + $this->assertEquals('alloy', $config->getOutputSpeechVoice()); + // Test custom options $customOptions = ['custom_param' => 'value', 'another_param' => 123]; $config->setCustomOptions($customOptions); @@ -210,6 +217,7 @@ public function testGetJsonSchema(): void ModelConfig::KEY_OUTPUT_SCHEMA, ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION, ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO, + ModelConfig::KEY_OUTPUT_SPEECH_VOICE, ModelConfig::KEY_CUSTOM_OPTIONS ]; @@ -228,6 +236,7 @@ public function testGetJsonSchema(): void $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_FILE_TYPE]['type']); $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION]['type']); $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO]['type']); + $this->assertEquals('string', $schema['properties'][ModelConfig::KEY_OUTPUT_SPEECH_VOICE]['type']); $this->assertEquals('object', $schema['properties'][ModelConfig::KEY_CUSTOM_OPTIONS]['type']); // Check constraints @@ -266,6 +275,7 @@ public function testToArrayAllProperties(): void $config->setOutputSchema(['type' => 'object']); $config->setOutputMediaOrientation(MediaOrientationEnum::portrait()); $config->setOutputMediaAspectRatio('9:16'); + $config->setOutputSpeechVoice('onyx'); $config->setCustomOptions(['key' => 'value']); $array = $config->toArray(); @@ -290,6 +300,7 @@ public function testToArrayAllProperties(): void $this->assertEquals(['type' => 'object'], $array[ModelConfig::KEY_OUTPUT_SCHEMA]); $this->assertEquals('portrait', $array[ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION]); $this->assertEquals('9:16', $array[ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO]); + $this->assertEquals('onyx', $array[ModelConfig::KEY_OUTPUT_SPEECH_VOICE]); $this->assertEquals(['key' => 'value'], $array[ModelConfig::KEY_CUSTOM_OPTIONS]); } @@ -401,6 +412,7 @@ public function testFromArrayAllProperties(): void ModelConfig::KEY_OUTPUT_FILE_TYPE => 'inline', ModelConfig::KEY_OUTPUT_MEDIA_ORIENTATION => 'landscape', ModelConfig::KEY_OUTPUT_MEDIA_ASPECT_RATIO => '16:9', + ModelConfig::KEY_OUTPUT_SPEECH_VOICE => 'fable', ModelConfig::KEY_CUSTOM_OPTIONS => ['custom' => true] ]; @@ -430,6 +442,7 @@ public function testFromArrayAllProperties(): void $this->assertEquals(FileTypeEnum::inline(), $config->getOutputFileType()); $this->assertEquals(MediaOrientationEnum::landscape(), $config->getOutputMediaOrientation()); $this->assertEquals('16:9', $config->getOutputMediaAspectRatio()); + $this->assertEquals('fable', $config->getOutputSpeechVoice()); $this->assertEquals(['custom' => true], $config->getCustomOptions()); } @@ -464,6 +477,7 @@ public function testArrayRoundTrip(): void $original->setOutputFileType(FileTypeEnum::inline()); $original->setOutputMediaOrientation(MediaOrientationEnum::square()); $original->setOutputMediaAspectRatio('1:1'); + $original->setOutputSpeechVoice('shimmer'); $original->setCustomOptions(['test' => 'value']); $array = $original->toArray(); @@ -477,6 +491,7 @@ public function testArrayRoundTrip(): void $this->assertEquals($original->getOutputFileType(), $restored->getOutputFileType()); $this->assertEquals($original->getOutputMediaOrientation(), $restored->getOutputMediaOrientation()); $this->assertEquals($original->getOutputMediaAspectRatio(), $restored->getOutputMediaAspectRatio()); + $this->assertEquals($original->getOutputSpeechVoice(), $restored->getOutputSpeechVoice()); $this->assertEquals($original->getCustomOptions(), $restored->getCustomOptions()); } @@ -601,11 +616,11 @@ public function testImplementsCorrectInterfaces(): void $config = new ModelConfig(); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $config ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $config ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Models/DTO/ModelMetadataTest.php b/tests/unit/Providers/Models/DTO/ModelMetadataTest.php index 766e2dd1..46fb4622 100644 --- a/tests/unit/Providers/Models/DTO/ModelMetadataTest.php +++ b/tests/unit/Providers/Models/DTO/ModelMetadataTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\SupportedOption; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; @@ -436,11 +438,11 @@ public function testImplementsCorrectInterfaces(): void $metadata = new ModelMetadata('test', 'Test', [], []); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $metadata ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $metadata ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php index 21f71b2b..75ec6c35 100644 --- a/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php +++ b/tests/unit/Providers/Models/DTO/ModelRequirementsTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; use WordPress\AiClient\Providers\Models\DTO\RequiredOption; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; @@ -378,11 +380,11 @@ public function testImplementsCorrectInterfaces(): void $requirements = new ModelRequirements([], []); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $requirements ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $requirements ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Models/DTO/RequiredOptionTest.php b/tests/unit/Providers/Models/DTO/RequiredOptionTest.php index 3b20551b..b10187af 100644 --- a/tests/unit/Providers/Models/DTO/RequiredOptionTest.php +++ b/tests/unit/Providers/Models/DTO/RequiredOptionTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\Models\DTO\RequiredOption; use WordPress\AiClient\Providers\Models\Enums\OptionEnum; @@ -432,11 +434,11 @@ public function testImplementsCorrectInterfaces(): void $option = new RequiredOption(OptionEnum::maxTokens(), 'value'); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $option ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $option ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Models/DTO/SupportedOptionTest.php b/tests/unit/Providers/Models/DTO/SupportedOptionTest.php index 79c44ca1..a66a6ed8 100644 --- a/tests/unit/Providers/Models/DTO/SupportedOptionTest.php +++ b/tests/unit/Providers/Models/DTO/SupportedOptionTest.php @@ -6,6 +6,8 @@ use JsonSerializable; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Providers\Models\DTO\SupportedOption; use WordPress\AiClient\Providers\Models\Enums\OptionEnum; @@ -138,6 +140,29 @@ public function testWithObjectValues(): void $this->assertFalse($option->isSupportedValue(['type' => 'json_object', 'extra' => 'field'])); } + /** + * Tests that isSupportedValue correctly handles unordered array values. + * + * @return void + */ + public function testIsSupportedValueWithUnorderedArray(): void + { + // Just use any option enum value for the name. + $option = new SupportedOption( + OptionEnum::outputSpeechVoice(), + [['red', 'green', 'blue'], ['yellow', 'orange']] + ); + + // Test with an array that has the same elements but in a different order + $this->assertTrue($option->isSupportedValue(['blue', 'red', 'green'])); + $this->assertTrue($option->isSupportedValue(['orange', 'yellow'])); + + // Test with an array that has different elements or missing elements + $this->assertFalse($option->isSupportedValue(['red', 'green'])); + $this->assertFalse($option->isSupportedValue(['red', 'green', 'blue', 'purple'])); + $this->assertFalse($option->isSupportedValue(['red', 'yellow', 'blue'])); + } + /** * Tests JSON schema generation. * @@ -400,11 +425,11 @@ public function testImplementsCorrectInterfaces(): void $option = new SupportedOption(OptionEnum::maxTokens(), ['value']); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $option ); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $option ); $this->assertInstanceOf( diff --git a/tests/unit/Providers/Models/Enums/OptionEnumTest.php b/tests/unit/Providers/Models/Enums/OptionEnumTest.php index dd82b746..9248a9ad 100644 --- a/tests/unit/Providers/Models/Enums/OptionEnumTest.php +++ b/tests/unit/Providers/Models/Enums/OptionEnumTest.php @@ -34,7 +34,7 @@ protected function getExpectedValues(): array { return [ // Explicitly defined constant (not in ModelConfig) - 'INPUT_MODALITIES' => 'input_modalities', + 'INPUT_MODALITIES' => 'inputModalities', // Dynamically added from ModelConfig KEY_* constants 'OUTPUT_MODALITIES' => 'outputModalities', @@ -56,6 +56,7 @@ protected function getExpectedValues(): array 'OUTPUT_SCHEMA' => 'outputSchema', 'OUTPUT_MEDIA_ORIENTATION' => 'outputMediaOrientation', 'OUTPUT_MEDIA_ASPECT_RATIO' => 'outputMediaAspectRatio', + 'OUTPUT_SPEECH_VOICE' => 'outputSpeechVoice', 'CUSTOM_OPTIONS' => 'customOptions', ]; } diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectoryTest.php b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectoryTest.php new file mode 100644 index 00000000..4924fbeb --- /dev/null +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleModelMetadataDirectoryTest.php @@ -0,0 +1,116 @@ +mockHttpTransporter = $this->createMock(HttpTransporterInterface::class); + $this->mockRequestAuthentication = $this->createMock(RequestAuthenticationInterface::class); + } + + /** + * Tests sendListModelsRequest() method on success. + * + * @return void + */ + public function testSendListModelsRequestSuccess(): void + { + $response = new Response(200, [], '{"data": [{"id": "model-a"}, {"id": "model-b"}]}'); + + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); // Return the request as is. + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn($response); + + $directory = new MockOpenAiCompatibleModelMetadataDirectory( + $this->mockHttpTransporter, + $this->mockRequestAuthentication, + function (string $modelId) { + return $this->createModelMetadataStub($modelId); + } + ); + + $modelsMetadata = $directory->listModelMetadata(); // Calls sendListModelsRequest internally. + + $this->assertCount(2, $modelsMetadata); + $this->assertEquals('model-a', $modelsMetadata[0]->getId()); + $this->assertEquals('model-b', $modelsMetadata[1]->getId()); + } + + /** + * Creates a ModelMetadata stub with the given ID. + * + * @param string $modelId + * @return ModelMetadata&\PHPUnit\Framework\MockObject\Stub + */ + public function createModelMetadataStub(string $modelId) + { + $modelMetadata = $this->createStub(ModelMetadata::class); + $modelMetadata->method('getId')->willReturn($modelId); + return $modelMetadata; + } + + /** + * Tests sendListModelsRequest() method on failure. + * + * @return void + */ + public function testSendListModelsRequestFailure(): void + { + $response = new Response(400, [], '{"error": "Bad Request"}'); + + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn($response); + + $directory = new MockOpenAiCompatibleModelMetadataDirectory( + $this->mockHttpTransporter, + $this->mockRequestAuthentication, + function (string $modelId) { + return $this->createModelMetadataStub($modelId); + } + ); + + $this->expectException(ResponseException::class); + $this->expectExceptionMessage('Bad status code: 400. Bad Request'); + + $directory->listModelMetadata(); + } +} diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php new file mode 100644 index 00000000..cae0b6a9 --- /dev/null +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/AbstractOpenAiCompatibleTextGenerationModelTest.php @@ -0,0 +1,1252 @@ +modelMetadata = $this->createStub(ModelMetadata::class); + $this->modelMetadata->method('getId')->willReturn('test-model'); + $this->providerMetadata = $this->createStub(ProviderMetadata::class); + $this->mockHttpTransporter = $this->createMock(HttpTransporterInterface::class); + $this->mockRequestAuthentication = $this->createMock(RequestAuthenticationInterface::class); + } + + /** + * Creates a mock instance of AbstractOpenAiCompatibleTextGenerationModel. + * + * @param ModelConfig|null $modelConfig + * @return MockOpenAiCompatibleTextGenerationModel + */ + private function createModel(?ModelConfig $modelConfig = null): MockOpenAiCompatibleTextGenerationModel + { + $model = new MockOpenAiCompatibleTextGenerationModel( + $this->modelMetadata, + $this->providerMetadata, + $this->mockHttpTransporter, + $this->mockRequestAuthentication + ); + if ($modelConfig) { + $model->setConfig($modelConfig); + } + return $model; + } + + /** + * Tests generateTextResult() method on success. + * + * @return void + */ + public function testGenerateTextResultSuccess(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Hello')])]; + $response = new Response( + 200, + [], + json_encode([ + 'id' => 'chatcmpl-123', + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Hi there!', + ], + 'finish_reason' => 'stop', + ], + ], + 'usage' => [ + 'prompt_tokens' => 10, + 'completion_tokens' => 5, + 'total_tokens' => 15, + ], + ]) + ); + + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn($response); + + $model = $this->createModel(); + $result = $model->generateTextResult($prompt); + + $this->assertInstanceOf(GenerativeAiResult::class, $result); + $this->assertEquals('chatcmpl-123', $result->getId()); + $this->assertCount(1, $result->getCandidates()); + $this->assertEquals('Hi there!', $result->getCandidates()[0]->getMessage()->getParts()[0]->getText()); + $this->assertEquals(FinishReasonEnum::stop(), $result->getCandidates()[0]->getFinishReason()); + $this->assertEquals(10, $result->getTokenUsage()->getPromptTokens()); + $this->assertEquals(5, $result->getTokenUsage()->getCompletionTokens()); + $this->assertEquals(15, $result->getTokenUsage()->getTotalTokens()); + } + + /** + * Tests generateTextResult() method on API failure. + * + * @return void + */ + public function testGenerateTextResultApiFailure(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Hello')])]; + $response = new Response(400, [], '{"error": "Bad Request"}'); + + $this->mockRequestAuthentication + ->expects($this->once()) + ->method('authenticateRequest') + ->willReturnArgument(0); + + $this->mockHttpTransporter + ->expects($this->once()) + ->method('send') + ->willReturn($response); + + $model = $this->createModel(); + + $this->expectException(ResponseException::class); + $this->expectExceptionMessage('Bad status code: 400. Bad Request'); + + $model->generateTextResult($prompt); + } + + /** + * Tests streamGenerateTextResult() method. + * + * @return void + */ + public function testStreamGenerateTextResult(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Hello')])]; + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Streaming is not yet implemented.'); + + $generator = $model->streamGenerateTextResult($prompt); + $generator->current(); // Attempt to get the first value to trigger the exception. + } + + /** + * Tests prepareGenerateTextParams() with basic text prompt. + * + * @return void + */ + public function testPrepareGenerateTextParamsBasicText(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test message')])]; + $model = $this->createModel(); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('model', $params); + $this->assertEquals('test-model', $params['model']); + $this->assertArrayHasKey('messages', $params); + $this->assertCount(1, $params['messages']); + $this->assertEquals('user', $params['messages'][0]['role']); + $this->assertCount(1, $params['messages'][0]['content']); + $this->assertEquals('text', $params['messages'][0]['content'][0]['type']); + $this->assertEquals('Test message', $params['messages'][0]['content'][0]['text']); + $this->assertArrayNotHasKey('customOptions', $params); // customOptions should not be present if empty + } + + /** + * Tests prepareGenerateTextParams() with system instruction. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithSystemInstruction(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('User message')])]; + $modelConfig = ModelConfig::fromArray(['systemInstruction' => 'You are a helpful assistant.']); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertCount(2, $params['messages']); + $this->assertEquals('system', $params['messages'][0]['role']); + $this->assertEquals('You are a helpful assistant.', $params['messages'][0]['content'][0]['text']); + $this->assertEquals('user', $params['messages'][1]['role']); + } + + /** + * Tests prepareGenerateTextParams() with candidate count. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithCandidateCount(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['candidateCount' => 2]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('n', $params); + $this->assertEquals(2, $params['n']); + } + + /** + * Tests prepareGenerateTextParams() with max tokens. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithMaxTokens(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['maxTokens' => 100]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('max_tokens', $params); + $this->assertEquals(100, $params['max_tokens']); + } + + /** + * Tests prepareGenerateTextParams() with temperature. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithTemperature(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['temperature' => 0.5]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('temperature', $params); + $this->assertEquals(0.5, $params['temperature']); + } + + /** + * Tests prepareGenerateTextParams() with topP. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithTopP(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['topP' => 0.9]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('top_p', $params); + $this->assertEquals(0.9, $params['top_p']); + } + + /** + * Tests prepareGenerateTextParams() with stop sequences. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithStopSequences(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['stopSequences' => ['stop1', 'stop2']]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('stop', $params); + $this->assertEquals(['stop1', 'stop2'], $params['stop']); + } + + /** + * Tests prepareGenerateTextParams() with presence penalty. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithPresencePenalty(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['presencePenalty' => 0.1]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('presence_penalty', $params); + $this->assertEquals(0.1, $params['presence_penalty']); + } + + /** + * Tests prepareGenerateTextParams() with frequency penalty. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithFrequencyPenalty(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['frequencyPenalty' => 0.2]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('frequency_penalty', $params); + $this->assertEquals(0.2, $params['frequency_penalty']); + } + + /** + * Tests prepareGenerateTextParams() with logprobs. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithLogprobs(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['logprobs' => true]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('logprobs', $params); + $this->assertTrue($params['logprobs']); + } + + /** + * Tests prepareGenerateTextParams() with top logprobs. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithTopLogprobs(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['topLogprobs' => 5]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('top_logprobs', $params); + $this->assertEquals(5, $params['top_logprobs']); + } + + /** + * Tests prepareGenerateTextParams() with function declarations. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithFunctionDeclarations(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $functionDeclaration = new FunctionDeclaration( + 'my_function', + 'My function', + ['type' => 'object'] + ); + $modelConfig = ModelConfig::fromArray( + ['functionDeclarations' => [$functionDeclaration->toArray()]] + ); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('tools', $params); + $this->assertCount(1, $params['tools']); + $this->assertEquals('function', $params['tools'][0]['type']); + $this->assertEquals($functionDeclaration->toArray(), $params['tools'][0]['function']); + } + + /** + * Tests prepareGenerateTextParams() with JSON output. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithJsonOutput(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['outputMimeType' => 'application/json']); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('response_format', $params); + $this->assertEquals(['type' => 'json_object'], $params['response_format']); + } + + /** + * Tests prepareGenerateTextParams() with JSON output schema. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithJsonOutputSchema(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $schema = ['type' => 'object', 'properties' => ['name' => ['type' => 'string']]]; + $modelConfig = ModelConfig::fromArray(['outputMimeType' => 'application/json', 'outputSchema' => $schema]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('response_format', $params); + $this->assertEquals(['type' => 'json_schema', 'json_schema' => $schema], $params['response_format']); + } + + /** + * Tests prepareGenerateTextParams() with custom options. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithCustomOptions(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['customOptions' => ['my_custom_key' => 'my_custom_value']]); + $model = $this->createModel($modelConfig); + + $params = $model->exposePrepareGenerateTextParams($prompt); + + $this->assertArrayHasKey('my_custom_key', $params); + $this->assertEquals('my_custom_value', $params['my_custom_key']); + } + + /** + * Tests prepareGenerateTextParams() with conflicting custom option. + * + * @return void + */ + public function testPrepareGenerateTextParamsWithConflictingCustomOption(): void + { + $prompt = [new Message(MessageRoleEnum::user(), [new MessagePart('Test')])]; + $modelConfig = ModelConfig::fromArray(['customOptions' => ['model' => 'conflicting-model']]); + $model = $this->createModel($modelConfig); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The custom option "model" conflicts with an existing parameter.'); + + $model->exposePrepareGenerateTextParams($prompt); + } + + /** + * Tests prepareMessagesParam() with text message. + * + * @return void + */ + public function testPrepareMessagesParamTextMessage(): void + { + $message = new Message(MessageRoleEnum::user(), [new MessagePart('Hello')]); + $model = $this->createModel(); + + $prepared = $model->exposePrepareMessagesParam([$message]); + + $this->assertCount(1, $prepared); + $this->assertEquals('user', $prepared[0]['role']); + $this->assertCount(1, $prepared[0]['content']); + $this->assertEquals('text', $prepared[0]['content'][0]['type']); + $this->assertEquals('Hello', $prepared[0]['content'][0]['text']); + } + + /** + * Tests prepareMessagesParam() with model message and function call. + * + * @return void + */ + public function testPrepareMessagesParamModelMessageWithFunctionCall(): void + { + $functionCall = new FunctionCall('call_1', 'my_function', ['arg1' => 'value1']); + $message = new Message( + MessageRoleEnum::model(), + [new MessagePart($functionCall)] + ); + $model = $this->createModel(); + + $prepared = $model->exposePrepareMessagesParam([$message]); + + $this->assertCount(1, $prepared); + $this->assertEquals('assistant', $prepared[0]['role']); + $this->assertCount(1, $prepared[0]['tool_calls']); + $this->assertEquals('function', $prepared[0]['tool_calls'][0]['type']); + $this->assertEquals('call_1', $prepared[0]['tool_calls'][0]['id']); + $this->assertEquals('my_function', $prepared[0]['tool_calls'][0]['function']['name']); + $this->assertEquals( + json_encode(['arg1' => 'value1']), + $prepared[0]['tool_calls'][0]['function']['arguments'] + ); + } + + /** + * Tests prepareMessagesParam() with function response. + * + * @return void + */ + public function testPrepareMessagesParamFunctionResponse(): void + { + $functionResponse = new FunctionResponse( + 'call_1', + 'my_function', + ['result' => 'success'] + ); + $message = new Message( + MessageRoleEnum::user(), + [new MessagePart($functionResponse)] + ); // Changed to user role + $model = $this->createModel(); + + $prepared = $model->exposePrepareMessagesParam([$message]); + + $this->assertCount(1, $prepared); + $this->assertEquals('tool', $prepared[0]['role']); + $this->assertEquals(json_encode(['result' => 'success']), $prepared[0]['content']); + $this->assertEquals('call_1', $prepared[0]['tool_call_id']); + } + + /** + * Tests getMessageRoleString() method. + * + * @dataProvider messageRoleProvider + * @param MessageRoleEnum $role + * @param string $expected + * @return void + */ + public function testGetMessageRoleString(MessageRoleEnum $role, string $expected): void + { + $model = $this->createModel(); + $this->assertEquals($expected, $model->exposeGetMessageRoleString($role)); + } + + /** + * Provides message roles and their expected string representations. + * + * @return array> + */ + public function messageRoleProvider(): array + { + return [ + 'user' => [MessageRoleEnum::user(), 'user'], + 'model' => [MessageRoleEnum::model(), 'assistant'], + ]; + } + + /** + * Tests getMessagePartContentData() with text part. + * + * @return void + */ + public function testGetMessagePartContentDataTextPart(): void + { + $part = new MessagePart('Hello'); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartContentData($part); + + $this->assertEquals(['type' => 'text', 'text' => 'Hello'], $data); + } + + /** + * Tests getMessagePartContentData() with remote image file. + * + * @return void + */ + public function testGetMessagePartContentDataRemoteImageFile(): void + { + $file = new File('https://example.com/image.png', 'image/png'); + $part = new MessagePart($file); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartContentData($part); + + $this->assertEquals(['type' => 'image_url', 'image_url' => ['url' => 'https://example.com/image.png']], $data); + } + + /** + * Tests getMessagePartContentData() with inline image file. + * + * @return void + */ + public function testGetMessagePartContentDataInlineImageFile(): void + { + $base64Image = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII='; + $file = new File( + $base64Image, + 'image/png' + ); + $part = new MessagePart($file); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartContentData($part); + + $this->assertEquals( + [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $base64Image + ] + ], + $data + ); + } + + /** + * Tests getMessagePartContentData() with inline audio file. + * + * @return void + */ + public function testGetMessagePartContentDataInlineAudioFile(): void + { + $file = new File( + 'data:audio/mpeg;base64,SUQzBAAAAAAA', + 'audio/mpeg' + ); + $part = new MessagePart($file); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartContentData($part); + + $this->assertEquals([ + 'type' => 'input_audio', + 'input_audio' => ['data' => 'SUQzBAAAAAAA', 'format' => 'mp3'] + ], $data); + } + + /** + * Tests getMessagePartContentData() with unsupported remote file type. + * + * @return void + */ + public function testGetMessagePartContentDataUnsupportedRemoteFile(): void + { + $file = new File('https://example.com/doc.pdf', 'application/pdf'); + $part = new MessagePart($file); + $model = $this->createModel(); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Unsupported MIME type "application/pdf" for remote file message part.'); + + $model->exposeGetMessagePartContentData($part); + } + + /** + * Tests getMessagePartContentData() with unsupported inline file type. + * + * @return void + */ + public function testGetMessagePartContentDataUnsupportedInlineFile(): void + { + $file = new File('data:text/plain;base64,SGVsbG8=', 'text/plain'); + $part = new MessagePart($file); + $model = $this->createModel(); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Unsupported MIME type "text/plain" for inline file message part.'); + + $model->exposeGetMessagePartContentData($part); + } + + /** + * Tests getMessagePartContentData() with function call part (should return null). + * + * @return void + */ + public function testGetMessagePartContentDataFunctionCallPart(): void + { + $functionCall = new FunctionCall('call_1', 'my_function', []); + $part = new MessagePart($functionCall); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartContentData($part); + + $this->assertNull($data); + } + + /** + * Tests getMessagePartContentData() with function response part (should throw exception). + * + * @return void + */ + public function testGetMessagePartContentDataFunctionResponsePart(): void + { + $functionResponse = new FunctionResponse('call_1', 'my_function', []); + $part = new MessagePart($functionResponse); + $model = $this->createModel(); + + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage( + 'The API only allows a single function response, as the only content of the message.' + ); + + $model->exposeGetMessagePartContentData($part); + } + + /** + * Tests getMessagePartToolCallData() with function call part. + * + * @return void + */ + public function testGetMessagePartToolCallDataFunctionCallPart(): void + { + $functionCall = new FunctionCall( + 'call_1', + 'my_function', + ['arg1' => 'value1'] + ); + $part = new MessagePart($functionCall); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartToolCallData($part); + + $this->assertEquals([ + 'type' => 'function', + 'id' => 'call_1', + 'function' => [ + 'name' => 'my_function', + 'arguments' => json_encode(['arg1' => 'value1']), + ], + ], $data); + } + + /** + * Tests getMessagePartToolCallData() with text part (should return null). + * + * @return void + */ + public function testGetMessagePartToolCallDataTextPart(): void + { + $part = new MessagePart('Hello'); + $model = $this->createModel(); + $data = $model->exposeGetMessagePartToolCallData($part); + + $this->assertNull($data); + } + + /** + * Tests validateOutputModalities() with text modality. + * + * @return void + */ + public function testValidateOutputModalitiesWithText(): void + { + $model = $this->createModel(); + $model->exposeValidateOutputModalities([ModalityEnum::text()]); + $this->assertTrue(true); // No exception means success. + } + + /** + * Tests validateOutputModalities() with multiple modalities including text. + * + * @return void + */ + public function testValidateOutputModalitiesWithMultipleIncludingText(): void + { + $model = $this->createModel(); + $model->exposeValidateOutputModalities([ModalityEnum::text(), ModalityEnum::image()]); + $this->assertTrue(true); // No exception means success. + } + + /** + * Tests validateOutputModalities() with no modalities. + * + * @return void + */ + public function testValidateOutputModalitiesWithNoModalities(): void + { + $model = $this->createModel(); + $model->exposeValidateOutputModalities([]); + $this->assertTrue(true); // No exception means success. + } + + /** + * Tests validateOutputModalities() without text modality. + * + * @return void + */ + public function testValidateOutputModalitiesWithoutText(): void + { + $model = $this->createModel(); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('A text output modality must be present when generating text.'); + $model->exposeValidateOutputModalities([ModalityEnum::image()]); + } + + /** + * Tests prepareOutputModalitiesParam() method. + * + * @dataProvider outputModalitiesProvider + * @param array $modalities + * @param array $expected + * @return void + */ + public function testPrepareOutputModalitiesParam( + array $modalities, + array $expected + ): void { + $model = $this->createModel(); + $this->assertEquals($expected, $model->exposePrepareOutputModalitiesParam($modalities)); + } + + /** + * Provides output modalities and their expected API parameter representations. + * + * @return array> + */ + public function outputModalitiesProvider(): array + { + return [ + 'text only' => [ + [ModalityEnum::text()], ['text'] + ], + 'image only' => [ + [ModalityEnum::image()], ['image'] + ], + 'audio only' => [ + [ModalityEnum::audio()], ['audio'] + ], + 'text and image' => [ + [ModalityEnum::text(), ModalityEnum::image()], ['text', 'image'] + ], + 'all modalities' => [ + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], ['text', 'image', 'audio'] + ], + ]; + } + + + /** + * Tests prepareToolsParam() method. + * + * @return void + */ + public function testPrepareToolsParam(): void + { + $functionDeclaration1 = new FunctionDeclaration('func1', 'Description 1', ['type' => 'object']); + $functionDeclaration2 = new FunctionDeclaration('func2', 'Description 2', ['type' => 'object']); + $functionDeclarations = [$functionDeclaration1, $functionDeclaration2]; + $model = $this->createModel(); + + $prepared = $model->exposePrepareToolsParam($functionDeclarations); + + $this->assertCount(2, $prepared); + $this->assertEquals('function', $prepared[0]['type']); + $this->assertEquals($functionDeclaration1->toArray(), $prepared[0]['function']); + $this->assertEquals('function', $prepared[1]['type']); + $this->assertEquals($functionDeclaration2->toArray(), $prepared[1]['function']); + } + + /** + * Tests prepareResponseFormatParam() with null schema. + * + * @return void + */ + public function testPrepareResponseFormatParamNullSchema(): void + { + $model = $this->createModel(); + $format = $model->exposePrepareResponseFormatParam(null); + + $this->assertEquals(['type' => 'json_object'], $format); + } + + /** + * Tests prepareResponseFormatParam() with schema. + * + * @return void + */ + public function testPrepareResponseFormatParamWithSchema(): void + { + $schema = ['type' => 'object', 'properties' => ['key' => ['type' => 'string']]]; + $model = $this->createModel(); + $format = $model->exposePrepareResponseFormatParam($schema); + + $this->assertEquals(['type' => 'json_schema', 'json_schema' => $schema], $format); + } + + /** + * Tests parseResponseToGenerativeAiResult() with valid response. + * + * @return void + */ + public function testParseResponseToGenerativeAiResultValidResponse(): void + { + $response = new Response( + 200, + [], + json_encode([ + 'id' => 'test-id', + 'choices' => [ + [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Test content', + ], + 'finish_reason' => 'stop', + ], + ], + 'usage' => [ + 'prompt_tokens' => 10, + 'completion_tokens' => 20, + 'total_tokens' => 30, + ], + 'model' => 'test-model', + ]) + ); + $model = $this->createModel(); + $result = $model->parseResponseToGenerativeAiResult($response); + + $this->assertInstanceOf(GenerativeAiResult::class, $result); + $this->assertEquals('test-id', $result->getId()); + $this->assertCount(1, $result->getCandidates()); + $this->assertEquals('Test content', $result->getCandidates()[0]->getMessage()->getParts()[0]->getText()); + $this->assertEquals(FinishReasonEnum::stop(), $result->getCandidates()[0]->getFinishReason()); + $this->assertEquals(10, $result->getTokenUsage()->getPromptTokens()); + $this->assertEquals(20, $result->getTokenUsage()->getCompletionTokens()); + $this->assertEquals(30, $result->getTokenUsage()->getTotalTokens()); + $this->assertEquals(['model' => 'test-model'], $result->getAdditionalData()); + } + + /** + * Tests parseResponseToGenerativeAiResult() with missing choices. + * + * @return void + */ + public function testParseResponseToGenerativeAiResultMissingChoices(): void + { + $response = new Response(200, [], json_encode(['id' => 'test-id'])); + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Unexpected API response: Missing the choices key.'); + + $model->parseResponseToGenerativeAiResult($response); + } + + /** + * Tests parseResponseToGenerativeAiResult() with invalid choices type. + * + * @return void + */ + public function testParseResponseToGenerativeAiResultInvalidChoicesType(): void + { + $response = new Response( + 200, + [], + json_encode(['choices' => 'invalid']) + ); + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Unexpected API response: The choices key must contain an array.'); + + $model->parseResponseToGenerativeAiResult($response); + } + + /** + * Tests parseResponseToGenerativeAiResult() with invalid choice element type. + * + * @return void + */ + public function testParseResponseToGenerativeAiResultInvalidChoiceElementType(): void + { + $response = new Response(200, [], json_encode(['choices' => ['invalid']])); + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage( + 'Unexpected API response: Each element in the choices key must be an associative array.' + ); + + $model->parseResponseToGenerativeAiResult($response); + } + + /** + * Tests parseResponseChoiceToCandidate() with valid data. + * + * @return void + */ + public function testParseResponseChoiceToCandidateValidData(): void + { + $choiceData = [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Hello from AI', + ], + 'finish_reason' => 'stop', + ]; + $model = $this->createModel(); + $candidate = $model->exposeParseResponseChoiceToCandidate($choiceData); + + $this->assertInstanceOf(Candidate::class, $candidate); + $this->assertEquals('Hello from AI', $candidate->getMessage()->getParts()[0]->getText()); + $this->assertEquals(FinishReasonEnum::stop(), $candidate->getFinishReason()); + } + + /** + * Tests parseResponseChoiceToCandidate() with missing message. + * + * @return void + */ + public function testParseResponseChoiceToCandidateMissingMessage(): void + { + $choiceData = [ + 'finish_reason' => 'stop', + ]; + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage( + 'Unexpected API response: Each choice must contain a message key with an associative array.' + ); + + $model->exposeParseResponseChoiceToCandidate($choiceData); + } + + /** + * Tests parseResponseChoiceToCandidate() with invalid message type. + * + * @return void + */ + public function testParseResponseChoiceToCandidateInvalidMessageType(): void + { + $choiceData = [ + 'message' => 'invalid', + 'finish_reason' => 'stop', + ]; + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage( + 'Unexpected API response: Each choice must contain a message key with an associative array.' + ); + + $model->exposeParseResponseChoiceToCandidate($choiceData); + } + + /** + * Tests parseResponseChoiceToCandidate() with missing finish reason. + * + * @return void + */ + public function testParseResponseChoiceToCandidateMissingFinishReason(): void + { + $choiceData = [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Hello from AI', + ], + ]; + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage( + 'Unexpected API response: Each choice must contain a finish_reason key with a string value.' + ); + + $model->exposeParseResponseChoiceToCandidate($choiceData); + } + + /** + * Tests parseResponseChoiceToCandidate() with invalid finish reason. + * + * @return void + */ + public function testParseResponseChoiceToCandidateInvalidFinishReason(): void + { + $choiceData = [ + 'message' => [ + 'role' => 'assistant', + 'content' => 'Hello from AI', + ], + 'finish_reason' => 'unknown', + ]; + $model = $this->createModel(); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Unexpected API response: Invalid finish reason "unknown".'); + + $model->exposeParseResponseChoiceToCandidate($choiceData); + } + + /** + * Tests parseResponseChoiceMessage() with assistant message. + * + * @return void + */ + public function testParseResponseChoiceMessageAssistant(): void + { + $messageData = [ + 'role' => 'assistant', + 'content' => 'Assistant response', + ]; + $model = $this->createModel(); + $message = $model->exposeParseResponseChoiceMessage($messageData); + + $this->assertEquals(MessageRoleEnum::model(), $message->getRole()); + $this->assertCount(1, $message->getParts()); + $this->assertEquals('Assistant response', $message->getParts()[0]->getText()); + } + + /** + * Tests parseResponseChoiceMessage() with user message. + * + * @return void + */ + public function testParseResponseChoiceMessageUser(): void + { + $messageData = [ + 'role' => 'user', + 'content' => 'User response', + ]; + $model = $this->createModel(); + $message = $model->exposeParseResponseChoiceMessage($messageData); + + $this->assertEquals(MessageRoleEnum::user(), $message->getRole()); + $this->assertCount(1, $message->getParts()); + $this->assertEquals('User response', $message->getParts()[0]->getText()); + } + + /** + * Tests parseResponseChoiceMessageParts() with content and reasoning. + * + * @return void + */ + public function testParseResponseChoiceMessagePartsContentAndReasoning(): void + { + $messageData = [ + 'reasoning_content' => 'Thinking process', + 'content' => 'Final answer', + ]; + $model = $this->createModel(); + $parts = $model->exposeParseResponseChoiceMessageParts($messageData); + + $this->assertCount(2, $parts); + $this->assertEquals('Thinking process', $parts[0]->getText()); + $this->assertEquals(MessagePartChannelEnum::thought(), $parts[0]->getChannel()); + $this->assertEquals('Final answer', $parts[1]->getText()); + $this->assertEquals(MessagePartChannelEnum::content(), $parts[1]->getChannel()); + } + + /** + * Tests parseResponseChoiceMessageParts() with tool calls. + * + * @return void + */ + public function testParseResponseChoiceMessagePartsToolCalls(): void + { + $messageData = [ + 'tool_calls' => [ + [ + 'id' => 'call_1', + 'type' => 'function', + 'function' => [ + 'name' => 'my_function', + 'arguments' => '{"param":"value"}', + ], + ], + ], + ]; + $model = $this->createModel(); + $parts = $model->exposeParseResponseChoiceMessageParts($messageData); + + $this->assertCount(1, $parts); + $this->assertInstanceOf(FunctionCall::class, $parts[0]->getFunctionCall()); + $this->assertEquals('call_1', $parts[0]->getFunctionCall()->getId()); + } + + /** + * Tests parseResponseChoiceMessageToolCallPart() with valid function call. + * + * @return void + */ + public function testParseResponseChoiceMessageToolCallPartValidFunctionCall(): void + { + $toolCallData = [ + 'id' => 'call_123', + 'type' => 'function', + 'function' => [ + 'name' => 'test_function', + 'arguments' => '{"key":"value"}', + ], + ]; + $model = $this->createModel(); + $part = $model->exposeParseResponseChoiceMessageToolCallPart($toolCallData); + + $this->assertInstanceOf(MessagePart::class, $part); + $this->assertInstanceOf(FunctionCall::class, $part->getFunctionCall()); + $this->assertEquals('call_123', $part->getFunctionCall()->getId()); + $this->assertEquals('test_function', $part->getFunctionCall()->getName()); + $this->assertEquals(['key' => 'value'], $part->getFunctionCall()->getArgs()); + } + + /** + * Tests parseResponseChoiceMessageToolCallPart() with missing function data. + * + * @return void + */ + public function testParseResponseChoiceMessageToolCallPartMissingFunctionData(): void + { + $toolCallData = [ + 'id' => 'call_123', + 'type' => 'function', + ]; + $model = $this->createModel(); + $part = $model->exposeParseResponseChoiceMessageToolCallPart($toolCallData); + + $this->assertNull($part); + } + + /** + * Tests parseResponseChoiceMessageToolCallPart() with non-function type. + * + * @return void + */ + public function testParseResponseChoiceMessageToolCallPartNonFunctionType(): void + { + $toolCallData = [ + 'id' => 'call_123', + 'type' => 'unknown', + 'function' => [ + 'name' => 'test_function', + 'arguments' => '{"key":"value"}', + ], + ]; + $model = $this->createModel(); + $part = $model->exposeParseResponseChoiceMessageToolCallPart($toolCallData); + + $this->assertNull($part); + } +} diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleModelMetadataDirectory.php b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleModelMetadataDirectory.php new file mode 100644 index 00000000..97f36787 --- /dev/null +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleModelMetadataDirectory.php @@ -0,0 +1,106 @@ + + */ + private array $mockModels; + + /** + * @var callable + */ + private $modelMetadataStubFactory; + + /** + * Constructor. + * + * @param HttpTransporterInterface&\PHPUnit\Framework\MockObject\MockObject $mockHttpTransporter + * @param RequestAuthenticationInterface&\PHPUnit\Framework\MockObject\MockObject $mockRequestAuthentication + * @param callable $modelMetadataStubFactory + * @param array $mockModels + */ + public function __construct( + $mockHttpTransporter, + $mockRequestAuthentication, + callable $modelMetadataStubFactory, + array $mockModels = [] + ) { + $this->mockHttpTransporter = $mockHttpTransporter; + $this->mockRequestAuthentication = $mockRequestAuthentication; + $this->modelMetadataStubFactory = $modelMetadataStubFactory; + $this->mockModels = $mockModels; + } + + /** + * @inheritdoc + */ + public function getHttpTransporter(): HttpTransporterInterface + { + return $this->mockHttpTransporter; + } + + /** + * @inheritdoc + */ + public function getRequestAuthentication(): RequestAuthenticationInterface + { + return $this->mockRequestAuthentication; + } + + /** + * @inheritdoc + */ + protected function createRequest( + HttpMethodEnum $method, + string $path, + array $headers = [], + $data = null + ): Request { + return new Request($method, 'https://example.com/' . $path, $headers, $data); + } + + /** + * @inheritdoc + */ + protected function parseResponseToModelMetadataList(Response $response): array + { + $data = $response->getData(); + $modelsMetadata = []; + if (isset($data['data']) && is_array($data['data'])) { + foreach ($data['data'] as $modelData) { + if (isset($modelData['id']) && is_string($modelData['id'])) { + $factory = $this->modelMetadataStubFactory; + $modelMetadata = $factory($modelData['id']); + $modelsMetadata[] = $modelMetadata; + } + } + } + return $modelsMetadata; + } +} diff --git a/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php new file mode 100644 index 00000000..0b7e8549 --- /dev/null +++ b/tests/unit/Providers/OpenAiCompatibleImplementation/MockOpenAiCompatibleTextGenerationModel.php @@ -0,0 +1,180 @@ +mockHttpTransporter = $mockHttpTransporter; + $this->mockRequestAuthentication = $mockRequestAuthentication; + } + + /** + * @inheritdoc + */ + public function getHttpTransporter(): HttpTransporterInterface + { + return $this->mockHttpTransporter; + } + + /** + * @inheritdoc + */ + public function getRequestAuthentication(): RequestAuthenticationInterface + { + return $this->mockRequestAuthentication; + } + + /** + * @inheritdoc + */ + protected function createRequest( + HttpMethodEnum $method, + string $path, + array $headers = [], + $data = null + ): Request { + return new Request($method, 'https://example.com/' . $path, $headers, $data); + } + + /** + * Sets a mock generative AI result to be returned by parseResponseToGenerativeAiResult. + * + * @param GenerativeAiResult $result + */ + public function setMockGenerativeAiResult(GenerativeAiResult $result): void + { + $this->mockGenerativeAiResult = $result; + } + + /** + * @inheritdoc + */ + public function parseResponseToGenerativeAiResult(Response $response): GenerativeAiResult + { + if ($this->mockGenerativeAiResult) { + return $this->mockGenerativeAiResult; + } + // Fallback to parent if no mock is set, or implement a basic parsing for testing. + return parent::parseResponseToGenerativeAiResult($response); + } + + // Expose protected methods for testing. + public function exposePrepareGenerateTextParams(array $prompt): array + { + return $this->prepareGenerateTextParams($prompt); + } + + public function exposeMergeSystemInstruction(array $prompt, string $systemInstruction): array + { + return $this->mergeSystemInstruction($prompt, $systemInstruction); + } + + public function exposePrepareMessagesParam(array $messages): array + { + return $this->prepareMessagesParam($messages); + } + + public function exposeGetMessageRoleString(MessageRoleEnum $role): string + { + return $this->getMessageRoleString($role); + } + + public function exposeGetMessagePartContentData(MessagePart $part): ?array + { + return $this->getMessagePartContentData($part); + } + + public function exposeGetMessagePartToolCallData(MessagePart $part): ?array + { + return $this->getMessagePartToolCallData($part); + } + + public function exposeValidateOutputModalities(array $outputModalities): void + { + $this->validateOutputModalities($outputModalities); + } + + public function exposePrepareOutputModalitiesParam(array $modalities): array + { + return $this->prepareOutputModalitiesParam($modalities); + } + + public function exposePrepareToolsParam(array $functionDeclarations): array + { + return $this->prepareToolsParam($functionDeclarations); + } + + public function exposePrepareResponseFormatParam(?array $outputSchema): array + { + return $this->prepareResponseFormatParam($outputSchema); + } + + public function exposeParseResponseChoiceToCandidate(array $choiceData): Candidate + { + return $this->parseResponseChoiceToCandidate($choiceData); + } + + public function exposeParseResponseChoiceMessage(array $messageData): Message + { + return $this->parseResponseChoiceMessage($messageData); + } + + public function exposeParseResponseChoiceMessageParts(array $messageData): array + { + return $this->parseResponseChoiceMessageParts($messageData); + } + + public function exposeParseResponseChoiceMessageToolCallPart(array $toolCallData): ?MessagePart + { + return $this->parseResponseChoiceMessageToolCallPart($toolCallData); + } +} diff --git a/tests/unit/Providers/ProviderRegistryTest.php b/tests/unit/Providers/ProviderRegistryTest.php index 832980d4..91126997 100644 --- a/tests/unit/Providers/ProviderRegistryTest.php +++ b/tests/unit/Providers/ProviderRegistryTest.php @@ -6,10 +6,18 @@ use InvalidArgumentException; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Providers\Http\DTO\ApiKeyRequestAuthentication; +use WordPress\AiClient\Providers\Models\DTO\ModelConfig; +use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelRequirements; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; use WordPress\AiClient\Providers\ProviderRegistry; +use WordPress\AiClient\Tests\mocks\MockHttpTransporter; +use WordPress\AiClient\Tests\mocks\MockModel; +use WordPress\AiClient\Tests\mocks\MockModelMetadataDirectory; use WordPress\AiClient\Tests\mocks\MockProvider; +use WordPress\AiClient\Tests\mocks\MockProviderAvailability; +use WordPress\AiClient\Tests\mocks\MockRequestAuthentication; /** * @covers \WordPress\AiClient\Providers\ProviderRegistry @@ -20,7 +28,15 @@ class ProviderRegistryTest extends TestCase protected function setUp(): void { + parent::setUp(); $this->registry = new ProviderRegistry(); + MockProvider::reset(); // Reset static state of mock provider before each test. + } + + protected function tearDown(): void + { + MockProvider::reset(); // Reset static state of mock provider after each test. + parent::tearDown(); } /** @@ -173,7 +189,7 @@ public function testGetProviderModelThrowsException(): void $this->expectException(InvalidArgumentException::class); $this->expectExceptionMessage('Model not found: test-model'); - $modelConfig = new \WordPress\AiClient\Providers\Models\DTO\ModelConfig([]); + $modelConfig = new ModelConfig([]); $this->registry->getProviderModel('mock', 'test-model', $modelConfig); } @@ -209,4 +225,336 @@ public function testProviderInstanceCaching(): void // Should not throw any errors and should reuse cached instance $this->addToAssertionCount(1); } + + /** + * Tests that setHttpTransporter hooks up the transporter to registered providers. + * + * @return void + */ + public function testSetHttpTransporterHooksUpToProviders(): void + { + $mockTransporter = new MockHttpTransporter(); + $mockAvailability = new MockProviderAvailability(); + $mockModelMetadataDirectory = new MockModelMetadataDirectory([ + 'mock-text-model' => new ModelMetadata( + 'mock-text-model', + 'Mock Text Model', + [CapabilityEnum::textGeneration()], + [] + ) + ]); + $mockModel = new MockModel( + new ModelMetadata('mock-model', 'Mock Model', [], []), + new ModelConfig([]) + ); + + MockProvider::setAvailability($mockAvailability); + MockProvider::setModelMetadataDirectory($mockModelMetadataDirectory); + + // Register the provider AFTER setting up mocks, so it uses these mocks. + $this->registry->registerProvider(MockProvider::class); + + // Set the transporter on the registry. + $this->registry->setHttpTransporter($mockTransporter); + + // Get a model instance from the provider. + $modelConfig = new ModelConfig([]); + $retrievedModel = $this->registry->getProviderModel('mock', 'mock-text-model', $modelConfig); + + // Verify that the transporter was set on the relevant instances. + $this->assertSame($mockTransporter, $mockAvailability->getHttpTransporter()); + $this->assertSame($mockTransporter, $mockModelMetadataDirectory->getHttpTransporter()); + $this->assertSame($mockTransporter, $retrievedModel->getHttpTransporter()); + } + + /** + * Tests that setProviderRequestAuthentication hooks up the authentication to registered providers. + * + * @return void + */ + public function testSetProviderRequestAuthenticationHooksUpToProviders(): void + { + $mockTransporter = new MockHttpTransporter(); // Add this line + $this->registry->setHttpTransporter($mockTransporter); // Add this line + + $mockAuth = new MockRequestAuthentication('custom_token'); + $mockAvailability = new MockProviderAvailability(); + $mockModelMetadataDirectory = new MockModelMetadataDirectory([ + 'mock-text-model' => new ModelMetadata( + 'mock-text-model', + 'Mock Text Model', + [CapabilityEnum::textGeneration()], + [] + ) + ]); + $mockModel = new MockModel( + new ModelMetadata('mock-model', 'Mock Model', [], []), + new ModelConfig([]) + ); + + MockProvider::setAvailability($mockAvailability); + MockProvider::setModelMetadataDirectory($mockModelMetadataDirectory); + + // Register the provider AFTER setting up mocks, so it uses these mocks. + $this->registry->registerProvider(MockProvider::class); + + // Set the authentication on the specific provider. + $this->registry->setProviderRequestAuthentication('mock', $mockAuth); + + // Get a model instance from the provider. + $modelConfig = new ModelConfig([]); + $retrievedModel = $this->registry->getProviderModel('mock', 'mock-text-model', $modelConfig); + + // Verify that the authentication was set on the relevant instances. + $this->assertSame($mockAuth, $mockAvailability->getRequestAuthentication()); + $this->assertSame($mockAuth, $mockModelMetadataDirectory->getRequestAuthentication()); + $this->assertSame($mockAuth, $retrievedModel->getRequestAuthentication()); + } + + /** + * Tests that getProviderRequestAuthentication returns the correct instance. + * + * @return void + */ + public function testGetProviderRequestAuthentication(): void + { + $this->registry->registerProvider(MockProvider::class); + $mockAuth = new MockRequestAuthentication('another_token'); + $this->registry->setProviderRequestAuthentication('mock', $mockAuth); + + $retrievedAuth = $this->registry->getProviderRequestAuthentication('mock'); + $this->assertSame($mockAuth, $retrievedAuth); + } + + /** + * Tests that getProviderRequestAuthentication returns a default instance if not explicitly set. + * + * @return void + */ + public function testGetProviderRequestAuthenticationReturnsDefault(): void + { + $this->registry->registerProvider(MockProvider::class); + $retrievedAuth = $this->registry->getProviderRequestAuthentication('mock'); + + // By default, it should create an ApiKeyRequestAuthentication if environment variables are set. + // Since no env vars are set in tests, it should fall back to null. + $this->assertNull($retrievedAuth); + } + + /** + * Tests the internal getEnvVarName method using reflection. + * + * @dataProvider envVarNameProvider + * @param string $providerId The provider ID. + * @param string $field The field name. + * @param string $expected The expected environment variable name. + * @return void + */ + public function testGetEnvVarName(string $providerId, string $field, string $expected): void + { + $method = new \ReflectionMethod(ProviderRegistry::class, 'getEnvVarName'); + $method->setAccessible(true); + + $result = $method->invoke($this->registry, $providerId, $field); // Invoke on instance + + $this->assertEquals($expected, $result); + } + + /** + * Provides data for testing getEnvVarName. + * + * @return array + */ + public function envVarNameProvider(): array + { + return [ + 'camelCase provider and field' => ['myProvider', 'apiKey', 'MY_PROVIDER_API_KEY'], + 'kebab-case provider and field' => ['my-provider', 'api-key', 'MY_PROVIDER_API_KEY'], + 'snake_case provider and field' => ['my_provider', 'api_key', 'MY_PROVIDER_API_KEY'], + 'mixed case' => ['AnotherProvider', 'someOtherField', 'ANOTHER_PROVIDER_SOME_OTHER_FIELD'], + 'simple names' => ['openai', 'key', 'OPENAI_KEY'], + ]; + } + + /** + * Tests that createDefaultProviderRequestAuthentication creates ApiKeyRequestAuthentication when env var is set. + * + * @return void + */ + public function testCreateDefaultProviderRequestAuthenticationWithEnvVar(): void + { + // Temporarily set an environment variable. + putenv('MOCK_API_KEY=test_env_api_key'); + + $this->registry->registerProvider(MockProvider::class); + + $method = new \ReflectionMethod(ProviderRegistry::class, 'createDefaultProviderRequestAuthentication'); + $method->setAccessible(true); + + $auth = $method->invoke($this->registry, MockProvider::class); + + $this->assertInstanceOf(ApiKeyRequestAuthentication::class, $auth); + $this->assertEquals('test_env_api_key', $auth->getApiKey()); + + // Clean up environment variable. + putenv('MOCK_API_KEY'); + } + + /** + * Tests that createDefaultProviderRequestAuthentication returns null when env var is not set. + * + * @return void + */ + public function testCreateDefaultProviderRequestAuthenticationWithoutEnvVar(): void + { + // Ensure environment variable is not set. + putenv('MOCK_API_KEY'); + + $this->registry->registerProvider(MockProvider::class); + + $method = new \ReflectionMethod(ProviderRegistry::class, 'createDefaultProviderRequestAuthentication'); + $method->setAccessible(true); + + $auth = $method->invoke($this->registry, MockProvider::class); + + $this->assertNull($auth); + } + + /** + * Tests bindModelDependencies with HTTP transporter. + * + * @return void + */ + public function testBindModelDependenciesWithHttpTransporter(): void + { + // Register provider and set HTTP transporter + $this->registry->registerProvider(MockProvider::class); + $httpTransporter = new MockHttpTransporter(); + $this->registry->setHttpTransporter($httpTransporter); + + // Create a mock model + $modelMetadata = new ModelMetadata( + 'test-model', + 'Test Model', + [CapabilityEnum::textGeneration()], + [] + ); + $modelConfig = new ModelConfig(); + + // Create a mock model instance that implements WithHttpTransporterInterface + $modelInstance = $this->createMock(MockModel::class); + $modelInstance->expects($this->once()) + ->method('providerMetadata') + ->willReturn(MockProvider::metadata()); + + $modelInstance->expects($this->once()) + ->method('setHttpTransporter') + ->with($httpTransporter); + + // Call bindModelDependencies + $this->registry->bindModelDependencies($modelInstance); + } + + /** + * Tests bindModelDependencies with request authentication. + * + * @return void + */ + public function testBindModelDependenciesWithRequestAuthentication(): void + { + // Register provider and set authentication + $this->registry->registerProvider(MockProvider::class); + $authentication = new MockRequestAuthentication('test-api-key'); + $this->registry->setProviderRequestAuthentication('mock', $authentication); + + // Set HTTP transporter (required by registry) + $httpTransporter = new MockHttpTransporter(); + $this->registry->setHttpTransporter($httpTransporter); + + // Create a mock model instance that implements WithRequestAuthenticationInterface + $modelInstance = $this->createMock(MockModel::class); + $modelInstance->expects($this->once()) + ->method('providerMetadata') + ->willReturn(MockProvider::metadata()); + + $modelInstance->expects($this->once()) + ->method('setHttpTransporter') + ->with($httpTransporter); + + $modelInstance->expects($this->once()) + ->method('setRequestAuthentication') + ->with($authentication); + + // Call bindModelDependencies + $this->registry->bindModelDependencies($modelInstance); + } + + /** + * Tests bindModelDependencies with model that doesn't need dependencies. + * + * @return void + */ + public function testBindModelDependenciesWithSimpleModel(): void + { + // Register provider + $this->registry->registerProvider(MockProvider::class); + + // Create a mock model that doesn't implement dependency interfaces + $modelInstance = $this->createMock(\WordPress\AiClient\Providers\Models\Contracts\ModelInterface::class); + $modelInstance->expects($this->once()) + ->method('providerMetadata') + ->willReturn(MockProvider::metadata()); + + // Call bindModelDependencies - should not throw any errors + $this->registry->bindModelDependencies($modelInstance); + + // Test passes if no exceptions are thrown + $this->assertTrue(true); + } + + /** + * Tests bindModelDependencies with unregistered provider. + * + * @return void + */ + public function testBindModelDependenciesWithUnregisteredProvider(): void + { + // Create a mock model with a provider that's not registered + $providerMetadata = $this->createMock(\WordPress\AiClient\Providers\DTO\ProviderMetadata::class); + $providerMetadata->method('getId')->willReturn('unregistered-provider'); + + $modelInstance = $this->createMock(\WordPress\AiClient\Providers\Models\Contracts\ModelInterface::class); + $modelInstance->expects($this->once()) + ->method('providerMetadata') + ->willReturn($providerMetadata); + + // Expect exception when trying to bind dependencies for unregistered provider + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Provider not registered: unregistered-provider'); + + $this->registry->bindModelDependencies($modelInstance); + } + + /** + * Tests bindModelDependencies without HTTP transporter when model needs it. + * + * @return void + */ + public function testBindModelDependenciesWithoutHttpTransporter(): void + { + // Register provider but don't set HTTP transporter + $this->registry->registerProvider(MockProvider::class); + + // Create a mock model instance that implements WithHttpTransporterInterface + $modelInstance = $this->createMock(MockModel::class); + $modelInstance->expects($this->once()) + ->method('providerMetadata') + ->willReturn(MockProvider::metadata()); + + // Expect runtime exception when trying to get HTTP transporter that isn't set + $this->expectException(\RuntimeException::class); + $this->expectExceptionMessage('HttpTransporterInterface instance not set'); + + $this->registry->bindModelDependencies($modelInstance); + } } diff --git a/tests/unit/Results/DTO/GenerativeAiResultTest.php b/tests/unit/Results/DTO/GenerativeAiResultTest.php index ddde10e3..f4de622b 100644 --- a/tests/unit/Results/DTO/GenerativeAiResultTest.php +++ b/tests/unit/Results/DTO/GenerativeAiResultTest.php @@ -17,6 +17,7 @@ use WordPress\AiClient\Providers\Enums\ProviderTypeEnum; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; use WordPress\AiClient\Providers\Models\Enums\CapabilityEnum; +use WordPress\AiClient\Results\Contracts\ResultInterface; use WordPress\AiClient\Results\DTO\Candidate; use WordPress\AiClient\Results\DTO\GenerativeAiResult; use WordPress\AiClient\Results\DTO\TokenUsage; @@ -659,7 +660,7 @@ public function testImplementsResultInterface(): void ); $this->assertInstanceOf( - \WordPress\AiClient\Results\Contracts\ResultInterface::class, + ResultInterface::class, $result ); } diff --git a/tests/unit/Results/DTO/TokenUsageTest.php b/tests/unit/Results/DTO/TokenUsageTest.php index d65f8974..8cff5033 100644 --- a/tests/unit/Results/DTO/TokenUsageTest.php +++ b/tests/unit/Results/DTO/TokenUsageTest.php @@ -5,6 +5,8 @@ namespace WordPress\AiClient\Tests\unit\Results\DTO; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Results\DTO\TokenUsage; /** @@ -184,7 +186,7 @@ public function testImplementsWithJsonSchemaInterface(): void $tokenUsage = new TokenUsage(10, 20, 30); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $tokenUsage ); } @@ -278,7 +280,7 @@ public function testImplementsWithArrayTransformationInterface(): void $tokenUsage = new TokenUsage(10, 20, 30); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithArrayTransformationInterface::class, + WithArrayTransformationInterface::class, $tokenUsage ); } diff --git a/tests/unit/Tools/DTO/WebSearchTest.php b/tests/unit/Tools/DTO/WebSearchTest.php index 93dc6962..465478c9 100644 --- a/tests/unit/Tools/DTO/WebSearchTest.php +++ b/tests/unit/Tools/DTO/WebSearchTest.php @@ -5,6 +5,7 @@ namespace WordPress\AiClient\Tests\unit\Tools\DTO; use PHPUnit\Framework\TestCase; +use WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface; use WordPress\AiClient\Tests\traits\ArrayTransformationTestTrait; use WordPress\AiClient\Tools\DTO\WebSearch; @@ -222,7 +223,7 @@ public function testImplementsWithJsonSchemaInterface(): void $webSearch = new WebSearch(); $this->assertInstanceOf( - \WordPress\AiClient\Common\Contracts\WithJsonSchemaInterface::class, + WithJsonSchemaInterface::class, $webSearch ); }