Commit e24d793b authored by alex yao's avatar alex yao

feat:Agent应用插件功能

parent 1366bf51
......@@ -333,6 +333,13 @@
<artifactId>google-api-client</artifactId>
<version>2.4.0</version>
</dependency>
<dependency>
<groupId>com.google.apis</groupId>
<artifactId>google-api-services-customsearch</artifactId>
<version>v1-rev20240821-2.0.0</version>
</dependency>
</dependencies>
......
......@@ -45,12 +45,13 @@ public interface AgentApplicationInfoService {
* @param temperature 模型参数temperature
* @param messages 对话消息
* @param tools 插件配置
* @param fileUrls 文件URL
* @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输
*/
String callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, HttpServletResponse httpServletResponse) throws Exception;
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception;
/**
* 应用下架
......
......@@ -23,6 +23,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.ToolFunction;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageRequest;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsText2ImageResult;
......@@ -37,6 +38,7 @@ import cn.com.yict.framemax.core.context.Context;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.com.yict.framemax.data.model.PagingInfo;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.ObjectUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
......@@ -96,6 +98,9 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Resource
private BizAgentApplicationDialoguesRecordService bizAgentApplicationDialoguesRecordService;
@Resource
private BizAgentApplicationPluginService bizAgentApplicationPluginService;
@Override
public BizAgentApplicationInfoEntity saveOrUpdate(BizAgentApplicationInfoEntity entity) throws Exception {
......@@ -163,7 +168,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Override
public String callAgentApplication(String agentId, String dialogueId, String largeModel,
String agentSystem, Integer[] kdIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, HttpServletResponse httpServletResponse) throws Exception {
List<Message> messages, List<Tool> tools, List<String> fileUrls, boolean stream, List<String> imageUrls, HttpServletResponse httpServletResponse) throws Exception {
logger.info("Call Agent Application, agentId:{}, dialogueId:{},largeModel:{},agentSystem:{},kdIds:{},communicationTurn:{},topP:{},temperature:{},messages:{}, tools:{}"
, agentId, dialogueId, largeModel, agentSystem, kdIds, communicationTurn, topP, temperature, messages, tools);
......@@ -172,13 +177,13 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
Tool[] toolArray = tools.toArray(new Tool[0]);
FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls);
FunctionResult functionResult = functionCall(dialogueId, messages, toolArray, agentId, fileUrls, imageUrls);
String promptTemplate = buildDialogsPrompt(functionResult, messages, agentSystem, kdIds, toolArray, dialogueId, agentId);
Message[] messageArray = buildMessages(messages, communicationTurn, promptTemplate);
return llmExecutorAndOutput(topP, stream, model, messageArray, httpServletResponse);
return llmExecutorAndOutput(topP, stream, model, messageArray, functionResult, httpServletResponse);
}
......@@ -207,7 +212,7 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
largeModelResponse.setUser("POC-CREATE-AGENT-SYSTEM");
BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse);
textOutputStream(httpServletResponse, bufferedReader);
textOutputStream(httpServletResponse, bufferedReader, null);
}
@Override
......@@ -500,16 +505,21 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param stream 是否流式输出
* @param model 模型
* @param messageArray 消息
* @param functionResult 函数结果
* @param httpServletResponse 响应
* @return 输出结果
* @throws Exception
*/
private String llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, HttpServletResponse httpServletResponse) throws Exception {
private String llmExecutorAndOutput(Float topP, boolean stream, String model, Message[] messageArray, FunctionResult functionResult, HttpServletResponse httpServletResponse) throws Exception {
if (stream) {
BufferedReader bufferedReader = invokeLLMStream(model, messageArray, topP);
return textOutputStream(httpServletResponse, bufferedReader);
return textOutputStream(httpServletResponse, bufferedReader, functionResult);
} else {
LargeModelDemandResult largeModelDemandResult = invokeLLM(model, messageArray, topP);
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
largeModelDemandResult.setFunction(toolFunction);
}
return textOutput(httpServletResponse, largeModelDemandResult);
}
}
......@@ -735,12 +745,20 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param bufferedReader
* @throws IOException
*/
private String textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader) throws
private String textOutputStream(HttpServletResponse httpServletResponse, BufferedReader bufferedReader, FunctionResult functionResult) throws
IOException {
String res = "";
httpServletResponse.setContentType(TEXT_EVENT_STREAM_CHARSET_UTF_8);
PrintWriter writer = httpServletResponse.getWriter();
StringBuilder output = new StringBuilder();
if (ObjectUtil.isNotNull(functionResult) && StringUtils.isNotBlank(functionResult.getFunctionName())) {
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
ToolFunction toolFunction = functionResultConvertToolFunction(functionResult);
result.setFunction(toolFunction);
writer.write(EVENT_STREAM_PREFIX + JsonUtils.serialize(result) + "\n\n");
writer.flush();
}
while ((res = bufferedReader.readLine()) != null) {
if (StringUtils.isBlank(res)) {
continue;
......@@ -767,6 +785,27 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
return output.toString();
}
private ToolFunction functionResultConvertToolFunction(FunctionResult functionResult) {
ToolFunction toolFunction = new ToolFunction();
BizAgentApplicationPluginEntity bizAgentApplicationPluginEntity = bizAgentApplicationPluginService.getInfoById(functionResult.getFunctionName());
if (bizAgentApplicationPluginEntity != null && !bizAgentApplicationPluginEntity.getClassification().equals("system")) {
Locale currentLocale = Context.get().getMessageSource().getCurrentLocale();
String languageTag = currentLocale.toLanguageTag();
switch (languageTag) {
case "zh-CN":
toolFunction.setName(bizAgentApplicationPluginEntity.getZhCnTitle());
break;
case "en":
toolFunction.setName(bizAgentApplicationPluginEntity.getZhTwTitle());
break;
case "zh-TW":
toolFunction.setName(bizAgentApplicationPluginEntity.getEnTitle());
break;
}
}
return toolFunction;
}
/**
* 构建消息【大模型参数】
*
......@@ -826,8 +865,10 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
* @param messages
* @param tools
* @param agentId
* @param fileUrls
* @param imageUrls
*/
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls) {
private FunctionResult functionCall(String dialogueId, List<Message> messages, Tool[] tools, String agentId, List<String> fileUrls, List<String> imageUrls) {
FunctionResult result = new FunctionResult();
if (ArrayUtils.isEmpty(tools)) {
return result;
......@@ -839,9 +880,12 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
} else {
query = content.toString();
}
query = "用户输入:" + query + "\n";
if (CollectionUtils.isNotEmpty(fileUrls)) {
query = JsonUtils.serialize(fileUrls) + query;
query = query + "用户上传文件地址:" + JsonUtils.serialize(fileUrls) + "\n";
}
if (CollectionUtils.isNotEmpty(imageUrls)) {
query = query + "用户上传图片地址:" + JsonUtils.serialize(imageUrls);
}
FunctionCallResult functionCallResult = llmService.functionCall(query, tools);
......
......@@ -14,7 +14,6 @@ import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService;
import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum;
import cn.com.poc.data_analyze.constants.DataAnalyzeTypeEnum;
import cn.com.poc.equity.aggregate.MemberEquityService;
import cn.com.poc.equity.aggregate.PointDeductionRulesService;
import cn.com.poc.equity.constants.ModifyEventEnum;
......@@ -249,6 +248,9 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogueId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.getMessageImageUrl(dto.getMessages());
//对话大模型配置
String model = StringUtils.isNotBlank(dto.getModelNickName()) ? dto.getModelNickName() : infoEntity.getLargeModel();
Float topP = dto.getTopP() == null ? infoEntity.getTopP() : dto.getTopP();
......@@ -264,7 +266,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务
agentApplicationInfoService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), infoEntity.getCommunicationTurn(), topP,
temperature, dto.getMessages(), tools, dto.getFileUrls(), true, httpServletResponse);
temperature, dto.getMessages(), tools, dto.getFileUrls(), true, imageUrls, httpServletResponse);
//数据采集
if (StringUtils.isBlank(dto.getChannel())) {
dto.setChannel(DataAnalyzeChannelEnum.preview.getChannel());
......
......@@ -4,10 +4,13 @@ import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.MultiContent;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriter;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.core.util.ObjectUtil;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
......@@ -16,6 +19,7 @@ import org.apache.commons.lang3.StringUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
......@@ -119,4 +123,36 @@ public class AgentApplicationTools {
}
}
}
/**
* 获取Message对话的图片地址
*
* @param messages 对话
* @return 返回图片地址列表,若无则返回null
*/
public static List<String> getMessageImageUrl(List<Message> messages) {
List<String> imageUrls = null;
Message mess = messages.get(messages.size() - 1);
if (!(mess.getContent() instanceof String) && ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0)).containsKey("image_url")) {
LinkedHashMap map = ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0));
MultiContent multiContent = JSONObject.parseObject(JsonUtils.serialize(map), MultiContent.class);
if (ObjectUtil.isNotNull(multiContent.getImageUrl()) && StringUtils.isNotBlank(multiContent.getImageUrl().getUrl())) {
imageUrls = new ArrayList<>();
imageUrls.add(multiContent.getImageUrl().getUrl());
}
}
return imageUrls;
}
/**
* 构建ImageUrls
*/
public static List<String> buildImageUrls(String imageUrl) {
List<String> imageUrls = null;
if (StringUtils.isNotBlank(imageUrl)) {
imageUrls = new ArrayList<>();
imageUrls.add(imageUrl);
}
return imageUrls;
}
}
......@@ -51,7 +51,7 @@ public class DocumentLoad {
String htmlStr = sb.toString();
return converter.convert(htmlStr);
} catch (IOException e) {
throw new I18nMessageException(e.getMessage());
return "";
}
}
......
......@@ -32,9 +32,10 @@ public interface AgentApplicationApiService {
* @param fileIds 文件ID列表
* @param query 消息
* @param stream 是否流式输出
* @param imageUrl 图片URL
* @param httpServletResponse
*/
void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception;
void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception;
/**
* 上传文件
......
......@@ -19,8 +19,9 @@ public interface AgentApplicationService {
* @param input 用户输入
* @param fileUrls 文件URL
* @param channel 渠道
* @param imageUrl 图片URL
*/
void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, HttpServletResponse httpServletResponse) throws Exception;
void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, String imageUrl, HttpServletResponse httpServletResponse) throws Exception;
/**
* 追问AI生成
......
......@@ -8,11 +8,9 @@ import cn.com.poc.agent_application.utils.AgentApplicationTools;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.service.BosConfigService;
import cn.com.poc.common.utils.DateUtils;
import cn.com.poc.common.utils.FileUtils;
import cn.com.poc.common.utils.UUIDTool;
import cn.com.poc.data_analyze.aggregate.DataAnalyzeReportService;
import cn.com.poc.data_analyze.constants.DataAnalyzeChannelEnum;
import cn.com.poc.data_analyze.constants.DataAnalyzeTypeEnum;
import cn.com.poc.equity.aggregate.MemberEquityService;
import cn.com.poc.equity.aggregate.PointDeductionRulesService;
import cn.com.poc.equity.constants.ModifyEventEnum;
......@@ -23,7 +21,6 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.core.io.FileUtil;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.stereotype.Component;
......@@ -33,8 +30,10 @@ import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream;
import java.io.File;
import java.math.BigDecimal;
import java.util.*;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.UUID;
/**
* @author alex.yao
......@@ -104,7 +103,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
}
@Override
public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception {
public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception {
BizAgentApplicationApiProfileEntity profileEntity = bizAgentApplicationApiProfileService.getByKeyAndSecret(apiKey, apiSecret);
if (profileEntity == null) {
throw new BusinessException("无效的API Key或Secret");
......@@ -136,6 +135,9 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
//配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), conversationId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
// 保存用户输入记录
Long inputTimestamp = System.currentTimeMillis();
......@@ -151,7 +153,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try {
String output = agentApplicationInfoService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, stream, httpServletResponse);
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, stream, imageUrls, httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, output);
} catch (Exception e) {
memberEquityService.rollbackPoint(reduceSn);
......
......@@ -110,7 +110,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
@Override
public void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, HttpServletResponse httpServletResponse) throws Exception {
public void callAgentApplication(String agentId, String dialogsId, String input, List<String> fileUrls, String channel, String imageUrl, HttpServletResponse httpServletResponse) throws Exception {
UserBaseEntity userBaseEntity = BlContext.getCurrentUserNotException();
if (userBaseEntity == null) {
......@@ -152,6 +152,8 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//配置对话function
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogsId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
// 获取图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
//记录输出时间戳
BizAgentApplicationDialoguesRecordEntity outputRecord = new BizAgentApplicationDialoguesRecordEntity();
......@@ -179,7 +181,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
//对话
String output = agentApplicationInfoService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, true, httpServletResponse);
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, fileUrls, true, imageUrls, httpServletResponse);
//保存对话记录
outputRecord.setContent(output);
......
......@@ -44,6 +44,16 @@ public class AgentApplicationDto {
this.fileUrls = fileUrls;
}
private String imageUrl;
public String getImageUrl() {
return imageUrl;
}
public void setImageUrl(String imageUrl) {
this.imageUrl = imageUrl;
}
private String channel;
public String getChannel() {
......
......@@ -23,11 +23,23 @@ public class CompletionsDto {
*/
private String query;
/**
* 图片
*/
private String imageUrl;
/**
* 是否流式传输
*/
private Boolean stream;
public String getImageUrl() {
return imageUrl;
}
public void setImageUrl(String imageUrl) {
this.imageUrl = imageUrl;
}
public String getConversationId() {
return conversationId;
......
......@@ -111,7 +111,7 @@ public class AgentApplicationRestImpl implements AgentApplicationRest {
Assert.notNull(dto.getAgentId());
Assert.notNull(dto.getDialogsId());
try {
agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), dto.getFileUrls(), dto.getChannel(), httpServletResponse);
agentApplicationService.callAgentApplication(dto.getAgentId(), dto.getDialogsId(), dto.getInput(), dto.getFileUrls(), dto.getChannel(), dto.getImageUrl(), httpServletResponse);
} catch (Exception e) {
httpServletResponse.setContentType("text/event-stream");
PrintWriter writer = httpServletResponse.getWriter();
......
......@@ -49,7 +49,7 @@ public class ModelLinkRestImpl implements ModelLinkRest {
if (StringUtils.isNotBlank(dto.getFileId())) {
fileIds.add(dto.getFileId());
}
agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), httpServletResponse);
agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), dto.getImageUrl(), httpServletResponse);
}
@Override
......
......@@ -6,6 +6,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.function.document_understanding.
import cn.com.poc.thirdparty.resource.demand.ai.function.html_reader.HtmlReaderFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.web_seach.WebSearchFunction;
public enum LargeModelFunctionEnum {
set_long_memory(SetLongMemoryFunction.class),
......@@ -13,7 +14,7 @@ public enum LargeModelFunctionEnum {
html_reader(HtmlReaderFunction.class),
document_reader(DocumentReaderFunction.class),
document_understanding(DocumentUnderstandIngFunction.class),
web_search(WebSearchFunction.class),
bing_web_search(null),
;
......
package cn.com.poc.thirdparty.resource.demand.ai.function.bing_web_search;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* @author alex.yao
* @date 2025/1/14
*/
@Component
public class BingWebSearchFunction extends AbstractLargeModelFunction {
@Override
public String doFunction(String content, String identifier) {
return null;
}
@Override
public String getDesc() {
return null;
}
@Override
public List<String> getLLMConfig() {
return null;
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.web_seach;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.utils.DocumentLoad;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.fastjson.JSONObject;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.apache.v2.ApacheHttpTransport;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.gson.GsonFactory;
import com.google.api.services.customsearch.v1.CustomSearchAPI;
import com.google.api.services.customsearch.v1.CustomSearchAPIRequestInitializer;
import com.google.api.services.customsearch.v1.model.Result;
import com.google.api.services.customsearch.v1.model.Search;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
/**
* @author alex.yao
* @date 2025/1/14
*/
@Component
public class WebSearchFunction extends AbstractLargeModelFunction {
private final HttpTransport transport = new ApacheHttpTransport();
private final JsonFactory jsonFactory = new GsonFactory();
private final String DESC = "该方法是通过Bing所有对用户给出的关键词进行网站、网页搜索检索获取网页内容";
private final FunctionLLMConfig functionLLMConfig = new FunctionLLMConfig.FunctionLLMConfigBuilder()
.name("web_search")
.description(DESC)
.parameters(new Parameters("object")
.addProperties("query", new Properties("string", "搜索关键词")))
.build();
@Override
public String doFunction(String content, String identifier) {
if (StringUtils.isBlank(content)) {
return "FAIL";
}
JSONObject jsonObject = JSONObject.parseObject(content);
String query = jsonObject.getString("query");
List<WebSearchFunctionResult> results = new ArrayList<>();
try {
CustomSearchAPI customSearchAPI = new CustomSearchAPI.Builder(transport, jsonFactory, GoogleNetHttpTransport.newTrustedTransport().createRequestFactory().getInitializer())
.setRootUrl("https://google-api.gsstcloud.com")
.setCustomSearchAPIRequestInitializer(new CustomSearchAPIRequestInitializer("AIzaSyCV8PTQ10rG5wo4E004dR3mcGD1RM_PrBw"))
.build();
Search execute = customSearchAPI
.cse().list().setCx("049026ecb26e840ed")
.setKey("AIzaSyCV8PTQ10rG5wo4E004dR3mcGD1RM_PrBw")
.setQ(query)
.setStart(1L)
.setNum(3)
.execute();
List<Result> items = execute.getItems();
for (Result item : items) {
String link = item.getLink();
String htmlContent = DocumentLoad.htmlToMarkdown(link);
if (StringUtils.isNotBlank(htmlContent)) {
htmlContent = htmlContent.replaceAll(StringUtils.SPACE, StringUtils.EMPTY);
}
String title = item.getTitle();
String snippet = item.getSnippet();
WebSearchFunctionResult webSearchResult = new WebSearchFunctionResult();
webSearchResult.setTitle(title);
webSearchResult.setUrl(link);
webSearchResult.setSnippet(snippet);
webSearchResult.setContent(htmlContent);
results.add(webSearchResult);
}
return JsonUtils.serialize(results);
} catch (Exception e) {
return JsonUtils.serialize(results);
}
}
@Override
public String getDesc() {
return DESC;
}
@Override
public List<String> getLLMConfig() {
return ListUtil.toList(JsonUtils.serialize(functionLLMConfig));
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return this.getLLMConfig();
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.web_seach;
/**
* @author alex.yao
* @date 2025/1/14
*/
public class WebSearchFunctionResult {
private String title;
private String url;
private String snippet;
private String content;
public String getTitle() {
return title;
}
public void setTitle(String title) {
this.title = title;
}
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getSnippet() {
return snippet;
}
public void setSnippet(String snippet) {
this.snippet = snippet;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.thirdparty.resource.demand.ai.function.web_seach.WebSearchFunction;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import java.io.IOException;
import java.security.GeneralSecurityException;
/**
* @author alex.yao
* @date 2025/1/14
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(initializers = SingleContextInitializer.class)
@WebAppConfiguration
public class GoogleSearchFunctionTest {
@Test
public void testSearch() throws IOException, GeneralSecurityException {
WebSearchFunction webSearchFunction = new WebSearchFunction();
System.out.println(webSearchFunction.doFunction("{\"query\":\"什么是区块链\"}", "1"));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment