diff --git a/examples/src/main/kotlin/completion/Completion.kt b/examples/src/main/kotlin/completion/Completion.kt index 81e5b9f..377a5ff 100644 --- a/examples/src/main/kotlin/completion/Completion.kt +++ b/examples/src/main/kotlin/completion/Completion.kt @@ -5,7 +5,7 @@ import com.cjcrafter.openai.openAI import io.github.cdimascio.dotenv.dotenv /** - * In this Kotlin example, we will be using the Chat API to create a simple chatbot. + * In this Kotlin example, we will be using the Completions API to generate a response. */ fun main() { @@ -17,7 +17,7 @@ fun main() { // Here you can change the model's settings, add tools, and more. val request = completionRequest { model("davinci") - prompt("What is 9+10") + prompt("The wheels on the bus go") } val completion = openai.createCompletion(request)[0] diff --git a/examples/src/main/kotlin/completion/StreamCompletion.kt b/examples/src/main/kotlin/completion/StreamCompletion.kt new file mode 100644 index 0000000..ba39d52 --- /dev/null +++ b/examples/src/main/kotlin/completion/StreamCompletion.kt @@ -0,0 +1,28 @@ +package completion + +import com.cjcrafter.openai.completions.completionRequest +import com.cjcrafter.openai.openAI +import io.github.cdimascio.dotenv.dotenv + +/** + * In this Kotlin example, we will be using the Completions API to generate a + * response. We will stream the tokens 1 at a time for a faster response time. + */ +fun main() { + + // To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version" + // dependency. Then you can add a .env file in your project directory. + val key = dotenv()["OPENAI_TOKEN"] + val openai = openAI { apiKey(key) } + + // Here you can change the model's settings, add tools, and more. + val request = completionRequest { + model("davinci") + prompt("The wheels on the bus go") + maxTokens(500) + } + + for (chunk in openai.streamCompletion(request)) { + print(chunk.choices[0].text) + } +} diff --git a/src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt b/src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt index 81297c4..11b1f9c 100644 --- a/src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt +++ b/src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt @@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.* import com.cjcrafter.openai.completions.CompletionRequest import com.cjcrafter.openai.completions.CompletionResponse import com.cjcrafter.openai.completions.CompletionResponseChunk -import com.cjcrafter.openai.completions.CompletionUsage +import com.fasterxml.jackson.databind.JavaType import com.fasterxml.jackson.databind.node.ObjectNode import okhttp3.* import okhttp3.MediaType.Companion.toMediaType @@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor( .post(body).build() } - override fun createCompletion(request: CompletionRequest): CompletionResponse { - @Suppress("DEPRECATION") - request.stream = false // use streamCompletion for stream=true - val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - - val httpResponse = client.newCall(httpRequest).execute() - println(httpResponse) - - return CompletionResponse("1", 1, "1", listOf(), CompletionUsage(1, 1, 1)) - } - - override fun streamCompletion(request: CompletionRequest): Iterable { - @Suppress("DEPRECATION") - request.stream = true // use createCompletion for stream=false - val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) - - return listOf() - } - - override fun createChatCompletion(request: ChatRequest): ChatResponse { - @Suppress("DEPRECATION") - request.stream = false // use streamChatCompletion for stream=true - val httpRequest = buildRequest(request, CHAT_ENDPOINT) - + protected open fun executeRequest(httpRequest: Request, responseType: Class): T { val httpResponse = client.newCall(httpRequest).execute() if (!httpResponse.isSuccessful) { val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText() httpResponse.close() - throw IOException("Unexpected code $httpResponse, recieved: $json") + throw IOException("Unexpected code $httpResponse, received: $json") } - val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null") - val str = json.readText() - return objectMapper.readValue(str, ChatResponse::class.java) + val jsonReader = httpResponse.body?.byteStream()?.bufferedReader() + ?: throw IOException("Response body is null") + val responseStr = jsonReader.readText() + return objectMapper.readValue(responseStr, responseType) } - override fun streamChatCompletion(request: ChatRequest): Iterable { - request.stream = true // Set streaming to true - val httpRequest = buildRequest(request, CHAT_ENDPOINT) - - return object : Iterable { - override fun iterator(): Iterator { - val httpResponse = client.newCall(httpRequest).execute() + private fun streamResponses( + request: Request, + responseType: JavaType, + updateResponse: (T, String) -> T + ): Iterable { + return object : Iterable { + override fun iterator(): Iterator { + val httpResponse = client.newCall(request).execute() if (!httpResponse.isSuccessful) { httpResponse.close() throw IOException("Unexpected code $httpResponse") } - val reader = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null") + val reader = httpResponse.body?.byteStream()?.bufferedReader() + ?: throw IOException("Response body is null") - // Only instantiate 1 ChatResponseChunk, otherwise simply update - // the existing one. This lets us accumulate the message. - var chunk: ChatResponseChunk? = null + var currentResponse: T? = null - return object : Iterator { + return object : Iterator { private var nextLine: String? = readNextLine(reader) private fun readNextLine(reader: BufferedReader): String? { @@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor( reader.close() return null } - - // Check if the line starts with 'data:' and skip empty lines } while (line != null && (line.isEmpty() || !line.startsWith("data: "))) return line?.removePrefix("data: ") } @@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor( return nextLine != null } - override fun next(): ChatResponseChunk { - val currentLine = nextLine ?: throw NoSuchElementException("No more lines") - //println(" $currentLine") - chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode) } ?: objectMapper.readValue(currentLine, ChatResponseChunk::class.java) - nextLine = readNextLine(reader) // Prepare the next line - return chunk!! - //return ChatResponseChunk("1", 1, listOf()) + override fun next(): T { + val line = nextLine ?: throw NoSuchElementException("No more lines") + currentResponse = if (currentResponse == null) { + objectMapper.readValue(line, responseType) + } else { + updateResponse(currentResponse!!, line) + } + nextLine = readNextLine(reader) + return currentResponse!! } } } } } + override fun createCompletion(request: CompletionRequest): CompletionResponse { + @Suppress("DEPRECATION") + request.stream = false // use streamCompletion for stream=true + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) + return executeRequest(httpRequest, CompletionResponse::class.java) + } + + override fun streamCompletion(request: CompletionRequest): Iterable { + @Suppress("DEPRECATION") + request.stream = true + val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT) + return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk::class.java)) { response, newLine -> + // We don't have any update logic, so we should ignore the old response and just return a new one + objectMapper.readValue(newLine, CompletionResponseChunk::class.java) + } + } + + override fun createChatCompletion(request: ChatRequest): ChatResponse { + @Suppress("DEPRECATION") + request.stream = false // use streamChatCompletion for stream=true + val httpRequest = buildRequest(request, CHAT_ENDPOINT) + return executeRequest(httpRequest, ChatResponse::class.java) + } + + override fun streamChatCompletion(request: ChatRequest): Iterable { + @Suppress("DEPRECATION") + request.stream = true + val httpRequest = buildRequest(request, CHAT_ENDPOINT) + return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk::class.java)) { response, newLine -> + response.update(objectMapper.readTree(newLine) as ObjectNode) + response + } + } + companion object { const val COMPLETIONS_ENDPOINT = "v1/completions" const val CHAT_ENDPOINT = "v1/chat/completions" - const val IMAGE_CREATE_ENDPOINT = "v1/images/generations" - const val IMAGE_EDIT_ENDPOINT = "v1/images/edits" - const val IMAGE_VARIATION_ENDPOINT = "v1/images/variations" } } \ No newline at end of file