Commit a3fd0216 authored by alex yao's avatar alex yao

feat:Agent对话 返回知识库和插件信息

parent d60ce0d1
...@@ -4,6 +4,7 @@ import cn.com.poc.agent_application.domain.FunctionResult; ...@@ -4,6 +4,7 @@ import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.entity.AgentResultEntity; import cn.com.poc.agent_application.entity.AgentResultEntity;
import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity; import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity;
import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity; import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
...@@ -38,45 +39,53 @@ public interface AgentApplicationInfoService { ...@@ -38,45 +39,53 @@ public interface AgentApplicationInfoService {
/** /**
* Agent应用对话 * Agent应用对话
* *
* @param agentId 应用ID * @param agentId 应用ID
* @param identifier 对话唯一标识 * @param identifier 对话唯一标识
* @param largeModel 模型 * @param largeModel 模型
* @param agentSystem 应用角色指令 * @param agentSystem 应用角色指令
* @param knowledgeIds 知识库ID * @param knowledgeIds 知识库ID
* @param communicationTurn 对话轮数 * @param communicationTurn 对话轮数
* @param topP 模型参数topP * @param topP 模型参数topP
* @param temperature 模型参数temperature * @param temperature 模型参数temperature
* @param messages 对话消息 * @param messages 对话消息
* @param tools 插件配置 * @param tools 插件配置
* @param fileUrls 文件URLs * @param fileUrls 文件URLs
* @param imageUrls 图片URLs * @param imageUrls 图片URLs
* @param stream 是否流式传输 * @param stream 是否流式传输
* @param score 知识库参数score
* @param topK 知识库参数topK
* @param knowledgeSearchType 知识库参数知识搜索类型
*/ */
AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem, AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature, Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception; List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* Agent应用对话 * Agent应用对话
* *
* @param agentId 应用ID * @param agentId 应用ID
* @param identifier 对话唯一标识 * @param identifier 对话唯一标识
* @param largeModel 模型 * @param largeModel 模型
* @param agentSystem 应用角色指令 * @param agentSystem 应用角色指令
* @param knowledgeIds 知识库ID * @param knowledgeIds 知识库ID
* @param communicationTurn 对话轮数 * @param communicationTurn 对话轮数
* @param topP 模型参数topP * @param topP 模型参数topP
* @param temperature 模型参数temperature * @param temperature 模型参数temperature
* @param messages 对话消息 * @param messages 对话消息
* @param tools 插件配置 * @param tools 插件配置
* @param functionCallResult 插件回调结果 * @param functionCallResult 插件回调结果
* @param fileUrls 文件URLs * @param fileUrls 文件URLs
* @param imageUrls 图片URLs * @param imageUrls 图片URLs
* @param stream 是否流式传输 * @param stream 是否流式传输
* @param score 知识库参数score
* @param topK 知识库参数topK
* @param knowledgeSearchType 知识库参数知识搜索类型
*/ */
AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem, AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature, Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception; List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* 应用下架 * 应用下架
......
...@@ -5,6 +5,7 @@ import cn.com.poc.agent_application.constant.AgentApplicationConstants; ...@@ -5,6 +5,7 @@ import cn.com.poc.agent_application.constant.AgentApplicationConstants;
import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants; import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants;
import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants; import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants;
import cn.com.poc.agent_application.domain.FunctionResult; import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.dto.KnowledgeContentResult;
import cn.com.poc.agent_application.entity.*; import cn.com.poc.agent_application.entity.*;
import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem; import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem;
import cn.com.poc.agent_application.service.*; import cn.com.poc.agent_application.service.*;
...@@ -15,6 +16,7 @@ import cn.com.poc.common.utils.BlContext; ...@@ -15,6 +16,7 @@ import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils; import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.constant.KnowledgeConstant; import cn.com.poc.knowledge.constant.KnowledgeConstant;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity; import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.service.BizKnowledgeDocumentService; import cn.com.poc.knowledge.service.BizKnowledgeDocumentService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity; import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AICreateImageService; import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AICreateImageService;
...@@ -28,6 +30,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction; ...@@ -28,6 +30,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest; import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SearchKnowledgeResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse; import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum; import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
...@@ -57,6 +60,7 @@ import java.io.BufferedReader; ...@@ -57,6 +60,7 @@ import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.*; import java.util.*;
import java.util.stream.Collectors;
import static cn.com.poc.common.constant.XLangConstant.*; import static cn.com.poc.common.constant.XLangConstant.*;
...@@ -172,7 +176,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -172,7 +176,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Override @Override
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel,
String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception { List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception {
logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}" logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}"
, agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools); , agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools);
...@@ -183,26 +188,31 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -183,26 +188,31 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls, imageUrls); FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls, imageUrls);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId); List<KnowledgeContentResult> knowledgeResult = knowledge(kdIds, messages, topK, score, KnowledgeSearchTypeEnum.MIX);
String promptTemplate = buildDialogsPrompt(functionResult, agentSystem, toolArray, dialogueId, agentId, knowledgeResult);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate); Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse); return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, knowledgeResult, httpServletResponse);
} }
@Override @Override
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception { public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception {
String model = modelConvert(largeModel); String model = modelConvert(largeModel);
Tool[] toolArray = tools.toArray(new Tool[0]); Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(dialogueId, functionCallResult, agentId); FunctionResult functionResult = functionCall(dialogueId, functionCallResult, agentId);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId); List<KnowledgeContentResult> knowledgeResult = knowledge(kdIds, messages, topK, score, KnowledgeSearchTypeEnum.MIX);
String promptTemplate = buildDialogsPrompt(functionResult, agentSystem, toolArray, dialogueId, agentId, knowledgeResult);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate); Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse); return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, knowledgeResult, httpServletResponse);
} }
@Override @Override
...@@ -524,15 +534,23 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -524,15 +534,23 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param model 模型 * @param model 模型
* @param messageArray 消息 * @param messageArray 消息
* @param functionResult 函数结果 * @param functionResult 函数结果
* @param knowledgeResult 知识库结果
* @param httpServletResponse 响应 * @param httpServletResponse 响应
* @return 输出结果 * @return 输出结果
* @throws Exception * @throws Exception
*/ */
private AgentResultEntity llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, HttpServletResponse httpServletResponse) throws Exception { private AgentResultEntity llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, List<KnowledgeContentResult> knowledgeResult, HttpServletResponse httpServletResponse) throws Exception {
if (stream) { if (stream) {
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter();
if (CollectionUtils.isNotEmpty(knowledgeResult)) {
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
result.setKnowledgeContentResult(knowledgeResult);
writer.write(EVENT_STREAM_PREFIX + JsonUtils.serialize(result) + "\n\n");
writer.flush();
}
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) { if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter();
LargeModelDemandResult result = new LargeModelDemandResult(); LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0"); result.setCode("0");
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult); ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
...@@ -584,15 +602,14 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -584,15 +602,14 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
/** /**
* 构建对话提示词 * 构建对话提示词
* *
* @param functionResult 函数结果 * @param functionResult 函数结果
* @param messages 对话消息 * @param agentSystem 应用角色指令
* @param agentSystem 应用角色指令 * @param tools 组件
* @param kdIds 知识库id * @param dialogueId 对话标识符
* @param tools 组件 * @param knowledgeContentResults 知识库结果
* @param dialogueId 对话标识符
* @return * @return
*/ */
private String buildDialogsPrompt(FunctionResult functionResult, List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String dialogueId, String agentId) { private String buildDialogsPrompt(FunctionResult functionResult, String agentSystem, Tool[] tools, String dialogueId, String agentId, List<KnowledgeContentResult> knowledgeContentResults) throws IOException {
String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem(); String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem();
Locale currentLocale = Context.get().getMessageSource().getCurrentLocale(); Locale currentLocale = Context.get().getMessageSource().getCurrentLocale();
// 系统语言 // 系统语言
...@@ -600,7 +617,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -600,7 +617,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
// 应用角色指令 // 应用角色指令
promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY); promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY);
// 调用知识库 // 调用知识库
promptTemplate = buildKnowledgePrompt(messages, kdIds, promptTemplate, null); promptTemplate = buildKnowledgePrompt(knowledgeContentResults, promptTemplate);
// 记忆 // 记忆
promptTemplate = buildMemoryPrompt(promptTemplate, tools, dialogueId, agentId); promptTemplate = buildMemoryPrompt(promptTemplate, tools, dialogueId, agentId);
// 函数调用 // 函数调用
...@@ -611,38 +628,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -611,38 +628,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
/** /**
* 构建知识库提示词 * 构建知识库提示词
* *
* @param messages * @param knowledgeContentResults
* @param kdIds
* @param promptTemplate
* @return * @return
*/ */
private String buildKnowledgePrompt(List<Message> messages, Integer[] kdIds, String promptTemplate, KnowledgeSearchTypeEnum searchTypeEnum) { private String buildKnowledgePrompt(List<KnowledgeContentResult> knowledgeContentResults, String promptTemplate) {
if (ArrayUtils.isNotEmpty(kdIds)) { StringBuilder knowledgePromptBuilder = new StringBuilder("");
List<String> knowledgeIds = new ArrayList<>(); if (CollectionUtils.isNotEmpty(knowledgeContentResults)) {
for (Integer kdId : kdIds) { for (int i = 1; i <= knowledgeContentResults.size(); i++) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId); knowledgePromptBuilder.append("<Chunk ").append(i).append(">").append(":<")
// 筛选训练完成的文档,否则跳过该文档 .append(StringUtils.LF)
if (null == knowledgeDocumentEntity && KnowledgeConstant.TrainStatus.COMPLETE.equals(knowledgeDocumentEntity.getTrainStatus())) { .append(knowledgeContentResults.get(i - 1).getContent())
continue; .append(">")
} .append(StringUtils.LF);
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
}
Object content = messages.get(messages.size() - 1).getContent();
String query = "";
if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString();
} else {
query = content.toString();
} }
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(query, knowledgeIds, 3, searchTypeEnum);
StringBuilder knowledgeResultsBuilder = new StringBuilder();
if (CollectionUtils.isNotEmpty(knowledgeResults)) {
for (int i = 1; i <= knowledgeResults.size(); i++) {
knowledgeResultsBuilder.append("### Chunk ").append(i).append(":").append(StringUtils.LF).append(knowledgeResults.get(i - 1)).append(StringUtils.LF);
}
}
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResultsBuilder.toString());
} }
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgePromptBuilder.toString());
return promptTemplate; return promptTemplate;
} }
...@@ -756,7 +756,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -756,7 +756,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param httpServletResponse * @param httpServletResponse
* @throws IOException * @throws IOException
*/ */
private AgentResultEntity textOutput(HttpServletResponse httpServletResponse, LargeModelDemandResult largeModelDemandResult) throws private AgentResultEntity textOutput(HttpServletResponse httpServletResponse, LargeModelDemandResult
largeModelDemandResult) throws
IOException { IOException {
PrintWriter writer = httpServletResponse.getWriter(); PrintWriter writer = httpServletResponse.getWriter();
writer.write(JsonUtils.serialize(largeModelDemandResult)); writer.write(JsonUtils.serialize(largeModelDemandResult));
...@@ -776,8 +777,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -776,8 +777,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param bufferedReader * @param bufferedReader
* @throws IOException * @throws IOException
*/ */
private AgentResultEntity textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws private AgentResultEntity textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws IOException {
IOException {
String res = ""; String res = "";
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8); httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter(); PrintWriter writer = httpServletResponse.getWriter();
...@@ -834,6 +834,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -834,6 +834,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
break; break;
} }
} }
toolFunction.setResult(functionResult.getFunctionResult());
return toolFunction; return toolFunction;
} }
...@@ -899,7 +900,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -899,7 +900,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param fileUrls * @param fileUrls
* @param imageUrls * @param imageUrls
*/ */
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls, List<String> imageUrls) { private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String
agentId, List<String> fileUrls, List<String> imageUrls) {
FunctionResult result = new FunctionResult(); FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) { if (ArrayUtils.isEmpty(tools)) {
return result; return result;
...@@ -953,6 +955,52 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -953,6 +955,52 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return result; return result;
} }
private List<KnowledgeContentResult> knowledge(Integer[] kdIds, List<Message> messages, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum) {
List<KnowledgeContentResult> knowledgeContentResults = new ArrayList<>();
if (ArrayUtils.isEmpty(kdIds)) {
return knowledgeContentResults;
}
List<String> knowledgeIds = new ArrayList<>();
for (Integer kdId : kdIds) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId);
// 筛选训练完成的文档,否则跳过该文档
if (null == knowledgeDocumentEntity && KnowledgeConstant.TrainStatus.COMPLETE.equals(knowledgeDocumentEntity.getTrainStatus())) {
continue;
}
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
}
Object content = messages.get(messages.size() - 1).getContent();
String query = "";
if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString();
} else {
query = content.toString();
}
SearchKnowledgeResult searchKnowledgeResult = demandKnowledgeService.searchKnowledge(query, knowledgeIds, topK, score, searchTypeEnum);
if (CollectionUtils.isNotEmpty(searchKnowledgeResult.getDocuments())) {
for (int i = 1; i <= searchKnowledgeResult.getDocuments().size(); i++) {
KnowledgeContentResult knowledgeContentResult = new KnowledgeContentResult();
knowledgeContentResult.setContent(searchKnowledgeResult.getDocuments().get(i - 1));
knowledgeContentResult.setKnowledgeId(searchKnowledgeResult.getKnowledgeIds().get(i - 1));
knowledgeContentResult.setScore(searchKnowledgeResult.getScore().get(i - 1));
knowledgeContentResults.add(knowledgeContentResult);
}
// 根据knowledgeId获取知识库名和文档名
knowledgeIds = knowledgeContentResults.stream().map(KnowledgeContentResult::getKnowledgeId).distinct().collect(Collectors.toList());
List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQueryItems = bizKnowledgeDocumentService.knowledgeDocumentRelationQuery(knowledgeIds, null);
for (KnowledgeContentResult result : knowledgeContentResults) {
String knowledgeId = result.getKnowledgeId();
KnowledgeDocumentRelationQueryItem item = knowledgeDocumentRelationQueryItems.stream().filter(v -> v.getKnowledgeId().equals(knowledgeId)).findFirst().get();
result.setKnowledgeName(item.getKnowledgeName());
result.setKdId(item.getKdId());
result.setDocumentName(item.getDocumentName());
}
}
return knowledgeContentResults;
}
/** /**
* 更新【记忆变量】结构 * 更新【记忆变量】结构
......
package cn.com.poc.agent_application.dto;
/**
* @author alex.yao
* @date 2025/2/28
*/
public class KnowledgeContentResult {
private String content;
private String knowledgeId;
private String knowledgeName;
private Integer kdId;
private String documentName;
private Double score;
public Integer getKdId() {
return kdId;
}
public void setKdId(Integer kdId) {
this.kdId = kdId;
}
public String getDocumentName() {
return documentName;
}
public void setDocumentName(String documentName) {
this.documentName = documentName;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(String knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getKnowledgeName() {
return knowledgeName;
}
public void setKnowledgeName(String knowledgeName) {
this.knowledgeName = knowledgeName;
}
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
}
...@@ -24,6 +24,7 @@ import cn.com.poc.equity.entity.BizPointDeductionRulesEntity; ...@@ -24,6 +24,7 @@ import cn.com.poc.equity.entity.BizPointDeductionRulesEntity;
import cn.com.poc.equity.service.BizPointDeductionRulesService; import cn.com.poc.equity.service.BizPointDeductionRulesService;
import cn.com.poc.knowledge.aggregate.KnowledgeService; import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity; import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.AgentLongMemoryEntity; import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.AgentLongMemoryEntity;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.LongMemory; import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.LongMemory;
...@@ -35,6 +36,8 @@ import cn.hutool.core.collection.ListUtil; ...@@ -35,6 +36,8 @@ import cn.hutool.core.collection.ListUtil;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.Assert; import org.springframework.util.Assert;
...@@ -48,6 +51,8 @@ import java.util.stream.Collectors; ...@@ -48,6 +51,8 @@ import java.util.stream.Collectors;
@Component @Component
public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
private Logger logger = LoggerFactory.getLogger(AgentApplicationInfoRest.class);
@Resource @Resource
private BizAgentApplicationInfoService bizAgentApplicationInfoService; private BizAgentApplicationInfoService bizAgentApplicationInfoService;
...@@ -256,11 +261,11 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -256,11 +261,11 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP(); Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP();
Float temperature = dto.getTemperature() == null ? infoEntity.getTemperature() : dto.getTemperature(); Float temperature = dto.getTemperature() == null ? infoEntity.getTemperature() : dto.getTemperature();
String agentSystem = StringUtils.isBlank(dto.getAgentSystem()) ? infoEntity.getAgentSystem() : dto.getAgentSystem(); String agentSystem = StringUtils.isBlank(dto.getAgentSystem()) ? infoEntity.getAgentSystem() : dto.getAgentSystem();
Integer communicationTurn = dto.getCommunicationTurn() == null ? infoEntity.getCommunicationTurn(): dto.getCommunicationTurn(); Integer communicationTurn = dto.getCommunicationTurn() == null ? infoEntity.getCommunicationTurn() : dto.getCommunicationTurn();
// 判断是否调用function // 判断是否调用function
//计算扣分数 //计算扣分数
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools,fileUrls,imageUrls); CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools, fileUrls, imageUrls);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, communicationTurn, checkPluginUseEntity.getDeductionTools()); Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, communicationTurn, checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo(); AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId); agentUseModifyEventInfo.setAgentId(agentId);
...@@ -269,7 +274,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -269,7 +274,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务 //调用应用服务
agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model, agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), communicationTurn, topP, agentSystem, kdIds.toArray(new Integer[0]), communicationTurn, topP,
temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls, httpServletResponse); temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
//数据采集 //数据采集
if (StringUtils.isBlank(dto.getChannel())) { if (StringUtils.isBlank(dto.getChannel())) {
dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel()); dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel());
...@@ -281,7 +288,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -281,7 +288,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
writer.write("data: {\"code\":-1,\"message\":\"" + e.getLocalizedMessage() + "\"} \n\n"); writer.write("data: {\"code\":-1,\"message\":\"" + e.getLocalizedMessage() + "\"} \n\n");
writer.write("data: [DONE]\n\n"); writer.write("data: [DONE]\n\n");
writer.flush(); writer.flush();
writer.close();
memberEquityService.rollbackPoint(reduceSn); memberEquityService.rollbackPoint(reduceSn);
logger.error("preview error", e);
} }
} }
......
...@@ -18,6 +18,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo; ...@@ -18,6 +18,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo;
import cn.com.poc.expose.aggregate.AgentApplicationApiService; import cn.com.poc.expose.aggregate.AgentApplicationApiService;
import cn.com.poc.knowledge.aggregate.KnowledgeService; import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.constants.OauthConstants; import cn.com.poc.support.security.oauth.constants.OauthConstants;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum; import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
...@@ -164,7 +165,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -164,7 +165,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try { try {
AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(), AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, agentResultEntity.getMessage()); saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, agentResultEntity.getMessage());
} catch (Exception e) { } catch (Exception e) {
memberEquityService.rollbackPoint(reduceSn); memberEquityService.rollbackPoint(reduceSn);
......
...@@ -22,6 +22,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo; ...@@ -22,6 +22,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo;
import cn.com.poc.expose.aggregate.AgentApplicationService; import cn.com.poc.expose.aggregate.AgentApplicationService;
import cn.com.poc.knowledge.aggregate.KnowledgeService; import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity; import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum; import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
...@@ -182,7 +183,9 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -182,7 +183,9 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//对话 //对话
AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(), AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
//保存对话记录 //保存对话记录
outputRecord.setContent(agentResultEntity.getMessage()); outputRecord.setContent(agentResultEntity.getMessage());
......
select
bkd.kd_id,
bki.knowledge_name,
bkd.document_name,
bkd.knowledge_id
from
biz_knowledge_document bkd
left join biz_knowledge_info bki on
JSON_CONTAINS(bki.kd_ids , json_array(bkd.kd_id))
where
bkd.is_deleted = 'N'
<<and bkd.knowledge_id in(:knowledgeIds)>>
\ No newline at end of file
package cn.com.poc.knowledge.query;
import java.io.Serializable;
import java.util.List;
/**
* Query Condition class for KnowledgeDocumentRelationQuery
*/
public class KnowledgeDocumentRelationQueryCondition implements Serializable {
private static final long serialVersionUID = 1L;
private List<String> knowledgeIds;
public List<String> getKnowledgeIds() {
return knowledgeIds;
}
public void setKnowledgeIds(List<String> knowledgeIds) {
this.knowledgeIds = knowledgeIds;
}
}
\ No newline at end of file
package cn.com.poc.knowledge.query;
import cn.com.yict.framemax.data.model.BaseItemClass;
import javax.persistence.Column;
import javax.persistence.Entity;
import java.io.Serializable;
/**
* Query Item class for KnowledgeDocumentRelationQuery
*/
@Entity
public class KnowledgeDocumentRelationQueryItem extends BaseItemClass implements Serializable {
private static final long serialVersionUID = 1L;
/**
* kd_id
* kd_id
*/
private Integer kdId;
@Column(name = "kd_id")
public Integer getKdId() {
return this.kdId;
}
public void setKdId(Integer kdId) {
this.kdId = kdId;
}
/**
* knowledge_id
*/
private String knowledgeId;
@Column(name = "knowledge_id")
public String getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(String knowledgeId) {
this.knowledgeId = knowledgeId;
}
/**
* knowledge_name
* knowledge_name
*/
private String knowledgeName;
@Column(name = "knowledge_name")
public String getKnowledgeName() {
return this.knowledgeName;
}
public void setKnowledgeName(String knowledgeName) {
this.knowledgeName = knowledgeName;
}
/**
* document_name
* document_name
*/
private String documentName;
@Column(name = "document_name")
public String getDocumentName() {
return this.documentName;
}
public void setDocumentName(String documentName) {
this.documentName = documentName;
}
}
\ No newline at end of file
package cn.com.poc.knowledge.service; package cn.com.poc.knowledge.service;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity; import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.query.KnowledgeQueryItem; import cn.com.poc.knowledge.query.KnowledgeQueryItem;
import cn.com.yict.framemax.core.service.BaseService; import cn.com.yict.framemax.core.service.BaseService;
import cn.com.yict.framemax.data.model.PagingInfo; import cn.com.yict.framemax.data.model.PagingInfo;
...@@ -20,4 +21,6 @@ public interface BizKnowledgeDocumentService extends BaseService { ...@@ -20,4 +21,6 @@ public interface BizKnowledgeDocumentService extends BaseService {
Boolean deleted(Integer kdId); Boolean deleted(Integer kdId);
List<KnowledgeQueryItem> searchKnowledge(String document, String trainStatus, String memberId, String documentName, List<Integer> kdIds, PagingInfo pagingInfo); List<KnowledgeQueryItem> searchKnowledge(String document, String trainStatus, String memberId, String documentName, List<Integer> kdIds, PagingInfo pagingInfo);
List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQuery(List<String> knowledgeIds, PagingInfo pagingInfo);
} }
\ No newline at end of file
...@@ -6,6 +6,8 @@ import cn.com.poc.common.utils.JsonUtils; ...@@ -6,6 +6,8 @@ import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.convert.KnowledgeDocumentConvert; import cn.com.poc.knowledge.convert.KnowledgeDocumentConvert;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity; import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.model.BizKnowledgeDocumentModel; import cn.com.poc.knowledge.model.BizKnowledgeDocumentModel;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryCondition;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.query.KnowledgeQueryCondition; import cn.com.poc.knowledge.query.KnowledgeQueryCondition;
import cn.com.poc.knowledge.query.KnowledgeQueryItem; import cn.com.poc.knowledge.query.KnowledgeQueryItem;
import cn.com.poc.knowledge.repository.BizKnowledgeDocumentRepository; import cn.com.poc.knowledge.repository.BizKnowledgeDocumentRepository;
...@@ -95,4 +97,11 @@ public class BizKnowledgeDocumentServiceImpl extends BaseServiceImpl ...@@ -95,4 +97,11 @@ public class BizKnowledgeDocumentServiceImpl extends BaseServiceImpl
condition.setKdIds(kdIds); condition.setKdIds(kdIds);
return this.sqlDao.query(condition, KnowledgeQueryItem.class, pagingInfo); return this.sqlDao.query(condition, KnowledgeQueryItem.class, pagingInfo);
} }
@Override
public List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQuery(List<String> knowledgeIds, PagingInfo pagingInfo) {
KnowledgeDocumentRelationQueryCondition condition = new KnowledgeDocumentRelationQueryCondition();
condition.setKnowledgeIds(knowledgeIds);
return this.sqlDao.query(condition, KnowledgeDocumentRelationQueryItem.class, pagingInfo);
}
} }
\ No newline at end of file
...@@ -2,6 +2,7 @@ package cn.com.poc.thirdparty.resource.demand.ai.aggregate; ...@@ -2,6 +2,7 @@ package cn.com.poc.thirdparty.resource.demand.ai.aggregate;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum; import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.GetKnowledgeChunkInfoResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.GetKnowledgeChunkInfoResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SearchKnowledgeResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SegmentationConfigRequest; import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SegmentationConfigRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.qaknowledge.QAKnowledgeConfig; import cn.com.poc.thirdparty.resource.demand.ai.entity.qaknowledge.QAKnowledgeConfig;
import cn.com.yict.framemax.data.model.PagingInfo; import cn.com.yict.framemax.data.model.PagingInfo;
...@@ -50,10 +51,11 @@ public interface DemandKnowledgeService { ...@@ -50,10 +51,11 @@ public interface DemandKnowledgeService {
* @param query 查询文本 * @param query 查询文本
* @param knowledgeIds 知识库id * @param knowledgeIds 知识库id
* @param topK 返回个数 * @param topK 返回个数
* @param score 分数阈值
* @param searchTypeEnum 查询类型 * @param searchTypeEnum 查询类型
* @return 查询结果 * @return 查询结果
*/ */
List<String> searchKnowledge(String query, List<String> knowledgeIds, Integer topK, KnowledgeSearchTypeEnum searchTypeEnum); SearchKnowledgeResult searchKnowledge(String query, List<String> knowledgeIds, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum);
/** /**
* 获取知识库分片 * 获取知识库分片
......
...@@ -15,6 +15,8 @@ import cn.hutool.core.lang.Assert; ...@@ -15,6 +15,8 @@ import cn.hutool.core.lang.Assert;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.apache.http.Header; import org.apache.http.Header;
import org.apache.http.message.BasicHeader; import org.apache.http.message.BasicHeader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import javax.annotation.Resource; import javax.annotation.Resource;
...@@ -24,6 +26,8 @@ import java.util.List; ...@@ -24,6 +26,8 @@ import java.util.List;
@Service @Service
public class DemandKnowledgeServiceImpl implements DemandKnowledgeService { public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
private Logger logger = LoggerFactory.getLogger(DemandKnowledgeService.class);
@Resource @Resource
private DgtoolsAbstractHttpClient dgToolsAbstractHttpClient; private DgtoolsAbstractHttpClient dgToolsAbstractHttpClient;
...@@ -44,7 +48,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService { ...@@ -44,7 +48,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
} }
@Override @Override
public String trainKnowledgeEvent(String fileURL, String knowledgeType, SegmentationConfigRequest segmentationConfig, List<QAKnowledgeConfig> qaKnowledgeConfigs){ public String trainKnowledgeEvent(String fileURL, String knowledgeType, SegmentationConfigRequest segmentationConfig, List<QAKnowledgeConfig> qaKnowledgeConfigs) {
Assert.notBlank(fileURL); Assert.notBlank(fileURL);
TrainKnowledgeRequest request = new TrainKnowledgeRequest(); TrainKnowledgeRequest request = new TrainKnowledgeRequest();
request.setDocumentUrl(fileURL); request.setDocumentUrl(fileURL);
...@@ -82,10 +86,14 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService { ...@@ -82,10 +86,14 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
} }
@Override @Override
public List<String> searchKnowledge(String query, List<String> knowledgeIds, Integer topK, KnowledgeSearchTypeEnum searchTypeEnum) { public SearchKnowledgeResult searchKnowledge(String query, List<String> knowledgeIds, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum) {
Assert.notBlank(query); Assert.notBlank(query);
SearchKnowledgeResult result = new SearchKnowledgeResult();
result.setDocuments(new ArrayList<>());
result.setScore(new ArrayList<>());
result.setKnowledgeIds(new ArrayList<>());
if (CollectionUtils.isEmpty(knowledgeIds)) { if (CollectionUtils.isEmpty(knowledgeIds)) {
return new ArrayList<String>(); return new SearchKnowledgeResult();
} }
if (topK == null) { if (topK == null) {
topK = 3; topK = 3;
...@@ -94,18 +102,16 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService { ...@@ -94,18 +102,16 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
searchKnowledgeRequest.setQuery(query); searchKnowledgeRequest.setQuery(query);
searchKnowledgeRequest.setKnowLedgeIds(knowledgeIds); searchKnowledgeRequest.setKnowLedgeIds(knowledgeIds);
searchKnowledgeRequest.setTopK(topK); searchKnowledgeRequest.setTopK(topK);
searchKnowledgeRequest.setScore(score);
VectorSearchConfig vectorSearchConfig = new VectorSearchConfig(); VectorSearchConfig vectorSearchConfig = new VectorSearchConfig();
vectorSearchConfig.setSearchType(searchTypeEnum); vectorSearchConfig.setSearchType(searchTypeEnum);
vectorSearchConfig.setApiKey(Config.get("large-model.apikey")); vectorSearchConfig.setApiKey(Config.get("large-model.apikey"));
searchKnowledgeRequest.setVectorSearchConfig(vectorSearchConfig); searchKnowledgeRequest.setVectorSearchConfig(vectorSearchConfig);
SearchKnowledgeResult searchKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders()); result = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders());
if (null == searchKnowledgeResult) { if (null == result) {
throw new I18nMessageException("exception/query.knowledge.base.exception"); logger.warn("query knowledge base exception");
}
if (CollectionUtils.isEmpty(searchKnowledgeResult.getDocuments())) {
return new ArrayList<String>();
} }
return searchKnowledgeResult.getDocuments(); return result;
} }
@Override @Override
......
...@@ -5,6 +5,16 @@ public class ToolFunction { ...@@ -5,6 +5,16 @@ public class ToolFunction {
private String arguments; private String arguments;
private String result;
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
public String getName() { public String getName() {
return name; return name;
} }
......
...@@ -14,6 +14,16 @@ public class SearchKnowledgeRequest extends AbstractRequest<SearchKnowledgeResul ...@@ -14,6 +14,16 @@ public class SearchKnowledgeRequest extends AbstractRequest<SearchKnowledgeResul
private VectorSearchConfig vectorSearchConfig; private VectorSearchConfig vectorSearchConfig;
private Double score;
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
public String getQuery() { public String getQuery() {
return query; return query;
} }
......
...@@ -5,9 +5,13 @@ import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult; ...@@ -5,9 +5,13 @@ import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult;
import java.util.List; import java.util.List;
public class SearchKnowledgeResult extends AbstractResult { public class SearchKnowledgeResult extends AbstractResult {
private List<String> documents; private List<String> documents;
private List<Double> score;
private List<String> knowledgeIds;
public List<String> getDocuments() { public List<String> getDocuments() {
return documents; return documents;
} }
...@@ -15,4 +19,20 @@ public class SearchKnowledgeResult extends AbstractResult { ...@@ -15,4 +19,20 @@ public class SearchKnowledgeResult extends AbstractResult {
public void setDocuments(List<String> documents) { public void setDocuments(List<String> documents) {
this.documents = documents; this.documents = documents;
} }
public List<Double> getScore() {
return score;
}
public void setScore(List<Double> score) {
this.score = score;
}
public List<String> getKnowledgeIds() {
return knowledgeIds;
}
public void setKnowledgeIds(List<String> knowledgeIds) {
this.knowledgeIds = knowledgeIds;
}
} }
package cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel; package cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel;
import cn.com.poc.agent_application.dto.KnowledgeContentResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Usage; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Usage;
import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult; import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import java.io.Serializable; import java.io.Serializable;
import java.util.List;
public class LargeModelDemandResult extends AbstractResult implements Serializable { public class LargeModelDemandResult extends AbstractResult implements Serializable {
...@@ -16,10 +18,20 @@ public class LargeModelDemandResult extends AbstractResult implements Serializab ...@@ -16,10 +18,20 @@ public class LargeModelDemandResult extends AbstractResult implements Serializab
private ToolFunction function; private ToolFunction function;
private List<KnowledgeContentResult> knowledgeContentResult;
private String finish_reason; private String finish_reason;
private Usage usage; private Usage usage;
public List<KnowledgeContentResult> getKnowledgeContentResult() {
return knowledgeContentResult;
}
public void setKnowledgeContentResult(List<KnowledgeContentResult> knowledgeContentResult) {
this.knowledgeContentResult = knowledgeContentResult;
}
public String getCode() { public String getCode() {
return code; return code;
} }
......
...@@ -15,6 +15,15 @@ public class QAKnowledgeChunkResult extends AbstractResult { ...@@ -15,6 +15,15 @@ public class QAKnowledgeChunkResult extends AbstractResult {
private List<Chunk> chunk; private List<Chunk> chunk;
private Integer totalChunk;
public Integer getTotalChunk() {
return totalChunk;
}
public void setTotalChunk(Integer totalChunk) {
this.totalChunk = totalChunk;
}
public List<QAChunkKey> getKey() { public List<QAChunkKey> getKey() {
return key; return key;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment