From 71e2d84671b4885875db6b99f1f37fe67d8ea54f Mon Sep 17 00:00:00 2001 From: Victor Alfaro Date: Wed, 16 Oct 2024 20:46:17 -0600 Subject: [PATCH] #30361: allowing stream chat to dotAI --- .../dotcms/ai/client/AIClientStrategy.java | 14 ++++-- .../dotcms/ai/client/AIDefaultStrategy.java | 12 +++-- .../ai/client/AIModelFallbackStrategy.java | 49 ++++++++++++------- .../com/dotcms/ai/client/AIProxiedClient.java | 6 +-- .../dotcms/ai/client/openai/OpenAIClient.java | 24 +++++++-- .../java/com/dotcms/ai/domain/AIResponse.java | 1 - .../dotcms/ai/rest/CompletionsResource.java | 11 +++-- 7 files changed, 78 insertions(+), 39 deletions(-) diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java index 6ac784ef2a2a..d49ea9da7881 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIClientStrategy.java @@ -23,7 +23,10 @@ */ public interface AIClientStrategy { - AIClientStrategy NOOP = (client, handler, request, output) -> AIResponse.builder().build(); + AIClientStrategy NOOP = (client, handler, request, output) -> { + AIResponse.builder().build(); + return null; + }; /** * Applies the strategy to the given AI client request and handles the response. @@ -32,10 +35,11 @@ public interface AIClientStrategy { * @param handler the response evaluator to handle the response * @param request the AI request to be processed * @param output the output stream to which the response will be written + * @return result output stream */ - void applyStrategy(AIClient client, - AIResponseEvaluator handler, - AIRequest request, - OutputStream output); + OutputStream applyStrategy(AIClient client, + AIResponseEvaluator handler, + AIRequest request, + OutputStream output); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java index 02149d98a7b1..3f58ca8bb3ea 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIDefaultStrategy.java @@ -1,7 +1,9 @@ package com.dotcms.ai.client; +import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.io.Serializable; +import java.util.Optional; /** * Default implementation of the {@link AIClientStrategy} interface. @@ -22,11 +24,13 @@ public class AIDefaultStrategy implements AIClientStrategy { @Override - public void applyStrategy(final AIClient client, - final AIResponseEvaluator handler, - final AIRequest request, - final OutputStream output) { + public OutputStream applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream incoming) { + final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new); client.sendRequest(request, output); + return output; } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java index 0553645ece58..c01a0ce5d6de 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIModelFallbackStrategy.java @@ -57,21 +57,24 @@ public class AIModelFallbackStrategy implements AIClientStrategy { * @param request the AI request to be processed * @param output the output stream to which the response will be written * @throws DotAIAllModelsExhaustedException if all models are exhausted and no successful response is obtained + * @return result output stream */ @Override - public void applyStrategy(final AIClient client, - final AIResponseEvaluator handler, - final AIRequest request, - final OutputStream output) { + public OutputStream applyStrategy(final AIClient client, + final AIResponseEvaluator handler, + final AIRequest request, + final OutputStream output) { final JSONObjectAIRequest jsonRequest = AIClient.useRequestOrThrow(request); final Tuple2 modelTuple = resolveModel(jsonRequest); final AIResponseData firstAttempt = sendAttempt(client, handler, jsonRequest, output, modelTuple); if (firstAttempt.isSuccess()) { - return; + return output; } runFallbacks(client, handler, jsonRequest, output, modelTuple); + + return output; } private static Tuple2 resolveModel(final JSONObjectAIRequest request) { @@ -96,11 +99,7 @@ private static Tuple2 resolveModel(final JSONObjectAIRequest req } private static boolean isSameAsFirst(final Model firstAttempt, final Model model) { - if (firstAttempt.equals(model)) { - return true; - } - - return false; + return firstAttempt.equals(model); } private static boolean isOperational(final Model model) { @@ -114,18 +113,32 @@ private static boolean isOperational(final Model model) { return true; } - private static AIResponseData doSend(final AIClient client, final AIRequest request) { - final ByteArrayOutputStream output = new ByteArrayOutputStream(); + private static boolean isStream(final JSONObjectAIRequest request) { + return request.getPayload().optBoolean(AiKeys.STREAM, false); + } + + private static AIResponseData doSend(final AIClient client, + final JSONObjectAIRequest request, + final OutputStream incoming) { + final OutputStream output = Optional.ofNullable(incoming).orElseGet(ByteArrayOutputStream::new); client.sendRequest(request, output); final AIResponseData responseData = new AIResponseData(); responseData.setResponse(output.toString()); - IOUtils.closeQuietly(output); + if (!isStream(request)) { + IOUtils.closeQuietly(output); + } return responseData; } - private static void redirectOutput(final OutputStream output, final String response) { + private static void redirectOutput(final JSONObjectAIRequest request, + final OutputStream output, + final String response) { + if (isStream(request)) { + return; + } + try (final InputStream input = new ByteArrayInputStream(response.getBytes(StandardCharsets.UTF_8))) { IOUtils.copy(input, output); } catch (IOException e) { @@ -133,12 +146,12 @@ private static void redirectOutput(final OutputStream output, final String respo } } - private static void notifyFailure(final AIModel aiModel, final AIRequest request) { + private static void notifyFailure(final AIModel aiModel, final JSONObjectAIRequest request) { AIAppValidator.get().validateModelsUsage(aiModel, request.getUserId()); } private static void handleFailure(final Tuple2 modelTuple, - final AIRequest request, + final JSONObjectAIRequest request, final AIResponseData responseData) { final AIModel aiModel = modelTuple._1; final Model model = modelTuple._2; @@ -177,7 +190,7 @@ private static AIResponseData sendAttempt(final AIClient client, final Tuple2 modelTuple) { final AIResponseData responseData = Try - .of(() -> doSend(client, request)) + .of(() -> doSend(client, request, output)) .getOrElseGet(exception -> fromException(evaluator, exception)); if (!responseData.isSuccess()) { @@ -200,7 +213,7 @@ private static AIResponseData sendAttempt(final AIClient client, AppConfig.debugLogger( AIModelFallbackStrategy.class, () -> String.format("Model [%s] succeeded. No need to fallback.", modelTuple._2.getName())); - redirectOutput(output, responseData.getResponse()); + redirectOutput(request, output, responseData.getResponse()); } else { logFailure(modelTuple, responseData); diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java index 73d675a3b90e..a37aff2fb311 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/AIProxiedClient.java @@ -73,13 +73,11 @@ public static AIProxiedClient of(final AIClient client, final AIProxyStrategy st * @return the AI response */ public AIResponse sendToAI(final AIRequest request, final OutputStream output) { - final OutputStream finalOutput = Optional.ofNullable(output).orElseGet(ByteArrayOutputStream::new); - - strategy.applyStrategy(client, responseEvaluator, request, finalOutput); + final OutputStream resultOutput = strategy.applyStrategy(client, responseEvaluator, request, output); return Optional.ofNullable(output) .map(out -> AIResponse.EMPTY) - .orElseGet(() -> AIResponse.builder().withResponse(finalOutput.toString()).build()); + .orElseGet(() -> AIResponse.builder().withResponse(resultOutput.toString()).build()); } } diff --git a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java index ab12dbba58f3..c705e6c00243 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java +++ b/dotCMS/src/main/java/com/dotcms/ai/client/openai/OpenAIClient.java @@ -20,6 +20,7 @@ import io.vavr.Tuple2; import io.vavr.control.Try; import org.apache.http.HttpHeaders; +import org.apache.http.HttpStatus; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; import org.apache.http.client.methods.HttpUriRequest; @@ -29,7 +30,9 @@ import org.apache.http.impl.client.HttpClients; import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import java.io.BufferedInputStream; +import java.io.ByteArrayOutputStream; import java.io.OutputStream; import java.io.Serializable; import java.util.Optional; @@ -129,17 +132,19 @@ public void sendRequest(final AIRequest request, fin lastRestCall.put(aiModel, System.currentTimeMillis()); - try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + try (final CloseableHttpClient httpClient = HttpClients.createDefault()) { final StringEntity jsonEntity = new StringEntity(payload.toString(), ContentType.APPLICATION_JSON); final HttpUriRequest httpRequest = AIClient.resolveMethod(jsonRequest.getMethod(), jsonRequest.getUrl()); httpRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON); httpRequest.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + appConfig.getApiKey()); if (!payload.getAsMap().isEmpty()) { - Try.run(() -> HttpEntityEnclosingRequestBase.class.cast(httpRequest).setEntity(jsonEntity)); + Try.run(() -> ((HttpEntityEnclosingRequestBase) httpRequest).setEntity(jsonEntity)); } - try (CloseableHttpResponse response = httpClient.execute(httpRequest)) { + try (final CloseableHttpResponse response = httpClient.execute(httpRequest)) { + onStreamCheckFotStatusCode(modelName, payload, response); + final BufferedInputStream in = new BufferedInputStream(response.getEntity().getContent()); final byte[] buffer = new byte[1024]; int len; @@ -161,4 +166,17 @@ public void sendRequest(final AIRequest request, fin } } + private static void onStreamCheckFotStatusCode(final String modelName, + final JSONObject payload, + final CloseableHttpResponse response) { + if (payload.optBoolean(AiKeys.STREAM, false)) { + final int statusCode = response.getStatusLine().getStatusCode(); + if (Response.Status.Family.familyOf(statusCode) == Response.Status.Family.CLIENT_ERROR) { + throw new DotAIModelNotFoundException(String.format( + "Model used [%s] in request in stream mode is not found", + modelName)); + } + } + } + } diff --git a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java index 8d9887b24571..dff8cacca25d 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java +++ b/dotCMS/src/main/java/com/dotcms/ai/domain/AIResponse.java @@ -41,7 +41,6 @@ public Builder withResponse(final String response) { return this; } - public AIResponse build() { return new AIResponse(this); } diff --git a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java index 5499de4ce660..0351ec0bb4c8 100644 --- a/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java +++ b/dotCMS/src/main/java/com/dotcms/ai/rest/CompletionsResource.java @@ -56,14 +56,15 @@ public class CompletionsResource { public final Response summarizeFromContent(@Context final HttpServletRequest request, @Context final HttpServletResponse response, final CompletionsForm formIn) { + final CompletionsForm resolvedForm = resolveForm(request, response, formIn); return getResponse( request, response, formIn, - () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(formIn), + () -> APILocator.getDotAIAPI().getCompletionsAPI().summarize(resolvedForm), output -> APILocator.getDotAIAPI() .getCompletionsAPI() - .summarizeStream(formIn, new LineReadingOutputStream(output))); + .summarizeStream(resolvedForm, new LineReadingOutputStream(output))); } /** @@ -81,14 +82,15 @@ public final Response summarizeFromContent(@Context final HttpServletRequest req public final Response rawPrompt(@Context final HttpServletRequest request, @Context final HttpServletResponse response, final CompletionsForm formIn) { + final CompletionsForm resolvedForm = resolveForm(request, response, formIn); return getResponse( request, response, formIn, - () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(formIn), + () -> APILocator.getDotAIAPI().getCompletionsAPI().raw(resolvedForm), output -> APILocator.getDotAIAPI() .getCompletionsAPI() - .rawStream(formIn, new LineReadingOutputStream(output))); + .rawStream(resolvedForm, new LineReadingOutputStream(output))); } /** @@ -180,6 +182,7 @@ private static Response getResponse(final HttpServletRequest request, final JSONObject jsonResponse = noStream.get(); jsonResponse.put(AiKeys.TOTAL_TIME, System.currentTimeMillis() - startTime + "ms"); + return Response.ok(jsonResponse.toString(), MediaType.APPLICATION_JSON).build(); }