Commit e112325f authored by alex yao's avatar alex yao

Merge branch 'task/textin-function' into 'release'

feat: 1.新增对话插件 PDF2MD 。 2.取消图片上传,统一使用文件上传。 3.文档理解插件从system改为utils, 应用需添加文档理解插件

See merge request !11
parents 2237f5aa 2345193e
......@@ -4,7 +4,6 @@ import cn.com.poc.agent_application.entity.AgentResultEntity;
import cn.com.poc.agent_application.entity.BizAgentApplicationInfoEntity;
import cn.com.poc.agent_application.entity.CreateAgentTitleAndDescEntity;
import cn.com.poc.agent_application.entity.KnowledgeSuperclassProblemConfig;
import cn.com.poc.agent_application.request.ArticleRewriteRequest;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeSearchTypeEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Message;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
......@@ -52,8 +51,6 @@ public interface AgentApplicationService {
* @param messages 对话消息
* @param tools 插件配置
* @param functionCallResult 插件回调结果
* @param fileUrls 文件URLs
* @param imageUrls 图片URLs
* @param stream 是否流式传输
* @param score 知识库参数score
* @param topK 知识库参数topK
......@@ -62,7 +59,7 @@ public interface AgentApplicationService {
*/
AgentResultEntity callAgentApplication(String agentId, String identifier, String largeModel, String agentSystem,
Integer[] knowledgeIds, Integer[] databaseIds, Integer communicationTurn, Float topP, Float temperature,
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, boolean stream,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, KnowledgeSuperclassProblemConfig superclassProblemConfig, HttpServletResponse httpServletResponse) throws Exception;
......
package cn.com.poc.agent_application.aggregate.impl;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dbchain.DBChainResult;
import com.google.common.collect.Lists;
import cn.com.poc.agent_application.aggregate.AgentApplicationService;
import cn.com.poc.agent_application.constant.*;
......@@ -202,7 +201,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
@Override
public AgentResultEntity callAgentApplication(String agentId, String dialogueId, String largeModel, String agentSystem, Integer[] kdIds, Integer[] databaseIds,
Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, List<String> fileUrls, boolean stream, List<String> imageUrls,
Integer communicationTurn, Float topP, Float temperature, List<Message> messages, List<Tool> tools, FunctionCallResult functionCallResult, boolean stream,
Double score, Integer topK, KnowledgeSearchTypeEnum knowledgeSearchType, KnowledgeSuperclassProblemConfig superclassProblemConfig, HttpServletResponse httpServletResponse) throws Exception {
String model = modelConvert(largeModel);
......@@ -1013,7 +1012,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
private ToolFunction functionResultConvertToolFunction(FunctionResult functionResult) {
ToolFunction toolFunction = new ToolFunction();
BizAgentApplicationPluginEntity bizAgentApplicationPluginEntity = bizAgentApplicationPluginService.getInfoById(functionResult.getFunctionName());
BizAgentApplicationPluginEntity bizAgentApplicationPluginEntity = bizAgentApplicationPluginService.getInfoByPluginId(functionResult.getFunctionName());
if (bizAgentApplicationPluginEntity != null && !bizAgentApplicationPluginEntity.getClassification().equals("system")) {
String lang = BlContext.getCurrentLocaleLanguageToLowerCase();
switch (lang) {
......@@ -1027,6 +1026,7 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
toolFunction.setName(bizAgentApplicationPluginEntity.getEnTitle());
break;
}
toolFunction.setDisplayFormat(bizAgentApplicationPluginEntity.getDisplayFormat());
}
toolFunction.setArguments(functionResult.getFunctionArg());
toolFunction.setResult(JsonUtils.serialize(functionResult.getFunctionResult()));
......
......@@ -21,6 +21,7 @@ public class BizAgentApplicationPluginConvert {
entity.setPoints(model.getPoints());
entity.setClassification(model.getClassification());
entity.setPluginId(model.getPluginId());
entity.setDisplayFormat(model.getDisplayFormat());
entity.setParentZhCnName(model.getParentZhCnName());
entity.setParentZhTwName(model.getParentZhTwName());
entity.setParentEnName(model.getParentEnName());
......@@ -44,6 +45,7 @@ public class BizAgentApplicationPluginConvert {
model.setId(entity.getId());
model.setIcon(entity.getIcon());
model.setPoints(entity.getPoints());
model.setDisplayFormat(entity.getDisplayFormat());
model.setClassification(entity.getClassification());
model.setPluginId(entity.getPluginId());
model.setParentZhCnName(entity.getParentZhCnName());
......
......@@ -40,7 +40,22 @@ public class BizAgentApplicationPluginDto {
public void setPluginId(java.lang.String pluginId){
this.pluginId = pluginId;
}
/** zh_cn_title
/**
* displayFormat
* 显示格式
*/
private java.lang.String displayFormat;
public String getDisplayFormat() {
return displayFormat;
}
public void setDisplayFormat(String displayFormat) {
this.displayFormat = displayFormat;
}
/** zh_cn_title
*插件标题(简体)
*/
private java.lang.String zhCnTitle;
......
......@@ -44,6 +44,20 @@ public class BizAgentApplicationPluginEntity {
this.classification = classification;
}
/**
* displayFormat
* 显示格式
*/
private java.lang.String displayFormat;
public String getDisplayFormat() {
return displayFormat;
}
public void setDisplayFormat(String displayFormat) {
this.displayFormat = displayFormat;
}
/**
* icon
*/
......
......@@ -75,6 +75,22 @@ public class BizAgentApplicationPluginModel extends BaseModelClass implements Se
}
/**
* displayFormat
* 显示格式
*/
private java.lang.String displayFormat;
@Column(name = "display_format", length = 10)
public String getDisplayFormat() {
return displayFormat;
}
public void setDisplayFormat(String displayFormat) {
this.displayFormat = displayFormat;
super.addValidField("displayFormat");
}
/**
* icon
*/
......
......@@ -261,7 +261,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogueId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.getMessageImageUrl(dto.getMessages());
// List<String> imageUrls = AgentApplicationTools.getMessageImageUrl(dto.getMessages());
//对话大模型配置
String model = StringUtils.isNotBlank(dto.getModelNickName()) ? dto.getModelNickName() : infoEntity.getLargeModel();
......@@ -272,7 +272,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
// 判断是否调用function
//计算扣分数
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools, fileUrls, imageUrls);
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(dto.getMessages(), tools, fileUrls);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(model, communicationTurn, checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId);
......@@ -281,7 +281,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
//调用应用服务
agentApplicationService.callAgentApplication(agentId, dialogueId, model,
agentSystem, kdIds.toArray(new Integer[0]), databaseIds, communicationTurn, topP,
temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(), dto.getFileUrls(), true, imageUrls,
temperature, dto.getMessages(), tools, checkPluginUseEntity.getFunctionCallResult(),true,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
superclassProblemConfig, httpServletResponse);
//数据采集
......
......@@ -11,14 +11,11 @@ import cn.com.poc.agent_application.service.BizAgentApplicationPluginService;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.StringUtils;
import cn.com.yict.framemax.core.context.Context;
import cn.com.yict.framemax.data.model.PagingInfo;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import static cn.com.poc.common.constant.XLangConstant.*;
......@@ -33,7 +30,7 @@ public class BizAgentApplicationPluginRestImpl implements BizAgentApplicationPlu
public AgentApplicationPluginDto getByPluginId(String pluginId) {
BizAgentApplicationPluginEntity bizAgentApplicationPluginEntity = new BizAgentApplicationPluginEntity();
bizAgentApplicationPluginEntity.setPluginId(pluginId);
BizAgentApplicationPluginEntity entity = bizAgentApplicationPluginService.getInfoById(pluginId);
BizAgentApplicationPluginEntity entity = bizAgentApplicationPluginService.getInfoByPluginId(pluginId);
return BizAgentApplicationPluginConvert.entityToDto(entity, BlContext.getCurrentLocaleLanguageToLowerCase());
}
......
......@@ -20,7 +20,7 @@ public interface BizAgentApplicationPluginService extends BaseService {
void deletedById(java.lang.Long id) throws Exception;
BizAgentApplicationPluginEntity getInfoById(java.lang.String pluginId);
BizAgentApplicationPluginEntity getInfoByPluginId(java.lang.String pluginId);
List<AgentPluginQueryItem> agentPluginQuery(AgentPluginQueryCondition condition, PagingInfo pagingInfo);
......
......@@ -127,7 +127,7 @@ public class BizAgentApplicationPluginServiceImpl extends BaseServiceImpl
}
@Override
public BizAgentApplicationPluginEntity getInfoById(String pluginId) {
public BizAgentApplicationPluginEntity getInfoByPluginId(String pluginId) {
BizAgentApplicationPluginModel model = new BizAgentApplicationPluginModel();
model.setPluginId(pluginId);
model.setIsDeleted(CommonConstant.IsDeleted.N);
......
......@@ -74,10 +74,10 @@ public class AgentApplicationTools {
//开启文档解析-文档理解
if (CommonConstant.YOrN.Y.equals(isDocumentParsing)) {
String functionName = LargeModelFunctionEnum.document_understanding.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getLLMConfig().get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
// String functionName = LargeModelFunctionEnum.document_understanding.name();
// String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getLLMConfig().get(0);
// Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
// tools.add(tool);
}
//初始化插件函数
......@@ -130,36 +130,36 @@ public class AgentApplicationTools {
* @param messages 对话
* @return 返回图片地址列表,若无则返回null
*/
public static List<String> getMessageImageUrl(List<Message> messages) {
List<String> imageUrls = null;
Message mess = messages.get(messages.size() - 1);
if (!(mess.getContent() instanceof String) && ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0)).containsKey("image_url")) {
LinkedHashMap map = ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0));
MultiContent multiContent = JSONObject.parseObject(JsonUtils.serialize(map), MultiContent.class);
if (ObjectUtil.isNotNull(multiContent.getImageUrl()) && StringUtils.isNotBlank(multiContent.getImageUrl().getUrl())) {
imageUrls = new ArrayList<>();
imageUrls.add(multiContent.getImageUrl().getUrl());
}
}
return imageUrls;
}
// public static List<String> getMessageImageUrl(List<Message> messages) {
// List<String> imageUrls = null;
// Message mess = messages.get(messages.size() - 1);
// if (!(mess.getContent() instanceof String) && ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0)).containsKey("image_url")) {
// LinkedHashMap map = ((LinkedHashMap) ((ArrayList) mess.getContent()).get(0));
// MultiContent multiContent = JSONObject.parseObject(JsonUtils.serialize(map), MultiContent.class);
// if (ObjectUtil.isNotNull(multiContent.getImageUrl()) && StringUtils.isNotBlank(multiContent.getImageUrl().getUrl())) {
// imageUrls = new ArrayList<>();
// imageUrls.add(multiContent.getImageUrl().getUrl());
// }
// }
// return imageUrls;
// }
/**
* 构建ImageUrls
*/
public static List<String> buildImageUrls(String imageUrl) {
List<String> imageUrls = null;
if (StringUtils.isNotBlank(imageUrl)) {
imageUrls = new ArrayList<>();
imageUrls.add(imageUrl);
}
return imageUrls;
}
// public static List<String> buildImageUrls(String imageUrl) {
// List<String> imageUrls = null;
// if (StringUtils.isNotBlank(imageUrl)) {
// imageUrls = new ArrayList<>();
// imageUrls.add(imageUrl);
// }
// return imageUrls;
// }
/**
* 判断将会调用的插件-用于扣减积分
*/
public static CheckPluginUseEntity checkPluginUse(List<Message> messages, List<Tool> tools, List<String> fileUrls, List<String> imageUrls) {
public static CheckPluginUseEntity checkPluginUse(List<Message> messages, List<Tool> tools, List<String> fileUrls) {
CheckPluginUseEntity checkPluginUseEntity = new CheckPluginUseEntity();
if (CollectionUtils.isEmpty(messages) || CollectionUtils.isEmpty(tools)) {
return checkPluginUseEntity;
......@@ -174,10 +174,7 @@ public class AgentApplicationTools {
}
query = "用户输入:" + query + "\n";
if (CollectionUtils.isNotEmpty(fileUrls)) {
query = query + "用户上传文件地址:" + JsonUtils.serialize(fileUrls) + "\n";
}
if (CollectionUtils.isNotEmpty(imageUrls)) {
query = query + "用户上传图片地址:" + JsonUtils.serialize(imageUrls);
query = query + "用户上传文件地址:" + JsonUtils.serialize(fileUrls) + "\n" + "文件格式:" + fileUrls.get(0).substring(fileUrls.get(0).lastIndexOf(".")) + "\n";
}
List<Tool> deductionTools = new ArrayList<>();
......
......@@ -32,10 +32,9 @@ public interface AgentApplicationApiService {
* @param fileIds 文件ID列表
* @param query 消息
* @param stream 是否流式输出
* @param imageUrl 图片URL
* @param httpServletResponse
*/
void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception;
void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception;
/**
* 上传文件
......
......@@ -111,7 +111,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
}
@Override
public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, String imageUrl, HttpServletResponse httpServletResponse) throws Exception {
public void completions(String apiKey, String apiSecret, String conversationId, List<String> fileIds, String query, boolean stream, HttpServletResponse httpServletResponse) throws Exception {
BizAgentApplicationApiProfileEntity profileEntity = bizAgentApplicationApiProfileService.getByKeyAndSecret(apiKey, apiSecret);
if (profileEntity == null) {
throw new BusinessException("无效的API Key或Secret");
......@@ -153,14 +153,14 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), conversationId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
//获取对话图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
// List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);/**/
// 保存用户输入记录
Long inputTimestamp = System.currentTimeMillis();
//计算扣分数
// 判断是否调用function
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools, fileUrls, imageUrls);
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools, fileUrls);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId);
......@@ -172,7 +172,7 @@ public class AgentApplicationApiServiceImpl implements AgentApplicationApiServic
try {
AgentResultEntity agentResultEntity = agentApplicationService.callAgentApplication(agentId, conversationId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), databaseIds, infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, stream, imageUrls,
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), stream,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
superclassProblemConfig, httpServletResponse);
saveRecord(conversationId, query, agentId, profileEntity, inputTimestamp, infoEntity, agentResultEntity.getMessage());
......
......@@ -160,7 +160,7 @@ public class AgentApplicationExposeServiceImpl implements AgentApplicationExpose
List<Tool> tools = AgentApplicationTools.buildFunctionConfig(infoEntity.getVariableStructure(), infoEntity.getIsLongMemory(), dialogsId, agentId, infoEntity.getUnitIds(), infoEntity.getIsDocumentParsing());
// 获取图片
List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
// List<String> imageUrls = AgentApplicationTools.buildImageUrls(imageUrl);
//记录输出时间戳
BizAgentApplicationDialoguesRecordEntity outputRecord = new BizAgentApplicationDialoguesRecordEntity();
......@@ -172,7 +172,7 @@ public class AgentApplicationExposeServiceImpl implements AgentApplicationExpose
//计算扣分数
// 判断是否调用function
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools, fileUrls, imageUrls);
CheckPluginUseEntity checkPluginUseEntity = AgentApplicationTools.checkPluginUse(messages, tools, fileUrls);
Long pointDeductionNum = pointDeductionRulesService.calculatePointDeductionNum(infoEntity.getLargeModel(), infoEntity.getCommunicationTurn(), checkPluginUseEntity.getDeductionTools());
AgentUseModifyEventInfo agentUseModifyEventInfo = new AgentUseModifyEventInfo();
agentUseModifyEventInfo.setAgentId(agentId);
......@@ -190,7 +190,7 @@ public class AgentApplicationExposeServiceImpl implements AgentApplicationExpose
//对话
AgentResultEntity agentResultEntity = agentApplicationService.callAgentApplication(agentId, dialogsId, infoEntity.getLargeModel(),
infoEntity.getAgentSystem(), kdIdList.toArray(new Integer[0]), databaseIds, infoEntity.getCommunicationTurn(),
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), fileUrls, true, imageUrls,
infoEntity.getTopP(), infoEntity.getTemperature(), messages, tools, checkPluginUseEntity.getFunctionCallResult(), true,
infoEntity.getKnowledgeSimilarity(), infoEntity.getKnowledgeNResult(), KnowledgeSearchTypeEnum.valueOf(infoEntity.getKnowledgeSearchType()),
knowledgeSuperclassProblemConfig, httpServletResponse);
......
......@@ -23,23 +23,12 @@ public class CompletionsDto {
*/
private String query;
/**
* 图片
*/
private String imageUrl;
/**
* 是否流式传输
*/
private Boolean stream;
public String getImageUrl() {
return imageUrl;
}
public void setImageUrl(String imageUrl) {
this.imageUrl = imageUrl;
}
public String getConversationId() {
return conversationId;
......
......@@ -45,7 +45,7 @@ public class ModelLinkRestImpl implements ModelLinkRest {
if (StringUtils.isNotBlank(dto.getFileId())) {
fileIds.add(dto.getFileId());
}
agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), dto.getImageUrl(), httpServletResponse);
agentApplicationApiService.completions(apiKey, apiSecret, dto.getConversationId(), fileIds, dto.getQuery(), dto.getStream(), httpServletResponse);
}
@Override
......
......@@ -7,6 +7,8 @@ public class ToolFunction {
private String result;
private String displayFormat;
public String getResult() {
return result;
}
......@@ -30,4 +32,12 @@ public class ToolFunction {
public void setArguments(String arguments) {
this.arguments = arguments;
}
public String getDisplayFormat() {
return displayFormat;
}
public void setDisplayFormat(String displayFormat) {
this.displayFormat = displayFormat;
}
}
......@@ -9,6 +9,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.function.image_ocr.ImageOCRFunct
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.notification_reminder.NotificationReminderFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.PdfToMDFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.top_search.DouyinTopSearchFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.top_search.ToutiaoTopSearchFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.top_search.WeiboTopSearchFunction;
......@@ -34,6 +35,8 @@ public enum LargeModelFunctionEnum {
bing_web_search(null),
pdf_to_md(PdfToMDFunction.class),
;
private Class<? extends AbstractLargeModelFunction> function;
......
package cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractFunctionResult;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.FunctionLLMConfig;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Parameters;
import cn.com.poc.thirdparty.resource.demand.ai.function.entity.Properties;
import cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.api.OCRClient;
import cn.hutool.core.collection.ListUtil;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
/**
* @author alex.yao
* @date 2025/5/7
*/
@Component
public class PdfToMDFunction extends AbstractLargeModelFunction {
private Logger logger = LoggerFactory.getLogger(PdfToMDFunction.class);
private final String DESC = "该方法是通过OCR获取PDF或者图片的表格内容提取并转为Markdown格式。";
private final FunctionLLMConfig functionLLMConfig = new FunctionLLMConfig.FunctionLLMConfigBuilder()
.name("pdf_to_md")
.description(DESC)
.parameters(new Parameters("object")
.addProperties("file_url", new Properties("string", "文件地址")))
.build();
@Override
public AbstractFunctionResult<String> doFunction(String content, String identifier) {
AbstractFunctionResult<String> result = new AbstractFunctionResult<String>();
JSONObject jsonObject = JSONObject.parseObject(content);
String url = jsonObject.getString("file_url");
byte[] fileContent = url.getBytes(StandardCharsets.UTF_8);
HashMap<String, Object> options = new HashMap<>();
options.put("apply_document_tree", 1);
options.put("catalog_details", 1);
options.put("dpi", 144);
options.put("get_excel", 1);
options.put("get_image", "objects");
options.put("markdown_details", 1);
options.put("page_start", 1);
options.put("page_count", 1);
options.put("page_details", 1);
options.put("paratext_mode", "annotation");
options.put("parse_mode", "auto");
options.put("table_flavor", "md");
OCRClient client = new OCRClient();
try {
String response = client.recognize(fileContent, options);
ObjectMapper mapper = new ObjectMapper();
JsonNode jsonNode = mapper.readTree(response);
if (jsonNode.has("result") && jsonNode.get("result").has("markdown")) {
String markdown = jsonNode.get("result").get("markdown").asText();
result.setPromptContent(markdown);
result.setFunctionResult(markdown);
}
return result;
} catch (Exception e) {
logger.error("Error occurred during PDF to MD conversion", e);
result.setPromptContent("FAIL");
result.setFunctionResult(e.getMessage());
return result;
}
}
@Override
public String getDesc() {
return DESC;
}
@Override
public List<String> getLLMConfig() {
return ListUtil.toList(JsonUtils.serialize(functionLLMConfig));
}
@Override
public List<String> getLLMConfig(List<Variable> variableStructure) {
return this.getLLMConfig();
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.api;
/**
* @author alex.yao
* @date 2025/5/7
*/
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.HashMap;
import java.util.Map;
public class OCRClient {
private Logger logger = LoggerFactory.getLogger(OCRClient.class);
private final String appId = "dafd04a574230c00ccba61132160de0c";
private final String secretCode = "3bc03c7e6f9402963e6e71d16d786a9c";
private final String baseUrl = "https://api.textin.com/ai/service/v1/pdf_to_markdown";
public OCRClient() {
}
public String recognize(byte[] fileContent, HashMap<String, Object> options) throws IOException {
StringBuilder queryParams = new StringBuilder();
for (Map.Entry<String, Object> entry : options.entrySet()) {
if (queryParams.length() > 0) {
queryParams.append("&");
}
queryParams.append(URLEncoder.encode(entry.getKey(), "UTF-8"))
.append("=")
.append(URLEncoder.encode(entry.getValue().toString(), "UTF-8"));
}
String fullUrl = baseUrl + (queryParams.length() > 0 ? "?" + queryParams : "");
URL url = new URL(fullUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("x-ti-app-id", appId);
connection.setRequestProperty("x-ti-secret-code", secretCode);
connection.setRequestProperty("Content-Type", "text/plain");
connection.setDoOutput(true);
try (OutputStream os = connection.getOutputStream()) {
os.write(fileContent);
os.flush();
}
int responseCode = connection.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader in = new BufferedReader(
new InputStreamReader(connection.getInputStream()))) {
StringBuilder response = new StringBuilder();
String inputLine;
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
return response.toString();
}
} else {
logger.error("HTTP request failed with code: {}, Error message view :{}", responseCode, "https://www.textin.com/document/pdf_to_markdown");
throw new IOException("HTTP request failed with code: " + responseCode);
}
}
}
\ No newline at end of file
package cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.entity;
/**
* @author alex.yao
* @date 2025/5/7
*/
public class PdfToMDResponse {
public String pdfPwd;
public Integer dpi;
public Integer pageStart;
public Integer pageCount;
public Boolean applyDocumentTree;
public String markdownDetails;
public String tableFlavor;
public String getImage;
public String parseMode;
public String getPdfPwd() {
return pdfPwd;
}
public void setPdfPwd(String pdfPwd) {
this.pdfPwd = pdfPwd;
}
public Integer getDpi() {
return dpi;
}
public void setDpi(Integer dpi) {
this.dpi = dpi;
}
public Integer getPageStart() {
return pageStart;
}
public void setPageStart(Integer pageStart) {
this.pageStart = pageStart;
}
public Integer getPageCount() {
return pageCount;
}
public void setPageCount(Integer pageCount) {
this.pageCount = pageCount;
}
public Boolean getApplyDocumentTree() {
return applyDocumentTree;
}
public void setApplyDocumentTree(Boolean applyDocumentTree) {
this.applyDocumentTree = applyDocumentTree;
}
public String getMarkdownDetails() {
return markdownDetails;
}
public void setMarkdownDetails(String markdownDetails) {
this.markdownDetails = markdownDetails;
}
public String getTableFlavor() {
return tableFlavor;
}
public void setTableFlavor(String tableFlavor) {
this.tableFlavor = tableFlavor;
}
public String getGetImage() {
return getImage;
}
public void setGetImage(String getImage) {
this.getImage = getImage;
}
public String getParseMode() {
return parseMode;
}
public void setParseMode(String parseMode) {
this.parseMode = parseMode;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.entity;
/**
* @author alex.yao
* @date 2025/5/7
*/
public class PdfToMDResult {
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.thirdparty.resource.demand.ai.function.text_in_pdf2md.api.OCRClient;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.runner.RunWith;
import org.junit.Test;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
/**
* @author alex.yao
* @date 2025/5/7
*/
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(initializers = SingleContextInitializer.class)
@WebAppConfiguration
public class PdfToMdFunctionTest {
@Test
public void test_pdfToMd() {
String url = "https://gsst-poe-sit.gz.bcebos.com/v1/1f606155c112a421bfa818e8d89b9ee%281%29.jpg?authorization=bce-auth-v1%2Fae40d293ec92452789c8f0c25a3d4e32%2F2025-05-07T06%3A51%3A07Z%2F300%2Fhost%2F7a55d6672fddf9ba7d47bac62d7dc9e597eaa0fb52c959cb3953c05371d37065";
byte[] fileContent = url.getBytes(StandardCharsets.UTF_8);
HashMap<String, Object> options = new HashMap<>();
options.put("apply_document_tree", 1);
options.put("catalog_details", 1);
options.put("dpi", 144);
options.put("get_excel", 1);
options.put("get_image", "objects");
options.put("markdown_details", 1);
options.put("page_start", 1);
options.put("page_count", 1);
options.put("page_details", 1);
options.put("paratext_mode", "annotation");
options.put("parse_mode", "auto");
options.put("table_flavor", "md");
OCRClient client = new OCRClient();
try {
String response = client.recognize(fileContent, options);
ObjectMapper mapper = new ObjectMapper();
JsonNode jsonNode = mapper.readTree(response);
if (jsonNode.has("result") && jsonNode.get("result").has("markdown")) {
String markdown = jsonNode.get("result").get("markdown").asText();
System.out.println(markdown);
}
} catch (Exception e) {
System.out.println("1111111");
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment