[Python小工具]Python批量生成数据到MySQL

[Python小工具]Python批量生成数据到MySQL

base.py

#!/usr/bin/python
# -*- coding:utf-8 -*-
import time
import random
from datetime import datetime, timedelta
from faker import Faker
import string

fake = Faker('zh_CN')


class Base_conn:
    # 数据库连接配置初始化
    def __init__(self, DB_HOST, DB_USER, DB_PASSWORD, DB_NAME):
        self.DB_HOST = DB_HOST
        self.DB_USER = DB_USER
        self.DB_PASSWORD = DB_PASSWORD
        self.DB_NAME = DB_NAME


class TimeShift:

    def starttime(self):
        self.start_time = time.time()

        return self.start_time

    def endtime(self):
        self.end_time = time.time()

        return self.end_time

    def executiontime(self):
        self.execution_time = self.end_time - self.start_time

        return self.execution_time

    def timeshift(self, execution_time):
        if execution_time < 60:
            return f"{int(execution_time)}秒"
        elif execution_time >= 60 or execution_time <= 3600:
            return f"{execution_time / 60:.1f}分钟"
        else:
            return f"{int(execution_time / 60 / 60):.1f}小时"


class Table_Employee:

    # 生成身份证信息的模拟函数,可根据实际情况调整
    def generate_id_card(self):
        # region_code = "420100"  # 地区编码
        region_code = "{:.6f}".format(random.random()).split('.')[1]  # 地区编码
        birth_date = datetime.strptime("19900101", "%Y%m%d") + timedelta(days=random.randint(0, 365 * 50))  # 出生日期
        seq_code = str(random.randint(100, 999))  # 序列码
        gender_code = str(random.randint(0, 9))  # 性别码
        check_code = str(random.randint(0, 9))  # 校验码(这里仅用于示例)

        id_card = f"{region_code}{birth_date.strftime('%Y%m%d')}{seq_code}{gender_code}{check_code}"

        # 确保生成的身份证号码长度正好是 18 个字符
        return id_card[:18]

    def __init__(self, work_id=None, name=None, age=None, gender=None, id_card=None, entry_date=None, department=None):
        # self.work_id = work_id if work_id else fake.pystr(min_chars=3, max_chars=10)
        self.work_id = work_id if work_id else "{:.6f}".format(random.random()).split('.')[1]
        self.name = name if name else fake.name()
        self.age = age if age else random.randint(18, 65)
        self.gender = gender if gender else fake.random_element(elements=('男', '女'))
        self.id_card = id_card if id_card else self.generate_id_card()
        self.entry_date = entry_date if entry_date else (
                datetime.now() - timedelta(days=random.randint(0, 365 * 10))).date()
        self.department = department if department else random.randint(1, 7)


class Table_Students:
    def Email_data(self):
        num_letters = random.randint(5, 10)
        random_letters = random.choices(string.ascii_letters, k=num_letters)
        Email = "www." + ''.join(random_letters) + '.com'
        return Email

    def Address_data(self):
        provinces = ['北京市', '天津市', '山西省']
        cities = {
            '北京市': ['东城区', '西城区', '朝阳区', '海淀区', '丰台区'],
            '天津市': ['和平区', '河东区', '河西区', '南开区', '红桥区'],
            '山西省': ['太原市', '大同市', '阳泉市', '长治市', '晋城市']
        }
        roads = ['长安街', '和平路', '人民大街', '建国路', '中山路', '解放街', '青年路', '光明街', '文化路', '新华街']

        province = random.choice(provinces)
        city = random.choice(cities[province])
        random_province = random.choice(provinces)
        random_road = random.choice(roads)
        random_address = random_province + random_road
        random_init = str(random.randint(100, 999))
        address = province + city + random_address + random_init + '号'

        return address

    def __init__(self, SNo=None, Sname=None, Gender=None, Birthday=None, Mobile=None, Email=None, Address=None,
                 Image='null'):
        # self.work_id = work_id if work_id else fake.pystr(min_chars=3, max_chars=10)
        self.SNo = SNo if SNo else str((datetime.now() - timedelta(days=random.randint(0, 365 * 10))).date()).split('-')[0] + str(random.randint(100,999))
        self.Sname = Sname if Sname else fake.name()
        self.Gender = Gender if Gender else fake.random_element(elements=('男', '女'))
        self.Birthday = Birthday if Birthday else (
                datetime.now() - timedelta(days=random.randint(0, 365 * 10))).date()
        self.Mobile = Mobile if Mobile else str(random.randint(130, 199)) + "{:.8f}".format(random.random()).split('.')[
            1]
        self.Email = Email if Email else self.Email_data()
        self.Address = Address if Address else self.Address_data()
        self.Image = Image

mysql_data.py

import re
import time
import pymysql
import threading
from faker import Faker
from pymysql.constants import CLIENT
from base import Base_conn, TimeShift, Table_Employee, Table_Students

fake = Faker('zh_CN')
# 数据库连接配置
# connect = Base_conn('localhost', 'root', 'iotplatform', 'test')
connect = Base_conn('localhost', 'root', 'iotplatform', 'student_profile_db')


def connect_db():
    try:
        connection = pymysql.connect(
            host=connect.DB_HOST,
            user=connect.DB_USER,
            password=connect.DB_PASSWORD,
            database=connect.DB_NAME,
            client_flag=CLIENT.MULTI_STATEMENTS,
            cursorclass=pymysql.cursors.DictCursor,
            charset='utf8mb4'
        )
        return connection
    except pymysql.MySQLError as e:
        print(f"Database connection failed: {e}")
        return None


# 数据插入函数
def insert_data(start, end, thread_id, sql):
    # 创建数据库连接
    connection = connect_db()
    try:
        # 创建 cursor 用于执行 SQL 语句
        with connection.cursor() as cursor:
            sql_cmd = sql

            for i in range(start, end):
                # 记录最后一次用于插入的数据
                table_employee = Table_Employee()
                table_students = Table_Students()
                # last_values = (table_employee.work_id,
                #                table_employee.name,
                #                table_employee.age,
                #                table_employee.gender,
                #                table_employee.id_card,
                #                table_employee.entry_date,
                #                table_employee.department)
                last_values = (
                    table_students.SNo,
                    table_students.Sname,
                    table_students.Gender,
                    table_students.Birthday,
                    table_students.Mobile,
                    table_students.Email,
                    table_students.Address,
                    table_students.Image
                )

                # 执行 SQL 语句
                cursor.execute(sql_cmd, last_values)

                connection.commit()

            # 提交事务
            # connection.commit()
        # print(f"Thread {thread_id}: Inserted rows {start} to {end}")
    except Exception as e:
        print(f"Thread {thread_id}: Error occurred: {e}")
    finally:
        # if last_values:
        #     print(f"Thread {thread_id}: Inserted rows {start} to {end}. Last row data: {last_values}")
        # 关闭数据库连接
        connection.close()


def perform_sql_operation(sql, operation):
    conn = connect_db()
    num = 0
    try:
        with conn.cursor() as cursor:
            if operation in ('select', 'show'):
                cursor.execute(sql)
                for row in cursor.fetchall():
                    print(row)
            elif operation in ('update', 'delete'):
                cursor.execute(sql)
                num += 1
                conn.commit()

        print('执行sql数:', num)


    except pymysql.MySQLError as e:
        print(f"Error: {e}")
        conn.rollback()
    finally:
        conn.close()


def insert_sql(sql_command, total_records, num_threads):
    total_records = total_records  # 总共需要插入的记录数
    num_threads = num_threads  # 线程数量
    records_per_thread = total_records // num_threads  # 每个线程需要插入的记录数
    sql_cmd = sql_command
    threads = []
    for i in range(num_threads):
        start_index = i * records_per_thread
        end_index = start_index + records_per_thread
        end_index = min(end_index, total_records)
        thread_id = i + 1
        # 创建线程
        thread = threading.Thread(target=insert_data, args=(start_index, end_index, thread_id, sql_cmd))
        threads.append(thread)
        thread.start()

    # 等待所有线程完成
    for thread in threads:
        thread.join()


def up_del_sel_sql(sql_cmd):
    sql_command = sql_cmd
    perform_sql_operation(sql_command, command)


# 主函数,负责创建线程并分配任务
def main(sql):
    if sql.lower() == 'insert':
        insert_sql(sql_cmd, total_records, num_threads)
    elif sql.lower() in ('update', 'delete', 'select', 'show'):
        up_del_sel_sql(sql_cmd)
    else:
        print('请传入执行的sql类型: [insert|update|delete|select]')


if __name__ == '__main__':
    # 总共需要插入的记录数
    total_records = 1000000
    # 开启线程数量
    num_threads = 1
    # 执行sql
    # sql_cmd = """ insert into `employee` (`workid`, `name`, `age`, `gender`, `idcard`, `entrydate`, `department`)
    #               values (%s, %s, %s, %s, %s, %s, %s) """
    sql_cmd = """ insert into `Student` (`SNo`, `Sname`, `Gender`, `Birthday`, `Mobile`, `Email`, `Address` , `Image`)
                      values (%s, %s, %s, %s, %s, %s, %s,%s) """
    # sql_cmd = "delete from employee limit 100000"
    # sql_cmd = "select count(*) from employee"
    # sql_cmd = "show tables like '%employee%'"
    # sql_cmd = ""

    command = re.search(r'\binsert|update|delete|select|show\b', sql_cmd, re.IGNORECASE)
    command = command.group().lower() if command else None
    if command:
        time_shift = TimeShift()
        start_time = time_shift.starttime()
        main(command)
        end_time = time_shift.endtime()
        execution_time = time_shift.timeshift(time_shift.executiontime())
        print(f"执行时间: {execution_time}")
    else:
        print('未找到匹配的命令')

你可能感兴趣的:(python,开发语言)