Commit 64582e53 authored by alex yao's avatar alex yao

fix[智能问数]: 修复[数据库查询异常]断连问题

parent cf7758b9
......@@ -212,159 +212,162 @@ public class AiBiServiceImpl implements AiBiService {
@Override
public void callV2(String dialoguesId, String input, String fileUrl, Integer[] knowledgeIds, Integer[] databaseIds, Long userId) throws Exception {
public void callV2(String dialoguesId, String input, String fileUrl, Integer[] knowledgeIds, Integer[] databaseIds, Long userId) throws IOException {
SSEUtil sseUtil = new SSEUtil();
// 保存对话基础信息
BizAiDialoguesEntity bizAiDialoguesEntity = new BizAiDialoguesEntity();
bizAiDialoguesEntity.setMemberId(userId);
bizAiDialoguesEntity.setDialoguesId(dialoguesId);
bizAiDialoguesEntity.setIsDeleted(CommonConstant.IsDeleted.N);
List<BizAiDialoguesEntity> bizAiDialoguesEntities = bizAiDialoguesService.findByExample(bizAiDialoguesEntity, null);
if (CollectionUtils.isEmpty(bizAiDialoguesEntities)) {
logger.error("dialogues id 不存在,请重新创建");
throw new BusinessException("对话异常");
}
try {
BizAiDialoguesEntity bizAiDialoguesEntity = new BizAiDialoguesEntity();
bizAiDialoguesEntity.setMemberId(userId);
bizAiDialoguesEntity.setDialoguesId(dialoguesId);
bizAiDialoguesEntity.setIsDeleted(CommonConstant.IsDeleted.N);
List<BizAiDialoguesEntity> bizAiDialoguesEntities = bizAiDialoguesService.findByExample(bizAiDialoguesEntity, null);
if (CollectionUtils.isEmpty(bizAiDialoguesEntities)) {
logger.error("dialogues id 不存在,请重新创建");
throw new BusinessException("对话异常");
}
// 保存标题信息
if (StringUtils.isBlank(bizAiDialoguesEntities.get(0).getTitle())) {
BizAiDialoguesEntity portalDialoguesEntity = bizAiDialoguesEntities.get(0);
portalDialoguesEntity.setTitle(input);
bizAiDialoguesService.update(portalDialoguesEntity);
}
// 保存标题信息
if (StringUtils.isBlank(bizAiDialoguesEntities.get(0).getTitle())) {
BizAiDialoguesEntity portalDialoguesEntity = bizAiDialoguesEntities.get(0);
portalDialoguesEntity.setTitle(input);
bizAiDialoguesService.update(portalDialoguesEntity);
}
if (databaseIds.length > 1) {
throw new BusinessException("仅支持单个数据库");
}
long inputTime = System.currentTimeMillis();
SSEUtil sseUtil = new SSEUtil();
CSVChainResult csvChainResult = new CSVChainResult();
List<DBChainResult> dbChainResults = new ArrayList<>();
String functionRecord = StringUtils.EMPTY;
if (databaseIds.length > 1) {
throw new BusinessException("仅支持单个数据库");
}
long inputTime = System.currentTimeMillis();
BizAgentApplicationGcConfigEntity gcConfigEntity = bizAgentApplicationGcConfigService.getByConfigCode("AIBIPrompt");
if (gcConfigEntity == null) {
throw new BusinessException("无法找到【智能问数】配置");
}
CSVChainResult csvChainResult = new CSVChainResult();
List<DBChainResult> dbChainResults = new ArrayList<>();
String functionRecord = StringUtils.EMPTY;
//1. 文件/数据库 chain
if (ArrayUtils.isNotEmpty(databaseIds)) {
for (Integer databaseId : databaseIds) {
BizKnowledgeDatabaseEntity bizKnowledgeDatabaseEntity = bizKnowledgeDatabaseService.get(databaseId.longValue());
if (bizKnowledgeDatabaseEntity == null) {
continue;
}
DBChainResponse dbChainResponse = new DBChainResponse();
dbChainResponse.setQuestion(input);
dbChainResponse.setContext(null);
dbChainResponse.setPrompt(null);
dbChainResponse.setTableFilters(null);
dbChainResponse.setMysqlUser(bizKnowledgeDatabaseEntity.getDbUsername());
dbChainResponse.setMysqlPassword(bizKnowledgeDatabaseEntity.getDbPassword());
dbChainResponse.setMysqlHost(bizKnowledgeDatabaseEntity.getDbHost());
dbChainResponse.setMysqlPort(bizKnowledgeDatabaseEntity.getDbPort());
dbChainResponse.setMysqlDatabase(bizKnowledgeDatabaseEntity.getDbName());
DBChainResult dbChainResult = chainService.dbChain(dbChainResponse);
if (dbChainResult != null) {
dbChainResult.setSql("```SQL\n" + dbChainResult.getSql() + "\n```");
dbChainResults.add(dbChainResult);
}
BizAgentApplicationGcConfigEntity gcConfigEntity = bizAgentApplicationGcConfigService.getByConfigCode("AIBIPrompt");
if (gcConfigEntity == null) {
throw new BusinessException("无法找到【智能问数】配置");
}
//2. 若有数据, 执行[EChart Agent] 判断并生成EChart Option
if (CollectionUtils.isNotEmpty(dbChainResults)) {
DBChainResult dbChainResult = dbChainResults.get(0);
//输出SQL
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
result.setFunction(null);
result.setDbChainResult(dbChainResults);
result.setKnowledgeContentResult(null);
sseUtil.send(JsonUtils.serialize(result));
JSONObject echartJSONObject = new JSONObject();
echartJSONObject.put("sql", dbChainResult.getSql());
echartJSONObject.put("sql_result", dbChainResult.getSqlResult());
echartJSONObject.put("question", input);
AbstractFunctionResult<String> functionResult = eChartGenerateFunction.doFunction(echartJSONObject.toJSONString(), null, null, null);
String eChartOption = functionResult.getFunctionResult();
//输出EChart
if (eChartOption != null) {
JSONObject jsonObject = JSONObject.parseObject(eChartOption);
if (!jsonObject.containsKey("skip")) {
ToolFunction toolFunction = new ToolFunction();
toolFunction.setName("echart_function");
toolFunction.setResult(eChartOption);
toolFunction.setDisplayFormat("json");
toolFunction.setArguments(null);
result = new LargeModelDemandResult();
result.setCode("0");
result.setFunction(toolFunction);
result.setDbChainResult(null);
result.setKnowledgeContentResult(null);
String toolFunctionJson = JsonUtils.serialize(result);
sseUtil.send(toolFunctionJson);
functionRecord = toolFunctionJson;
//1. 文件/数据库 chain
if (ArrayUtils.isNotEmpty(databaseIds)) {
for (Integer databaseId : databaseIds) {
BizKnowledgeDatabaseEntity bizKnowledgeDatabaseEntity = bizKnowledgeDatabaseService.get(databaseId.longValue());
if (bizKnowledgeDatabaseEntity == null) {
continue;
}
DBChainResponse dbChainResponse = new DBChainResponse();
dbChainResponse.setQuestion(input);
dbChainResponse.setContext(null);
dbChainResponse.setPrompt(null);
dbChainResponse.setTableFilters(null);
dbChainResponse.setMysqlUser(bizKnowledgeDatabaseEntity.getDbUsername());
dbChainResponse.setMysqlPassword(bizKnowledgeDatabaseEntity.getDbPassword());
dbChainResponse.setMysqlHost(bizKnowledgeDatabaseEntity.getDbHost());
dbChainResponse.setMysqlPort(bizKnowledgeDatabaseEntity.getDbPort());
dbChainResponse.setMysqlDatabase(bizKnowledgeDatabaseEntity.getDbName());
DBChainResult dbChainResult = chainService.dbChain(dbChainResponse);
if (dbChainResult != null) {
dbChainResult.setSql("```SQL\n" + dbChainResult.getSql() + "\n```");
dbChainResults.add(dbChainResult);
}
}
}
} else if (StringUtils.isNotBlank(fileUrl)) {
CSVChainResponse csvChainResponse = new CSVChainResponse();
csvChainResponse.setQuestion(input);
csvChainResponse.setContext(null);
csvChainResponse.setFilePath(fileUrl);
csvChainResult = chainService.csvChain(csvChainResponse);
if (csvChainResult != null) {
String[] csvRes = JsonUtils.deSerialize(csvChainResult.getResult(), String[].class);
if (ArrayUtils.isNotEmpty(csvRes)) {
//2. 若有数据, 执行[EChart Agent] 判断并生成EChart Option
if (CollectionUtils.isNotEmpty(dbChainResults)) {
DBChainResult dbChainResult = dbChainResults.get(0);
//输出SQL
LargeModelDemandResult result = new LargeModelDemandResult();
result.setCode("0");
result.setFunction(null);
result.setDbChainResult(dbChainResults);
result.setKnowledgeContentResult(null);
sseUtil.send(JsonUtils.serialize(result));
JSONObject echartJSONObject = new JSONObject();
echartJSONObject.put("sql", null);
echartJSONObject.put("sql_result", csvRes.toString());
echartJSONObject.put("sql", dbChainResult.getSql());
echartJSONObject.put("sql_result", dbChainResult.getSqlResult());
echartJSONObject.put("question", input);
AbstractFunctionResult<String> functionResult = eChartGenerateFunction.doFunction(echartJSONObject.toJSONString(), null, null, null);
String eChartOption = functionResult.getFunctionResult();
functionRecord = outputECharts(eChartOption, sseUtil, functionRecord);
//输出EChart
if (functionResult != null && StringUtils.isNotBlank(functionResult.getFunctionResult())) {
JSONObject jsonObject = JSONObject.parseObject(functionResult.getFunctionResult());
if (!jsonObject.containsKey("skip")) {
ToolFunction toolFunction = new ToolFunction();
toolFunction.setName("echart_function");
toolFunction.setResult(functionResult.getFunctionResult());
toolFunction.setDisplayFormat("json");
toolFunction.setArguments(null);
result = new LargeModelDemandResult();
result.setCode("0");
result.setFunction(toolFunction);
result.setDbChainResult(null);
result.setKnowledgeContentResult(null);
String toolFunctionJson = JsonUtils.serialize(result);
sseUtil.send(toolFunctionJson);
functionRecord = toolFunctionJson;
}
}
}
} else if (StringUtils.isNotBlank(fileUrl)) {
CSVChainResponse csvChainResponse = new CSVChainResponse();
csvChainResponse.setQuestion(input);
csvChainResponse.setContext(null);
csvChainResponse.setFilePath(fileUrl);
csvChainResult = chainService.csvChain(csvChainResponse);
if (csvChainResult != null) {
String[] csvRes = JsonUtils.deSerialize(csvChainResult.getResult(), String[].class);
if (ArrayUtils.isNotEmpty(csvRes)) {
JSONObject echartJSONObject = new JSONObject();
echartJSONObject.put("sql", null);
echartJSONObject.put("sql_result", csvRes.toString());
echartJSONObject.put("question", input);
AbstractFunctionResult<String> functionResult = eChartGenerateFunction.doFunction(echartJSONObject.toJSONString(), null, null, null);
String eChartOption = functionResult.getFunctionResult();
functionRecord = outputECharts(eChartOption, sseUtil, functionRecord);
}
}
}
}
//3. 执行对话大模型
List<Message> messages = buildMessage(userId, dialoguesId, gcConfigEntity.getConfigSystem(), input, dbChainResults, csvChainResult);
LargeModelResponse largeModelResponse = new LargeModelResponse();
largeModelResponse.setModel(gcConfigEntity.getLargeModel());
largeModelResponse.setMessages(messages.toArray(new Message[0]));
largeModelResponse.setStream(true);
largeModelResponse.setUser("AI_BI_CHAT");
BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse);
LongtextDialoguesResult longtextDialoguesResult = textOutputStream(sseUtil, bufferedReader);
//保存对话记录
BizAgentApplicationDialoguesRecordEntity inputRecord = new BizAgentApplicationDialoguesRecordEntity();
inputRecord.setMemberId(userId);
inputRecord.setContent(input);
if (StringUtils.isNotBlank(fileUrl)) {
inputRecord.setFileUrl(fileUrl);
//3. 执行对话大模型
List<Message> messages = buildMessage(userId, dialoguesId, gcConfigEntity.getConfigSystem(), input, dbChainResults, csvChainResult);
LargeModelResponse largeModelResponse = new LargeModelResponse();
largeModelResponse.setModel(gcConfigEntity.getLargeModel());
largeModelResponse.setMessages(messages.toArray(new Message[0]));
largeModelResponse.setStream(true);
largeModelResponse.setUser("AI_BI_CHAT");
BufferedReader bufferedReader = llmService.chatChunk(largeModelResponse);
LongtextDialoguesResult longtextDialoguesResult = textOutputStream(sseUtil, bufferedReader);
//保存对话记录
BizAgentApplicationDialoguesRecordEntity inputRecord = new BizAgentApplicationDialoguesRecordEntity();
inputRecord.setMemberId(userId);
inputRecord.setContent(input);
if (StringUtils.isNotBlank(fileUrl)) {
inputRecord.setFileUrl(fileUrl);
}
inputRecord.setDialogsId(dialoguesId);
inputRecord.setRole("user");
inputRecord.setTimestamp(inputTime);
// 保存AI回复记录
BizAgentApplicationDialoguesRecordEntity assistantRecord = new BizAgentApplicationDialoguesRecordEntity();
assistantRecord.setMemberId(userId);
assistantRecord.setContent(longtextDialoguesResult.getMessage());
assistantRecord.setReasoningContent(longtextDialoguesResult.getReasoningContent());
assistantRecord.setDialogsId(dialoguesId);
assistantRecord.setFunction(functionRecord);
assistantRecord.setRole("assistant");
assistantRecord.setTimestamp(System.currentTimeMillis());
bizAgentApplicationDialoguesRecordService.save(inputRecord);
bizAgentApplicationDialoguesRecordService.save(assistantRecord);
} catch (Exception e) {
sseUtil.completeByError(e.getMessage());
}
inputRecord.setDialogsId(dialoguesId);
inputRecord.setRole("user");
inputRecord.setTimestamp(inputTime);
// 保存AI回复记录
BizAgentApplicationDialoguesRecordEntity assistantRecord = new BizAgentApplicationDialoguesRecordEntity();
assistantRecord.setMemberId(userId);
assistantRecord.setContent(longtextDialoguesResult.getMessage());
assistantRecord.setReasoningContent(longtextDialoguesResult.getReasoningContent());
assistantRecord.setDialogsId(dialoguesId);
assistantRecord.setFunction(functionRecord);
assistantRecord.setRole("assistant");
assistantRecord.setTimestamp(System.currentTimeMillis());
bizAgentApplicationDialoguesRecordService.save(inputRecord);
bizAgentApplicationDialoguesRecordService.save(assistantRecord);
}
private String outputECharts(String eChartOption, SSEUtil sseUtil, String functionRecord) throws IOException {
private String outputECharts(String eChartOption, SSEUtil sseUtil, String functionRecord) throws IOException {
//输出EChart
if (eChartOption != null) {
JSONObject jsonObject = JSONObject.parseObject(eChartOption);
......
......@@ -33,16 +33,21 @@ public class ChainServiceImpl implements ChainService {
@Override
public DBChainResult dbChain(DBChainResponse response) {
logger.info("dbChain response : {}", response);
DBChainResult dbChainResult = aiDialogueService.dbChain(response);
if (dbChainResult == null || dbChainResult.getStatus().equals("error")) {
logger.error("dbChain result error : {} , response:{}", dbChainResult, response);
try {
DBChainResult dbChainResult = aiDialogueService.dbChain(response);
if (dbChainResult == null || dbChainResult.getStatus().equals("error")) {
logger.error("dbChain result error : {} , response:{}", dbChainResult, response);
return null;
}
if (StringUtils.isBlank(dbChainResult.getSqlResult())) {
logger.warn("dbChain result sqlResult is blank : {} , response:{}", dbChainResult, response);
return null;
}
return dbChainResult;
} catch (Exception e) {
logger.error("调用数据库失败:{}", e.getMessage());
return null;
}
if (StringUtils.isBlank(dbChainResult.getSqlResult())) {
logger.warn("dbChain result sqlResult is blank : {} , response:{}", dbChainResult, response);
return null;
}
return dbChainResult;
}
@Override
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment