package tech.powerjob.worker.persistence.db; import tech.powerjob.worker.common.constants.TaskStatus; import tech.powerjob.worker.core.processor.TaskResult; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import tech.powerjob.worker.persistence.TaskDO; import java.sql.*; import java.util.Collection; import java.util.List; import java.util.Map; /** * 任务持久化实现层,表名:task_info * * @author tjq * @since 2020/3/17 */ public class TaskDAOImpl implements TaskDAO { private final boolean useIndex; private final ConnectionFactory connectionFactory; public TaskDAOImpl(ConnectionFactory connectionFactory) { this(false, connectionFactory); } public TaskDAOImpl(boolean useIndex, ConnectionFactory connectionFactory) { this.useIndex = useIndex; this.connectionFactory = connectionFactory; } @Override public void initTable() throws Exception { String delTableSQL = "drop table if exists task_info"; // 感谢 Gitee 用户 @Linfly 反馈的 BUG // bigint(20) 与 Java Long 取值范围完全一致 String createTableSQL = "create table task_info (task_id varchar(255), instance_id bigint, sub_instance_id bigint, task_name varchar(255), task_content blob, address varchar(255), status int, result text, failed_cnt int, created_time bigint, last_modified_time bigint, last_report_time bigint, constraint pkey unique (instance_id, task_id))"; String createIndexSQL = "create INDEX idx_status ON task_info (status)"; try (Connection conn = connectionFactory.getConnection(); Statement stat = conn.createStatement()) { stat.execute(delTableSQL); stat.execute(createTableSQL); if (useIndex) { stat.execute(createIndexSQL); } } } @Override public boolean save(TaskDO task) throws SQLException { String insertSQL = "insert into task_info(task_id, instance_id, sub_instance_id, task_name, task_content, address, status, result, failed_cnt, created_time, last_modified_time, last_report_time) values (?,?,?,?,?,?,?,?,?,?,?,?)"; try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(insertSQL)) { fillInsertPreparedStatement(task, ps); return ps.executeUpdate() == 1; } } @Override public boolean batchSave(Collection tasks) throws SQLException { String insertSql = "insert into task_info(task_id, instance_id, sub_instance_id, task_name, task_content, address, status, result, failed_cnt, created_time, last_modified_time, last_report_time) values (?,?,?,?,?,?,?,?,?,?,?,?)"; boolean originAutoCommitFlag ; try (Connection conn = connectionFactory.getConnection()) { originAutoCommitFlag = conn.getAutoCommit(); conn.setAutoCommit(false); try ( PreparedStatement ps = conn.prepareStatement(insertSql)) { for (TaskDO task : tasks) { fillInsertPreparedStatement(task, ps); ps.addBatch(); } ps.executeBatch(); return true; } catch (Throwable e) { conn.rollback(); throw e; } finally { conn.setAutoCommit(originAutoCommitFlag); } } } @Override public boolean simpleDelete(SimpleTaskQuery condition) throws SQLException { String deleteSQL = "delete from task_info where %s"; String sql = String.format(deleteSQL, condition.getQueryCondition()); try (Connection conn = connectionFactory.getConnection(); Statement stat = conn.createStatement()) { stat.executeUpdate(sql); return true; } } @Override public List simpleQuery(SimpleTaskQuery query) throws SQLException { ResultSet rs = null; String sql = "select * from task_info where " + query.getQueryCondition(); List result = Lists.newLinkedList(); try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { if (query.isReadOnly()) { conn.setReadOnly(true); } rs = ps.executeQuery(); while (rs.next()) { result.add(convert(rs)); } } finally { if (rs != null) { try { rs.close(); }catch (Exception ignore) { } } } return result; } @Override public List> simpleQueryPlus(SimpleTaskQuery query) throws SQLException { ResultSet rs = null; String sqlFormat = "select %s from task_info where %s"; String sql = String.format(sqlFormat, query.getQueryContent(), query.getQueryCondition()); List> result = Lists.newLinkedList(); try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { rs = ps.executeQuery(); // 原数据,包含了列名 ResultSetMetaData metaData = rs.getMetaData(); while (rs.next()) { Map row = Maps.newHashMap(); result.add(row); for (int i = 0; i < metaData.getColumnCount(); i++) { String colName = metaData.getColumnName(i + 1); Object colValue = rs.getObject(colName); row.put(colName, colValue); } } } finally { if (rs != null) { try { rs.close(); }catch (Exception ignore) { } } } return result; } @Override public boolean simpleUpdate(SimpleTaskQuery condition, TaskDO updateField) throws SQLException { String sqlFormat = "update task_info set %s where %s"; String updateSQL = String.format(sqlFormat, updateField.fetchUpdateSQL(), condition.getQueryCondition()); try (Connection conn = connectionFactory.getConnection(); PreparedStatement stat = conn.prepareStatement(updateSQL)) { stat.executeUpdate(); return true; } } @Override public List getAllTaskResult(Long instanceId, Long subInstanceId) throws SQLException { ResultSet rs = null; List taskResults = Lists.newLinkedList(); String sql = "select task_id, status, result from task_info where instance_id = ? and sub_instance_id = ?"; try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setLong(1, instanceId); ps.setLong(2, subInstanceId); rs = ps.executeQuery(); while (rs.next()) { int taskStatus = rs.getInt(2); // 只需要完成的结果 if (taskStatus == TaskStatus.WORKER_PROCESS_SUCCESS.getValue() || taskStatus == TaskStatus.WORKER_PROCESS_FAILED.getValue()) { TaskResult result = new TaskResult(); taskResults.add(result); result.setTaskId(rs.getString(1)); result.setSuccess(taskStatus == TaskStatus.WORKER_PROCESS_SUCCESS.getValue()); result.setResult(rs.getString(3)); } } }finally { if (rs != null) { try { rs.close(); }catch (Exception ignore) { } } } return taskResults; } @Override public boolean updateTaskStatus(Long instanceId, String taskId, int status, long lastReportTime, String result) throws SQLException { String sql = "update task_info set status = ?, last_report_time = ?, result = ?, last_modified_time = ? where instance_id = ? and task_id = ?"; try (Connection conn = connectionFactory.getConnection(); PreparedStatement ps = conn.prepareStatement(sql)) { ps.setInt(1, status); ps.setLong(2, lastReportTime); ps.setString(3, result); ps.setLong(4, lastReportTime); ps.setLong(5, instanceId); ps.setString(6, taskId); ps.executeUpdate(); return true; } } private static TaskDO convert(ResultSet rs) throws SQLException { TaskDO task = new TaskDO(); task.setTaskId(rs.getString("task_id")); task.setInstanceId(rs.getLong("instance_id")); task.setSubInstanceId(rs.getLong("sub_instance_id")); task.setTaskName(rs.getString("task_name")); task.setTaskContent(rs.getBytes("task_content")); task.setAddress(rs.getString("address")); task.setStatus(rs.getInt("status")); task.setResult(rs.getString("result")); task.setFailedCnt(rs.getInt("failed_cnt")); task.setCreatedTime(rs.getLong("created_time")); task.setLastModifiedTime(rs.getLong("last_modified_time")); task.setLastReportTime(rs.getLong("last_report_time")); return task; } // 填充插入字段 private static void fillInsertPreparedStatement(TaskDO task, PreparedStatement ps) throws SQLException { ps.setString(1, task.getTaskId()); ps.setLong(2, task.getInstanceId()); ps.setLong(3, task.getSubInstanceId()); ps.setString(4, task.getTaskName()); ps.setBytes(5, task.getTaskContent()); ps.setString(6, task.getAddress()); ps.setInt(7, task.getStatus()); ps.setString(8, task.getResult()); ps.setInt(9, task.getFailedCnt()); ps.setLong(10, task.getCreatedTime()); ps.setLong(11, task.getLastModifiedTime()); ps.setLong(12, task.getLastReportTime()); } }