一个注解解决ShardingJdbc不支持复杂SQL

背景介绍

公司最近做分库分表业务,接入了 Sharding JDBC,接入完成后,回归测试时发现好几个 SQL 执行报错,关键这几个表都还不是分片表。报错如下:

这下糟了嘛。熟悉 Sharding JDBC 的同学应该知道,有很多 SQL 它是不支持的。官方截图如下:

如果要去修改这些复杂 SQL 的话,可能要花费很多时间。那怎么办呢?只能从 Sharding JDBC 这里找突破口了,两天的研究,出来了下面这个只需要加一个注解轻松解决 Sharding Jdbc 不支持复杂 SQL 的方案。

问题复现

我本地写了一个复杂 SQL 进行测试:

public List<Map<String, Object>> queryOrder(){
        List<Map<String, Object>> orders = borderRepository.findOrders(); return orders;
    }
public interface BOrderRepository extends JpaRepository<BOrder,Long> { @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
    List<Map<String, Object>> findOrders();
}

写了个测试 controller 来调用,调用后果然报错了。



解决思路

因为查询的复杂 SQL 的表不是分片表,那能不能指定这几个复杂查询的时候不用 Sharding JDBC 的数据源呢?

  1. 在注入 Sharding JDBC 数据源的地方做处理,注入一个我们自定义的数据源
  2. 这样我们获取连接的时候就能返回原生数据源了
  3. 另外我们声明一个注解,对标识了注解的就返回原生数据源,否则还是返回 Sharding 数据源

具体实现

  1. 编写一个 autoConfig 类,来替换 ShardingSphereAutoConfiguration 类
/**
 * 动态数据源核心自动配置类
 *
 *
 */ @Configuration @ComponentScan("org.apache.shardingsphere.spring.boot.converter") @EnableConfigurationProperties(SpringBootPropertiesConfiguration.class) @ConditionalOnProperty(prefix = "spring.shardingsphere", name = "enabled", havingValue = "true", matchIfMissing = true) @AutoConfigureBefore(DataSourceAutoConfiguration.class) public class DynamicDataSourceAutoConfiguration implements EnvironmentAware { private String databaseName; private final SpringBootPropertiesConfiguration props; private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>(); public DynamicDataSourceAutoConfiguration(SpringBootPropertiesConfiguration props) { this.props = props;
    } /**
     * Get mode configuration.
     *
     * @return mode configuration
     */ @Bean public ModeConfiguration modeConfiguration() { return null == props.getMode() ? null : new ModeConfigurationYamlSwapper().swapToObject(props.getMode());
    } /**
     * Get ShardingSphere data source bean.
     *
     * @param rules rules configuration
     * @param modeConfig mode configuration
     * @return data source bean
     * @throws SQLException SQL exception
     */ @Bean @Conditional(LocalRulesCondition.class) @Autowired(required = false) public DataSource shardingSphereDataSource(final ObjectProvider<List<RuleConfiguration>> rules, final ObjectProvider<ModeConfiguration> modeConfig) throws SQLException {
        Collection<RuleConfiguration> ruleConfigs = Optional.ofNullable(rules.getIfAvailable()).orElseGet(Collections::emptyList);
        DataSource dataSource = ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig.getIfAvailable(), dataSourceMap, ruleConfigs, props.getProps()); return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
    } /**
     * Get data source bean from registry center.
     *
     * @param modeConfig mode configuration
     * @return data source bean
     * @throws SQLException SQL exception
     */ @Bean @ConditionalOnMissingBean(DataSource.class) public DataSource dataSource(final ModeConfiguration modeConfig) throws SQLException {
        DataSource dataSource = !dataSourceMap.isEmpty() ? ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig, dataSourceMap, Collections.emptyList(), props.getProps())
                : ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig); return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
    } /**
     * Create transaction type scanner.
     *
     * @return transaction type scanner
     */ @Bean public TransactionTypeScanner transactionTypeScanner() { return new TransactionTypeScanner();
    } @Override public final void setEnvironment(final Environment environment) {
        dataSourceMap.putAll(DataSourceMapSetter.getDataSourceMap(environment));
        databaseName = DatabaseNameSetter.getDatabaseName(environment);
    } @Role(BeanDefinition.ROLE_INFRASTRUCTURE) @Bean @ConditionalOnProperty(prefix = "spring.datasource.dynamic.aop", name = "enabled", havingValue = "true", matchIfMissing = true) public Advisor dynamicDatasourceAnnotationAdvisor() {
        DynamicDataSourceAnnotationInterceptor interceptor = new DynamicDataSourceAnnotationInterceptor(true);
        DynamicDataSourceAnnotationAdvisor advisor = new DynamicDataSourceAnnotationAdvisor(interceptor, DS.class); return advisor;
    }

}
  1. 自定义数据源
public class WrapShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable{ private ShardingSphereDataSource dataSource; private Map<String, DataSource> dataSourceMap; public WrapShardingDataSource(ShardingSphereDataSource dataSource, Map<String, DataSource> dataSourceMap) { this.dataSource = dataSource; this.dataSourceMap = dataSourceMap;
    } public DataSource getTargetDataSource(){
        String peek = DynamicDataSourceContextHolder.peek(); if(StringUtils.isEmpty(peek)){ return dataSource;
        } return dataSourceMap.get(peek);
    } @Override public Connection getConnection() throws SQLException { return getTargetDataSource().getConnection();
    } @Override public Connection getConnection(final String username, final String password) throws SQLException { return getConnection();
    } @Override public void close() throws Exception {
        DataSource targetDataSource = getTargetDataSource(); if (targetDataSource instanceof AutoCloseable) {
            ((AutoCloseable) targetDataSource).close();
        }
    } @Override public int getLoginTimeout() throws SQLException {
        DataSource targetDataSource = getTargetDataSource(); return targetDataSource ==null ? 0 : targetDataSource.getLoginTimeout();
    } @Override public void setLoginTimeout(final int seconds) throws SQLException {
        DataSource targetDataSource = getTargetDataSource();
        targetDataSource.setLoginTimeout(seconds);
    }
}
  1. 声明指定数据源注解
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DS { /**
     * 数据源名
     */ String value();
}
  1. 另外使用 AOP 的方式拦截使用了注解的类或方法,并且要将这些用了注解的方法存起来,在获取数据源连接的时候取出来进行判断。这就还要用到 ThreadLocal。

aop 拦截器:

public class DynamicDataSourceAnnotationInterceptor implements MethodInterceptor { private final DataSourceClassResolver dataSourceClassResolver; public DynamicDataSourceAnnotationInterceptor(Boolean allowedPublicOnly) {
        dataSourceClassResolver = new DataSourceClassResolver(allowedPublicOnly);
    } @Override public Object invoke(MethodInvocation invocation) throws Throwable {
        String dsKey = determineDatasourceKey(invocation);
        DynamicDataSourceContextHolder.push(dsKey); try { return invocation.proceed();
        } finally {
            DynamicDataSourceContextHolder.poll();
        }
    } private String determineDatasourceKey(MethodInvocation invocation) {
        String key = dataSourceClassResolver.findKey(invocation.getMethod(), invocation.getThis()); return key;
    }
}

aop 切面定义:

/**
 * aop Advisor
 */ public class DynamicDataSourceAnnotationAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware { private final Advice advice; private final Pointcut pointcut; private final Class<? extends Annotation> annotation; public DynamicDataSourceAnnotationAdvisor(MethodInterceptor advice,
                                               Class<? extends Annotation> annotation) { this.advice = advice; this.annotation = annotation; this.pointcut = buildPointcut();
    } @Override public Pointcut getPointcut() { return this.pointcut;
    } @Override public Advice getAdvice() { return this.advice;
    } @Override public void setBeanFactory(BeanFactory beanFactory) throws BeansException { if (this.advice instanceof BeanFactoryAware) {
            ((BeanFactoryAware) this.advice).setBeanFactory(beanFactory);
        }
    } private Pointcut buildPointcut() {
        Pointcut cpc = new AnnotationMatchingPointcut(annotation, true);
        Pointcut mpc = new AnnotationMethodPoint(annotation); return new ComposablePointcut(cpc).union(mpc);
    } /**
     * In order to be compatible with the spring lower than 5.0
     */ private static class AnnotationMethodPoint implements Pointcut { private final Class<? extends Annotation> annotationType; public AnnotationMethodPoint(Class<? extends Annotation> annotationType) {
            Assert.notNull(annotationType, "Annotation type must not be null"); this.annotationType = annotationType;
        } @Override public ClassFilter getClassFilter() { return ClassFilter.TRUE;
        } @Override public MethodMatcher getMethodMatcher() { return new AnnotationMethodMatcher(annotationType);
        } private static class AnnotationMethodMatcher extends StaticMethodMatcher { private final Class<? extends Annotation> annotationType; public AnnotationMethodMatcher(Class<? extends Annotation> annotationType) { this.annotationType = annotationType;
            } @Override public boolean matches(Method method, Class<?> targetClass) { if (matchesMethod(method)) { return true;
                } // Proxy classes never have annotations on their redeclared methods. if (Proxy.isProxyClass(targetClass)) { return false;
                } // The method may be on an interface, so let's check on the target class as well. Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); return (specificMethod != method && matchesMethod(specificMethod));
            } private boolean matchesMethod(Method method) { return AnnotatedElementUtils.hasAnnotation(method, this.annotationType);
            }
        }
    }
}
/**
 * 数据源解析器
 *
 */ public class DataSourceClassResolver { private static boolean mpEnabled = false; private static Field mapperInterfaceField; static {
        Class<?> proxyClass = null; try {
            proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.MybatisMapperProxy");
        } catch (ClassNotFoundException e1) { try {
                proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.PageMapperProxy");
            } catch (ClassNotFoundException e2) { try {
                    proxyClass = Class.forName("org.apache.ibatis.binding.MapperProxy");
                } catch (ClassNotFoundException ignored) {
                }
            }
        } if (proxyClass != null) { try {
                mapperInterfaceField = proxyClass.getDeclaredField("mapperInterface");
                mapperInterfaceField.setAccessible(true);
                mpEnabled = true;
            } catch (NoSuchFieldException e) {
                e.printStackTrace();
            }
        }
    } /**
     * 缓存方法对应的数据源
     */ private final Map<Object, String> dsCache = new ConcurrentHashMap<>(); private final boolean allowedPublicOnly; /**
     * 加入扩展, 给外部一个修改aop条件的机会
     *
     * @param allowedPublicOnly 只允许公共的方法, 默认为true
     */ public DataSourceClassResolver(boolean allowedPublicOnly) { this.allowedPublicOnly = allowedPublicOnly;
    } /**
     * 从缓存获取数据
     *
     * @param method       方法
     * @param targetObject 目标对象
     * @return ds
     */ public String findKey(Method method, Object targetObject) { if (method.getDeclaringClass() == Object.class) { return "";
        }
        Object cacheKey = new MethodClassKey(method, targetObject.getClass());
        String ds = this.dsCache.get(cacheKey); if (ds == null) {
            ds = computeDatasource(method, targetObject); if (ds == null) {
                ds = "";
            } this.dsCache.put(cacheKey, ds);
        } return ds;
    } /**
     * 查找注解的顺序
     * 1\. 当前方法
     * 2\. 桥接方法
     * 3\. 当前类开始一直找到Object
     * 4\. 支持mybatis-plus, mybatis-spring
     *
     * @param method       方法
     * @param targetObject 目标对象
     * @return ds
     */ private String computeDatasource(Method method, Object targetObject) { if (allowedPublicOnly && !Modifier.isPublic(method.getModifiers())) { return null;
        } //1\. 从当前方法接口中获取 String dsAttr = findDataSourceAttribute(method); if (dsAttr != null) { return dsAttr;
        }
        Class<?> targetClass = targetObject.getClass();
        Class<?> userClass = ClassUtils.getUserClass(targetClass); // JDK代理时,  获取实现类的方法声明.  method: 接口的方法, specificMethod: 实现类方法 Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);

        specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod); //2\. 从桥接方法查找 dsAttr = findDataSourceAttribute(specificMethod); if (dsAttr != null) { return dsAttr;
        } // 从当前方法声明的类查找 dsAttr = findDataSourceAttribute(userClass); if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) { return dsAttr;
        } //since 3.4.1 从接口查找,只取第一个找到的 for (Class<?> interfaceClazz : ClassUtils.getAllInterfacesForClassAsSet(userClass)) {
            dsAttr = findDataSourceAttribute(interfaceClazz); if (dsAttr != null) { return dsAttr;
            }
        } // 如果存在桥接方法 if (specificMethod != method) { // 从桥接方法查找 dsAttr = findDataSourceAttribute(method); if (dsAttr != null) { return dsAttr;
            } // 从桥接方法声明的类查找 dsAttr = findDataSourceAttribute(method.getDeclaringClass()); if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) { return dsAttr;
            }
        } return getDefaultDataSourceAttr(targetObject);
    } /**
     * 默认的获取数据源名称方式
     *
     * @param targetObject 目标对象
     * @return ds
     */ private String getDefaultDataSourceAttr(Object targetObject) {
        Class<?> targetClass = targetObject.getClass(); // 如果不是代理类, 从当前类开始, 不断的找父类的声明 if (!Proxy.isProxyClass(targetClass)) {
            Class<?> currentClass = targetClass; while (currentClass != Object.class) {
                String datasourceAttr = findDataSourceAttribute(currentClass); if (datasourceAttr != null) { return datasourceAttr;
                }
                currentClass = currentClass.getSuperclass();
            }
        } // mybatis-plus, mybatis-spring 的获取方式 if (mpEnabled) { final Class<?> clazz = getMapperInterfaceClass(targetObject); if (clazz != null) {
                String datasourceAttr = findDataSourceAttribute(clazz); if (datasourceAttr != null) { return datasourceAttr;
                } // 尝试从其父接口获取 return findDataSourceAttribute(clazz.getSuperclass());
            }
        } return null;
    } /**
     * 用于处理嵌套代理
     *
     * @param target JDK 代理类对象
     * @return InvocationHandler 的 Class
     */ private Class<?> getMapperInterfaceClass(Object target) {
        Object current = target; while (Proxy.isProxyClass(current.getClass())) {
            Object currentRefObject = AopProxyUtils.getSingletonTarget(current); if (currentRefObject == null) { break;
            }
            current = currentRefObject;
        } try { if (Proxy.isProxyClass(current.getClass())) { return (Class<?>) mapperInterfaceField.get(Proxy.getInvocationHandler(current));
            }
        } catch (IllegalAccessException ignore) {
        } return null;
    } /**
     * 通过 AnnotatedElement 查找标记的注解, 映射为  DatasourceHolder
     *
     * @param ae AnnotatedElement
     * @return 数据源映射持有者
     */ private String findDataSourceAttribute(AnnotatedElement ae) {
        AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(ae, DS.class); if (attributes != null) { return attributes.getString("value");
        } return null;
    }
}

ThreadLocal:

public final class DynamicDataSourceContextHolder { /**
     * 为什么要用链表存储(准确的是栈)
     * <pre>
     * 为了支持嵌套切换,如ABC三个service都是不同的数据源
     * 其中A的某个业务要调B的方法,B的方法需要调用C的方法。一级一级调用切换,形成了链。
     * 传统的只设置当前线程的方式不能满足此业务需求,必须使用栈,后进先出。
     * </pre>
     */ private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new NamedThreadLocal<Deque<String>>("dynamic-datasource") {
        @Override protected Deque<String> initialValue() { return new ArrayDeque<>();
        }
    }; private DynamicDataSourceContextHolder() {
    } /**
     * 获得当前线程数据源
     *
     * @return 数据源名称
     */ public static String peek() { return LOOKUP_KEY_HOLDER.get().peek();
    } /**
     * 设置当前线程数据源
     * <p>
     * 如非必要不要手动调用,调用后确保最终清除
     * </p>
     *
     * @param ds 数据源名称
     */ public static String push(String ds) {
        String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;
        LOOKUP_KEY_HOLDER.get().push(dataSourceStr); return dataSourceStr;
    } /**
     * 清空当前线程数据源
     * <p>
     * 如果当前线程是连续切换数据源 只会移除掉当前线程的数据源名称
     * </p>
     */ public static void poll() {
        Deque<String> deque = LOOKUP_KEY_HOLDER.get(); deque.poll(); if (deque.isEmpty()) {
            LOOKUP_KEY_HOLDER.remove();
        }
    } /**
     * 强制清空本地线程
     * <p>
     * 防止内存泄漏,如手动调用了push可调用此方法确保清除
     * </p>
     */ public static void clear() {
        LOOKUP_KEY_HOLDER.remove();
    }
}
  1. 启动类上做如下配置:

引入我们写的自动配置类,排除 ShardingJdbc 的自动配置类。

@SpringBootApplication(exclude = ShardingSphereAutoConfiguration.class) @Import({DynamicDataSourceAutoConfiguration.class}) public class ShardingRunApplication { public static void main(String[] args) {
        SpringApplication.run(ShardingRunApplication.class);
    }
}

最后,我们给之前写的 Repository 加上注解:

public interface BOrderRepository extends JpaRepository<BOrder,Long> { @DS("slave0") @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
    List<Map<String, Object>> findOrders();
}
再次调用,查询成功!!!
#Java##程序员#
全部评论

相关推荐

评论
点赞
1
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务