(V2)SpringAI——会话记忆优化版本

前言

会话历史记忆原版本采用直接的原始信息存储,过长的记忆可能导致会话效果降低。为了优化,我们采用算法端提供的五轮原始会话+历史摘要。

优化版本流程

  1. 用户打字、说话,前端携带生成的文本内容请求后端
  2. 后端接受请求,调用大模型
  3. 检查历史记忆,非摘要信息是否达到了五轮
  4. 达到了五轮,触发优化逻辑,将五轮会话进行历史总结,进行压缩
  5. 压缩的信息存储到另一个位置,同时永久出现在会话记忆中
  6. 开启新的会话记忆循环
  7. 大模型生成回复,返回前端
  8. 前端展示大模型回复

优化版本的优势

  1. 通过摘要设计,在控制会话记忆长度的同时,提升大模型的回复效果
  2. 同时让大模型的响应速度变得更加平稳
  3. 保证大模型的回复质量,避免上下文截断或者上下文过长
  4. 适合于对大模型准确度有一定要求的场景
  5. 后端的实现逻辑

1.ChatMemoryRepository实现类技术选型

    后端在原来已经实现的RedisChatMemoryRepository设计的基础上,进行进一步优化。

    原RedisChatMemoryRepository设计:SpringAI的初步用法

    为什么选择Redis作为会话记忆存储?
  • 会话记忆存储属于高频率读写,每次会话都会进行读写操作,Redis的内存模型天然适配当前业务场景。
  • 但是SpringAI目前不提供Redis的底层实现,我们可以自己进行设计,构造出来符合当前优化场景的会话记忆逻辑。

2.Redis实现类技术架构设计

原RedisChatMemoryRepository设计
  • 对于获取所有的conversationId,额外使用list数据结构进行存储,方便直接读取使用
  • 对于读写的Message,统一使用自定义的业务模型MessageInfo,同时进行适当转换。
  • 对于Message读写,使用String类型,便于全量存储
优化RedisChatMemoryRepository设计
  • 保留原来的设计
  • 同时新增会话摘要,在存储会话历史的时候,检查是否达到了我们要求的条数。
  • 达到阈值后,我们进行会话记忆优化
  • 对当前的n条原始信息,我们不进行存储,转而调用总结模型进行摘要总结
  • 将摘要总结放入会话摘要专栏
  • 优化过程采取消息队列进行异步解耦

3. 坑点说明

  1. 我们通过接口实现findByConversationId方法,主要是查找包括摘要在内的所有信息
  2. 但是这里进行查找后,我们的SpringAI框架会进行max-message-size检查并进行截断
  3. 后续才会进入saveAll方法,这时候出现的情况就是,如果我们的触发阈值与SpringAI的一样,就会被截断
  4. 因此我们需要额外维护我们设置的阈值,防止SpringAI框架截断我们的所有历史摘要总结

4. 代码示例

4.1 核心代码

package com.project.demo.common.config.ai.common;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.json.JSONUtil;
import com.project.demo.chat.factory.MessageConvertor;
import com.project.demo.chat.model.bean.MessageInfo;
import com.project.demo.common.factory.MessageFactory;
import com.project.demo.common.model.bean.SummaryMessage;
import com.project.demo.common.model.constant.RabbitMQConstant;
import com.project.demo.common.provider.CacheProvider;
import com.project.demo.common.util.RabbitUtils;
import com.project.demo.common.util.StringUtils;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Repository;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
 * @author by 王玉涛
 * @Classname RedisSummeryMemoryRepository
 * @Description 基于Redis实现的对话记忆仓库,用于存储和管理AI对话历史记录。
 *               支持以下功能:
 *               1. 存储、查询、删除对话记录
 *               2. 对超出最大容量的对话进行摘要处理,并通过消息队列发送给AI模型生成总结
 *               3. 使用缓存提供者(CacheProvider)操作Redis,保证数据读写一致性
 * @Date 2025/7/13 13:00
 */
@Repository
@Slf4j
@RequiredArgsConstructor
public class RedisSummeryMemoryRepository implements ChatMemoryRepository {

    /**
     * RabbitMQ工具类,用于异步发送消息到AI模型服务
     */
    private final RabbitUtils rabbitUtils;

    /**
     * 最大对话记录数,超过该数量后将触发摘要逻辑
     * 配置来源:application.yml 中的 doubao.memory.max-messages 属性
     */
    @Value("${doubao.memory.max-memory-size}")
    private Integer maxSize;

    /**
     * 存储所有对话ID集合的缓存键
     * 用于快速获取系统中所有的 conversationId
     */
    private static final String AI_MEMORY_CONVERSATION_IDS_KEY = "ai:memory:conversation:ids";

    /**
     * 对话记录的缓存键前缀
     * 后接具体的 conversationId 构成完整键名,如:ai:memory:conversation:12345
     */
    private static final String AI_MEMORY_CONVERSATION_KEY_PREFIX = "ai:memory:conversation:";

    /**
     * 对话摘要的缓存键前缀
     * 后接具体的 conversationId 构成完整键名,如:ai:memory:summary:12345
     */
    private static final String AI_MEMORY_SUMMARY_KEY_PREFIX = "ai:memory:summary:";

    /**
     * 缓存提供者接口实例,封装了对Redis的基本操作
     */
    private final CacheProvider cacheProvider;

    /**
     * 获取所有存在的对话ID列表
     *
     * @return 包含所有 conversationId 的字符串列表
     */
    @Override
    public List<String> findConversationIds() {
        log.info("基于Redis内存的聊天历史记忆:查询所有的id");
        return cacheProvider.getSet(AI_MEMORY_CONVERSATION_IDS_KEY, String.class);
    }

    /**
     * 根据指定的 conversationId 查询完整的对话内容
     *
     * @param conversationId 对话唯一标识符
     * @return Spring AI 的 Message 列表,表示整个对话历史
     * 因为当前查找的信息,会被截断,所以说,我们目前有一个比较确定的是:我们自己设置的对话窗口以及上下文限制,不能一样。
     */
    @Override
    public List<Message> findByConversationId(String conversationId) {
        log.info("基于Redis内存的聊天历史记忆:查询id为{}的历史记忆+信息摘要", conversationId);

        // 从 Redis 中获取原始对话
        String data = cacheProvider.getString(AI_MEMORY_CONVERSATION_KEY_PREFIX, conversationId);
        List<String> summaryData = cacheProvider.getListString(AI_MEMORY_SUMMARY_KEY_PREFIX, conversationId);

        // 如果主数据和摘要都为空,则直接返回空列表,避免后续无意义处理
        if (StringUtils.isNullOrEmpty(data) && CollectionUtil.isEmpty(summaryData)) {
            return Collections.emptyList();
        }

        // 将主数据反序列化为 MessageInfo 列表
        List<MessageInfo> myMessages = new ArrayList<>();
        if (!StringUtils.isNullOrEmpty(data)) {
            myMessages.addAll(JSONUtil.toList(data, MessageInfo.class));
        }

        // 将摘要字符串转换为 MessageInfo 列表
        List<MessageInfo> messageInfos = summaryData.stream()
                .map(MessageConvertor::contentToMyMessage)
                .toList();

        // 合并主消息和摘要消息
        myMessages.addAll(messageInfos);

        // 最终将 MessageInfo 转换为 Spring AI 的 Message 类型并返回
        return myMessages.stream()
                .map(MessageConvertor::myToAiMessage)
                .toList();
    }

    /**
     * 保存与指定 conversationId 关联的所有对话消息
     *
     * @param conversationId 对话唯一标识符
     * @param messages       要保存的 Spring AI Message 列表
     */
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        List<MessageInfo> messageInfos = messages.stream().map(MessageConvertor::aiMessageToMy).toList();
        if (countSystemSize(messageInfos) >= maxSize) {
            log.info("基于Redis内存的聊天历史记忆:对话记录数超出最大限制,进行摘要总结并停止记录、删除记忆");

            String key = AI_MEMORY_SUMMARY_KEY_PREFIX + conversationId;
            SummaryMessage summaryMessage = MessageFactory.createSummaryRequest(key, messageInfos);

            rabbitUtils.sendToQueue(
                    RabbitMQConstant.AI_MODEL_EXCHANGE,
                    RabbitMQConstant.AI_MODEL_SUMMARY_KEY,
                    summaryMessage
            );

            key = AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId;
            cacheProvider.deleteByKey(key);
            return;
        }
        log.info("基于Redis内存的聊天历史记忆:保存id为{}的历史记忆", conversationId);

        // 添加 conversationId 到全局对话ID集合中
        cacheProvider.addSet(AI_MEMORY_CONVERSATION_IDS_KEY, conversationId);

        // 序列化消息列表并保存到对应 conversationId 的缓存中
        List<MessageInfo> myMessages = messages.stream()
                .map(MessageConvertor::aiMessageToMy)
                .toList();

        myMessages = excludeSystemMessages(myMessages);
        cacheProvider.set(AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId, JSONUtil.toJsonStr(myMessages));
    }

    /**
     * 删除与指定 conversationId 关联的所有对话记录
     *
     * @param myMessages 待处理的消息列表
     * @return
     */
    List<MessageInfo> excludeSystemMessages(List<MessageInfo> myMessages) {
        int size = myMessages.size();
        List<MessageInfo> messageInfos = new ArrayList<>(size);
        for (MessageInfo messageInfo : myMessages) {
            if (Objects.isNull(messageInfo)) {
                continue;
            }
            if (messageInfo.getSenderIdentity() != 3) {
                messageInfos.add(messageInfo);
            }
        }
        return messageInfos;
    }

    /**
     * 计算系统消息的个数
     *
     * @param messageInfos 消息列表
     * @return 系统消息的个数
     */
    private int countSystemSize(List<MessageInfo> messageInfos) {
        int count = 0;
        for (MessageInfo messageInfo : messageInfos) {
            if (messageInfo.getSenderIdentity() != 3) {
                count++;
            }
        }
        return count;
    }

    /**
     * 删除与指定 conversationId 关联的所有对话记录
     *
     * @param conversationId 要删除的对话唯一标识符
     */
    @Override
    public void deleteByConversationId(String conversationId) {
        log.info("基于Redis内存的聊天历史记忆:删除id为{}的历史记忆", conversationId);

        // 删除该 conversationId 对应的对话记录和全局ID集合中的条目
        cacheProvider.deleteByKey(AI_MEMORY_CONVERSATION_KEY_PREFIX + conversationId);
        cacheProvider.deleteByKey(AI_MEMORY_SUMMARY_KEY_PREFIX + conversationId);
        cacheProvider.deleteByKey(AI_MEMORY_CONVERSATION_IDS_KEY, conversationId);
    }
}

4.2 CacheProvider实现类代码

package com.project.demo.common.provider.impl;

import com.project.demo.common.exception.UserException;
import com.project.demo.common.model.constant.RedisConstant;
import com.project.demo.common.model.constant.RedisTtlConstant;
import com.project.demo.common.provider.CacheProvider;
import com.project.demo.common.util.StringUtils;
import com.project.demo.user.constant.UserErrorConstant;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;

/**
 * @className: CacheRedisProvider
 * @author: 顾漂亮
 * @date: 2025/5/27 15:14
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class CacheRedisProvider implements CacheProvider {

    /**
     * RedisTemple : 先将被存储的数据转换成字节数组(不可读), 再存储到redis中,读取的时候按照字节数组读取
     * StringRedisTemplate: 直接存放的就是String(可读)
     */
    private final StringRedisTemplate redisTemplate;

    /**
     * 向redis中存入值
     * @param key
     * @param value
     * @return
     */
    @Override
    public boolean set(String key, String value){
        try{
            redisTemplate.opsForValue().set(key, value);
            return true;
        }catch (Exception e){
            log.error("RedisUtil error, set{}, {}", key, value,e);
            return false;
        }
    }


    /**
     * 向redis中存入值,并且设置过期时间
     * @param key
     * @param value
     * @param time 过期时间,单位秒
     * @return
     */
    @Override
    public boolean set(String key, String value, Long time){
        try{
            redisTemplate.opsForValue().set(key, value);
            return true;
        }catch (Exception e){
            log.error("RedisUtil error, set{}, {}, {}", key, value,time, e);
            return false;
        }
    }


    /**
     * 获取redis中的值
     * @param key
     * @return
     */
    @Override
    public String get(String key){
        try{
            return !StringUtils.isNullOrEmpty(key)
                    ? redisTemplate.opsForValue().get(key)
                    : null;
        }catch (Exception e){
            log.error("RedisUtil error, get({})", key, e);
            return null;
        }
    }

    /**
     * 获取redis中的值
     * @return
     */
    @Override
    public String getString(String keyPrefix, String keyEnd){
        String key = keyPrefix + keyEnd;
        return redisTemplate.opsForValue().get(key);
    }

    /**
     * 获取指定键对应的列表值
     *
     * @param key   缓存键
     * @param clazz 列表元素类型
     * @return 列表值
     */
    @Override
    public <T> List<T> getList(String key, Class<T> clazz) {
        List<String> stringList = redisTemplate.opsForList().range(key, 0, -1);
        if (CollectionUtils.isEmpty(stringList)) {
            return Collections.emptyList();
        }
        return stringList
                .stream()
                .map(string -> StringUtils.jsonToClass(string, clazz))
                .toList();
    }

    /**
     * 添加列表元素
     *
     * @param key   缓存键
     * @param value 列表元素
     */
    @Override
    public void addList(String key, String value) {
        redisTemplate.opsForList().rightPush(key, value);
    }

    /**
     * 添加集合元素
     *
     * @param key   缓存键
     * @param value 集合元素
     */
    @Override
    public void addSet(String key, String value) {
        redisTemplate.opsForSet().add(key, value);
    }

    /**
     * 获取指定键对应的集合值
     *
     * @param key         缓存键
     * @param stringClass 集合元素类型
     * @return 集合值
     */
    @Override
    public List<String> getSet(String key, Class<String> stringClass) {
        Set<String> members = redisTemplate.opsForSet().members(key);
        if (CollectionUtils.isEmpty(members)) {
            return Collections.emptyList();
        }
        return members.stream().toList();
    }

    /**
     * 获取指定前缀和后缀的列表值
     *
     * @param keyPrefix 前缀
     * @param keyEnd    后缀
     * @return 列表值
     */
    @Override
    public List<String> getListString(String keyPrefix, String keyEnd) {
        List<String> stringList = redisTemplate.opsForList().range(keyPrefix + keyEnd, 0, -1);
        if (CollectionUtils.isEmpty(stringList)) {
            return Collections.emptyList();
        }
        return stringList;
    }

    /**
     * 删除redis中的值
     * @param key String... 本质上还是数组,但是对于参数的传递来说,参数的个数是可变参数,所以可以传入多个参数
     * @return
     */
    @Override
    public boolean delete(String... key){
        try {
            if(key != null && key.length > 0){
                if(key.length == 1){
                    redisTemplate.delete(key[0]);
                }else{
                    redisTemplate.delete((Collection<String>) CollectionUtils.arrayToList(key));
                }
            }
            return true;
        }catch (Exception e){
            log.error("RedisUtil error, delete({})", key, e);
            return false;
        }
    }

    /**
     * 获取指定键对应的值,并删除该键。
     *
     * @param key 缓存键
     * @return 缓存值,如果不存在则返回 null
     */
    @Override
    public void deleteByKey(String key) {
        redisTemplate.delete(key);
    }

    /**
     * 删除指定键对应的值,并验证值是否匹配。
     *
     * @param key   缓存键
     * @param value 缓存值
     */
    @Override
    public void deleteByKey(String key, String value) {
        /**
         * 这行代码 redisTemplate.opsForList().remove(key, 0, value);
         * 会从 Redis 中的一个列表(由 key 指定)移除所有等于 value 的元素,
         * 且最多移除一次(因为 count = 0)
         */
        redisTemplate.opsForList().remove(key, 0, value);
    }

    @Override
    public void validateAndStorageCode(String phone, String code) {
        // 1. 获取key
        String key = RedisConstant.PHONE_CODE_PREFIX + phone;
        // 2. 获取过期时间
        Long expire = redisTemplate.getExpire(key, TimeUnit.SECONDS);
        // 3. 检验频繁操作
        if (expire > RedisTtlConstant.MIN_PHONE_TTL) {
            throw new UserException(UserErrorConstant.FREQUENT_OPERATIONS);
        }
        // 4. 覆盖存储
        redisTemplate.opsForValue().set(
                key,
                code,
                RedisTtlConstant.PHONE_CODE_TTL,
                TimeUnit.SECONDS);
    }
}

4.3 Consumer代码

/**
 * @author by 王玉涛
 * @Classname AiModelConsumer
 * @Description AI模型消息消费者,用于处理AI对话摘要任务。
 *              从RabbitMQ队列中消费待处理的对话消息,调用豆包大模型生成对话摘要,
 *              并将生成的摘要内容存储到缓存中。
 * @Date 2025/7/13 14:30
 */
@Component
@RequiredArgsConstructor
@Slf4j
public class AiModelConsumer {
    
    /**
     * 豆包AI提供者,用于获取历史对话摘要客户端
     */
    private final DoubaoAiProvider doubaoAiProvider;

    /**
     * 缓存服务提供者,用于存储生成的摘要内容
     */
    private final CacheProvider cacheProvider;


    /**
     * 监听AI模型摘要队列,处理对话摘要任务
     *
     * @param summaryMessage 包含对话历史和存储键的摘要消息对象
     */
    @RabbitListener(queues = RabbitMQConstant.AI_MODEL_SUMMARY_QUEUE, concurrency = "5")
    public void handleAiModelSummary(SummaryMessage summaryMessage) {
        log.info("处理消息: {}", summaryMessage);

        // 使用豆包AI的历史对话摘要客户端生成摘要内容
        String summaryContent = doubaoAiProvider.historySummaryClient()
                .prompt(summaryMessage.myMessagesToString())
                .call()
                .content();
        
        log.info("摘要内容: {}", summaryContent);

        // 将生成的摘要内容添加到缓存列表中
        cacheProvider.addList(summaryMessage.getStorageKey(), summaryContent);
        log.info("存储摘要内容成功");
    }
}

#SpringBoot##SpringAI##Java#
全部评论

相关推荐

11-03 17:42
门头沟学院 Java
点赞 评论 收藏
分享
点赞 评论 收藏
分享
点赞 评论 收藏
分享
评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务