Commit 27cf6370 authored by alex yao's avatar alex yao

refactor:重构Function Call流程

parent 3b17f40e
......@@ -5,6 +5,7 @@ import cn.com.poc.agent_application.constant.AgentApplicationConstants;
import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants;
import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants;
import cn.com.poc.agent_application.convert.AgentApplicationInfoConvert;
import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.entity.*;
import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem;
import cn.com.poc.agent_application.service.*;
......@@ -22,6 +23,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.FunctionCall;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
......@@ -164,11 +166,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
Tool[] toolArray = tools.toArray(new Tool[0]);
String promptTemplate = buildDialogsPrompt(messages, agentSystem, kdIds, toolArray, identifier);
FunctionResult functionResult = functionCall(identifier, messages, toolArray);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, identifier);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
BufferedReader bufferedReader = invokeLLM(model, messageArray, topP, toolArray, identifier);
BufferedReader bufferedReader = invokeLLM(model, messageArray, topP);
return textOutput(httpServletResponse, bufferedReader);
}
......@@ -507,12 +511,15 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
/**
* 构建对话提示词
*
* @param messages
* @param agentSystem
* @param kdIds
* @param functionResult 函数结果
* @param messages 对话消息
* @param agentSystem 应用角色指令
* @param kdIds 知识库id
* @param tools 组件
* @param identifier 对话标识符
* @return
*/
private String buildDialogsPrompt(List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String identifier) {
private String buildDialogsPrompt(FunctionResult functionResult, List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String identifier) {
String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem();
promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY);
// 调用知识库
......@@ -577,7 +584,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
promptTemplate = promptTemplate.replace("${valueMemoryResult}", StringUtils.EMPTY);
}
}
// 函数
if (functionResult != null) {
promptTemplate = promptTemplate.replace("${functionName}", StringUtils.isNotBlank(functionResult.getFunctionName()) ? functionResult.getFunctionName() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionArg}", StringUtils.isNotBlank(functionResult.getFunctionArg()) ? functionResult.getFunctionArg() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionDesc}", StringUtils.isNotBlank(functionResult.getFunctionDesc()) ? functionResult.getFunctionDesc() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionResult}", StringUtils.isNotBlank(functionResult.getFunctionResult()) ? functionResult.getFunctionResult() : StringUtils.EMPTY);
}
}
return promptTemplate;
}
......@@ -591,85 +604,79 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @return
* @throws Exception
*/
private BufferedReader invokeLLM(String largeModel, Message[] messageArray, Float topP, Tool[] tools, String identifier) throws Exception {
private BufferedReader invokeLLM(String largeModel, Message[] messageArray, Float topP) throws Exception {
LargeModelResponse largeModelResponse = new LargeModelResponse();
largeModelResponse.setModel(largeModel);
largeModelResponse.setMessages(messageArray);
largeModelResponse.setTopP(topP);
largeModelResponse.setStream(true);
largeModelResponse.setUser("POE");
if (ArrayUtils.isNotEmpty(tools)) {
largeModelResponse.setTools(tools);
largeModelResponse.setTool_choice("auto");
}
BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse);
bufferedReader.mark(200);
String res = "";
boolean isFunctionCall = false;
String functionResult = null;
StringBuffer finishReason = new StringBuffer();
StringBuffer functionName = new StringBuffer();
StringBuffer functionArguments = new StringBuffer();
while ((res = bufferedReader.readLine()) != null) {
if (StringUtils.isBlank(res)) {
continue;
}
// 添加function 参数 LargeModelDemandResult 和中台接口一致
LargeModelDemandResult result = JsonUtils.deSerialize(res.replaceFirst(EVENT_STREAM_PREFIX, StringUtils.EMPTY), LargeModelDemandResult.class);
if (!"0".equals(result.getCode())) {
logger.error("LLM Error,code:{}", result.getCode());
throw new I18nMessageException("exception/call.failure");
}
if (StringUtils.isBlank(result.getMessage())
&& StringUtils.isBlank(result.getFunction().getName())
&& StringUtils.isBlank(result.getFunction().getArguments())
&& StringUtils.isBlank(result.getFinish_reason())) {
continue;
}
if (!(result.getFunction() != null && (StringUtils.isNotBlank(result.getFunction().getName()) || StringUtils.isNotBlank(result.getFunction().getArguments()))) && !isFunctionCall) {
bufferedReader.reset();
break;
} else {
isFunctionCall = true;
if (result.getFunction().getName() != null && !result.getFunction().getName().isEmpty()) {
functionName.append(result.getFunction().getName());
}
if (result.getFunction().getArguments() != null && !result.getFunction().getArguments().isEmpty()) {
functionArguments.append(result.getFunction().getArguments());
}
}
if (result.getFinish_reason() != null) {
finishReason.append(result.getFinish_reason());
}
// 走到了最后一个流,判断是否为function_call
// 如果是function_call,则处理function
if (finishReason.toString().equals("tool_calls")) {
// 如果是function_call,则处理function
if (!functionName.toString().isEmpty() && !functionArguments.toString().isEmpty()) {
// 执行函数返回结果
LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionName.toString());
functionResult = functionEnum.getFunction().doFunction(functionArguments.toString(), identifier);
}
if (functionResult != null) {
// 处理function - 1,获取function_call 2. 调用对应function 3. 返回function_call并且调用LLM 4.获取bufferedReader 返回
String functionRole = largeModelResponse.getModel().startsWith("qwen") ? LLMRoleEnum.FUNCTION.getRole() : LLMRoleEnum.TOOL.getRole();
Message[] sendMessage = buildFunctionMessage(messageArray, functionName.toString(), functionArguments.toString(), functionResult, functionRole);
largeModelResponse.setMessages(sendMessage);
return llmService.chatChunk(largeModelResponse);
}
}
}
// 若不为function_call,则直接返回bufferedReader
return bufferedReader;
return llmService.chatChunk(largeModelResponse);
// bufferedReader.mark(200);
//
// String res = "";
//
// boolean isFunctionCall = false;
// String functionResult = null;
// StringBuffer finishReason = new StringBuffer();
// StringBuffer functionName = new StringBuffer();
// StringBuffer functionArguments = new StringBuffer();
//
// while ((res = bufferedReader.readLine()) != null) {
// if (StringUtils.isBlank(res)) {
// continue;
// }
// // 添加function 参数 LargeModelDemandResult 和中台接口一致
// LargeModelDemandResult result = JsonUtils.deSerialize(res.replaceFirst(EVENT_STREAM_PREFIX, StringUtils.EMPTY), LargeModelDemandResult.class);
// if (!"0".equals(result.getCode())) {
// logger.error("LLM Error,code:{}", result.getCode());
// throw new I18nMessageException("exception/call.failure");
// }
//
// if (StringUtils.isBlank(result.getMessage())
// && StringUtils.isBlank(result.getFunction().getName())
// && StringUtils.isBlank(result.getFunction().getArguments())
// && StringUtils.isBlank(result.getFinish_reason())) {
// continue;
// }
//
// if (!(result.getFunction() != null && (StringUtils.isNotBlank(result.getFunction().getName()) || StringUtils.isNotBlank(result.getFunction().getArguments()))) && !isFunctionCall) {
// bufferedReader.reset();
// break;
// } else {
// isFunctionCall = true;
// if (result.getFunction().getName() != null && !result.getFunction().getName().isEmpty()) {
// functionName.append(result.getFunction().getName());
// }
// if (result.getFunction().getArguments() != null && !result.getFunction().getArguments().isEmpty()) {
// functionArguments.append(result.getFunction().getArguments());
// }
// }
//
// if (result.getFinish_reason() != null) {
// finishReason.append(result.getFinish_reason());
// }
// // 走到了最后一个流,判断是否为function_call
// // 如果是function_call,则处理function
// if (finishReason.toString().equals("tool_calls")) {
// // 如果是function_call,则处理function
// if (!functionName.toString().isEmpty() && !functionArguments.toString().isEmpty()) {
// // 执行函数返回结果
// LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionName.toString());
// functionResult = functionEnum.getFunction().doFunction(functionArguments.toString(), identifier);
// }
//
// if (functionResult != null) {
// // 处理function - 1,获取function_call 2. 调用对应function 3. 返回function_call并且调用LLM 4.获取bufferedReader 返回
// String functionRole = largeModelResponse.getModel().startsWith("qwen") ? LLMRoleEnum.FUNCTION.getRole() : LLMRoleEnum.TOOL.getRole();
// Message[] sendMessage = buildFunctionMessage(messageArray, functionName.toString(), functionArguments.toString(), functionResult, functionRole);
// largeModelResponse.setMessages(sendMessage);
// return llmService.chatChunk(largeModelResponse);
// }
// }
// }
// // 若不为function_call,则直接返回bufferedReader
// return bufferedReader;
}
......@@ -798,4 +805,39 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return largeModelEntity.getModelName();
}
/**
* 判断是否需要FunctionCall
*
* @param identifier
* @param messages
* @param tools
*/
private FunctionResult functionCall(String identifier, List<Message> messages, Tool[] tools) {
FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) {
return result;
}
Object content = messages.get(messages.size() - 1).getContent();
String query;
if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString();
} else {
query = content.toString();
}
FunctionCallResult functionCallResult = llmService.functionCall(query, tools);
if (functionCallResult != null && functionCallResult.isNeed()) {
// 执行函数返回结果
LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionCallResult.getFunctionCall().getName());
String functionResult = functionEnum.getFunction().doFunction(functionCallResult.getFunctionCall().getArguments(), identifier);
//构造返回结果
result.setFunctionName(functionCallResult.getFunctionCall().getName());
result.setFunctionArg(functionCallResult.getFunctionCall().getArguments());
result.setFunctionDesc(functionEnum.getFunction().getDesc());
result.setFunctionResult(functionResult);
}
return result;
}
}
package cn.com.poc.agent_application.domain;
public class FunctionResult {
private String functionName;
private String functionArg;
private String functionDesc;
private String functionResult;
public String getFunctionName() {
return functionName;
}
public void setFunctionName(String functionName) {
this.functionName = functionName;
}
public String getFunctionArg() {
return functionArg;
}
public void setFunctionArg(String functionArg) {
this.functionArg = functionArg;
}
public String getFunctionDesc() {
return functionDesc;
}
public void setFunctionDesc(String functionDesc) {
this.functionDesc = functionDesc;
}
public String getFunctionResult() {
return functionResult;
}
public void setFunctionResult(String functionResult) {
this.functionResult = functionResult;
}
}
......@@ -3,7 +3,7 @@ package cn.com.poc.support.dgTools;
import cn.com.poc.common.constant.FmxParamConfigConstant;
import cn.com.poc.common.utils.ListUtils;
import cn.com.poc.common.utils.http.LocalHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.support.dgTools.request.AbstractParam;
import cn.com.poc.support.dgTools.request.AbstractRequest;
import cn.com.poc.support.dgTools.request.ProjectTokenRequest;
......
......@@ -4,7 +4,7 @@ import cn.com.poc.common.constant.FmxParamConfigConstant;
import cn.com.poc.common.constant.MkpRedisKeyConstant;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.support.dgTools.request.ProjectTokenRequest;
import cn.com.poc.support.dgTools.result.ProjectTokenResult;
import cn.com.poc.support.dgTools.service.AuthorizationService;
......
package cn.com.poc.thirdparty.resource.demand.ai.aggregate;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
......@@ -23,5 +25,11 @@ public interface AIDialogueService {
/**
* 调用中台通用大模型接口 [非流]
*/
LargeModelDemandResult poly(LargeModelResponse largeResponse) ;
LargeModelDemandResult poly(LargeModelResponse largeResponse);
/**
* 判断是否需要Function Call
*/
FunctionCallResult functionCall(FunctionCallResponse response);
}
......@@ -8,7 +8,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.OpenAiResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.*;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
......
package cn.com.poc.thirdparty.resource.demand.ai.aggregate.impl;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AIDialogueService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import org.apache.http.Header;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.RequestBuilder;
......@@ -58,7 +60,7 @@ public class AIDialogueServiceImpl implements AIDialogueService {
}
@Override
public LargeModelDemandResult poly(LargeModelResponse largeResponse) {
public LargeModelDemandResult poly(LargeModelResponse largeResponse) {
largeResponse.setStream(false);
LargeModelDemandResponse response = new LargeModelDemandResponse();
response.setApiKey(API_KEY);
......@@ -66,6 +68,18 @@ public class AIDialogueServiceImpl implements AIDialogueService {
return largeModelRequest(response);
}
@Override
public FunctionCallResult functionCall(FunctionCallResponse response) {
String url = DgtoolsApiConstants.DgtoolsAI.FUNCTION_CALL;
response.setApiKey(API_KEY);
List<Header> headers = new ArrayList<Header>() {{
add(DgtoolsApiConstants.JSON_HEADER);
add(DgtoolsApiConstants.AI_HEADER);
add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
}};
return dgToolsAbstractHttpClient.doRequest(url, response, headers);
}
private BufferedReader largeModelStream(LargeModelDemandResponse request) throws IOException {
String jsonBody = dgToolsAbstractHttpClient.buildJson(request);
CloseableHttpClient httpClient = HttpClients.createDefault();
......@@ -89,7 +103,7 @@ public class AIDialogueServiceImpl implements AIDialogueService {
add(DgtoolsApiConstants.AI_HEADER);
add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
}};
return dgToolsAbstractHttpClient.doRequest(url, request, headers);
return dgToolsAbstractHttpClient.doRequest(url, request, headers);
}
......
package cn.com.poc.thirdparty.resource.demand.ai.aggregate.impl;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.support.dgTools.result.AbstractResult;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.*;
......
package cn.com.poc.support.dgTools.constants;
package cn.com.poc.thirdparty.resource.demand.ai.common;
import org.apache.http.Header;
import org.apache.http.HttpHeaders;
......@@ -257,6 +257,11 @@ public interface DgtoolsApiConstants {
* 大模型【通用】
*/
String LARGE_MODEL = "largeModelRest/completion.json";
/**
* Function Call 判断,判断是否需要Function Call
*/
String FUNCTION_CALL = "largeModelRest/functionCall.json";
}
interface ClickHouse {
......
......@@ -20,4 +20,12 @@ public class FunctionCall {
public void setArguments(String arguments) {
this.arguments = arguments;
}
@Override
public String toString() {
return "FunctionCall{" +
"name='" + name + '\'' +
", arguments='" + arguments + '\'' +
'}';
}
}
package cn.com.poc.thirdparty.resource.demand.ai.entity.function;
import cn.com.poc.support.dgTools.request.AbstractRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import java.io.Serializable;
import java.util.List;
public class FunctionCallResponse extends AbstractRequest<FunctionCallResult> implements Serializable {
/**
* 模型apiKey
*/
private String apiKey;
/**
* 提问内容
*/
private String query;
/**
* function 配置
*/
List<Function> functions;
public String getQuery() {
return query;
}
public void setQuery(String query) {
this.query = query;
}
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
public List<Function> getFunctions() {
return functions;
}
public void setFunctions(List<Function> functions) {
this.functions = functions;
}
@Override
public String getMethod() throws Exception {
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.entity.function;
import cn.com.poc.support.dgTools.result.AbstractResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.FunctionCall;
import java.io.Serializable;
public class FunctionCallResult extends AbstractResult implements Serializable {
/**
* 是否为函数调用
*/
private boolean need;
/**
* 函数信息
*/
private FunctionCall functionCall;
public boolean isNeed() {
return need;
}
public void setNeed(boolean need) {
this.need = need;
}
public FunctionCall getFunctionCall() {
return functionCall;
}
public void setFunctionCall(FunctionCall functionCall) {
this.functionCall = functionCall;
}
@Override
public String toString() {
return "FunctionCallResult{" +
"need=" + need +
", functionCall=" + functionCall +
'}';
}
}
......@@ -2,6 +2,7 @@ package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import java.util.List;
......@@ -9,6 +10,13 @@ public abstract class AbstractLargeModelFunction {
public abstract String doFunction(String content, String key);
/**
* 获取函数描述
*
* @return
*/
public abstract String getDesc();
/**
* 获取配置
*/
......
......@@ -24,9 +24,16 @@ import java.util.Map;
@Service
public class SetLongMemoryFunction extends AbstractLargeModelFunction {
private String desc = "该方法仅用来保存用户想记录的内容,不能通过该方法进行查询。";
@Resource
private RedisService redisService;
@Override
public String getDesc() {
return desc;
}
/**
* 执行函数
*
......@@ -67,7 +74,7 @@ public class SetLongMemoryFunction extends AbstractLargeModelFunction {
function.put("name", "set_long_memory");
function.put("description", "该方法仅用来保存用户想记录的内容,不能通过该方法进行查询。");
function.put("description", desc);
function.put("parameters", parameters);
config.put("type", "function");
......
......@@ -4,6 +4,7 @@ import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.json.JSONObject;
......@@ -18,9 +19,16 @@ import java.util.Map;
@Service
public class SetValueMemoryFunction extends AbstractLargeModelFunction {
private String desc = "该方法仅用enum给定的内容名来保存用户想记录的内容值,不可使用该方法进行查询";
@Resource
private RedisService redisService;
@Override
public String getDesc() {
return desc;
}
@Override
public String doFunction(String content, String key) {
String contentKey = SetValueMemoryConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
......@@ -80,7 +88,7 @@ public class SetValueMemoryFunction extends AbstractLargeModelFunction {
parameters.put("type", "object");
function.put("name", "set_value_memory");
function.put("description", "该方法仅用enum给定的内容名来保存用户想记录的内容值,不可使用该方法进行查询");
function.put("description", desc);
function.put("parameters", parameters);
config.put("type", "function");
......
package cn.com.poc.thirdparty.resource.demand.clickhouse.service.impl;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.clickhouse.entity.WebBrowseHarvestEntity;
import cn.com.poc.thirdparty.resource.demand.clickhouse.service.DataReportService;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
......
......@@ -3,7 +3,7 @@ package cn.com.poc.thirdparty.resource.demand.member.api;
import cn.com.poc.thirdparty.resource.demand.member.entity.DemandAuthResponse;
import cn.com.poc.thirdparty.resource.demand.member.entity.DemandAuthResult;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import org.apache.http.Header;
import org.springframework.stereotype.Service;
......
......@@ -2,7 +2,7 @@ package cn.com.poc.thirdparty.resource.demand.member.api;
import cn.com.poc.common.utils.http.LocalHttpClient;
import cn.com.poc.thirdparty.resource.demand.member.entity.DemandMemberResult;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
import org.springframework.beans.factory.annotation.Value;
......
package cn.com.poc.thirdparty.service;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
......@@ -29,5 +31,13 @@ public interface LLMService {
*/
BufferedReader chatChunk(LargeModelResponse request) throws Exception;
/**
* functionCall判断
*
* @param query
* @param tools
*/
FunctionCallResult functionCall(String query, Tool[] tools);
}
package cn.com.poc.thirdparty.service.impl;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AIDialogueService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.service.LLMService;
import cn.hutool.core.lang.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.io.BufferedReader;
import java.util.ArrayList;
import java.util.List;
/**
* @author alex.yao
......@@ -38,4 +45,21 @@ public class LLMServiceImpl implements LLMService {
public BufferedReader chatChunk(LargeModelResponse request) throws Exception {
return aiDialogueService.polyStream(request);
}
@Override
public FunctionCallResult functionCall(String query, Tool[] tools) {
Assert.notEmpty(tools);
FunctionCallResponse response = new FunctionCallResponse();
response.setQuery(query);
List<Function> functions = new ArrayList<>();
for (Tool tool : tools) {
Function function = new Function();
function.setName(tool.getFunction().getName());
function.setDescription(tool.getFunction().getDescription());
function.setParameters(tool.getFunction().getParameters());
functions.add(function);
}
response.setFunctions(functions);
return aiDialogueService.functionCall(response);
}
}
package cn.com.poc.demand;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AIDialogueService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(initializers = SingleContextInitializer.class)
@WebAppConfiguration
public class AiDialogueTest {
@Resource
AIDialogueService aiDialogueService;
@Test
public void functionCall() {
String query = "帮我记一下,今天下午需要开会";
Map<String, Object> content = new HashMap<>();
content.put("type", "string");
content.put("description", "内容的详细说明");
Map<String, Object> properties = new HashMap<>();
properties.put("content", content);
Map<String, Object> parameters = new HashMap<>();
parameters.put("type", "object");
parameters.put("properties", properties);
List<String> required = new ArrayList<>();
required.add("content");
parameters.put("required", required);
Function function = new Function();
function.setName("set_long_memory");
function.setDescription("该方法仅用来保存用户想记录的内容,不能通过该方法进行查询。");
function.setParameters(parameters);
List<Function> functions = new ArrayList<>();
functions.add(function);
FunctionCallResponse functionCallResponse = new FunctionCallResponse();
functionCallResponse.setQuery(query);
functionCallResponse.setFunctions(functions);
FunctionCallResult functionCallResult = aiDialogueService.functionCall(functionCallResponse);
System.out.println(functionCallResult);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment