Commit 8b3a3522 authored by alex yao's avatar alex yao

feat: 应用知识库调用配置

parent e81ddd54
......@@ -18,7 +18,7 @@ public interface AgentApplicationInfoService {
* 应用预览
*/
String callAgentApplication(String largeModel, String[] unitIds, String agentSystem,
String[] knowledgeIds, Integer communicationTurn, Float topP,
Integer[] kdIds, Integer communicationTurn, Float topP,
List<Message> messages, HttpServletResponse httpServletResponse) throws Exception;
......
......@@ -11,6 +11,8 @@ import cn.com.poc.agent_application.service.BizAgentApplicationLargeModelListSer
import cn.com.poc.agent_application.service.BizAgentApplicationPublishService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.service.BizKnowledgeDocumentService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AICreateImageService;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
......@@ -51,6 +53,9 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
final private Logger logger = LoggerFactory.getLogger(AgentApplicationInfoService.class);
@Resource
private BizKnowledgeDocumentService bizKnowledgeDocumentService;
@Resource
private BizAgentApplicationLargeModelListService bizAgentApplicationLargeModelListService;
......@@ -93,12 +98,12 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
}
@Override
public String callAgentApplication(String largeModel, String[] unitIds, String agentSystem, String[] knowledgeIds, Integer communicationTurn, Float topP, List<Message> messages, HttpServletResponse httpServletResponse) throws Exception {
logger.info("--------- Call Agent Application large model:{},unitIds:{},agentSystem:{},knowledgeIds:{}" + " communicationTurn:{},topP:{},messages:{}--------------", largeModel, unitIds, agentSystem, knowledgeIds, communicationTurn, topP, messages);
public String callAgentApplication(String largeModel, String[] unitIds, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, List<Message> messages, HttpServletResponse httpServletResponse) throws Exception {
logger.info("--------- Call Agent Application large model:{},unitIds:{},agentSystem:{},kdIds:{}" + " communicationTurn:{},topP:{},messages:{}--------------", largeModel, unitIds, agentSystem, kdIds, communicationTurn, topP, messages);
String model = modelConvert(largeModel);
String promptTemplate = buildDialogsPrompt(messages, agentSystem, knowledgeIds);
String promptTemplate = buildDialogsPrompt(messages, agentSystem, kdIds);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
......@@ -332,12 +337,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return JsonUtils.deSerialize(res.substring(start, end + 1), CreateAgentTitleAndDescEntity.class);
}
private String buildDialogsPrompt(List<Message> messages, String agentSystem, String[] knowledgeIds) {
private String buildDialogsPrompt(List<Message> messages, String agentSystem, Integer[] kdIds) {
String promptTemplate = bizAgentApplicationGcConfigService.getByConfigCode(AgentApplicationGCConfigConstants.AGENT_BASE_SYSTEM).getConfigSystem();
promptTemplate = promptTemplate.replace("${agentSystem}", StringUtils.isNotBlank(agentSystem) ? agentSystem : StringUtils.EMPTY);
// 调用知识库
if (ArrayUtils.isNotEmpty(knowledgeIds)) {
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(messages.get(messages.size() - 1).getContent().get(0).getText(), Arrays.stream(knowledgeIds).collect(Collectors.toList()), 10);
if (ArrayUtils.isNotEmpty(kdIds)) {
List<String> knowledgeIds = new ArrayList<>();
for (Integer kdId : kdIds) {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = bizKnowledgeDocumentService.get(kdId);
if (null == knowledgeDocumentEntity) {
continue;
}
knowledgeIds.add(knowledgeDocumentEntity.getKnowledgeId());
}
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(messages.get(messages.size() - 1).getContent().get(0).getText(), knowledgeIds, 10);
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResults.toString());
}
return promptTemplate;
......
......@@ -12,12 +12,18 @@ import cn.com.poc.agent_application.service.BizAgentApplicationInfoService;
import cn.com.poc.agent_application.service.BizAgentApplicationLargeModelListService;
import cn.com.poc.agent_application.service.BizAgentApplicationPublishService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.common.utils.ListUtils;
import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.knowledge.entity.BizKnowledgeInfoEntity;
import cn.com.poc.knowledge.service.BizKnowledgeInfoService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.data.model.PagingInfo;
import cn.hutool.core.collection.ListUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
......@@ -47,6 +53,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
@Resource
private BizAgentApplicationLargeModelListService bizAgentApplicationLargeModelListService;
@Resource
private KnowledgeService knowledgeService;
public List<AgentApplicationInfoDto> getListByMember(AgentApplicationInfoSearchDto dto, PagingInfo pagingInfo) throws Exception {
UserBaseEntity userBaseEntity = BlContext.getCurrentUserNotException();
Long userId = userBaseEntity.getUserId();
......@@ -132,8 +141,12 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
if (infoEntity == null) {
throw new BusinessException("应用不存在");
}
//获取知识库配置
List<Integer> kdIds = knowledgeService.getKdIdsByKnowledgeInfoIds(infoEntity.getKnowledgeIds());
//调用应用服务
agentApplicationInfoService.callAgentApplication(infoEntity.getLargeModel(), infoEntity.getUnitIds()
, infoEntity.getAgentSystem(), infoEntity.getKnowledgeIds(), infoEntity.getCommunicationTurn(), infoEntity.getTopP()
, infoEntity.getAgentSystem(), kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getTopP()
, dto.getMessages(), httpServletResponse);
} catch (Exception e) {
httpServletResponse.setContentType("text/event-stream");
......
......@@ -16,6 +16,9 @@ import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.expose.aggregate.AgentApplicationService;
import cn.com.poc.expose.dto.AgentApplicationDto;
import cn.com.poc.expose.rest.AgentApplicationRest;
import cn.com.poc.knowledge.aggregate.KnowledgeService;
import cn.com.poc.knowledge.entity.BizKnowledgeInfoEntity;
import cn.com.poc.knowledge.service.BizKnowledgeInfoService;
import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.common.domain.Message;
import cn.com.poc.thirdparty.resource.demand.ai.common.domain.MultiContent;
......@@ -25,6 +28,7 @@ import cn.com.poc.thirdparty.service.LLMService;
import cn.com.yict.framemax.core.exception.BusinessException;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -40,6 +44,10 @@ import java.util.List;
public class AgentApplicationServiceImpl implements AgentApplicationService {
final private Logger logger = LoggerFactory.getLogger(AgentApplicationService.class);
@Resource
private KnowledgeService knowledgeService;
@Resource
private BizAgentApplicationPublishService bizAgentApplicationPublishService;
......@@ -73,13 +81,16 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
throw new BusinessException("未找到应用");
}
//获取知识库配置
List<Integer> kdIdList = knowledgeService.getKdIdsByKnowledgeInfoIds(infoEntity.getKnowledgeIds());
// 构造对话参数
List<Message> messages = new ArrayList<>();
buildMessages(dialogsId, agentId, userBaseEntity.getUserId(), messages, input);
String output = agentApplicationInfoService.callAgentApplication(infoEntity.getLargeModel(), infoEntity.getUnitIds()
, infoEntity.getAgentSystem(), infoEntity.getKnowledgeIds(), infoEntity.getCommunicationTurn(), infoEntity.getTopP()
, infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getTopP()
, messages, httpServletResponse);
......
......@@ -97,5 +97,11 @@ public interface KnowledgeService {
*/
void addKnowledgeChunk(Integer kdId, String content, Integer chunkSort);
/**
* 获取知识库文档ID列表
*
* @param knowledgeInfoIds 知识库信息ID列表
*/
List<Integer> getKdIdsByKnowledgeInfoIds(String[] knowledgeInfoIds) throws Exception;
}
......@@ -22,7 +22,9 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.SegmentationCon
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.data.model.PagingInfo;
import cn.hutool.core.bean.BeanUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
......@@ -258,4 +260,23 @@ public class KnowledgeServiceImpl implements KnowledgeService {
String knowledgeId = bizKnowledgeDocumentEntity.getKnowledgeId();
demandKnowledgeService.addKnowledgeChunk(knowledgeId, content, chunkSort);
}
@Override
public List<Integer> getKdIdsByKnowledgeInfoIds(String[] knowledgeInfoIds) throws Exception {
//获取知识库配置
List<Integer> kdIdList = new ArrayList<>();
if (ArrayUtils.isNotEmpty(knowledgeInfoIds)) {
String[] knowledgeIds = knowledgeInfoIds;
for (String knowledgeId : knowledgeIds) {
BizKnowledgeInfoEntity bizKnowledgeInfoEntity = bizKnowledgeInfoService.get(Integer.valueOf(knowledgeId));
if (bizKnowledgeInfoEntity == null || org.apache.commons.lang3.StringUtils.isBlank(bizKnowledgeInfoEntity.getKdIds())) {
continue;
}
List<Integer> kdIds = JsonUtils.deSerialize(bizKnowledgeInfoEntity.getKdIds(), new TypeReference<List<Integer>>() {
}.getType());
kdIdList.addAll(kdIds);
}
}
return kdIdList;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment