Commit 549c6d87 authored by alex yao's avatar alex yao

fix:修复调用扣减多个插件积分问题

parent 8964d53d
package cn.com.poc.agent_application.aggregate; package cn.com.poc.agent_application.aggregate;
import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity; import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity;
import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity; import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.List; import java.util.List;
...@@ -53,6 +55,28 @@ public interface AgentApplicationInfoService { ...@@ -53,6 +55,28 @@ public interface AgentApplicationInfoService {
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature, Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception; List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
/**
* Agent应用对话
*
* @param agentId 应用ID
* @param identifier 对话唯一标识
* @param largeModel 模型
* @param agentSystem 应用角色指令
* @param knowledgeIds 知识库ID
* @param communicationTurn 对话轮数
* @param topP 模型参数topP
* @param temperature 模型参数temperature
* @param messages 对话消息
* @param tools 插件配置
* @param functionCallResult 插件回调结果
* @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输
*/
String callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* 应用下架 * 应用下架
* *
......
...@@ -188,6 +188,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -188,6 +188,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse); return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
} }
@Override
public String callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception {
String model = modelConvert(largeModel);
Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(dialogueId, functionCallResult, agentId);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
}
@Override @Override
public void createAgentSystem(String input, HttpServletResponse httpServletResponse) throws Exception { public void createAgentSystem(String input, HttpServletResponse httpServletResponse) throws Exception {
...@@ -907,6 +921,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -907,6 +921,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
} }
private FunctionResult functionCall(String dialogueId, FunctionCallResult functionCallResult, String agentId) {
FunctionResult result = new FunctionResult();
if (functionCallResult != null && functionCallResult.isNeed()) {
// 执行函数返回结果
LargeModelFunctionEnum functionEnum = LargeModelFunctionEnum.valueOf(functionCallResult.getFunctionCall().getName());
String functionResult = functionEnum.getFunction().doFunction(functionCallResult.getFunctionCall().getArguments(), AgentApplicationTools.identifier(dialogueId, agentId));
//构造返回结果
result.setFunctionName(functionCallResult.getFunctionCall().getName());
result.setFunctionArg(functionCallResult.getFunctionCall().getArguments());
result.setFunctionDesc(functionEnum.getFunction().getDesc());
result.setFunctionResult(functionResult);
}
return result;
}
/** /**
* 更新【记忆变量】结构 * 更新【记忆变量】结构
......
package cn.com.poc.agent_application.entity;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import java.util.List;
/**
* @author alex.yao
* @date 2025/1/22
*/
public class CheckPluginUseEntity {
private List<Tool> deductionTools;
private FunctionCallResult functionCallResult;
public List<Tool> getDeductionTools() {
return deductionTools;
}
public void setDeductionTools(List<Tool> deductionTools) {
this.deductionTools = deductionTools;
}
public FunctionCallResult getFunctionCallResult() {
return functionCallResult;
}
public void setFunctionCallResult(FunctionCallResult functionCallResult) {
this.functionCallResult = functionCallResult;
}
}
...@@ -258,9 +258,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -258,9 +258,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
String agentSystem = StringUtils.isBlank(dto.getAgentSystem()) ? infoEntity.getAgentSystem() : dto.getAgentSystem(); String agentSystem = StringUtils.isBlank(dto.getAgentSystem()) ? infoEntity.getAgentSystem() : dto.getAgentSystem();
// 判断是否调用function // 判断是否调用function
List<Tool> deductionTools = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools);
//计算扣分数 //计算扣分数
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, infoEntity.getCommunicationTurn(), deductionTools); CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, infoEntity.getCommunicationTurn(), checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo(); AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId); agentUseModifyEventInfo.setAgentId(agentId);
agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.N); agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.N);
...@@ -268,7 +268,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -268,7 +268,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务 //调用应用服务
agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model, agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), topP, agentSystem, kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), topP,
temperature, dto.getMessages(), tools, dto.getFileUrls(), true, imageUrls, httpServletResponse); temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls, httpServletResponse);
//数据采集 //数据采集
if (StringUtils.isBlank(dto.getChannel())) { if (StringUtils.isBlank(dto.getChannel())) {
dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel()); dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel());
......
package cn.com.poc.agent_application.utils; package cn.com.poc.agent_application.utils;
import cn.com.poc.agent_application.entity.CheckPluginUseEntity;
import cn.com.poc.agent_application.entity.Variable; import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.constant.CommonConstant; import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.DocumentLoad; import cn.com.poc.common.utils.DocumentLoad;
...@@ -32,7 +33,6 @@ import java.util.Map; ...@@ -32,7 +33,6 @@ import java.util.Map;
*/ */
public class AgentApplicationTools { public class AgentApplicationTools {
/** /**
* 构造Agent应用 函数配置 * 构造Agent应用 函数配置
* *
...@@ -162,9 +162,10 @@ public class AgentApplicationTools { ...@@ -162,9 +162,10 @@ public class AgentApplicationTools {
/** /**
* 判断将会调用的插件-用于扣减积分 * 判断将会调用的插件-用于扣减积分
*/ */
public static List<Tool> checkPluginUse(List<Message> messages, List<Tool> tools) { public static CheckPluginUseEntity checkPluginUse(List<Message> messages, List<Tool> tools) {
CheckPluginUseEntity checkPluginUseEntity = new CheckPluginUseEntity();
if (CollectionUtils.isEmpty(messages) || CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(messages) || CollectionUtils.isEmpty(tools)) {
return null; return checkPluginUseEntity;
} }
LLMService llmService = SpringUtils.getBean(LLMService.class); LLMService llmService = SpringUtils.getBean(LLMService.class);
String query = messages.get(messages.size() - 1).getContent().toString(); String query = messages.get(messages.size() - 1).getContent().toString();
...@@ -176,7 +177,9 @@ public class AgentApplicationTools { ...@@ -176,7 +177,9 @@ public class AgentApplicationTools {
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class); Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
deductionTools.add(tool); deductionTools.add(tool);
} }
return deductionTools; checkPluginUseEntity.setDeductionTools(deductionTools);
checkPluginUseEntity.setFunctionCallResult(functionCallResult);
return checkPluginUseEntity;
} }
} }
...@@ -143,8 +143,8 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -143,8 +143,8 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
//计算扣分数 //计算扣分数
// 判断是否调用function // 判断是否调用function
List<Tool> deductionTools = AgentApplicationTools.checkPluginUse(messages, tools); CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), deductionTools); Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo(); AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId); agentUseModifyEventInfo.setAgentId(agentId);
agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.Y); agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.Y);
...@@ -155,7 +155,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -155,7 +155,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try { try {
String output = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(), String output = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, stream, imageUrls, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls, httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, output); saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, output);
} catch (Exception e) { } catch (Exception e) {
memberEquityService.rollbackPoint(reduceSn); memberEquityService.rollbackPoint(reduceSn);
......
...@@ -165,10 +165,9 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -165,10 +165,9 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
outputRecord.setTimestamp(System.currentTimeMillis()); outputRecord.setTimestamp(System.currentTimeMillis());
//计算扣分数 //计算扣分数
// 判断是否调用function // 判断是否调用function
List<Tool> deductionTools = AgentApplicationTools.checkPluginUse(messages, tools); CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), deductionTools); Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo(); AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId); agentUseModifyEventInfo.setAgentId(agentId);
agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.Y); agentUseModifyEventInfo.setIsPublish(CommonConstant.IsDeleted.Y);
...@@ -185,7 +184,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -185,7 +184,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//对话 //对话
String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(), String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, true, imageUrls, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls, httpServletResponse);
//保存对话记录 //保存对话记录
outputRecord.setContent(output); outputRecord.setContent(output);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment