Commit 876dd01a authored by alex yao's avatar alex yao

feat: 优化变量记忆方法,支持多属性同时配置

parent 3e64d624
......@@ -30,6 +30,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResp
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.GetLongMemory;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.LongMemoryEntity;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.GetMemoryVariable;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.GetValueMemory;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.SetValueMemoryConstants;
import cn.com.poc.thirdparty.service.LLMService;
......@@ -608,8 +609,8 @@ public class AgentApplicationInfoServiceImpl implements AgentApplicationInfoServ
}
// 变量记忆
if (LargeModelFunctionEnum.set_value_memory.name().equals(name)) {
Map<Object, Object> map = GetValueMemory.get(identifier + ":" + agentId);
if (LargeModelFunctionEnum.memory_variable_writer.name().equals(name)) {
Map<Object, Object> map = GetMemoryVariable.get(identifier + ":" + agentId);
StringBuilder stringBuilder = new StringBuilder("");
if (MapUtils.isNotEmpty(map)) {
Set<Object> keySet = map.keySet();
......
......@@ -17,6 +17,7 @@ import cn.com.poc.support.security.oauth.entity.UserBaseEntity;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Tool;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryConstants;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.GetMemoryVariable;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.GetValueMemory;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.com.yict.framemax.data.model.PagingInfo;
......@@ -218,21 +219,21 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
List<Tool> tools = new ArrayList<>();
//开启对话变量
if (CollectionUtils.isNotEmpty(infoEntity.getVariableStructure())) {
String functionName = LargeModelFunctionEnum.set_value_memory.name();
String functionName = LargeModelFunctionEnum.memory_variable_writer.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getVariableStructureLLMConfig(infoEntity.getVariableStructure()).get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
//初始化变量函数
Map<Object, Object> map = GetValueMemory.get(agentId + ":" + agentId);
Map<Object, Object> map = GetMemoryVariable.get(agentId + ":" + agentId);
List<Variable> variableStructure = infoEntity.getVariableStructure();
if (MapUtils.isEmpty(map)) {
for (Variable variable : variableStructure) {
String key = variable.getKey();
String variableDefault = variable.getVariableDefault();
JSONObject jsonObject = new JSONObject();
jsonObject.put("contentName", key);
jsonObject.put("contentValue", variableDefault);
jsonObject.put("key", key);
jsonObject.put("value", variableDefault);
LargeModelFunctionEnum.valueOf(functionName).getFunction().doFunction(jsonObject.toJSONString(), agentId + ":" + agentId);
}
}
......@@ -347,7 +348,7 @@ public class AgentApplicationInfoRestImpl implements AgentApplicationInfoRest {
public List<AgentApplicationValueMemoryDto> getVariableList(String agentId) {
List<AgentApplicationValueMemoryDto> result = new ArrayList<>();
BizAgentApplicationInfoEntity infoEntity = bizAgentApplicationInfoService.getByAgentId(agentId);
Map<Object, Object> map = GetValueMemory.get(agentId + ":" + agentId);
Map<Object, Object> map = GetMemoryVariable.get(agentId + ":" + agentId);
List<Variable> variableStructure = infoEntity.getVariableStructure();
if (MapUtils.isEmpty(map)) {
if (CollectionUtils.isEmpty(variableStructure)) {
......
......@@ -28,6 +28,7 @@ import cn.com.poc.thirdparty.resource.demand.ai.constants.LLMRoleEnum;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelDemandResult;
import cn.com.poc.thirdparty.resource.demand.ai.entity.largemodel.LargeModelResponse;
import cn.com.poc.thirdparty.resource.demand.ai.function.LargeModelFunctionEnum;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.GetMemoryVariable;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.GetValueMemory;
import cn.com.poc.thirdparty.service.LLMService;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
......@@ -381,21 +382,21 @@ public class AgentApplicationServiceImpl implements AgentApplicationService {
List<Tool> tools = new ArrayList<>();
//开启对话变量
if (CollectionUtils.isNotEmpty(infoEntity.getVariableStructure())) {
String functionName = LargeModelFunctionEnum.set_value_memory.name();
String functionName = LargeModelFunctionEnum.memory_variable_writer.name();
String llmConfig = LargeModelFunctionEnum.valueOf(functionName).getFunction().getVariableStructureLLMConfig(infoEntity.getVariableStructure()).get(0);
Tool tool = JsonUtils.deSerialize(llmConfig, Tool.class);
tools.add(tool);
//初始化变量函数
Map<Object, Object> map = GetValueMemory.get(identifier + ":" + infoEntity.getAgentId());
Map<Object, Object> map = GetMemoryVariable.get(identifier + ":" + infoEntity.getAgentId());
if (MapUtils.isEmpty(map)) {
List<Variable> variableStructure = infoEntity.getVariableStructure();
for (Variable variable : variableStructure) {
String key = variable.getKey();
String variableDefault = variable.getVariableDefault();
JSONObject jsonObject = new JSONObject();
jsonObject.put("contentName", key);
jsonObject.put("contentValue", variableDefault);
jsonObject.put("key", key);
jsonObject.put("value", variableDefault);
LargeModelFunctionEnum.valueOf(functionName).getFunction().doFunction(jsonObject.toJSONString(), identifier + ":" + infoEntity.getAgentId());
}
}
......
......@@ -2,12 +2,13 @@ package cn.com.poc.thirdparty.resource.demand.ai.function;
import cn.com.poc.common.utils.SpringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.SetValueMemoryFunction;
public enum LargeModelFunctionEnum {
set_long_memory(SetLongMemoryFunction.class),
set_value_memory(SetValueMemoryFunction.class),
memory_variable_writer(MemoryVariableWriterFunction.class),
;
private Class<? extends AbstractLargeModelFunction> function;
......
package cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.SpringUtils;
import java.util.Map;
public class GetMemoryVariable {
public static Map<Object, Object> get(String key) {
String redisKey = MemoryVariableWriterConstants.REDIS_PREFIX + key + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
RedisService redisService = SpringUtils.getBean(RedisService.class);
if (!redisService.hasKey(redisKey)) {
return null;
}
Map<Object, Object> result = redisService.hmget(redisKey);
return result;
}
}
package cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer;
public interface MemoryVariableWriterConstants {
String REDIS_PREFIX = "AGENT_APP_FUNCTION:MEMORY_VARIABLE:";
}
package cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.common.service.RedisService;
import cn.com.poc.common.utils.BlContext;
import cn.com.poc.common.utils.JsonUtils;
import cn.com.poc.common.utils.StringUtils;
import cn.com.poc.thirdparty.resource.demand.ai.function.AbstractLargeModelFunction;
import cn.com.yict.framemax.core.i18n.I18nMessageException;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONException;
import cn.hutool.json.JSONObject;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
public class MemoryVariableWriterFunction extends AbstractLargeModelFunction {
private final String DESC = "该方法仅用enum给定的内容名来保存用户想记录的内容值,不可使用该方法进行查询";
private final String LLM_JSON_SCHEMA = "{\n" +
" \"type\": \"function\",\n" +
" \"function\": {\n" +
" \"name\": \"memory_variable_writer\",\n" +
" \"parameters\": {\n" +
" \"type\": \"array\",\n" +
" \"properties\": {\n" +
" \"key\": {\n" +
" \"type\": \"string\",\n" +
" \"enum\": ${variableStructure},\n" +
" \"description\": \"内容键\"\n" +
" },\n" +
" \"value\": {\n" +
" \"type\": \"string\",\n" +
" \"description\": \"内容值\"\n" +
" }\n" +
" }\n" +
" }\n" +
" },\n" +
" \"description\": \"" + DESC + "\"\n" +
" }";
private final Long expireTime = 30 * 60 * 24L; // 30天有效期
@Resource
private RedisService redisService;
@Override
public String doFunction(String content, String identifier) {
if (StringUtils.isBlank(content) || StringUtils.isBlank(identifier)) {
return "FAIL";
}
String contentKey = MemoryVariableWriterConstants.REDIS_PREFIX + identifier + ":" + BlContext.getCurrentUserNotException().getUserId().toString();
// 创建 JSONObject 对象
Map<String, Object> result = new HashMap<>();
if (redisService.hasKey(contentKey)) {
Map<Object, Object> hmget = redisService.hmget(contentKey);
for (Map.Entry<Object, Object> entry : hmget.entrySet()) {
if (entry.getKey() instanceof String) {
String tempKey = (String) entry.getKey();
result.put(tempKey, entry.getValue());
}
}
}
if (isJsonArray(content)) {
JSONArray jsonArray = new JSONArray(content);
for (int i = 0; i < jsonArray.size(); i++) {
setMap(jsonArray.getJSONObject(i), result);
}
} else {
setMap(new JSONObject(content), result);
}
redisService.hmset(contentKey, result, expireTime);
return "SUCCESS";
}
private static void setMap(JSONObject jsonObject, Map<String, Object> result) {
String key = jsonObject.getStr("key");
String value = jsonObject.getStr("value");
result.put(key, value);
}
@Override
public String getDesc() {
return DESC;
}
@Override
public List<String> getLLMConfig() {
throw new I18nMessageException("exception/this.method.is.not.supported");
}
@Override
public List<String> getVariableStructureLLMConfig(List<Variable> variableStructure) {
List<String> enumList = new ArrayList<>();
for (Variable variable : variableStructure) {
enumList.add(variable.getKey());
}
String enums = JsonUtils.serialize(enumList);
String configStr = LLM_JSON_SCHEMA.replace("${variableStructure}", enums);
List<String> resultList = new ArrayList<>();
resultList.add(configStr);
return resultList;
}
private boolean isJsonArray(String json) {
try {
new JSONArray(json);
return true;
} catch (JSONException e) {
return false;
}
}
}
......@@ -7,20 +7,23 @@ import cn.com.poc.common.utils.BlContext;
import cn.com.poc.expose.aggregate.AgentApplicationService;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.DemandKnowledgeService;
import cn.com.poc.thirdparty.resource.demand.ai.function.long_memory.SetLongMemoryConstants;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.GetValueMemory;
import cn.com.poc.thirdparty.resource.demand.ai.function.value_memory.SetValueMemoryConstants;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import cn.com.yict.framemax.security.oauth.support.UsernameOauthTokenAuthenticationToken;
import com.google.common.collect.Lists;
import org.apache.commons.collections4.ListUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextImpl;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;
@RunWith(SpringJUnit4ClassRunner.class)
......@@ -33,14 +36,14 @@ public class AgentApplicationInfoTest {
private AgentApplicationInfoService applicationInfoService;
@Resource
private DemandKnowledgeService demandKnowledgeService;
private RedisService redisService;
@Resource
private RedisService redisService;
private AgentApplicationService agentApplicationService;
@Test
public void del(){
String contentKey = SetValueMemoryConstants.REDIS_PREFIX ;
public void del() {
String contentKey = SetValueMemoryConstants.REDIS_PREFIX;
redisService.del(contentKey);
}
......@@ -64,17 +67,27 @@ public class AgentApplicationInfoTest {
@Test
public void test() {
List<Object> list = Lists.newArrayList("1", "2", "3", "4", "5", "6", "7", "8", "9", "1");
redisService.lSet("key",list);
redisService.lSet("key", list);
// redisService.sSet("key",1);
}
@Resource
private AgentApplicationService agentApplicationService;
@Test
public void test2() throws InterruptedException {
public void createRecommendQuestion() throws InterruptedException {
agentApplicationService.createRecommendQuestion();
}
@Test
public void getValueMemory() {
String agentId = "d56d9a6a90ed435d831f84bc1af50547";
String contentKey = SetValueMemoryConstants.REDIS_PREFIX + agentId + ":" + "204";
System.out.println(redisService.hmget(contentKey));
// Set<Object> keySet = result.keySet();
// for (Object key : keySet) {
// Object value = result.get(key);
// System.out.println("key:" + key + ",value:" + value);
// }
}
}
package cn.com.poc.demand;
import cn.com.poc.agent_application.entity.Variable;
import cn.com.poc.thirdparty.resource.demand.ai.aggregate.AIDialogueService;
import cn.com.poc.thirdparty.resource.demand.ai.entity.dialogue.Function;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResponse;
import cn.com.poc.thirdparty.resource.demand.ai.entity.function.FunctionCallResult;
import cn.com.poc.thirdparty.resource.demand.ai.function.memory_variable_writer.MemoryVariableWriterFunction;
import cn.com.yict.framemax.core.spring.SingleContextInitializer;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONException;
import com.alibaba.fastjson.JSONObject;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.test.context.ContextConfiguration;
......@@ -59,4 +64,46 @@ public class AiDialogueTest {
System.out.println(functionCallResult);
}
@Resource
private MemoryVariableWriterFunction memoryVariableWriterFunction;
@Test
public void getMemoryVariableWriterFunctionConfig() {
Variable name = new Variable();
name.setKey("name");
name.setVariableDefault("");
Variable age = new Variable();
age.setKey("age");
age.setVariableDefault("");
List<Variable> variableStructure = new ArrayList<>();
variableStructure.add(name);
variableStructure.add(age);
List<String> variableStructureLLMConfig = memoryVariableWriterFunction.getVariableStructureLLMConfig(variableStructure);
System.out.println(variableStructureLLMConfig);
}
@Test
public void testJsonArray() {
String json1 = "[{\"key\": \"name\", \"value\": \"roger\"}, {\"key\": \"age\", \"value\": 12}]";
String json2 = "{\"key\": \"name\", \"value\": \"roger\"}";
System.out.println(isJsonArray(json1));
System.out.println(isJsonArray(json2));
}
public static boolean isJsonArray(String json) {
try {
new JSONArray(json);
return true;
} catch (JSONException e) {
return false;
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment