Commit 3fe4bcf9 authored by alex yao's avatar alex yao

feat: 新增Text to Speech接口

parent 14e44c7c
package cn.com.poc.common.utils;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Path;
public class PcmToWav {
/**
* @param src 待转换文件路径
* @param target 目标文件路径
* @throws IOException 抛出异常
*/
public static Path convertAudioFiles(String src) throws IOException {
FileInputStream fis = new FileInputStream(src);
File tempAudioFile = File.createTempFile(UUIDTool.getUUID(), "wav");
FileOutputStream fos = new FileOutputStream(tempAudioFile);
//计算长度
byte[] buf = new byte[1024 * 4];
int size = fis.read(buf);
int PCMSize = 0;
while (size != -1) {
PCMSize += size;
size = fis.read(buf);
}
fis.close();
//填入参数,比特率等等。这里用的是16位单声道 8000 hz
WaveHeader header = new WaveHeader();
//长度字段 = 内容的大小(PCMSize) + 头部字段的大小(不包括前面4字节的标识符RIFF以及fileLength本身的4字节)
header.fileLength = PCMSize + (44 - 8);
header.FmtHdrLeth = 16;
header.BitsPerSample = 16;
header.Channels = 1;
header.FormatTag = 0x0001;
header.SamplesPerSec = 16000;
header.BlockAlign = (short) (header.Channels * header.BitsPerSample / 16);
header.AvgBytesPerSec = header.BlockAlign * header.SamplesPerSec;
header.DataHdrLeth = PCMSize;
byte[] h = header.getHeader();
assert h.length == 44; //WAV标准,头部应该是44字节
fos.write(h, 0, h.length);
fis = new FileInputStream(src);
size = fis.read(buf);
while (size != -1) {
fos.write(buf, 0, size);
size = fis.read(buf);
}
fis.close();
fos.close();
return tempAudioFile.toPath();
}
}
package cn.com.poc.common.utils;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
public class WaveHeader {
public final char fileID[] = {'R', 'I', 'F', 'F'};
public int fileLength;
public char wavTag[] = {'W', 'A', 'V', 'E'};
public char FmtHdrID[] = {'f', 'm', 't', ' '};
public int FmtHdrLeth;
public short FormatTag;
public short Channels;
public int SamplesPerSec;
public int AvgBytesPerSec;
public short BlockAlign;
public short BitsPerSample;
public char DataHdrID[] = {'d', 'a', 't', 'a'};
public int DataHdrLeth;
public byte[] getHeader() throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
WriteChar(bos, fileID);
WriteInt(bos, fileLength);
WriteChar(bos, wavTag);
WriteChar(bos, FmtHdrID);
WriteInt(bos, FmtHdrLeth);
WriteShort(bos, FormatTag);
WriteShort(bos, Channels);
WriteInt(bos, SamplesPerSec);
WriteInt(bos, AvgBytesPerSec);
WriteShort(bos, BlockAlign);
WriteShort(bos, BitsPerSample);
WriteChar(bos, DataHdrID);
WriteInt(bos, DataHdrLeth);
bos.flush();
byte[] r = bos.toByteArray();
bos.close();
return r;
}
private void WriteShort(ByteArrayOutputStream bos, int s) throws IOException {
byte[] mybyte = new byte[2];
mybyte[1] = (byte) ((s << 16) >> 24);
mybyte[0] = (byte) ((s << 24) >> 24);
bos.write(mybyte);
}
private void WriteInt(ByteArrayOutputStream bos, int n) throws IOException {
byte[] buf = new byte[4];
buf[3] = (byte) (n >> 24);
buf[2] = (byte) ((n << 8) >> 24);
buf[1] = (byte) ((n << 16) >> 24);
buf[0] = (byte) ((n << 24) >> 24);
bos.write(buf);
}
private void WriteChar(ByteArrayOutputStream bos, char[] id) {
for (char c : id) {
bos.write(c);
}
}
}
package cn.com.poc.expose.websocket;
import cn.com.poc.expose.websocket.constant.WsHandlerMatcher;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
import cn.com.poc.expose.websocket.holder.WSHandlerHolder;
import cn.com.yict.framemax.core.config.Config;
import cn.com.yict.framemax.core.exception.BusinessException;
import org.java_websocket.WebSocket;
import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import org.springframework.stereotype.Component;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.UUID;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Component
public class SuperLinkWebSocketServer extends WebSocketServer {
private final Logger logger = LoggerFactory.getLogger(SuperLinkWebSocketServer.class);
private boolean serverStarted = false;
private final String LOG_KEY = "log_trace_id";
private SuperLinkWebSocketServer() {
}
public SuperLinkWebSocketServer(int port, int decodercount) {
super(new InetSocketAddress(port), decodercount);
}
@Override
public void onOpen(WebSocket webSocket, ClientHandshake clientHandshake) {
// auth(webSocket);
initWebSocketTraceId();
String path = webSocket.getResourceDescriptor();
logger.warn("websocket new connection open , address:{} ,path:{} ", webSocket.getRemoteSocketAddress().toString(), path);
AbstractWsHandler handlerClass = WsHandlerMatcher.getWsHandlerClass(path);
if (handlerClass != null) {
WSHandlerHolder.set(handlerClass);
} else {
webSocket.send("{\"code\":-1, \"message\":\"path mismatch\"}");
webSocket.close();
}
}
private void auth(WebSocket webSocket) {
String hostName = webSocket.getRemoteSocketAddress().getHostName();
logger.info("connection hostname:{}", hostName);
String[] split = Config.get("white.list.ip").split(",");
if (Arrays.stream(split).noneMatch(ip -> ip.equals(hostName))) {
throw new BusinessException("no authority");
}
}
@Override
public void onMessage(WebSocket webSocket, String s) {
logger.info("{} connection send message:{}", webSocket.getRemoteSocketAddress().toString(), s.length() < 1024 ? s : "response too long");
WSHandlerHolder.get().doHandler(webSocket, s);
}
@Override
public void onClose(WebSocket webSocket, int i, String s, boolean b) {
logger.warn("connection is close:{}", webSocket.getRemoteSocketAddress().toString() + webSocket.getResourceDescriptor());
WSHandlerHolder.clear();
}
@Override
public void onError(WebSocket webSocket, Exception e) {
logger.warn("connection is error:{}", webSocket.getRemoteSocketAddress().toString() + webSocket.getResourceDescriptor());
WSHandlerHolder.clear();
}
@Override
public void onStart() {
logger.warn("---------------------------------websocket start--------------------------------------");
}
@Override
public void start() {
serverStarted = true;
super.start();
}
@Override
public void stop() throws IOException, InterruptedException {
serverStarted = false;
super.stop();
}
@Override
public void stop(int i) throws InterruptedException {
serverStarted = false;
super.stop(i);
}
public boolean isRun() {
return this.serverStarted;
}
private void initWebSocketTraceId() {
String traceId = UUID.randomUUID().toString().replaceAll("-", "");
MDC.put(LOG_KEY, traceId);
}
}
package cn.com.poc.expose.websocket.config;
import cn.com.poc.expose.websocket.SuperLinkWebSocketServer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Configuration
public class SuperLinkWebSocketConfig {
@Bean
public SuperLinkWebSocketServer superLinkWebSocketServer() {
return new SuperLinkWebSocketServer(40088, Runtime.getRuntime().availableProcessors() * 10);
}
}
package cn.com.poc.expose.websocket.constant;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
import cn.com.poc.expose.websocket.handler.TextToSpeechTencentHandler;
public class WsHandlerMatcher {
final private static String TEXT_TO_SPEECH_TC = "/websocket/textToSpeechTC.ws";
public static AbstractWsHandler getWsHandlerClass(String path) {
if (StringUtils.isBlank(path)) {
return null;
}
AbstractWsHandler handler = null;
switch (path) {
case TEXT_TO_SPEECH_TC:
handler = new TextToSpeechTencentHandler();
break;
}
return handler;
}
}
package cn.com.poc.expose.websocket.dto;
public class TextToSpeechTencentResponse {
private String codec = "mp3";
private Integer sampleRate = 16000;
private Integer voiceType = 1001;
private Integer volume = 0;
private Float speed = 0F;
private String content;
public Float getSpeed() {
return speed;
}
public void setSpeed(Float speed) {
this.speed = speed;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getCodec() {
return codec;
}
public void setCodec(String codec) {
this.codec = codec;
}
public Integer getSampleRate() {
return sampleRate;
}
public void setSampleRate(Integer sampleRate) {
this.sampleRate = sampleRate;
}
public Integer getVoiceType() {
return voiceType;
}
public void setVoiceType(Integer voiceType) {
this.voiceType = voiceType;
}
public Integer getVolume() {
return volume;
}
public void setVolume(Integer volume) {
this.volume = volume;
}
}
package cn.com.poc.expose.websocket.exception;
import cn.com.yict.framemax.core.exception.ErrorCoded;
import org.slf4j.MDC;
import java.io.Serializable;
/**
* @author alex.yao
* @date 2023/12/15
**/
public class WebsocketException extends RuntimeException implements ErrorCoded {
private static final long serialVersionUID = 2332618265610125980L;
private final String LOG_KEY = "log_trace_id";
private Serializable errorCode = -1;
private String traceId;
public WebsocketException() {
super();
}
public WebsocketException(Throwable cause) {
super(cause);
}
public WebsocketException(String message) {
super(message);
this.traceId = MDC.get(LOG_KEY);
}
public WebsocketException(Serializable code, String message) {
super(message);
this.errorCode = code;
this.traceId = MDC.get(LOG_KEY);
}
public WebsocketException(String message, String traceId) {
super(message);
this.traceId = traceId;
}
public WebsocketException(Serializable code, String message, String traceId) {
super(message);
this.errorCode = code;
this.traceId = traceId;
}
@Override
public Serializable getErrorCode() {
return this.errorCode;
}
@Override
public String getMessage() {
return super.getMessage();
}
@Override
public String toString() {
return "{\"code\":" + "\"" + errorCode + "\" ,\"message\":" + "\"" + super.getMessage() + "\" ,\"traceId\":" + "\"" + traceId + "\"}";
}
}
package cn.com.poc.expose.websocket.handler;
import org.java_websocket.WebSocket;
/**
* @author alex.yao
* @date 2023/12/7
**/
public abstract class AbstractWsHandler {
public abstract void doHandler(WebSocket webSocket,String message);
}
package cn.com.poc.expose.websocket.handler;
import cn.com.poc.common.service.BosConfigService;
import cn.com.poc.common.utils.*;
import cn.com.poc.expose.websocket.dto.TextToSpeechTencentResponse;
import cn.com.poc.expose.websocket.exception.WebsocketException;
import cn.com.yict.framemax.core.config.Config;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.com.yict.framemax.core.i18n.I18nUtils;
import com.tencent.SpeechClient;
import com.tencent.tts.model.SpeechSynthesisRequest;
import com.tencent.tts.model.SpeechSynthesisResponse;
import com.tencent.tts.service.SpeechSynthesisListener;
import com.tencent.tts.service.SpeechSynthesizer;
import org.java_websocket.WebSocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public class TextToSpeechTencentHandler extends AbstractWsHandler {
final Logger logger = LoggerFactory.getLogger(TextToSpeechTencentHandler.class);
private SpeechSynthesizer speechSynthesizer;
private BosConfigService bosConfigService;
@Override
public void doHandler(WebSocket webSocket, String message) {
try {
bosConfigService = SpringUtils.getBean(BosConfigService.class);
TextToSpeechTencentResponse response = JsonUtils.deSerialize(message, TextToSpeechTencentResponse.class);
if (response == null) {
throw new WebsocketException("arg cannot null");
}
if (StringUtils.isBlank(response.getContent())) {
throw new WebsocketException("content cannot null");
}
initTTS(webSocket, response);
textToSpeech(response.getContent());
} catch (Exception e) {
WebsocketException websocketException = new WebsocketException(e);
webSocket.send(websocketException.getMessage());
throw websocketException;
}
}
/**
* 调用TTS
*/
private void textToSpeech(String text) {
this.speechSynthesizer.synthesisLongText(text);
}
/**
* 初始化TTS
*/
private void initTTS(WebSocket webSocket, TextToSpeechTencentResponse textToSpeechTencentResponse) throws IOException {
//从配置文件读取密钥
String appId = Config.get("tencent.speech.synthesizer.appid");
String secretId = Config.get("tencent.speech.synthesizer.secretId");
String secretKey = Config.get("tencent.speech.synthesizer.secretKey");
//创建SpeechSynthesizerClient实例,目前是单例
SpeechClient client = SpeechClient.newInstance(appId, secretId, secretKey);
//初始化SpeechSynthesizerRequest,SpeechSynthesizerRequest包含请求参数
SpeechSynthesisRequest request = SpeechSynthesisRequest.initialize();
request.setSampleRate(textToSpeechTencentResponse.getSampleRate());
request.setSpeed(textToSpeechTencentResponse.getSpeed());
request.setCodec(textToSpeechTencentResponse.getCodec());
request.setVolume(textToSpeechTencentResponse.getVolume());
request.setVoiceType(textToSpeechTencentResponse.getVoiceType());
//使用客户端client创建语音合成实例
if ("wav".equals(textToSpeechTencentResponse.getCodec())) {
request.setCodec("pcm");
}
speechSynthesizer = client.newSpeechSynthesizer(request, new SpeechSynthesisListener() {
List<byte[]> audioBytes = new ArrayList<>();
AtomicInteger sessionId = new AtomicInteger(0);
File tempAudioFile = File.createTempFile(UUIDTool.getUUID(), textToSpeechTencentResponse.getCodec());
AtomicInteger count = new AtomicInteger(0);
@Override
public void onComplete(SpeechSynthesisResponse response) {
logger.info("onComplete");
try (FileOutputStream fileOutputStream = new FileOutputStream(tempAudioFile, true)) {
for (byte[] audioByte : audioBytes) {
fileOutputStream.write(audioByte);
}
} catch (IOException e) {
logger.error("onComplete:{}", e.getMessage());
}
Path path = tempAudioFile.toPath();
if ("wav".equals(textToSpeechTencentResponse.getCodec())) {
try {
path = PcmToWav.convertAudioFiles(tempAudioFile.getPath());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
//上传音频
try (InputStream fileInputStream = Files.newInputStream(path);) {
String uploadUrl = bosConfigService.upload(fileInputStream, textToSpeechTencentResponse.getCodec(), null);
webSocket.send("{\"replyVoiceUrl\":\"" + uploadUrl + "\"}");
} catch (Exception e) {
throw new BusinessException(e);
}
webSocket.send("{\"final\":true}");
webSocket.close();
}
@Override
public void onMessage(byte[] data) {
//发送音频
sessionId.incrementAndGet();
Base64.Encoder encoder = Base64.getEncoder();
String base64 = encoder.encodeToString(data);
webSocket.send("{\"sessionId\":" + count.get() + ",\"audio\":\"" + base64 + "\"}");
audioBytes.add(data);
}
@Override
public void onFail(SpeechSynthesisResponse response) {
logger.warn("onFail:{}", response.getMessage());
webSocket.close();
}
});
}
}
package cn.com.poc.expose.websocket.holder;
import cn.com.poc.expose.websocket.handler.AbstractWsHandler;
/**
* @author alex.yao
* @date 2024/3/19 22:50
*/
public class WSHandlerHolder {
private final static ThreadLocal<AbstractWsHandler> WS_HANDLER_HOLDER = new ThreadLocal<>();
public static AbstractWsHandler get(){
return WS_HANDLER_HOLDER.get();
}
public static void set(AbstractWsHandler abstractWsHandler){
WS_HANDLER_HOLDER.set(abstractWsHandler);
}
public static void clear(){
WS_HANDLER_HOLDER.remove();
}
}
package cn.com.poc.expose.websocket.init;
import cn.com.poc.expose.websocket.SuperLinkWebSocketServer;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
/**
* @author alex.yao
* @date 2023/12/7
**/
@Component
public class SuperLinkWSRun implements InitializingBean {
@Resource
private SuperLinkWebSocketServer superLinkWebSocketServer;
@Override
public void afterPropertiesSet() throws Exception {
if (!superLinkWebSocketServer.isRun()) {
superLinkWebSocketServer.start();
}
}
}
...@@ -72,4 +72,6 @@ collect.limit.message=Click too fast, do not repeat the operation ...@@ -72,4 +72,6 @@ collect.limit.message=Click too fast, do not repeat the operation
sms.limit.message=Do not send it again. Try again later sms.limit.message=Do not send it again. Try again later
phone.is.exist=The mobile number already exists phone.is.exist=The mobile number already exists
email.is.exist=The email already exists email.is.exist=The email already exists
file.load.error=File loading failure file.load.error=File loading failure
\ No newline at end of file content.cannot.null=Content can not be null
arg.cannot.null=Arg can not be null
\ No newline at end of file
...@@ -73,3 +73,5 @@ sms.limit.message=\u8BF7\u52FF\u91CD\u590D\u53D1\u9001\uFF0C\u8BF7\u7A0D\u540E\u ...@@ -73,3 +73,5 @@ sms.limit.message=\u8BF7\u52FF\u91CD\u590D\u53D1\u9001\uFF0C\u8BF7\u7A0D\u540E\u
phone.is.exist=\u8BE5\u624B\u673A\u53F7\u5DF2\u5B58\u5728 phone.is.exist=\u8BE5\u624B\u673A\u53F7\u5DF2\u5B58\u5728
email.is.exist=\u8BE5\u90AE\u7BB1\u5DF2\u5B58\u5728 email.is.exist=\u8BE5\u90AE\u7BB1\u5DF2\u5B58\u5728
file.load.error=\u6587\u4EF6\u52A0\u8F7D\u5931\u8D25 file.load.error=\u6587\u4EF6\u52A0\u8F7D\u5931\u8D25
content.cannot.null=\u5185\u5BB9\u4E0D\u80FD\u4E3A\u7A7A
arg.cannot.null=\u8BF7\u6C42\u53C2\u6570\u4E0D\u80FD\u4E3A\u7A7A
\ No newline at end of file
...@@ -72,4 +72,6 @@ collect.limit.message=\u9EDE\u64CA\u904E\u5FEB\uFF0C\u8ACB\u52FF\u91CD\u8907\u64 ...@@ -72,4 +72,6 @@ collect.limit.message=\u9EDE\u64CA\u904E\u5FEB\uFF0C\u8ACB\u52FF\u91CD\u8907\u64
sms.limit.message=\u8ACB\u52FF\u91CD\u8907\u767C\u9001\uFF0C\u7A0D\u5F8C\u91CD\u8A66 sms.limit.message=\u8ACB\u52FF\u91CD\u8907\u767C\u9001\uFF0C\u7A0D\u5F8C\u91CD\u8A66
phone.is.exist=\u8A72\u624B\u6A5F\u865F\u5DF2\u5B58\u5728 phone.is.exist=\u8A72\u624B\u6A5F\u865F\u5DF2\u5B58\u5728
email.is.exist=\u8A72\u90F5\u7BB1\u5DF2\u5B58\u5728 email.is.exist=\u8A72\u90F5\u7BB1\u5DF2\u5B58\u5728
file.load.error=\u6587\u4EF6\u52A0\u8F09\u5931\u6557 file.load.error=\u6587\u4EF6\u52A0\u8F09\u5931\u6557
\ No newline at end of file content.cannot.null=\u5185\u5BB9\u4E0D\u80FD\u7232\u7A7A
arg.cannot.null=\u8ACB\u6C42\u53C3\u6578\u4E0D\u80FD\u7232\u7A7A
\ No newline at end of file
package cn.com.poc.demand;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AIDialogueService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONException;
import com.alibaba.fastjson.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(initializers = SingleContextInitializer.class)
@WebAppConfiguration
public class AiDialogueTest {
@Resource
AIDialogueService aiDialogueService;
@Test
public void functionCall() {
String query = "帮我记一下,今天下午需要开会";
Map<String, Object> content = new HashMap<>();
content.put("type", "string");
content.put("description", "内容的详细说明");
Map<String, Object> properties = new HashMap<>();
properties.put("content", content);
Map<String, Object> parameters = new HashMap<>();
parameters.put("type", "object");
parameters.put("properties", properties);
List<String> required = new ArrayList<>();
required.add("content");
parameters.put("required", required);
Function function = new Function();
function.setName("set_long_memory");
function.setDescription("该方法仅用来保存用户想记录的内容,不能通过该方法进行查询。");
function.setParameters(parameters);
List<Function> functions = new ArrayList<>();
functions.add(function);
FunctionCallResponse functionCallResponse = new FunctionCallResponse();
functionCallResponse.setQuery(query);
functionCallResponse.setFunctions(functions);
FunctionCallResult functionCallResult = aiDialogueService.functionCall(functionCallResponse);
System.out.println(functionCallResult);
}
@Resource
private MemoryVariableWriterFunction memoryVariableWriterFunction;
@Test
public void getMemoryVariableWriterFunctionConfig() {
Variable name = new Variable();
name.setKey("name");
name.setVariableDefault("");
Variable age = new Variable();
age.setKey("age");
age.setVariableDefault("");
List<Variable> variableStructure = new ArrayList<>();
variableStructure.add(name);
variableStructure.add(age);
List<String> variableStructureLLMConfig = memoryVariableWriterFunction.getVariableStructureLLMConfig(variableStructure);
System.out.println(variableStructureLLMConfig);
}
@Test
public void testJsonArray() {
String json1 = "[{\"key\": \"name\", \"value\": \"roger\"}, {\"key\": \"age\", \"value\": 12}]";
String json2 = "{\"key\": \"name\", \"value\": \"roger\"}";
System.out.println(isJsonArray(json1));
System.out.println(isJsonArray(json2));
}
public static boolean isJsonArray(String json) {
try {
new JSONArray(json);
return true;
} catch (JSONException e) {
return false;
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment