github代码分析——PSO粒子群优化算法代码(python)

文章目录

  • 1 readme阅读
  • 2 代码

  • 本篇博文是建立在阅读Gao老师的repository代码之上,老规矩,分享转发,请老师收下我的膝盖!
    github代码分析——PSO粒子群优化算法代码(python)_第1张图片
  • repository的地址:

https://github.com/EddyGao/PSO

1 readme阅读

        粒子群优化算法(PSO:Particle swarm optimization) 是一种进化计算技术(evolutionary computation)。
       源于对鸟群捕食的行为研究。粒子群优化算法的基本思想:是通过群体中个体之间的协作和信息共享来寻找最优解。

       鸟被抽象为没有质量和体积的微粒(点),并延伸到N维空间,粒子i在N维空间的位置表示为矢量Xi=(x1,x2,…,xN),飞行速度表示为矢量Vi=(v1,v2,…,vN)。
       每个粒子都有一个由目标函数决定的适应值(fitness value),并且知道自己到目前为止发现的最好位置(pbest)和现在的位置Xi。这个可以看作是粒子自己的飞行经验。
       除此之外,每个粒子还知道到目前为止整个群体中所有粒子发现的最好位置(gbest)(gbest是pbest中的最好值),这个可以看作是粒子同伴的经验。

       粒子就是通过自己的经验和同伴中最好的经验来决定下一步的运动。

  • 标准PSO算法的流程:
    1)初始化一群微粒(群体规模为N),包括随机位置和速度;
    2)评价每个微粒的适应度(fitness);
    3)对每个微粒,将其适应值与其经过的最好位置pbest作比较,如果较好,则将其作为当前的最好位置pbest;
    4)对每个微粒,将其适应值与其经过的最好位置gbest作比较,如果较好,则将其作为当前的最好位置gbest;
    5)根据公式(2)、(3)调整微粒速度和位置;
    6)未达到结束条件则转第2)步。

PSO的优势:在于简单容易实现并且没有许多参数的调节。目前已被广泛应用于函数优化、神经网络训练、模糊系统控制以及其他遗传算法的应用领域

2 代码

其中遇到了很多问题,下面的代码也是我自己修改过的!

  1. 适应度函数fitness的理解还是不够深入,这个函数有怎么样的要求呢?
  2. 是求fitness的最大值还是最小值呢?体现在算法的什么地方呢?
  3. 在最后,如何把gbestpop在300次迭代的位置都能画出来?这应该是一个三维的图了,x轴表示1-300,y轴表示x1 z轴表示x2
  4. 注意:粒子的位置是2d的,速度也是2d的!
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import math

w = 1  # 惯性权重
lr = (0.49445, 1.49445)  # 个体学习率和群体学习率
maxgen = 300  # 最大迭代次数
sizepop = 50  # 种群规模
rangepop = (-2 * math.pi, 2 * math.pi)  # 粒子的位置的范围限制, x和y方向相同
rangespeed = (-0.5, 0.5)  # 粒子的速度范围限制


def func(x):
    # x 输入粒子位置
    # y 粒子适应度值
    if (x[0] == 0) & (x[1] == 0):
        y = x[0]**2 + x[1]**2
    else:
        y = x[0] ** 2 + x[1] ** 2
    return y


def initpopvfit(sizepop):
    # 初始化为0
    pop = np.zeros((sizepop, 2))  # 2维
    v = np.zeros((sizepop, 2))  # 2维
    fitness = np.zeros(sizepop)

    # rand初始化50个粒子的位置pop[] 速度v[] 适应度fitness[]
    for i in range(sizepop):
        pop[i] = [(np.random.rand() - 0.5) * rangepop[0] * 2, (np.random.rand() - 0.5) * rangepop[1] * 2]
        v[i] = [(np.random.rand() - 0.5) * rangepop[0] * 2, (np.random.rand() - 0.5) * rangepop[1] * 2]
        fitness[i] = func(pop[i])
    return pop, v, fitness


def getinitbest(fitness, pop):
    # 群体最优的粒子位置及其适应度值
    gbestpop, gbestfitness = pop[fitness.argmax()].copy(), fitness.max()  # argmax# 取出a中元素最大值所对应的索引
    # 个体最优的粒子位置及其适应度值,使用copy()使得对pop的改变不影响pbestpop,pbestfitness类似
    pbestpop, pbestfitness = pop.copy(), fitness.copy()

    return gbestpop, gbestfitness, pbestpop, pbestfitness


pop, v, fitness = initpopvfit(sizepop)  # v.shape=(50*2)
print(pop)
print(fitness)
gbestpop, gbestfitness, pbestpop, pbestfitness = getinitbest(fitness, pop)
# print(v)

result = np.zeros(maxgen)  # result 就是放结果的
resultpop = np.zeros((maxgen, sizepop))

for i in range(maxgen):  # 迭代300次
    t = 0.5

    # 速度更新
    for j in range(sizepop):
        v[j] += lr[0] * np.random.rand() * (pbestpop[j] - pop[j]) + lr[1] * np.random.rand() * (gbestpop - pop[j])

        # 对于越界的速度,要进行合法性调整
        flag0 = v[j][0] < rangespeed[0]  # flag0=1表示axis=0处越界
        flag1 = v[j][1] > rangespeed[1]  # flag1=1表示axis=1处越界

        if flag0:
            v[j][0] = rangespeed[0]
        if flag1:
            v[j][1] = rangespeed[1]

    # 粒子位置更新
    for j in range(sizepop):
        # pop[j] += 0.5*v[j]
        pop[j] = (1 - t) * pop[j] + t * v[j]

        # 对于越界的位置,要进行合法性调整
        flag0 = pop[j][0] < rangepop[0]  # flag0=1表示axis=0处越界
        flag1 = pop[j][1] > rangepop[1]  # flag1=1表示axis=1处越界

        if flag0:
            pop[j][0] = rangepop[0]
        if flag1:
            pop[j][1] = rangepop[1]

    # 适应度更新
    for j in range(sizepop):
        fitness[j] = func(pop[j])

    for j in range(sizepop):
        if fitness[j] > pbestfitness[j]:
            pbestfitness[j] = fitness[j].copy()
            pbestpop[j] = pop[j].copy()

    if pbestfitness.max() > gbestfitness:
        gbestfitness = pbestfitness.max()
        gbestpop = pop[pbestfitness.argmax()].copy()
        print(gbestpop)

    result[i] = gbestfitness
    # resultpop[i] = gbestpop


plt.plot(result)
plt.show()

你可能感兴趣的:(机器学习算法,算法,python,开发语言)