Skip to content

Commit

Permalink
Merge pull request #862 from bjwswang/main
Browse files Browse the repository at this point in the history
chore: use same stream options between mpchain and others
  • Loading branch information
bjwswang authored Mar 15, 2024
2 parents b4d1808 + 289a7f5 commit 9fd5da9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 64 deletions.
8 changes: 4 additions & 4 deletions deploy/charts/gpu-operator/Chart.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dependencies:
- name: node-feature-discovery
repository: https://kubernetes-sigs.github.io/node-feature-discovery/charts
version: 0.14.2
digest: sha256:84ec59c0c12da825ca7dc25bdac63d0f2106822a129f7fe1f9d60a4023a543ce
generated: "2023-10-10T11:26:00.823757+02:00"
repository: ""
version: v0.14.2
digest: sha256:8c36de373825e52e288835695ed9298337359d6b65f56971feecb6ac907dcb89
generated: "2024-03-15T02:54:39.089270185Z"
37 changes: 0 additions & 37 deletions deploy/gpu-operator/nvidia_gpu.yaml

This file was deleted.

20 changes: 11 additions & 9 deletions pkg/appruntime/chain/llmchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any
instance := l.Instance
options := GetChainOptions(instance.Spec.CommonChainConfig)

needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
}

// Check if have files as input
v3, ok := args["documents"]
if ok {
Expand Down Expand Up @@ -117,18 +123,14 @@ func (l *LLMChain) Run(ctx context.Context, _ client.Client, args map[string]any
l.LLMChain = *chain

var out string
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))

// Predict based on options
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.LLMChain, args, options...)
} else {
out, err = chains.Predict(ctx, l.LLMChain, args)
}
out, err = chains.Predict(ctx, l.LLMChain, args)
}

out, err = handleNoErrNoOut(ctx, needStream, out, err, l.LLMChain, args, options)
klog.FromContext(ctx).V(5).Info("use llmchain, blocking out:" + out)
if err == nil {
Expand Down
12 changes: 7 additions & 5 deletions pkg/appruntime/chain/mpchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,18 @@ import (
const (
// For map-reduce
DefaultPromptTemplateForMap = `
Content: {{.context}}
As an expert document summarizer, please provide a concise summary of the following content based on your expertise. Don't worry about the length of the summary:
With above content, please summarize it with only 1/5 size of the content.Please remind that your answer must use same language(中文或English) of the content.
Content: {{.context}}
Please note that your response should be in the same language as the content (English or Chinese).
`
DefaultPromptTemplatForReduce = `
Below is the sub-summaries that each is based on a piece of a complete document:
After segmenting the document and generating sub-summaries for each section, it is now time to create a comprehensive summary. Below are the sub-summaries, each based on a specific part of the complete document:
{{.context}}
{{.context}}
Please generate a single summary based on above sub-summaries.
Please generate a cohesive summary that encapsulates the main points from the provided sub-summaries.
`
)

Expand Down
19 changes: 10 additions & 9 deletions pkg/appruntime/chain/retrievalqachain.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st

instance := l.Instance
options := GetChainOptions(instance.Spec.CommonChainConfig)
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))
}

// Check if have files as input
v5, ok := args["documents"]
Expand Down Expand Up @@ -119,18 +124,14 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ client.Client, args map[st
l.ConversationalRetrievalQA = chain
args["query"] = args["question"]
var out string
needStream := false
needStream, ok = args[base.InputIsNeedStreamKeyInArg].(bool)
if ok && needStream {
options = append(options, chains.WithStreamingFunc(stream(args)))

// Predict based on options
if len(options) > 0 {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...)
} else {
if len(options) > 0 {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...)
} else {
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args)
}
out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args)
}

out, err = handleNoErrNoOut(ctx, needStream, out, err, l.ConversationalRetrievalQA, args, options)
klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out)
if err == nil {
Expand Down

0 comments on commit 9fd5da9

Please sign in to comment.