利用IDEA 自带功能生成POJOs

利用IDEA 自带功能生成POJOs

一、实现

1、脚本

以下是根据自己理解的DDD领域驱动设计思想,修改的脚本。

import com.intellij.database.model.DasTable
import com.intellij.database.model.ObjectKind
import com.intellij.database.util.Case
import com.intellij.database.util.DasUtil

/**
 * Available context bindings:
 * SELECTION Iterable
 * PROJECT project
 * FILES files helper
 */

//选择包名
packageName = ""
author = "minliang.xing"
//数据类型的映射关系
typeMapping = [
        (~/(?i)tinyint|smallint|mediumint/)       : "Integer",
        (~/(?i)int/)                              : "Long",
        (~/(?i)bool|bit/)                         : "Boolean",
        (~/(?i)float|double|decimal|real/)        : "Double",
        (~/(?i)year|datetime|timestamp|date|time/): "Date",
        (~/(?i)blob|binary|bfile|clob|raw|image/) : "InputStream",
        (~/(?i)/)                                 : "String"
]

/**
 * FILES.chooseDirectoryAndSave 是在 idea 的 Database 窗口鼠标右键
 * 点击 groovy 选项后弹出文件夹选择框关闭时回调的方法,
 * DasTable 指代一张表,保存了该张表中的一些信息,如表名,字段等,dir 是选中的文件夹。
 */
//选择生成的pojo类的存储路径
FILES.chooseDirectoryAndSave("Choose directory", "Choose where to store generated files") { dir ->
    packageName = getPackageNameForTemp(dir)
    initDDDStructure(dir)
    SELECTION.filter { it instanceof DasTable && it.getKind() == ObjectKind.TABLE }.each {
        //generate(it, dir)
        generate(it, dir)
    }
}

def initDDDStructure(dir) {
    Set files = new HashSet()
    files.add("web")
    files.add("web\\controller")
    files.add("application\\command\\cmd")
    files.add("application\\command\\service")
    files.add("application\\command\\service\\impl")
    files.add("application\\query\\qry")
    files.add("application\\query\\service")
    files.add("application\\query\\service\\impl")
    files.add("domain")
    files.add("domain\\aggregate")
    files.add("domain\\type")
    files.add("infrastructure")
    files.add("infrastructure\\db")
    files.add("infrastructure\\db\\dataobject")
    files.add("infrastructure\\db\\converter")
    files.add("infrastructure\\db\\repository")

    files.each {
        def file = new File(dir.toString() + "\\" + it)
        if (!file.exists()) {
            file.mkdirs()
            def packageFile = new File(file.getPath() + "\\package-info.java")
            packageFile.withPrintWriter("UTF-8") {
                out -> generatePackageInfo(out, dir.toString() + "\\" + it)
            }

        }

    }
}


/**
 * generate 方法会根据 table 和 dir 生成目标文件。
 * generate(out, className, fields) 方法是真正进行模板生成并写入文件的地方。
 *
 */
//将生成的内容写入文件
def generate(table, dir) {
    //def className = javaName(table.getName(), true)
    def className = javaClassName(table.getName(), true)
    def fields = calcFields(table)


    def cmdServiceDir = dir.toString() + "\\application\\command\\service"
    def cmdServiceFile = new File(cmdServiceDir, className + "CmdService.java");
    if (!cmdServiceFile.exists()) {
        cmdServiceFile.withPrintWriter("UTF-8") {
            out -> generateCmdService(out, className, cmdServiceDir)
        }
    }

    def cmdServiceImplDir = cmdServiceDir + "\\impl"
    def cmdServiceImplFile = new File(cmdServiceImplDir, className + "CmdServiceImpl.java");
    if (!cmdServiceImplFile.exists()) {
        cmdServiceImplFile.withPrintWriter("UTF-8") {
            out -> generateCmdServiceImpl(out, className, cmdServiceDir)
        }
    }

    def qryServiceDir = dir.toString() + "\\application\\query\\service"
    def qryServiceFile = new File(qryServiceDir, className + "QryService.java");
    if (!qryServiceFile.exists()) {
        qryServiceFile.withPrintWriter("UTF-8") {
            out -> generateQryService(out, className, qryServiceDir)
        }
    }

    def qryServiceImplDir = qryServiceDir + "\\impl"
    def qryServiceImplFile = new File(qryServiceImplDir, className + "QryServiceImpl.java");
    if (!qryServiceImplFile.exists()) {
        qryServiceImplFile.withPrintWriter("UTF-8") {
            out -> generateQryServiceImpl(out, className, qryServiceDir)
        }
    }

    def abstractAndInterfaceDir = dir.toString() + "\\domain\\aggregate"
    def abstractFile = new File(abstractAndInterfaceDir, "Abstract" + className + ".java");
    if (!abstractFile.exists()) {
        abstractFile.withPrintWriter("UTF-8") {
            out -> generateAbstract(out, className, fields, table, abstractAndInterfaceDir)
        }
    }

    def interfaceFile = new File(abstractAndInterfaceDir, className + ".java")
    if (!interfaceFile.exists()) {
        interfaceFile.withPrintWriter("UTF-8") {
            out -> generateInterface(out, className, fields, table, abstractAndInterfaceDir)
        }
    }

    def entityDir = dir.toString() + "\\infrastructure\\db\\dataobject"
    def entityFile = new File(entityDir, className + "Entity.java")
    if (!entityFile.exists()) {
        entityFile.withPrintWriter("UTF-8") {
            out -> generateEntity(out, className, fields, table, entityDir)
        }
    }

    def repositorys = dir.toString() + "\\infrastructure\\db\\repository"
    def repositoryFile = new File(repositorys, className + "Repository.java")
    if (!repositoryFile.exists()) {
        repositoryFile.withPrintWriter("UTF-8") {
            out -> generateRepository(out, className, fields, table, repositorys)
        }
    }
}

def generatePackageInfo(out, dir) {
    out.println "package " + getPackageName(dir)
}


def generateQryServiceImpl(out, className, dir) {
    out.println "package " + getPackageNameForTemp(dir) + ".impl;"
    out.println ""

    out.println "import " + getPackageNameForTemp(dir) + ".$className" + "QryService;"
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""

    out.println "public class $className" + "QryServiceImpl implements $className" + "QryService {"
    out.println ""
    out.println ""
    out.println "}"
}

def generateQryService(out, className, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""
    out.println "public interface $className" + "QryService {"
    out.println ""
    out.println ""
    out.println "}"
}


def generateCmdServiceImpl(out, className, dir) {
    out.println "package " + getPackageNameForTemp(dir) + ".impl;"
    out.println ""
    out.println "import " + getPackageNameForTemp(dir) + ".$className" + "CmdService;"
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""

    out.println "public class $className" + "CmdServiceImpl implements $className" + "CmdService {"
    out.println ""
    out.println ""
    out.println "}"
}

def generateCmdService(out, className, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""
    out.println "public interface $className" + "CmdService {"
    out.println ""
    out.println ""
    out.println "}"
}


//out指文件路径,className类名即表名,classConment表字段的注解,fields装载了一个表的所有列的列信息,元素类型为 Map 的 List。
def generateRepository(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import $packageName" + ".infrastructure.db.dataobject.$className" + "Entity;"
    out.println "import org.springframework.data.jpa.repository.JpaRepository;"
    out.println "import org.springframework.data.jpa.repository.JpaSpecificationExecutor;"
    out.println ""

    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""

    out.println "public interface $className" + "Repository extends JpaRepository<$className" + "Entity, String>, JpaSpecificationExecutor<$className" + "Entity>  {"
    out.println ""
    out.println ""
    out.println "}"
}


//out指文件路径,className类名即表名,classConment表字段的注解,fields装载了一个表的所有列的列信息,元素类型为 Map 的 List。
def generateEntity(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import lombok.ToString;"
    out.println "import cn.medbanks.trading.common.domain.BaseDomain;"
    out.println "import $packageName" + ".domain.aggregate.Abstract" + "$className;"
    out.println "import $packageName" + ".domain.aggregate.$className;"

    out.println "import org.hibernate.annotations.DynamicInsert;"
    out.println "import org.hibernate.annotations.DynamicUpdate;"
    out.println "import org.hibernate.annotations.GenericGenerator;"
    out.println "import javax.persistence.*;"
    out.println "import java.util.List;"
    out.println ""
    out.println "import org.springframework.data.jpa.domain.support.AuditingEntityListener;"
    Set types = new HashSet()

    fields.each() {
        types.add(it.type)
    }

    if (types.contains("Date")) {
        out.println "import java.util.Date;"
    }

    if (types.contains("InputStream")) {
        out.println "import java.io.InputStream;"
    }

    out.println ""
    out.println ""

    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"


    out.println ""
    out.println ""
    out.println "@DynamicInsert(value = true)"
    out.println "@DynamicUpdate(value = true)"
    out.println "@EntityListeners({AuditingEntityListener.class})"
    out.println "@Table(name = \"${table.getName()}\")"
    out.println "@Entity"
    out.println "@ToString"
    out.println "public class $className" + "Entity extends Abstract" + "$className implements $className {"
    out.println ""

    Set bases = ["createdBy", "dateCreated", "dateUpdated", "updatedBy"]


// 输出get/set方法
    fields.each() {
        if (!bases.contains(it.name)) {
            if (it.name.capitalize() == className.capitalize() + "Id" || it.name == "id") {
                out.println "\t@Id\n" +
                        "\t@GeneratedValue(generator = \"uuid\")\n" +
                        "\t@GenericGenerator(\n" +
                        "            name = \"uuid\",\n" +
                        "            strategy = \"cn.medbanks.trading.common.domain.CustomUUIDGenerator\"\n" +
                        " \t)" +
                        "\t@Column(length = 32)"

            } else {
                out.println "\t@Column"
            }

            if (it.type == "Date") {
                out.println "\t@Temporal(TemporalType.TIMESTAMP)"
            }
            out.println "\t@Override"
            out.println "\tpublic ${it.type} get${it.name.capitalize()}() {"
            out.println "\t\treturn this.${it.name};"
            out.println "\t}"
            out.println ""

            out.println "\t@Override"
            out.println "\tpublic void set${it.name.capitalize()}(${it.type} ${it.name}) {"
            out.println "\t\tthis.${it.name} = ${it.name};"
            out.println "\t}"
            out.println ""
        }
    }
    out.println ""
    out.println "}"
}


//out指文件路径,className类名即表名,classConment表字段的注解,fields装载了一个表的所有列的列信息,元素类型为 Map 的 List。
def generateInterface(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    Set types = new HashSet()

    fields.each() {
        types.add(it.type)
    }


    if (types.contains("Date")) {
        out.println "import java.util.Date;"
    }

    if (types.contains("InputStream")) {
        out.println "import java.io.InputStream;"
    }
    out.println ""
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""


    out.println "public interface $className  {"
    out.println ""
    Set bases = ["createdBy", "dateCreated", "dateUpdated", "updatedBy"]
// 输出get/set方法
    fields.each() {
        if (!bases.contains(it.name)) {
            if (isNotEmpty(it.commoent)) {
                out.println "\t/**"
                out.println "\t * 获取 ${it.commoent}"
                out.println "\t *"
                out.println "\t * @return  ${it.commoent}"
                out.println "\t */ "
            }
            out.println ""
            out.println "\t ${it.type} get${it.name.capitalize()}();"
            out.println ""

            if (isNotEmpty(it.commoent)) {
                out.println "\t/**"
                out.println "\t * 设置 ${it.commoent}"
                out.println "\t *"
                out.println "\t *  @param ${it.name}  ${it.commoent}"
                out.println "\t */ "
            }
            out.println "\t void set${it.name.capitalize()}(${it.type} ${it.name});"
        }
    }
    out.println ""
    out.println "}"
}


// out指文件路径 , className类名即表名 , classConment表字段的注解 , fields装载了一个表的所有列的列信息 , 元素类型为 Map 的 List 。

def generateAbstract(out, className, fields, table, dir) {
    out.println "package " + getPackageName(dir)
    out.println ""
    out.println "import cn.medbanks.trading.common.domain.AbstractDomain;"
    Set types = new HashSet()

    fields.each() {
        types.add(it.type)
    }

    if (types.contains("Date")) {
        out.println "import java.util.Date;"
    }

    if (types.contains("InputStream")) {
        out.println "import java.io.InputStream;"
    }
    out.println ""
    out.println ""
    out.println " /**"
    out.println "  * @author: $author"
    out.println "  */"
    out.println ""
    out.println ""

    out.println "public abstract class Abstract$className  extends AbstractDomain {"
    out.println ""
    //out.println genSerialID()
    Set bases = ["createdBy", "dateCreated", "dateUpdated", "updatedBy"]
    fields.each() {
        if (!bases.contains(it.name)) {
            out.println ""
            if (isNotEmpty(it.commoent)) {
                out.println "\t/**"
                out.println "\t * ${it.commoent}"
                out.println "\t */"
            }
            out.println "\tprotected ${it.type} ${it.name};"
        }
    }
    out.println ""
    out.println "}"
}


// 获取包所在文件夹路径
def getPackageName(dir) {
    return getPackageNameForTemp(dir) + ";"
}

def getPackageNameForTemp(dir) {
    return dir.toString().replaceAll("\\\\", ".").replaceAll("/", ".").replaceAll("^.*src(\\.main\\.java\\.)?", "")
}


/**
 * 字段计算(处理)函数
 * calcFields 方法会遍历并取出 DasTable 中每一个字段的属性并放入 fields 中,
 * fields 类型相当于 java 中一个元素类型为 Map 的 List。
 */
def calcFields(table) {
    DasUtil.getColumns(table).reduce([]) { fields, col ->
        def spec = Case.LOWER.apply(col.getDataType().getSpecification())

        def typeStr = typeMapping.find { p, t -> p.matcher(spec).find() }.value
        def comm = [
                /**name   : fieldName(col.getName(), false),
                 * 我的字段会给字典字段和外键字段增加标识头,这里转属性时去掉
                 def fieldName(str, capitalize) {//去除开头的标识
                 if (str.startsWith("dict_")) str = str.substring(5)
                 if (str.startsWith("fk_")) str = str.substring(3)
                 def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
                 .collect { Case.LOWER.apply(it).capitalize() }.join("")
                 .replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
                 capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]}*/
                colName : col.getName(),
                // name : changeStyle (javaName(col.getName(), false) ,true),
                name    : javaName(col.getName(), false),
                type    : typeStr,
                commoent: col.getComment(),
                annos   : "\t@Column(name = \"" + col.getName() + "\" )"]
        if ("id".equals(Case.LOWER.apply(col.getName())))
            comm.annos += "\n\t@Id\n\t@GeneratedValue"
        fields += [comm]
    }
}

// 处理类名(这里是因为我的表都是以t_命名的,所以需要处理去掉生成类名时的开头的T,
// 如果你不需要那么请查找用到了 javaClassName这个方法的地方修改为 javaName 即可)
def javaClassName(str, capitalize) {
    def s = str.split(/[^\p{Alnum}]/).collect { def s = Case.LOWER.apply(it).capitalize() }.join("")
// 去除开头的T http://developer.51cto.com/art/200906/129168.htm
    //s = s[7..s.size() - 1]
    capitalize ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}
/**
 * 类名以驼峰式命名
 */
/**
 * javaName 将数据库字段名映射为驼峰风格的 java 变量名。
 */
def javaName(str, capitalize) {
    def s = com.intellij.psi.codeStyle.NameUtil.splitNameIntoWords(str)
            .collect { Case.LOWER.apply(it).capitalize() }
            .join("")
            .replaceAll(/[^\p{javaJavaIdentifierPart}[_]]/, "_")
    capitalize || s.length() == 1 ? s : Case.LOWER.apply(s[0]) + s[1..-1]
}
/**
 * 类名以表名(其中表名的首字母大写)命名
 */
/*def javaName(str, capitalize) {
    def s = str.split(/(?<=[^\p{IsLetter}])/).collect { Case.LOWER.apply(it).capitalize() }
            .join("").replaceAll(/[^\p{javaJavaIdentifierPart}]/, "_")
    capitalize || s.length() == 1? s : Case.LOWER.apply(s[0]) + s[1..-1]
}*/

def isNotEmpty(content) {
    return content != null && content.toString().trim().length() > 0
}


static String genSerialID() {
    return "\tprivate static final long serialVersionUID = " + Math.abs(new Random().nextLong()) + "L;";
}

2、基础实体

@MappedSuperclass
public abstract class AbstractDomain implements Serializable {

    /**
     * 记录创建人
     */
    protected String createdBy;

    /**
     * 记录创建时间
     */
    protected Date dateCreated;

    /**
     * 记录最后更新人
     */
    protected String updatedBy;

    /**
     * 记录更新时间
     */
    protected Date dateUpdated;

    @Temporal(TemporalType.TIMESTAMP)
    @CreatedDate
    @Column(name = "date_created", columnDefinition = "datetime(0) null comment '创建时间'")
    public Date getDateCreated() {
        return dateCreated;
    }

    public void setDateCreated(Date dateCreated) {
        this.dateCreated = dateCreated;
    }

    @Temporal(TemporalType.TIMESTAMP)
    @LastModifiedDate
    @Column(name = "date_updated", columnDefinition = "datetime(0) null comment '更新时间'")
    public Date getDateUpdated() {
        return dateUpdated;
    }

    public void setDateUpdated(Date dateUpdated) {
        this.dateUpdated = dateUpdated;
    }

    @CreatedBy
    @Column(name = "created_by", columnDefinition = "varchar(255) comment '创建人'")
    public String getCreatedBy() {
        return createdBy;
    }

    public void setCreatedBy(String createdBy) {
        this.createdBy = createdBy;
    }

    @LastModifiedBy
    @Column(name = "updated_by", columnDefinition = "varchar(255) comment '更新人'")
    public String getUpdatedBy() {
        return updatedBy;
    }

    public void setUpdatedBy(String updatedBy) {
        this.updatedBy = updatedBy;
    }

}

3、审计

自动更新创建人、创建时间、更新人、更新时间

@Slf4j
@Configuration
public class UserAuditor implements AuditorAware {
    private final static String DEFAULT_AUDITOR = "default_auditor";


    @Override
    public Optional getCurrentAuditor() {
        UserInfo userInfo = new UserInfo();
        try {
            UserInfoUtils userInfoUtils = new UserInfoUtils();
            userInfo = userInfoUtils.getUserInfo();
        } catch (Exception e) {
            log.info("getUserInfo null");
        }
        return Optional.of(Optional.ofNullable(userInfo.getUserName()).orElse(DEFAULT_AUDITOR));
    }
}

二、使用

1、用idea 连接数据库,右键点击数据库中要用来生成的表,如下图所示

利用IDEA 自带功能生成POJOs_第1张图片

2、找到Generate POJOs.groovy并复制脚本文件到这个文件中。

3、重复上面动作选择Generate POJOs.groovy

利用IDEA 自带功能生成POJOs_第2张图片

4、在弹出的选择框内选择文件需要生成的位置

利用IDEA 自带功能生成POJOs_第3张图片

5、生成文件,调整实体关系

利用IDEA 自带功能生成POJOs_第4张图片

三、其他

可以根据数据库及自己业务类型调整文件结构和实体对应关系。

你可能感兴趣的:(工具)