STDP突触的设计(二)

 之前写过一个简单STDP突触的窗口函数,但是实际上的STDP突触远比一个窗口函数难得多,然后在拜读了论文Spike-Timing-Dependent Plasticity in Balanced Random Networks中有具体的分布式突触的设计,然后我按照他们给出的思路,复现了STDP突触的代码

#include 
#include "LIF.h"
#include "stdp_connection.h"
#include

using namespace std;

int main()
{
	vector pre_spike = {0,1,0,0,0,0,0,1,0,0 };
	vector post_spike= { 0,0,0,0,1,0,0,1,0,0 };

	LIF pre_neuron;
	LIF post_neuron;
	pre_neuron.update(pre_spike);
	post_neuron.update(post_spike);
	cout << "-----------pre neuron history is------------" << endl;
	pre_neuron.print_history();
	cout << "-----------post neuron history is------------" << endl;
	post_neuron.print_history();
	
	STDP stdp;
	for (int i = 0; i < pre_spike.size(); i++)
	{
		if (pre_spike[i] == 1)
		{
			stdp.update_weight(i, post_neuron);
		}
	}
	
	
	return 0;
}

然年是神经元

#pragma once
#include
#include
#include"Histentry.h"

using namespace std;
class LIF
{
public:
	LIF();
    //记录神经元的脉冲发放历史
    void set_spiketime(double t_spike);
    //返回(t1,t2]时刻的历史发放记录
    void get_history(double t1,double t2,deque::iterator*start, deque::iterator* end);
    //返回t处的Kminus(突触轨迹)值
    double get_K_value(double t);

    void update(vector spike);//神经元的更新步骤
    void print_history();//打印History表
private:
    int N_syn;//STDP突触传入的数量
    double tau;//时间常数,超参数

    double t_lastspike;//神经元上次发放的时间
    //t_sp是脉冲发生的事件,K_minus是t_sp时刻的K_,counter_sp计算脉冲信息被突触访问的次数
    deque history;
    //每一个神经元都维护这一个脉冲历史记录发放表
    double Kminus;//抑制因子
};

LIF::LIF()
    :tau(10.0),t_lastspike(0.0),Kminus(0.0),N_syn(1)
{

}
void LIF::print_history()
{
    deque::iterator it;
    for (it = history.begin(); it < history.end(); it++)
    {
        cout << "t_sp = " << it->t_sp << ",Kminus = " << it->Kminus << ",count_sp = " << it->count_sp << endl;
    }
}


void LIF::update(vector spike)
{
    for (int i = 0; i < spike.size(); i++)
    {
        if (spike[i] == 1)
        {
            //神经元发放脉冲,记录时间
            set_spiketime(i);
        }
    }
}

void LIF::set_spiketime(double t_spike)
{
    //当神经元的入度不为零时
    // 查看历史记录的某一条记录是否被至少使用N_syn次
    //如果超出这个,则证明此条记录不会使用,放出队列
    if (N_syn > 0)
    {

        while (history.size() > 1)
        {
            if (history.front().count_sp >= N_syn)
            {
                history.pop_front();    
            }
            else
            {
                break;
            }
        }
        //将此时刻的Kminus和时间记录到记录表中
        Kminus = Kminus * std::exp(-(t_spike - t_lastspike) / tau) + 1;
        history.push_back(Histentry(t_spike, Kminus, 0));
        t_lastspike = t_spike;
    }
    else
    {
        t_lastspike = t_spike;
    }
    
}

double LIF::get_K_value(double t)
{
    double K_value;
    //当记录表为零时,返回0
    if (history.empty())
    {
        K_value = 0.0;
        return K_value;
    }
    //从队列的队尾开始,加快访问速度
    deque::reverse_iterator it;
    for (it = history.rbegin(); it != history.rend(); it++)
    {
        if (t - it->t_sp > 0)
        {
            K_value = it->Kminus * std::exp((it->t_sp - t) / tau);
            return K_value;
        }
    }
    //如果遍历表没有的话,默认返回0
    K_value = 0.0;
    return K_value;
}

void LIF::get_history(double t1, double t2, deque::iterator* start, deque::iterator* end)
{
    *end = history.end();
    //当队列为空,返回空指针
    while (history.empty())
    {
        *start = *end;
        return;
    }

    deque::reverse_iterator it = history.rbegin();
    //寻找(t1,t2]时刻的记录,返回指向这两个位置的指针
    while (it != history.rend() && it->t_sp > t2)
    {
        it++;
    }
    *end = it.base();
    while (it != history.rend() && it->t_sp > t1)
    {
        it->count_sp += 1;
        it++;
    }
    *start = it.base();
}

然后是STDP突触

#pragma once
#include"LIF.h"

#include
#include "Histentry.h"
class STDP
{
public:
	STDP();

    double facilitate(double w, double Kplus);
    double depresss(double w, double Kminus);
    void update_weight(double t_spike,LIF target);

    double weight;//STDP权重

private:
    //以下为STDP公式的超参数
    double lamda;
    double alpha;
    double mu;
    double Wmax;
    double tau;


    double Kplus;//增强因子,每一个STDP突触都有

    double t_lastspike;//该突触上次发放脉冲的时间


};
STDP::STDP()
    : weight(10.0),tau(20.0),lamda(0.1),alpha(1.0),Wmax(100.0),t_lastspike(0.0),Kplus(0.0),mu(1.0)
{

}
//w代表当前权重,Kplus时增强程度,返回增强后的权重
double STDP::facilitate(double w, double Kplus)
{
    double norm_w = (w / Wmax) + (lamda * std::pow(1.0 - (w / Wmax), mu) * Kplus);
    return norm_w < 1.0 ? norm_w * Wmax : Wmax;
}
//w代表当前权重,Kminus时抑制程度,返回抑制后的权重
double STDP::depresss(double w, double Kminus)
{
    double norm_w = (w / Wmax) - (alpha * lamda * std::pow(w / Wmax, mu) * Kminus);
    return norm_w > 0.0 ? norm_w * Wmax : 0.0;
}
//当脉冲传递到后神经元时,调用此函数,t_spike代表源神经元的发放时间,
//target代表后神经元
void STDP::update_weight(double t_spike, LIF target)
{
    double delay = 1.0;
    std::deque::iterator start;
    std::deque::iterator end;
    //从突触后神经元获取相关范围(t1、t2]内的尖峰历史
    //历史记录(t_lastspike - delay,…,t_spike-delay]
    target.get_history(t_lastspike - delay, t_spike - delay, &start, &end);

    //在前神经元的脉冲发放时间后后神经元发放对突触权重起到增强作用
    double dt;
    while (start != end)
    {
        dt = t_lastspike - (start->t_sp + delay);
        ++start;
        weight = facilitate(weight, Kplus * std::exp(dt / tau));
    }

    //突触后神经元脉冲发作时的抑制应
    // 仅取决于自最近一次突触前尖峰以来的时间。
    double Kminus = target.get_K_value(t_spike - delay);
    cout << "Kminus" <

最后是记录各种历史的类

#pragma once
class Histentry
{
public:
	double t_sp; //脉冲发生的时间
	double Kminus;//此时刻对应的K_
	int count_sp;//计数器
	Histentry(double t_sp,double Kminus,int count_sp);

};

Histentry::Histentry(double t_sp, double Kminus, int count_sp)
	:t_sp(t_sp),
	Kminus(Kminus),
	count_sp(count_sp)
{
}

执行效果为

STDP突触的设计(二)_第1张图片

你可能感兴趣的:(分布式模拟脉冲神经网络,神经网络)