(V2)SpringAI——会话记忆优化版本
前言
会话历史记忆原版本采用直接的原始信息存储,过长的记忆可能导致会话效果降低。为了优化,我们采用算法端提供的五轮原始会话+历史摘要。
优化版本流程
- 用户打字、说话,前端携带生成的文本内容请求后端
- 后端接受请求,调用大模型
- 检查历史记忆,非摘要信息是否达到了五轮
- 达到了五轮,触发优化逻辑,将五轮会话进行历史总结,进行压缩
- 压缩的信息存储到另一个位置,同时永久出现在会话记忆中
- 开启新的会话记忆循环
- 大模型生成回复,返回前端
- 前端展示大模型回复

优化版本的优势
- 通过摘要设计,在控制会话记忆长度的同时,提升大模型的回复效果
- 同时让大模型的响应速度变得更加平稳
- 保证大模型的回复质量,避免上下文截断或者上下文过长
- 适合于对大模型准确度有一定要求的场景
后端的实现逻辑
1.ChatMemoryRepository实现类技术选型
- 会话记忆存储属于高频率读写,每次会话都会进行读写操作,Redis的内存模型天然适配当前业务场景。
- 但是SpringAI目前不提供Redis的底层实现,我们可以自己进行设计,构造出来符合当前优化场景的会话记忆逻辑。
后端在原来已经实现的RedisChatMemoryRepository设计的基础上,进行进一步优化。
原RedisChatMemoryRepository设计:SpringAI的初步用法
为什么选择Redis作为会话记忆存储?
2.Redis实现类技术架构设计
原RedisChatMemoryRepository设计
- 对于获取所有的conversationId,额外使用list数据结构进行存储,方便直接读取使用
- 对于读写的Message,统一使用自定义的业务模型MessageInfo,同时进行适当转换。
- 对于Message读写,使用String类型,便于全量存储
优化RedisChatMemoryRepository设计
- 保留原来的设计
- 同时新增会话摘要,在存储会话历史的时候,检查是否达到了我们要求的条数。
- 达到阈值后,我们进行会话记忆优化
- 对当前的n条原始信息,我们不进行存储,转而调用总结模型进行摘要总结
- 将摘要总结放入会话摘要专栏
- 优化过程采取消息队列进行异步解耦
3. 坑点说明
- 我们通过接口实现findByConversationId方法,主要是查找包括摘要在内的所有信息
- 但是这里进行查找后,我们的SpringAI框架会进行max-message-size检查并进行截断
- 后续才会进入saveAll方法,这时候出现的情况就是,如果我们的触发阈值与SpringAI的一样,就会被截断
- 因此我们需要额外维护我们设置的阈值,防止SpringAI框架截断我们的所有历史摘要总结
4. 代码示例
4.1 核心代码
package com.project.demo.common.config.ai.common;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.json.JSONUtil;
import com.project.demo.chat.factory.MessageConvertor;
import com.project.demo.chat.model.bean.MessageInfo;
import com.project.demo.common.factory.MessageFactory;
import com.project.demo.common.model.bean.SummaryMessage;
import com.project.demo.common.model.constant.RabbitMQConstant;
import com.project.demo.common.provider.CacheProvider;
import com.project.demo.common.util.RabbitUtils;
import com.project.demo.common.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* @author by 王玉涛
* @Classname RedisSummeryMemoryRepository
* @Description 基于Redis实现的对话记忆仓库,用于存储和管理AI对话历史记录。
* 支持以下功能:
* 1. 存储、查询、删除对话记录
* 2. 对超出最大容量的对话进行摘要处理,并通过消息队列发送给AI模型生成总结
* 3. 使用缓存提供者(CacheProvider)操作Redis,保证数据读写一致性
* @Date 2025/7/13 13:00
*/
@Repository
@Slf4j
@RequiredArgsConstructor
public class RedisSummeryMemoryRepository implements ChatMemoryRepository {
/**
* RabbitMQ工具类,用于异步发送消息到AI模型服务
*/
private final RabbitUtils rabbitUtils;
/**
* 最大对话记录数,超过该数量后将触发摘要逻辑
* 配置来源:application.yml 中的 doubao.memory.max-messages 属性
*/
@Value("${doubao.memory.max-memory-size}")
private Integer maxSize;
/**
* 存储所有对话ID集合的缓存键
* 用于快速获取系统中所有的 conversationId
*/
private static final String AI_MEMORY_CONVERSATION_IDS_KEY = "ai:memory:conversation:ids";
/**
* 对话记录的缓存键前缀
* 后接具体的 conversationId 构成完整键名,如:ai:memory:conversation:12345
*/
private static final String AI_MEMORY_CONVERSATION_KEY_PREFIX = "ai:memory:conversation:";
/**
* 对话摘要的缓存键前缀
* 后接具体的 conversationId 构成完整键名,如:ai:memory:summary:12345
*/
private static final String AI_MEMORY_SUMMARY_KEY_PREFIX = "ai:memory:summary:";
/**
* 缓存提供者接口实例,封装了对Redis的基本操作
*/
private final CacheProvider cacheProvider;
/**
* 获取所有存在的对话ID列表
*
* @return 包含所有 conversationId 的字符串列表
*/
@Override
public List<String> findConversationIds() {
log.info("基于Redis内存的聊天历史记忆:查询所有的id");
return cacheProvider.getSet(AI_MEMORY_CONVERSATION_IDS_KEY, String.class);
}
/**
* 根据指定的 conversationId 查询完整的对话内容
*
* @param conversationId 对话唯一标识符
* @return Spring AI 的 Message 列表,表示整个对话历史
* 因为当前查找的信息,会被截断,所以说,我们目前有一个比较确定的是:我们自己设置的对话窗口以及上下文限制,不能一样。
*/
@Override
public List<Message> findByConversationId(String conversationId) {
log.info("基于Redis内存的聊天历史记忆:查询id为{}的历史记忆+信息摘要", conversationId);
// 从 Redis 中获取原始对话
String data = cacheProvider.getString(AI_MEMORY_CONVERSATION_KEY_PREFIX, conversationId);
List<String> summaryData = cacheProvider.getListString(AI_MEMORY_SUMMARY_KEY_PREFIX, conversationId);
// 如果主数据和摘要都为空,则直接返回空列表,避免后续无意义处理
if (StringUtils.isNullOrEmpty(data) && CollectionUtil.isEmpty(summaryData)) {
return Collections.emptyList();
}
// 将主数据反序列化为 MessageInfo 列表
List<MessageInfo> myMessages = new ArrayList<>();
if (!StringUtils.isNullOrEmpty(data)) {
myMessages.addAll(JSONUtil.toList(data, MessageInfo.class));
}
// 将摘要字符串转换为 MessageInfo 列表
List<MessageInfo> messageInfos = summaryData.stream()
.map(MessageConvertor::contentToMyMessage)
.toList();
// 合并主消息和摘要消息
myMessages.addAll(messageInfos);
// 最终将 MessageInfo 转换为 Spring AI 的 Message 类型并返回
return myMessages.stream()
.map(MessageConvertor::myToAiMessage)
.toList();
}
/**
* 保存与指定 conversationId 关联的所有对话消息
*
* @param conversationId 对话唯一标识符
* @param messages 要保存的 Spring AI Message 列表
*/
@Override
public void saveAll(String conversationId, List<Message> messages) {
List<MessageInfo> messageInfos = messages.stream().map(MessageConvertor::aiMessageToMy).toList();
if (countSystemSize(messageInfos) >= maxSize) {
log.info("基于Redis内存的聊天历史记忆:对话记录数超出最大限制,进行摘要总结并停止记录、删除记忆");
String key = AI_MEMORY_SUMMARY_KEY_PREFIX + conversationId;
SummaryMessage summaryMessage = MessageFactory.createSummaryRequest(key, messageInfos);
rabbitUtils.sendToQueue(
RabbitMQConstant.AI_MODEL_EXCHANGE,
RabbitMQConstant.AI_MODEL_SUMMARY_KEY,
summaryMessage
);
key = AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId;
cacheProvider.deleteByKey(key);
return;
}
log.info("基于Redis内存的聊天历史记忆:保存id为{}的历史记忆", conversationId);
// 添加 conversationId 到全局对话ID集合中
cacheProvider.addSet(AI_MEMORY_CONVERSATION_IDS_KEY, conversationId);
// 序列化消息列表并保存到对应 conversationId 的缓存中
List<MessageInfo> myMessages = messages.stream()
.map(MessageConvertor::aiMessageToMy)
.toList();
myMessages = excludeSystemMessages(myMessages);
cacheProvider.set(AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId, JSONUtil.toJsonStr(myMessages));
}
/**
* 删除与指定 conversationId 关联的所有对话记录
*
* @param myMessages 待处理的消息列表
* @return
*/
List<MessageInfo> excludeSystemMessages(List<MessageInfo> myMessages) {
int size = myMessages.size();
List<MessageInfo> messageInfos = new ArrayList<>(size);
for (MessageInfo messageInfo : myMessages) {
if (Objects.isNull(messageInfo)) {
continue;
}
if (messageInfo.getSenderIdentity() != 3) {
messageInfos.add(messageInfo);
}
}
return messageInfos;
}
/**
* 计算系统消息的个数
*
* @param messageInfos 消息列表
* @return 系统消息的个数
*/
private int countSystemSize(List<MessageInfo> messageInfos) {
int count = 0;
for (MessageInfo messageInfo : messageInfos) {
if (messageInfo.getSenderIdentity() != 3) {
count++;
}
}
return count;
}
/**
* 删除与指定 conversationId 关联的所有对话记录
*
* @param conversationId 要删除的对话唯一标识符
*/
@Override
public void deleteByConversationId(String conversationId) {
log.info("基于Redis内存的聊天历史记忆:删除id为{}的历史记忆", conversationId);
// 删除该 conversationId 对应的对话记录和全局ID集合中的条目
cacheProvider.deleteByKey(AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId);
cacheProvider.deleteByKey(AI_MEMORY_SUMMARY_KEY_PREFIX + conversationId);
cacheProvider.deleteByKey(AI_MEMORY_CONVERSATION_IDS_KEY, conversationId);
}
}
4.2 CacheProvider实现类代码
package com.project.demo.common.provider.impl;
import com.project.demo.common.exception.UserException;
import com.project.demo.common.model.constant.RedisConstant;
import com.project.demo.common.model.constant.RedisTtlConstant;
import com.project.demo.common.provider.CacheProvider;
import com.project.demo.common.util.StringUtils;
import com.project.demo.user.constant.UserErrorConstant;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
/**
* @className: CacheRedisProvider
* @author: 顾漂亮
* @date: 2025/5/27 15:14
*/
@Slf4j
@Component
@RequiredArgsConstructor
public class CacheRedisProvider implements CacheProvider {
/**
* RedisTemple : 先将被存储的数据转换成字节数组(不可读), 再存储到redis中,读取的时候按照字节数组读取
* StringRedisTemplate: 直接存放的就是String(可读)
*/
private final StringRedisTemplate redisTemplate;
/**
* 向redis中存入值
* @param key
* @param value
* @return
*/
@Override
public boolean set(String key, String value){
try{
redisTemplate.opsForValue().set(key, value);
return true;
}catch (Exception e){
log.error("RedisUtil error, set{}, {}", key, value,e);
return false;
}
}
/**
* 向redis中存入值,并且设置过期时间
* @param key
* @param value
* @param time 过期时间,单位秒
* @return
*/
@Override
public boolean set(String key, String value, Long time){
try{
redisTemplate.opsForValue().set(key, value);
return true;
}catch (Exception e){
log.error("RedisUtil error, set{}, {}, {}", key, value,time, e);
return false;
}
}
/**
* 获取redis中的值
* @param key
* @return
*/
@Override
public String get(String key){
try{
return !StringUtils.isNullOrEmpty(key)
? redisTemplate.opsForValue().get(key)
: null;
}catch (Exception e){
log.error("RedisUtil error, get({})", key, e);
return null;
}
}
/**
* 获取redis中的值
* @return
*/
@Override
public String getString(String keyPrefix, String keyEnd){
String key = keyPrefix + keyEnd;
return redisTemplate.opsForValue().get(key);
}
/**
* 获取指定键对应的列表值
*
* @param key 缓存键
* @param clazz 列表元素类型
* @return 列表值
*/
@Override
public <T> List<T> getList(String key, Class<T> clazz) {
List<String> stringList = redisTemplate.opsForList().range(key, 0, -1);
if (CollectionUtils.isEmpty(stringList)) {
return Collections.emptyList();
}
return stringList
.stream()
.map(string -> StringUtils.jsonToClass(string, clazz))
.toList();
}
/**
* 添加列表元素
*
* @param key 缓存键
* @param value 列表元素
*/
@Override
public void addList(String key, String value) {
redisTemplate.opsForList().rightPush(key, value);
}
/**
* 添加集合元素
*
* @param key 缓存键
* @param value 集合元素
*/
@Override
public void addSet(String key, String value) {
redisTemplate.opsForSet().add(key, value);
}
/**
* 获取指定键对应的集合值
*
* @param key 缓存键
* @param stringClass 集合元素类型
* @return 集合值
*/
@Override
public List<String> getSet(String key, Class<String> stringClass) {
Set<String> members = redisTemplate.opsForSet().members(key);
if (CollectionUtils.isEmpty(members)) {
return Collections.emptyList();
}
return members.stream().toList();
}
/**
* 获取指定前缀和后缀的列表值
*
* @param keyPrefix 前缀
* @param keyEnd 后缀
* @return 列表值
*/
@Override
public List<String> getListString(String keyPrefix, String keyEnd) {
List<String> stringList = redisTemplate.opsForList().range(keyPrefix + keyEnd, 0, -1);
if (CollectionUtils.isEmpty(stringList)) {
return Collections.emptyList();
}
return stringList;
}
/**
* 删除redis中的值
* @param key String... 本质上还是数组,但是对于参数的传递来说,参数的个数是可变参数,所以可以传入多个参数
* @return
*/
@Override
public boolean delete(String... key){
try {
if(key != null && key.length > 0){
if(key.length == 1){
redisTemplate.delete(key[0]);
}else{
redisTemplate.delete((Collection<String>) CollectionUtils.arrayToList(key));
}
}
return true;
}catch (Exception e){
log.error("RedisUtil error, delete({})", key, e);
return false;
}
}
/**
* 获取指定键对应的值,并删除该键。
*
* @param key 缓存键
* @return 缓存值,如果不存在则返回 null
*/
@Override
public void deleteByKey(String key) {
redisTemplate.delete(key);
}
/**
* 删除指定键对应的值,并验证值是否匹配。
*
* @param key 缓存键
* @param value 缓存值
*/
@Override
public void deleteByKey(String key, String value) {
/**
* 这行代码 redisTemplate.opsForList().remove(key, 0, value);
* 会从 Redis 中的一个列表(由 key 指定)移除所有等于 value 的元素,
* 且最多移除一次(因为 count = 0)
*/
redisTemplate.opsForList().remove(key, 0, value);
}
@Override
public void validateAndStorageCode(String phone, String code) {
// 1. 获取key
String key = RedisConstant.PHONE_CODE_PREFIX + phone;
// 2. 获取过期时间
Long expire = redisTemplate.getExpire(key, TimeUnit.SECONDS);
// 3. 检验频繁操作
if (expire > RedisTtlConstant.MIN_PHONE_TTL) {
throw new UserException(UserErrorConstant.FREQUENT_OPERATIONS);
}
// 4. 覆盖存储
redisTemplate.opsForValue().set(
key,
code,
RedisTtlConstant.PHONE_CODE_TTL,
TimeUnit.SECONDS);
}
}
4.3 Consumer代码
/**
* @author by 王玉涛
* @Classname AiModelConsumer
* @Description AI模型消息消费者,用于处理AI对话摘要任务。
* 从RabbitMQ队列中消费待处理的对话消息,调用豆包大模型生成对话摘要,
* 并将生成的摘要内容存储到缓存中。
* @Date 2025/7/13 14:30
*/
@Component
@RequiredArgsConstructor
@Slf4j
public class AiModelConsumer {
/**
* 豆包AI提供者,用于获取历史对话摘要客户端
*/
private final DoubaoAiProvider doubaoAiProvider;
/**
* 缓存服务提供者,用于存储生成的摘要内容
*/
private final CacheProvider cacheProvider;
/**
* 监听AI模型摘要队列,处理对话摘要任务
*
* @param summaryMessage 包含对话历史和存储键的摘要消息对象
*/
@RabbitListener(queues = RabbitMQConstant.AI_MODEL_SUMMARY_QUEUE, concurrency = "5")
public void handleAiModelSummary(SummaryMessage summaryMessage) {
log.info("处理消息: {}", summaryMessage);
// 使用豆包AI的历史对话摘要客户端生成摘要内容
String summaryContent = doubaoAiProvider.historySummaryClient()
.prompt(summaryMessage.myMessagesToString())
.call()
.content();
log.info("摘要内容: {}", summaryContent);
// 将生成的摘要内容添加到缓存列表中
cacheProvider.addList(summaryMessage.getStorageKey(), summaryContent);
log.info("存储摘要内容成功");
}
}
#SpringBoot##SpringAI##Java#