自动生成JPA bean及repository生成简陋工具

因为工具不太灵活,手写了一个,没啥技术难度,纯堆代码量

import java.io.File;
import java.io.FileOutputStream;
import java.nio.charset.Charset;
import java.sql.*;
import java.util.*;

/**
 * JPA dao自动生成工具
 */
public class JpaGenerate {
    //bean所在位置,示例:com.jpa.bean
    static String entityPackage = "com.datasource.entity";
    static String entityPath = "";
    //repository,示例:com.jpa.repository
    static String repositoryPackage = "com.datasource.dao";
    static String repositoryPath = "";
    //mysql配置
    final static String url = "jdbc:mysql://localhost:3306/test";
    final static String user = "root";
    final static String password = "123456";
    //需要自动生成的数据库
    final static String database = "operation_platform";


    public static void main(String[] args) throws SQLException {
        Driver driver = new com.mysql.cj.jdbc.Driver();
        Properties info = new Properties();
        info.setProperty("user", user);
        info.setProperty("password", password);
        Connection conn = driver.connect(url, info);
        Statement stat = conn.createStatement();
        String sql = """
                    select 
                	    table_name,
                        column_name,
                        column_type,
                        column_key,
                        extra,
                        column_comment
                    FROM
                        information_schema. COLUMNS
                    WHERE
                """
                + "table_schema = '" + database + "' ORDER BY table_name";
        ResultSet rs = stat.executeQuery(sql);
        Map> fieldMap = new HashMap<>();
        while (rs.next()) {
            String table = rs.getString("table_name");
            if (!fieldMap.containsKey(table)) {
                fieldMap.put(table, new ArrayList<>());
            }
            fieldMap.get(table).add(new DatasourceTableField(
                    rs.getString("column_name"),
                    rs.getString("column_type"),
                    rs.getString("column_comment"),
                    rs.getString("column_key"),
                    rs.getString("extra")
            ));
        }
        conn.close();

        String projectPath = System.getProperty("user.dir").
                replaceAll("\\\\", "/");
        entityPath = projectPath + "/src/main/java/" + entityPackage.replaceAll("\\.", "/");
        repositoryPath = projectPath + "/src/main/java/" + repositoryPackage.replaceAll("\\.", "/");
        fieldMap.keySet().stream().forEach(key -> {
            try {
                createJpaFiles(key, fieldMap.get(key));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });

    }

    static void createJpaFiles(String table, List list) throws Exception {
        String tableName = convertTableName(table);
        String entityFilePath = entityPath + "/" + tableName + ".java";
        File entity = new File(entityFilePath);
        if (!entity.exists()) {
            entity.createNewFile();
            FileOutputStream outputStream = new FileOutputStream(entity);
            String fileText = "package " + entityPackage + ";\n\n" + """
                    import jakarta.persistence.*;
                    import lombok.Data;
                                        
                    """;
            boolean time = false;
            String body = "\n";
            for (DatasourceTableField field : list) {
                if (field.columnType.contains("("))
                    field.columnType = field.columnType.split("\\(")[0];
                String columnName = convertFieldName(field.columnName);
                switch (field.columnType) {
                    case "int", "smallint" -> {
                        if ("PRI".equals(field.columnKey)) {
                            if ("auto_increment".equals(field.extra)) {
                                body += "    @GeneratedValue(strategy = GenerationType.IDENTITY)\n";
                            }
                            body += "    @Id\n";
                        }
                        body += "    private Integer " + columnName + ";\n";
                    }
                    case "varchar", "text" -> body += "    private String " + columnName + ";\n";
                    case "datetime" -> {
                        time = true;
                        body += "    private LocalDateTime " + columnName + ";\n";
                    }
                    case "double" -> body += "    private Double " + columnName + ";\n";
                }
            }
            if (time) {
                fileText += "import java.time.LocalDateTime;\n\n";
            }
            fileText += """
                    @Entity
                    @Data
                    @Table(name = \"""" + table +
                    "\")\npublic class " + tableName + "{\n" + body + "}";
            outputStream.write(fileText.getBytes(Charset.forName("utf-8")));
            outputStream.close();
        }
        String repositoryFilePath = repositoryPath + "/" + tableName + "Repository.java";
        File repository = new File(repositoryFilePath);
        if (!repository.exists()) {
            repository.createNewFile();
            FileOutputStream outputStream = new FileOutputStream(repository);
            String fileText = "package " + repositoryPackage + ";\n\n" +
                    "import " + entityPackage + "." + tableName + ";\n" +
                    """
                            import org.springframework.data.jpa.repository.JpaRepository;
                            import org.springframework.stereotype.Repository;
                                               
                            @Repository
                            """
                    + "public interface " + tableName + "Repository extends JpaRepository<" +
                    tableName + ", Integer> {\n}";
            outputStream.write(fileText.getBytes(Charset.forName("utf-8")));
            outputStream.close();
        }
    }

    static String convertTableName(String table) {
        String tableName = "";
        if (table.contains("_")) {
            for (String str : table.split("_")) {
                tableName += Character.toUpperCase(str.charAt(0)) + str.substring(1);
            }
        } else {
            tableName = Character.toUpperCase(table.charAt(0)) + table.substring(1);
        }
        return tableName;
    }

    static String convertFieldName(String tableField) {
        String fieldName = "";
        if (tableField.contains("_")) {
            boolean head = true;
            for (String str : tableField.split("_")) {
                if (head) {
                    fieldName += str;
                    head = false;
                } else
                    fieldName += Character.toUpperCase(str.charAt(0)) + str.substring(1);
            }
        } else {
            fieldName = tableField;
        }
        return fieldName;
    }

    static class DatasourceTableField {

        String columnName;
        String columnType;
        String columnComment;
        String columnKey;
        String extra;

        public DatasourceTableField(String columnName, String columnType, String columnComment, String columnKey, String extra) {
            this.columnName = columnName;
            this.columnType = columnType;
            this.columnComment = columnComment;
            this.columnKey = columnKey;
            this.extra = extra;
        }
    }

}

由于项目中使用了lombok,所以写死了引入,不需要的直接去掉第85行,那么就需要重新加get set函数,在111行后再遍历一下list对body中添加getset函数即可,这里有一些坑,这个只能用于mysql,我没写其他数据库的连接和查询,然后目录是根据spring boot的结构来的,如果不是spring boot 可以修改64、65行目录,由于偷懒,mysql的数据类型也没写完,有其他类型的可以在94行的switch中添加,最后,此代码需要jdk17,因为使用了多行字符串和17+的switch,低版本的话改下字符串和switch写法也能用

你可能感兴趣的:(java)