(Python)遗传算法解决旅行商问题(GA-TSP)

原视频:【算法】遗传算法解决旅行商(TSP)问题_哔哩哔哩_bilibili

代码链接:ga.py · main · mirrors / zifeiyu0531 / ga-tsp · GitCode

up主的讲解很有东西,这里对代码进行汇总,对一些小白看起来可能遇到的小问题进行备注。

一、代码构成

        共有三个py文件构成:main.py主程序入口、ga.py遗传算法的流程、config.py一些基本参数的配置。

1.main.py

import turtle
import numpy as np
import my_utils.ga_tsp_config as conf
import my_utils.ga
import matplotlib.pyplot as plt

config = conf.get_config()

# 城市距离矩阵计算方法
def build_dist_mat(input_list):
    n = config.city_num
    dist_mat = np.zeros([n, n])
    for i in range(n):
        for j in range(i + 1, n):
            d = input_list[i, :] - input_list[j, :]
            # 计算点积
            dist_mat[i, j] = np.dot(d, d)
            dist_mat[j, i] = dist_mat[i, j]
    return dist_mat


# 随机生成城市坐标
city_pos_list = np.random.rand(config.city_num, config.pos_dimension)
# 城市距离矩阵
city_dist_mat = build_dist_mat(city_pos_list)

print("\n", city_pos_list)
print(city_dist_mat)

# 遗传算法运行
ga = my_utils.ga.Ga(city_dist_mat)  # 实例化一个Ga类
result_list, fitness_list = ga.train()  # result_list路线结果图  fitness_list适应度结果,对应两个图
result = result_list[-1]    # list[-1]:返回最后一个数据,这个集合中存放每一代对应的解,最终的结果为最后一代基因的路线
result_pos_list = city_pos_list[result, :]

# 绘图
# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['KaiTi']  # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题

fig = plt.figure()
plt.plot(result_pos_list[:, 0], result_pos_list[:, 1], 'o-r')
plt.title(u"路线")
# plt.legend()
plt.legend(["This is my legend"], fontsize="x-large")
fig.show()

fig = plt.figure()
plt.plot(fitness_list)
plt.title(u"适应度曲线")
# plt.legend()
# 这里为空的话,会显示错误No artists with labels found to put in legend.
# Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
plt.legend(["This is my legend"], fontsize="x-large")
fig.show()
turtle.done()  # 暂停程序,给时间来查看图形

2.ga.py

import my_utils.ga_tsp_config as conf
import random

city_dist_mat = None
config = conf.get_config()
# 各项参数
gene_len = config.city_num
individual_num = config.individual_num  # 种群个体数
gen_num = config.gen_num    # 迭代轮数
mutate_prob = config.mutate_prob

def copy_list(old_arr: [int]):
    new_arr = []
    for element in old_arr:
        new_arr.append(element)
    return new_arr


# 个体类
class Individual:
    def __init__(self, genes=None):
        # 随机生成序列
        if genes is None:
            genes = [i for i in range(gene_len)]    # 先生成一个从1-n的数组
            random.shuffle(genes)       # 再把他打乱,相当于随机生成旅行商的一次路径
        self.genes = genes  # 如果基因不为none,在交叉编译生成子代的时候会用到
        self.fitness = self.evaluate_fitness()  # 根据基因计算适应度

    def evaluate_fitness(self):
        # 计算个体适应度
        fitness = 0.0
        for i in range(gene_len - 1):
            # 起始城市和目标城市
            from_idx = self.genes[i]
            to_idx = self.genes[i + 1]
            fitness += city_dist_mat[from_idx, to_idx]
        # 连接首尾
        fitness += city_dist_mat[self.genes[-1], self.genes[0]]
        return fitness


class Ga:
    def __init__(self, input_): # 输入:城市距离矩阵
        global city_dist_mat    # city_distance_matrix :城市距离矩阵
        city_dist_mat = input_
        self.best = None  # 每一代的最佳个体
        self.individual_list = []  # 每一代的个体列表
        self.result_list = []  # 每一代对应的解,对应第一个输出
        self.fitness_list = []  # 每一代对应的适应度,对应第二个输出

    def cross(self):
        new_gen = []
        random.shuffle(self.individual_list)    # 把每一代的种群个体进行打乱
        for i in range(0, individual_num - 1, 2):   # 把打乱过的种群以步长为2进行循环遍历
            # 父代基因
            genes1 = copy_list(self.individual_list[i].genes)
            genes2 = copy_list(self.individual_list[i + 1].genes)
            index1 = random.randint(0, gene_len - 2)    # 相当于在交叉过程中选取基因片段
            index2 = random.randint(index1, gene_len - 1)   # 随机生成两个index,保证index2>index1
            pos1_recorder = {value: idx for idx, value in enumerate(genes1)}    # 记录初始基因片段对应的位置
            pos2_recorder = {value: idx for idx, value in enumerate(genes2)}    # 解决交叉时产生的位置冲突问题
            # 交叉
            # 如果g2准备给g1的基因中存在冲突,那就把g1当前交叉位置的基因与g1中产生冲突的基因交换位置
            for j in range(index1, index2):
                value1, value2 = genes1[j], genes2[j]
                pos1, pos2 = pos1_recorder[value2], pos2_recorder[value1]
                genes1[j], genes1[pos1] = genes1[pos1], genes1[j]
                genes2[j], genes2[pos2] = genes2[pos2], genes2[j]
                pos1_recorder[value1], pos1_recorder[value2] = pos1, j
                pos2_recorder[value1], pos2_recorder[value2] = j, pos2
            new_gen.append(Individual(genes1))  # 交叉之后生成两个新个体,放到new_gen中
            new_gen.append(Individual(genes2))
        return new_gen

    # 对交叉得到的新一代基因进行变异操作的方法
    def mutate(self, new_gen):
        for individual in new_gen:
            if random.random() < mutate_prob:
                # 翻转切片
                old_genes = copy_list(individual.genes)
                index1 = random.randint(0, gene_len - 2)
                index2 = random.randint(index1, gene_len - 1)
                genes_mutate = old_genes[index1:index2]
                genes_mutate.reverse()
                individual.genes = old_genes[:index1] + genes_mutate + old_genes[index2:]
        # 两代合并,后面还要对合并之后的种群进行选择
        self.individual_list += new_gen

    def select(self):
        # 锦标赛--防止陷入局部最优
        group_num = 10  # 小组数
        group_size = 10  # 每小组人数
        group_winner = individual_num // group_num  # 每小组获胜人数--每个小组最终要保留的人数60/10=6,即每组都是10进6
        winners = []  # 锦标赛结果
        for i in range(group_num):
            group = []
            for j in range(group_size):
                # 随机组成小组
                player = random.choice(self.individual_list)
                player = Individual(player.genes)
                group.append(player)
            group = Ga.rank(group)
            # 取出获胜者
            winners += group[:group_winner]
        self.individual_list = winners  # 获得这一代的最终结果

    # @staticmethod用于修饰类中的方法,使其可以再不创建类实例的情况下调用方法,执行效率比较高。也可以像一般方法一样用实例调用方法
    # 静态方法不可以应用类中的属性或方法,其参数列表也不需要约定的默认参数self。理解为类对外部函数的封装,有助于代码结构优化和可读性
    @staticmethod
    def rank(group):
        # 冒泡排序
        for i in range(1, len(group)):
            for j in range(0, len(group) - i):
                if group[j].fitness > group[j + 1].fitness:
                    group[j], group[j + 1] = group[j + 1], group[j]
        return group

    # 得到子代的方法
    def next_gen(self):
        # 交叉
        new_gen = self.cross()
        # 对交叉之后的子代基因进行变异
        self.mutate(new_gen)
        # 对变异之后两代基因合并之后的种群做进一步选择
        self.select()
        # 循环遍历获得这一代的best结果
        for individual in self.individual_list:
            if individual.fitness < self.best.fitness:
                self.best = individual

    def train(self):
        # 初代种群
        self.individual_list = [Individual() for _ in range(individual_num)]  # 通过个体的构造函数生成对应数量的种群
        self.best = self.individual_list[0] # 这里的best也是随机的
        # 迭代
        for i in range(gen_num):    # 迭代轮数
            self.next_gen() # 生成子一代
            # 连接首尾
            result = copy_list(self.best.genes)
            result.append(result[0])
            self.result_list.append(result)
            self.fitness_list.append(self.best.fitness)
        return self.result_list, self.fitness_list

3.config.py

# -*- coding: utf-8 -*-
import argparse

parser = argparse.ArgumentParser(description='Configuration file')
arg_lists = []


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg


# Data
data_arg = add_argument_group('Data')
data_arg.add_argument('--city_num', type=int, default=20, help='city num')  # 城市数量
data_arg.add_argument('--pos_dimension', type=int, default=2, help='city num')  # 坐标维度
data_arg.add_argument('--individual_num', type=int, default=60, help='individual num')  # 种群个体数50-100
data_arg.add_argument('--gen_num', type=int, default=400, help='generation num')  # 迭代轮数,城市数量越多,这里也
data_arg.add_argument('--mutate_prob', type=float, default=0.25, help='probability of mutate')  # 变异概率


def get_config():
    config, unparsed = parser.parse_known_args()
    return config


def print_config():
    config = get_config()
    print('\n')
    print('Data Config:')
    print('* city num:', config.city_num)
    print('* individual num:', config.individual_num)
    print('* generation num:', config.gen_num)
    print('* probability of mutate:', config.cross_prob)

二、注意点说明

  1. 直接运行如果报错:
    Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
    就将main.py中plt.legend()改为plt.legend(["This is my legend"], fontsize="x-large")
  2. 最后结果闪退:在main.py最后加上turtle.done(),作用是暂停程序,给用户时间来查看图形
  3. @staticmethod用于修饰类中的方法,使其可以再不创建类实例的情况下调用方法,执行效率比较高。也可以像一般方法一样用实例调用方法
    静态方法不可以应用类中的属性或方法,其参数列表也不需要约定的默认参数self。理解为类对外部函数的封装,有助于代码结构优化和可读性
  4. list列表切片说明
#list[起始索引,结束索引]切片时包含起始索引位置的元素,但不包含结束索引位置的元素
# 索引为 0表示第一个,1表示第二个,-1表示最后一个,-2表示倒数第二个

# list[-1]:返回最后一个数据
# list[:1]:返回0到1的数据,故返回第一个数据
# list[1:]:返回从1到0的数据,故返回第二个到最后一个的数据(不包含结束索引位置0)
# list[-1:]:返回从-1到0的数据,故返回最后一个数据
# list[:-1]:返回从0到-1的数据,故返回第一个到倒数第二个的数据(不包含结束索引位置-1)
# list[::1]:表示步长为1,步长大于0时,返回序列为原顺序;。
# list[::-1]: 表示从右往左以步长为1进行切片。步长小于0时,返回序列为倒序
# list[::2]: 表示从左往右步长为2进行切片

list = [1, 2, 3, 4, 5]
print(list[-1])  # 5
print(list[:1])  # [1]
print(list[1:])  # [2, 3, 4, 5]
print(list[-1:])  # [5]
print(list[:-1])  # [1, 2, 3, 4]
print(list[::1])  # [1, 2, 3, 4, 5]
print(list[::-1])  # [5, 4, 3, 2, 1]
print(list[::2])  # [1, 3, 5]

你可能感兴趣的:(Python,算法,python,numpy)