Commit 4428dadb authored by alex yao's avatar alex yao

优化记忆函数主键

parent 8f2a35c7
...@@ -31,7 +31,7 @@ public interface AgentApplicationInfoService { ...@@ -31,7 +31,7 @@ public interface AgentApplicationInfoService {
/** /**
* 应用预览 * 应用预览
*/ */
String callAgentApplication(String identifier, String largeModel, String[] unitIds, String agentSystem, String callAgentApplication(String agentId, String identifier, String largeModel, String[] unitIds, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature, Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, HttpServletResponse httpServletResponse) throws Exception; List<Message> messages, List<Tool> tools, HttpServletResponse httpServletResponse) throws Exception;
......
...@@ -4,7 +4,6 @@ import cn.com.poc.agent_application.aggregate.AgentApplicationInfoService; ...@@ -4,7 +4,6 @@ import cn.com.poc.agent_application.aggregate.AgentApplicationInfoService;
import cn.com.poc.agent_application.constant.AgentApplicationConstants; import cn.com.poc.agent_application.constant.AgentApplicationConstants;
import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants; import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants;
import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants; import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants;
import cn.com.poc.agent_application.convert.AgentApplicationInfoConvert;
import cn.com.poc.agent_application.domain.FunctionResult; import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.entity.*; import cn.com.poc.agent_application.entity.*;
import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem; import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem;
...@@ -174,16 +173,16 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -174,16 +173,16 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
} }
@Override @Override
public String callAgentApplication(String identifier, String largeModel, String[] unitIds, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, HttpServletResponse httpServletResponse) throws Exception { public String callAgentApplication(String agentId, String identifier, String largeModel, String[] unitIds, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, HttpServletResponse httpServletResponse) throws Exception {
logger.info("--------- Call Agent Application large model:{},unitIds:{},agentSystem:{},knowledgeIds:{}" + " communicationTurn:{},topP:{},messages:{}--------------", largeModel, unitIds, agentSystem, kdIds, communicationTurn, topP, messages); logger.info("--------- Call Agent Application large model:{},unitIds:{},agentSystem:{},knowledgeIds:{}" + " communicationTurn:{},topP:{},messages:{}--------------", largeModel, unitIds, agentSystem, kdIds, communicationTurn, topP, messages);
String model = modelConvert(largeModel); String model = modelConvert(largeModel);
Tool[] toolArray = tools.toArray(new Tool[0]); Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(identifier, messages, toolArray); FunctionResult functionResult = functionCall(identifier, messages, toolArray, agentId);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, identifier); String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, identifier, agentId);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate); Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
...@@ -534,13 +533,30 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -534,13 +533,30 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param identifier 对话标识符 * @param identifier 对话标识符
* @return * @return
*/ */
private String buildDialogsPrompt(FunctionResult functionResult, List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String identifier) { private String buildDialogsPrompt(FunctionResult functionResult, List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String identifier, String agentId) {
String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem(); String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem();
// 应用角色指令
promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY); promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY);
// 调用知识库 // 调用知识库
promptTemplate = buildKnowledgePrompt(messages, kdIds, promptTemplate);
// 记忆
promptTemplate = buildMemoryPrompt(promptTemplate, tools, identifier, agentId);
// 函数调用
promptTemplate = buildFunctionPrompt(functionResult, promptTemplate);
return promptTemplate;
}
/**
* 构建知识库提示词
*
* @param messages
* @param kdIds
* @param promptTemplate
* @return
*/
private String buildKnowledgePrompt(List<Message> messages, Integer[] kdIds, String promptTemplate) {
if (ArrayUtils.isNotEmpty(kdIds)) { if (ArrayUtils.isNotEmpty(kdIds)) {
List<String> knowledgeIds = new ArrayList<>(); List<String> knowledgeIds = new ArrayList<>();
for (Integer kdId : kdIds) { for (Integer kdId : kdIds) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId); BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId);
if (null == knowledgeDocumentEntity) { if (null == knowledgeDocumentEntity) {
...@@ -548,10 +564,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -548,10 +564,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
} }
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId()); knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
} }
Object content = messages.get(messages.size() - 1).getContent(); Object content = messages.get(messages.size() - 1).getContent();
String query = ""; String query = "";
if (content instanceof List) { if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString(); query = ((List<HashMap>) content).get(0).get("text").toString();
} else { } else {
...@@ -560,14 +574,26 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -560,14 +574,26 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(query, knowledgeIds, 3); List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(query, knowledgeIds, 3);
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResults.toString()); promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResults.toString());
} }
// 记忆 return promptTemplate;
}
/**
* 构造记忆体提示词
*
* @param promptTemplate
* @param tools
* @param identifier
* @return
*/
private String buildMemoryPrompt(String promptTemplate, Tool[] tools, String identifier, String agentId) {
String longMemoryContent = StringUtils.EMPTY;
String valueMemoryContent = StringUtils.EMPTY;
if (ArrayUtils.isNotEmpty(tools)) { if (ArrayUtils.isNotEmpty(tools)) {
for (Tool tool : tools) { for (Tool tool : tools) {
String name = tool.getFunction().getName(); String name = tool.getFunction().getName();
// 长期记忆 // 长期记忆
if (LargeModelFunctionEnum.set_long_memory.name().equals(name)) { if (LargeModelFunctionEnum.set_long_memory.name().equals(name)) {
List<LongMemoryEntity> longMemoryEntities = GetLongMemory.get(identifier); List<LongMemoryEntity> longMemoryEntities = GetLongMemory.get(identifier + ":" + agentId);
if (CollectionUtils.isNotEmpty(longMemoryEntities)) { if (CollectionUtils.isNotEmpty(longMemoryEntities)) {
StringBuilder stringBuilder = new StringBuilder(""); StringBuilder stringBuilder = new StringBuilder("");
for (LongMemoryEntity longMemoryEntity : longMemoryEntities) { for (LongMemoryEntity longMemoryEntity : longMemoryEntities) {
...@@ -577,16 +603,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -577,16 +603,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
.append("Content").append(":").append(longMemoryEntity.getContent()) .append("Content").append(":").append(longMemoryEntity.getContent())
.append(StringUtils.LF); .append(StringUtils.LF);
} }
String searchMemoryContent = stringBuilder.toString(); longMemoryContent = stringBuilder.toString();
promptTemplate = promptTemplate.replace("${longMemoryResult}", searchMemoryContent);
} else {
promptTemplate = promptTemplate.replace("${longMemoryResult}", StringUtils.EMPTY);
} }
} }
// 变量 // 变量记忆
if (LargeModelFunctionEnum.set_value_memory.name().equals(name)) { if (LargeModelFunctionEnum.set_value_memory.name().equals(name)) {
Map<Object, Object> map = GetValueMemory.get(identifier); Map<Object, Object> map = GetValueMemory.get(identifier + ":" + agentId);
StringBuilder stringBuilder = new StringBuilder(""); StringBuilder stringBuilder = new StringBuilder("");
if (MapUtils.isNotEmpty(map)) { if (MapUtils.isNotEmpty(map)) {
Set<Object> keySet = map.keySet(); Set<Object> keySet = map.keySet();
...@@ -594,18 +617,26 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -594,18 +617,26 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
stringBuilder.append(key.toString()).append(":").append(map.get(key)).append(StringUtils.LF); stringBuilder.append(key.toString()).append(":").append(map.get(key)).append(StringUtils.LF);
} }
} }
promptTemplate = promptTemplate.replace("${valueMemoryResult}", stringBuilder.toString()); valueMemoryContent = stringBuilder.toString();
} else {
promptTemplate = promptTemplate.replace("${valueMemoryResult}", StringUtils.EMPTY);
} }
} }
// 函数
if (functionResult != null) {
promptTemplate = promptTemplate.replace("${functionName}", StringUtils.isNotBlank(functionResult.getFunctionName()) ? functionResult.getFunctionName() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionArg}", StringUtils.isNotBlank(functionResult.getFunctionArg()) ? functionResult.getFunctionArg() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionDesc}", StringUtils.isNotBlank(functionResult.getFunctionDesc()) ? functionResult.getFunctionDesc() : StringUtils.EMPTY);
promptTemplate = promptTemplate.replace("${functionResult}", StringUtils.isNotBlank(functionResult.getFunctionResult()) ? functionResult.getFunctionResult() : StringUtils.EMPTY);
} }
return promptTemplate.replace("${longMemoryResult}", longMemoryContent).replace("${valueMemoryResult}", valueMemoryContent);
}
/**
* 构建函数调用提示词
*
* @param functionResult
* @param promptTemplate
* @return
*/
private String buildFunctionPrompt(FunctionResult functionResult, String promptTemplate) {
if (functionResult != null) {
promptTemplate = promptTemplate.replace("${functionName}", StringUtils.isNotBlank(functionResult.getFunctionName()) ? functionResult.getFunctionName() : StringUtils.EMPTY)
.replace("${functionArg}", StringUtils.isNotBlank(functionResult.getFunctionArg()) ? functionResult.getFunctionArg() : StringUtils.EMPTY)
.replace("${functionDesc}", StringUtils.isNotBlank(functionResult.getFunctionDesc()) ? functionResult.getFunctionDesc() : StringUtils.EMPTY)
.replace("${functionResult}", StringUtils.isNotBlank(functionResult.getFunctionResult()) ? functionResult.getFunctionResult() : StringUtils.EMPTY);
} }
return promptTemplate; return promptTemplate;
} }
...@@ -627,72 +658,6 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -627,72 +658,6 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
largeModelResponse.setStream(true); largeModelResponse.setStream(true);
largeModelResponse.setUser("POE"); largeModelResponse.setUser("POE");
return llmService.chatChunk(largeModelResponse); return llmService.chatChunk(largeModelResponse);
// bufferedReader.mark(200);
//
// String res = "";
//
// boolean isFunctionCall = false;
// String functionResult = null;
// StringBuffer finishReason = new StringBuffer();
// StringBuffer functionName = new StringBuffer();
// StringBuffer functionArguments = new StringBuffer();
//
// while ((res = bufferedReader.readLine()) != null) {
// if (StringUtils.isBlank(res)) {
// continue;
// }
// // 添加function 参数 LargeModelDemandResult 和中台接口一致
// LargeModelDemandResult result = JsonUtils.deSerialize(res.replaceFirst(EVENT_STREAM_PREFIX, StringUtils.EMPTY), LargeModelDemandResult.class);
// if (!"0".equals(result.getCode())) {
// logger.error("LLM Error,code:{}", result.getCode());
// throw new I18nMessageException("exception/call.failure");
// }
//
// if (StringUtils.isBlank(result.getMessage())
// && StringUtils.isBlank(result.getFunction().getName())
// && StringUtils.isBlank(result.getFunction().getArguments())
// && StringUtils.isBlank(result.getFinish_reason())) {
// continue;
// }
//
// if (!(result.getFunction() != null && (StringUtils.isNotBlank(result.getFunction().getName()) || StringUtils.isNotBlank(result.getFunction().getArguments()))) && !isFunctionCall) {
// bufferedReader.reset();
// break;
// } else {
// isFunctionCall = true;
// if (result.getFunction().getName() != null && !result.getFunction().getName().isEmpty()) {
// functionName.append(result.getFunction().getName());
// }
// if (result.getFunction().getArguments() != null && !result.getFunction().getArguments().isEmpty()) {
// functionArguments.append(result.getFunction().getArguments());
// }
// }
//
// if (result.getFinish_reason() != null) {
// finishReason.append(result.getFinish_reason());
// }
// // 走到了最后一个流,判断是否为function_call
// // 如果是function_call,则处理function
// if (finishReason.toString().equals("tool_calls")) {
// // 如果是function_call,则处理function
// if (!functionName.toString().isEmpty() && !functionArguments.toString().isEmpty()) {
// // 执行函数返回结果
// LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionName.toString());
// functionResult = functionEnum.getFunction().doFunction(functionArguments.toString(), identifier);
// }
//
// if (functionResult != null) {
// // 处理function - 1,获取function_call 2. 调用对应function 3. 返回function_call并且调用LLM 4.获取bufferedReader 返回
// String functionRole = largeModelResponse.getModel().startsWith("qwen") ? LLMRoleEnum.FUNCTION.getRole() : LLMRoleEnum.TOOL.getRole();
// Message[] sendMessage = buildFunctionMessage(messageArray, functionName.toString(), functionArguments.toString(), functionResult, functionRole);
// largeModelResponse.setMessages(sendMessage);
// return llmService.chatChunk(largeModelResponse);
// }
// }
// }
// // 若不为function_call,则直接返回bufferedReader
// return bufferedReader;
} }
/** /**
...@@ -704,7 +669,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -704,7 +669,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param functionResult * @param functionResult
* @return * @return
*/ */
private Message[] buildFunctionMessage(Message[] messageArray, String functionName, String functionArguments, String functionResult, String functionRole) { private Message[] buildFunctionMessage(Message[] messageArray, String functionName, String
functionArguments, String functionResult, String functionRole) {
List<FunctionCall> functionCalls = new ArrayList<>(); List<FunctionCall> functionCalls = new ArrayList<>();
FunctionCall functionCall = new FunctionCall(); FunctionCall functionCall = new FunctionCall();
...@@ -827,7 +793,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -827,7 +793,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param messages * @param messages
* @param tools * @param tools
*/ */
private FunctionResult functionCall(String identifier, List<Message> messages, Tool[] tools) { private FunctionResult functionCall(String identifier, List<Message> messages, Tool[] tools, String agentId) {
FunctionResult result = new FunctionResult(); FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) { if (ArrayUtils.isEmpty(tools)) {
return result; return result;
...@@ -843,7 +809,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -843,7 +809,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
if (functionCallResult != null && functionCallResult.isNeed()) { if (functionCallResult != null && functionCallResult.isNeed()) {
// 执行函数返回结果 // 执行函数返回结果
LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionCallResult.getFunctionCall().getName()); LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionCallResult.getFunctionCall().getName());
String functionResult = functionEnum.getFunction().doFunction(functionCallResult.getFunctionCall().getArguments(), identifier); String functionResult = functionEnum.getFunction().doFunction(functionCallResult.getFunctionCall().getArguments(), identifier + "_" + agentId);
//构造返回结果 //构造返回结果
result.setFunctionName(functionCallResult.getFunctionCall().getName()); result.setFunctionName(functionCallResult.getFunctionCall().getName());
......
...@@ -347,7 +347,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -347,7 +347,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
public List<AgentApplicationValueMemoryDto> getVariableList(String agentId) { public List<AgentApplicationValueMemoryDto> getVariableList(String agentId) {
List<AgentApplicationValueMemoryDto> result = new ArrayList<>(); List<AgentApplicationValueMemoryDto> result = new ArrayList<>();
BizAgentApplicationInfoEntity infoEntity = bizAgentApplicationInfoService.getByAgentId(agentId); BizAgentApplicationInfoEntity infoEntity = bizAgentApplicationInfoService.getByAgentId(agentId);
Map<Object, Object> map = GetValueMemory.get(agentId); Map<Object, Object> map = GetValueMemory.get(agentId + ":" + agentId);
List<Variable> variableStructure = infoEntity.getVariableStructure(); List<Variable> variableStructure = infoEntity.getVariableStructure();
if (MapUtils.isEmpty(map)) { if (MapUtils.isEmpty(map)) {
if (CollectionUtils.isEmpty(variableStructure)) { if (CollectionUtils.isEmpty(variableStructure)) {
......
...@@ -136,7 +136,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -136,7 +136,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
outputRecord.setTimestamp(System.currentTimeMillis()); outputRecord.setTimestamp(System.currentTimeMillis());
//对话 //对话
String output = agentApplicationInfoService.callAgentApplication(dialogsId, infoEntity.getLargeModel(), String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getUnitIds(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getUnitIds(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, httpServletResponse);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment