Commit cbfb29ad authored by Roger Wu's avatar Roger Wu

Merge branch 'release' of ssh://gitlab.gsstcloud.com:10022/poc/poc-api into release

parents 43aaf3a4 a13fec29
......@@ -151,7 +151,18 @@
<version>4.12</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>3.7.7</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>3.7.7</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-test</artifactId>
......@@ -375,6 +386,13 @@
<version>0.10.328</version>
</dependency>
<dependency>
<groupId>io.github.furstenheim</groupId>
<artifactId>copy_down</artifactId>
<version>1.0</version>
</dependency>
</dependencies>
......
......@@ -12,12 +12,16 @@ public interface AgentApplicationInfoService {
/**
* 创建/更新应用
*
* @param entity 应用信息
*/
BizAgentApplicationInfoEntity saveOrUpdate(BizAgentApplicationInfoEntity entity) throws Exception;
/**
* 更新并发布应用
*
* @param entity 应用信息
*/
boolean updateAndPublish(BizAgentApplicationInfoEntity entity) throws Exception;
......@@ -29,16 +33,28 @@ public interface AgentApplicationInfoService {
boolean deletedAgentApplication(String agentId) throws Exception;
/**
* 应用预览
* Agent应用对话
*
* @param agentId 应用ID
* @param identifier 对话唯一标识
* @param largeModel 模型
* @param agentSystem 应用角色指令
* @param knowledgeIds 知识库ID
* @param communicationTurn 对话轮数
* @param topP 模型参数topP
* @param temperature 模型参数temperature
* @param messages 对话消息
* @param tools 插件配置
* @param fileUrls 文件URL
*/
String callAgentApplication(String identifier, String largeModel, String[] unitIds, String agentSystem,
String callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, HttpServletResponse httpServletResponse) throws Exception;
List<Message> messages, List<Tool> tools, List<String> fileUrls, HttpServletResponse httpServletResponse) throws Exception;
/**
* 应用下架
*
* @param agentId
* @param agentId 应用ID
* @return
*/
boolean unPublish(String agentId) throws Exception;
......@@ -46,7 +62,7 @@ public interface AgentApplicationInfoService {
/**
* 角色指令AI生成
*
* @param input
* @param input 用户输入内容
* @param httpServletResponse
* @return
*/
......
package cn.com.poc.agent_application.convert;
import cn.com.poc.agent_application.domain.AgentApplicationBaseInfo;
import cn.com.poc.agent_application.domain.AgentApplicationCommConfig;
import cn.com.poc.agent_application.domain.AgentApplicationCommModelConfig;
import cn.com.poc.agent_application.domain.AgentApplicationKnowledgeConfig;
import cn.com.poc.agent_application.domain.*;
import cn.com.poc.agent_application.dto.AgentApplicationInfoDto;
import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.agent_application.entity.VoiceConfig;
import cn.com.poc.agent_application.model.BizAgentApplicationInfoModel;
import cn.com.poc.agent_application.query.AgentApplicationInfoQueryItem;
import cn.com.poc.agent_application.query.PublishAgentApplicationQueryItem;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.expose.dto.SearchAgentApplicationDto;
import cn.hutool.core.util.ObjectUtil;
......@@ -65,12 +64,17 @@ public class AgentApplicationInfoConvert {
if (StringUtils.isNotBlank(model.getKnowledgeIds())) {
entity.setKnowledgeIds(JsonUtils.deSerialize(model.getKnowledgeIds(), Integer[].class));
}
entity.setIsDocumentParsing(model.getIsDocumentParsing());
entity.setLargeModel(model.getLargeModel());
entity.setTopP(model.getTopP());
entity.setTemperature(model.getTemperature());
if (StringUtils.isNotBlank(model.getUnitIds())) {
entity.setUnitIds(JsonUtils.deSerialize(model.getUnitIds(), String[].class));
}
if (StringUtils.isNotBlank(model.getVoiceConfig())) {
VoiceConfig voiceConfig = JsonUtils.deSerialize(model.getVoiceConfig(), VoiceConfig.class);
entity.setVoiceConfig(voiceConfig);
}
entity.setIsDeleted(model.getIsDeleted());
entity.setCreator(model.getCreator());
entity.setCreatedTime(model.getCreatedTime());
......@@ -93,6 +97,7 @@ public class AgentApplicationInfoConvert {
model.setPreamble(entity.getPreamble());
model.setPublishTime(entity.getPublishTime());
model.setIsLongMemory(entity.getIsLongMemory());
model.setIsDocumentParsing(entity.getIsDocumentParsing());
if (CollectionUtils.isNotEmpty(entity.getVariableStructure())) {
model.setVariableStructure(JsonUtils.serialize(entity.getVariableStructure()));
}
......@@ -112,6 +117,9 @@ public class AgentApplicationInfoConvert {
if (ArrayUtils.isNotEmpty(entity.getUnitIds())) {
model.setUnitIds(JsonUtil.toJson(entity.getUnitIds()));
}
if (ObjectUtil.isNotEmpty(entity.getVoiceConfig())) {
model.setVoiceConfig(JsonUtils.serialize(entity.getVoiceConfig()));
}
model.setIsDeleted(entity.getIsDeleted());
model.setCreator(entity.getCreator());
model.setCreatedTime(entity.getCreatedTime());
......@@ -144,6 +152,7 @@ public class AgentApplicationInfoConvert {
AgentApplicationKnowledgeConfig knowledgeConfig = new AgentApplicationKnowledgeConfig();
knowledgeConfig.setKnowledgeIds(entity.getKnowledgeIds());
knowledgeConfig.setIsDocumentParsing(entity.getIsDocumentParsing());
AgentApplicationCommModelConfig commModelConfig = new AgentApplicationCommModelConfig();
commModelConfig.setLargeModel(entity.getLargeModel());
......@@ -151,11 +160,21 @@ public class AgentApplicationInfoConvert {
commModelConfig.setTemperature(entity.getTemperature());
commModelConfig.setCommunicationTurn(entity.getCommunicationTurn());
AgentApplicationVoiceConfig voiceConfig = new AgentApplicationVoiceConfig();
if (ObjectUtil.isNotEmpty(entity.getVoiceConfig())) {
voiceConfig.setDefaultOpen(entity.getVoiceConfig().getDefaultOpen());
voiceConfig.setTimbreId(entity.getVoiceConfig().getTimbreId());
} else {
voiceConfig.setDefaultOpen(CommonConstant.YOrN.N);
voiceConfig.setTimbreId(StringUtils.EMPTY);
}
AgentApplicationInfoDto dto = new AgentApplicationInfoDto();
dto.setBaseInfo(baseInfo);
dto.setCommConfig(commConfig);
dto.setKnowledgeConfig(knowledgeConfig);
dto.setCommModelConfig(commModelConfig);
dto.setVoiceConfig(voiceConfig);
dto.setUnitIds(entity.getUnitIds());
dto.setCreator(entity.getCreator());
dto.setCreatedTime(entity.getCreatedTime());
......@@ -191,6 +210,7 @@ public class AgentApplicationInfoConvert {
if (ObjectUtil.isNotEmpty(dto.getKnowledgeConfig())) {
entity.setKnowledgeIds(dto.getKnowledgeConfig().getKnowledgeIds());
entity.setIsDocumentParsing(dto.getKnowledgeConfig().getIsDocumentParsing());
}
if (ObjectUtil.isNotEmpty(dto.getCommModelConfig())) {
......@@ -200,6 +220,16 @@ public class AgentApplicationInfoConvert {
entity.setCommunicationTurn(dto.getCommModelConfig().getCommunicationTurn());
}
VoiceConfig voiceConfig = new VoiceConfig();
if (ObjectUtil.isNotEmpty(dto.getVoiceConfig())) {
voiceConfig.setDefaultOpen(dto.getVoiceConfig().getDefaultOpen());
voiceConfig.setTimbreId(dto.getVoiceConfig().getTimbreId());
} else {
voiceConfig.setDefaultOpen(CommonConstant.YOrN.N);
voiceConfig.setTimbreId(StringUtils.EMPTY);
}
entity.setVoiceConfig(voiceConfig);
entity.setUnitIds(dto.getUnitIds());
entity.setCreator(dto.getCreator());
entity.setCreatedTime(dto.getCreatedTime());
......@@ -221,6 +251,7 @@ public class AgentApplicationInfoConvert {
entity.setPreamble(infoQueryItem.getPreamble());
entity.setPublishTime(infoQueryItem.getPublishTime());
entity.setIsLongMemory(infoQueryItem.getIsLongMemory());
entity.setIsDocumentParsing(infoQueryItem.getIsDocumentParsing());
if (StringUtils.isNotBlank(infoQueryItem.getVariableStructure())) {
entity.setVariableStructure(JsonUtils.deSerialize(infoQueryItem.getVariableStructure(), new TypeReference<List<Variable>>() {
......@@ -243,6 +274,9 @@ public class AgentApplicationInfoConvert {
if (StringUtils.isNotBlank(infoQueryItem.getUnitIds())) {
entity.setUnitIds(JsonUtils.deSerialize(infoQueryItem.getUnitIds(), String[].class));
}
if (StringUtils.isNotBlank(infoQueryItem.getVoiceConfig())) {
entity.setVoiceConfig(JsonUtils.deSerialize(infoQueryItem.getVoiceConfig(), VoiceConfig.class));
}
entity.setIsDeleted(infoQueryItem.getIsDeleted());
entity.setCreator(infoQueryItem.getCreator());
entity.setCreatedTime(infoQueryItem.getCreatedTime());
......
package cn.com.poc.agent_application.convert;
import cn.com.poc.agent_application.domain.AgentApplicationBaseInfo;
import cn.com.poc.agent_application.domain.AgentApplicationCommConfig;
import cn.com.poc.agent_application.domain.AgentApplicationCommModelConfig;
import cn.com.poc.agent_application.domain.AgentApplicationKnowledgeConfig;
import cn.com.poc.agent_application.domain.*;
import cn.com.poc.agent_application.dto.BizAgentApplicationPublishDto;
import cn.com.poc.agent_application.entity.BizAgentApplicationPublishEntity;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.agent_application.entity.VoiceConfig;
import cn.com.poc.agent_application.model.BizAgentApplicationPublishModel;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.JsonUtils;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.core.type.TypeReference;
......@@ -35,7 +34,8 @@ public class BizAgentApplicationPublishConvert {
entity.setAgentSystem(model.getAgentSystem());
entity.setPreamble(model.getPreamble());
entity.setIsLongMemory(model.getIsLongMemory());
entity.setPublishTime(model.getCreatedTime());
entity.setPublishTime(model.getModifiedTime());
entity.setIsDocumentParsing(model.getIsDocumentParsing());
if (StringUtils.isNotBlank(model.getVariableStructure())) {
entity.setVariableStructure(JsonUtils.deSerialize(model.getVariableStructure(), new TypeReference<List<Variable>>() {
}.getType()));
......@@ -57,12 +57,15 @@ public class BizAgentApplicationPublishConvert {
if (StringUtils.isNotBlank(model.getUnitIds())) {
entity.setUnitIds(JsonUtils.deSerialize(model.getUnitIds(), String[].class));
}
if (StringUtils.isNotBlank(model.getVoiceConfig())) {
VoiceConfig voiceConfig = JsonUtils.deSerialize(model.getVoiceConfig(), VoiceConfig.class);
entity.setVoiceConfig(voiceConfig);
}
entity.setIsDeleted(model.getIsDeleted());
entity.setCreator(model.getCreator());
entity.setCreatedTime(model.getCreatedTime());
entity.setModifier(model.getModifier());
entity.setModifiedTime(model.getModifiedTime());
entity.setSysVersion(model.getSysVersion());
return entity;
}
......@@ -77,6 +80,7 @@ public class BizAgentApplicationPublishConvert {
model.setAgentSystem(entity.getAgentSystem());
model.setPreamble(entity.getPreamble());
model.setIsLongMemory(entity.getIsLongMemory());
model.setIsDocumentParsing(entity.getIsDocumentParsing());
if (CollectionUtils.isNotEmpty(entity.getVariableStructure())) {
model.setVariableStructure(JsonUtils.serialize(entity.getVariableStructure()));
}
......@@ -96,12 +100,15 @@ public class BizAgentApplicationPublishConvert {
if (ArrayUtils.isNotEmpty(entity.getUnitIds())) {
model.setUnitIds(JsonUtil.toJson(entity.getUnitIds()));
}
if (ObjectUtil.isNotEmpty(entity.getVoiceConfig())) {
model.setVoiceConfig(JsonUtils.serialize(entity.getVoiceConfig()));
}
model.setIsDeleted(entity.getIsDeleted());
model.setCreator(entity.getCreator());
model.setCreatedTime(entity.getCreatedTime());
model.setModifier(entity.getModifier());
model.setModifiedTime(entity.getModifiedTime());
model.setSysVersion(entity.getSysVersion());
return model;
}
......@@ -128,6 +135,7 @@ public class BizAgentApplicationPublishConvert {
AgentApplicationKnowledgeConfig knowledgeConfig = new AgentApplicationKnowledgeConfig();
knowledgeConfig.setKnowledgeIds(entity.getKnowledgeIds());
knowledgeConfig.setIsDocumentParsing(entity.getIsDocumentParsing());
AgentApplicationCommModelConfig commModelConfig = new AgentApplicationCommModelConfig();
commModelConfig.setLargeModel(entity.getLargeModel());
......@@ -135,10 +143,21 @@ public class BizAgentApplicationPublishConvert {
commModelConfig.setTemperature(entity.getTemperature());
commModelConfig.setCommunicationTurn(entity.getCommunicationTurn());
AgentApplicationVoiceConfig voiceConfig = new AgentApplicationVoiceConfig();
if (ObjectUtil.isNotEmpty(entity.getVoiceConfig())) {
voiceConfig.setDefaultOpen(entity.getVoiceConfig().getDefaultOpen());
voiceConfig.setTimbreId(entity.getVoiceConfig().getTimbreId());
} else {
voiceConfig.setDefaultOpen(CommonConstant.YOrN.N);
voiceConfig.setTimbreId(StringUtils.EMPTY);
}
dto.setBaseInfo(baseInfo);
dto.setCommConfig(commConfig);
dto.setKnowledgeConfig(knowledgeConfig);
dto.setCommModelConfig(commModelConfig);
dto.setVoiceConfig(voiceConfig);
dto.setUnitIds(entity.getUnitIds());
dto.setCreator(entity.getCreator());
dto.setCreatedTime(entity.getCreatedTime());
......@@ -175,6 +194,7 @@ public class BizAgentApplicationPublishConvert {
if (ObjectUtil.isNotEmpty(dto.getKnowledgeConfig())) {
entity.setKnowledgeIds(dto.getKnowledgeConfig().getKnowledgeIds());
entity.setIsDocumentParsing(dto.getKnowledgeConfig().getIsDocumentParsing());
}
if (ObjectUtil.isNotEmpty(dto.getCommModelConfig())) {
......@@ -183,6 +203,14 @@ public class BizAgentApplicationPublishConvert {
entity.setTemperature(dto.getCommModelConfig().getTemperature());
}
VoiceConfig voiceConfig = new VoiceConfig();
if (ObjectUtil.isNotEmpty(dto.getVoiceConfig())) {
voiceConfig.setDefaultOpen(dto.getVoiceConfig().getDefaultOpen());
voiceConfig.setTimbreId(dto.getVoiceConfig().getTimbreId());
} else {
voiceConfig.setDefaultOpen(CommonConstant.YOrN.N);
voiceConfig.setTimbreId(StringUtils.EMPTY);
}
entity.setUnitIds(dto.getUnitIds());
entity.setCreator(dto.getCreator());
entity.setCreatedTime(dto.getCreatedTime());
......
......@@ -18,4 +18,18 @@ public class AgentApplicationKnowledgeConfig {
public void setKnowledgeIds(java.lang.Integer[] knowledgeIds) {
this.knowledgeIds = knowledgeIds;
}
/**
* is_document_parsing
* 是否开启文档解析 1、Y 是 2、N 否
*/
private java.lang.String isDocumentParsing;
public java.lang.String getIsDocumentParsing() {
return this.isDocumentParsing;
}
public void setIsDocumentParsing(java.lang.String isDocumentParsing) {
this.isDocumentParsing = isDocumentParsing;
}
}
package cn.com.poc.agent_application.domain;
public class AgentApplicationVoiceConfig {
/**
* 是否默认开启 Y-是 N-否
*/
private String defaultOpen;
/**
* 音色ID
*/
private String timbreId;
public String getDefaultOpen() {
return defaultOpen;
}
public void setDefaultOpen(String defaultOpen) {
this.defaultOpen = defaultOpen;
}
public String getTimbreId() {
return timbreId;
}
public void setTimbreId(String timbreId) {
this.timbreId = timbreId;
}
}
package cn.com.poc.agent_application.dto;
import cn.com.poc.agent_application.domain.AgentApplicationBaseInfo;
import cn.com.poc.agent_application.domain.AgentApplicationCommConfig;
import cn.com.poc.agent_application.domain.AgentApplicationCommModelConfig;
import cn.com.poc.agent_application.domain.AgentApplicationKnowledgeConfig;
import cn.com.poc.agent_application.domain.*;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
......@@ -21,6 +18,8 @@ public class AgentApplicationInfoDto implements java.io.Serializable {
private AgentApplicationCommModelConfig commModelConfig;
private AgentApplicationVoiceConfig voiceConfig;
public AgentApplicationBaseInfo getBaseInfo() {
return baseInfo;
}
......@@ -53,6 +52,14 @@ public class AgentApplicationInfoDto implements java.io.Serializable {
this.commModelConfig = commModelConfig;
}
public AgentApplicationVoiceConfig getVoiceConfig() {
return voiceConfig;
}
public void setVoiceConfig(AgentApplicationVoiceConfig voiceConfig) {
this.voiceConfig = voiceConfig;
}
/**
* unit_ids
* 组件ID
......@@ -113,7 +120,7 @@ public class AgentApplicationInfoDto implements java.io.Serializable {
/**
* isCollect
* 当前用户是否收藏 1、Y 是 2、N 否
* */
*/
private String isCollect;
public String getIsCollect() {
......@@ -124,29 +131,31 @@ public class AgentApplicationInfoDto implements java.io.Serializable {
this.isCollect = isCollect;
}
/** agent_publish_id
*发布应用的ID
/**
* agent_publish_id
* 发布应用的ID
*/
private java.lang.Integer agentPublishId;
public java.lang.Integer getAgentPublishId(){
public java.lang.Integer getAgentPublishId() {
return this.agentPublishId;
}
public void setAgentPublishId(java.lang.Integer agentPublishId){
public void setAgentPublishId(java.lang.Integer agentPublishId) {
this.agentPublishId = agentPublishId;
}
/** is_sale
*是否上架应用 1、Y 是 2、N 否
/**
* is_sale
* 是否上架应用 1、Y 是 2、N 否
*/
private java.lang.String isSale;
public java.lang.String getIsSale(){
public java.lang.String getIsSale() {
return this.isSale;
}
public void setIsSale(java.lang.String isSale){
public void setIsSale(java.lang.String isSale) {
this.isSale = isSale;
}
......
......@@ -19,6 +19,16 @@ public class AgentApplicationPreviewDto implements Serializable {
private List<Message> messages;
private List<String> fileUrls;
public List<String> getFileUrls() {
return fileUrls;
}
public void setFileUrls(List<String> fileUrls) {
this.fileUrls = fileUrls;
}
public Float getTopP() {
return topP;
}
......
package cn.com.poc.agent_application.entity;
import cn.com.poc.agent_application.domain.AgentApplicationVoiceConfig;
import javax.persistence.Column;
import java.util.Arrays;
import java.util.List;
......@@ -315,6 +317,35 @@ public class BizAgentApplicationInfoEntity {
this.isLongMemory = isLongMemory;
}
/**
* is_document_parsing
* 是否开启文档解析 1、Y 是 2、N 否
*/
private java.lang.String isDocumentParsing;
public java.lang.String getIsDocumentParsing() {
return this.isDocumentParsing;
}
public void setIsDocumentParsing(java.lang.String isDocumentParsing) {
this.isDocumentParsing = isDocumentParsing;
}
/**
* voice_config
* 声音配置 defaultOpen- 是否默认开启 timbreId-音色
*/
private VoiceConfig voiceConfig;
public VoiceConfig getVoiceConfig() {
return voiceConfig;
}
public void setVoiceConfig(VoiceConfig voiceConfig) {
this.voiceConfig = voiceConfig;
}
/**
* is_deleted
......
......@@ -2,8 +2,14 @@ package cn.com.poc.agent_application.entity;
public class CreateAgentTitleAndDescEntity {
/**
* 应用名称
*/
private String agentTitle;
/**
* 应用描述
*/
private String agentDesc;
public String getAgentTitle() {
......
package cn.com.poc.agent_application.entity;
public class VoiceConfig {
/**
* 是否默认开启 Y-是 N-否
*/
private String defaultOpen;
/**
* 音色ID
*/
private String timbreId;
public String getDefaultOpen() {
return defaultOpen;
}
public void setDefaultOpen(String defaultOpen) {
this.defaultOpen = defaultOpen;
}
public String getTimbreId() {
return timbreId;
}
public void setTimbreId(String timbreId) {
this.timbreId = timbreId;
}
}
......@@ -21,6 +21,8 @@ select distinct
unit_ids,
variable_structure,
is_long_memory,
is_document_parsing,
voice_config,
is_deleted,
CREATOR,
CREATED_TIME,
......
......@@ -285,18 +285,18 @@ public class AgentApplicationInfoQueryItem extends BaseItemClass implements Seri
}
/** temperature
*对话模型 温度 [0-1.00]
/**
* temperature
* 对话模型 温度 [0-1.00]
*/
private java.lang.Float temperature;
@Column(name = "temperature")
public java.lang.Float getTemperature(){
public java.lang.Float getTemperature() {
return this.temperature;
}
public void setTemperature(java.lang.Float temperature){
public void setTemperature(java.lang.Float temperature) {
this.temperature = temperature;
}
......@@ -347,6 +347,36 @@ public class AgentApplicationInfoQueryItem extends BaseItemClass implements Seri
this.isLongMemory = isLongMemory;
}
/**
* is_document_parsing
* 是否开启文档解析 1、Y 是 2、N 否
*/
private java.lang.String isDocumentParsing;
@Column(name = "is_document_parsing", length = 1)
public java.lang.String getIsDocumentParsing() {
return this.isDocumentParsing;
}
public void setIsDocumentParsing(java.lang.String isDocumentParsing) {
this.isDocumentParsing = isDocumentParsing;
}
/**
* voice_config
*/
private java.lang.String voiceConfig;
@Column(name = "voice_config")
public java.lang.String getVoiceConfig() {
return this.voiceConfig;
}
public void setVoiceConfig(java.lang.String voiceConfig) {
this.voiceConfig = voiceConfig;
}
/**
* is_deleted
* is_deleted
......
......@@ -18,8 +18,6 @@ public interface BizAgentApplicationPublishRest extends BaseRest {
BizAgentApplicationPublishDto save(@RequestBody BizAgentApplicationPublishDto dto) throws Exception;
BizAgentApplicationPublishDto update(@RequestBody BizAgentApplicationPublishDto dto) throws Exception;
void deletedById(@RequestParam java.lang.Integer id) throws Exception;
}
\ No newline at end of file
......@@ -38,12 +38,6 @@ public class BizAgentApplicationPublishRestImpl implements BizAgentApplicationPu
return BizAgentApplicationPublishConvert.entityToDto(bizAgentApplicationPublishService.save(entity));
}
public BizAgentApplicationPublishDto update(BizAgentApplicationPublishDto dto) throws Exception{
Assert.notNull(dto);
BizAgentApplicationPublishEntity entity = BizAgentApplicationPublishConvert.dtoToEntity(dto);
return BizAgentApplicationPublishConvert.entityToDto(bizAgentApplicationPublishService.update(entity));
}
public void deletedById(java.lang.Integer id) throws Exception{
Assert.notNull(id);
bizAgentApplicationPublishService.deletedById(id);
......
......@@ -37,7 +37,7 @@ public class AgentApplicationMallSchedule {
* 【用于更新DB的应用热度】
* 定时规则:每分钟触发一次
*/
@Scheduled(cron = "0 * * * * ?")
@Scheduled(cron = "0 0/5 * * * ?")
public void updateAgentPopularity() throws Exception {
// 查询应用市场表
List<BizAgentApplicationMallEntity> mallEntities = bizAgentApplicationMallService.getList();
......
......@@ -12,7 +12,7 @@ public interface BizAgentApplicationPublishService extends BaseService {
BizAgentApplicationPublishEntity get(java.lang.Integer id) throws Exception;
BizAgentApplicationPublishEntity getByAgentId(String agentId) throws Exception;
BizAgentApplicationPublishEntity getByAgentId(String agentId);
List<BizAgentApplicationPublishEntity> findByExample(BizAgentApplicationPublishEntity example, PagingInfo pagingInfo) throws Exception;
......
......@@ -213,7 +213,15 @@ public class BizAgentApplicationInfoServiceImpl extends BaseServiceImpl
if (entity.getTemperature() != null) {
Assert.isTrue(entity.getTemperature() > 0 && entity.getTemperature() <= 1.0, "temperature is error,must greater than 0, less than or equal to 1.9");
}
if (entity.getVoiceConfig() != null) {
model.setVoiceConfig(JsonUtils.serialize(entity.getVoiceConfig()));
}
if (StringUtils.isNotBlank(entity.getIsDocumentParsing())) {
model.setIsDocumentParsing(entity.getIsDocumentParsing());
}
if (StringUtils.isNotBlank(entity.getIsLongMemory())) {
model.setIsLongMemory(entity.getIsLongMemory());
}
model.setTemperature(entity.getTemperature());
model.setTopP(entity.getTopP());
model.setContinuousQuestionStatus(entity.getContinuousQuestionStatus());
......
......@@ -50,7 +50,7 @@ public class BizAgentApplicationPublishServiceImpl extends BaseServiceImpl
}
@Override
public BizAgentApplicationPublishEntity getByAgentId(String agentId) throws Exception {
public BizAgentApplicationPublishEntity getByAgentId(String agentId) {
Assert.notNull(agentId);
BizAgentApplicationPublishModel model = new BizAgentApplicationPublishModel();
model.setAgentId(agentId);
......@@ -89,15 +89,17 @@ public class BizAgentApplicationPublishServiceImpl extends BaseServiceImpl
Assert.notNull(entity);
Assert.notNull(entity.getId(), "update pk can not be null");
BizAgentApplicationPublishModel model = this.repository.get(entity.getId());
model.setAgentId(entity.getAgentId());
model.setIsDeleted(CommonConstant.IsDeleted.N);
List<BizAgentApplicationPublishModel> models = this.repository.findByExample(model);
if (CollectionUtils.isEmpty(models)) {
throw new I18nMessageException("exception/data.does.not.exist");
}
model = models.get(0);
paramVerificationAndConvert(entity, model);
BizAgentApplicationPublishModel saveModel = this.repository.save(model);
if (model == null) {
throw new I18nMessageException("exception/publication.failed");
}
BizAgentApplicationPublishModel updateModel = BizAgentApplicationPublishConvert.entityToModel(entity);
updateModel.setIsDeleted(CommonConstant.IsDeleted.N);
updateModel.setCreator(model.getCreator());
updateModel.setCreatedTime(model.getCreatedTime());
updateModel.setModifier(null);
updateModel.setModifiedTime(null);
updateModel.setSysVersion(model.getSysVersion());
BizAgentApplicationPublishModel saveModel = this.repository.save(updateModel);
return BizAgentApplicationPublishConvert.modelToEntity(saveModel);
}
......@@ -186,5 +188,6 @@ public class BizAgentApplicationPublishServiceImpl extends BaseServiceImpl
if (ArrayUtils.isNotEmpty(entity.getUnitIds())) {
model.setUnitIds(JsonUtils.serialize(entity.getUnitIds()));
}
model.setIsDocumentParsing(entity.getIsDocumentParsing());
}
}
\ No newline at end of file
package cn.com.poc.agent_application.utils;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriter;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* @author alex.yao
* @description agent应用配置工具类
*/
public class AgentApplicationTools {
/**
* 构造Agent应用 函数配置
*
* @param variableStructures 记忆变量
* @param isLongMemory 是否开启长期记忆
* @param identifier 标识符
* @param agentId 应用id
* @param unitIds 插件id
* @param isDocumentParsing 是否开启文档解析
* @return
*/
public static List<Tool> buildFunctionConfig(List<Variable> variableStructures, String isLongMemory, String identifier, String agentId, String[] unitIds, String isDocumentParsing) {
List<Tool> tools = new ArrayList<>();
//开启对话变量
if (CollectionUtils.isNotEmpty(variableStructures)) {
String functionName = LargeModelFunctionEnum.memory_variable_writer.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getLLMConfig(variableStructures).get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
//初始化变量函数
Map<Object, Object> map = MemoryVariableWriter.get(identifier(identifier, agentId));
if (MapUtils.isEmpty(map)) {
List<Variable> variableStructure = variableStructures;
for (Variable variable : variableStructure) {
String key = variable.getKey();
String variableDefault = variable.getVariableDefault();
JSONObject jsonObject = new JSONObject();
jsonObject.put("key", key);
jsonObject.put("value", variableDefault);
LargeModelFunctionEnum.valueOf(functionName).getFunction().doFunction(jsonObject.toJSONString(), identifier(identifier, agentId));
}
}
}
//开启长期记忆
if (CommonConstant.YOrN.Y.equals(isLongMemory)) {
String functionName = LargeModelFunctionEnum.set_long_memory.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getLLMConfig().get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
}
//开启文档解析-文档理解
if (CommonConstant.YOrN.Y.equals(isDocumentParsing)) {
String functionName = LargeModelFunctionEnum.document_understanding.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getLLMConfig().get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
}
//初始化插件函数
if (ArrayUtils.isNotEmpty(unitIds)) {
for (String unitId : unitIds) {
LargeModelFunctionEnum modelFunctionEnum = LargeModelFunctionEnum.getFunction(unitId);
if (modelFunctionEnum == null) {
continue;
}
String llmConfig = modelFunctionEnum.getFunction().getLLMConfig().get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
}
}
return tools;
}
/**
* 创建会话唯一标识符
*
* @param dialogueId 对话ID
* @param agentId 应用ID
* @return
*/
public static String identifier(String dialogueId, String agentId) {
return dialogueId + ":" + agentId;
}
/**
* 检查对话文件内容是否为空
*
* @throws I18nMessageException 文件内容为空时抛出异常
*/
public static void checkDialogueContentIsEmpty(List<String> fileUrls) {
// 判断文件是否为空,如果不为空
if (CollectionUtils.isNotEmpty(fileUrls)) {
for (String fileUrl : fileUrls) {
File file = DocumentLoad.downloadURLDocument(fileUrl);
String documentContent = DocumentLoad.documentToText(file);
if (StringUtils.isBlank(documentContent)) {
throw new I18nMessageException("exception/file.content.empty");
}
}
}
}
}
......@@ -7,7 +7,7 @@ import java.lang.annotation.*;
/**
* 限流
*/
@Target(ElementType.METHOD)
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Component
@Documented
......@@ -77,7 +77,8 @@ public @interface RedisLimit {
*/
MONTH_OF_YEAR;
LimitTimeUnit(){}
LimitTimeUnit() {
}
}
}
......@@ -2,6 +2,7 @@ package cn.com.poc.common.utils;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.core.io.FileUtil;
import io.github.furstenheim.*;
import org.apache.pdfbox.io.RandomAccessBufferedFileInputStream;
import org.apache.pdfbox.pdfparser.PDFParser;
import org.apache.pdfbox.pdmodel.PDDocument;
......@@ -12,10 +13,48 @@ import org.apache.poi.xwpf.usermodel.XWPFDocument;
import org.springframework.util.Assert;
import java.io.*;
import java.net.URL;
import java.net.URLConnection;
import java.nio.file.Files;
public class DocumentLoad {
final static OptionsBuilder optionsBuilder = OptionsBuilder.anOptions();
final static Options options = optionsBuilder.withBr("-")
.withLinkStyle(LinkStyle.REFERENCED)
.withLinkReferenceStyle(LinkReferenceStyle.SHORTCUT)
.build();
final static CopyDown converter = new CopyDown(options);
/**
* Html To Markdown
*/
public static String htmlToMarkdown(String url) {
try {
// 创建 资源符对象 连接
URLConnection conn = new URL(url).openConnection();
// 获取输入流
InputStream inputStream = conn.getInputStream();
// 缓冲区,读取输入流内容,64KB
char[] buffer = new char[1024 * 64];
int len;
StringBuilder sb = new StringBuilder();
// 转换为字符流
InputStreamReader isr = new InputStreamReader(inputStream);
// 循环读取
while ((len = isr.read(buffer)) != -1) {
sb.append(buffer, 0, len);
}
// 关闭资源
inputStream.close();
isr.close();
String htmlStr = sb.toString();
return converter.convert(htmlStr);
} catch (IOException e) {
throw new I18nMessageException(e.getMessage());
}
}
/**
* 读取文档
*
......@@ -92,4 +131,31 @@ public class DocumentLoad {
doc.close();
return stringBuilder.toString();
}
public static File downloadURLDocument(String path) {
// 下载网络文件
int bytesum = 0;
int byteread = 0;
try {
URL url = new URL(path);
URLConnection conn = url.openConnection();
String[] split = url.getFile().split("\\.");
String suffix = split[split.length - 1];
File tempFile = File.createTempFile(UUIDTool.getUUID(), "." + suffix);
FileOutputStream fs = new FileOutputStream(tempFile);
InputStream inStream = conn.getInputStream();
byte[] buffer = new byte[1024];
while ((byteread = inStream.read(buffer)) != -1) {
bytesum += byteread;
fs.write(buffer, 0, byteread);
}
fs.close();
return tempFile;
} catch (IOException e) {
throw new I18nMessageException("exception/file.load.error");
}
}
}
package cn.com.poc.common.utils;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Path;
public class PcmToWav {
/**
* @param src 待转换文件路径
* @param target 目标文件路径
* @throws IOException 抛出异常
*/
public static Path convertAudioFiles(String src) throws IOException {
FileInputStream fis = new FileInputStream(src);
File tempAudioFile = File.createTempFile(UUIDTool.getUUID(), "wav");
FileOutputStream fos = new FileOutputStream(tempAudioFile);
//计算长度
byte[] buf = new byte[1024 * 4];
int size = fis.read(buf);
int PCMSize = 0;
while (size != -1) {
PCMSize += size;
size = fis.read(buf);
}
fis.close();
//填入参数,比特率等等。这里用的是16位单声道 8000 hz
WaveHeader header = new WaveHeader();
//长度字段 = 内容的大小(PCMSize) + 头部字段的大小(不包括前面4字节的标识符RIFF以及fileLength本身的4字节)
header.fileLength = PCMSize + (44 - 8);
header.FmtHdrLeth = 16;
header.BitsPerSample = 16;
header.Channels = 1;
header.FormatTag = 0x0001;
header.SamplesPerSec = 16000;
header.BlockAlign = (short) (header.Channels * header.BitsPerSample / 16);
header.AvgBytesPerSec = header.BlockAlign * header.SamplesPerSec;
header.DataHdrLeth = PCMSize;
byte[] h = header.getHeader();
assert h.length == 44; //WAV标准,头部应该是44字节
fos.write(h, 0, h.length);
fis = new FileInputStream(src);
size = fis.read(buf);
while (size != -1) {
fos.write(buf, 0, size);
size = fis.read(buf);
}
fis.close();
fos.close();
return tempAudioFile.toPath();
}
}
......@@ -18,6 +18,14 @@ import java.util.stream.Collectors;
**/
public class SQLUtils {
/**
* 获取批量插入sql和参数
*
* @param modelClass
* @param models
* @return
* @throws Exception
*/
public static BatchInsertResult getInsertSql(Class<? extends BaseModelClass> modelClass, List<? extends BaseModelClass> models) throws Exception {
//获取表名
Table table = modelClass.getAnnotation(Table.class);
......@@ -39,7 +47,7 @@ public class SQLUtils {
tableFieldList.add(column.name());
modelGetMethodList.add(method.getName());
}
BatchInsertResult result = new BatchInsertResult();
//构造insert sql
String fieldStr = tableFieldList.stream().collect(Collectors.joining(","));
StringBuilder insertSQL = new StringBuilder();
......@@ -52,7 +60,10 @@ public class SQLUtils {
}
insertSQL.append(")");
result.setInsertSQL(insertSQL.toString());
//构造实体
if (models != null) {
List<Object[]> toInserteParams = new ArrayList<>(models.size());
for (BaseModelClass model : models) {
Object[] array = new Object[modelGetMethodList.size()];
......@@ -64,10 +75,8 @@ public class SQLUtils {
}
toInserteParams.add(array);
}
BatchInsertResult result = new BatchInsertResult();
result.setInsertSQL(insertSQL.toString());
result.setInsertParams(toInserteParams);
}
return result;
}
......
package cn.com.poc.common.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class WaveHeader {
public final char fileID[] = {'R', 'I', 'F', 'F'};
public int fileLength;
public char wavTag[] = {'W', 'A', 'V', 'E'};
public char FmtHdrID[] = {'f', 'm', 't', ' '};
public int FmtHdrLeth;
public short FormatTag;
public short Channels;
public int SamplesPerSec;
public int AvgBytesPerSec;
public short BlockAlign;
public short BitsPerSample;
public char DataHdrID[] = {'d', 'a', 't', 'a'};
public int DataHdrLeth;
public byte[] getHeader() throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
WriteChar(bos, fileID);
WriteInt(bos, fileLength);
WriteChar(bos, wavTag);
WriteChar(bos, FmtHdrID);
WriteInt(bos, FmtHdrLeth);
WriteShort(bos, FormatTag);
WriteShort(bos, Channels);
WriteInt(bos, SamplesPerSec);
WriteInt(bos, AvgBytesPerSec);
WriteShort(bos, BlockAlign);
WriteShort(bos, BitsPerSample);
WriteChar(bos, DataHdrID);
WriteInt(bos, DataHdrLeth);
bos.flush();
byte[] r = bos.toByteArray();
bos.close();
return r;
}
private void WriteShort(ByteArrayOutputStream bos, int s) throws IOException {
byte[] mybyte = new byte[2];
mybyte[1] = (byte) ((s << 16) >> 24);
mybyte[0] = (byte) ((s << 24) >> 24);
bos.write(mybyte);
}
private void WriteInt(ByteArrayOutputStream bos, int n) throws IOException {
byte[] buf = new byte[4];
buf[3] = (byte) (n >> 24);
buf[2] = (byte) ((n << 8) >> 24);
buf[1] = (byte) ((n << 16) >> 24);
buf[0] = (byte) ((n << 24) >> 24);
bos.write(buf);
}
private void WriteChar(ByteArrayOutputStream bos, char[] id) {
for (char c : id) {
bos.write(c);
}
}
}
......@@ -2,6 +2,7 @@ package cn.com.poc.expose.aggregate;
import cn.com.poc.agent_application.query.MemberCollectQueryItem;
import cn.com.yict.framemax.data.model.PagingInfo;
import org.springframework.web.bind.annotation.RequestParam;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
......@@ -10,8 +11,13 @@ public interface AgentApplicationService {
/**
* 调用 已发布Agent应用
*
* @param agentId 应用ID
* @param dialogsId 对话ID
* @param input 用户输入
* @param fileUrls 文件URL
*/
void callAgentApplication(String agentId, String dialogsId, String input, HttpServletResponse httpServletResponse) throws Exception;
void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, HttpServletResponse httpServletResponse) throws Exception;
/**
* 追问AI生成
......@@ -41,4 +47,21 @@ public interface AgentApplicationService {
*/
List<MemberCollectQueryItem> getCollectedApplications(Long memberId, PagingInfo pagingInfo);
/**
* 获取用户在当前应用的自动播放配置
*
* @param memberId 用户ID
* @param agentId 应用ID
*/
String autoPlayByAgentId(Long memberId, String agentId);
/**
* 设置用户在当前应用的自动播放配置
*
* @param memberId 用户ID
* @param agentId 应用ID
* @param autoPlay 自动播放配置 Y/N
*/
String enableAutoPlay(Long memberId, String agentId, String autoPlay);
}
package cn.com.poc.expose.dto;
import java.util.List;
public class AgentApplicationDto {
private String dialogsId;
......@@ -32,4 +34,13 @@ public class AgentApplicationDto {
this.input = input;
}
private List<String> fileUrls;
public List<String> getFileUrls() {
return fileUrls;
}
public void setFileUrls(List<String> fileUrls) {
this.fileUrls = fileUrls;
}
}
......@@ -9,6 +9,12 @@ public class UserDialoguesDto implements Serializable {
*/
private java.lang.String dialogsId;
/**
* 应用ID
*/
private java.lang.String agentId;
/**
* 内容
*/
......@@ -30,4 +36,12 @@ public class UserDialoguesDto implements Serializable {
public void setContent(String content) {
this.content = content;
}
public String getAgentId() {
return agentId;
}
public void setAgentId(String agentId) {
this.agentId = agentId;
}
}
......@@ -86,4 +86,18 @@ public interface AgentApplicationRest extends BaseRest {
*/
List<DialoguesContextDto> getDialogueContext(@RequestParam String dialogueId) throws Exception;
/**
* 获取用户在当前应用的自动播放配置
*
* @param agentId 应用ID
*/
String autoPlayByAgentId(@RequestParam String agentId);
/**
* 设置用户在当前应用的自动播放配置
*
* @param agentId 应用ID
* @param autoPlay 自动播放配置 Y/N
*/
String enableAutoPlay(@RequestParam String agentId, @RequestParam String autoPlay);
}
package cn.com.poc.expose.rest.impl;
import cn.com.poc.agent_application.aggregate.AgentApplicationInfoService;
import cn.com.poc.agent_application.aggregate.AgentApplicationMallService;
import cn.com.poc.agent_application.convert.AgentApplicationInfoConvert;
import cn.com.poc.agent_application.convert.BizAgentApplicationPublishConvert;
import cn.com.poc.agent_application.dto.AgentApplicationCreateContinueQuesDto;
......@@ -104,7 +103,7 @@ public class AgentApplicationRestImpl implements AgentApplicationRest {
Assert.notNull(dto.getAgentId());
Assert.notNull(dto.getDialogsId());
try {
agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), httpServletResponse);
agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), dto.getFileUrls(), httpServletResponse);
} catch (Exception e) {
httpServletResponse.setContentType("text/event-stream");
PrintWriter writer = httpServletResponse.getWriter();
......@@ -185,6 +184,7 @@ public class AgentApplicationRestImpl implements AgentApplicationRest {
result = memberDialoguesQueryItems.stream().map(item -> {
UserDialoguesDto userDialoguesDto = new UserDialoguesDto();
userDialoguesDto.setDialogsId(item.getDialogsId());
userDialoguesDto.setAgentId(item.getAgentId());
String content = item.getContent().length() > 20 ? item.getContent().substring(0, 20) : item.getContent();
userDialoguesDto.setContent(content);
return userDialoguesDto;
......@@ -231,4 +231,20 @@ public class AgentApplicationRestImpl implements AgentApplicationRest {
}
return null;
}
@Override
public String autoPlayByAgentId(String agentId) {
Assert.notBlank(agentId);
UserBaseEntity currentUser = BlContext.getCurrentUser();
return agentApplicationService.autoPlayByAgentId(currentUser.getUserId(), agentId);
}
@Override
public String enableAutoPlay(String agentId, String autoPlay) {
Assert.notBlank(agentId);
Assert.notBlank(autoPlay);
Assert.isTrue(CommonConstant.YOrN.N.equals(autoPlay) || CommonConstant.YOrN.Y.equals(autoPlay));
UserBaseEntity currentUser = BlContext.getCurrentUser();
return agentApplicationService.enableAutoPlay(currentUser.getUserId(), agentId, autoPlay);
}
}
package cn.com.poc.expose.websocket;
import cn.com.poc.expose.websocket.constant.WsHandlerMatcher;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
import cn.com.poc.expose.websocket.holder.WSHandlerHolder;
import cn.com.yict.framemax.core.config.Config;
import cn.com.yict.framemax.core.exception.BusinessException;
import org.java_websocket.WebSocket;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.UUID;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Component
public class SuperLinkWebSocketServer extends WebSocketServer {
private final Logger logger = LoggerFactory.getLogger(SuperLinkWebSocketServer.class);
private boolean serverStarted = false;
private final String LOG_KEY = "log_trace_id";
private SuperLinkWebSocketServer() {
}
public SuperLinkWebSocketServer(int port, int decodercount) {
super(new InetSocketAddress(port), decodercount);
}
@Override
public void onOpen(WebSocket webSocket, ClientHandshake clientHandshake) {
// auth(webSocket);
initWebSocketTraceId();
String path = webSocket.getResourceDescriptor();
logger.warn("websocket new connection open , address:{} ,path:{} ", webSocket.getRemoteSocketAddress().toString(), path);
AbstractWsHandler handlerClass = WsHandlerMatcher.getWsHandlerClass(path);
if (handlerClass != null) {
WSHandlerHolder.set(handlerClass);
} else {
webSocket.send("{\"code\":-1, \"message\":\"path mismatch\"}");
webSocket.close();
}
}
private void auth(WebSocket webSocket) {
String hostName = webSocket.getRemoteSocketAddress().getHostName();
logger.info("connection hostname:{}", hostName);
String[] split = Config.get("white.list.ip").split(",");
if (Arrays.stream(split).noneMatch(ip -> ip.equals(hostName))) {
throw new BusinessException("no authority");
}
}
@Override
public void onMessage(WebSocket webSocket, String s) {
logger.info("{} connection send message:{}", webSocket.getRemoteSocketAddress().toString(), s.length() < 1024 ? s : "response too long");
WSHandlerHolder.get().doHandler(webSocket, s);
}
@Override
public void onClose(WebSocket webSocket, int i, String s, boolean b) {
logger.warn("connection is close:{}", webSocket.getRemoteSocketAddress().toString() + webSocket.getResourceDescriptor());
WSHandlerHolder.clear();
}
@Override
public void onError(WebSocket webSocket, Exception e) {
logger.warn("connection is error:{}", webSocket.getRemoteSocketAddress().toString() + webSocket.getResourceDescriptor());
WSHandlerHolder.clear();
}
@Override
public void onStart() {
logger.warn("---------------------------------websocket start--------------------------------------");
}
@Override
public void start() {
serverStarted = true;
super.start();
}
@Override
public void stop() throws IOException, InterruptedException {
serverStarted = false;
super.stop();
}
@Override
public void stop(int i) throws InterruptedException {
serverStarted = false;
super.stop(i);
}
public boolean isRun() {
return this.serverStarted;
}
private void initWebSocketTraceId() {
String traceId = UUID.randomUUID().toString().replaceAll("-", "");
MDC.put(LOG_KEY, traceId);
}
}
package cn.com.poc.expose.websocket.config;
import cn.com.poc.expose.websocket.SuperLinkWebSocketServer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Configuration
public class SuperLinkWebSocketConfig {
@Bean
public SuperLinkWebSocketServer superLinkWebSocketServer() {
return new SuperLinkWebSocketServer(40088, Runtime.getRuntime().availableProcessors() * 10);
}
}
package cn.com.poc.expose.websocket.constant;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
import cn.com.poc.expose.websocket.handler.TextToSpeechTencentHandler;
public class WsHandlerMatcher {
final private static String TEXT_TO_SPEECH_TC = "/websocket/textToSpeechTC.ws";
public static AbstractWsHandler getWsHandlerClass(String path) {
if (StringUtils.isBlank(path)) {
return null;
}
AbstractWsHandler handler = null;
switch (path) {
case TEXT_TO_SPEECH_TC:
handler = new TextToSpeechTencentHandler();
break;
}
return handler;
}
}
package cn.com.poc.expose.websocket.dto;
public class TextToSpeechTencentResponse {
private String codec = "mp3";
private Integer sampleRate = 16000;
private Integer voiceType = 1001;
private Integer volume = 0;
private Float speed = 0F;
private String content;
public Float getSpeed() {
return speed;
}
public void setSpeed(Float speed) {
this.speed = speed;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getCodec() {
return codec;
}
public void setCodec(String codec) {
this.codec = codec;
}
public Integer getSampleRate() {
return sampleRate;
}
public void setSampleRate(Integer sampleRate) {
this.sampleRate = sampleRate;
}
public Integer getVoiceType() {
return voiceType;
}
public void setVoiceType(Integer voiceType) {
this.voiceType = voiceType;
}
public Integer getVolume() {
return volume;
}
public void setVolume(Integer volume) {
this.volume = volume;
}
}
package cn.com.poc.expose.websocket.exception;
import cn.com.yict.framemax.core.exception.ErrorCoded;
import org.slf4j.MDC;
import java.io.Serializable;
/**
* @author alex.yao
* @date 2023/12/15
**/
public class WebsocketException extends RuntimeException implements ErrorCoded {
private static final long serialVersionUID = 2332618265610125980L;
private final String LOG_KEY = "log_trace_id";
private Serializable errorCode = -1;
private String traceId;
public WebsocketException() {
super();
}
public WebsocketException(Throwable cause) {
super(cause);
}
public WebsocketException(String message) {
super(message);
this.traceId = MDC.get(LOG_KEY);
}
public WebsocketException(Serializable code, String message) {
super(message);
this.errorCode = code;
this.traceId = MDC.get(LOG_KEY);
}
public WebsocketException(String message, String traceId) {
super(message);
this.traceId = traceId;
}
public WebsocketException(Serializable code, String message, String traceId) {
super(message);
this.errorCode = code;
this.traceId = traceId;
}
@Override
public Serializable getErrorCode() {
return this.errorCode;
}
@Override
public String getMessage() {
return super.getMessage();
}
@Override
public String toString() {
return "{\"code\":" + "\"" + errorCode + "\" ,\"message\":" + "\"" + super.getMessage() + "\" ,\"traceId\":" + "\"" + traceId + "\"}";
}
}
package cn.com.poc.expose.websocket.handler;
import org.java_websocket.WebSocket;
/**
* @author alex.yao
* @date 2023/12/7
**/
public abstract class AbstractWsHandler {
public abstract void doHandler(WebSocket webSocket,String message);
}
package cn.com.poc.expose.websocket.handler;
import cn.com.poc.common.service.BosConfigService;
import cn.com.poc.common.utils.*;
import cn.com.poc.expose.websocket.dto.TextToSpeechTencentResponse;
import cn.com.poc.expose.websocket.exception.WebsocketException;
import cn.com.yict.framemax.core.config.Config;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.core.i18n.I18nUtils;
import com.tencent.SpeechClient;
import com.tencent.tts.model.SpeechSynthesisRequest;
import com.tencent.tts.model.SpeechSynthesisResponse;
import com.tencent.tts.service.SpeechSynthesisListener;
import com.tencent.tts.service.SpeechSynthesizer;
import org.java_websocket.WebSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public class TextToSpeechTencentHandler extends AbstractWsHandler {
final Logger logger = LoggerFactory.getLogger(TextToSpeechTencentHandler.class);
private SpeechSynthesizer speechSynthesizer;
private BosConfigService bosConfigService;
@Override
public void doHandler(WebSocket webSocket, String message) {
try {
bosConfigService = SpringUtils.getBean(BosConfigService.class);
TextToSpeechTencentResponse response = JsonUtils.deSerialize(message, TextToSpeechTencentResponse.class);
if (response == null) {
throw new WebsocketException("arg cannot null");
}
if (StringUtils.isBlank(response.getContent())) {
throw new WebsocketException("content cannot null");
}
initTTS(webSocket, response);
textToSpeech(response.getContent());
} catch (Exception e) {
WebsocketException websocketException = new WebsocketException(e);
webSocket.send(websocketException.getMessage());
throw websocketException;
}
}
/**
* 调用TTS
*/
private void textToSpeech(String text) {
this.speechSynthesizer.synthesisLongText(text);
}
/**
* 初始化TTS
*/
private void initTTS(WebSocket webSocket, TextToSpeechTencentResponse textToSpeechTencentResponse) throws IOException {
//从配置文件读取密钥
String appId = Config.get("tencent.speech.synthesizer.appid");
String secretId = Config.get("tencent.speech.synthesizer.secretId");
String secretKey = Config.get("tencent.speech.synthesizer.secretKey");
//创建SpeechSynthesizerClient实例,目前是单例
SpeechClient client = SpeechClient.newInstance(appId, secretId, secretKey);
//初始化SpeechSynthesizerRequest,SpeechSynthesizerRequest包含请求参数
SpeechSynthesisRequest request = SpeechSynthesisRequest.initialize();
request.setSampleRate(textToSpeechTencentResponse.getSampleRate());
request.setSpeed(textToSpeechTencentResponse.getSpeed());
request.setCodec(textToSpeechTencentResponse.getCodec());
request.setVolume(textToSpeechTencentResponse.getVolume());
request.setVoiceType(textToSpeechTencentResponse.getVoiceType());
//使用客户端client创建语音合成实例
if ("wav".equals(textToSpeechTencentResponse.getCodec())) {
request.setCodec("pcm");
}
speechSynthesizer = client.newSpeechSynthesizer(request, new SpeechSynthesisListener() {
List<byte[]> audioBytes = new ArrayList<>();
AtomicInteger sessionId = new AtomicInteger(0);
File tempAudioFile = File.createTempFile(UUIDTool.getUUID(), textToSpeechTencentResponse.getCodec());
AtomicInteger count = new AtomicInteger(0);
@Override
public void onComplete(SpeechSynthesisResponse response) {
logger.info("onComplete");
try (FileOutputStream fileOutputStream = new FileOutputStream(tempAudioFile, true)) {
for (byte[] audioByte : audioBytes) {
fileOutputStream.write(audioByte);
}
} catch (IOException e) {
logger.error("onComplete:{}", e.getMessage());
}
Path path = tempAudioFile.toPath();
if ("wav".equals(textToSpeechTencentResponse.getCodec())) {
try {
path = PcmToWav.convertAudioFiles(tempAudioFile.getPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
//上传音频
try (InputStream fileInputStream = Files.newInputStream(path);) {
String uploadUrl = bosConfigService.upload(fileInputStream, textToSpeechTencentResponse.getCodec(), null);
webSocket.send("{\"replyVoiceUrl\":\"" + uploadUrl + "\"}");
} catch (Exception e) {
throw new BusinessException(e);
}
webSocket.send("{\"final\":true}");
webSocket.close();
}
@Override
public void onMessage(byte[] data) {
//发送音频
sessionId.incrementAndGet();
Base64.Encoder encoder = Base64.getEncoder();
String base64 = encoder.encodeToString(data);
webSocket.send("{\"sessionId\":" + count.get() + ",\"audio\":\"" + base64 + "\"}");
audioBytes.add(data);
}
@Override
public void onFail(SpeechSynthesisResponse response) {
logger.warn("onFail:{}", response.getMessage());
WebsocketException websocketException = new WebsocketException(response.getMessage());
webSocket.send(websocketException.toString());
webSocket.close();
throw websocketException;
}
});
}
}
package cn.com.poc.expose.websocket.holder;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
/**
* @author alex.yao
* @date 2024/3/19 22:50
*/
public class WSHandlerHolder {
private final static ThreadLocal<AbstractWsHandler> WS_HANDLER_HOLDER = new ThreadLocal<>();
public static AbstractWsHandler get(){
return WS_HANDLER_HOLDER.get();
}
public static void set(AbstractWsHandler abstractWsHandler){
WS_HANDLER_HOLDER.set(abstractWsHandler);
}
public static void clear(){
WS_HANDLER_HOLDER.remove();
}
}
package cn.com.poc.expose.websocket.init;
import cn.com.poc.expose.websocket.SuperLinkWebSocketServer;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Component
public class SuperLinkWSRun implements InitializingBean {
@Resource
private SuperLinkWebSocketServer superLinkWebSocketServer;
@Override
public void afterPropertiesSet() throws Exception {
if (!superLinkWebSocketServer.isRun()) {
superLinkWebSocketServer.start();
}
}
}
......@@ -2,12 +2,10 @@ package cn.com.poc.message.service.impl;
import cn.com.poc.agent_application.aggregate.AgentApplicationMallService;
import cn.com.poc.agent_application.entity.BizAgentApplicationPublishEntity;
import cn.com.poc.agent_application.service.BizAgentApplicationDialoguesRecordService;
import cn.com.poc.agent_application.service.BizAgentApplicationPublishService;
import cn.com.poc.message.entity.AgentApplicationClickEventMessage;
import cn.com.poc.message.service.AgentApplicationConsumerService;
import cn.com.poc.message.topic.AgentApplicationTopic;
import cn.com.yict.framemax.core.service.BaseService;
import cn.com.yict.framemax.tumbleweed.client.annotation.Consumer;
import org.springframework.stereotype.Service;
......
......@@ -3,7 +3,7 @@ package cn.com.poc.support.dgTools;
import cn.com.poc.common.constant.FmxParamConfigConstant;
import cn.com.poc.common.utils.ListUtils;
import cn.com.poc.common.utils.http.LocalHttpClient;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.route.DgtoolsApiRoute;
import cn.com.poc.support.dgTools.request.AbstractParam;
import cn.com.poc.support.dgTools.request.AbstractRequest;
import cn.com.poc.support.dgTools.request.ProjectTokenRequest;
......@@ -117,8 +117,8 @@ public class DgtoolsAbstractHttpClient {
}
HttpUriRequest httpUriRequest = post
.setHeader(DgtoolsApiConstants.JSON_HEADER)
.setUri(dgtoolsDomainurl + DgtoolsApiConstants.BASE_URL + url)
.setHeader(DgtoolsApiRoute.JSON_HEADER)
.setUri(dgtoolsDomainurl + DgtoolsApiRoute.BASE_URL + url)
.setEntity(new StringEntity(json, StandardCharsets.UTF_8))
.build();
DgtoolsApiResult dgtoolsApiResult = LocalHttpClient.executeJsonResult(httpUriRequest, DgtoolsApiResult.class);
......
......@@ -4,7 +4,7 @@ import cn.com.poc.common.constant.FmxParamConfigConstant;
import cn.com.poc.common.constant.MkpRedisKeyConstant;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.route.DgtoolsApiRoute;
import cn.com.poc.support.dgTools.request.ProjectTokenRequest;
import cn.com.poc.support.dgTools.result.ProjectTokenResult;
import cn.com.poc.support.dgTools.service.AuthorizationService;
......@@ -51,8 +51,8 @@ public class AuthorizationServiceImpl implements AuthorizationService {
request.setProjectKey(projectKey);
request.setProjectSecret(projectSecret);
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.PAY_HEADER);
ProjectTokenResult projectTokenResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.GET_APP_TOKEN, request, headers);
headers.add(DgtoolsApiRoute.PAY_HEADER);
ProjectTokenResult projectTokenResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.GET_APP_TOKEN, request, headers);
return projectTokenResult.getAppToken();
}
......
......@@ -8,7 +8,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.OpenAiResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.*;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.route.DgtoolsApiRoute;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
......@@ -84,10 +84,10 @@ public class AICreateImageServiceImpl implements AICreateImageService {
Assert.isTrue(request.getN() <= 10 && request.getN() >= 1, "The number of images to generate. Must be between 1 and 10.");
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.JSON_HEADER);
headers.add(DgtoolsApiConstants.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
OpenAiResult openAiResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_OPENAI_CREATE_IMAGE, request, headers);
headers.add(DgtoolsApiRoute.JSON_HEADER);
headers.add(DgtoolsApiRoute.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
OpenAiResult openAiResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.AI_OPENAI_CREATE_IMAGE, request, headers);
if (openAiResult != null) {
GenerationsResult generationsResult = JSONObject.parseObject(openAiResult.getMessage(), GenerationsResult.class);
return generationsResult;
......@@ -114,10 +114,10 @@ public class AICreateImageServiceImpl implements AICreateImageService {
Assert.isTrue(request.getNum() <= 6 && request.getNum() >= 1, "图片生成数量,支持1-6张");
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.JSON_HEADER);
headers.add(DgtoolsApiConstants.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduGetImageResult baiduGetImageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_CREATE_IMAGE, request, headers);
headers.add(DgtoolsApiRoute.JSON_HEADER);
headers.add(DgtoolsApiRoute.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduGetImageResult baiduGetImageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.AI_BAIDU_CREATE_IMAGE, request, headers);
if (baiduGetImageResult != null) {
GenerationsResult generationsResult = new GenerationsResult();
List<ImgUrl> imgUrls = baiduGetImageResult.getData().getImgUrls();
......@@ -150,10 +150,10 @@ public class AICreateImageServiceImpl implements AICreateImageService {
}
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.JSON_HEADER);
headers.add(DgtoolsApiConstants.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduGetImageV2Result baiduGetImageV2Result = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_CREATE_IMAGE_V2, request, headers);
headers.add(DgtoolsApiRoute.JSON_HEADER);
headers.add(DgtoolsApiRoute.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduGetImageV2Result baiduGetImageV2Result = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.AI_BAIDU_CREATE_IMAGE_V2, request, headers);
if (baiduGetImageV2Result != null) {
if (FAILED.equals(baiduGetImageV2Result.getData().getTaskStatus())) {
......@@ -195,10 +195,10 @@ public class AICreateImageServiceImpl implements AICreateImageService {
Assert.notNull(request.getPrompt(), "文生图配置异常,请联系开发人员");
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.JSON_HEADER);
headers.add(DgtoolsApiConstants.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduAISailsText2ImageResult imageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_SAILS_TEXT_CREATE_IMAGE, request, headers);
headers.add(DgtoolsApiRoute.JSON_HEADER);
headers.add(DgtoolsApiRoute.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
BaiduAISailsText2ImageResult imageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.AI_BAIDU_SAILS_TEXT_CREATE_IMAGE, request, headers);
if (imageResult == null) {
throw new I18nMessageException("exception/middle.platform.is.unresponsive");
}
......@@ -234,10 +234,10 @@ public class AICreateImageServiceImpl implements AICreateImageService {
String jsonBody = dgToolsAbstractHttpClient.buildJson(request);
CloseableHttpClient httpClient = HttpClients.createDefault();
CloseableHttpResponse httpResponse = httpClient.execute(RequestBuilder.create(POST)
.setUri(DOMAIN_URL + DgtoolsApiConstants.BASE_URL + DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_SAILS_IMAGE_CREATE_TEXT)
.addHeader(DgtoolsApiConstants.JSON_HEADER)
.addHeader(DgtoolsApiConstants.AI_HEADER)
.addHeader(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()))
.setUri(DOMAIN_URL + DgtoolsApiRoute.BASE_URL + DgtoolsApiRoute.DgtoolsAI.AI_BAIDU_SAILS_IMAGE_CREATE_TEXT)
.addHeader(DgtoolsApiRoute.JSON_HEADER)
.addHeader(DgtoolsApiRoute.AI_HEADER)
.addHeader(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()))
.setEntity(new StringEntity(jsonBody, StandardCharsets.UTF_8))
.build()
);
......
......@@ -8,7 +8,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDema
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.route.DgtoolsApiRoute;
import org.apache.http.Header;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.RequestBuilder;
......@@ -70,12 +70,12 @@ public class AIDialogueServiceImpl implements AIDialogueService {
@Override
public FunctionCallResult functionCall(FunctionCallResponse response) {
String url = DgtoolsApiConstants.DgtoolsAI.FUNCTION_CALL;
String url = DgtoolsApiRoute.DgtoolsAI.FUNCTION_CALL;
response.setApiKey(API_KEY);
List<Header> headers = new ArrayList<Header>() {{
add(DgtoolsApiConstants.JSON_HEADER);
add(DgtoolsApiConstants.AI_HEADER);
add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
add(DgtoolsApiRoute.JSON_HEADER);
add(DgtoolsApiRoute.AI_HEADER);
add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
}};
return dgToolsAbstractHttpClient.doRequest(url, response, headers);
}
......@@ -84,10 +84,10 @@ public class AIDialogueServiceImpl implements AIDialogueService {
String jsonBody = dgToolsAbstractHttpClient.buildJson(request);
CloseableHttpClient httpClient = HttpClients.createDefault();
CloseableHttpResponse httpResponse = httpClient.execute(RequestBuilder.create(POST)
.setUri(DOMAIN_URL + DgtoolsApiConstants.BASE_URL + DgtoolsApiConstants.DgtoolsAI.LARGE_MODEL)
.addHeader(DgtoolsApiConstants.JSON_HEADER)
.addHeader(DgtoolsApiConstants.AI_HEADER)
.addHeader(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()))
.setUri(DOMAIN_URL + DgtoolsApiRoute.BASE_URL + DgtoolsApiRoute.DgtoolsAI.LARGE_MODEL)
.addHeader(DgtoolsApiRoute.JSON_HEADER)
.addHeader(DgtoolsApiRoute.AI_HEADER)
.addHeader(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()))
.setEntity(new StringEntity(jsonBody, StandardCharsets.UTF_8))
.build()
);
......@@ -97,11 +97,11 @@ public class AIDialogueServiceImpl implements AIDialogueService {
}
private LargeModelDemandResult largeModelRequest(LargeModelDemandResponse request) {
String url = DgtoolsApiConstants.DgtoolsAI.LARGE_MODEL;
String url = DgtoolsApiRoute.DgtoolsAI.LARGE_MODEL;
List<Header> headers = new ArrayList<Header>() {{
add(DgtoolsApiConstants.JSON_HEADER);
add(DgtoolsApiConstants.AI_HEADER);
add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
add(DgtoolsApiRoute.JSON_HEADER);
add(DgtoolsApiRoute.AI_HEADER);
add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
}};
return dgToolsAbstractHttpClient.doRequest(url, request, headers);
}
......
package cn.com.poc.thirdparty.resource.demand.ai.aggregate.impl;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.thirdparty.resource.demand.ai.common.DgtoolsApiConstants;
import cn.com.poc.thirdparty.resource.demand.ai.route.DgtoolsApiRoute;
import cn.com.poc.support.dgTools.result.AbstractResult;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge.*;
......@@ -33,7 +33,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
TrainKnowledgeRequest request = new TrainKnowledgeRequest();
request.setDocumentUrl(fileURL);
request.setSegmentationConfig(segmentationConfig);
TrainKnowledgeResult trainKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.TRAIN_KNOWLEDGE, request, getHeaders());
TrainKnowledgeResult trainKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.TRAIN_KNOWLEDGE, request, getHeaders());
if (null == trainKnowledgeResult) {
throw new I18nMessageException("exception/abnormal.knowledge.base.training");
}
......@@ -46,7 +46,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
TrainKnowledgeRequest request = new TrainKnowledgeRequest();
request.setDocumentUrl(fileURL);
request.setSegmentationConfig(segmentationConfig);
TrainKnowledgeResult trainKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.TRAIN_KNOWLEDGE_EVENT, request, getHeaders());
TrainKnowledgeResult trainKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.TRAIN_KNOWLEDGE_EVENT, request, getHeaders());
if (null == trainKnowledgeResult) {
throw new I18nMessageException("exception/abnormal.knowledge.base.training");
}
......@@ -58,7 +58,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
Assert.notBlank(knowledgeId);
TrainKnowledgeStatusRequest request = new TrainKnowledgeStatusRequest();
request.setKnowledgeId(knowledgeId);
TrainKnowledgeStatusResult trainKnowledgeStatusResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.TRAIN_KNOWLEDGE_STATUS, request, getHeaders());
TrainKnowledgeStatusResult trainKnowledgeStatusResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.TRAIN_KNOWLEDGE_STATUS, request, getHeaders());
if (null == trainKnowledgeStatusResult) {
throw new I18nMessageException("exception/abnormal.training.status.of.knowledge.base.acquisition");
}
......@@ -70,7 +70,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
Assert.notBlank(knowledgeId);
DelKnowledgeRequest request = new DelKnowledgeRequest();
request.setKnowledgeId(knowledgeId);
AbstractResult abstractResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.DEL_KNOWLEDGE, request, getHeaders());
AbstractResult abstractResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.DEL_KNOWLEDGE, request, getHeaders());
if (null == abstractResult) {
throw new I18nMessageException("exception/delete.knowledge.base.exception");
}
......@@ -89,7 +89,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
searchKnowledgeRequest.setQuery(query);
searchKnowledgeRequest.setKnowLedgeIds(knowledgeIds);
searchKnowledgeRequest.setTopK(topK);
SearchKnowledgeResult searchKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders());
SearchKnowledgeResult searchKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.SEARCH_KNOWLEDGE, searchKnowledgeRequest, getHeaders());
if (null == searchKnowledgeResult) {
throw new I18nMessageException("exception/query.knowledge.base.exception");
}
......@@ -104,7 +104,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
GetKnowledgeChunkInfoRequest request = new GetKnowledgeChunkInfoRequest();
request.setKnowledgeIds(knowledgeIds);
request.setQuery(query);
return dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.GET_KNOWLEDGE_CHUNK_INFOS, request, getHeaders(), pagingInfo);
return dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.GET_KNOWLEDGE_CHUNK_INFOS, request, getHeaders(), pagingInfo);
}
@Override
......@@ -113,7 +113,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
request.setKnowledgeId(knowledgeId);
request.setChunkRelationId(chunkRelationId);
request.setIsOpen(isOpen);
dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.OPEN_KNOWLEDGE_CHUNK, request, getHeaders());
dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.OPEN_KNOWLEDGE_CHUNK, request, getHeaders());
}
@Override
......@@ -121,7 +121,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
UpsertChunkInfoRequest request = new UpsertChunkInfoRequest();
request.setKnowledgeId(knowledgeId);
request.setChunkRelationId(chunkRelationId);
dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.DELETE_KNOWLEDGE_CHUNK, request, getHeaders());
dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.DELETE_KNOWLEDGE_CHUNK, request, getHeaders());
}
@Override
......@@ -130,7 +130,7 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
request.setKnowledgeId(knowledgeId);
request.setChunkRelationId(chunkRelationId);
request.setChunkContent(content);
dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.UPDATE_KNOWLEDGE_CHUNK_DOC, request, getHeaders());
dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.UPDATE_KNOWLEDGE_CHUNK_DOC, request, getHeaders());
}
@Override
......@@ -139,14 +139,14 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
request.setKnowledgeId(knowledgeId);
request.setChunkSort(chunkSort);
request.setChunkContent(content);
dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.ADD_KNOWLEDGE_CHUNK, request, getHeaders());
dgToolsAbstractHttpClient.doRequest(DgtoolsApiRoute.DgtoolsAI.ADD_KNOWLEDGE_CHUNK, request, getHeaders());
}
private List<Header> getHeaders() {
List<Header> headers = new ArrayList<>();
headers.add(DgtoolsApiConstants.JSON_HEADER);
headers.add(DgtoolsApiConstants.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
headers.add(DgtoolsApiRoute.JSON_HEADER);
headers.add(DgtoolsApiRoute.AI_HEADER);
headers.add(new BasicHeader(DgtoolsApiRoute.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
return headers;
}
}
......@@ -8,7 +8,14 @@ import java.util.List;
public abstract class AbstractLargeModelFunction {
public abstract String doFunction(String content, String key);
/**
* 执行函数
*
* @param content 入参
* @param identifier 唯一标识
* @return
*/
public abstract String doFunction(String content, String identifier);
/**
* 获取函数描述
......@@ -22,9 +29,11 @@ public abstract class AbstractLargeModelFunction {
*/
public abstract List<String> getLLMConfig();
/**
* 获取有关变量的配置
*/
public abstract List<String> getVariableStructureLLMConfig(List<Variable> variableStructure);
public abstract List<String> getLLMConfig(List<Variable> variableStructure);
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.common.utils.SpringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.document_reader.DocumentReaderFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.document_understanding.DocumentUnderstandIngFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.html_reader.HtmlReaderFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.SetValueMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
public enum LargeModelFunctionEnum {
set_long_memory(SetLongMemoryFunction.class),
set_value_memory(SetValueMemoryFunction.class),
memory_variable_writer(MemoryVariableWriterFunction.class),
html_reader(HtmlReaderFunction.class),
document_reader(DocumentReaderFunction.class),
document_understanding(DocumentUnderstandIngFunction.class),
bing_web_search(null),
;
private Class<? extends AbstractLargeModelFunction> function;
LargeModelFunctionEnum(Class<? extends AbstractLargeModelFunction> function) {
......@@ -23,4 +31,13 @@ public enum LargeModelFunctionEnum {
public void setFunction(Class<AbstractLargeModelFunction> function) {
this.function = function;
}
public static LargeModelFunctionEnum getFunction(String functionName) {
for (LargeModelFunctionEnum value : LargeModelFunctionEnum.values()) {
if (value.name().equals(functionName)) {
return value;
}
}
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.document_reader;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.springframework.stereotype.Component;
import java.io.File;
import java.util.List;
@Component
public class DocumentReaderFunction extends AbstractLargeModelFunction {
private final String DESC = "文档阅读器,读取PDF、docx、doc、txt、md格式文件";
private final FunctionLLMConfig functionLLMConfig = new FunctionLLMConfig
.FunctionLLMConfigBuilder()
.name("document_reader")
.description(DESC)
.parameters(new Parameters("object")
.addProperties("file_urls", new Properties("array", "doc、docx、pdf、txt、md文件地址"))
).build();
@Override
public String doFunction(String content, String identifier) {
if (StringUtils.isBlank(content)) {
return StringUtils.EMPTY;
}
StringBuilder sb = new StringBuilder();
JSONObject jsonObject = JSON.parseObject(content);
if (jsonObject.containsKey("file_urls")) {
JSONArray jsonArray = jsonObject.getJSONArray("file_urls");
int size = jsonArray.size();
if (size == 0) {
return StringUtils.EMPTY;
}
for (int i = 0; i < size; i++) {
String fileUrl = jsonArray.getString(i);
File file = DocumentLoad.downloadURLDocument(fileUrl);
sb.append(StringUtils.LF).append("## Document ").append((i + 1)).append(StringUtils.LF);
sb.append(DocumentLoad.documentToText(file));
}
return sb.toString();
}
return StringUtils.EMPTY;
}
@Override
public String getDesc() {
return DESC;
}
@Override
public List<String> getLLMConfig() {
return ListUtil.toList(JsonUtils.serialize(this.functionLLMConfig));
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return this.getLLMConfig();
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.entity;
public class Function {
/**
* 函数名
*/
private String name;
/**
* 函数参数
*/
private Parameters parameters;
/**
* 函数描述
*/
private String description;
public Function() {
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Parameters getParameters() {
return parameters;
}
public void setParameters(Parameters parameters) {
this.parameters = parameters;
}
public String getDescription() {
return description;
}
public void setDescription(String description) {
this.description = description;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.long_memory;
import java.io.Serializable;
import java.util.Date;
public class LongMemoryEntity implements Serializable {
public class AgentLongMemoryEntity implements Serializable {
private String content;
......
package cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer;
public interface MemoryVariableWriterConstants {
String REDIS_PREFIX = "AGENT_APP_FUNCTION:MEMORY_VARIABLE:";
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment