Skip to content

Commit

Permalink
reimplement completions api
Browse files Browse the repository at this point in the history
  • Loading branch information
CJCrafter committed Nov 11, 2023
1 parent de6be93 commit b6b117a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/src/main/kotlin/completion/Completion.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import com.cjcrafter.openai.openAI
import io.github.cdimascio.dotenv.dotenv

/**
* In this Kotlin example, we will be using the Chat API to create a simple chatbot.
* In this Kotlin example, we will be using the Completions API to generate a response.
*/
fun main() {

Expand All @@ -17,7 +17,7 @@ fun main() {
// Here you can change the model's settings, add tools, and more.
val request = completionRequest {
model("davinci")
prompt("What is 9+10")
prompt("The wheels on the bus go")
}

val completion = openai.createCompletion(request)[0]
Expand Down
28 changes: 28 additions & 0 deletions examples/src/main/kotlin/completion/StreamCompletion.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package completion

import com.cjcrafter.openai.completions.completionRequest
import com.cjcrafter.openai.openAI
import io.github.cdimascio.dotenv.dotenv

/**
* In this Kotlin example, we will be using the Completions API to generate a
* response. We will stream the tokens 1 at a time for a faster response time.
*/
fun main() {

// To use dotenv, you need to add the "io.github.cdimascio:dotenv-kotlin:version"
// dependency. Then you can add a .env file in your project directory.
val key = dotenv()["OPENAI_TOKEN"]
val openai = openAI { apiKey(key) }

// Here you can change the model's settings, add tools, and more.
val request = completionRequest {
model("davinci")
prompt("The wheels on the bus go")
maxTokens(500)
}

for (chunk in openai.streamCompletion(request)) {
print(chunk.choices[0].text)
}
}
115 changes: 62 additions & 53 deletions src/main/kotlin/com/cjcrafter/openai/OpenAIImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.cjcrafter.openai.chat.*
import com.cjcrafter.openai.completions.CompletionRequest
import com.cjcrafter.openai.completions.CompletionResponse
import com.cjcrafter.openai.completions.CompletionResponseChunk
import com.cjcrafter.openai.completions.CompletionUsage
import com.fasterxml.jackson.databind.JavaType
import com.fasterxml.jackson.databind.node.ObjectNode
import okhttp3.*
import okhttp3.MediaType.Companion.toMediaType
Expand Down Expand Up @@ -32,62 +32,40 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
.post(body).build()
}

override fun createCompletion(request: CompletionRequest): CompletionResponse {
@Suppress("DEPRECATION")
request.stream = false // use streamCompletion for stream=true
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)

val httpResponse = client.newCall(httpRequest).execute()
println(httpResponse)

return CompletionResponse("1", 1, "1", listOf(), CompletionUsage(1, 1, 1))
}

override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
@Suppress("DEPRECATION")
request.stream = true // use createCompletion for stream=false
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)

return listOf()
}

override fun createChatCompletion(request: ChatRequest): ChatResponse {
@Suppress("DEPRECATION")
request.stream = false // use streamChatCompletion for stream=true
val httpRequest = buildRequest(request, CHAT_ENDPOINT)

protected open fun <T> executeRequest(httpRequest: Request, responseType: Class<T>): T {
val httpResponse = client.newCall(httpRequest).execute()
if (!httpResponse.isSuccessful) {
val json = httpResponse.body?.byteStream()?.bufferedReader()?.readText()
httpResponse.close()
throw IOException("Unexpected code $httpResponse, recieved: $json")
throw IOException("Unexpected code $httpResponse, received: $json")
}

val json = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
val str = json.readText()
return objectMapper.readValue(str, ChatResponse::class.java)
val jsonReader = httpResponse.body?.byteStream()?.bufferedReader()
?: throw IOException("Response body is null")
val responseStr = jsonReader.readText()
return objectMapper.readValue(responseStr, responseType)
}

override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
request.stream = true // Set streaming to true
val httpRequest = buildRequest(request, CHAT_ENDPOINT)

return object : Iterable<ChatResponseChunk> {
override fun iterator(): Iterator<ChatResponseChunk> {
val httpResponse = client.newCall(httpRequest).execute()
private fun <T> streamResponses(
request: Request,
responseType: JavaType,
updateResponse: (T, String) -> T
): Iterable<T> {
return object : Iterable<T> {
override fun iterator(): Iterator<T> {
val httpResponse = client.newCall(request).execute()

if (!httpResponse.isSuccessful) {
httpResponse.close()
throw IOException("Unexpected code $httpResponse")
}

val reader = httpResponse.body?.byteStream()?.bufferedReader() ?: throw IOException("Response body is null")
val reader = httpResponse.body?.byteStream()?.bufferedReader()
?: throw IOException("Response body is null")

// Only instantiate 1 ChatResponseChunk, otherwise simply update
// the existing one. This lets us accumulate the message.
var chunk: ChatResponseChunk? = null
var currentResponse: T? = null

return object : Iterator<ChatResponseChunk> {
return object : Iterator<T> {
private var nextLine: String? = readNextLine(reader)

private fun readNextLine(reader: BufferedReader): String? {
Expand All @@ -98,8 +76,6 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
reader.close()
return null
}

// Check if the line starts with 'data:' and skip empty lines
} while (line != null && (line.isEmpty() || !line.startsWith("data: ")))
return line?.removePrefix("data: ")
}
Expand All @@ -108,24 +84,57 @@ open class OpenAIImpl @ApiStatus.Internal constructor(
return nextLine != null
}

override fun next(): ChatResponseChunk {
val currentLine = nextLine ?: throw NoSuchElementException("No more lines")
//println(" $currentLine")
chunk = chunk?.apply { update(objectMapper.readTree(currentLine) as ObjectNode) } ?: objectMapper.readValue(currentLine, ChatResponseChunk::class.java)
nextLine = readNextLine(reader) // Prepare the next line
return chunk!!
//return ChatResponseChunk("1", 1, listOf())
override fun next(): T {
val line = nextLine ?: throw NoSuchElementException("No more lines")
currentResponse = if (currentResponse == null) {
objectMapper.readValue(line, responseType)
} else {
updateResponse(currentResponse!!, line)
}
nextLine = readNextLine(reader)
return currentResponse!!
}
}
}
}
}

override fun createCompletion(request: CompletionRequest): CompletionResponse {
@Suppress("DEPRECATION")
request.stream = false // use streamCompletion for stream=true
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
return executeRequest(httpRequest, CompletionResponse::class.java)
}

override fun streamCompletion(request: CompletionRequest): Iterable<CompletionResponseChunk> {
@Suppress("DEPRECATION")
request.stream = true
val httpRequest = buildRequest(request, COMPLETIONS_ENDPOINT)
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(CompletionResponseChunk::class.java)) { response, newLine ->
// We don't have any update logic, so we should ignore the old response and just return a new one
objectMapper.readValue(newLine, CompletionResponseChunk::class.java)
}
}

override fun createChatCompletion(request: ChatRequest): ChatResponse {
@Suppress("DEPRECATION")
request.stream = false // use streamChatCompletion for stream=true
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
return executeRequest(httpRequest, ChatResponse::class.java)
}

override fun streamChatCompletion(request: ChatRequest): Iterable<ChatResponseChunk> {
@Suppress("DEPRECATION")
request.stream = true
val httpRequest = buildRequest(request, CHAT_ENDPOINT)
return streamResponses(httpRequest, objectMapper.typeFactory.constructType(ChatResponseChunk::class.java)) { response, newLine ->
response.update(objectMapper.readTree(newLine) as ObjectNode)
response
}
}

companion object {
const val COMPLETIONS_ENDPOINT = "v1/completions"
const val CHAT_ENDPOINT = "v1/chat/completions"
const val IMAGE_CREATE_ENDPOINT = "v1/images/generations"
const val IMAGE_EDIT_ENDPOINT = "v1/images/edits"
const val IMAGE_VARIATION_ENDPOINT = "v1/images/variations"
}
}

0 comments on commit b6b117a

Please sign in to comment.