Commit a3fd0216 authored by alex yao's avatar alex yao

feat:Agent对话 返回知识库和插件信息

parent d60ce0d1
......@@ -4,6 +4,7 @@ import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.entity.AgentResultEntity;
import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity;
import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
......@@ -51,10 +52,14 @@ public interface AgentApplicationInfoService {
* @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输
* @param score 知识库参数score
* @param topK 知识库参数topK
* @param knowledgeSearchType 知识库参数知识搜索类型
*/
AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception;
/**
* Agent应用对话
......@@ -73,10 +78,14 @@ public interface AgentApplicationInfoService {
* @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输
* @param score 知识库参数score
* @param topK 知识库参数topK
* @param knowledgeSearchType 知识库参数知识搜索类型
*/
AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception;
/**
* 应用下架
......
......@@ -5,6 +5,7 @@ import cn.com.poc.agent_application.constant.AgentApplicationConstants;
import cn.com.poc.agent_application.constant.AgentApplicationDialoguesRecordConstants;
import cn.com.poc.agent_application.constant.AgentApplicationGCConfigConstants;
import cn.com.poc.agent_application.domain.FunctionResult;
import cn.com.poc.agent_application.dto.KnowledgeContentResult;
import cn.com.poc.agent_application.entity.*;
import cn.com.poc.agent_application.query.DialogsIdsQueryByAgentIdQueryItem;
import cn.com.poc.agent_application.service.*;
......@@ -15,6 +16,7 @@ import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.constant.KnowledgeConstant;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.service.BizKnowledgeDocumentService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AICreateImageService;
......@@ -28,6 +30,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SearchKnowledgeResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
......@@ -57,6 +60,7 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;
import static cn.com.poc.common.constant.XLangConstant.*;
......@@ -172,7 +176,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Override
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel,
String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception {
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception {
logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}"
, agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools);
......@@ -183,26 +188,31 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls, imageUrls);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId);
List<KnowledgeContentResult> knowledgeResult = knowledge(kdIds, messages, topK, score, KnowledgeSearchTypeEnum.MIX);
String promptTemplate = buildDialogsPrompt(functionResult, agentSystem, toolArray, dialogueId, agentId, knowledgeResult);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, knowledgeResult, httpServletResponse);
}
@Override
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception {
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, HttpServletResponse httpServletResponse) throws Exception {
String model = modelConvert(largeModel);
Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(dialogueId, functionCallResult, agentId);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId);
List<KnowledgeContentResult> knowledgeResult = knowledge(kdIds, messages, topK, score, KnowledgeSearchTypeEnum.MIX);
String promptTemplate = buildDialogsPrompt(functionResult, agentSystem, toolArray, dialogueId, agentId, knowledgeResult);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, knowledgeResult, httpServletResponse);
}
@Override
......@@ -524,15 +534,23 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param model 模型
* @param messageArray 消息
* @param functionResult 函数结果
* @param knowledgeResult 知识库结果
* @param httpServletResponse 响应
* @return 输出结果
* @throws Exception
*/
private AgentResultEntity llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, HttpServletResponse httpServletResponse) throws Exception {
private AgentResultEntity llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, List<KnowledgeContentResult> knowledgeResult, HttpServletResponse httpServletResponse) throws Exception {
if (stream) {
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter();
if (CollectionUtils.isNotEmpty(knowledgeResult)) {
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
result.setKnowledgeContentResult(knowledgeResult);
writer.write(EVENT_STREAM_PREFIX + JsonUtils.serialize(result) + "\n\n");
writer.flush();
}
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
......@@ -585,14 +603,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* 构建对话提示词
*
* @param functionResult 函数结果
* @param messages 对话消息
* @param agentSystem 应用角色指令
* @param kdIds 知识库id
* @param tools 组件
* @param dialogueId 对话标识符
* @param knowledgeContentResults 知识库结果
* @return
*/
private String buildDialogsPrompt(FunctionResult functionResult, List<Message> messages, String agentSystem, Integer[] kdIds, Tool[] tools, String dialogueId, String agentId) {
private String buildDialogsPrompt(FunctionResult functionResult, String agentSystem, Tool[] tools, String dialogueId, String agentId, List<KnowledgeContentResult> knowledgeContentResults) throws IOException {
String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem();
Locale currentLocale = Context.get().getMessageSource().getCurrentLocale();
// 系统语言
......@@ -600,7 +617,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
// 应用角色指令
promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY);
// 调用知识库
promptTemplate = buildKnowledgePrompt(messages, kdIds, promptTemplate, null);
promptTemplate = buildKnowledgePrompt(knowledgeContentResults, promptTemplate);
// 记忆
promptTemplate = buildMemoryPrompt(promptTemplate, tools, dialogueId, agentId);
// 函数调用
......@@ -611,38 +628,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
/**
* 构建知识库提示词
*
* @param messages
* @param kdIds
* @param promptTemplate
* @param knowledgeContentResults
* @return
*/
private String buildKnowledgePrompt(List<Message> messages, Integer[] kdIds, String promptTemplate, KnowledgeSearchTypeEnum searchTypeEnum) {
if (ArrayUtils.isNotEmpty(kdIds)) {
List<String> knowledgeIds = new ArrayList<>();
for (Integer kdId : kdIds) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId);
// 筛选训练完成的文档,否则跳过该文档
if (null == knowledgeDocumentEntity && KnowledgeConstant.TrainStatus.COMPLETE.equals(knowledgeDocumentEntity.getTrainStatus())) {
continue;
}
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
}
Object content = messages.get(messages.size() - 1).getContent();
String query = "";
if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString();
} else {
query = content.toString();
}
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(query, knowledgeIds, 3, searchTypeEnum);
StringBuilder knowledgeResultsBuilder = new StringBuilder();
if (CollectionUtils.isNotEmpty(knowledgeResults)) {
for (int i = 1; i <= knowledgeResults.size(); i++) {
knowledgeResultsBuilder.append("### Chunk ").append(i).append(":").append(StringUtils.LF).append(knowledgeResults.get(i - 1)).append(StringUtils.LF);
}
private String buildKnowledgePrompt(List<KnowledgeContentResult> knowledgeContentResults, String promptTemplate) {
StringBuilder knowledgePromptBuilder = new StringBuilder("");
if (CollectionUtils.isNotEmpty(knowledgeContentResults)) {
for (int i = 1; i <= knowledgeContentResults.size(); i++) {
knowledgePromptBuilder.append("<Chunk ").append(i).append(">").append(":<")
.append(StringUtils.LF)
.append(knowledgeContentResults.get(i - 1).getContent())
.append(">")
.append(StringUtils.LF);
}
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResultsBuilder.toString());
}
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgePromptBuilder.toString());
return promptTemplate;
}
......@@ -756,7 +756,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param httpServletResponse
* @throws IOException
*/
private AgentResultEntity textOutput(HttpServletResponse httpServletResponse, LargeModelDemandResult largeModelDemandResult) throws
private AgentResultEntity textOutput(HttpServletResponse httpServletResponse, LargeModelDemandResult
largeModelDemandResult) throws
IOException {
PrintWriter writer = httpServletResponse.getWriter();
writer.write(JsonUtils.serialize(largeModelDemandResult));
......@@ -776,8 +777,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param bufferedReader
* @throws IOException
*/
private AgentResultEntity textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws
IOException {
private AgentResultEntity textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws IOException {
String res = "";
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter();
......@@ -834,6 +834,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
break;
}
}
toolFunction.setResult(functionResult.getFunctionResult());
return toolFunction;
}
......@@ -899,7 +900,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param fileUrls
* @param imageUrls
*/
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls, List<String> imageUrls) {
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String
agentId, List<String> fileUrls, List<String> imageUrls) {
FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) {
return result;
......@@ -953,6 +955,52 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return result;
}
private List<KnowledgeContentResult> knowledge(Integer[] kdIds, List<Message> messages, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum) {
List<KnowledgeContentResult> knowledgeContentResults = new ArrayList<>();
if (ArrayUtils.isEmpty(kdIds)) {
return knowledgeContentResults;
}
List<String> knowledgeIds = new ArrayList<>();
for (Integer kdId : kdIds) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId);
// 筛选训练完成的文档,否则跳过该文档
if (null == knowledgeDocumentEntity && KnowledgeConstant.TrainStatus.COMPLETE.equals(knowledgeDocumentEntity.getTrainStatus())) {
continue;
}
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
}
Object content = messages.get(messages.size() - 1).getContent();
String query = "";
if (content instanceof List) {
query = ((List<HashMap>) content).get(0).get("text").toString();
} else {
query = content.toString();
}
SearchKnowledgeResult searchKnowledgeResult = demandKnowledgeService.searchKnowledge(query, knowledgeIds, topK, score, searchTypeEnum);
if (CollectionUtils.isNotEmpty(searchKnowledgeResult.getDocuments())) {
for (int i = 1; i <= searchKnowledgeResult.getDocuments().size(); i++) {
KnowledgeContentResult knowledgeContentResult = new KnowledgeContentResult();
knowledgeContentResult.setContent(searchKnowledgeResult.getDocuments().get(i - 1));
knowledgeContentResult.setKnowledgeId(searchKnowledgeResult.getKnowledgeIds().get(i - 1));
knowledgeContentResult.setScore(searchKnowledgeResult.getScore().get(i - 1));
knowledgeContentResults.add(knowledgeContentResult);
}
// 根据knowledgeId获取知识库名和文档名
knowledgeIds = knowledgeContentResults.stream().map(KnowledgeContentResult::getKnowledgeId).distinct().collect(Collectors.toList());
List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQueryItems = bizKnowledgeDocumentService.knowledgeDocumentRelationQuery(knowledgeIds, null);
for (KnowledgeContentResult result : knowledgeContentResults) {
String knowledgeId = result.getKnowledgeId();
KnowledgeDocumentRelationQueryItem item = knowledgeDocumentRelationQueryItems.stream().filter(v -> v.getKnowledgeId().equals(knowledgeId)).findFirst().get();
result.setKnowledgeName(item.getKnowledgeName());
result.setKdId(item.getKdId());
result.setDocumentName(item.getDocumentName());
}
}
return knowledgeContentResults;
}
/**
* 更新【记忆变量】结构
......
package cn.com.poc.agent_application.dto;
/**
* @author alex.yao
* @date 2025/2/28
*/
public class KnowledgeContentResult {
private String content;
private String knowledgeId;
private String knowledgeName;
private Integer kdId;
private String documentName;
private Double score;
public Integer getKdId() {
return kdId;
}
public void setKdId(Integer kdId) {
this.kdId = kdId;
}
public String getDocumentName() {
return documentName;
}
public void setDocumentName(String documentName) {
this.documentName = documentName;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(String knowledgeId) {
this.knowledgeId = knowledgeId;
}
public String getKnowledgeName() {
return knowledgeName;
}
public void setKnowledgeName(String knowledgeName) {
this.knowledgeName = knowledgeName;
}
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
}
......@@ -24,6 +24,7 @@ import cn.com.poc.equity.entity.BizPointDeductionRulesEntity;
import cn.com.poc.equity.service.BizPointDeductionRulesService;
import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.AgentLongMemoryEntity;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.LongMemory;
......@@ -35,6 +36,8 @@ import cn.hutool.core.collection.ListUtil;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
......@@ -48,6 +51,8 @@ import java.util.stream.Collectors;
@Component
public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
private Logger logger = LoggerFactory.getLogger(AgentApplicationInfoRest.class);
@Resource
private BizAgentApplicationInfoService bizAgentApplicationInfoService;
......@@ -256,11 +261,11 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP();
Float temperature = dto.getTemperature() == null ? infoEntity.getTemperature() : dto.getTemperature();
String agentSystem = StringUtils.isBlank(dto.getAgentSystem()) ? infoEntity.getAgentSystem() : dto.getAgentSystem();
Integer communicationTurn = dto.getCommunicationTurn() == null ? infoEntity.getCommunicationTurn(): dto.getCommunicationTurn();
Integer communicationTurn = dto.getCommunicationTurn() == null ? infoEntity.getCommunicationTurn() : dto.getCommunicationTurn();
// 判断是否调用function
//计算扣分数
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools,fileUrls,imageUrls);
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools, fileUrls, imageUrls);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, communicationTurn, checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId);
......@@ -269,7 +274,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务
agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), communicationTurn, topP,
temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls, httpServletResponse);
temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
//数据采集
if (StringUtils.isBlank(dto.getChannel())) {
dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel());
......@@ -281,7 +288,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
writer.write("data: {\"code\":-1,\"message\":\"" + e.getLocalizedMessage() + "\"} \n\n");
writer.write("data: [DONE]\n\n");
writer.flush();
writer.close();
memberEquityService.rollbackPoint(reduceSn);
logger.error("preview error", e);
}
}
......
......@@ -18,6 +18,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo;
import cn.com.poc.expose.aggregate.AgentApplicationApiService;
import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.constants.OauthConstants;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
......@@ -164,7 +165,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try {
AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls, httpServletResponse);
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, agentResultEntity.getMessage());
} catch (Exception e) {
memberEquityService.rollbackPoint(reduceSn);
......
......@@ -22,6 +22,7 @@ import cn.com.poc.equity.domain.modifyEquityInfo.AgentUseModifyEventInfo;
import cn.com.poc.expose.aggregate.AgentApplicationService;
import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
......@@ -182,7 +183,9 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//对话
AgentResultEntity agentResultEntity = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls, httpServletResponse);
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
httpServletResponse);
//保存对话记录
outputRecord.setContent(agentResultEntity.getMessage());
......
select
bkd.kd_id,
bki.knowledge_name,
bkd.document_name,
bkd.knowledge_id
from
biz_knowledge_document bkd
left join biz_knowledge_info bki on
JSON_CONTAINS(bki.kd_ids , json_array(bkd.kd_id))
where
bkd.is_deleted = 'N'
<<and bkd.knowledge_id in(:knowledgeIds)>>
\ No newline at end of file
package cn.com.poc.knowledge.query;
import java.io.Serializable;
import java.util.List;
/**
* Query Condition class for KnowledgeDocumentRelationQuery
*/
public class KnowledgeDocumentRelationQueryCondition implements Serializable {
private static final long serialVersionUID = 1L;
private List<String> knowledgeIds;
public List<String> getKnowledgeIds() {
return knowledgeIds;
}
public void setKnowledgeIds(List<String> knowledgeIds) {
this.knowledgeIds = knowledgeIds;
}
}
\ No newline at end of file
package cn.com.poc.knowledge.query;
import cn.com.yict.framemax.data.model.BaseItemClass;
import javax.persistence.Column;
import javax.persistence.Entity;
import java.io.Serializable;
/**
* Query Item class for KnowledgeDocumentRelationQuery
*/
@Entity
public class KnowledgeDocumentRelationQueryItem extends BaseItemClass implements Serializable {
private static final long serialVersionUID = 1L;
/**
* kd_id
* kd_id
*/
private Integer kdId;
@Column(name = "kd_id")
public Integer getKdId() {
return this.kdId;
}
public void setKdId(Integer kdId) {
this.kdId = kdId;
}
/**
* knowledge_id
*/
private String knowledgeId;
@Column(name = "knowledge_id")
public String getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(String knowledgeId) {
this.knowledgeId = knowledgeId;
}
/**
* knowledge_name
* knowledge_name
*/
private String knowledgeName;
@Column(name = "knowledge_name")
public String getKnowledgeName() {
return this.knowledgeName;
}
public void setKnowledgeName(String knowledgeName) {
this.knowledgeName = knowledgeName;
}
/**
* document_name
* document_name
*/
private String documentName;
@Column(name = "document_name")
public String getDocumentName() {
return this.documentName;
}
public void setDocumentName(String documentName) {
this.documentName = documentName;
}
}
\ No newline at end of file
package cn.com.poc.knowledge.service;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.query.KnowledgeQueryItem;
import cn.com.yict.framemax.core.service.BaseService;
import cn.com.yict.framemax.data.model.PagingInfo;
......@@ -20,4 +21,6 @@ public interface BizKnowledgeDocumentService extends BaseService {
Boolean deleted(Integer kdId);
List<KnowledgeQueryItem> searchKnowledge(String document, String trainStatus, String memberId, String documentName, List<Integer> kdIds, PagingInfo pagingInfo);
List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQuery(List<String> knowledgeIds, PagingInfo pagingInfo);
}
\ No newline at end of file
......@@ -6,6 +6,8 @@ import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.convert.KnowledgeDocumentConvert;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.model.BizKnowledgeDocumentModel;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryCondition;
import cn.com.poc.knowledge.query.KnowledgeDocumentRelationQueryItem;
import cn.com.poc.knowledge.query.KnowledgeQueryCondition;
import cn.com.poc.knowledge.query.KnowledgeQueryItem;
import cn.com.poc.knowledge.repository.BizKnowledgeDocumentRepository;
......@@ -95,4 +97,11 @@ public class BizKnowledgeDocumentServiceImpl extends BaseServiceImpl
condition.setKdIds(kdIds);
return this.sqlDao.query(condition, KnowledgeQueryItem.class, pagingInfo);
}
@Override
public List<KnowledgeDocumentRelationQueryItem> knowledgeDocumentRelationQuery(List<String> knowledgeIds, PagingInfo pagingInfo) {
KnowledgeDocumentRelationQueryCondition condition = new KnowledgeDocumentRelationQueryCondition();
condition.setKnowledgeIds(knowledgeIds);
return this.sqlDao.query(condition, KnowledgeDocumentRelationQueryItem.class, pagingInfo);
}
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ package cn.com.poc.thirdparty.resource.demand.ai.aggregate;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.GetKnowledgeChunkInfoResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SearchKnowledgeResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SegmentationConfigRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.qaknowledge.QAKnowledgeConfig;
import cn.com.yict.framemax.data.model.PagingInfo;
......@@ -50,10 +51,11 @@ public interface DemandKnowledgeService {
* @param query 查询文本
* @param knowledgeIds 知识库id
* @param topK 返回个数
* @param score 分数阈值
* @param searchTypeEnum 查询类型
* @return 查询结果
*/
List<String> searchKnowledge(String query, List<String> knowledgeIds, Integer topK, KnowledgeSearchTypeEnum searchTypeEnum);
SearchKnowledgeResult searchKnowledge(String query, List<String> knowledgeIds, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum);
/**
* 获取知识库分片
......
......@@ -15,6 +15,8 @@ import cn.hutool.core.lang.Assert;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.http.Header;
import org.apache.http.message.BasicHeader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
......@@ -24,6 +26,8 @@ import java.util.List;
@Service
public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
private Logger logger = LoggerFactory.getLogger(DemandKnowledgeService.class);
@Resource
private DgtoolsAbstractHttpClient dgToolsAbstractHttpClient;
......@@ -44,7 +48,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
}
@Override
public String trainKnowledgeEvent(String fileURL, String knowledgeType, SegmentationConfigRequest segmentationConfig, List<QAKnowledgeConfig> qaKnowledgeConfigs){
public String trainKnowledgeEvent(String fileURL, String knowledgeType, SegmentationConfigRequest segmentationConfig, List<QAKnowledgeConfig> qaKnowledgeConfigs) {
Assert.notBlank(fileURL);
TrainKnowledgeRequest request = new TrainKnowledgeRequest();
request.setDocumentUrl(fileURL);
......@@ -82,10 +86,14 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
}
@Override
public List<String> searchKnowledge(String query, List<String> knowledgeIds, Integer topK, KnowledgeSearchTypeEnum searchTypeEnum) {
public SearchKnowledgeResult searchKnowledge(String query, List<String> knowledgeIds, Integer topK, Double score, KnowledgeSearchTypeEnum searchTypeEnum) {
Assert.notBlank(query);
SearchKnowledgeResult result = new SearchKnowledgeResult();
result.setDocuments(new ArrayList<>());
result.setScore(new ArrayList<>());
result.setKnowledgeIds(new ArrayList<>());
if (CollectionUtils.isEmpty(knowledgeIds)) {
return new ArrayList<String>();
return new SearchKnowledgeResult();
}
if (topK == null) {
topK = 3;
......@@ -94,18 +102,16 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
searchKnowledgeRequest.setQuery(query);
searchKnowledgeRequest.setKnowLedgeIds(knowledgeIds);
searchKnowledgeRequest.setTopK(topK);
searchKnowledgeRequest.setScore(score);
VectorSearchConfig vectorSearchConfig = new VectorSearchConfig();
vectorSearchConfig.setSearchType(searchTypeEnum);
vectorSearchConfig.setApiKey(Config.get("large-model.apikey"));
searchKnowledgeRequest.setVectorSearchConfig(vectorSearchConfig);
SearchKnowledgeResult searchKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders());
if (null == searchKnowledgeResult) {
throw new I18nMessageException("exception/query.knowledge.base.exception");
}
if (CollectionUtils.isEmpty(searchKnowledgeResult.getDocuments())) {
return new ArrayList<String>();
result = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders());
if (null == result) {
logger.warn("query knowledge base exception");
}
return searchKnowledgeResult.getDocuments();
return result;
}
@Override
......
......@@ -5,6 +5,16 @@ public class ToolFunction {
private String arguments;
private String result;
public String getResult() {
return result;
}
public void setResult(String result) {
this.result = result;
}
public String getName() {
return name;
}
......
......@@ -14,6 +14,16 @@ public class SearchKnowledgeRequest extends AbstractRequest<SearchKnowledgeResul
private VectorSearchConfig vectorSearchConfig;
private Double score;
public Double getScore() {
return score;
}
public void setScore(Double score) {
this.score = score;
}
public String getQuery() {
return query;
}
......
......@@ -5,9 +5,13 @@ import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult;
import java.util.List;
public class SearchKnowledgeResult extends AbstractResult {
private List<String> documents;
private List<Double> score;
private List<String> knowledgeIds;
public List<String> getDocuments() {
return documents;
}
......@@ -15,4 +19,20 @@ public class SearchKnowledgeResult extends AbstractResult {
public void setDocuments(List<String> documents) {
this.documents = documents;
}
public List<Double> getScore() {
return score;
}
public void setScore(List<Double> score) {
this.score = score;
}
public List<String> getKnowledgeIds() {
return knowledgeIds;
}
public void setKnowledgeIds(List<String> knowledgeIds) {
this.knowledgeIds = knowledgeIds;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel;
import cn.com.poc.agent_application.dto.KnowledgeContentResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Usage;
import cn.com.poc.thirdparty.resource.demand.dgTools.result.AbstractResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import java.io.Serializable;
import java.util.List;
public class LargeModelDemandResult extends AbstractResult implements Serializable {
......@@ -16,10 +18,20 @@ public class LargeModelDemandResult extends AbstractResult implements Serializab
private ToolFunction function;
private List<KnowledgeContentResult> knowledgeContentResult;
private String finish_reason;
private Usage usage;
public List<KnowledgeContentResult> getKnowledgeContentResult() {
return knowledgeContentResult;
}
public void setKnowledgeContentResult(List<KnowledgeContentResult> knowledgeContentResult) {
this.knowledgeContentResult = knowledgeContentResult;
}
public String getCode() {
return code;
}
......
......@@ -15,6 +15,15 @@ public class QAKnowledgeChunkResult extends AbstractResult {
private List<Chunk> chunk;
private Integer totalChunk;
public Integer getTotalChunk() {
return totalChunk;
}
public void setTotalChunk(Integer totalChunk) {
this.totalChunk = totalChunk;
}
public List<QAChunkKey> getKey() {
return key;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment