package tech.powerjob.official.processors.impl.sql;
|
|
import com.alibaba.fastjson.JSON;
|
import tech.powerjob.worker.core.processor.ProcessResult;
|
import lombok.extern.slf4j.Slf4j;
|
import org.h2.jdbc.JdbcSQLIntegrityConstraintViolationException;
|
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.Test;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
import tech.powerjob.official.processors.TestUtils;
|
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
/**
|
* @author Echo009
|
* @since 2021/3/11
|
*/
|
@Slf4j
|
class SpringDatasourceSqlProcessorTest {
|
|
private static SpringDatasourceSqlProcessor springDatasourceSqlProcessor;
|
|
@BeforeAll
|
static void initSqlProcessor() {
|
|
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
|
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
|
.addScript("classpath:db_init.sql")
|
.build();
|
springDatasourceSqlProcessor = new SpringDatasourceSqlProcessor(database);
|
// do nothing
|
springDatasourceSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
// 排除掉包含 drop 的 SQL
|
springDatasourceSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
// add ';'
|
springDatasourceSqlProcessor.setSqlParser((sql, taskContext) -> {
|
if (!sql.endsWith(";")) {
|
return sql + ";";
|
}
|
return sql;
|
});
|
|
// just invoke clean datasource method
|
springDatasourceSqlProcessor.removeDataSource("NULL_DATASOURCE");
|
|
log.info("init sql processor successfully!");
|
|
}
|
|
|
@Test
|
void testSqlValidator() {
|
SpringDatasourceSqlProcessor.SqlParams sqlParams = new SpringDatasourceSqlProcessor.SqlParams();
|
sqlParams.setSql("drop table test_table");
|
// 校验不通过
|
assertThrows(IllegalArgumentException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
}
|
|
@Test
|
void testIncorrectDataSourceName() {
|
SpringDatasourceSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
|
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
|
// 数据源名称非法
|
assertThrows(IllegalArgumentException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
}
|
|
@Test
|
void testExecDDL() {
|
SpringDatasourceSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
|
ProcessResult processResult = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
assertTrue(processResult.isSuccess());
|
}
|
|
@Test
|
void testExecSQL() {
|
|
SpringDatasourceSqlProcessor.SqlParams sqlParams1 = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
ProcessResult processResult1 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1)));
|
assertTrue(processResult1.isSuccess());
|
|
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1))));
|
// 第二条会失败回滚
|
SpringDatasourceSqlProcessor.SqlParams sqlParams2 = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams2))));
|
// 上方回滚,这里就能成功插入
|
SpringDatasourceSqlProcessor.SqlParams sqlParams3 = constructSqlParam("insert into test_table (id, content) values (1, '?')");
|
ProcessResult processResult3 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams3)));
|
assertTrue(processResult3.isSuccess());
|
|
SpringDatasourceSqlProcessor.SqlParams sqlParams4 = constructSqlParam("insert into test_table (id, content) values (2, '?');insert into test_table (id, content) values (3, '?')");
|
ProcessResult processResult4 = springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams4)));
|
assertTrue(processResult4.isSuccess());
|
|
}
|
|
@Test
|
public void testQuery() {
|
SpringDatasourceSqlProcessor.SqlParams insertParams = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(insertParams)));
|
|
SpringDatasourceSqlProcessor.SqlParams queryParams = constructSqlParam("select * from test_table");
|
springDatasourceSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(queryParams)));
|
}
|
|
static SpringDatasourceSqlProcessor.SqlParams constructSqlParam(String sql){
|
SpringDatasourceSqlProcessor.SqlParams sqlParams = new SpringDatasourceSqlProcessor.SqlParams();
|
sqlParams.setSql(sql);
|
sqlParams.setShowResult(true);
|
return sqlParams;
|
}
|
|
}
|