Commit e525cb03 authored by alex yao's avatar alex yao

feat:优化知识库训练流程

parent 52684761
package cn.com.poc.knowledge.scheduler;
import cn.com.poc.common.constant.CommonConstant;
import cn.com.poc.knowledge.constant.KnowledgeConstant;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
import cn.com.poc.knowledge.query.KnowledgeInfosQueryCondition;
import cn.com.poc.knowledge.query.KnowledgeInfosQueryItem;
import cn.com.poc.knowledge.service.BizKnowledgeDocumentService;
import cn.com.poc.knowledge.service.BizKnowledgeInfoService;
import cn.com.poc.message.entity.KnowledgeTrainStatusMessage;
import cn.com.poc.message.service.KnowledgeProducerService;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeTrainStatusConstant;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.Collection;
import java.util.List;
@Component
......@@ -22,6 +29,12 @@ public class KnowledgeInfoScheduler {
@Resource
private BizKnowledgeInfoService bizKnowledgeInfoService;
@Resource
private BizKnowledgeDocumentService bizKnowledgeDocumentService;
@Resource
private DemandKnowledgeService demandKnowledgeService;
@Scheduled(cron = "0 0/1 * * * ?")
public void knowledgeInfoStatusUpdateScheduler() throws Exception {
KnowledgeInfosQueryCondition condition = new KnowledgeInfosQueryCondition();
......@@ -34,4 +47,34 @@ public class KnowledgeInfoScheduler {
knowledgeProducerService.knowledgeInfoStatusCheck(knowledgeTrainStatusMessage);
}
}
@Scheduled(cron = "0 0/1 * * * ?")
public void knowledgeDocumentStatusUpdateScheduler() throws Exception {
BizKnowledgeDocumentEntity knowledgeDocumentEntity = new BizKnowledgeDocumentEntity();
knowledgeDocumentEntity.setTrainStatus(KnowledgeConstant.TrainStatus.TRAINING);
knowledgeDocumentEntity.setIsDeleted(CommonConstant.IsDeleted.N);
List<BizKnowledgeDocumentEntity> entities = bizKnowledgeDocumentService.findByExample(knowledgeDocumentEntity, null);
if (CollectionUtils.isEmpty(entities)) {
return;
}
for (BizKnowledgeDocumentEntity entity : entities) {
String trainKnowledgeStatus = demandKnowledgeService.trainKnowledgeStatus(entity.getKnowledgeId());
if (KnowledgeTrainStatusConstant.fail.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(entity.getKdId());
message.setKnowledgeId(entity.getKnowledgeId());
message.setStatus(KnowledgeConstant.TrainStatus.FAIL);
knowledgeProducerService.trainStatusUpdate(message);
//记录失败原因
} else if (KnowledgeTrainStatusConstant.success.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(entity.getKdId());
message.setKnowledgeId(entity.getKnowledgeId());
message.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
knowledgeProducerService.trainStatusUpdate(message);
}
}
}
}
......@@ -56,7 +56,6 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
public void trainKnowledge(TrainKnowledgeMessage message) throws Exception {
//修改训练状态
String knowledgeId = demandKnowledgeService.trainKnowledgeEvent(message.getFileUrl(), message.getSegmentationConfig());
KnowledgeTrainStatusMessage trainStatusMessage = new KnowledgeTrainStatusMessage();
trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.TRAINING);
trainStatusMessage.setKdId(message.getKid());
......@@ -65,67 +64,7 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
}
/**
* 训练知识库
*
* @param message
* @return
*/
// @Override
// @Consumer(topic = KnowledgeTopic.TRAIN_KNOWLEDGE, scale = 3, retry = true)
// public void trainKnowledge(TrainKnowledgeMessage message) throws Exception {
// //修改训练状态
// KnowledgeTrainStatusMessage trainStatusMessage = new KnowledgeTrainStatusMessage();
// trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.TRAINING);
// trainStatusMessage.setKdId(message.getKid());
// trainStatusMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
// knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
// try {
// String knowledgeId = demandKnowledgeService.trainKnowledge(message.getFileUrl(), message.getSegmentationConfig());
//
// //训练日志
// BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
// bizKnowledgeTrainLogEntity.setKdId(message.getKid());
// bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
// bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.COMPLETE);
// bizKnowledgeTrainLogEntity.setFailureLog("");
// bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
// bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
//
// //训练完成
// KnowledgeTrainStatusMessage completeMessage = new KnowledgeTrainStatusMessage();
// completeMessage.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
// completeMessage.setKdId(message.getKid());
// completeMessage.setKnowledgeId(knowledgeId);
// completeMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
// knowledgeProducerService.trainStatusUpdate(completeMessage);
// } catch (BusinessException e) {
// logger.warn("--------------message:{},知识库训练失败:{}------------", message, e.getMessage());
//
// //记录状态 训练失败
// trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.FAIL);
// trainStatusMessage.setKdId(message.getKid());
// knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
//
// logger.warn("-------保存知识库训练失败状态----------");
// //训练日志
// BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
// bizKnowledgeTrainLogEntity.setKdId(message.getKid());
// bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
// bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
// bizKnowledgeTrainLogEntity.setFailureLog(e.getMessage());
// bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
// bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
// logger.warn("-------保存知识库训练失败日志----------");
//
//
// BizKnowledgeInfoEntity bizKnowledgeInfoEntity = bizKnowledgeInfoService.get(message.getKnowledgeInfoId());
// bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
// bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
// logger.warn("-------保存知识库训练失败状态----------");
//
// }
// }
@Override
@Consumer(topic = KnowledgeTopic.TRAIN_STATUS, scale = 1, retry = true)
public void trainStatusUpdate(KnowledgeTrainStatusMessage message) throws Exception {
......@@ -153,27 +92,9 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
logger.info("-------知识库训练状态检查,kdIds:{}-------", kdIds);
for (Integer kdId : kdIdList) {
BizKnowledgeDocumentEntity documentEntity = bizKnowledgeDocumentService.get(kdId);
if (KnowledgeConstant.TrainStatus.TRAINING.equals(documentEntity.getTrainStatus())) {
String knowledgeId = documentEntity.getKnowledgeId();
String trainKnowledgeStatus = demandKnowledgeService.trainKnowledgeStatus(knowledgeId);
if (KnowledgeTrainStatusConstant.fail.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(kdId);
message.setKnowledgeId(knowledgeId);
message.setStatus(KnowledgeConstant.TrainStatus.FAIL);
knowledgeProducerService.trainStatusUpdate(message);
//更新知识库训练状态
bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
} else if (KnowledgeTrainStatusConstant.success.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(kdId);
message.setKnowledgeId(knowledgeId);
message.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
knowledgeProducerService.trainStatusUpdate(message);
}
if (KnowledgeConstant.TrainStatus.FAIL.equals(documentEntity.getTrainStatus())) {
bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
}
if (!documentEntity.getTrainStatus().equals(KnowledgeConstant.TrainStatus.COMPLETE)) {
......@@ -181,12 +102,10 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
break;
}
}
if (isAllComplete) {
logger.info("-------知识库训练状态检查,全部完成, knowledgeInfoId:{}-------", knowledgeTrainStatusMessage.getKnowledgeInfoId());
bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.COMPLETE);
bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment