Skip to content

Commit

Permalink
MBS-9476: Implement connector for Google Imagen
Browse files Browse the repository at this point in the history
  • Loading branch information
PhMemmel committed Dec 2, 2024
1 parent 70327d4 commit 0db9f97
Show file tree
Hide file tree
Showing 15 changed files with 542 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Currently supported AI tools:
- OpenAI Dall-E (also via Azure)
- Google Gemini
- Google Synthesize (text to speech)
- Google Imagen 3 (via Vertex AI)
- Ollama

Currently available AI purposes:
Expand Down
10 changes: 6 additions & 4 deletions classes/base_connector.php
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ public function get_available_options(): array {
*/
public function make_request(array $data): request_response {
$client = new http_client([
'timeout' => get_config('local_ai_manager', 'requesttimeout'),
'verify' => !empty(get_config('local_ai_manager', 'verifyssl')),
'timeout' => get_config('local_ai_manager', 'requesttimeout'),
'verify' => !empty(get_config('local_ai_manager', 'verifyssl')),
]);

$options['headers'] = $this->get_headers();
Expand All @@ -177,7 +177,8 @@ public function make_request(array $data): request_response {
$return = request_response::create_from_error(
$response->getStatusCode(),
get_string('error_sendingrequestfailed', 'local_ai_manager'),
$response->getBody(),
$response->getBody()->getContents(),
$response->getBody()
);
}
return $return;
Expand Down Expand Up @@ -231,7 +232,8 @@ final protected function create_error_response_from_exception(ClientExceptionInt
if (method_exists($exception, 'getResponse') && !empty($exception->getResponse())) {
$debuginfo .= $exception->getResponse()->getBody()->getContents();
}
return request_response::create_from_error($exception->getCode(), $message, $debuginfo);
return request_response::create_from_error($exception->getCode(), $message, $debuginfo,
$exception->getResponse()->getBody());
}

/**
Expand Down
8 changes: 6 additions & 2 deletions classes/base_instance.php
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ final public function edit_form_definition(\MoodleQuickForm $mform, array $custo
*/
final public function store_formdata(stdClass $data): void {
$this->set_name(trim($data->name));
$this->set_endpoint(trim($data->endpoint));
if (!empty($data->endpoint)) {
$this->set_endpoint(trim($data->endpoint));
}
$this->set_apikey(trim($data->apikey));
$this->set_connector($data->connector);
$this->set_tenant(trim($data->tenant));
Expand Down Expand Up @@ -555,7 +557,9 @@ final public function validation(array $data, array $files): array {
if (empty($data['name'])) {
$errors['name'] = get_string('formvalidation_editinstance_name', 'local_ai_manager');
}
if (str_starts_with($data['endpoint'], 'http://') && !str_starts_with($data['endpoint'], 'https://')) {
if (!empty($data['endpoint'])
&& str_starts_with($data['endpoint'], 'http://')
&& !str_starts_with($data['endpoint'], 'https://')) {
$errors['endpoint'] = get_string('formvalidation_editinstance_endpointnossl', 'local_ai_manager');
}
return $errors + $this->extend_validation($data, $files);
Expand Down
2 changes: 1 addition & 1 deletion classes/local/prompt_response.php
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class prompt_response {
private int $code;

/** @var string If there has been an error, this variable contains the error message */
private string $errormessage;
private string $errormessage = '';

/** @var string If there has been an error, this variable contains additional debugging information */
private string $debuginfo;
Expand Down
7 changes: 6 additions & 1 deletion classes/local/request_response.php
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,18 @@ public function get_code(): int {
* @param int $code the status code
* @param string $errormessage the error message
* @param string $debuginfo the debug info
* @param ?StreamInterface $rawresponse the raw response object, or null if not available
* @return request_response the request_response object containing all information about the error
*/
public static function create_from_error(int $code, string $errormessage, string $debuginfo): request_response {
public static function create_from_error(int $code, string $errormessage, string $debuginfo,
?StreamInterface $rawresponse = null): request_response {
$requestresponse = new self();
$requestresponse->set_code($code);
$requestresponse->set_errormessage($errormessage);
$requestresponse->set_debuginfo($debuginfo);
if (!empty($rawresponse)) {
$requestresponse->set_response($rawresponse);
}
return $requestresponse;
}

Expand Down
5 changes: 4 additions & 1 deletion classes/manager.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
use context;
use context_system;
use core_plugin_manager;
use dml_exception;
use local_ai_manager\event\get_ai_response_failed;
use local_ai_manager\event\get_ai_response_succeeded;
use local_ai_manager\local\config_manager;
Expand Down Expand Up @@ -182,6 +181,10 @@ public function perform_request(string $prompttext, array $options = []): prompt
return $promptresponse;
}
$promptcompletion = $this->connector->execute_prompt_completion($requestresult->get_response(), $options);
if (!empty($promptcompletion->get_errormessage())) {
get_ai_response_failed::create_from_prompt_response($promptdata, $promptcompletion, $duration)->trigger();
return $promptcompletion;
}
if (!empty($options['forcenewitemid']) && !empty($options['component']) &&
!empty($options['contextid'] && !empty($options['itemid']))) {
if ($DB->record_exists('local_ai_manager_request_log',
Expand Down
1 change: 1 addition & 0 deletions tests/ai_manager_utils_test.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace local_ai_manager;

use Firebase\JWT\JWT;
use stdClass;

/**
Expand Down
1 change: 1 addition & 0 deletions tools/googlesynthesize/classes/connector.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public function make_request(array $data): request_response {
} else {
$return = request_response::create_from_error($response->getStatusCode(),
get_string('error_sendingrequestfailed', 'local_ai_manager'),
$response->getBody()->getContents(),
$response->getBody()
);
}
Expand Down
266 changes: 266 additions & 0 deletions tools/imagen/classes/connector.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
<?php
// This file is part of Moodle - http://moodle.org/
//
// Moodle is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Moodle is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Moodle. If not, see <http://www.gnu.org/licenses/>.

namespace aitool_imagen;

use core\http_client;
use Firebase\JWT\JWT;
use local_ai_manager\base_connector;
use local_ai_manager\local\prompt_response;
use local_ai_manager\local\request_response;
use local_ai_manager\local\unit;
use local_ai_manager\local\usage;
use local_ai_manager\manager;
use Psr\Http\Client\ClientExceptionInterface;
use Psr\Http\Message\StreamInterface;

/**
* Connector for Imagen.
*
* @package aitool_imagen
* @copyright 2024 ISB Bayern
* @author Philipp Memmel
* @license http://www.gnu.org/copyleft/gpl.html GNU GPL v3 or later
*/
class connector extends base_connector {

/** @var string The access token to use for authentication against the Google imagen API endpoint. */
private string $accesstoken = '';

#[\Override]
public function get_models_by_purpose(): array {
return [
'imggen' => ['imagegeneration@006', 'imagen-3.0-generate-001'],
];
}

#[\Override]
public function get_prompt_data(string $prompttext, array $requestoptions): array {
$promptdata = [
'instances' => [
[
'prompt' => $prompttext,
],
],
'parameters' => [
'sampleCount' => 1,
'safetySetting' => 'block_few',
'language' => 'en',
'aspectRatio' => $requestoptions['sizes'][0],
],
];

return $promptdata;
}

#[\Override]
protected function get_headers(): array {
$headers = parent::get_headers();
$headers['Authorization'] = 'Bearer ' . $this->accesstoken;
return $headers;
}

#[\Override]
public function get_unit(): unit {
return unit::COUNT;
}

#[\Override]
public function make_request(array $data): request_response {
// Currently, imagen does not support many languages. So we first translate the prompt into English and hardcode the
// language to "English" later on in the request options.
$translatemanager = new manager('translate');
$translaterequestresult = $translatemanager->perform_request(
'Translate the following words into English, only return the translated text: '
. $data['instances'][0]['prompt']);
if ($translaterequestresult->get_code() !== 200) {
return request_response::create_from_error($translaterequestresult->get_code(),
get_string('err_translationfailed', 'aitool_imagen'), $translaterequestresult->get_debuginfo());
}
$translatedprompt = $translaterequestresult->get_content();

// Subsitute the current prompt by the translated one.
$data['instances'][0]['prompt'] = $translatedprompt;

try {
// Composing the "Authorization" header is not that easy as just looking up a Bearer token in the database.
// So we here explicitly retrieve the access token from cache or the Google OAuth API and do some proper error handling.
// After we stored it in $this->accesstoken it can be properly set into the header by the self::get_headers method.
$this->accesstoken = $this->get_access_token();
} catch (\moodle_exception $exception) {
return request_response::create_from_error(0, $exception->getMessage(), $exception->getTraceAsString());
}
// We keep track of the time the cached access token expires. However, due latency, different clocks
// on different servers etc. we could end up sending a request with an actually expired access token.
// So we clear the cached access token and re-submit the request ONE TIME if we receive a 401 response code
// with an ACCESS_TOKEN_EXPIRED error code.

$requestresponse = parent::make_request($data);
if ($requestresponse->get_code() === 401) {
// We need to reset the stream, so we can again read it.
$requestresponse->get_response()->rewind();
$content = json_decode($requestresponse->get_response()->getContents(), true);
if (!empty(array_filter($content['error']['details'], fn($details) => $details['reason'] === 'ACCESS_TOKEN_EXPIRED'))) {
$authcache = \cache::make('aitool_imagen', 'auth');
$authcache->delete($this->instance->get_id());
$requestresponse = parent::make_request($data);
}
}
return $requestresponse;
}

#[\Override]
public function execute_prompt_completion(StreamInterface $result, array $options = []): prompt_response {
global $USER;
$content = json_decode($result->getContents(), true);
if (empty($content['predictions']) || !array_key_exists('bytesBase64Encoded', $content['predictions'][0])) {
return prompt_response::create_from_error(400,
get_string('err_predictionmissing', 'aitool_imagen'), '');
}
$fs = get_file_storage();
$fileinfo = [
'contextid' => \context_user::instance($USER->id)->id,
'component' => 'user',
'filearea' => 'draft',
'itemid' => $options['itemid'],
'filepath' => '/',
'filename' => $options['filename'],
];
$file = $fs->create_file_from_string($fileinfo, base64_decode($content['predictions'][0]['bytesBase64Encoded']));

$filepath = \moodle_url::make_draftfile_url(
$file->get_itemid(),
$file->get_filepath(),
$file->get_filename()
)->out();

return prompt_response::create_from_result($this->instance->get_model(), new usage(1.0), $filepath);
}

#[\Override]
public function get_available_options(): array {
$options['sizes'] = [
['key' => '1:1', 'displayname' => '1:1 (1536 x 1536)'],
['key' => '9:16', 'displayname' => '9:16 (1152 x 2016)'],
['key' => '16:9', 'displayname' => '16:9 (2016 x 1134)'],
['key' => '3:4', 'displayname' => '3:4 (1344 x 1792)'],
['key' => '4:3', 'displayname' => '4:3 (1792 x 1344)'],
];
return $options;
}

/**
* Retrieves a fresh access token from the Google oauth endpoint.
*
* @return array of the form ['access_token' => 'xxx', 'expires' => 1730805678] containing the access token and the time at
* which the token expires. If there has been an error, the array is of the form
* ['error' => 'more detailed info about the error']
* @throws \dml_exception
*/
public function retrieve_access_token(): array {
$clock = \core\di::get(\core\clock::class);
$serviceaccountinfo = json_decode($this->instance->get_customfield1());
$kid = $serviceaccountinfo->private_key_id;
$privatekey = $serviceaccountinfo->private_key;
$clientemail = $serviceaccountinfo->client_email;
$jwtpayload = [
'iss' => $clientemail,
'sub' => $clientemail,
'scope' => 'https://www.googleapis.com/auth/cloud-platform',
'aud' => 'https://oauth2.googleapis.com/token',
'iat' => $clock->time(),
'exp' => $clock->time() + HOURSECS,
];
$jwt = JWT::encode($jwtpayload, $privatekey, 'RS256', null, ['kid' => $kid]);

$client = new http_client([
'timeout' => get_config('local_ai_manager', 'requesttimeout'),
]);
$options['query'] = [
'assertion' => $jwt,
'grant_type' => 'urn:ietf:params:oauth:grant-type:jwt-bearer',
];

try {
$response = $client->post('https://oauth2.googleapis.com/token', $options);
} catch (ClientExceptionInterface $exception) {
return ['error' => $exception->getMessage()];
}
if ($response->getStatusCode() === 200) {
$content = $response->getBody()->getContents();
if (empty($content)) {
return ['error' => 'Empty response'];
}
$content = json_decode($content, true);
if (empty($content['access_token'])) {
return ['error' => 'Response does not contain "access_token" key'];
}
return [
'access_token' => $content['access_token'],
// We set the expiry time of the access token and reduce it by 10 seconds to avoid some errors caused
// by different clocks on different servers, latency etc.
'expires' => $clock->time() + intval($content['expires_in']) - 10,
];
} else {
return ['error' => 'Response status code is not OK 200, but ' . $response->getStatusCode() . ': ' .
$response->getBody()->getContents()];
}
}

/**
* Gets an access token for accessing the imagen API.
*
* This will check if the cached access token still has not expired. If cache is empty or the token has expired
* a new access token will be fetched by calling {@see self::retrieve_access_token} and the new token will be stored
* in the cache.
*
* @return string the access token as string, empty if no
*/
public function get_access_token(): string {
$clock = \core\di::get(\core\clock::class);
$authcache = \cache::make('aitool_imagen', 'auth');
$cachedauthinfo = $authcache->get($this->instance->get_id());
if (empty($cachedauthinfo) || json_decode($cachedauthinfo)->expires < $clock->time()) {
$authinfo = $this->retrieve_access_token();
if (!empty($authinfo['error'])) {
throw new \moodle_exception('Error retrieving access token', '', '', '', $authinfo['error']);
}
$cachedauthinfo = json_encode($authinfo);
$authcache->set($this->instance->get_id(), $cachedauthinfo);
$accesstoken = $authinfo['access_token'];
} else {
$accesstoken = json_decode($cachedauthinfo, true)['access_token'];
}
return $accesstoken;
}

#[\Override]
protected function get_custom_error_message(int $code, ?ClientExceptionInterface $exception = null): string {
$message = '';
switch ($code) {
case 400:
if (method_exists($exception, 'getResponse') && !empty($exception->getResponse())) {
$responsebody = json_decode($exception->getResponse()->getBody()->getContents());
if (property_exists($responsebody, 'error') && property_exists($responsebody->error, 'status')
&& $responsebody->error->status === 'INVALID_ARGUMENT') {
$message = get_string('err_contentpolicyviolation', 'aitool_imagen');
}
}
break;
}
return $message;
}
}
Loading

0 comments on commit 0db9f97

Please sign in to comment.