前言
如果大家对MyBatis源码不熟悉,可以阅读我的这篇文章,专门讲解MyBatis源码阅读的https://juejin.cn/post/7017638866626543624
如果大家想知道MyBatis插件怎么融入实际项目,请参考我的开源项目https://gitee.com/zhuhuijie/base-platform
插件部分位于base-platform/base-common/common-db-mysql下
感兴趣的点个star,持续更新中...
MyBatis 四大内置对象
- Executor 执行器 实际用来执行SQL的对象
StatementHandler 数据库会话处理器 编译/处理SQL语句的
- PreparedStatementHanler 创建PreparedStatement 最常用占位符
- CallableStatementHandler 创建CallableStatement 执行存储过程
- SimpleStatementHanler 创建Statement 字符串拼接,有SQL注入风险
ParameterHandler 参数处理器
public interface ParameterHandler { Object getParameterObject(); void setParameter(PreparedStatement ps) }
ResultSetHandler 处理结果集
public interface ResultSetHandler { <E> list<E> handlerResultSets(Statement stmt) throws SQLException; <E> Cursor<E> handlerCursorResultSets(Statement stmt) throws SQLException; void handlerOutputParameters(CallableStatement cs) throws SQLException; }
MyBatis 执行SQL的过程
- 根据配置,获取SQLSession对象
通过动态代理,获取Mapper的代理对象
StudentMapper mapper = sqlSession.getMapper(StudentMapper.class);
通过代理对象调用具体SQL
Student student = mapper.getStudentById(id);
- 通过反射调用该方法
mapperMethod.execute(sqlSession, args);
- INSERT sqlSession.insert()
- UPDATE sqlSession.update()
- DELETE sqlSession.delete()
SELECT sqlSession.select()
selectList
executor.query() 调用CachingExecutor【装饰者模式】 真实使用SimpleExecutor--->父类BaseExcutor.query() ---> doQuery()抽象 -->SimpleExecutor.doQuery() 【模板模式】
Handler对象初始化
创建一个委托,根据不同StatementType创建不同的对象new PreparedStatementHanler()
- JDBC的Statement stmt = preparedStatementHanler.instantiateStatement() ---> connection.preparedStatement()
handler.parameterize(stmt) 参数处理
- ParameterHandler
- resultSetHandler.handlerResultSets(preparedStatement) 封装结果
- ...
- 得到结果
MyBatis 插件如何开发
MyBatis插件本质上就是对MyBatis四大内置对象的增强。
它是基于MyBatis的拦截器,通过AOP的方式进行使用。
案例一 打印SQL插件:
创建拦截器
注意拦截器实现的是ibatis包下的,上边的注解决定了我们的拦截器是从MyBatis的哪里进行切入的,然后通过AOP的方式进行扩展。
package com.zhj.common.db.mysql.plugins; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.session.ResultHandler; import java.math.BigDecimal; import java.math.RoundingMode; import java.sql.Statement; /** * 打印SQL语句 * 1. 记录SQL语句 * 2. 记录执行的时间 * type 增强的内置对象的类型(必须是四大内置对象中的一个 StatementHandler.class(增强最多的)) * method 增强的方法名 * args{} 是形参列表,防止方法重载,找不到对应的方法 * @author zhj */ @Slf4j @Intercepts({ @Signature( type = StatementHandler.class, method = "query", args = {Statement.class, ResultHandler.class} ), @Signature( type = StatementHandler.class, method = "update", args = {Statement.class} ) } ) public class PrintSQLPlugins implements Interceptor { /** * 拦截方法 * @param invocation * @return * @throws Throwable */ @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler= (StatementHandler) invocation.getTarget(); BoundSql boundSql = statementHandler.getBoundSql(); String sql = boundSql.getSql(); log.info("----------------------------【SQL】-------------------------------"); log.info(sql.replace("\n","")); long beginTime = System.currentTimeMillis(); Object proceed = invocation.proceed(); // 放行,执行目标对象的对应方法 long endTime = System.currentTimeMillis(); log.info("----------------------------【SQL执行的时长为:{} s】", BigDecimal.valueOf(endTime - beginTime).divide(BigDecimal.valueOf(1000)).setScale(6, RoundingMode.DOWN).doubleValue()); return proceed; } }
让该插件生效
package com.zhj.common.db.mysql.config; import com.zhj.common.db.mysql.plugins.PrintSQLPlugins; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.transaction.annotation.EnableTransactionManagement; /** * @author zhj */ @Configuration @MapperScan("com.zhj.data.mapper") @EnableTransactionManagement public class DBAutoConfiguration { @Bean @ConditionalOnProperty(value = "zhj.plugins.printSql.enable", havingValue = "true", matchIfMissing = false) public PrintSQLPlugins getPrintSQLPlugins(){ return new PrintSQLPlugins(); } }
- 通过配置决定是否启用插件
@ConditionalOnProperty(value = "zhj.plugins.printSql.enable", havingValue = "true", matchIfMissing = false)
导入依赖,创建Bean使插件在配置时可以自动提示
package com.zhj.common.db.mysql.entity; import lombok.Data; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.stereotype.Component; /** * @author zhj */ @Component @ConfigurationProperties(prefix = "zhj.plugins.printSql") @Data public class ZhjConfigInfo { private Boolean enable; }
依赖:
<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-configuration-processor</artifactId> <optional>true</optional> </dependency>
配置文件中开启插件:
zhj: plugins: printSql: enable: true
案例二 分页插件:
基础分页插件的实现:
创建分页对象
package com.zhj.common.db.mysql.page; import lombok.Data; import lombok.experimental.Accessors; import java.io.Serializable; /** * 分页信息对象 * @author zhj */ @Data @Accessors(chain = true) public class Page implements Serializable { /** * 当前页 */ private Integer pageNo; /** * 每页多少条 */ private Integer pageSize; /** * 总页码 */ private Integer pageTotal; /** * 总条数 */ private Integer pageCount; }
创建分页工具
这里我们通过ThreadLocal来设置分页对象
package com.zhj.common.db.mysql.page; /** * 分页管理器 * @author zhj */ public class PageUtils { private static ThreadLocal<Page> pageThreadLocal = new ThreadLocal<>(); /** * 设置分页对象 * @param pageNo * @param pageSize */ public static void setPage(Integer pageNo, Integer pageSize){ pageThreadLocal.set(new Page().setPageNo(pageNo).setPageSize(pageSize)); } /** * 获取分页对象 * @return */ public static Page getPage(){ return pageThreadLocal.get(); } /** * 清理分页信息 */ public static void clear(){ pageThreadLocal.remove(); } }
创建实现分页插件的拦截器
package com.zhj.common.db.mysql.plugins; import com.zhj.common.db.mysql.page.Page; import com.zhj.common.db.mysql.page.PageUtils; import com.zhj.common.db.mysql.util.MybatisUtils; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; /** * 分页插件 * @author zhj */ @Slf4j @Intercepts({ @Signature( type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class} // 需要与对应版本一致 ) }) public class PagePlugins implements Interceptor { @Override public Object intercept(Invocation invocation) throws Throwable { // 获取非代理对象 StatementHandler target = MybatisUtils.getNoProxyTarget(invocation.getTarget()); BoundSql boundSql = target.getBoundSql(); // 拿到sql 转为小写,去掉前后空格 String sql = boundSql.getSql().toLowerCase().trim(); // 判断是否需要添加分页 if (!sql.startsWith("select")) { return invocation.proceed(); } // 获取分页参数 Page page = PageUtils.getPage(); if (page == null) { return invocation.proceed(); } // 处理分页 log.info("[需要分页的SQL: {}", sql.replace("\n","")); // 构建一个查询分页总条数的sql; Integer count = count(target, invocation, sql); log.info("[SQL的总条数为: " + count); // 处理pageNo if (page.getPageNo() == null || page.getPageNo() < 1) page.setPageNo(1); // 处理pageSize if (page.getPageSize() == null || page.getPageSize() < 1) page.setPageSize(10); // 设置分页对象 page.setPageCount(count); page.setPageTotal(page.getPageCount() % page.getPageSize() == 0 ? page.getPageCount()/ page.getPageSize() : page.getPageCount()/ page.getPageSize() + 1); if (page.getPageNo() > page.getPageTotal()) page.setPageNo(page.getPageTotal()); log.info("[处理过的Page为: " + page); sql += " limit " + (page.getPageNo() * page.getPageSize() - 1) + "," + page.getPageSize(); log.info("[分页处理过的SQL: {}", sql.replace("\n","")); // 通过反射设置BoundSql的sql // MyBatis提供了工具,该工具通过反射实现 MetaObject metaObject = SystemMetaObject.forObject(boundSql); metaObject.setValue("sql", sql); return invocation.proceed(); } /** * 获取sql的总条数 * @param sql * @return */ private Integer count(StatementHandler statementHandler, Invocation invocation, String sql) throws SQLException { // 判断是否存在排序的内容 int orderByIndex = -1; if (sql.lastIndexOf("order by") != -1) { sql = sql.substring(0, orderByIndex); } // 获取查询总条数sql int fromIndex = sql.indexOf("from"); String countSQL = "select count(*) " + sql.substring(fromIndex); log.info("[查询总条数的SQL: " + countSQL); // 执行sql // 获得方法的参数 Connection connection = (Connection) invocation.getArgs()[0]; PreparedStatement ps = null; ResultSet resultSet = null; try { // sql 处理器 ps = connection.prepareStatement(countSQL); // 处理参数 statementHandler.parameterize(ps); // 执行sql resultSet = ps.executeQuery(); // 获取结果 if (resultSet.first()) { return resultSet.getInt(1); } } catch (SQLException sqlException) { log.info("[查询总条数的SQL出现异常!!!]"); throw sqlException; } finally { if (resultSet != null) { resultSet.close(); } if (ps != null) { ps.close(); } } return -1; } }
由于使用代理模式对MyBatis四大内置对象进行增强,当创建多个分页插件时会进行干扰,我们有时候获得的目标对象,并不是真实的目标对象,而是其它插件形成的代理对象,我们需要写一个工具类获取真实的目标对象。
package com.zhj.common.db.mysql.util; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.SystemMetaObject; /** * @author zhj */ public class MybatisUtils { /** * 获取非代理对象 * @param target * @param <T> * @return */ public static <T> T getNoProxyTarget(Object target) { MetaObject invocationMetaObject = SystemMetaObject.forObject(target); while (invocationMetaObject.hasGetter("h")) { // 说明获得的是代理对象 target = invocationMetaObject.getValue("h.target"); invocationMetaObject = SystemMetaObject.forObject(target); } return (T) target; } }
注入分页插件,使其生效
package com.zhj.common.db.mysql.config; import com.zhj.common.db.mysql.plugins.PagePlugins; import com.zhj.common.db.mysql.plugins.PrintSQLPlugins; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.transaction.annotation.EnableTransactionManagement; /** * @author zhj */ @Configuration @MapperScan("com.zhj.data.mapper") @EnableTransactionManagement public class DBAutoConfiguration { @Bean @ConditionalOnProperty(value = "zhj.plugins.printSql.enable", havingValue = "true", matchIfMissing = false) public PrintSQLPlugins getPrintSQLPlugins(){ return new PrintSQLPlugins(); } @Bean public PagePlugins getPagePlugins(){ return new PagePlugins(); } }
在Controller(Service)中设置开启分页
package com.zhj.business.controller; import com.zhj.business.protocol.input.StudentInput; import com.zhj.business.service.StudentService; import com.zhj.common.core.result.Result; import com.zhj.common.core.util.ResultUtils; import com.zhj.common.db.mysql.page.PageUtils; import com.zhj.data.entity.example.Student; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import javax.validation.Valid; import java.util.List; /** * @author zhj */ @Slf4j @RestController @RequestMapping("/student") public class StudentController { @Autowired private StudentService studentService; @GetMapping("/list") public Result<List<Student>> list() { // 开启分页,可将前端传入的值设置到Page中 PageUtils.setPage(1,2); List<Student> list = studentService.list(); return ResultUtils.createSuccess(list); } }
让分页插件更优雅:
将侵入部分去掉,通过AOP的方式开启分页,并将分页信息返回
package com.zhj.common.db.mysql.aop; import com.zhj.common.db.mysql.page.BasePageResult; import com.zhj.common.db.mysql.page.Page; import com.zhj.common.db.mysql.page.PageUtils; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.springframework.util.StringUtils; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import javax.servlet.http.HttpServletRequest; /** * 如果参数携带pageNo和pageSize自动开启分页 * @author zhj */ @Aspect public class WebPageAOP { @Around("@within(org.springframework.web.bind.annotation.RestController) || @within(org.springframework.stereotype.Controller)") public Object pageAOP(ProceedingJoinPoint joinPoint) throws Throwable { // 获取参数 ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); HttpServletRequest request = requestAttributes.getRequest(); String pageNo = request.getParameter("pageNo"); String pageSize = request.getParameter("pageSize"); if (!StringUtils.isEmpty(pageNo) && !StringUtils.isEmpty(pageSize)) { PageUtils.setPage(Integer.parseInt(pageNo), Integer.parseInt(pageSize)); } Object proceed = null; try { proceed = joinPoint.proceed(); Page page = PageUtils.getPage(); if (proceed instanceof BasePageResult && page != null) { BasePageResult basePageResult = (BasePageResult) proceed; basePageResult.setPage(page); } } catch (Throwable e) { throw e; } finally { PageUtils.clear(); } return proceed; } }
package com.zhj.common.db.mysql.config; import com.zhj.common.db.mysql.aop.PageAOP; import com.zhj.common.db.mysql.aop.WebPageAOP; import com.zhj.common.db.mysql.plugins.PagePlugins; import com.zhj.common.db.mysql.plugins.PrintSQLPlugins; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.transaction.annotation.EnableTransactionManagement; /** * @author zhj */ @Configuration @MapperScan("com.zhj.data.mapper") @EnableTransactionManagement public class DBAutoConfiguration { @Bean @ConditionalOnProperty(value = "zhj.plugins.printSql.enable", havingValue = "true", matchIfMissing = false) public PrintSQLPlugins getPrintSQLPlugins(){ return new PrintSQLPlugins(); } @Bean public PagePlugins getPagePlugins(){ return new PagePlugins(); } @Bean public WebPageAOP getWebPageAOP(){ return new WebPageAOP(); } }
通过注解将分页的粒度控制到更细的粒度
创建注解
package com.zhj.common.db.mysql.annotation; import java.lang.annotation.*; /** * @author zhj */ @Documented @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) public @interface Page { }
Page对象增加开关
package com.zhj.common.db.mysql.page; import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.Data; import lombok.experimental.Accessors; import java.io.Serializable; /** * 分页信息对象 * @author zhj */ @Data @Accessors(chain = true) public class Page implements Serializable { /** * 当前页 */ private Integer pageNo; /** * 每页多少条 */ private Integer pageSize; /** * 总页码 */ private Integer pageTotal; /** * 总条数 */ private Integer pageCount; /** * 是否开启分页 */ @JsonIgnore private boolean enable; }
在原来的分页拦截器上增加判断条件
// 获取分页参数 Page page = PageUtils.getPage(); if (page == null || !page.isEnable()) { return invocation.proceed(); }
通过AOP设置开关
package com.zhj.common.db.mysql.aop; import com.zhj.common.db.mysql.page.Page; import com.zhj.common.db.mysql.page.PageUtils; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; /** * @author zhj */ @Aspect public class PageAOP { @Around("@annotation(com.zhj.common.db.mysql.annotation.Page)") public Object pageAOP(ProceedingJoinPoint joinPoint) throws Throwable { Page page = PageUtils.getPage(); if (page != null) { page.setEnable(true); } try { return joinPoint.proceed(); } finally { if (page != null) { page.setEnable(false); } } } }
package com.zhj.common.db.mysql.config; import com.zhj.common.db.mysql.aop.PageAOP; import com.zhj.common.db.mysql.aop.WebPageAOP; import com.zhj.common.db.mysql.plugins.PagePlugins; import com.zhj.common.db.mysql.plugins.PrintSQLPlugins; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.transaction.annotation.EnableTransactionManagement; /** * @author zhj */ @Configuration @MapperScan("com.zhj.data.mapper") @EnableTransactionManagement public class DBAutoConfiguration { @Bean @ConditionalOnProperty(value = "zhj.plugins.printSql.enable", havingValue = "true", matchIfMissing = false) public PrintSQLPlugins getPrintSQLPlugins(){ return new PrintSQLPlugins(); } @Bean public PagePlugins getPagePlugins(){ return new PagePlugins(); } @Bean public WebPageAOP getWebPageAOP(){ return new WebPageAOP(); } @Bean public PageAOP getPageAOP(){ return new PageAOP(); } }
在对应的service或者dao上开启分页
package com.zhj.business.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.zhj.business.service.StudentService; import com.zhj.common.db.mysql.annotation.Page; import com.zhj.data.entity.example.Student; import com.zhj.data.mapper.example.dao.StudentDao; import org.springframework.stereotype.Service; import java.util.List; /** * @author zhj */ @Service public class StudentServiceImpl extends ServiceImpl<StudentDao, Student> implements StudentService { @Page @Override public List<Student> list() { return super.list(); } }
MyBatis插件开发总结
想要对框架进行扩展,首先必须得了解框架源码,只有对源码有较为深入的了解,我们才能更好的把握从哪个点进行切入扩展。本文中的两个案例都是最为简单的实现,说实话,还有很多漏洞,比如第一个打印SQL的插件我们并没有去将参数填充,也没有拿到参数,第二个案例分页,只能满足一些比较简单的场景,如果SQL过于复杂,很可能会出现Bug。这些内容都需要我们不断去学习源码,不断的去学习开源项目,积累的越多,我们写出来的工具越完美。大家可以参考GitHub上MyBatis分页的开源项目,对自己写的分页插件进行不断的完善,当然大家也可以在评论区进行交流,共同学习。