Skip to content

Commit

Permalink
增加ToolChoice java&&go单测 (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj authored Sep 11, 2024
1 parent 2aa73f0 commit a6b22ec
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 14 deletions.
4 changes: 2 additions & 2 deletions go/appbuilder/app_builder_client_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ type AppBuilderClientRunRequest struct {
AppID string `json:"app_id"`
Query string `json:"query"`
Stream bool `json:"stream"`
EndUserID string `json:"end_user_id"`
EndUserID *string `json:"end_user_id"`
ConversationID string `json:"conversation_id"`
Tools []Tool `json:"tools"`
ToolOutputs []ToolOutput `json:"tool_outputs"`
ToolChoice ToolChoice `json:"tool_choice"`
ToolChoice *ToolChoice `json:"tool_choice"`
}

type Tool struct {
Expand Down
54 changes: 51 additions & 3 deletions go/appbuilder/app_builder_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ func TestNewAppBuilderClient(t *testing.T) {
func TestAppBuilderClientRunWithToolCall(t *testing.T) {
os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")
os.Setenv("APPBUILDER_LOGFILE", "")
os.Setenv("GATEWAY_URL_V2", "https://apaas-api-sandbox.baidu-int.com/")
config, err := NewSDKConfig("", "bce-v3/ALTAK-vGrDN4BvjP15rDrXBI9OC/6d435ece62ed09b396e1b051bd87869c11861332")
config, err := NewSDKConfig("", "")
if err != nil {
t.Fatalf("new http client config failed: %v", err)
}

appID := "4d4b1b27-d607-4d2a-9002-206134217a9f"
appID := ""
client, err := NewAppBuilderClient(appID, config)
if err != nil {
t.Fatalf("new AgentBuidler instance failed")
Expand Down Expand Up @@ -160,3 +159,52 @@ func TestAppBuilderClientRunWithToolCall(t *testing.T) {
fmt.Println("----------------answer-------------------")
fmt.Println(totalAnswer)
}

func TestAppBuilderClientRunToolChoice(t *testing.T) {
os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")
os.Setenv("APPBUILDER_LOGFILE", "")
config, err := NewSDKConfig("", "")
if err != nil {
t.Fatalf("new http client config failed: %v", err)
}

appID := ""
client, err := NewAppBuilderClient(appID, config)
if err != nil {
t.Fatalf("new AgentBuidler instance failed")
}

conversationID, err := client.CreateConversation()
if err != nil {
t.Fatalf("create conversation failed: %v", err)
}

input := make(map[string]interface{})
input["city"] = "北京"
end_user_id := "go_user_id_0"
i, err := client.RunWithToolCall(AppBuilderClientRunRequest{
ConversationID: conversationID,
AppID: appID,
Query: "你能干什么",
EndUserID: &end_user_id,
Stream: false,
ToolChoice: &ToolChoice{
Type: "function",
Function: ToolChoiceFunction{
Name: "WeatherQuery",
Input: input,
},
},
})

if err != nil {
fmt.Println("run failed: ", err)
}

for answer, err := i.Next(); err == nil; answer, err = i.Next() {
for _, ev := range answer.Events {
evJSON, _ := json.Marshal(ev)
fmt.Println(string(evJSON))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.tools.Tool;
import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator;
import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult;
import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest;
Expand Down Expand Up @@ -100,11 +99,10 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx
required.add("location");
parameters.put("required", required);

AppBuilderClientRunRequest.Tool.Function func =
new AppBuilderClientRunRequest.Tool.Function(name, desc, parameters);
AppBuilderClientRunRequest.Tool tool =
new AppBuilderClientRunRequest.Tool("function", func);
request.setTools(new AppBuilderClientRunRequest.Tool[] {tool});
AppBuilderClientRunRequest.Tool.Function func = new AppBuilderClientRunRequest.Tool.Function(name, desc,
parameters);
AppBuilderClientRunRequest.Tool tool = new AppBuilderClientRunRequest.Tool("function", func);
request.setTools(new AppBuilderClientRunRequest.Tool[] { tool });

AppBuilderClientIterator itor = builder.run(request);
assertTrue(itor.hasNext());
Expand All @@ -119,14 +117,40 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx
request2.setAppId(appId);
request2.setConversationID(conversationId);

AppBuilderClientRunRequest.ToolOutput output =
new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度");
request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] {output});
AppBuilderClientRunRequest.ToolOutput output = new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度");
request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] { output });
AppBuilderClientIterator itor2 = builder.run(request2);
assertTrue(itor2.hasNext());
while (itor2.hasNext()) {
AppBuilderClientResult result = itor2.next();
System.out.println(result);
}
}

@Test
public void AppBuilderClientRunToolChoiceTest() throws IOException, AppBuilderServerException {
AppBuilderClient builder = new AppBuilderClient(appId);
String conversationId = builder.createConversation();
assertNotNull(conversationId);

AppBuilderClientRunRequest request = new AppBuilderClientRunRequest();
request.setAppId(appId);
request.setConversationID(conversationId);
request.setQuery("你能干什么");
request.setStream(false);
request.setEndUserId("java_test_user_0");
Map<String, Object> input = new HashMap<>();
input.put("city", "北京");
AppBuilderClientRunRequest.ToolChoice.Function func = new AppBuilderClientRunRequest.ToolChoice.Function(
"WeatherQuery", input);
AppBuilderClientRunRequest.ToolChoice choice = new AppBuilderClientRunRequest.ToolChoice("function", func);
request.setToolChoice(choice);

AppBuilderClientIterator itor = builder.run(request);
assertTrue(itor.hasNext());
while (itor.hasNext()) {
AppBuilderClientResult result = itor.next();
System.out.println(result);
}
}
}

0 comments on commit a6b22ec

Please sign in to comment.