From 21cad8f793c1868e00118b58c37c722cdd0e4fe1 Mon Sep 17 00:00:00 2001 From: LemonHX Date: Thu, 9 May 2024 18:05:28 +0800 Subject: [PATCH] support deepseek --- grpcServer/relay.go | 6 +- grpcServer/relay_gin.go | 3 + relay/reqTransformer/ChatGPT.go | 2 + tests_grpc/deepseek_test.go | 133 ++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+), 1 deletion(-) create mode 100644 tests_grpc/deepseek_test.go diff --git a/grpcServer/relay.go b/grpcServer/relay.go index 51bd56f..08370ae 100644 --- a/grpcServer/relay.go +++ b/grpcServer/relay.go @@ -28,6 +28,7 @@ const AZURE_OPENAI_LLM_API = "azure_openai" const BAICHUAN_LLM_API = "baichuan" const GEMINI_LLM_API = "gemini" const MOONSHOT_LLM_API = "moonshot" +const DEEPSEEK_LLM_API = "deepseek" func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.LLMRequestSchema) (*model.LLMResponseSchema, error) { info := rs.GetLlmRequestInfo() @@ -35,7 +36,8 @@ func (uno *UnoForwardServer) BlockingRequestLLM(ctx context.Context, rs *model.L case OPENAI_LLM_API: cli := NewOpenAIClient(info) return OpenAIChatCompletion(cli, rs) - + case DEEPSEEK_LLM_API: + fallthrough case MOONSHOT_LLM_API: cli := NewOpenAIClient(info) if functionCallingRequestMake(rs) { @@ -82,6 +84,8 @@ func (uno *UnoForwardServer) StreamRequestLLM(rs *model.LLMRequestSchema, sv mod case OPENAI_LLM_API: cli := NewOpenAIClient(info) return OpenAIChatCompletionStreaming(cli, rs, sv) + case DEEPSEEK_LLM_API: + fallthrough case MOONSHOT_LLM_API: cli := NewOpenAIClient(info) if functionCallingRequestMake(rs) { diff --git a/grpcServer/relay_gin.go b/grpcServer/relay_gin.go index 5539a5f..5e3629b 100644 --- a/grpcServer/relay_gin.go +++ b/grpcServer/relay_gin.go @@ -29,6 +29,9 @@ func getProvider(m string) (string, error) { if strings.Contains(m, "moonshot") { return "moonshot", nil } + if strings.Contains(m, "deepseek") { + return "deepseek", nil + } return "", errors.New("could not get provider") } diff --git a/relay/reqTransformer/ChatGPT.go b/relay/reqTransformer/ChatGPT.go index 814e73b..29f3480 100644 --- a/relay/reqTransformer/ChatGPT.go +++ b/relay/reqTransformer/ChatGPT.go @@ -77,6 +77,8 @@ func ChatGPTToGrpcRequest(api string, model_type string, token string, req opena switch api { case "moonshot": url = "https://api.moonshot.cn/v1" + case "deepseek": + url = "https://api.deepseek.com/v1" } return &model.LLMRequestSchema{ Messages: messages, diff --git a/tests_grpc/deepseek_test.go b/tests_grpc/deepseek_test.go new file mode 100644 index 0000000..68de986 --- /dev/null +++ b/tests_grpc/deepseek_test.go @@ -0,0 +1,133 @@ +package tests_grpc_test + +import ( + "context" + "log" + "os" + "testing" + + "github.com/joho/godotenv" + "go.limit.dev/unollm/grpcServer" + "go.limit.dev/unollm/model" + "go.limit.dev/unollm/utils" +) + +func TestDeepSeek(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "假如今天下大雨,我是否需要带伞?", + }) + OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.DEEPSEEK_LLM_API, + Model: "deepseek-chat", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.deepseek.com/v1", + Token: OPENAIApiKey, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + res, err := mockServer.BlockingRequestLLM(context.Background(), &req) + if err != nil { + t.Error(err) + } + log.Println("res: ", res) +} + +func TestDeepSeekStreaming(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "假如今天下大雨,我是否需要带伞?", + }) + OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.DEEPSEEK_LLM_API, + Model: "deepseek-chat", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.deepseek.com/v1", + Token: OPENAIApiKey, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + mockServerPipe := utils.MockServerStream{ + Stream: make(chan *model.PartialLLMResponse, 1000), + } + err := mockServer.StreamRequestLLM(&req, &mockServerPipe) + if err != nil { + t.Fatal(err) + } + for { + res := <-mockServerPipe.Stream + log.Println(res) + if res.LlmTokenCount != nil { + log.Println(res.LlmTokenCount) + return + } + } +} + +func TestDeepSeekFunctionCalling(t *testing.T) { + godotenv.Load("../.env") + + messages := make([]*model.LLMChatCompletionMessage, 0) + messages = append(messages, &model.LLMChatCompletionMessage{ + Role: "user", + Content: "whats the weather like in Poston?", + }) + OPENAIApiKey := os.Getenv("TEST_DEEPSEEK_API") + req_info := model.LLMRequestInfo{ + LlmApiType: grpcServer.DEEPSEEK_LLM_API, + Model: "deepseek-chat", + Temperature: 0.9, + TopP: 0.9, + TopK: 1, + Url: "https://api.deepseek.com/v1", + Token: OPENAIApiKey, + Functions: []*model.Function{ + { + Name: "get_weather", + Description: "Get the weather of a location", + Parameters: []*model.FunctionCallingParameter{ + { + Name: "location", + Type: "string", + Description: "The city and state, e.g. San Francisco, CA", + }, + { + Name: "unit", + Type: "string", + Enums: []string{"celsius", "fahrenheit"}, + }, + }, + Requireds: []string{"location", "unit"}, + }, + }, + UseFunctionCalling: true, + } + req := model.LLMRequestSchema{ + Messages: messages, + LlmRequestInfo: &req_info, + } + mockServer := grpcServer.UnoForwardServer{} + res, err := mockServer.BlockingRequestLLM(context.Background(), &req) + if err != nil { + t.Fatal(err) + } + log.Printf("res: %#v", res.ToolCalls[0]) +}