WangHan
2024-09-12 d5855a4926926698b740bc6c7ba489de47adb68b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package tech.powerjob.official.processors.impl.sql;
 
import com.alibaba.fastjson.JSON;
import tech.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.h2.jdbc.JdbcSQLIntegrityConstraintViolationException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
 
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
 
/**
 * @author Echo009
 * @since 2021/3/11
 */
@Slf4j
class SpringDatasourceSqlProcessorTest {
 
    private static SpringDatasourceSqlProcessor springDatasourceSqlProcessor;
 
    @BeforeAll
    static void initSqlProcessor() {
 
        EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
        EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
                .addScript("classpath:db_init.sql")
                .build();
        springDatasourceSqlProcessor = new SpringDatasourceSqlProcessor(database);
        // do nothing
        springDatasourceSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
        // 排除掉包含 drop 的 SQL
        springDatasourceSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
        // add ';'
        springDatasourceSqlProcessor.setSqlParser((sql, taskContext) -> {
            if (!sql.endsWith(";")) {
                return sql + ";";
            }
            return sql;
        });
 
        // just invoke clean datasource method
        springDatasourceSqlProcessor.removeDataSource("NULL_DATASOURCE");
 
        log.info("init sql processor successfully!");
 
    }
 
 
    @Test
    void testSqlValidator() {
        SpringDatasourceSqlProcessor.SqlParams sqlParams = new SpringDatasourceSqlProcessor.SqlParams();
        sqlParams.setSql("drop table test_table");
        // 校验不通过
        assertThrows(IllegalArgumentException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
    }
 
    @Test
    void testIncorrectDataSourceName() {
        SpringDatasourceSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
        sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
        // 数据源名称非法
        assertThrows(IllegalArgumentException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
    }
 
    @Test
    void testExecDDL() {
        SpringDatasourceSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
        ProcessResult processResult = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
        assertTrue(processResult.isSuccess());
    }
 
    @Test
    void testExecSQL() {
 
        SpringDatasourceSqlProcessor.SqlParams sqlParams1 = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
        ProcessResult processResult1 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1)));
        assertTrue(processResult1.isSuccess());
 
        assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1))));
        // 第二条会失败回滚
        SpringDatasourceSqlProcessor.SqlParams sqlParams2 = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
        assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams2))));
        // 上方回滚,这里就能成功插入
        SpringDatasourceSqlProcessor.SqlParams sqlParams3 = constructSqlParam("insert into test_table (id, content) values (1, '?')");
        ProcessResult processResult3 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams3)));
        assertTrue(processResult3.isSuccess());
 
        SpringDatasourceSqlProcessor.SqlParams sqlParams4 = constructSqlParam("insert into test_table (id, content) values (2, '?');insert into test_table (id, content) values (3, '?')");
        ProcessResult processResult4 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams4)));
        assertTrue(processResult4.isSuccess());
 
    }
 
    @Test
    public void testQuery() {
        SpringDatasourceSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
        springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams)));
 
        SpringDatasourceSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table");
        springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams)));
    }
 
    static SpringDatasourceSqlProcessor.SqlParams constructSqlParam(String sql){
        SpringDatasourceSqlProcessor.SqlParams sqlParams = new SpringDatasourceSqlProcessor.SqlParams();
        sqlParams.setSql(sql);
        sqlParams.setShowResult(true);
        return sqlParams;
    }
 
}