场景
在项目中如有一条查询语句,需要将其中表名所在的库名补充上(库名已知),
例如:select id from a
,表a所在库叫d,那么需要修改为select id from d.a
初步思考
所以最开始打算用正则,不区分大小写的形式匹配到from+空格+不带点、不带空格的字符串,然后将第三部分替换。
有了如下代码:
private String dealSql(String sql, String dbName) {
sql = sql.replaceAll("\\s+", " ").replace(";"," ");
String result = sql;
String p = "(from)(\\s+)([^\\s]+)";
Pattern pattern = Pattern.compile(p, Pattern.CASE_INSENSITIVE);
Matcher matcher = pattern.matcher(sql);
while (matcher.find()) {
if(matcher.groupCount() == 3) {
if(!matcher.group(3).contains(".")) {
result = matcher.replaceAll("$1$2" + dbName + ".$3");
}
}
}
return result;
}
经过测试
深入思考
貌似结果是正确的,但是由于sql不可能只有这样一种情况,还会有诸如以下形式产生:
- 多表查询,使用where关联
select a.name from a,b where a.id = b.id
- 使用join的联合查询
select a.name from a left join b on a.id = b.id
- 带有别名
select a.name from a t1 left join b t2 on t1.id = t2.id
- 虽然有from,但是是子查询
select t.name from (select a.name from a left join b on a.id = b.id) t
也许还会有其他复杂情况,不过短时间只想到了这几种情况,要是按照上面那种方式来逐一匹配、解决的话,费键盘不说,也很费头发。
作为一个优秀的百度工程师,决定搜索一下,发现可以通过druid中的一些工具类来实现,而且正巧项目中使用到了druid依赖。但是这种小需求,网上资料比较少,所以研究了一下druid的源码。
Druid中的Vistor
Druid中主要使用访问者模式来解析sql,并且封装了对sql语句解析生成的sql树的一系列操作。
访问者模式,是行为型设计模式之一。访问者模式是一种将数据操作与数据结构分离的设计模式,
在SQLObject接口中提供了很多使用Visitor的方法。
平时用到的SQLStatement就是实现了这个接口
由于这里用到的语法是Hive的,Druid中默认已经支持标准sql,只要看一下通用的sql实现方式即可。
它的accept方法会调用visitor,并且执行里面的preVisit(),visit(),postVisit();
翻看了SQLASTVisitor,里面默认实现了上面那些对于ast语法树操作的接口,虽然没有现成符合需求的,需要自己实现,但是也不需要实现全部方法,只要继承了这个类以后,重写关键部分就可以了。
这里已经清楚了大概逻辑,所以实现起来就比较容易了,首先继承HiveASTVisitorAdapter这个类(还是SQLASTVisitor的实现类)。
关系图如下:
主要用到图中的SQLASTVisitorAdapter这个类。
从字面意思可以猜出,对应sql语句的各个部分都有一个visit方法。我在这里需要的是修改表名,所以在这个文件中搜索了一下Table。
不出意外的话,这个方法就是正在寻找的东西。
于是写了一个类去继承,然后重写:
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.dialect.hive.visitor.HiveASTVisitorAdapter;
import com.cestc.basictool.utils.StringUtils;
/**
* ExportTableAliasVisitor
*
* @author Mwg
* @date 2020/09/08 23:47
*/
public class ExportTableAliasVisitor extends HiveASTVisitorAdapter {
private static ThreadLocal<String> dbName = new ThreadLocal<>();
public void set(String s) {
dbName.set(s);
}
public void remove() {
dbName.remove();
}
@Override
public boolean visit(SQLExprTableSource x) {
//别名,如果有别名,别名保持不变
String s = StringUtils.isEmpty(x.getAlias()) ? x.getExpr().toString() : x.getAlias();
// 修改表名,不包含点才加 select id from c left join d on c.id = d.id 中的c 和 d
if(!x.getExpr().toString().contains(".")) {
x.setExpr("`" + dbName.get() + "`." + x.getExpr());
}
x.setAlias(s);
return true;
}
}
在这里使用ThreadLocal去传递库名是因为对其他代码不熟悉,稳妥起见,先这样尝试一下。
调用的main方法:
public static void main(String[] args) {
List<String> list = new ArrayList<>();
String s0 = "select id from a";
String s1 = "select a.id from a,b where a.id = b.id";
String s2 = "select a.name from a left join b on a.id = b.id";
String s3 = "select a.name from a t1 left join b t2 on t1.id = t2.id";
String s4 = "select t.name from (select a.name from a left join b on a.id = b.id) t";
list.add(s0);
list.add(s1);
list.add(s2);
list.add(s3);
list.add(s4);
String dbType = JdbcConstants.HIVE;
int i = 0;
for (String sql : list) {
System.out.println(i++);
System.out.println(sql);
ExportTableAliasVisitor visitor = new ExportTableAliasVisitor();
visitor.set("d");
List<SQLStatement> stmtList = SQLUtils.parseStatements(sql, dbType);
for (SQLStatement stmt : stmtList) {
stmt.accept(visitor);
}
String s = SQLUtils.toSQLString(stmtList, dbType);
visitor.remove();
System.out.println(s);
}
}
输出结果:
暂时看起来没发现有什么问题,还是很好用的,后续发现问题再补充。
总结
通过parse()将sql转换为ast语法树 => 通过Visitor修改语法树 => 将语法树转换为SQL