Commit 102dceb9 authored by alex yao's avatar alex yao

refactor: 重构知识库训练-异步

parent 814d1100
...@@ -17,6 +17,7 @@ import cn.com.poc.message.service.KnowledgeConsumerService; ...@@ -17,6 +17,7 @@ import cn.com.poc.message.service.KnowledgeConsumerService;
import cn.com.poc.message.service.KnowledgeProducerService; import cn.com.poc.message.service.KnowledgeProducerService;
import cn.com.poc.message.topic.KnowledgeTopic; import cn.com.poc.message.topic.KnowledgeTopic;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService; import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
import cn.com.poc.thirdparty.resource.demand.ai.constants.KnowledgeTrainStatusConstant;
import cn.com.yict.framemax.core.exception.BusinessException; import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.tumbleweed.client.annotation.Consumer; import cn.com.yict.framemax.tumbleweed.client.annotation.Consumer;
import com.alibaba.fastjson.TypeReference; import com.alibaba.fastjson.TypeReference;
...@@ -50,69 +51,81 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService { ...@@ -50,69 +51,81 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
private BizKnowledgeInfoService bizKnowledgeInfoService; private BizKnowledgeInfoService bizKnowledgeInfoService;
/**
* 训练知识库
*
* @param message
* @return
*/
@Override @Override
@Consumer(topic = KnowledgeTopic.TRAIN_KNOWLEDGE, scale = 3, retry = true) @Consumer(topic = KnowledgeTopic.TRAIN_KNOWLEDGE, scale = 3, retry = true)
public void trainKnowledge(TrainKnowledgeMessage message) throws Exception { public void trainKnowledge(TrainKnowledgeMessage message) throws Exception {
//修改训练状态 //修改训练状态
String knowledgeId = demandKnowledgeService.trainKnowledgeEvent(message.getFileUrl(), message.getSegmentationConfig());
KnowledgeTrainStatusMessage trainStatusMessage = new KnowledgeTrainStatusMessage(); KnowledgeTrainStatusMessage trainStatusMessage = new KnowledgeTrainStatusMessage();
trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.TRAINING); trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.TRAINING);
trainStatusMessage.setKdId(message.getKid()); trainStatusMessage.setKdId(message.getKid());
trainStatusMessage.setKnowledgeInfoId(message.getKnowledgeInfoId()); trainStatusMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
trainStatusMessage.setKnowledgeId(knowledgeId);
knowledgeProducerService.trainStatusUpdate(trainStatusMessage); knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
try {
String knowledgeId = demandKnowledgeService.trainKnowledge(message.getFileUrl(), message.getSegmentationConfig());
//训练日志
BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
bizKnowledgeTrainLogEntity.setKdId(message.getKid());
bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.COMPLETE);
bizKnowledgeTrainLogEntity.setFailureLog("");
bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
//训练完成
KnowledgeTrainStatusMessage completeMessage = new KnowledgeTrainStatusMessage();
completeMessage.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
completeMessage.setKdId(message.getKid());
completeMessage.setKnowledgeId(knowledgeId);
completeMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
knowledgeProducerService.trainStatusUpdate(completeMessage);
} catch (BusinessException e) {
logger.warn("--------------message:{},知识库训练失败:{}------------", message, e.getMessage());
//记录状态 训练失败
trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.FAIL);
trainStatusMessage.setKdId(message.getKid());
knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
logger.warn("-------保存知识库训练失败状态----------");
//训练日志
BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
bizKnowledgeTrainLogEntity.setKdId(message.getKid());
bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
bizKnowledgeTrainLogEntity.setFailureLog(e.getMessage());
bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
logger.warn("-------保存知识库训练失败日志----------");
BizKnowledgeInfoEntity bizKnowledgeInfoEntity = bizKnowledgeInfoService.get(message.getKnowledgeInfoId());
bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
logger.warn("-------保存知识库训练失败状态----------");
}
} }
/**
* 训练知识库
*
* @param message
* @return
*/
// @Override
// @Consumer(topic = KnowledgeTopic.TRAIN_KNOWLEDGE, scale = 3, retry = true)
// public void trainKnowledge(TrainKnowledgeMessage message) throws Exception {
// //修改训练状态
// KnowledgeTrainStatusMessage trainStatusMessage = new KnowledgeTrainStatusMessage();
// trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.TRAINING);
// trainStatusMessage.setKdId(message.getKid());
// trainStatusMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
// knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
// try {
// String knowledgeId = demandKnowledgeService.trainKnowledge(message.getFileUrl(), message.getSegmentationConfig());
//
// //训练日志
// BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
// bizKnowledgeTrainLogEntity.setKdId(message.getKid());
// bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
// bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.COMPLETE);
// bizKnowledgeTrainLogEntity.setFailureLog("");
// bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
// bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
//
// //训练完成
// KnowledgeTrainStatusMessage completeMessage = new KnowledgeTrainStatusMessage();
// completeMessage.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
// completeMessage.setKdId(message.getKid());
// completeMessage.setKnowledgeId(knowledgeId);
// completeMessage.setKnowledgeInfoId(message.getKnowledgeInfoId());
// knowledgeProducerService.trainStatusUpdate(completeMessage);
// } catch (BusinessException e) {
// logger.warn("--------------message:{},知识库训练失败:{}------------", message, e.getMessage());
//
// //记录状态 训练失败
// trainStatusMessage.setStatus(KnowledgeConstant.TrainStatus.FAIL);
// trainStatusMessage.setKdId(message.getKid());
// knowledgeProducerService.trainStatusUpdate(trainStatusMessage);
//
// logger.warn("-------保存知识库训练失败状态----------");
// //训练日志
// BizKnowledgeTrainLogEntity bizKnowledgeTrainLogEntity = new BizKnowledgeTrainLogEntity();
// bizKnowledgeTrainLogEntity.setKdId(message.getKid());
// bizKnowledgeTrainLogEntity.setTimestamp(System.currentTimeMillis());
// bizKnowledgeTrainLogEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
// bizKnowledgeTrainLogEntity.setFailureLog(e.getMessage());
// bizKnowledgeTrainLogEntity.setIsDeleted(CommonConstant.IsDeleted.N);
// bizKnowledgeTrainLogService.save(bizKnowledgeTrainLogEntity);
// logger.warn("-------保存知识库训练失败日志----------");
//
//
// BizKnowledgeInfoEntity bizKnowledgeInfoEntity = bizKnowledgeInfoService.get(message.getKnowledgeInfoId());
// bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
// bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
// logger.warn("-------保存知识库训练失败状态----------");
//
// }
// }
@Override @Override
@Consumer(topic = KnowledgeTopic.TRAIN_STATUS, scale = 1, retry = true) @Consumer(topic = KnowledgeTopic.TRAIN_STATUS, scale = 1, retry = true)
public void trainStatusUpdate(KnowledgeTrainStatusMessage message) throws Exception { public void trainStatusUpdate(KnowledgeTrainStatusMessage message) throws Exception {
...@@ -120,7 +133,6 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService { ...@@ -120,7 +133,6 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
bizKnowledgeDocumentEntity.setTrainStatus(message.getStatus()); bizKnowledgeDocumentEntity.setTrainStatus(message.getStatus());
bizKnowledgeDocumentEntity.setKnowledgeId(message.getKnowledgeId()); bizKnowledgeDocumentEntity.setKnowledgeId(message.getKnowledgeId());
bizKnowledgeDocumentService.update(message.getKdId(), bizKnowledgeDocumentEntity); bizKnowledgeDocumentService.update(message.getKdId(), bizKnowledgeDocumentEntity);
knowledgeProducerService.knowledgeInfoStatusCheck(message);
} }
...@@ -141,6 +153,29 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService { ...@@ -141,6 +153,29 @@ public class KnowledgeConsumerServiceImpl implements KnowledgeConsumerService {
logger.info("-------知识库训练状态检查,kdIds:{}-------", kdIds); logger.info("-------知识库训练状态检查,kdIds:{}-------", kdIds);
for (Integer kdId : kdIdList) { for (Integer kdId : kdIdList) {
BizKnowledgeDocumentEntity documentEntity = bizKnowledgeDocumentService.get(kdId); BizKnowledgeDocumentEntity documentEntity = bizKnowledgeDocumentService.get(kdId);
if (KnowledgeConstant.TrainStatus.TRAINING.equals(documentEntity.getTrainStatus())) {
String knowledgeId = documentEntity.getKnowledgeId();
String trainKnowledgeStatus = demandKnowledgeService.trainKnowledgeStatus(knowledgeId);
if (KnowledgeTrainStatusConstant.fail.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(kdId);
message.setKnowledgeId(knowledgeId);
message.setStatus(KnowledgeConstant.TrainStatus.FAIL);
knowledgeProducerService.trainStatusUpdate(message);
//更新知识库训练状态
bizKnowledgeInfoEntity.setTrainStatus(KnowledgeConstant.TrainStatus.FAIL);
bizKnowledgeInfoService.update(bizKnowledgeInfoEntity);
} else if (KnowledgeTrainStatusConstant.success.equals(trainKnowledgeStatus)) {
KnowledgeTrainStatusMessage message = new KnowledgeTrainStatusMessage();
message.setKdId(kdId);
message.setKnowledgeId(knowledgeId);
message.setStatus(KnowledgeConstant.TrainStatus.COMPLETE);
knowledgeProducerService.trainStatusUpdate(message);
}
}
if (!documentEntity.getTrainStatus().equals(KnowledgeConstant.TrainStatus.COMPLETE)) { if (!documentEntity.getTrainStatus().equals(KnowledgeConstant.TrainStatus.COMPLETE)) {
isAllComplete = false; isAllComplete = false;
break; break;
......
...@@ -240,7 +240,10 @@ public interface DgtoolsApiConstants { ...@@ -240,7 +240,10 @@ public interface DgtoolsApiConstants {
* 知识库 * 知识库
*/ */
String TRAIN_KNOWLEDGE = "knowLedgeRest/trainKnowLedge.json"; String TRAIN_KNOWLEDGE = "knowLedgeRest/trainKnowLedge.json";
String TRAIN_KNOWLEDGE_EVENT = "knowLedgeRest/trainKnowLedgeEvent.json";
String TRAIN_KNOWLEDGE_STATUS = "/knowLedgeRest/trainKnowLedgeStatus.json";
String DEL_KNOWLEDGE = "knowLedgeRest/delKnowLedge.json"; String DEL_KNOWLEDGE = "knowLedgeRest/delKnowLedge.json";
String SEARCH_KNOWLEDGE = "knowLedgeRest/searchKnowledge.json"; String SEARCH_KNOWLEDGE = "knowLedgeRest/searchKnowledge.json";
......
...@@ -19,6 +19,22 @@ public interface DemandKnowledgeService { ...@@ -19,6 +19,22 @@ public interface DemandKnowledgeService {
*/ */
String trainKnowledge(String fileURL, SegmentationConfigRequest segmentationConfig); String trainKnowledge(String fileURL, SegmentationConfigRequest segmentationConfig);
/**
* 训练知识库-异步
*
* @param fileURL 训练文档
* @return 知识库id
*/
String trainKnowledgeEvent(String fileURL, SegmentationConfigRequest segmentationConfig);
/**
* 获取知识库训练状态
*
* @param knowledgeId
* @return 训练状态
*/
String trainKnowledgeStatus(String knowledgeId);
/** /**
* 删除知识库 * 删除知识库
* *
......
...@@ -40,6 +40,31 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService { ...@@ -40,6 +40,31 @@ public class DemandKnowledgeServiceImpl implements DemandKnowledgeService {
return trainKnowledgeResult.getKnowledgeId(); return trainKnowledgeResult.getKnowledgeId();
} }
@Override
public String trainKnowledgeEvent(String fileURL, SegmentationConfigRequest segmentationConfig) {
Assert.notBlank(fileURL);
TrainKnowledgeRequest request = new TrainKnowledgeRequest();
request.setDocumentUrl(fileURL);
request.setSegmentationConfig(segmentationConfig);
TrainKnowledgeResult trainKnowledgeResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.TRAIN_KNOWLEDGE_EVENT, request, getHeaders());
if (null == trainKnowledgeResult) {
throw new BusinessException("train knowledge error");
}
return trainKnowledgeResult.getKnowledgeId();
}
@Override
public String trainKnowledgeStatus(String knowledgeId) {
Assert.notBlank(knowledgeId);
TrainKnowledgeStatusRequest request = new TrainKnowledgeStatusRequest();
request.setKnowledgeId(knowledgeId);
TrainKnowledgeStatusResult trainKnowledgeStatusResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.TRAIN_KNOWLEDGE_STATUS, request, getHeaders());
if (null == trainKnowledgeStatusResult) {
throw new BusinessException("get knowledge train status error");
}
return trainKnowledgeStatusResult.getTrainStatus();
}
@Override @Override
public void delKnowledge(String knowledgeId) { public void delKnowledge(String knowledgeId) {
Assert.notBlank(knowledgeId); Assert.notBlank(knowledgeId);
......
package cn.com.poc.thirdparty.resource.demand.ai.constants;
public interface KnowledgeTrainStatusConstant {
String unTrain = "unTrain";
String line = "line";
String train = "train";
String fail = "fail";
String success = "success";
}
package cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge;
import cn.com.poc.support.dgTools.request.AbstractRequest;
import java.io.Serializable;
public class TrainKnowledgeStatusRequest extends AbstractRequest<TrainKnowledgeStatusResult> implements Serializable {
private String knowledgeId;
public String getKnowledgeId() {
return knowledgeId;
}
public void setKnowledgeId(String knowledgeId) {
this.knowledgeId = knowledgeId;
}
@Override
public String getMethod() throws Exception {
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.entity.knowledge;
import cn.com.poc.support.dgTools.result.AbstractResult;
import java.io.Serializable;
public class TrainKnowledgeStatusResult extends AbstractResult implements Serializable {
private String trainStatus;
public String getTrainStatus() {
return trainStatus;
}
public void setTrainStatus(String trainStatus) {
this.trainStatus = trainStatus;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment