Skip to content

Commit

Permalink
feat: tool abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
plutoless committed Oct 22, 2024
1 parent 75b0853 commit 01d95ac
Show file tree
Hide file tree
Showing 14 changed files with 526 additions and 227 deletions.
4 changes: 2 additions & 2 deletions agents/property.json
Original file line number Diff line number Diff line change
Expand Up @@ -1640,7 +1640,7 @@
},
{
"name": "camera_va_openai_azure",
"auto_start": false,
"auto_start": true,
"nodes": [
{
"type": "extension",
Expand Down Expand Up @@ -2408,7 +2408,7 @@
},
{
"name": "va_openai_v2v",
"auto_start": true,
"auto_start": false,
"nodes": [
{
"type": "extension",
Expand Down
185 changes: 51 additions & 134 deletions agents/ten_packages/extension/openai_chatgpt_python/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@
import json
import traceback

from .helper import AsyncEventEmitter, AsyncQueue, get_current_time, get_property_bool, get_property_float, get_property_int, get_property_string, parse_sentences, rgb2base64jpeg
from ten.async_ten_env import AsyncTenEnv
from ..ten_llm_base.helper import AsyncEventEmitter, AsyncQueue, get_properties_int, get_properties_string, get_properties_float, get_property_bool, get_property_int, get_property_string

from .helper import parse_sentences, rgb2base64jpeg
from .openai import OpenAIChatGPT, OpenAIChatGPTConfig
from ten import (
AudioFrame,
VideoFrame,
AsyncExtension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
Data,
)
from ..ten_llm_base.extension import TenLLMAudioCompletionArgs, TenLLMBaseExtension, TenLLMDataType, TenLLMTextCompletionArgs

from .log import logger

CMD_IN_FLUSH = "flush"
Expand Down Expand Up @@ -52,21 +56,15 @@
TASK_TYPE_CHAT_COMPLETION_WITH_VISION = "chat_completion_with_vision"


class OpenAIChatGPTExtension(AsyncExtension):
class OpenAIChatGPTExtension(TenLLMBaseExtension):
memory = []
max_memory_length = 10
openai_chatgpt = None
enable_tools = False
image_data = None
image_width = 0
image_height = 0
checking_vision_text_items = []
loop = None
sentence_fragment = ""

# Create the queue for message processing
queue = AsyncQueue()

available_tools = [
{
"type": "function",
Expand All @@ -81,58 +79,35 @@ class OpenAIChatGPTExtension(AsyncExtension):

async def on_init(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_init")
await super().on_init(ten_env)
ten_env.on_init_done()

async def on_start(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_start")

self.loop = asyncio.get_event_loop()
self.loop.create_task(self._process_queue(ten_env))
await super().on_start(ten_env)

# Prepare configuration
openai_chatgpt_config = OpenAIChatGPTConfig.default_config()

# Mandatory properties
openai_chatgpt_config.base_url = get_property_string(
ten_env, PROPERTY_BASE_URL) or openai_chatgpt_config.base_url
openai_chatgpt_config.api_key = get_property_string(
ten_env, PROPERTY_API_KEY)
get_properties_string(ten_env, [PROPERTY_BASE_URL, PROPERTY_API_KEY], lambda name, value: setattr(
openai_chatgpt_config, name, value or getattr(openai_chatgpt_config, name)))
if not openai_chatgpt_config.api_key:
ten_env.log_info(f"API key is missing, exiting on_start")
return

# Optional properties
openai_chatgpt_config.model = get_property_string(
ten_env, PROPERTY_MODEL) or openai_chatgpt_config.model
openai_chatgpt_config.prompt = get_property_string(
ten_env, PROPERTY_PROMPT) or openai_chatgpt_config.prompt
openai_chatgpt_config.frequency_penalty = get_property_float(
ten_env, PROPERTY_FREQUENCY_PENALTY) or openai_chatgpt_config.frequency_penalty
openai_chatgpt_config.presence_penalty = get_property_float(
ten_env, PROPERTY_PRESENCE_PENALTY) or openai_chatgpt_config.presence_penalty
openai_chatgpt_config.temperature = get_property_float(
ten_env, PROPERTY_TEMPERATURE) or openai_chatgpt_config.temperature
openai_chatgpt_config.top_p = get_property_float(
ten_env, PROPERTY_TOP_P) or openai_chatgpt_config.top_p
openai_chatgpt_config.max_tokens = get_property_int(
ten_env, PROPERTY_MAX_TOKENS) or openai_chatgpt_config.max_tokens
openai_chatgpt_config.proxy_url = get_property_string(
ten_env, PROPERTY_PROXY_URL) or openai_chatgpt_config.proxy_url
get_properties_string(ten_env, [PROPERTY_MODEL, PROPERTY_PROMPT, PROPERTY_PROXY_URL], lambda name, value: setattr(
openai_chatgpt_config, name, value or getattr(openai_chatgpt_config, name)))
get_properties_float(ten_env, [PROPERTY_FREQUENCY_PENALTY, PROPERTY_PRESENCE_PENALTY, PROPERTY_TEMPERATURE, PROPERTY_TOP_P], lambda name, value: setattr(
openai_chatgpt_config, name, value or getattr(openai_chatgpt_config, name)))
get_properties_int(ten_env, [PROPERTY_MAX_TOKENS], lambda name, value: setattr(
openai_chatgpt_config, name, value or getattr(openai_chatgpt_config, name)))

# Properties that don't affect openai_chatgpt_config
self.greeting = get_property_string(ten_env, PROPERTY_GREETING)
self.enable_tools = get_property_bool(ten_env, PROPERTY_ENABLE_TOOLS)
self.max_memory_length = get_property_int(
ten_env, PROPERTY_MAX_MEMORY_LENGTH)
checking_vision_text_items_str = get_property_string(
ten_env, PROPERTY_CHECKING_VISION_TEXT_ITEMS)
if checking_vision_text_items_str:
try:
self.checking_vision_text_items = json.loads(
checking_vision_text_items_str)
except Exception as err:
ten_env.log_info(
f"Error parsing {PROPERTY_CHECKING_VISION_TEXT_ITEMS}: {err}")
self.users_count = 0

# Create instance
Expand All @@ -147,21 +122,20 @@ async def on_start(self, ten_env: TenEnv) -> None:

async def on_stop(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_stop")

# TODO: clean up resources

await super().on_stop(ten_env)
ten_env.on_stop_done()

async def on_deinit(self, ten_env: TenEnv) -> None:
ten_env.log_info("on_deinit")
await super().on_deinit(ten_env)
ten_env.on_deinit_done()

async def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
ten_env.log_info(f"on_cmd name: {cmd_name}")

if cmd_name == CMD_IN_FLUSH:
await self._flush_queue(ten_env)
await self.flush_input_items(ten_env)
ten_env.send_cmd(Cmd.create(CMD_OUT_FLUSH), None)
ten_env.log_info("on_cmd sent flush")
status_code, detail = StatusCode.OK, "success"
Expand Down Expand Up @@ -193,87 +167,47 @@ async def on_cmd(self, ten_env: TenEnv, cmd: Cmd) -> None:
cmd_result.set_property_string("detail", detail)
ten_env.return_result(cmd_result, cmd)

async def on_data(self, ten_env: TenEnv, data: Data) -> None:
async def on_data(self, ten_env: AsyncTenEnv, data: Data) -> None:
data_name = data.get_name()
ten_env.log_debug("on_data name {}".format(data_name))

# Get the necessary properties
is_final = get_property_bool(data, DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
input_text = get_property_string(data, DATA_IN_TEXT_DATA_PROPERTY_TEXT)
is_final = get_property_bool(data, "is_final")
input_text = get_property_string(data, "text")

if not is_final:
ten_env.log_info("ignore non-final input")
ten_env.log_debug("ignore non-final input")
return
if not input_text:
ten_env.log_info("ignore empty text")
ten_env.log_warn("ignore empty text")
return

ten_env.log_info(f"OnData input text: [{input_text}]")
ten_env.log_debug(f"OnData input text: [{input_text}]")

# Start an asynchronous task for handling chat completion
await self.queue.put([TASK_TYPE_CHAT_COMPLETION, input_text])

async def on_audio_frame(self, ten_env: TenEnv, audio_frame: AudioFrame) -> None:
# TODO: process pcm frame
pass

async def on_video_frame(self, ten_env: TenEnv, video_frame: VideoFrame) -> None:
# ten_env.log_info(f"OpenAIChatGPTExtension on_video_frame {frame.get_width()} {frame.get_height()}")
self.image_data = video_frame.get_buf()
self.image_width = video_frame.get_width()
self.image_height = video_frame.get_height()
return

async def _process_queue(self, ten_env: TenEnv):
"""Asynchronously process queue items one by one."""
while True:
# Wait for an item to be available in the queue
[task_type, message] = await self.queue.get()
try:
# Create a new task for the new message
self.current_task = asyncio.create_task(
self._run_chatflow(ten_env, task_type, message, self.memory))
await self.current_task # Wait for the current task to finish or be cancelled
except asyncio.CancelledError:
ten_env.log_info(f"Task cancelled: {message}")

async def _flush_queue(self, ten_env: TenEnv):
"""Flushes the self.queue and cancels the current task."""
# Flush the queue using the new flush method
await self.queue.flush()

# Cancel the current task if one is running
if self.current_task:
ten_env.log_info("Cancelling the current task during flush.")
self.current_task.cancel()

async def _run_chatflow(self, ten_env: TenEnv, task_type: str, input_text: str, memory):
await self.queue_input_item(TenLLMDataType.TEXT, [{
"type": "text",
"text": input_text,
}])

async def on_audio_completion(self, ten_env: TenEnv, message: any, **kargs: TenLLMAudioCompletionArgs) -> None:
return await super().on_audio_completion(ten_env, message, **kargs)

async def on_text_completion(self, ten_env: TenEnv, content: any, **kargs: TenLLMTextCompletionArgs) -> None:
"""Run the chatflow asynchronously."""
memory_cache = []
memory = self.memory
try:
ten_env.log_info(f"for input text: [{input_text}] memory: {memory}")
ten_env.log_info(f"for input text: [{content}] memory: {memory}")
message = None
tools = None
no_tool = kargs.get("no_tool", False)

# Prepare the message and tools based on the task type
if task_type == TASK_TYPE_CHAT_COMPLETION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
elif task_type == TASK_TYPE_CHAT_COMPLETION_WITH_VISION:
message = {"role": "user", "content": input_text}
memory_cache = memory_cache + \
[message, {"role": "assistant", "content": ""}]
tools = self.available_tools if self.enable_tools else None
if self.image_data is not None:
url = rgb2base64jpeg(
self.image_data, self.image_width, self.image_height)
message = {
"role": "user",
"content": [
{"type": "text", "text": input_text},
{"type": "image_url", "image_url": {"url": url}},
],
}
ten_env.log_info(f"msg with vision data: {message}")
message = {"role": "user", "content": content}
non_artifact_content = [item for item in content if item.get("type") == "text"]
non_artifact_message = {"role": "user", "content": non_artifact_content}
memory_cache = memory_cache + [non_artifact_message, {"role": "assistant", "content": ""}]
tools = self.available_tools if not no_tool else None

self.sentence_fragment = ""

Expand All @@ -285,7 +219,7 @@ async def handle_tool_call(tool_call):
ten_env.log_info(f"tool_call: {tool_call}")
if tool_call.function.name == "get_vision_image":
# Append the vision image to the last assistant message
await self.queue.put([TASK_TYPE_CHAT_COMPLETION_WITH_VISION, input_text], True)
pass

async def handle_content_update(content: str):
# Append the content to the last assistant message
Expand All @@ -296,7 +230,7 @@ async def handle_content_update(content: str):
sentences, self.sentence_fragment = parse_sentences(
self.sentence_fragment, content)
for s in sentences:
self._send_data(ten_env, s, False)
self.send_text_output(ten_env, s, False)

async def handle_content_finished(full_content: str):
content_finished_event.set()
Expand All @@ -312,12 +246,12 @@ async def handle_content_finished(full_content: str):
# Wait for the content to be finished
await content_finished_event.wait()
except asyncio.CancelledError:
ten_env.log_info(f"Task cancelled: {input_text}")
ten_env.log_info(f"Task cancelled: {content}")
except Exception as e:
logger.error(
f"Error in chat_completion: {traceback.format_exc()} for input text: {input_text}")
f"Error in chat_completion: {traceback.format_exc()} for input text: {content}")
finally:
self._send_data(ten_env, "", True)
self.send_text_output(ten_env, "", True)
# always append the memory
for m in memory_cache:
self._append_memory(m)
Expand All @@ -326,20 +260,3 @@ def _append_memory(self, message: str):
if len(self.memory) > self.max_memory_length:
self.memory.pop(0)
self.memory.append(message)

def _send_data(self, ten_env: TenEnv, sentence: str, end_of_segment: bool):
try:
output_data = Data.create("text_data")
output_data.set_property_string(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
output_data.set_property_bool(
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, end_of_segment
)
ten_env.send_data(output_data)
ten_env.log_info(
f"{'end of segment ' if end_of_segment else ''}sent sentence [{sentence}]"
)
except Exception as err:
ten_env.log_info(
f"send sentence [{sentence}] failed, err: {err}"
)
Loading

0 comments on commit 01d95ac

Please sign in to comment.