手写@Component和@Resource实现IoC和DI
个人博客:zhenganwen.top
IoC & DI
IoC
即Inversion of Control
,控制反转,核心思想是面向接口编程,让上下游组件面向接口编程从而实现业务的灵活切换(上下游组件解耦)。
DI
即Dependency Injection
,依赖注入,所有的基础组件都面向接口编程,由Spring
帮我们注入接口实现类对象。
以Controller
层和Service
层为例,Controller
为调用方,是为上游。如果按照传统方式,在Controller
中显式地通过new
来引入某Service
进行编程,那么这两者是紧耦合的:
public class UserController{ UserService userService = new UserService(); public void add(User user){ return userService.add(User user); } public void query(){ return userService.query(); } } public class UserService{ public void add(User user){ // insert user } public void query(){ // select user } }
如此,一旦业务变更就需要更改UserService
的代码。如现在系统要做读写分离,那么就需要将UserService
修改如下:
public class UserService{ public void add(User user){ // insert user in db_1 } public void query(){ // select user in db_2 } }
这违背了“开闭原则”(对扩展开发对修改关闭)。
于是在Spring
中,所有的基础组件都面向接口编程,将基础组件的实例交给Spring
创建和管理。通过DI
(依赖注入)来灵活地控制组件之间的依赖关系。
首先针对特定的业务域定义统一的接口:
public interface IUserService{ void add(); void query(); }
根据接口可以有不同的实现类,如读写同库、读写分离:
public class UserServiceImpl implements IUserService{ public void add(User user){ // insert user in db_1 } public void query(){ // select user in db_1 } } public class UserReadWriteServiceImpl implements IUserService{ public void add(User user){ // insert user in db_1 } public void query(){ // select user in db_2 } }
将上述基础业务组件纳入SpringIoC
容器管理
<bean id="userServiceImpl" class="xx.xx.xx.UserServiceImpl"></bean> <bean id="userServiceReadWriteImpl" class="xx.xx.xx.UserServiceReadWriteImpl"></bean>
如此在上游组件中可以做到灵活切换:
@Controller public class UserController{ @Resource(name = "userServiceImpl") //@Resource(name = "userServiceReadWriteImpl") UserService userService; public void add(User user){ return userService.add(User user); } public void query(){ return userService.query(); } }
这种编程模式在业务变更时,只需扩展一个UserServiceReadWriteImpl
,并通过@Resource
更改注入源就可实现业务的切换,而无需更改原有的UserServiceImpl
的代码。
手写xml版本IoC容器
IoC容器:
package cn.tuhu.springioc.ioc; import org.dom4j.Document; import org.dom4j.DocumentException; import org.dom4j.Element; import org.dom4j.io.SAXReader; import java.io.InputStream; import java.util.List; import java.util.concurrent.ConcurrentHashMap; public class CustomClassPathXmlApplicationContext { private String xmlPath; private static ConcurrentHashMap<String, Object> beanFactory; /** * 实例化bean * 提取xml中的<bean></bean>节点,根据其中的class属性实例化bean,以beanId为缓存到beanFactory * 注入依赖 * 遍历bean的属性,将标注有依赖注入的属性到beanFactory中找相关依赖赋值 * @param xmlPath */ public CustomClassPathXmlApplicationContext(String xmlPath) { this.xmlPath = xmlPath; beanFactory = new ConcurrentHashMap<String, Object>(); try { initBeanFactory(); } catch (DocumentException e) { System.out.println("xml解析失败,请检查编写是否正确"); e.printStackTrace(); } catch (ClassNotFoundException e) { System.out.println("bean配置的class不存在"); e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (InstantiationException e) { System.out.println("bean初始化失败,请确保有无参构造方法"); e.printStackTrace(); } } public Object getBean(String beanId) { return beanFactory.get(beanId); } private List<Element> parseXml() throws DocumentException { InputStream inputStream = this.getClass().getClassLoader().getResourceAsStream(xmlPath); SAXReader saxReader = new SAXReader(); Document document = saxReader.read(inputStream); Element rootElement = document.getRootElement(); List<Element> beanElements = rootElement.elements(); return beanElements; } private void initBeanFactory() throws DocumentException, ClassNotFoundException, IllegalAccessException, InstantiationException { List<Element> beanElements = parseXml(); for (Element beanElement : beanElements) { String beanId = beanElement.attributeValue("id"); String className = beanElement.attributeValue("class"); Class<?> clz = Class.forName(className); beanFactory.put(beanId, clz.newInstance()); } } }
spring.xml
<?xml version="1.0" encoding="UTF-8"?> <beans xmlns="http://www.springframework.org/schema/beans" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.springframework.org/schema/beans http://www.springframework.org/schema/beans/spring-beans.xsd"> <bean id="userService" class="cn.tuhu.springioc.service.impl.UserServiceImpl"/> </beans>
组件:
package cn.tuhu.springioc.service; public interface UserService { void add(); } package cn.tuhu.springioc.service.impl; import cn.tuhu.springioc.service.UserService; public class UserServiceImpl implements UserService { public void add() { System.out.println("insert user"); } }
测试:
package cn.tuhu.springioc.service.ioc; import cn.tuhu.springioc.ioc.CustomClassPathXmlApplicationContext; import cn.tuhu.springioc.service.UserService; import org.junit.Test; public class CustomClassPathXmlApplicationContextTest { @Test public void getBean() { CustomClassPathXmlApplicationContext context = new CustomClassPathXmlApplicationContext("spring.xml"); UserService userService = (UserService) context.getBean("userService"); userService.add(); } } insert user
手写注解版IoC容器和注入注解
自定义注解
组件注解
package cn.tuhu.springioc.annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; //可以在类上标注 @Target(ElementType.TYPE) //运行时保留此注解信息 @Retention(RetentionPolicy.RUNTIME) public @interface CustomComponent { }
依赖注入
package cn.tuhu.springioc.annotation; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; //可以在成员变量上标注 @Target(ElementType.FIELD) //运行时保留此注解信息 @Retention(RetentionPolicy.RUNTIME) public @interface CustomResource { }
IoC容器
package cn.tuhu.springioc.ioc; import cn.tuhu.springioc.annotation.CustomComponent; import cn.tuhu.springioc.annotation.CustomResource; import cn.tuhu.springioc.util.ClassUtils; import java.lang.reflect.Field; import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentHashMap; public class CustomAnnotationApplicationContext { private String basePackage; private static ConcurrentHashMap<String, Object> beanFactory = new ConcurrentHashMap<String, Object>(); public CustomAnnotationApplicationContext(String basePackage) { this.basePackage = basePackage; this.initBeans(); Collection<Object> beans = beanFactory.values(); for (Object bean : beans) { this.injectDependencies(bean); } } private void injectDependencies(Object bean) { Class<?> clz = bean.getClass(); Field[] fields = clz.getDeclaredFields(); for (Field field : fields) { CustomResource annotation = field.getAnnotation(CustomResource.class); if (annotation != null) { //如果标注了@CustomResource注解,则需要依赖注入 Object obj = beanFactory.get(field.getName()); field.setAccessible(true);//如果访问权限不够,需要设置此项 try { field.set(bean, obj); //依赖注入 } catch (IllegalAccessException e) { //设置了Accessible则不会抛此异常 e.printStackTrace(); } } } } private void initBeans() { List<Class<?>> classes = ClassUtils.getClasses(basePackage); //ClassUtils工具类见下 for (Class<?> clz : classes) { CustomComponent annotation = clz.getAnnotation(CustomComponent.class); if (annotation != null) { // 标注了@CustomComponent注解,需要纳入IoC容器管理 Object bean = null; try { bean = clz.newInstance(); } catch (InstantiationException e) { System.out.printf("实例化bean:%s 失败", clz.toString()); e.printStackTrace(); } catch (IllegalAccessException e) { System.out.printf("%s访问权限不够", clz.toString()); e.printStackTrace(); } String beanId = this.toLowerCaseFirstChar(clz.getSimpleName()); beanFactory.put(beanId, bean); } } } public Object getBean(String beanId) { return beanFactory.get(beanId); } private String toLowerCaseFirstChar(String className) { StringBuilder stringBuilder = new StringBuilder(className.substring(0,1).toLowerCase()); stringBuilder.append(className.substring(1)); return stringBuilder.toString(); } }
通过反射扫描某package
下所有类的工具类
package cn.tuhu.springioc.util; import java.io.File; import java.io.FileFilter; import java.io.IOException; import java.net.JarURLConnection; import java.net.URL; import java.net.URLDecoder; import java.util.ArrayList; import java.util.Enumeration; import java.util.List; import java.util.jar.JarEntry; import java.util.jar.JarFile; public class ClassUtils { /** * 取得某个接口下所有实现这个接口的类 */ public static List<Class> getAllClassByInterface(Class c) { List<Class> returnClassList = null; if (c.isInterface()) { // 获取当前的包名 String packageName = c.getPackage().getName(); // 获取当前包下以及子包下所以的类 List<Class<?>> allClass = getClasses(packageName); if (allClass != null) { returnClassList = new ArrayList<Class>(); for (Class classes : allClass) { // 判断是否是同一个接口 if (c.isAssignableFrom(classes)) { // 本身不加入进去 if (!c.equals(classes)) { returnClassList.add(classes); } } } } } return returnClassList; } /* * 取得某一类所在包的所有类名 不含迭代 */ public static String[] getPackageAllClassName(String classLocation, String packageName) { // 将packageName分解 String[] packagePathSplit = packageName.split("[.]"); String realClassLocation = classLocation; int packageLength = packagePathSplit.length; for (int i = 0; i < packageLength; i++) { realClassLocation = realClassLocation + File.separator + packagePathSplit[i]; } File packeageDir = new File(realClassLocation); if (packeageDir.isDirectory()) { String[] allClassName = packeageDir.list(); return allClassName; } return null; } /** * 从包package中获取所有的Class * * @param pack * @return */ public static List<Class<?>> getClasses(String packageName) { // 第一个class类的集合 List<Class<?>> classes = new ArrayList<Class<?>>(); // 是否循环迭代 boolean recursive = true; // 获取包的名字 并进行替换 String packageDirName = packageName.replace('.', '/'); // 定义一个枚举的集合 并进行循环来处理这个目录下的things Enumeration<URL> dirs; try { dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName); // 循环迭代下去 while (dirs.hasMoreElements()) { // 获取下一个元素 URL url = dirs.nextElement(); // 得到协议的名称 String protocol = url.getProtocol(); // 如果是以文件的形式保存在服务器上 if ("file".equals(protocol)) { // 获取包的物理路径 String filePath = URLDecoder.decode(url.getFile(), "UTF-8"); // 以文件的方式扫描整个包下的文件 并添加到集合中 findAndAddClassesInPackageByFile(packageName, filePath, recursive, classes); } else if ("jar".equals(protocol)) { // 如果是jar包文件 // 定义一个JarFile JarFile jar; try { // 获取jar jar = ((JarURLConnection) url.openConnection()).getJarFile(); // 从此jar包 得到一个枚举类 Enumeration<JarEntry> entries = jar.entries(); // 同样的进行循环迭代 while (entries.hasMoreElements()) { // 获取jar里的一个实体 可以是目录 和一些jar包里的其他文件 如META-INF等文件 JarEntry entry = entries.nextElement(); String name = entry.getName(); // 如果是以/开头的 if (name.charAt(0) == '/') { // 获取后面的字符串 name = name.substring(1); } // 如果前半部分和定义的包名相同 if (name.startsWith(packageDirName)) { int idx = name.lastIndexOf('/'); // 如果以"/"结尾 是一个包 if (idx != -1) { // 获取包名 把"/"替换成"." packageName = name.substring(0, idx).replace('/', '.'); } // 如果可以迭代下去 并且是一个包 if ((idx != -1) || recursive) { // 如果是一个.class文件 而且不是目录 if (name.endsWith(".class") && !entry.isDirectory()) { // 去掉后面的".class" 获取真正的类名 String className = name.substring(packageName.length() + 1, name.length() - 6); try { // 添加到classes classes.add(Class.forName(packageName + '.' + className)); } catch (ClassNotFoundException e) { e.printStackTrace(); } } } } } } catch (IOException e) { e.printStackTrace(); } } } } catch (IOException e) { e.printStackTrace(); } return classes; } /** * 以文件的形式来获取包下的所有Class * * @param packageName * @param packagePath * @param recursive * @param classes */ public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, List<Class<?>> classes) { // 获取此包的目录 建立一个File File dir = new File(packagePath); // 如果不存在或者 也不是目录就直接返回 if (!dir.exists() || !dir.isDirectory()) { return; } // 如果存在 就获取包下的所有文件 包括目录 File[] dirfiles = dir.listFiles(new FileFilter() { // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件) public boolean accept(File file) { return (recursive && file.isDirectory()) || (file.getName().endsWith(".class")); } }); // 循环所有文件 for (File file : dirfiles) { // 如果是目录 则继续扫描 if (file.isDirectory()) { findAndAddClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), recursive, classes); } else { // 如果是java类文件 去掉后面的.class 只留下类名 String className = file.getName().substring(0, file.getName().length() - 6); try { // 添加到集合中去 classes.add(Class.forName(packageName + '.' + className)); } catch (ClassNotFoundException e) { e.printStackTrace(); } } } } }
基础组件
Service
package cn.tuhu.springioc.service.impl; import cn.tuhu.springioc.annotation.CustomComponent; import cn.tuhu.springioc.service.UserService; @CustomComponent public class UserServiceImpl implements UserService { public void add() { System.out.println("UserServiceImp: insert user"); } }
Controller
package cn.tuhu.springioc.controller; import cn.tuhu.springioc.annotation.CustomComponent; import cn.tuhu.springioc.annotation.CustomResource; import cn.tuhu.springioc.service.UserService; @CustomComponent public class UserController { @CustomResource private UserService userServiceImpl; public void add() { System.out.println("UserController: receive request for add user"); userServiceImpl.add(); } }
测试
package cn.tuhu.springioc.annotation; import cn.tuhu.springioc.controller.UserController; import cn.tuhu.springioc.ioc.CustomAnnotationApplicationContext; import org.junit.Test; public class UserControllerTest { @Test public void add() { CustomAnnotationApplicationContext context = new CustomAnnotationApplicationContext("cn.tuhu.springioc"); UserController userController = (UserController) context.getBean("userController"); userController.add(); } } UserController: receive request for add user UserServiceImp: insert user#Java##Spring#