Commit fa9af058 authored by alex yao's avatar alex yao

refactor: 【长期记忆】【变量】function的实现

parent d6ea0e9f
......@@ -9,6 +9,7 @@ import cn.com.poc.agent_application.service.BizAgentApplicationGcConfigService;
import cn.com.poc.agent_application.service.BizAgentApplicationInfoService;
import cn.com.poc.agent_application.service.BizAgentApplicationLargeModelListService;
import cn.com.poc.agent_application.service.BizAgentApplicationPublishService;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.knowledge.entity.BizKnowledgeDocumentEntity;
......@@ -22,12 +23,15 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.generations.BaiduAISailsT
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.ai.function.*;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.LongMemoryEntity;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.GetLongMemory;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.GetValueMemory;
import cn.com.poc.thirdparty.service.LLMService;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.hutool.core.bean.BeanUtil;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.gson.Gson;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
......@@ -39,9 +43,7 @@ import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.*;
@Component
public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoService {
......@@ -76,6 +78,9 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
@Resource
private LLMService llmService;
@Resource
private RedisService redisService;
@Override
public boolean updateAndPublish(BizAgentApplicationInfoEntity bizAgentApplicationInfoEntity) throws Exception {
UserBaseEntity userBaseEntity = BlContext.getCurrentUserNotException();
......@@ -371,17 +376,39 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
List<String> knowledgeResults = demandKnowledgeService.searchKnowledge(messages.get(messages.size() - 1).getContent().get(0).getText(), knowledgeIds, 3);
promptTemplate = promptTemplate.replace("${knowledgeResults}", knowledgeResults.toString());
}
// todo 获取记忆
// 记忆
if (ArrayUtils.isNotEmpty(tools)) {
for (Tool tool : tools) {
String name = tool.getFunction().getName();
if ("set_long_memory".equals(name)) {
String searchMemoryContent = LargeModelFunctionEnum.valueOf("search_memory_content").getFunction().doFunction(null, identifier);
promptTemplate = promptTemplate.replace("${longMemory}", searchMemoryContent);
// 长期记忆
if (LargeModelFunctionEnum.set_long_memory.name().equals(name)) {
List<LongMemoryEntity> longMemoryEntities = GetLongMemory.get(identifier);
if (CollectionUtils.isNotEmpty(longMemoryEntities)) {
StringBuilder stringBuilder = new StringBuilder();
for (LongMemoryEntity longMemoryEntity : longMemoryEntities) {
stringBuilder
.append("Time").append(":").append(longMemoryEntity.getTimestamp())
.append(StringUtils.SPACE)
.append("Content").append(":").append(longMemoryEntity.getContent())
.append(StringUtils.LF);
}
String searchMemoryContent = stringBuilder.toString();
promptTemplate = promptTemplate.replace("${longMemoryResult}", searchMemoryContent);
}
}
if ("set_value_memory".equals(name)) {
// String searchMemoryContent = LargeModelFunctionEnum.valueOf("search_memory_content").getFunction().doFunction(null, identifier);
// 变量
if (LargeModelFunctionEnum.set_value_memory.name().equals(name)) {
Map<Object, Object> map = GetValueMemory.get(identifier);
StringBuilder stringBuilder = new StringBuilder();
if (MapUtils.isNotEmpty(map)) {
Set<Object> keySet = map.keySet();
for (Object key : keySet) {
stringBuilder.append(key.toString()).append(":").append(map.get(key).toString()).append(StringUtils.LF);
}
}
promptTemplate = promptTemplate.replace("${valueMemoryResult}", stringBuilder.toString());
}
}
......
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.common.utils.SpringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.SetValueMemoryFunction;
public enum LargeModelFunctionEnum {
set_long_memory(SetLongMemoryFunction.class),
set_value_memory(SetValueMemoryFunction.class),
search_memory_content(SearchMemoryContentFunction.class),
search_memory_content_by_enum(SearchMemoryContentByNameFunction.class);
;
private Class<? extends AbstractLargeModelFunction> function;
LargeModelFunctionEnum(Class<? extends AbstractLargeModelFunction> function) {
......
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.hutool.json.JSONObject;
import com.google.gson.Gson;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
public class SearchMemoryContentByNameFunction extends AbstractLargeModelFunction {
@Resource
private RedisService redisService;
@Override
public String doFunction(String content, String key) {
// 用enum给定的内容名来查询用户相关信息
// 创建 JSONObject 对象
JSONObject jsonObject = new JSONObject(content);
String contentName = jsonObject.getStr("content");
StringBuffer result = new StringBuffer();
// 先查询变量记忆
String contentKey = key + ":" + BlContext.getCurrentUserNotException().getUserId().toString() + ":" + contentName;
result.append(redisService.get(contentKey));
// 如果短期记忆没查到
return result.toString();
}
@Override
public List<String> getLLMConfig() {
return null;
}
@Override
public List<String> getVariableStructureLLMConfig(String[] variableStructure) {
Map<String, Object> config = new HashMap<>();
Map<String, Object> function = new HashMap<>();
Map<String, Object> parameters = new HashMap<>();
parameters.put("type", "object");
List<String> required = new ArrayList<>();
required.add("content");
parameters.put("required", required);
// 根据变量名查询记忆方法
Map<String, Object> content = new HashMap<>();
content.put("type", "string");
content.put("description", "内容名");
content.put("enum", variableStructure); // 设置变量
Map<String, Object> searchProperties = new HashMap<>();
searchProperties.put("content", content);
parameters.put("properties", searchProperties);
parameters.put("type","object");
function.put("name", "search_memory_content_by_Enum");
function.put("description", "用enum给定的内容名来查询用户信息(什么内容都可以)");
function.put("parameters", parameters);
config.put("type", "function");
config.put("function", function);
// 将 Map 转换为 JSON 字符串
Gson gson = new Gson();
String jsonString = gson.toJson(config);
List<String> resultList = new ArrayList<>();
resultList.add(jsonString);
return resultList;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import com.google.gson.Gson;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
public class SearchMemoryContentFunction extends AbstractLargeModelFunction {
@Resource
private RedisService redisService;
@Override
public String doFunction(String content, String key) {
// 查询用户相关信息(什么内容都可以)
String result;
String longMemoryKey = key + ":" + BlContext.getCurrentUserNotException().getUserId().toString() + ":" + "longMemory";
result = redisService.hmget(longMemoryKey).toString();
return result;
}
@Override
public List<String> getLLMConfig() {
Map<String, Object> config = new HashMap<>();
Map<String, Object> function = new HashMap<>();
Map<String, Object> content = new HashMap<>();
content.put("type", "string");
content.put("description","信息说明");
Map<String, Object> properties = new HashMap<>();
properties.put("content", content);
Map<String, Object> parameters = new HashMap<>();
parameters.put("type", "object");
parameters.put("properties",properties);
List<String> required = new ArrayList<>();
required.add("content");
parameters.put("required", required);
function.put("name", "search_memory_content");
function.put("description", "获取用户相关信息");
function.put("parameters", parameters);
config.put("type", "function");
config.put("function", function);
// 将 Map 转换为 JSON 字符串
Gson gson = new Gson();
String jsonString = gson.toJson(config);
List<String> resultList = new ArrayList<>();
resultList.add(jsonString);
return resultList;
}
@Override
public List<String> getVariableStructureLLMConfig(String[] variableStructure) {
return null;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.long_memory;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.SpringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class GetLongMemory {
public static List<LongMemoryEntity> get(String key) {
RedisService redisService = SpringUtils.getBean(RedisService.class);
List<LongMemoryEntity> result = new ArrayList<>();
// 查询用户相关信息(什么内容都可以)
String contentKey = SetLongMemoryConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
Map<Object, Object> map = redisService.hmget(contentKey);
Set<Object> keySet = map.keySet();
for (Object mapKey : keySet) {
LongMemoryEntity entity = new LongMemoryEntity();
entity.setContent(map.get(mapKey).toString());
entity.setTimestamp(mapKey.toString());
result.add(entity);
}
return result;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.long_memory;
import java.io.Serializable;
import java.util.Date;
public class LongMemoryEntity implements Serializable {
private String content;
private String timestamp;
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public String getTimestamp() {
return timestamp;
}
public void setTimestamp(String timestamp) {
this.timestamp = timestamp;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.long_memory;
public interface SetLongMemoryConstants {
String REDIS_PREFIX = "AGENT_APP_FUNCTION:LONG_MEMORY:";
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
package cn.com.poc.thirdparty.resource.demand.ai.function.long_memory;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.DateUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.hutool.json.JSONObject;
import com.google.gson.Gson;
import com.sun.org.apache.regexp.internal.RE;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
......@@ -13,6 +16,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 长期记忆
*/
@Service
public class SetLongMemoryFunction extends AbstractLargeModelFunction {
......@@ -20,15 +26,22 @@ public class SetLongMemoryFunction extends AbstractLargeModelFunction {
@Resource
private RedisService redisService;
/**
* 执行函数
*
* @param content 内容
* @param key 唯一标识
* @return
*/
@Override
public String doFunction(String content, String key) {
// todo 执行保存长期记忆的操作
// 创建 JSONObject 对象
JSONObject jsonObject = new JSONObject(content);
// 提取 content
String contents = jsonObject.getStr("content");
String contentKey = key + ":" + BlContext.getCurrentUserNotException().getUserId().toString() + ":" + "longMemory";
Map<Object, Object> hmget = redisService.hmget(key);
String contentKey = SetLongMemoryConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
Map<Object, Object> hmget = redisService.hmget(contentKey);
Map<String, Object> result = new HashMap<>();
for (Map.Entry<Object, Object> entry : hmget.entrySet()) {
if (entry.getKey() instanceof String) {
......@@ -37,10 +50,7 @@ public class SetLongMemoryFunction extends AbstractLargeModelFunction {
result.put(tempKey, entry.getValue());
}
}
List<String> list = new ArrayList<>();
list.add("timestamp:" + DateUtils.getCurrTime());
list.add("content:" + contents);
result.put(Integer.toString(hmget.size()), list);
result.put(DateUtils.getCurrTime(), contents);
redisService.hmset(contentKey, result);
return "SUCCESS";
}
......@@ -84,6 +94,6 @@ public class SetLongMemoryFunction extends AbstractLargeModelFunction {
@Override
public List<String> getVariableStructureLLMConfig(String[] variableStructure) {
return null;
throw new BusinessException("暂不支持变量结构配置");
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.value_memory;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.SpringUtils;
import java.util.Map;
/**
* 获取【变量】内容
*/
public class GetValueMemory {
public static Map<Object, Object> get(String key) {
String contentKey = SetValueMemoryConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
RedisService redisService = SpringUtils.getBean(RedisService.class);
if (!redisService.hasKey(contentKey)) {
return null;
}
Map<Object, Object> result = redisService.hmget(contentKey);
return result;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.value_memory;
public interface SetValueMemoryConstants {
String REDIS_PREFIX = "AGENT_APP_FUNCTION:VALUE_MEMORY:";
}
package cn.com.poc.thirdparty.resource.demand.ai.function;
package cn.com.poc.thirdparty.resource.demand.ai.function.value_memory;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.yict.framemax.core.exception.BusinessException;
import cn.hutool.json.JSONObject;
import com.google.gson.Gson;
import org.springframework.stereotype.Service;
......@@ -18,22 +20,32 @@ public class SetValueMemoryFunction extends AbstractLargeModelFunction {
private RedisService redisService;
@Override
//todo 保存【变量】方法重构
public String doFunction(String content, String key) {
// todo 执行保存变量的操作
String contentKey = SetValueMemoryConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
// 创建 JSONObject 对象
JSONObject jsonObject = new JSONObject(content);
// 提取 contentName 和 contentValue
String contentName = jsonObject.getStr("contentName");
String contentValue = jsonObject.getStr("contentValue");
String contentKey = key + ":" + BlContext.getCurrentUserNotException().getUserId().toString() + ":" + contentName;
redisService.set(contentKey, contentValue);
Map<String, Object> result = new HashMap<>();
if (redisService.hasKey(contentKey)) {
Map<Object, Object> hmget = redisService.hmget(contentKey);
for (Map.Entry<Object, Object> entry : hmget.entrySet()) {
if (entry.getKey() instanceof String) {
String tempKey = (String) entry.getKey();
result.put(tempKey, entry.getValue());
}
}
}
result.put(contentName, contentValue);
redisService.hmset(contentKey, result);
return "SUCCESS";
}
@Override
public List<String> getLLMConfig() {
return null;
throw new BusinessException("不支持此方法");
}
@Override
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment