Commit e24d793b authored by alex yao's avatar alex yao

feat:Agent应用插件功能

parent 1366bf51
...@@ -333,6 +333,13 @@ ...@@ -333,6 +333,13 @@
<artifactId>google-api-client</artifactId> <artifactId>google-api-client</artifactId>
<version>2.4.0</version> <version>2.4.0</version>
</dependency> </dependency>
<dependency>
<groupId>com.google.apis</groupId>
<artifactId>google-api-services-customsearch</artifactId>
<version>v1-rev20240821-2.0.0</version>
</dependency>
</dependencies> </dependencies>
......
...@@ -45,12 +45,13 @@ public interface AgentApplicationInfoService { ...@@ -45,12 +45,13 @@ public interface AgentApplicationInfoService {
* @param temperature 模型参数temperature * @param temperature 模型参数temperature
* @param messages 对话消息 * @param messages 对话消息
* @param tools 插件配置 * @param tools 插件配置
* @param fileUrls 文件URL * @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输 * @param stream 是否流式传输
*/ */
String callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem, String callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature, Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, HttpServletResponse httpServletResponse) throws Exception; List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* 应用下架 * 应用下架
......
...@@ -23,6 +23,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum; ...@@ -23,6 +23,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest; import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult; import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult;
...@@ -37,6 +38,7 @@ import cn.com.yict.framemax.core.context.Context; ...@@ -37,6 +38,7 @@ import cn.com.yict.framemax.core.context.Context;
import cn.com.yict.framemax.core.i18n.I18nMessageException; import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.com.yict.framemax.data.model.PagingInfo; import cn.com.yict.framemax.data.model.PagingInfo;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
...@@ -96,6 +98,9 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -96,6 +98,9 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Resource @Resource
private BizAgentApplicationDialoguesRecordService bizAgentApplicationDialoguesRecordService; private BizAgentApplicationDialoguesRecordService bizAgentApplicationDialoguesRecordService;
@Resource
private BizAgentApplicationPluginService bizAgentApplicationPluginService;
@Override @Override
public BizAgentApplicationInfoEntity saveOrUpdate(BizAgentApplicationInfoEntity entity) throws Exception { public BizAgentApplicationInfoEntity saveOrUpdate(BizAgentApplicationInfoEntity entity) throws Exception {
...@@ -163,7 +168,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -163,7 +168,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Override @Override
public String callAgentApplication(String agentId, String dialogueId, String largeModel, public String callAgentApplication(String agentId, String dialogueId, String largeModel,
String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature, String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, HttpServletResponse httpServletResponse) throws Exception { List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception {
logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}" logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}"
, agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools); , agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools);
...@@ -172,13 +177,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -172,13 +177,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
Tool[] toolArray = tools.toArray(new Tool[0]); Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls); FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls, imageUrls);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId); String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate); Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, httpServletResponse); return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
} }
...@@ -207,7 +212,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -207,7 +212,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
largeModelResponse.setUser("POC-CREATE-AGENT-SYSTEM"); largeModelResponse.setUser("POC-CREATE-AGENT-SYSTEM");
BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse); BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse);
textOutputStream(httpServletResponse, bufferedReader); textOutputStream(httpServletResponse, bufferedReader, null);
} }
@Override @Override
...@@ -500,16 +505,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -500,16 +505,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param stream 是否流式输出 * @param stream 是否流式输出
* @param model 模型 * @param model 模型
* @param messageArray 消息 * @param messageArray 消息
* @param functionResult 函数结果
* @param httpServletResponse 响应 * @param httpServletResponse 响应
* @return 输出结果 * @return 输出结果
* @throws Exception * @throws Exception
*/ */
private String llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, HttpServletResponse httpServletResponse) throws Exception { private String llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, HttpServletResponse httpServletResponse) throws Exception {
if (stream) { if (stream) {
BufferedReader bufferedReader = invokeLLMStream(model, messageArray, topP); BufferedReader bufferedReader = invokeLLMStream(model, messageArray, topP);
return textOutputStream(httpServletResponse, bufferedReader); return textOutputStream(httpServletResponse, bufferedReader, functionResult);
} else { } else {
LargeModelDemandResult largeModelDemandResult = invokeLLM(model, messageArray, topP); LargeModelDemandResult largeModelDemandResult = invokeLLM(model, messageArray, topP);
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
largeModelDemandResult.setFunction(toolFunction);
}
return textOutput(httpServletResponse, largeModelDemandResult); return textOutput(httpServletResponse, largeModelDemandResult);
} }
} }
...@@ -735,12 +745,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -735,12 +745,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param bufferedReader * @param bufferedReader
* @throws IOException * @throws IOException
*/ */
private String textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws private String textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader, FunctionResult functionResult) throws
IOException { IOException {
String res = ""; String res = "";
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8); httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter(); PrintWriter writer = httpServletResponse.getWriter();
StringBuilder output = new StringBuilder(); StringBuilder output = new StringBuilder();
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
result.setFunction(toolFunction);
writer.write(EVENT_STREAM_PREFIX + JsonUtils.serialize(result) + "\n\n");
writer.flush();
}
while ((res = bufferedReader.readLine()) != null) { while ((res = bufferedReader.readLine()) != null) {
if (StringUtils.isBlank(res)) { if (StringUtils.isBlank(res)) {
continue; continue;
...@@ -767,6 +785,27 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -767,6 +785,27 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return output.toString(); return output.toString();
} }
private ToolFunction functionResultConvertToolFunction(FunctionResult functionResult) {
ToolFunction toolFunction = new ToolFunction();
BizAgentApplicationPluginEntity bizAgentApplicationPluginEntity = bizAgentApplicationPluginService.getInfoById(functionResult.getFunctionName());
if (bizAgentApplicationPluginEntity != null && !bizAgentApplicationPluginEntity.getClassification().equals("system")) {
Locale currentLocale = Context.get().getMessageSource().getCurrentLocale();
String languageTag = currentLocale.toLanguageTag();
switch (languageTag) {
case "zh-CN":
toolFunction.setName(bizAgentApplicationPluginEntity.getZhCnTitle());
break;
case "en":
toolFunction.setName(bizAgentApplicationPluginEntity.getZhTwTitle());
break;
case "zh-TW":
toolFunction.setName(bizAgentApplicationPluginEntity.getEnTitle());
break;
}
}
return toolFunction;
}
/** /**
* 构建消息【大模型参数】 * 构建消息【大模型参数】
* *
...@@ -826,8 +865,10 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -826,8 +865,10 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param messages * @param messages
* @param tools * @param tools
* @param agentId * @param agentId
* @param fileUrls
* @param imageUrls
*/ */
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls) { private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls, List<String> imageUrls) {
FunctionResult result = new FunctionResult(); FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) { if (ArrayUtils.isEmpty(tools)) {
return result; return result;
...@@ -839,9 +880,12 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ ...@@ -839,9 +880,12 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
} else { } else {
query = content.toString(); query = content.toString();
} }
query = "用户输入:" + query + "\n";
if (CollectionUtils.isNotEmpty(fileUrls)) { if (CollectionUtils.isNotEmpty(fileUrls)) {
query = JsonUtils.serialize(fileUrls) + query; query = query + "用户上传文件地址:" + JsonUtils.serialize(fileUrls) + "\n";
}
if (CollectionUtils.isNotEmpty(imageUrls)) {
query = query + "用户上传图片地址:" + JsonUtils.serialize(imageUrls);
} }
FunctionCallResult functionCallResult = llmService.functionCall(query, tools); FunctionCallResult functionCallResult = llmService.functionCall(query, tools);
......
...@@ -14,7 +14,6 @@ import cn.com.poc.common.constant.CommonConstant; ...@@ -14,7 +14,6 @@ import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.BlContext; import cn.com.poc.common.utils.BlContext;
import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService; import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService;
import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum; import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum;
import cn.com.poc.data_analyze.constants.DataAnalyzeTypeEnum;
import cn.com.poc.equity.aggregate.MemberEquityService; import cn.com.poc.equity.aggregate.MemberEquityService;
import cn.com.poc.equity.aggregate.PointDeductionRulesService; import cn.com.poc.equity.aggregate.PointDeductionRulesService;
import cn.com.poc.equity.constants.ModifyEventEnum; import cn.com.poc.equity.constants.ModifyEventEnum;
...@@ -249,6 +248,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -249,6 +248,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//配置对话function //配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogueId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing()); List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogueId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.getMessageImageUrl(dto.getMessages());
//对话大模型配置 //对话大模型配置
String model = StringUtils.isNotBlank(dto.getModelNickName()) ? dto.getModelNickName() : infoEntity.getLargeModel(); String model = StringUtils.isNotBlank(dto.getModelNickName()) ? dto.getModelNickName() : infoEntity.getLargeModel();
Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP(); Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP();
...@@ -264,7 +266,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest { ...@@ -264,7 +266,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务 //调用应用服务
agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model, agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), topP, agentSystem, kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), topP,
temperature, dto.getMessages(), tools, dto.getFileUrls(), true, httpServletResponse); temperature, dto.getMessages(), tools, dto.getFileUrls(), true, imageUrls, httpServletResponse);
//数据采集 //数据采集
if (StringUtils.isBlank(dto.getChannel())) { if (StringUtils.isBlank(dto.getChannel())) {
dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel()); dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel());
......
...@@ -4,10 +4,13 @@ import cn.com.poc.agent_application.entity.Variable; ...@@ -4,10 +4,13 @@ import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.constant.CommonConstant; import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.DocumentLoad; import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils; import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum; import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriter; import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriter;
import cn.com.yict.framemax.core.i18n.I18nMessageException; import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
...@@ -16,6 +19,7 @@ import org.apache.commons.lang3.StringUtils; ...@@ -16,6 +19,7 @@ import org.apache.commons.lang3.StringUtils;
import java.io.File; import java.io.File;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
...@@ -119,4 +123,36 @@ public class AgentApplicationTools { ...@@ -119,4 +123,36 @@ public class AgentApplicationTools {
} }
} }
} }
/**
* 获取Message对话的图片地址
*
* @param messages 对话
* @return 返回图片地址列表,若无则返回null
*/
public static List<String> getMessageImageUrl(List<Message> messages) {
List<String> imageUrls = null;
Message mess = messages.get(messages.size() - 1);
if (!(mess.getContent() instanceof String) && ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0)).containsKey("image_url")) {
LinkedHashMap map = ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0));
MultiContent multiContent = JSONObject.parseObject(JsonUtils.serialize(map), MultiContent.class);
if (ObjectUtil.isNotNull(multiContent.getImageUrl()) && StringUtils.isNotBlank(multiContent.getImageUrl().getUrl())) {
imageUrls = new ArrayList<>();
imageUrls.add(multiContent.getImageUrl().getUrl());
}
}
return imageUrls;
}
/**
* 构建ImageUrls
*/
public static List<String> buildImageUrls(String imageUrl) {
List<String> imageUrls = null;
if (StringUtils.isNotBlank(imageUrl)) {
imageUrls = new ArrayList<>();
imageUrls.add(imageUrl);
}
return imageUrls;
}
} }
...@@ -51,7 +51,7 @@ public class DocumentLoad { ...@@ -51,7 +51,7 @@ public class DocumentLoad {
String htmlStr = sb.toString(); String htmlStr = sb.toString();
return converter.convert(htmlStr); return converter.convert(htmlStr);
} catch (IOException e) { } catch (IOException e) {
throw new I18nMessageException(e.getMessage()); return "";
} }
} }
......
...@@ -32,9 +32,10 @@ public interface AgentApplicationApiService { ...@@ -32,9 +32,10 @@ public interface AgentApplicationApiService {
* @param fileIds 文件ID列表 * @param fileIds 文件ID列表
* @param query 消息 * @param query 消息
* @param stream 是否流式输出 * @param stream 是否流式输出
* @param imageUrl 图片URL
* @param httpServletResponse * @param httpServletResponse
*/ */
void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception; void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* 上传文件 * 上传文件
......
...@@ -19,8 +19,9 @@ public interface AgentApplicationService { ...@@ -19,8 +19,9 @@ public interface AgentApplicationService {
* @param input 用户输入 * @param input 用户输入
* @param fileUrls 文件URL * @param fileUrls 文件URL
* @param channel 渠道 * @param channel 渠道
* @param imageUrl 图片URL
*/ */
void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, HttpServletResponse httpServletResponse) throws Exception; void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, String imageUrl, HttpServletResponse httpServletResponse) throws Exception;
/** /**
* 追问AI生成 * 追问AI生成
......
...@@ -8,11 +8,9 @@ import cn.com.poc.agent_application.utils.AgentApplicationTools; ...@@ -8,11 +8,9 @@ import cn.com.poc.agent_application.utils.AgentApplicationTools;
import cn.com.poc.common.constant.CommonConstant; import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.service.BosConfigService; import cn.com.poc.common.service.BosConfigService;
import cn.com.poc.common.utils.DateUtils; import cn.com.poc.common.utils.DateUtils;
import cn.com.poc.common.utils.FileUtils;
import cn.com.poc.common.utils.UUIDTool; import cn.com.poc.common.utils.UUIDTool;
import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService; import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService;
import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum; import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum;
import cn.com.poc.data_analyze.constants.DataAnalyzeTypeEnum;
import cn.com.poc.equity.aggregate.MemberEquityService; import cn.com.poc.equity.aggregate.MemberEquityService;
import cn.com.poc.equity.aggregate.PointDeductionRulesService; import cn.com.poc.equity.aggregate.PointDeductionRulesService;
import cn.com.poc.equity.constants.ModifyEventEnum; import cn.com.poc.equity.constants.ModifyEventEnum;
...@@ -23,7 +21,6 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum; ...@@ -23,7 +21,6 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool; import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.yict.framemax.core.exception.BusinessException; import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.FileUtil;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
...@@ -33,8 +30,10 @@ import javax.annotation.Resource; ...@@ -33,8 +30,10 @@ import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
import java.io.File; import java.io.File;
import java.math.BigDecimal; import java.util.ArrayList;
import java.util.*; import java.util.Date;
import java.util.List;
import java.util.UUID;
/** /**
* @author alex.yao * @author alex.yao
...@@ -104,7 +103,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -104,7 +103,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
} }
@Override @Override
public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception { public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception {
BizAgentApplicationApiProfileEntity profileEntity = bizAgentApplicationApiProfileService.getByKeyAndSecret(apiKey, apiSecret); BizAgentApplicationApiProfileEntity profileEntity = bizAgentApplicationApiProfileService.getByKeyAndSecret(apiKey, apiSecret);
if (profileEntity == null) { if (profileEntity == null) {
throw new BusinessException("无效的API Key或Secret"); throw new BusinessException("无效的API Key或Secret");
...@@ -136,6 +135,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -136,6 +135,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
//配置对话function //配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), conversationId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing()); List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), conversationId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
// 保存用户输入记录 // 保存用户输入记录
Long inputTimestamp = System.currentTimeMillis(); Long inputTimestamp = System.currentTimeMillis();
...@@ -151,7 +153,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic ...@@ -151,7 +153,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try { try {
String output = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(), String output = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, stream, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, stream, imageUrls, httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, output); saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, output);
} catch (Exception e) { } catch (Exception e) {
memberEquityService.rollbackPoint(reduceSn); memberEquityService.rollbackPoint(reduceSn);
......
...@@ -110,7 +110,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -110,7 +110,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
@Override @Override
public void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, HttpServletResponse httpServletResponse) throws Exception { public void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, String imageUrl, HttpServletResponse httpServletResponse) throws Exception {
UserBaseEntity userBaseEntity = BlContext.getCurrentUserNotException(); UserBaseEntity userBaseEntity = BlContext.getCurrentUserNotException();
if (userBaseEntity == null) { if (userBaseEntity == null) {
...@@ -152,6 +152,8 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -152,6 +152,8 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//配置对话function //配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogsId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing()); List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogsId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
// 获取图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
//记录输出时间戳 //记录输出时间戳
BizAgentApplicationDialoguesRecordEntity outputRecord = new BizAgentApplicationDialoguesRecordEntity(); BizAgentApplicationDialoguesRecordEntity outputRecord = new BizAgentApplicationDialoguesRecordEntity();
...@@ -179,7 +181,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService { ...@@ -179,7 +181,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//对话 //对话
String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(), String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, true, httpServletResponse); infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, true, imageUrls, httpServletResponse);
//保存对话记录 //保存对话记录
outputRecord.setContent(output); outputRecord.setContent(output);
......
...@@ -44,6 +44,16 @@ public class AgentApplicationDto { ...@@ -44,6 +44,16 @@ public class AgentApplicationDto {
this.fileUrls = fileUrls; this.fileUrls = fileUrls;
} }
private String imageUrl;
public String getImageUrl() {
return imageUrl;
}
public void setImageUrl(String imageUrl) {
this.imageUrl = imageUrl;
}
private String channel; private String channel;
public String getChannel() { public String getChannel() {
......
...@@ -23,11 +23,23 @@ public class CompletionsDto { ...@@ -23,11 +23,23 @@ public class CompletionsDto {
*/ */
private String query; private String query;
/**
* 图片
*/
private String imageUrl;
/** /**
* 是否流式传输 * 是否流式传输
*/ */
private Boolean stream; private Boolean stream;
public String getImageUrl() {
return imageUrl;
}
public void setImageUrl(String imageUrl) {
this.imageUrl = imageUrl;
}
public String getConversationId() { public String getConversationId() {
return conversationId; return conversationId;
......
...@@ -111,7 +111,7 @@ public class AgentApplicationRestImpl implements AgentApplicationRest { ...@@ -111,7 +111,7 @@ public class AgentApplicationRestImpl implements AgentApplicationRest {
Assert.notNull(dto.getAgentId()); Assert.notNull(dto.getAgentId());
Assert.notNull(dto.getDialogsId()); Assert.notNull(dto.getDialogsId());
try { try {
agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), dto.getFileUrls(), dto.getChannel(), httpServletResponse); agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), dto.getFileUrls(), dto.getChannel(), dto.getImageUrl(), httpServletResponse);
} catch (Exception e) { } catch (Exception e) {
httpServletResponse.setContentType("text/event-stream"); httpServletResponse.setContentType("text/event-stream");
PrintWriter writer = httpServletResponse.getWriter(); PrintWriter writer = httpServletResponse.getWriter();
......
...@@ -49,7 +49,7 @@ public class ModelLinkRestImpl implements ModelLinkRest { ...@@ -49,7 +49,7 @@ public class ModelLinkRestImpl implements ModelLinkRest {
if (StringUtils.isNotBlank(dto.getFileId())) { if (StringUtils.isNotBlank(dto.getFileId())) {
fileIds.add(dto.getFileId()); fileIds.add(dto.getFileId());
} }
agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), httpServletResponse); agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), dto.getImageUrl(), httpServletResponse);
} }
@Override @Override
......
...@@ -6,6 +6,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.function.document_understanding. ...@@ -6,6 +6,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.function.document_understanding.
import cn.com.poc.thirdparty.resource.demand.ai.function.html_reader.HtmlReaderFunction; import cn.com.poc.thirdparty.resource.demand.ai.function.html_reader.HtmlReaderFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction; import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction; import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.web_seach.WebSearchFunction;
public enum LargeModelFunctionEnum { public enum LargeModelFunctionEnum {
set_long_memory(SetLongMemoryFunction.class), set_long_memory(SetLongMemoryFunction.class),
...@@ -13,7 +14,7 @@ public enum LargeModelFunctionEnum { ...@@ -13,7 +14,7 @@ public enum LargeModelFunctionEnum {
html_reader(HtmlReaderFunction.class), html_reader(HtmlReaderFunction.class),
document_reader(DocumentReaderFunction.class), document_reader(DocumentReaderFunction.class),
document_understanding(DocumentUnderstandIngFunction.class), document_understanding(DocumentUnderstandIngFunction.class),
web_search(WebSearchFunction.class),
bing_web_search(null), bing_web_search(null),
; ;
......
package cn.com.poc.thirdparty.resource.demand.ai.function.bing_web_search;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* @author alex.yao
* @date 2025/1/14
*/
@Component
public class BingWebSearchFunction extends AbstractLargeModelFunction {
@Override
public String doFunction(String content, String identifier) {
return null;
}
@Override
public String getDesc() {
return null;
}
@Override
public List<String> getLLMConfig() {
return null;
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.web_seach;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.fastjson.JSONObject;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.apache.v2.ApacheHttpTransport;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.gson.GsonFactory;
import com.google.api.services.customsearch.v1.CustomSearchAPI;
import com.google.api.services.customsearch.v1.CustomSearchAPIRequestInitializer;
import com.google.api.services.customsearch.v1.model.Result;
import com.google.api.services.customsearch.v1.model.Search;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
/**
* @author alex.yao
* @date 2025/1/14
*/
@Component
public class WebSearchFunction extends AbstractLargeModelFunction {
private final HttpTransport transport = new ApacheHttpTransport();
private final JsonFactory jsonFactory = new GsonFactory();
private final String DESC = "该方法是通过Bing所有对用户给出的关键词进行网站、网页搜索检索获取网页内容";
private final FunctionLLMConfig functionLLMConfig = new FunctionLLMConfig.FunctionLLMConfigBuilder()
.name("web_search")
.description(DESC)
.parameters(new Parameters("object")
.addProperties("query", new Properties("string", "搜索关键词")))
.build();
@Override
public String doFunction(String content, String identifier) {
if (StringUtils.isBlank(content)) {
return "FAIL";
}
JSONObject jsonObject = JSONObject.parseObject(content);
String query = jsonObject.getString("query");
List<WebSearchFunctionResult> results = new ArrayList<>();
try {
CustomSearchAPI customSearchAPI = new CustomSearchAPI.Builder(transport, jsonFactory, GoogleNetHttpTransport.newTrustedTransport().createRequestFactory().getInitializer())
.setRootUrl("https://google-api.gsstcloud.com")
.setCustomSearchAPIRequestInitializer(new CustomSearchAPIRequestInitializer("AIzaSyCV8PTQ10rG5wo4E004dR3mcGD1RM_PrBw"))
.build();
Search execute = customSearchAPI
.cse().list().setCx("049026ecb26e840ed")
.setKey("AIzaSyCV8PTQ10rG5wo4E004dR3mcGD1RM_PrBw")
.setQ(query)
.setStart(1L)
.setNum(3)
.execute();
List<Result> items = execute.getItems();
for (Result item : items) {
String link = item.getLink();
String htmlContent = DocumentLoad.htmlToMarkdown(link);
if (StringUtils.isNotBlank(htmlContent)) {
htmlContent = htmlContent.replaceAll(StringUtils.SPACE, StringUtils.EMPTY);
}
String title = item.getTitle();
String snippet = item.getSnippet();
WebSearchFunctionResult webSearchResult = new WebSearchFunctionResult();
webSearchResult.setTitle(title);
webSearchResult.setUrl(link);
webSearchResult.setSnippet(snippet);
webSearchResult.setContent(htmlContent);
results.add(webSearchResult);
}
return JsonUtils.serialize(results);
} catch (Exception e) {
return JsonUtils.serialize(results);
}
}
@Override
public String getDesc() {
return DESC;
}
@Override
public List<String> getLLMConfig() {
return ListUtil.toList(JsonUtils.serialize(functionLLMConfig));
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return this.getLLMConfig();
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.web_seach;
/**
* @author alex.yao
* @date 2025/1/14
*/
public class WebSearchFunctionResult {
private String title;
private String url;
private String snippet;
private String content;
public String getTitle() {
return title;
}
public void setTitle(String title) {
this.title = title;
}
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getSnippet() {
return snippet;
}
public void setSnippet(String snippet) {
this.snippet = snippet;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.thirdparty.resource.demand.ai.function.web_seach.WebSearchFunction;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import java.io.IOException;
import java.security.GeneralSecurityException;
/**
* @author alex.yao
* @date 2025/1/14
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(initializers = SingleContextInitializer.class)
@WebAppConfiguration
public class GoogleSearchFunctionTest {
@Test
public void testSearch() throws IOException, GeneralSecurityException {
WebSearchFunction webSearchFunction = new WebSearchFunction();
System.out.println(webSearchFunction.doFunction("{\"query\":\"什么是区块链\"}", "1"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment