Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

agents: internationalize and add Chinese prompt #1001

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,19 @@ package agents

import (
"context"
_ "embed"
"fmt"
"regexp"
"strings"

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)

const (
_conversationalFinalAnswerAction = "AI:"
)

// ConversationalAgent is a struct that represents an agent responsible for deciding
// what to do or give the final output if the task is finished given a set of inputs
// and previous steps taken.
Expand All @@ -34,6 +30,10 @@ type ConversationalAgent struct {
Tools []tools.Tool
// Output key is the key where the final output is placed.
OutputKey string
// FinalAnswer is the final answer in various languages.
FinalAnswer string
// Lang is the language the prompt will use.
Lang i18n.Lang
// CallbacksHandler is the handler for callbacks.
CallbacksHandler callbacks.Handler
}
Expand All @@ -45,6 +45,7 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
for _, opt := range opts {
opt(&options)
}
options.loadConversationalTranslatable()

return &ConversationalAgent{
Chain: chains.NewLLMChain(
Expand All @@ -54,6 +55,8 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
),
Tools: tools,
OutputKey: options.outputKey,
FinalAnswer: i18n.AgentsMustPhrase(options.lang, "conversational final answer"),
Lang: options.lang,
CallbacksHandler: options.callbacksHandler,
}
}
Expand All @@ -69,7 +72,7 @@ func (a *ConversationalAgent) Plan(
fullInputs[key] = value
}

fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps)
fullInputs["agent_scratchpad"] = constructScratchPad(intermediateSteps, a.Lang)

var stream func(ctx context.Context, chunk []byte) error

Expand All @@ -84,7 +87,10 @@ func (a *ConversationalAgent) Plan(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStopWords([]string{
fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
fmt.Sprintf("\n\t%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
}),
chains.WithStreamingFunc(stream),
)
if err != nil {
Expand Down Expand Up @@ -117,22 +123,22 @@ func (a *ConversationalAgent) GetTools() []tools.Tool {
return a.Tools
}

func constructScratchPad(steps []schema.AgentStep) string {
func constructScratchPad(steps []schema.AgentStep, lang i18n.Lang) string {
var scratchPad string
if len(steps) > 0 {
for _, step := range steps {
scratchPad += step.Action.Log
scratchPad += "\nObservation: " + step.Observation
scratchPad += fmt.Sprintf("\n%s %s", i18n.AgentsMustPhrase(lang, "observation"), step.Observation)
}
scratchPad += "\n" + "Thought:"
scratchPad += fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(lang, "thought"))
}

return scratchPad
}

func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
if strings.Contains(output, _conversationalFinalAnswerAction) {
splits := strings.Split(output, _conversationalFinalAnswerAction)
if strings.Contains(output, a.FinalAnswer) {
splits := strings.Split(output, a.FinalAnswer)

finishAction := &schema.AgentFinish{
ReturnValues: map[string]any{
Expand All @@ -144,7 +150,9 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
return nil, finishAction, nil
}

r := regexp.MustCompile(`Action: (.*?)[\n]*Action Input: (.*)`)
action, actionInput := i18n.AgentsMustPhrase(a.Lang, "action"),
i18n.AgentsMustPhrase(a.Lang, "action input")
r := regexp.MustCompile(fmt.Sprintf(`%s (.*?)[\n]*%s (.*)`, action, actionInput))
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
Expand All @@ -155,15 +163,6 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
}, nil, nil
}

//go:embed prompts/conversational_prefix.txt
var _defaultConversationalPrefix string //nolint:gochecknoglobals

//go:embed prompts/conversational_format_instructions.txt
var _defaultConversationalFormatInstructions string //nolint:gochecknoglobals

//go:embed prompts/conversational_suffix.txt
var _defaultConversationalSuffix string //nolint:gochecknoglobals

func createConversationalPrompt(tools []tools.Tool, prefix, instructions, suffix string) prompts.PromptTemplate {
template := strings.Join([]string{prefix, instructions, suffix}, "\n\n")

Expand Down
1 change: 1 addition & 0 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func NewExecutor(agent Agent, opts ...Option) *Executor {
for _, opt := range opts {
opt(&options)
}
options.loadExecutorTranslatable()

return &Executor{
Agent: agent,
Expand Down
40 changes: 27 additions & 13 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,12 @@ import (

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)

const (
_finalAnswerAction = "Final Answer:"
_defaultOutputKey = "output"
)

// OneShotZeroAgent is a struct that represents an agent responsible for deciding
// what to do or give the final output if the task is finished given a set of inputs
// and previous steps taken.
Expand All @@ -32,6 +28,10 @@ type OneShotZeroAgent struct {
Tools []tools.Tool
// Output key is the key where the final output is placed.
OutputKey string
// FinalAnswer is the final answer in various languages.
FinalAnswer string
// Lang is the language the prompt will use.
Lang i18n.Lang
// CallbacksHandler is the handler for callbacks.
CallbacksHandler callbacks.Handler
}
Expand All @@ -46,6 +46,7 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
for _, opt := range opts {
opt(&options)
}
options.loadMrklTranslatable()

return &OneShotZeroAgent{
Chain: chains.NewLLMChain(
Expand All @@ -55,6 +56,8 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
),
Tools: tools,
OutputKey: options.outputKey,
FinalAnswer: i18n.AgentsMustPhrase(options.lang, "mrkl final answer"),
Lang: options.lang,
CallbacksHandler: options.callbacksHandler,
}
}
Expand All @@ -70,8 +73,8 @@ func (a *OneShotZeroAgent) Plan(
fullInputs[key] = value
}

fullInputs["agent_scratchpad"] = constructMrklScratchPad(intermediateSteps)
fullInputs["today"] = time.Now().Format("January 02, 2006")
fullInputs["agent_scratchpad"] = constructMrklScratchPad(intermediateSteps, a.Lang)
fullInputs["today"] = time.Now().Format(i18n.AgentsMustPhrase(a.Lang, "today format"))

var stream func(ctx context.Context, chunk []byte) error

Expand All @@ -86,7 +89,10 @@ func (a *OneShotZeroAgent) Plan(
ctx,
a.Chain,
fullInputs,
chains.WithStopWords([]string{"\nObservation:", "\n\tObservation:"}),
chains.WithStopWords([]string{
fmt.Sprintf("\n%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
fmt.Sprintf("\n\t%s", i18n.AgentsMustPhrase(a.Lang, "observation")),
}),
chains.WithStreamingFunc(stream),
)
if err != nil {
Expand Down Expand Up @@ -119,21 +125,21 @@ func (a *OneShotZeroAgent) GetTools() []tools.Tool {
return a.Tools
}

func constructMrklScratchPad(steps []schema.AgentStep) string {
func constructMrklScratchPad(steps []schema.AgentStep, lang i18n.Lang) string {
var scratchPad string
if len(steps) > 0 {
for _, step := range steps {
scratchPad += "\n" + step.Action.Log
scratchPad += "\nObservation: " + step.Observation + "\n"
scratchPad += fmt.Sprintf("\n%s %s\n", i18n.AgentsMustPhrase(lang, "observation"), step.Observation)
}
}

return scratchPad
}

func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
if strings.Contains(output, _finalAnswerAction) {
splits := strings.Split(output, _finalAnswerAction)
if strings.Contains(output, a.FinalAnswer) {
splits := strings.Split(output, a.FinalAnswer)

return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
Expand All @@ -143,7 +149,15 @@ func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *sc
}, nil
}

r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`)
action, actionInput, observation := i18n.AgentsMustPhrase(a.Lang, "action"),
i18n.AgentsMustPhrase(a.Lang, "action input"),
i18n.AgentsMustPhrase(a.Lang, "observation")
var r *regexp.Regexp
if strings.Contains(output, observation) {
r = regexp.MustCompile(fmt.Sprintf(`%s\s*(.+)\s*%s\s(?s)*(.+)%s`, action, actionInput, observation))
} else {
r = regexp.MustCompile(fmt.Sprintf(`%s\s*(.+)\s*%s\s(?s)*(.+)`, action, actionInput))
}
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
Expand Down
23 changes: 0 additions & 23 deletions agents/mrkl_prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,6 @@ import (
"github.com/tmc/langchaingo/tools"
)

const (
_defaultMrklPrefix = `Today is {{.today}}.
Answer the following questions as best you can. You have access to the following tools:

{{.tool_descriptions}}`

_defaultMrklFormatInstructions = `Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [ {{.tool_names}} ]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question`

_defaultMrklSuffix = `Begin!

Question: {{.input}}
{{.agent_scratchpad}}`
)

func createMRKLPrompt(tools []tools.Tool, prefix, instructions, suffix string) prompts.PromptTemplate {
template := strings.Join([]string{prefix, instructions, suffix}, "\n\n")

Expand Down
7 changes: 6 additions & 1 deletion agents/markl_test.go → agents/mrkl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/i18n"
"github.com/tmc/langchaingo/schema"
)

Expand Down Expand Up @@ -38,7 +39,11 @@ func TestMRKLOutputParser(t *testing.T) {
},
}

a := OneShotZeroAgent{}
lang := i18n.EN
a := OneShotZeroAgent{
FinalAnswer: i18n.AgentsMustPhrase(lang, "mrkl final answer"),
Lang: lang,
}
for _, tc := range testCases {
actions, finish, err := a.parseOutput(tc.input)
require.ErrorIs(t, tc.expectedErr, err)
Expand Down
1 change: 1 addition & 0 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
for _, opt := range opts {
opt(&options)
}
options.loadOpenAIFunctionsTranslatable()

return &OpenAIFunctionsAgent{
LLM: llm,
Expand Down
Loading