From a6b22ec4e07522e41c9d352024dd3b4dd4c8b6db Mon Sep 17 00:00:00 2001 From: userpj Date: Wed, 11 Sep 2024 16:04:45 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0ToolChoice=20java&&go?= =?UTF-8?q?=E5=8D=95=E6=B5=8B=20(#511)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go/appbuilder/app_builder_client_data.go | 4 +- go/appbuilder/app_builder_client_test.go | 54 +++++++++++++++++-- .../appbuilder/AppBuilderClientTest.java | 42 +++++++++++---- 3 files changed, 86 insertions(+), 14 deletions(-) diff --git a/go/appbuilder/app_builder_client_data.go b/go/appbuilder/app_builder_client_data.go index 45f00043..2b572443 100644 --- a/go/appbuilder/app_builder_client_data.go +++ b/go/appbuilder/app_builder_client_data.go @@ -48,11 +48,11 @@ type AppBuilderClientRunRequest struct { AppID string `json:"app_id"` Query string `json:"query"` Stream bool `json:"stream"` - EndUserID string `json:"end_user_id"` + EndUserID *string `json:"end_user_id"` ConversationID string `json:"conversation_id"` Tools []Tool `json:"tools"` ToolOutputs []ToolOutput `json:"tool_outputs"` - ToolChoice ToolChoice `json:"tool_choice"` + ToolChoice *ToolChoice `json:"tool_choice"` } type Tool struct { diff --git a/go/appbuilder/app_builder_client_test.go b/go/appbuilder/app_builder_client_test.go index 2405ddc1..b70e43dc 100644 --- a/go/appbuilder/app_builder_client_test.go +++ b/go/appbuilder/app_builder_client_test.go @@ -69,13 +69,12 @@ func TestNewAppBuilderClient(t *testing.T) { func TestAppBuilderClientRunWithToolCall(t *testing.T) { os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") os.Setenv("APPBUILDER_LOGFILE", "") - os.Setenv("GATEWAY_URL_V2", "https://apaas-api-sandbox.baidu-int.com/") - config, err := NewSDKConfig("", "bce-v3/ALTAK-vGrDN4BvjP15rDrXBI9OC/6d435ece62ed09b396e1b051bd87869c11861332") + config, err := NewSDKConfig("", "") if err != nil { t.Fatalf("new http client config failed: %v", err) } - appID := "4d4b1b27-d607-4d2a-9002-206134217a9f" + appID := "" client, err := NewAppBuilderClient(appID, config) if err != nil { t.Fatalf("new AgentBuidler instance failed") @@ -160,3 +159,52 @@ func TestAppBuilderClientRunWithToolCall(t *testing.T) { fmt.Println("----------------answer-------------------") fmt.Println(totalAnswer) } + +func TestAppBuilderClientRunToolChoice(t *testing.T) { + os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") + os.Setenv("APPBUILDER_LOGFILE", "") + config, err := NewSDKConfig("", "") + if err != nil { + t.Fatalf("new http client config failed: %v", err) + } + + appID := "" + client, err := NewAppBuilderClient(appID, config) + if err != nil { + t.Fatalf("new AgentBuidler instance failed") + } + + conversationID, err := client.CreateConversation() + if err != nil { + t.Fatalf("create conversation failed: %v", err) + } + + input := make(map[string]interface{}) + input["city"] = "北京" + end_user_id := "go_user_id_0" + i, err := client.RunWithToolCall(AppBuilderClientRunRequest{ + ConversationID: conversationID, + AppID: appID, + Query: "你能干什么", + EndUserID: &end_user_id, + Stream: false, + ToolChoice: &ToolChoice{ + Type: "function", + Function: ToolChoiceFunction{ + Name: "WeatherQuery", + Input: input, + }, + }, + }) + + if err != nil { + fmt.Println("run failed: ", err) + } + + for answer, err := i.Next(); err == nil; answer, err = i.Next() { + for _, ev := range answer.Events { + evJSON, _ := json.Marshal(ev) + fmt.Println(string(evJSON)) + } + } +} diff --git a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java index 409c4884..df96628e 100644 --- a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java @@ -9,7 +9,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import javax.tools.Tool; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult; import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest; @@ -100,11 +99,10 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx required.add("location"); parameters.put("required", required); - AppBuilderClientRunRequest.Tool.Function func = - new AppBuilderClientRunRequest.Tool.Function(name, desc, parameters); - AppBuilderClientRunRequest.Tool tool = - new AppBuilderClientRunRequest.Tool("function", func); - request.setTools(new AppBuilderClientRunRequest.Tool[] {tool}); + AppBuilderClientRunRequest.Tool.Function func = new AppBuilderClientRunRequest.Tool.Function(name, desc, + parameters); + AppBuilderClientRunRequest.Tool tool = new AppBuilderClientRunRequest.Tool("function", func); + request.setTools(new AppBuilderClientRunRequest.Tool[] { tool }); AppBuilderClientIterator itor = builder.run(request); assertTrue(itor.hasNext()); @@ -119,9 +117,8 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx request2.setAppId(appId); request2.setConversationID(conversationId); - AppBuilderClientRunRequest.ToolOutput output = - new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度"); - request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] {output}); + AppBuilderClientRunRequest.ToolOutput output = new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度"); + request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] { output }); AppBuilderClientIterator itor2 = builder.run(request2); assertTrue(itor2.hasNext()); while (itor2.hasNext()) { @@ -129,4 +126,31 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx System.out.println(result); } } + + @Test + public void AppBuilderClientRunToolChoiceTest() throws IOException, AppBuilderServerException { + AppBuilderClient builder = new AppBuilderClient(appId); + String conversationId = builder.createConversation(); + assertNotNull(conversationId); + + AppBuilderClientRunRequest request = new AppBuilderClientRunRequest(); + request.setAppId(appId); + request.setConversationID(conversationId); + request.setQuery("你能干什么"); + request.setStream(false); + request.setEndUserId("java_test_user_0"); + Map input = new HashMap<>(); + input.put("city", "北京"); + AppBuilderClientRunRequest.ToolChoice.Function func = new AppBuilderClientRunRequest.ToolChoice.Function( + "WeatherQuery", input); + AppBuilderClientRunRequest.ToolChoice choice = new AppBuilderClientRunRequest.ToolChoice("function", func); + request.setToolChoice(choice); + + AppBuilderClientIterator itor = builder.run(request); + assertTrue(itor.hasNext()); + while (itor.hasNext()) { + AppBuilderClientResult result = itor.next(); + System.out.println(result); + } + } }