神经元和突触模拟无法从正弦波输入中学习

Neuron and Synapse Simulation Not Learning from Sine Wave Input

提问人:Lyft 提问时间:11/4/2023 最后编辑:Lyft 更新时间:11/4/2023 访问量:24

问:

你好 Stack Overflow 社区,

我目前正在开发 C++ 中神经元和突触的模拟,并且在使用正弦波输入实现学习机制时遇到了困难。尽管进行了各种参数调整和调试尝试,但神经元并没有像预期的那样进行调整。此外,当我将增量时间设置为 0.1 或更高时,由于变量 V 或门控值变为 NaN 或无限,solve.cpp 崩溃。

所以我的问题是为什么这不能正确学习,我怎样才能让它学习?

我在下面为项目的每个组件提供了相关的代码片段。我已将代码结构化为多个文件,如下所示:

solve.cpp:包含模型方程。:模拟的入口点。 和 :定义 Neuron 类及其方法。 和 :定义 Synapse 类及其方法。main.cppneuron.hneuron.cppsynapse.hsynapse.cpp

solve.cpp 的上下文:该文件包含控制模拟中每个神经元和突触动力学的微分方程。这些方程计算每个时间步的膜电位和门控变量。当增量时间 () 设置为 0.1 或更高时,会出现特定问题,这有时会导致变量变为 NaN 或无限。这表明所使用的数值积分方法可能存在不稳定性。dt

#include "solve.h"
#include <cmath>




// Function to calculate membrane potential derivative
double membrane_potential_derivative(double V, double m, double h, double n, double gNa, double gK, double gL, double ENa, double EK, double El, double Cm, double dt, double I) {
    double INa = gNa * std::pow(m, 3) * h * (V - ENa);
    double IK = gK * std::pow(n, 4) * (V - EK);
    double Il = gL * (V - El);

    return (I - INa - IK - Il) / Cm;
}


// Implementations of alpha and beta functions:
double alpha_n(double V) { return 0.01 * (V + 55) / (1 - exp(-(V + 55) / 10)); }
double beta_n(double V) { return 0.125 * exp(-(V + 65) / 80); }
double alpha_m(double V) { return 0.1 * (V + 40) / (1 - exp(-(V + 40) / 10)); }
double beta_m(double V) { return 4 * exp(-(V + 65) / 18); }
double alpha_h(double V) { return 0.07 * exp(-(V + 65) / 20); }
double beta_h(double V) { return 1 / (exp(-(V + 35) / 10) + 1); }



// Rate functions for gating variables
double rate_n_change(double V, double n) { return alpha_n(V) * (1 - n) - beta_n(V) * n; }
double rate_m_change(double V, double m) { return alpha_m(V) * (1 - m) - beta_m(V) * m; }
double rate_h_change(double V, double h) { return alpha_h(V) * (1 - h) - beta_h(V) * h; }


// Runge-Kutta method implementation
void hh_runge_kutta(double& V, double& m, double& h, double& n, double gNa, double gK, double gL, double ENa, double EK, double El, double Cm, double dt, double I) {
    // K1
    double k1_V = membrane_potential_derivative(V, m, h, n, gNa, gK, gL, ENa, EK, El, Cm, dt, I);
    double k1_n = rate_n_change(V, n);
    double k1_m = rate_m_change(V, m);
    double k1_h = rate_h_change(V, h);
    if (std::isnan(k1_V)) throw std::runtime_error("k1_V is non-numeric");

    // Midpoint values
    double V2 = V + k1_V * dt / 2;
    double n2 = n + k1_n * dt / 2;
    double m2 = m + k1_m * dt / 2;
    double h2 = h + k1_h * dt / 2;


    // K2
    double k2_V = membrane_potential_derivative(V2, m2, h2, n2, gNa, gK, gL, ENa, EK, El, Cm, dt, I);
    double k2_n = rate_n_change(V2, n2);
    double k2_m = rate_m_change(V2, m2);
    double k2_h = rate_h_change(V2, h2);
    if (std::isnan(k2_V)) throw std::runtime_error("k2_V is non-numeric");


    // New midpoint values
    V2 = V + k2_V * dt / 2;
    n2 = n + k2_n * dt / 2;
    m2 = m + k2_m * dt / 2;
    h2 = h + k2_h * dt / 2;

    // K3
    double k3_V = membrane_potential_derivative(V2, m2, h2, n2, gNa, gK, gL, ENa, EK, El, Cm, dt, I);
    double k3_n = rate_n_change(V2, n2);
    double k3_m = rate_m_change(V2, m2);
    double k3_h = rate_h_change(V2, h2);
    if (std::isnan(k3_V)) throw std::runtime_error("k3_V is non-numeric");


    // End values
    double V3 = V + k3_V * dt;
    double n3 = n + k3_n * dt;
    double m3 = m + k3_m * dt;
    double h3 = h + k3_h * dt;

    // K4
    double k4_V = membrane_potential_derivative(V3, m3, h3, n3, gNa, gK, gL, ENa, EK, El, Cm, dt, I);
    double k4_n = rate_n_change(V3, n3);
    double k4_m = rate_m_change(V3, m3);
    double k4_h = rate_h_change(V3, h3);
    if (std::isnan(k4_V)) throw std::runtime_error("k4_V is non-numeric");


    // Weighted sum of gradients
    V += dt / 6 * (k1_V + 2 * k2_V + 2 * k3_V + k4_V);
    n += dt / 6 * (k1_n + 2 * k2_n + 2 * k3_n + k4_n);
    m += dt / 6 * (k1_m + 2 * k2_m + 2 * k3_m + k4_m);
    h += dt / 6 * (k1_h + 2 * k2_h + 2 * k3_h + k4_h);
}

// Euler method implementation
void hh_euler(double& V, double& m, double& h, double& n, double gNa, double gK, double gL, double ENa, double EK, double El, double Cm, double dt, double I) {
    // Calculate the derivatives for V, m, h, and n using the current state
    double dV_dt = membrane_potential_derivative(V, m, h, n, gNa, gK, gL, ENa, EK, El, Cm, dt, I);
    double dn_dt = rate_n_change(V, n);
    double dm_dt = rate_m_change(V, m);
    double dh_dt = rate_h_change(V, h);

    // Check for non-numeric derivatives, which could indicate a problem with the inputs or the derivative functions
    if (std::isnan(dV_dt) || std::isnan(dn_dt) || std::isnan(dm_dt) || std::isnan(dh_dt)) {
        throw std::runtime_error("Non-numeric derivative encountered in Euler method");
    }

    // Update the state using the Euler method
    V += dV_dt * dt;
    n += dn_dt * dt;
    m += dm_dt * dt;
    h += dh_dt * dt;
}

neuron.h/neuron.cpp 的上下文:这里定义了神经元类,它封装了神经元的属性,例如其膜电位、离子通道和各种状态变量。这些方法包括初始化例程、状态更新以及任何根据模拟输入(在本例中为正弦波)修改神经元参数的学习规则实现。

#ifndef NEURON_H
#define NEURON_H

#include "solve.h"
#include "random.h"
#include "params.h"

const double MIN_V = -90; // hypothetical minimum potential
const double MAX_V = 40;  // hypothetical maximum potential (spike peak)


// Forward declare Synapse to use it in the Neuron header
class Synapse;

class Neuron {
private:
    double V; // Membrane potential
    double rV; // Resting membrane potential
    double m, h, n; // Gating variables
    double gNa, gK, gL; // Conductances
    double ENa, EK, El; // Reversal potentials, mV
    double Cm; // Membrane capacitance, uF

    const double spike_threshold = -45; // Threshold in mV -69.04
    double last_spike_time = -1; // Initialize to -1 to indicate no spike has occurred

    std::vector<Synapse*> incoming_synapses; // Vector to hold synapses where this neuron is postsynaptic

    std::vector<Synapse*> outgoing_synapses; // Vector to hold synapses where this neuron is presynaptic
public:
    Neuron(const NeuronParams& params, const InitialValues& init);

    void set_state(const NeuronParams& params, const InitialValues& init);

    // Fires neuron and updates values, takes current (uA) and delta_time (ms)
    void step(double current, double delta_time, double real_time);
    
    // Method to apply inhibition from another neuron
    void inhibit(double amount);

    // Function to call when the neuron spikes
    void on_spike(double time);

    // Handles receiving input from synapses
    void receive(double synaptic_current, double delta_time);

    // Connects this neuron to another neuron
    void connect_to(Neuron* other);

    // Returns all synapses connected to this neuron
    std::vector<Synapse*> get_all_synapses() const;

    // Checks if this neuron is connected to another neuron
    bool is_connected_to(Neuron* other) const;

    // Getter functions for neuron properties
    double get_membrane_potential() const;
    void reset_potential();
    double get_resting_potential() const;
    double get_normalized_state() const;
    double get_last_spike_time() const;

};

#endif // NEURON_H
#include "neuron.h"
#include "synapse.h"


// Constructor implementation
Neuron::Neuron(const NeuronParams& params, const InitialValues& init) {
    // Initialize neuron parameters from the provided structs
    V = init.V;
    rV = V;
    m = init.m;
    h = init.h;
    n = init.n;
    gNa = params.gNa;
    gK = params.gK;
    gL = params.gL;
    ENa = params.ENa;
    EK = params.EK;
    El = params.El;
    Cm = params.Cm;
}

void Neuron::set_state(const NeuronParams& params, const InitialValues& init) {
    // Initialize neuron parameters from the provided structs
    V = init.V;
    rV = V;
    m = init.m;
    h = init.h;
    n = init.n;
    gNa = params.gNa;
    gK = params.gK;
    gL = params.gL;
    ENa = params.ENa;
    EK = params.EK;
    El = params.El;
    Cm = params.Cm;
}

void Neuron::step(double input_current, double delta_time, double real_time) {
   

    double previous_potential = V;
    hh_runge_kutta(V, m, h, n, gNa, gK, gL, ENa, EK, El, Cm, delta_time, input_current);


    if (V >= spike_threshold) {
        on_spike(real_time);
    }

    // After updating the neuron's own properties, sum the inputs from synapses
    for (auto synapse : incoming_synapses) {
        synapse->apply(delta_time);
        synapse->apply_stdp(delta_time);
        synapse->apply_hebbian_learning(delta_time);
    }
}

// Method to apply inhibition from another neuron
void Neuron::inhibit(double amount) {
    this->V = std::max(this->V - amount, this->rV);
}


// Define the on_spike and check_for_spike methods
void Neuron::on_spike(double time) {
    last_spike_time = time;
    for (auto& synapse : outgoing_synapses) {
        synapse->on_presynaptic_spike(time);
    }
}

void Neuron::receive(double synaptic_current, double delta_time) {
    if (gL <= 0) {
        std::cerr << "Error: Conductance (gL) is non-positive, cannot compute Rm." << std::endl;
        throw std::runtime_error("Invalid conductance (gL) value.");
    }
    // Calculate the membrane resistance (Rm)
    double Rm = (gL > 0) ? 1 / gL : 1;  // Use gL to calculate membrane resistance. Ensure gL is not zero to avoid division by zero error.

    // Calculate the membrane time constant (tau = Rm * Cm)
    double tau = Rm * Cm;

    // Use an exponential decay model for updating the membrane potential based on the synaptic current
    double delta_V = (synaptic_current / Cm) * (1 - exp(-delta_time / tau));

    // Update the membrane potential
    V += delta_V;
}



// Synaptic Connections
void Neuron::connect_to(Neuron* other) {
    // We only create one synapse that goes from this neuron to the other neuron
    Synapse* syn = new Synapse(this, other, random(0.5), spike_threshold);
    outgoing_synapses.push_back(syn);
    other->incoming_synapses.push_back(syn);
}
std::vector<Synapse*> Neuron::get_all_synapses() const {
    std::vector<Synapse*> all_synapses;
    all_synapses.insert(all_synapses.end(), incoming_synapses.begin(), incoming_synapses.end());
    all_synapses.insert(all_synapses.end(), outgoing_synapses.begin(), outgoing_synapses.end());
    return all_synapses;
}
bool Neuron::is_connected_to(Neuron* other) const {
    // Check both incoming and outgoing synapses for a connection to the other neuron.
    return std::any_of(incoming_synapses.begin(), incoming_synapses.end(),
        [other](Synapse* s) { return s->get_pre_neuron() == other; }) ||
        std::any_of(outgoing_synapses.begin(), outgoing_synapses.end(),
            [other](Synapse* s) { return s->get_post_neuron() == other; });
}

// Getter functions
double Neuron::get_membrane_potential() const {
    return V;
}
void Neuron::reset_potential() {
    V = rV;
}
double Neuron::get_resting_potential() const {
    return rV;
}
double Neuron::get_normalized_state() const {
    //return (V - MIN_V) / (MAX_V - MIN_V);
    // Normalization of membrane potential between 0 and 1
    double norm_V = (V - MIN_V) / (MAX_V - MIN_V);

    // Geometric mean of gating variables as a representation of the channel's combined state
    double geom_mean_gates = pow(m * h * n, 1.0 / 3.0);

    // Incorporate gating variables and membrane potential
    double combined_state = norm_V * geom_mean_gates;

    // Use sigmoid to keep the final value bounded between 0 and 1
    double sigmoid_state = 1.0 / (1.0 + exp(-combined_state));

    return sigmoid_state;
}

double Neuron::get_last_spike_time() const {
    return last_spike_time;
}

synapse.h/synapse.cpp 的上下文这些文件定义 Synapse 类,该类表示神经元之间的连接。每个突触都有方法根据突触前神经元的活动计算突触电流,并作为学习过程的一部分调整其强度或权重。

#ifndef SYNAPSE_H
#define SYNAPSE_H



// Forward declare Neuron to use it in the Synapse header
class Neuron;

class Synapse {
private:
    Neuron* pre;     // Presynaptic neuron (source neuron)
    Neuron* post;    // Postsynaptic neuron (target neuron)
    double weight;   // Synaptic weight
    bool is_excitatory; // Flag to indicate whether the synapse is excitatory


    // Neurotransmitter Dynamics
    double nt_release = 1.0; // Initial neurotransmitter release amount
    double receptor_activated = 0.0; // Initially no receptors are activated
    double conductance = 0.002; // Maximum conductance in siemens (S)
    double receptor_binding_rate = 0.1; // Binding rate in ms^-1
    double nt_decay = 0.95; // Neurotransmitter decay factor per timestep
    double nt_max = 1.0; // Max neurotransmitter quantity normalized
    double E_syn; // Synaptic reversal potential

    // Synaptic Reversal Potential
    double E_syn_excitatory = 0.0;   // mV for excitatory synapses
    double E_syn_inhibitory = -80.0; // mV for inhibitory synapses

    // STDP (Spike-Timing Dependent Plasticity)
    double last_pre_spike_time = -1; // Initialize to -1 to indicate no spike has occurred
    double last_post_spike_time = -1; // Initialize to -1 to indicate no spike has occurred
    double spike_threshold;
    double stdp_time_constant = 20.0; // Common STDP time constant in ms
    double tau_plus = 20.0;  // ms
    double tau_minus = 20.0; // ms
    double A_plus = 0.05;  // Amplitude for LTP
    double A_minus = 0.0675; // Amplitude for LTD


public:
    // Constructor initializes the synapse with pre and post neurons and the synaptic weight
    Synapse(Neuron* pre_neuron, Neuron* post_neuron, bool excitatory, double spike_threshold);


    void apply_stdp(double current_time);
    void apply_hebbian_learning(double delta_time);

    // Apply the post-synaptic neuron based on the activity of the pre-synaptic neuron
    void apply(double delta_time);


    // Adjust the synaptic weight by a given delta
    void adjust_weight(double delta);


    // Update for STDP
    void on_presynaptic_spike(double time);
    void on_postsynaptic_spike(double time);
    void update_stdp(double delta_time);

    // Utility functions
    double get_weight() const;
    void set_weight(double new_weight);
    double get_synaptic_efficacy() const; // Get the current synaptic efficacy based on weight and neurotransmitter dynamics
    Neuron* get_pre_neuron() const;
    Neuron* get_post_neuron() const;
};

#endif // SYNAPSE_H
#include "synapse.h"
#include "neuron.h"



Synapse::Synapse(Neuron* pre_neuron, Neuron* post_neuron, bool excitatory, double spike_threshold)
    : pre(pre_neuron), post(post_neuron), is_excitatory(excitatory), spike_threshold(spike_threshold)
{
    weight = (is_excitatory ? random<double>(0.0, 0.6) : random<double>(-0.6, 0.0));
    E_syn = (is_excitatory ? E_syn_excitatory : E_syn_inhibitory);
}

void Synapse::apply_stdp(double current_time) {
    // Assuming that 'last_pre_spike_time' and 'last_post_spike_time' are updated elsewhere
    double dt = last_post_spike_time - last_pre_spike_time;
    if (dt > 0) {
        weight -= A_minus * exp(-dt / tau_minus); // should decrease if post-neuron fires after pre-neuron
    }
    else if (dt < 0) {
        weight += A_plus * exp(dt / tau_plus); // should increase if pre-neuron fires before post-neuron
    }

    // Keep weight within bounds
    weight = std::clamp(weight, is_excitatory ? 0.0 : -1.0, is_excitatory ? 1.0 : 0.0);

    // Reset the timing since we've already accounted for the latest spike
    if (pre->get_last_spike_time() == current_time) {
        last_post_spike_time = -1;
    }
    if (post->get_last_spike_time() == current_time) {
        last_pre_spike_time = -1;
    }
}

void Synapse::apply_hebbian_learning(double delta_time) {
    double learning_rate = 0.01; // Consider making this a class member for better control
    if (pre->get_membrane_potential() > spike_threshold && post->get_membrane_potential() > spike_threshold) {
        weight += learning_rate + random<double>(-0.001, 0.001);
    }
    else {
        weight -= learning_rate + random<double>(-0.001, 0.001);
    }
    weight = std::clamp(weight, is_excitatory ? 0.0 : -1.0, is_excitatory ? 1.0 : 0.0);
}


void Synapse::apply(double delta_time) {
    // Neurotransmitter release and decay logic
    nt_release = (pre->get_membrane_potential() > spike_threshold) ? nt_max : nt_release * nt_decay;

    // Synaptic current calculation based on receptor activation and conductance
    receptor_activated += (nt_release - receptor_activated) * receptor_binding_rate * delta_time;
    double synaptic_current = receptor_activated * conductance * (post->get_membrane_potential() - E_syn);

    // Correct calculation above is wrong
    //receptor_activated += (nt_release - receptor_activated) * receptor_binding_rate * delta_time;
    //double synaptic_current = receptor_activated * conductance * (post->V - reversal_potential);

    // Check for NaN or Inf in synaptic_current after updating it
    if (std::isnan(synaptic_current) || std::isinf(synaptic_current)) {
        std::cerr << "Error: Synaptic current is NaN or Inf" << std::endl;
        throw std::runtime_error("Invalid synaptic current value.");
    }

    // Deliver the synaptic current to the post-synaptic neuron
    post->receive(synaptic_current, delta_time);

    // STDP and Hebbian learning are only applied when a spike occurs
    if (pre->get_membrane_potential() > spike_threshold) {
        on_presynaptic_spike(pre->get_last_spike_time());
    }

    if (post->get_membrane_potential() > spike_threshold) {
        on_postsynaptic_spike(post->get_last_spike_time());
    }

    // Call STDP for the current time step
    apply_stdp(delta_time);

    // Call Hebbian learning for the current time step
    apply_hebbian_learning(delta_time);
}

void Synapse::adjust_weight(double delta) {
    weight += delta;
    // Clamp the weight to ensure it stays within physiological limits
    weight = std::clamp(weight, is_excitatory ? 0.0 : -1.0, is_excitatory ? 1.0 : 0.0);
}

void Synapse::on_presynaptic_spike(double time) {
    double dt = time - last_post_spike_time;
    if (dt <= 0) {
        weight *= std::exp(dt / stdp_time_constant); // Depression
    }
    else {
        weight *= std::exp(-dt / stdp_time_constant); // Potentiation
    }
    // Ensure the weight remains within physiological limits
    weight = std::clamp(weight, is_excitatory ? 0.0 : -1.0, is_excitatory ? 1.0 : 0.0);

    last_pre_spike_time = time;
}


// Now for the on_postsynaptic_spike method
void Synapse::on_postsynaptic_spike(double time) {
    double dt = time - last_pre_spike_time;
    if (dt > 0) {
        // Long-Term Potentiation (LTP)
        weight += A_plus * exp(-dt / tau_plus);
    }
    else {
        // Long-Term Depression (LTD)
        weight += A_minus * exp(dt / tau_minus);
    }
    // Ensure the weight remains within physiological limits
    weight = std::clamp(weight, is_excitatory ? 0.0 : -1.0, is_excitatory ? 1.0 : 0.0);

    last_post_spike_time = time;
}



double Synapse::get_weight() const {
    return weight;
}
void Synapse::set_weight(double new_weight) {
    weight = new_weight;
}
double Synapse::get_synaptic_efficacy() const {
    return is_excitatory ? weight * receptor_activated : -weight * receptor_activated;
}
Neuron* Synapse::get_pre_neuron() const {
    return pre;
}
Neuron* Synapse::get_post_neuron() const {
    return post;
}

main.cpp 的上下文:是模拟的入口点,其中创建并连接神经元和突触的实例。它包含主模拟循环,该循环迭代调用求解函数,更新神经元和突触状态,并应用任何学习规则。main.cpp

#include <vector>
#include <cmath>
#include <iostream>
#include <unordered_map>
#include <numeric> // for std::accumulate
#include "neuron.h" // Your neuron class
#include "synapse.h" // Your synapse class
#include "solve.h" // Assuming this contains necessary integration methods
#include "params.h" // Assuming this contains necessary parameter structures
#include "random.h" // For random initialization

#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

const int num_neurons = 10; // Number of neurons in the network
const double time_step = 0.01; // The time step for simulation in milliseconds
const double learning_rate = 0.01; // The learning rate for synaptic updates

// Function to generate the sine wave values
std::vector<double> generate_sine_wave(int num_points, double frequency, double amplitude) {
    std::vector<double> sine_wave(num_points);
    for (int i = 0; i < num_points; ++i) {
        sine_wave[i] = amplitude * std::sin(2 * M_PI * frequency * i * time_step / 1000.0);
    }
    return sine_wave;
}

// Function to calculate mean squared error
double calculate_mse(const std::vector<double>& errors) {
    double mse = std::accumulate(errors.begin(), errors.end(), 0.0,
        [](double sum, double err) { return sum + err * err; });
    mse /= errors.size();
    return mse;
}

// Training function
void train_network(std::vector<Neuron>& neurons, const std::vector<double>& sine_wave, int max_epochs, double mse_threshold) {
    std::vector<double> errors;
    double mse = 0.0;
    for (int epoch = 0; epoch < max_epochs; ++epoch) {
        errors.clear();
        for (int t = 0; t < sine_wave.size(); ++t) {
            double target = sine_wave[t];

            // Inject the sine wave as input current to the first neuron
            neurons[0].step(target, time_step, time_step * t);
            for (int n = 1; n < neurons.size(); ++n) {
                neurons[n].step(0, time_step, time_step * t);
            }

            // Calculate error and update weights
            double output = neurons.back().get_membrane_potential();
            double error = target - output;
            errors.push_back(error);

            // Update synaptic weights
            for (auto& synapse : neurons.back().get_all_synapses()) {
                double delta_weight = learning_rate * error * synapse->get_synaptic_efficacy();
                synapse->adjust_weight(delta_weight);
            }
        }

        mse = calculate_mse(errors);

        // Display the MSE every epoch
        std::cout << "Epoch: " << epoch << " / " << max_epochs << ", MSE: " << mse << std::endl;

        if (mse < mse_threshold) {
            std::cout << "Convergence reached. Stopping training." << std::endl;
            break;
        }
    }
}



int main() {
    /*
    NeuronParams params{
        7.15,       // gNa
        50.0,       // ENa
        1.43,       // gK
        -95.0,      // EK
        0.02672,    // gL
        -63.563,    // El
        0.143       // Cm
    };

    InitialValues init{
        -60,      // V
        0.0529324,  // m
        0.3176767,  // h
        0.5961207   // n
    };
    */

    NeuronParams params{
        120,       // gNa
        112,       // ENa
        36,       // gK
        -12,      // EK
        0.3,    // gL
        10.613,    // El
        1       // Cm
    };

    InitialValues init{
        -65,      // V
        0.0003,  // m
        0.9998,  // h
        0.0011   // n
    };

    std::vector<Neuron> neurons;
    for (int i = 0; i < num_neurons; ++i) {
        neurons.emplace_back(params, init);
    }

    // Use an unordered map to store unique neuron connections.
    std::unordered_map<int, std::unordered_map<int, bool>> connections;

    // Connect neurons with synapses, ensuring no duplicates.
    int max_synapses = num_neurons * (num_neurons - 1); // Maximum possible synapses.
    for (int i = 0; i < max_synapses && i < num_neurons; ++i) {

        int n1 = random<int>(0, num_neurons - 1); // 0 to num_neurons - 1
        int n2 = random<int>(0, num_neurons - 1); // 0 to num_neurons - 1

        if (n1 != n2 && !connections[n1][n2]) { // Check if connection doesn't already exist
            connections[n1][n2] = true;
            neurons[n1].connect_to(&neurons[n2]);
        }
        else {
            i--;
        }
    }

   
    for (const auto& neuron : neurons) {
        for (const auto& synapse : neuron.get_all_synapses()) {
            if (synapse->get_pre_neuron() == &neuron) {  // This ensures we only print the synapse once
                std::cout << "Connection: " << synapse->get_pre_neuron() << " to " << synapse->get_post_neuron() << std::endl;
                std::cout << "Synapse weight: " << synapse->get_weight() << std::endl;
            }
        }
    }

    int elements = 1000;
    double frequency = 1.0;
    double amplitude = 1.0;

    // Generate sine wave data
    std::vector<double> sine_wave = generate_sine_wave(elements, frequency, amplitude);

    train_network(neurons, sine_wave, 1000, 0.01);
    return 0;
}
C++ 算法 机器学习 数学 神经网络

评论

0赞 Sam Varshavchik 11/4/2023
很抱歉听到您“在使用正弦波输入实现学习机制时遇到困难”,但您的具体问题是什么?Stackoverflow 不是一个众包的调试服务,我们不调试别人的代码。我们只回答特定的编程问题。在 Stackoverflow 上发布您的第一个问题之前,您是否参加了导览、阅读了帮助中心并学习了如何提问
0赞 Community 11/5/2023
请编辑问题,将其限制在特定问题上,并具有足够的细节以确定适当的答案。

答: 暂无答案