diff --git a/conf/openmetadata.yaml b/conf/openmetadata.yaml index 4150fc776f86..a00b29f1617a 100644 --- a/conf/openmetadata.yaml +++ b/conf/openmetadata.yaml @@ -494,7 +494,7 @@ elasticsearch: naturalLanguageSearch: enabled: ${NATURAL_LANGUAGE_SEARCH_ENABLED:-false} semanticSearchEnabled: ${SEMANTIC_SEARCH_ENABLED:-false} - embeddingProvider: ${EMBEDDING_PROVIDER:-bedrock} # Options: "openai", "bedrock", "djl" + embeddingProvider: ${EMBEDDING_PROVIDER:-bedrock} # Options: "openai", "bedrock", "google", "djl" maxConcurrentEmbeddingRequests: ${MAX_CONCURRENT_EMBEDDING_REQUESTS:-10} providerClass: ${NATURAL_LANGUAGE_SEARCH_PROVIDER_CLASS:-org.openmetadata.service.search.nlq.NoOpNLQService} bedrock: @@ -515,6 +515,11 @@ elasticsearch: apiVersion: ${OPENAI_API_VERSION:-"2024-02-01"} # Azure OpenAI API version embeddingModelId: ${OPENAI_EMBEDDING_MODEL_ID:-"text-embedding-3-small"} embeddingDimension: ${OPENAI_EMBEDDING_DIMENSION:-1536} + google: + apiKey: ${GOOGLE_API_KEY:-""} # API key from Google AI Studio + embeddingModelId: ${GOOGLE_EMBEDDING_MODEL_ID:-"gemini-embedding-001"} + embeddingDimension: ${GOOGLE_EMBEDDING_DIMENSION:-768} # Sent as outputDimensionality. gemini-embedding-001 supports 768/1536/3072; text-embedding-004 supports 768. + endpoint: ${GOOGLE_API_ENDPOINT:-""} # Optional override; full :embedContent URL. Leave empty to use the default Generative Language API endpoint. djl: embeddingModel: ${DJL_EMBEDDING_MODEL:-"ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L6-v2"} diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java index c1e606f2a210..5cfc44dd2ceb 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/jdbi3/SystemRepository.java @@ -50,6 +50,7 @@ import org.openmetadata.schema.security.client.OpenMetadataJWTClientConfig; import org.openmetadata.schema.security.scim.ScimConfiguration; import org.openmetadata.schema.service.configuration.elasticsearch.ElasticSearchConfiguration; +import org.openmetadata.schema.service.configuration.elasticsearch.Google; import org.openmetadata.schema.service.configuration.elasticsearch.NaturalLanguageSearchConfiguration; import org.openmetadata.schema.service.configuration.slackApp.SlackAppConfiguration; import org.openmetadata.schema.services.connections.metadata.AuthProvider; @@ -766,8 +767,21 @@ private String getEmbeddingConfigurationMessage(OpenMetadataApplicationConfig ap nlpConfig.getOpenai().getEmbeddingDimension(), deploymentInfo); } + case "google" -> { + Google googleCfg = nlpConfig.getGoogle(); + if (googleCfg == null) { + yield "Google provider selected but google configuration block is missing"; + } + String googleEndpoint = + nullOrEmpty(googleCfg.getEndpoint()) + ? "generativelanguage.googleapis.com" + : googleCfg.getEndpoint(); + yield String.format( + "Google configuration: endpoint: %s, embeddingModelId: %s, embeddingDimension: %s", + googleEndpoint, googleCfg.getEmbeddingModelId(), googleCfg.getEmbeddingDimension()); + } default -> String.format( - "Unknown provider '%s'. Supported providers: djl, bedrock, openai", provider); + "Unknown provider '%s'. Supported providers: djl, bedrock, openai, google", provider); }; } catch (Exception e) { LOG.error("Error getting embedding configuration", e); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java index 442ec0f2d99e..3e7c92802179 100644 --- a/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/SearchRepository.java @@ -142,6 +142,7 @@ import org.openmetadata.service.search.vector.client.BedrockEmbeddingClient; import org.openmetadata.service.search.vector.client.DjlEmbeddingClient; import org.openmetadata.service.search.vector.client.EmbeddingClient; +import org.openmetadata.service.search.vector.client.GoogleEmbeddingClient; import org.openmetadata.service.search.vector.client.OpenAIEmbeddingClient; import org.openmetadata.service.security.policyevaluator.SubjectContext; import org.openmetadata.service.util.EntityUtil; @@ -3227,6 +3228,13 @@ protected EmbeddingClient createEmbeddingClient(ElasticSearchConfiguration esCon } yield new OpenAIEmbeddingClient(esConfig); } + case "google" -> { + if (config.getGoogle() == null) { + throw new IllegalStateException( + "Google configuration is required when using google provider"); + } + yield new GoogleEmbeddingClient(esConfig); + } case "djl" -> { if (config.getDjl() == null) { throw new IllegalStateException("DJL configuration is required when using djl provider"); diff --git a/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClient.java b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClient.java new file mode 100644 index 000000000000..7ed656226923 --- /dev/null +++ b/openmetadata-service/src/main/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClient.java @@ -0,0 +1,208 @@ +/* + * Copyright 2024 Collate + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package org.openmetadata.service.search.vector.client; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; +import java.net.URI; +import java.net.URLEncoder; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import lombok.extern.slf4j.Slf4j; +import org.openmetadata.schema.service.configuration.elasticsearch.ElasticSearchConfiguration; +import org.openmetadata.schema.service.configuration.elasticsearch.Google; +import org.openmetadata.schema.service.configuration.elasticsearch.NaturalLanguageSearchConfiguration; + +@Slf4j +public final class GoogleEmbeddingClient extends EmbeddingClient { + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String MODELS_PREFIX = "models/"; + private static final String DEFAULT_BASE_URL = + "https://generativelanguage.googleapis.com/v1beta/" + MODELS_PREFIX; + + private final HttpClient httpClient; + private final String apiKey; + private final String modelId; + private final int dimension; + private final String endpoint; + + public GoogleEmbeddingClient(ElasticSearchConfiguration config) { + super(resolveMaxConcurrent(config)); + NaturalLanguageSearchConfiguration nlsCfg = config.getNaturalLanguageSearch(); + Google googleCfg = nlsCfg.getGoogle(); + if (googleCfg == null) { + throw new IllegalArgumentException("Google configuration is required"); + } + if (googleCfg.getApiKey() == null || googleCfg.getApiKey().isBlank()) { + throw new IllegalArgumentException("Google API key is required"); + } + if (googleCfg.getEmbeddingModelId() == null || googleCfg.getEmbeddingModelId().isBlank()) { + throw new IllegalArgumentException("Google embedding model ID is required"); + } + if (googleCfg.getEmbeddingDimension() == null || googleCfg.getEmbeddingDimension() <= 0) { + throw new IllegalArgumentException("Google embedding dimension must be positive"); + } + + this.apiKey = googleCfg.getApiKey(); + this.modelId = googleCfg.getEmbeddingModelId(); + this.dimension = googleCfg.getEmbeddingDimension(); + this.endpoint = resolveEndpoint(googleCfg); + this.httpClient = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(30)).build(); + + LOG.info( + "Initialized GoogleEmbeddingClient with model={}, dimension={}, endpoint={}", + modelId, + dimension, + endpoint); + } + + GoogleEmbeddingClient( + HttpClient httpClient, String apiKey, String modelId, int dimension, String endpoint) { + this(httpClient, apiKey, modelId, dimension, endpoint, DEFAULT_MAX_CONCURRENT_REQUESTS); + } + + GoogleEmbeddingClient( + HttpClient httpClient, + String apiKey, + String modelId, + int dimension, + String endpoint, + int maxConcurrentRequests) { + super(maxConcurrentRequests); + this.httpClient = httpClient; + this.apiKey = apiKey; + this.modelId = modelId; + this.dimension = dimension; + this.endpoint = endpoint; + } + + private String resolveEndpoint(Google config) { + String configured = config.getEndpoint(); + if (configured != null && !configured.isBlank()) { + String normalizedEndpoint = configured.replaceAll("/+$", ""); + if (!normalizedEndpoint.contains(":embedContent")) { + throw new IllegalArgumentException( + "Invalid google.endpoint configuration. Expected a full Google embedding endpoint " + + "URL containing ':embedContent', for example " + + "'https://generativelanguage.googleapis.com/v1beta/models/" + + config.getEmbeddingModelId() + + ":embedContent'."); + } + return normalizedEndpoint; + } + return DEFAULT_BASE_URL + config.getEmbeddingModelId() + ":embedContent"; + } + + @Override + protected float[] doEmbed(String text) { + if (text == null || text.isBlank()) { + throw new IllegalArgumentException("Input text must not be null or blank"); + } + + try { + String body = buildRequestBody(text); + HttpRequest request = buildRequest(body); + HttpResponse response = + httpClient.send(request, HttpResponse.BodyHandlers.ofString()); + + if (response.statusCode() != 200) { + String errorMsg = extractErrorMessage(response.body()); + throw new RuntimeException( + "Google API returned status " + response.statusCode() + ": " + errorMsg); + } + + return parseEmbeddingResponse(response.body()); + } catch (IOException e) { + LOG.error("IO error calling Google API: {}", e.getMessage(), e); + throw new RuntimeException("Google embedding generation failed due to IO error", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Google embedding generation was interrupted", e); + } + } + + private String buildRequestBody(String text) throws IOException { + ObjectNode payload = MAPPER.createObjectNode(); + payload.put("model", MODELS_PREFIX + modelId); + ObjectNode content = payload.putObject("content"); + ArrayNode parts = content.putArray("parts"); + ObjectNode part = parts.addObject(); + part.put("text", text); + // Pin the response vector size to the configured dimension. Required for `gemini-embedding-001` + // (defaults to 3072 otherwise); supported and silently truncating for `text-embedding-004`. + payload.put("outputDimensionality", dimension); + return MAPPER.writeValueAsString(payload); + } + + private HttpRequest buildRequest(String body) { + // Google's Generative Language API requires the API key as a `key=` query parameter; + // it does not accept Bearer/Authorization headers for AI Studio keys. + String encodedKey = URLEncoder.encode(apiKey, StandardCharsets.UTF_8); + String separator = endpoint.contains("?") ? "&" : "?"; + String url = endpoint + separator + "key=" + encodedKey; + return HttpRequest.newBuilder() + .uri(URI.create(url)) + .header("Content-Type", "application/json") + .timeout(Duration.ofSeconds(30)) + .POST(HttpRequest.BodyPublishers.ofString(body)) + .build(); + } + + @Override + public int getDimension() { + return dimension; + } + + @Override + public String getModelId() { + return modelId; + } + + private float[] parseEmbeddingResponse(String responseBody) { + try { + JsonNode root = MAPPER.readTree(responseBody); + JsonNode embedding = root.get("embedding"); + if (embedding == null || !embedding.isObject()) { + throw new RuntimeException("Invalid Google response: no embedding object found"); + } + JsonNode values = embedding.get("values"); + if (values == null || !values.isArray() || values.isEmpty()) { + throw new RuntimeException("Invalid Google response: no values array found"); + } + float[] result = new float[values.size()]; + for (int i = 0; i < values.size(); i++) { + result[i] = (float) values.get(i).asDouble(); + } + return result; + } catch (IOException e) { + throw new RuntimeException("Failed to parse Google embedding response", e); + } + } + + private String extractErrorMessage(String responseBody) { + try { + JsonNode root = MAPPER.readTree(responseBody); + JsonNode error = root.get("error"); + if (error != null && error.has("message")) { + return error.get("message").asText(); + } + } catch (Exception e) { + LOG.trace("Could not parse Google error envelope: {}", e.getMessage()); + } + return responseBody; + } +} diff --git a/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClientTest.java b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClientTest.java new file mode 100644 index 000000000000..bb05acdbb024 --- /dev/null +++ b/openmetadata-service/src/test/java/org/openmetadata/service/search/vector/client/GoogleEmbeddingClientTest.java @@ -0,0 +1,598 @@ +/* + * Copyright 2024 Collate + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file + * except in compliance with the License. You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package org.openmetadata.service.search.vector.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import javax.net.ssl.SSLSession; +import org.junit.jupiter.api.Test; +import org.openmetadata.schema.service.configuration.elasticsearch.ElasticSearchConfiguration; +import org.openmetadata.schema.service.configuration.elasticsearch.Google; +import org.openmetadata.schema.service.configuration.elasticsearch.NaturalLanguageSearchConfiguration; + +class GoogleEmbeddingClientTest { + + private static final String EMBED_ENDPOINT = + "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent"; + + private static class StubHttpResponse implements HttpResponse { + private final String body; + private final int statusCode; + private final HttpRequest request; + + StubHttpResponse(String body, int statusCode, HttpRequest request) { + this.body = body; + this.statusCode = statusCode; + this.request = request; + } + + @Override + public int statusCode() { + return statusCode; + } + + @Override + public HttpRequest request() { + return request; + } + + @Override + public Optional> previousResponse() { + return Optional.empty(); + } + + @Override + public HttpHeaders headers() { + return HttpHeaders.of(Map.of(), (a, b) -> true); + } + + @Override + public String body() { + return body; + } + + @Override + public Optional sslSession() { + return Optional.empty(); + } + + @Override + public URI uri() { + return request.uri(); + } + + @Override + public HttpClient.Version version() { + return HttpClient.Version.HTTP_2; + } + } + + private static class StubHttpClient extends HttpClient { + private final String responseBody; + private final int statusCode; + private final List capturedRequests = new ArrayList<>(); + + StubHttpClient(String responseBody, int statusCode) { + this.responseBody = responseBody; + this.statusCode = statusCode; + } + + List getCapturedRequests() { + return capturedRequests; + } + + @Override + public Optional authenticator() { + return Optional.empty(); + } + + @Override + public Optional connectTimeout() { + return Optional.empty(); + } + + @Override + public Optional cookieHandler() { + return Optional.empty(); + } + + @Override + public Redirect followRedirects() { + return Redirect.NEVER; + } + + @Override + public Optional proxy() { + return Optional.empty(); + } + + @Override + public javax.net.ssl.SSLContext sslContext() { + return null; + } + + @Override + public javax.net.ssl.SSLParameters sslParameters() { + return null; + } + + @Override + public Optional executor() { + return Optional.empty(); + } + + @Override + public Version version() { + return Version.HTTP_2; + } + + @Override + @SuppressWarnings("unchecked") + public HttpResponse send( + HttpRequest request, HttpResponse.BodyHandler responseBodyHandler) { + capturedRequests.add(request); + return (HttpResponse) new StubHttpResponse(responseBody, statusCode, request); + } + + @Override + public CompletableFuture> sendAsync( + HttpRequest request, HttpResponse.BodyHandler responseBodyHandler) { + return CompletableFuture.supplyAsync(() -> send(request, responseBodyHandler)); + } + + @Override + public CompletableFuture> sendAsync( + HttpRequest request, + HttpResponse.BodyHandler responseBodyHandler, + HttpResponse.PushPromiseHandler pushPromiseHandler) { + return sendAsync(request, responseBodyHandler); + } + } + + @Test + void testSuccessfulEmbeddingResponse() { + String response = "{\"embedding\":{\"values\":[0.1,0.2,0.3]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient(httpClient, "test-key", "text-embedding-004", 3, EMBED_ENDPOINT); + + float[] embedding = client.embed("hello world"); + + assertNotNull(embedding); + assertEquals(3, embedding.length); + assertEquals(0.1f, embedding[0], 0.001f); + assertEquals(0.2f, embedding[1], 0.001f); + assertEquals(0.3f, embedding[2], 0.001f); + } + + @Test + void testClientCreationWithConfig() { + ElasticSearchConfiguration config = buildConfig("test-key", "text-embedding-004", 768); + GoogleEmbeddingClient client = new GoogleEmbeddingClient(config); + + assertEquals(768, client.getDimension()); + assertEquals("text-embedding-004", client.getModelId()); + } + + @Test + void testClientCreationWithCustomModel() { + ElasticSearchConfiguration config = buildConfig("test-key", "gemini-embedding-001", 3072); + GoogleEmbeddingClient client = new GoogleEmbeddingClient(config); + + assertEquals(3072, client.getDimension()); + assertEquals("gemini-embedding-001", client.getModelId()); + } + + @Test + void testMissingGoogleConfigThrows() { + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testMissingApiKeyThrows() { + Google googleCfg = + new Google().withEmbeddingModelId("text-embedding-004").withEmbeddingDimension(768); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testBlankApiKeyThrows() { + Google googleCfg = + new Google() + .withApiKey(" ") + .withEmbeddingModelId("text-embedding-004") + .withEmbeddingDimension(768); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testMissingModelIdThrows() { + Google googleCfg = new Google().withApiKey("test-key").withEmbeddingDimension(768); + // Schema defaults embeddingModelId to "text-embedding-004"; force-null to exercise the guard. + googleCfg.setEmbeddingModelId(null); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testBlankModelIdThrows() { + Google googleCfg = + new Google().withApiKey("test-key").withEmbeddingModelId(" ").withEmbeddingDimension(768); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testZeroDimensionThrows() { + Google googleCfg = + new Google() + .withApiKey("test-key") + .withEmbeddingModelId("text-embedding-004") + .withEmbeddingDimension(0); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testNegativeDimensionThrows() { + Google googleCfg = + new Google() + .withApiKey("test-key") + .withEmbeddingModelId("text-embedding-004") + .withEmbeddingDimension(-100); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testNullDimensionThrows() { + Google googleCfg = + new Google().withApiKey("test-key").withEmbeddingModelId("text-embedding-004"); + googleCfg.setEmbeddingDimension(null); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + } + + @Test + void testCustomEndpointConstruction() { + Google googleCfg = + new Google() + .withApiKey("test-key") + .withEmbeddingModelId("text-embedding-004") + .withEmbeddingDimension(768) + .withEndpoint( + "https://proxy.example.com/v1beta/models/text-embedding-004:embedContent/"); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + GoogleEmbeddingClient client = new GoogleEmbeddingClient(config); + assertNotNull(client); + assertEquals("text-embedding-004", client.getModelId()); + assertEquals(768, client.getDimension()); + } + + @Test + void testCustomEndpointWithoutEmbedContentThrows() { + Google googleCfg = + new Google() + .withApiKey("test-key") + .withEmbeddingModelId("text-embedding-004") + .withEmbeddingDimension(768) + .withEndpoint("https://proxy.example.com/v1beta/models/"); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + + IllegalArgumentException ex = + assertThrows(IllegalArgumentException.class, () -> new GoogleEmbeddingClient(config)); + assertTrue(ex.getMessage().contains(":embedContent")); + } + + @Test + void testNullTextThrows() { + ElasticSearchConfiguration config = buildConfig("test-key", "text-embedding-004", 768); + GoogleEmbeddingClient client = new GoogleEmbeddingClient(config); + + assertThrows(IllegalArgumentException.class, () -> client.embed(null)); + } + + @Test + void testBlankTextThrows() { + ElasticSearchConfiguration config = buildConfig("test-key", "text-embedding-004", 768); + GoogleEmbeddingClient client = new GoogleEmbeddingClient(config); + + assertThrows(IllegalArgumentException.class, () -> client.embed(" ")); + } + + @Test + void testNon200StatusThrowsWithExtractedErrorMessage() { + String errorBody = + "{\"error\":{\"code\":429,\"message\":\"Quota exceeded\",\"status\":\"RESOURCE_EXHAUSTED\"}}"; + StubHttpClient httpClient = new StubHttpClient(errorBody, 429); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "test-key", "text-embedding-004", 768, EMBED_ENDPOINT); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> client.embed("hello")); + assertTrue(ex.getMessage().contains("429")); + assertTrue(ex.getMessage().contains("Quota exceeded")); + } + + @Test + void testNon200StatusWithNonJsonBodyEchoesBody() { + StubHttpClient httpClient = new StubHttpClient("Service Unavailable", 503); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "test-key", "text-embedding-004", 768, EMBED_ENDPOINT); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> client.embed("hello")); + assertTrue(ex.getMessage().contains("503")); + assertTrue(ex.getMessage().contains("Service Unavailable")); + } + + @Test + void testMissingEmbeddingObjectThrows() { + StubHttpClient httpClient = new StubHttpClient("{\"foo\":\"bar\"}", 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "test-key", "text-embedding-004", 768, EMBED_ENDPOINT); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> client.embed("hello")); + assertTrue(ex.getMessage().contains("no embedding object")); + } + + @Test + void testMissingValuesArrayThrows() { + StubHttpClient httpClient = new StubHttpClient("{\"embedding\":{}}", 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "test-key", "text-embedding-004", 768, EMBED_ENDPOINT); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> client.embed("hello")); + assertTrue(ex.getMessage().contains("no values array")); + } + + @Test + void testEmptyValuesArrayThrows() { + StubHttpClient httpClient = new StubHttpClient("{\"embedding\":{\"values\":[]}}", 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "test-key", "text-embedding-004", 768, EMBED_ENDPOINT); + + RuntimeException ex = assertThrows(RuntimeException.class, () -> client.embed("hello")); + assertTrue(ex.getMessage().contains("no values array")); + } + + @Test + void testRequestUrlContainsApiKeyAsQueryParam() { + String response = "{\"embedding\":{\"values\":[0.1]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "my-secret-key", "text-embedding-004", 1, EMBED_ENDPOINT); + + client.embed("hi"); + + assertEquals(1, httpClient.getCapturedRequests().size()); + HttpRequest request = httpClient.getCapturedRequests().get(0); + String url = request.uri().toString(); + assertTrue(url.endsWith("text-embedding-004:embedContent?key=my-secret-key"), url); + } + + @Test + void testRequestHasNoAuthorizationHeader() { + String response = "{\"embedding\":{\"values\":[0.1]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "my-secret-key", "text-embedding-004", 1, EMBED_ENDPOINT); + + client.embed("hi"); + + HttpRequest request = httpClient.getCapturedRequests().get(0); + assertTrue(request.headers().firstValue("Authorization").isEmpty()); + assertTrue(request.headers().firstValue("api-key").isEmpty()); + assertEquals("application/json", request.headers().firstValue("Content-Type").orElse(null)); + } + + @Test + void testRequestBodyShape() throws Exception { + String response = "{\"embedding\":{\"values\":[0.1]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "my-secret-key", "gemini-embedding-001", 768, EMBED_ENDPOINT); + + client.embed("the quick brown fox"); + + HttpRequest request = httpClient.getCapturedRequests().get(0); + String body = extractBody(request); + com.fasterxml.jackson.databind.JsonNode parsed = + new com.fasterxml.jackson.databind.ObjectMapper().readTree(body); + assertEquals("models/gemini-embedding-001", parsed.get("model").asText()); + assertEquals( + "the quick brown fox", parsed.get("content").get("parts").get(0).get("text").asText()); + assertEquals(768, parsed.get("outputDimensionality").asInt()); + } + + @Test + void testApiKeyIsUrlEncoded() { + String response = "{\"embedding\":{\"values\":[0.1]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, "key with spaces&chars", "text-embedding-004", 1, EMBED_ENDPOINT); + + client.embed("hi"); + + HttpRequest request = httpClient.getCapturedRequests().get(0); + String url = request.uri().toString(); + assertTrue(url.contains("key=key+with+spaces%26chars"), url); + } + + @Test + void testEndpointWithExistingQueryStringUsesAmpersand() { + String response = "{\"embedding\":{\"values\":[0.1]}}"; + StubHttpClient httpClient = new StubHttpClient(response, 200); + + GoogleEmbeddingClient client = + new GoogleEmbeddingClient( + httpClient, + "my-key", + "text-embedding-004", + 1, + "https://proxy.example.com/embed?alt=json"); + + client.embed("hi"); + + HttpRequest request = httpClient.getCapturedRequests().get(0); + String url = request.uri().toString(); + assertEquals("https://proxy.example.com/embed?alt=json&key=my-key", url); + } + + private static String extractBody(HttpRequest request) { + java.net.http.HttpRequest.BodyPublisher publisher = + request + .bodyPublisher() + .orElseThrow(() -> new IllegalStateException("Request had no body publisher")); + java.util.concurrent.CompletableFuture future = + new java.util.concurrent.CompletableFuture<>(); + publisher.subscribe( + new java.util.concurrent.Flow.Subscriber<>() { + private final java.io.ByteArrayOutputStream out = new java.io.ByteArrayOutputStream(); + + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(java.nio.ByteBuffer item) { + byte[] arr = new byte[item.remaining()]; + item.get(arr); + out.write(arr, 0, arr.length); + } + + @Override + public void onError(Throwable throwable) { + future.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + future.complete(out.toString(java.nio.charset.StandardCharsets.UTF_8)); + } + }); + try { + return future.get(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (java.util.concurrent.ExecutionException e) { + throw new RuntimeException("Body publisher failed", e.getCause()); + } catch (java.util.concurrent.TimeoutException e) { + throw new RuntimeException("Body publisher timed out after 5s", e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Body publisher interrupted", e); + } + } + + private ElasticSearchConfiguration buildConfig(String apiKey, String modelId, int dimension) { + Google googleCfg = + new Google() + .withApiKey(apiKey) + .withEmbeddingModelId(modelId) + .withEmbeddingDimension(dimension); + + NaturalLanguageSearchConfiguration nlsCfg = new NaturalLanguageSearchConfiguration(); + nlsCfg.setGoogle(googleCfg); + + ElasticSearchConfiguration config = new ElasticSearchConfiguration(); + config.setNaturalLanguageSearch(nlsCfg); + return config; + } +} diff --git a/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json b/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json index 4b7591e6244d..57e74512e3e4 100644 --- a/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json +++ b/openmetadata-spec/src/main/resources/json/schema/configuration/elasticSearchConfiguration.json @@ -149,7 +149,7 @@ "default": 0.6 }, "embeddingProvider": { - "description": "The provider to use for generating vector embeddings (e.g., bedrock, openai).", + "description": "The provider to use for generating vector embeddings (e.g., bedrock, openai, google, djl).", "type": "string", "default": "bedrock" }, @@ -240,6 +240,37 @@ } }, "additionalProperties": false + }, + "google": { + "description": "Google Gemini configuration for embedding generation via the Generative Language API.", + "type": "object", + "javaType": "org.openmetadata.schema.service.configuration.elasticsearch.Google", + "properties": { + "apiKey": { + "description": "API key from Google AI Studio for authenticating with the Generative Language API.", + "type": "string" + }, + "modelId": { + "description": "Gemini chat model identifier for query transformation (e.g., gemini-2.5-flash, gemini-1.5-flash).", + "type": "string", + "default": "gemini-2.5-flash" + }, + "embeddingModelId": { + "description": "Gemini embedding model identifier (e.g., gemini-embedding-001, text-embedding-004).", + "type": "string", + "default": "gemini-embedding-001" + }, + "embeddingDimension": { + "description": "Dimension of the embedding vector, sent to Google as `outputDimensionality`. For `gemini-embedding-001` valid values are 768, 1536, or 3072. For `text-embedding-004` use 768.", + "type": "integer", + "default": 768 + }, + "endpoint": { + "description": "Optional override for the full embedding endpoint URL. Must be the complete URL including the model and `:embedContent` action (e.g. `https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent`), not just a base URL. Leave empty to use the default Generative Language API endpoint, which is constructed from `embeddingModelId`. The `key` query parameter is appended automatically.", + "type": "string" + } + }, + "additionalProperties": false } }, "additionalProperties": false diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts index 4f798d792c90..71142cee19be 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/configuration/elasticSearchConfiguration.ts @@ -127,13 +127,17 @@ export interface NaturalLanguageSearch { */ djl?: Djl; /** - * The provider to use for generating vector embeddings (e.g., bedrock, openai). + * The provider to use for generating vector embeddings (e.g., bedrock, openai, google, djl). */ embeddingProvider?: string; /** * Enable or disable natural language search */ enabled?: boolean; + /** + * Google Gemini configuration for embedding generation via the Generative Language API. + */ + google?: Google; /** * Weight for BM25 keyword search results in hybrid RRF pipeline (0.0-1.0) */ @@ -238,6 +242,40 @@ export interface Djl { embeddingModel?: string; } +/** + * Google Gemini configuration for embedding generation via the Generative Language API. + */ +export interface Google { + /** + * API key from Google AI Studio for authenticating with the Generative Language API. + */ + apiKey?: string; + /** + * Dimension of the embedding vector, sent to Google as `outputDimensionality`. For + * `gemini-embedding-001` valid values are 768, 1536, or 3072. For `text-embedding-004` use + * 768. + */ + embeddingDimension?: number; + /** + * Gemini embedding model identifier (e.g., gemini-embedding-001, text-embedding-004). + */ + embeddingModelId?: string; + /** + * Optional override for the full embedding endpoint URL. Must be the complete URL including + * the model and `:embedContent` action (e.g. + * `https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent`), + * not just a base URL. Leave empty to use the default Generative Language API endpoint, + * which is constructed from `embeddingModelId`. The `key` query parameter is appended + * automatically. + */ + endpoint?: string; + /** + * Gemini chat model identifier for query transformation (e.g., gemini-2.5-flash, + * gemini-1.5-flash). + */ + modelId?: string; +} + /** * OpenAI configuration for embedding generation. Supports both OpenAI and Azure OpenAI * endpoints. diff --git a/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts b/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts index 891650b613c6..ad058cb5aa23 100644 --- a/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts +++ b/openmetadata-ui/src/main/resources/ui/src/generated/settings/settings.ts @@ -2173,13 +2173,17 @@ export interface NaturalLanguageSearch { */ djl?: Djl; /** - * The provider to use for generating vector embeddings (e.g., bedrock, openai). + * The provider to use for generating vector embeddings (e.g., bedrock, openai, google, djl). */ embeddingProvider?: string; /** * Enable or disable natural language search */ enabled?: boolean; + /** + * Google Gemini configuration for embedding generation via the Generative Language API. + */ + google?: Google; /** * Weight for BM25 keyword search results in hybrid RRF pipeline (0.0-1.0) */ @@ -2284,6 +2288,40 @@ export interface Djl { embeddingModel?: string; } +/** + * Google Gemini configuration for embedding generation via the Generative Language API. + */ +export interface Google { + /** + * API key from Google AI Studio for authenticating with the Generative Language API. + */ + apiKey?: string; + /** + * Dimension of the embedding vector, sent to Google as `outputDimensionality`. For + * `gemini-embedding-001` valid values are 768, 1536, or 3072. For `text-embedding-004` use + * 768. + */ + embeddingDimension?: number; + /** + * Gemini embedding model identifier (e.g., gemini-embedding-001, text-embedding-004). + */ + embeddingModelId?: string; + /** + * Optional override for the full embedding endpoint URL. Must be the complete URL including + * the model and `:embedContent` action (e.g. + * `https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent`), + * not just a base URL. Leave empty to use the default Generative Language API endpoint, + * which is constructed from `embeddingModelId`. The `key` query parameter is appended + * automatically. + */ + endpoint?: string; + /** + * Gemini chat model identifier for query transformation (e.g., gemini-2.5-flash, + * gemini-1.5-flash). + */ + modelId?: string; +} + /** * OpenAI configuration for embedding generation. Supports both OpenAI and Azure OpenAI * endpoints.