package cn.com.poc.thirdparty.resource.demand.ai.aggregate.impl;

import cn.com.poc.common.service.BosConfigService;
import cn.com.poc.common.utils.Assert;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AICreateImageService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.OpenAiResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.*;
import cn.com.poc.thirdparty.resource.demand.member.service.DemandAuthService;
import cn.com.poc.support.dgTools.DgtoolsAbstractHttpClient;
import cn.com.poc.support.dgTools.constants.DgtoolsApiConstants;
import cn.com.yict.framemax.core.exception.BusinessException;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.collections.CollectionUtils;
import org.apache.http.Header;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.RequestBuilder;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicHeader;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletResponse;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * @author alex.yao
 * @date 2023/5/30
 **/
@Service
public class AICreateImageServiceImpl implements AICreateImageService {

    private Logger logger = LoggerFactory.getLogger(AICreateImageService.class);

    private final Integer CHAR_BUFFER_SIZE = 100;

    private final String EVENT_STREAM_PREFIX = "data: ";

    private final String EVENT_STREAM_SUFFIX = org.apache.commons.lang3.StringUtils.LF + org.apache.commons.lang3.StringUtils.LF;

    private String FAILED = "FAILED";

    private final String POST = "POST";

    @Value("${dgtools.domain.url}")
    private String DOMAIN_URL;

    @Resource
    private BosConfigService bosConfigService;

    @Resource
    private DemandAuthService demandAuthService;

    @Resource
    private DgtoolsAbstractHttpClient dgToolsAbstractHttpClient;

    @Override
    public GenerationsResult invokeCreateImage(GenerationsRequest request) {
        String minSize = "256x256";
        String midSize = "512x512";
        String maxSize = "1024x1024";
        if (request.getSize() == null) {
            request.setSize(maxSize);
        }
        if (request.getN() == null) {
            request.setN(1);
        }
        Assert.notNull(request.getPrompt());
        Assert.isTrue(request.getPrompt().length() < 1000, "A text description of the desired image(s). The maximum length is 1000 characters.");
        Assert.isTrue(minSize.equals(request.getSize()) || midSize.equals(request.getSize()) || maxSize.equals(request.getSize()), "The size of the generated images. Must be one of 256x256, 512x512, or 1024x1024.");
        Assert.isTrue(request.getN() <= 10 && request.getN() >= 1, "The number of images to generate. Must be between 1 and 10.");

        List<Header> headers = new ArrayList<>();
        headers.add(DgtoolsApiConstants.JSON_HEADER);
        headers.add(DgtoolsApiConstants.AI_HEADER);
        headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
        OpenAiResult openAiResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_OPENAI_CREATE_IMAGE, request, headers);
        if (openAiResult != null) {
            GenerationsResult generationsResult = JSONObject.parseObject(openAiResult.getMessage(), GenerationsResult.class);
            return generationsResult;
        }

        return null;
    }

    @Override
    public GenerationsResult invokeCreateImage(BaiduGenerationsRequest request) throws IOException {
        Assert.notNull(request);
        String minSize = "1024*1024";
        String midSize = "1024*1536";
        String maxSize = "1536*1024";
        if (request.getResolution() == null) {
            request.setResolution(maxSize);
        }
        if (request.getNum() == null) {
            request.setNum(1);
        }
        Assert.notNull(request.getText());
        Assert.isTrue(request.getText().length() < 100, "输入内容，长度不超过100个字");
        Assert.isTrue(minSize.equals(request.getResolution()) || midSize.equals(request.getResolution()) || maxSize.equals(request.getResolution()), "图片分辨率，可支持1024*1024、1024*1536、1536*1024");
        Assert.isTrue(request.getNum() <= 6 && request.getNum() >= 1, "图片生成数量，支持1-6张");

        List<Header> headers = new ArrayList<>();
        headers.add(DgtoolsApiConstants.JSON_HEADER);
        headers.add(DgtoolsApiConstants.AI_HEADER);
        headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
        BaiduGetImageResult baiduGetImageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_CREATE_IMAGE, request, headers);
        if (baiduGetImageResult != null) {
            GenerationsResult generationsResult = new GenerationsResult();
            List<ImgUrl> imgUrls = baiduGetImageResult.getData().getImgUrls();
            if (CollectionUtils.isEmpty(imgUrls)) {
                return null;
            }
            List<Datum> data = new ArrayList<>();
            for (ImgUrl imgUrl : imgUrls) {
                Datum datum = new Datum();
                String url = bosConfigService.uploadImageByUrl2Bos(imgUrl.getImage());
                datum.setUrl(url);
                data.add(datum);
            }
            generationsResult.setData(data);
            return generationsResult;
        }
        return null;
    }


    @Override
    public GenerationsResult invokeCreateImageV2(BaiduGenerationsV2Request request) throws IOException {
        Assert.notNull(request);
        Assert.notBlank(request.getPrompt());
        Assert.notNull(request.getWidth());
        Assert.notNull(request.getHeight());

        if (request.getImageNum() == null) {
            request.setImageNum(1);
        }

        List<Header> headers = new ArrayList<>();
        headers.add(DgtoolsApiConstants.JSON_HEADER);
        headers.add(DgtoolsApiConstants.AI_HEADER);
        headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
        BaiduGetImageV2Result baiduGetImageV2Result = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_CREATE_IMAGE_V2, request, headers);

        if (baiduGetImageV2Result != null) {
            if (FAILED.equals(baiduGetImageV2Result.getData().getTaskStatus())) {
                logger.error("-----------调用百度AI作图高级错误：{}-------------", JSONObject.toJSONString(baiduGetImageV2Result));
                throw new BusinessException("生成图片失败");
            }
            GenerationsResult generationsResult = new GenerationsResult();
            List<SubTaskResult> subTaskResultList = baiduGetImageV2Result.getData().getSubTaskResultList();
            if (CollectionUtils.isEmpty(subTaskResultList)) {
                return null;
            }
            List<String> imgUrls = new ArrayList<>();
            for (SubTaskResult subTaskResult : subTaskResultList) {
                List<FinalImage> finalImageList = subTaskResult.getFinalImageList();
                finalImageList.forEach(finalImage -> {
                    imgUrls.add(finalImage.getImgUrl());
                });
            }
            if (CollectionUtils.isEmpty(imgUrls)) {
                return null;
            }
            List<Datum> data = new ArrayList<>();
            for (String imgUrl : imgUrls) {
                Datum datum = new Datum();
                String url = bosConfigService.uploadImageByUrl2Bos(imgUrl);
                datum.setUrl(url);
                data.add(datum);
            }
            generationsResult.setData(data);
            return generationsResult;
        }
        return null;
    }

    @Override
    public BaiduAISailsText2ImageResult executeSailsText2Image(BaiduAISailsText2ImageRequest request) throws Exception {

        Assert.notNull(request, "文生图配置异常，请联系开发人员");
        Assert.notNull(request.getPrompt(), "文生图配置异常，请联系开发人员");

        List<Header> headers = new ArrayList<>();
        headers.add(DgtoolsApiConstants.JSON_HEADER);
        headers.add(DgtoolsApiConstants.AI_HEADER);
        headers.add(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()));
        BaiduAISailsText2ImageResult imageResult = dgToolsAbstractHttpClient.doRequest(DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_SAILS_TEXT_CREATE_IMAGE, request, headers);
        if (imageResult == null) {
            throw new BusinessException("中台无响应，请联系开发人员");
        }
        if (StringUtils.isNotBlank(imageResult.getErrorMsg())) {
            logger.error("调用中台百度千帆大模型错误，错误信息：{}", JSONObject.toJSONString(imageResult));
            throw new BusinessException(imageResult.getErrorMsg());
        }
        List<BaiduAISailsText2ImageDataItem> dataList = imageResult.getData();
        if (CollectionUtils.isEmpty(dataList)) {
            throw new BusinessException("中台无响应，请联系开发人员");
        }

        //上传到oss
        for (BaiduAISailsText2ImageDataItem imageData : dataList) {
            if (StringUtils.isBlank(imageData.getImageBase64())) {
                throw new BusinessException("中台无响应，请联系开发人员");
            }
            String url = bosConfigService.uploadImageByBase64(imageData.getImageBase64());
            imageData.setUrl(url);
        }
        imageResult.setData(dataList);
        logger.info("调用中台百度大模型成功，响应信息：{}", JSONObject.toJSONString(imageResult));
        return imageResult;

    }

    @Override
    public BaiduAISailsImage2TextResult executeSailsImage2Text(BaiduAISailsImage2TextRequest request, HttpServletResponse httpServletResponse) throws Exception {

        Assert.notNull(request, "图生文配置异常，请联系开发人员");
        Assert.notNull(request.getPrompt(), "图生文配置异常，请联系开发人员");

        String jsonBody = dgToolsAbstractHttpClient.buildJson(request);
        CloseableHttpClient httpClient = HttpClients.createDefault();
        CloseableHttpResponse httpResponse = httpClient.execute(RequestBuilder.create(POST)
                .setUri(DOMAIN_URL + DgtoolsApiConstants.BASE_URL + DgtoolsApiConstants.DgtoolsAI.AI_BAIDU_SAILS_IMAGE_CREATE_TEXT)
                .addHeader(DgtoolsApiConstants.JSON_HEADER)
                .addHeader(DgtoolsApiConstants.AI_HEADER)
                .addHeader(new BasicHeader(DgtoolsApiConstants.HEADER_X_PLATFORM_AUTHORIZATION, demandAuthService.getToken()))
                .setEntity(new StringEntity(jsonBody, StandardCharsets.UTF_8))
                .build()
        );
        if (request.getStream() == null || !request.getStream()) {
            //非流
            String responseBodyStr = EntityUtils.toString(httpResponse.getEntity());
            if (StringUtils.isBlank(responseBodyStr)) {
                throw new BusinessException("无响应，请联系开发人员");
            }
            JSONObject jsonObject = JSONObject.parseObject(responseBodyStr);
            Integer code = (Integer) jsonObject.get("code");
            if (code == 0) {
                String data = jsonObject.getString("data");
                return JSON.parseObject(data, BaiduAISailsImage2TextResult.class);
            } else {
                throw new BusinessException(jsonObject.getString("message"));
            }
        }
        //流式
        return sailsImage2TextEventStreamResponse(httpResponse, httpServletResponse);

    }

    /**
     * 处理千帆大模型图生文流式响应
     */
    BaiduAISailsImage2TextResult sailsImage2TextEventStreamResponse(CloseableHttpResponse httpResponse, HttpServletResponse httpServletResponse) throws Exception {
        PrintWriter writer = null;
        InputStream inputStream = null;
        BufferedReader bufferedReader = null;
        StringBuilder completeResponse = new StringBuilder();
        //返回实体，用作记录保存
        BaiduAISailsImage2TextResult returnResult = new BaiduAISailsImage2TextResult();
        try {
            writer = httpServletResponse.getWriter();
            inputStream = httpResponse.getEntity().getContent();
            bufferedReader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8), CHAR_BUFFER_SIZE);
            String str;
            while ((str = bufferedReader.readLine()) != null) {
                if (org.apache.commons.lang3.StringUtils.isBlank(str)) {
                    continue;
                }
                //去除第一个data:
                Pattern pattern = Pattern.compile("^" + EVENT_STREAM_PREFIX);
                Matcher matcher = pattern.matcher(str);
                if (matcher.find()) {
                    str = matcher.replaceFirst(org.apache.commons.lang3.StringUtils.EMPTY);
                }
                BaiduAISailsImage2TextResult image2TextResult = JSONObject.parseObject(str, BaiduAISailsImage2TextResult.class);
                if (image2TextResult == null || image2TextResult.getId() == null) {
                    continue;
                }
                //不同的结果赋值到返回实体
                if (image2TextResult.getErrorCode() != null) {
                    BeanUtils.copyProperties(image2TextResult, returnResult);
                } else {
                    //拼装完整回答
                    completeResponse.append(image2TextResult.getResult());
                }
                writer.write(EVENT_STREAM_PREFIX + str + EVENT_STREAM_SUFFIX);
                writer.flush();
            }
            returnResult.setResult(completeResponse.toString());
        } catch (Exception e) {
            throw new BusinessException(e.getMessage());
        } finally {
            //关闭资源
            if (writer != null) {
                writer.close();
            }
            if (inputStream != null) {
                inputStream.close();
            }
            if (bufferedReader != null) {
                bufferedReader.close();
            }
        }
        return returnResult;
    }

}
