Example 2

This is an example that reproduces Figure 5 from Genkin, M., Engel, T.A., Nat Mach Intell 2, 674–683 (2020). Here we apply the proposed model selection method based on features consistency on synthetic data generated from the double-well potential. This example uses an optinal neuralflow module neuralflow.utilities.FC_stationary, which provides the utility functions for stationary feature consistency model selection.

[1]:
import neuralflow
from neuralflow.utilities.FC_stationary import FeatureConsistencyAnalysis_st, FeatureComplexity_st
import numpy as np
import matplotlib.pyplot as plt, matplotlib.gridspec as gridspec

Step 1: Load data and calculate Feature complexities (FC)

In this example, data will be loaded from npz files. To generate these files, we used the same double-well ground-truth model as in Example 1 (also see FIG. 3,4,5 in the main text). We trained and validated the model on two data samples D1 and D2. Each data sample contains ~20,000 spikes, and the data was split into two equal non-overlapping parts (D1=D11+D12, D2=D21+D22), where the first part was used for training and the second - for validation.

Fitting results were saved in two data files (one for each data sample). Each data file contains the following entries:

iter_num: array of intergers, iteration numbers on which peqs were recorded,

peqs: 2D array of fitted peqs (only recorded at iterations specified by iter_num array),

logliks: negative training loglikelihoods recorded on each iteration,

logliksCV: negative validated loglikelihoods recorded on each iteration.

For each data sample, we trained a model for 50,000 iterations. These data files can be straight-forwardly generated using by extending the code from Example 1. Here we include the precalculated fitting results since fitting 50,000 iterations may take a long time. This data is the same data as was used in FIG. 3,4,5, but fitting results (peqs) are saved sparsely to keep the data files small. Thus, the generated figures may slightly differ from FIG. 5 in the main text.

  1. Specify the ground-truth model. Extract spectral element method (SEM) integration weights and differentiation matrix.

  2. Load data files with the fitting results and convert them to the dictionaries.

  3. For each fitted peq and the ground-truth peq, caclulate the corresponding feature complexities.

[2]:
EnergyModelParams = {'pde_solve_param':{'method':{'name': 'SEM', 'gridsize': {'Np': 8, 'Ne': 256}}},
               'Nv': 64,
               'peq_model':{"model": "double_well", "params": {"xmin": 0.6, "xmax": 0.0, "depth": 2}},
               'D0': 10,
               'num_neuron':1,
               'firing_model':[{"model": "rectified_linear", "params": {"r_slope": 100, "x_thresh": -1}}],
               'verbose':True
               }

em_gt = neuralflow.EnergyModel(**EnergyModelParams)
grid_params = {'w':em_gt.w_d_, 'dmat':em_gt.dmat_d_}

data1 = dict(np.load('data/Ex2_datasample1.npz',allow_pickle=True))
data2 = dict(np.load('data/Ex2_datasample2.npz',allow_pickle=True))

data1['FCs'] = FeatureComplexity_st(data1['peqs'],grid_params)
data2['FCs'] = FeatureComplexity_st(data2['peqs'],grid_params)
FC_gt = FeatureComplexity_st(em_gt.peq_,grid_params)

Step 2: Define the hyperparameters for feature consistency analysis

Here we define the following hyperparameters of our feature consistency method:

KL_thres: Threshold Kullback-Leibler divergence that defines a point when two models start to diverge (see Methods).

FC_radius: Feature complexity radius that determines a slack in features complexities: instead of comparing models with exactly the same feature complexities, we allow some slack in the feature complexities of the two models. (see Methods).

In addtion, we define the following hyperparameters:

KL_thres_late: Same as KL_thres, but with a higher value. This threshold will be used to demonstrate that high KL thresholds lead to disagreement in the selected potentials.

FC_final: maximum feature complexity explored by the feature consistency analysis.

FC_stride: FC resolution for the feature complexity axis.

[3]:
KL_thres=0.01
FC_radius = 1

KL_thres_late=0.03
FC_final = 25
FC_stride = 0.1

FC_options= [KL_thres, FC_radius,  FC_final, FC_stride]

Step 3: Perform feature consistency analysis

  1. For the analysis, use FeatureConsistencyAnalysis function that returns shared FC axis, KL divergencies, the index of optimal FC in FC_shared, and the indices of peqs and FCs in the original data arrays that correspond to each FC in the FC_shared array.

  2. Determine optimal FC (FC_opt), as well as early and late FCs. The late feature complexity is found by thresholding the KL with KL_thres_late. The early FC is defined as FC_opt-4.

[4]:
FC_shared, KL, FC_opt_ind, ind1_shared, ind2_shared = FeatureConsistencyAnalysis_st(data1,data2, grid_params, *FC_options)
FC_opt = FC_shared [FC_opt_ind]

FC_late_ind = np.where(KL > KL_thres_late)[0][0]-1
FC_late = FC_shared[FC_late_ind]

FC_early_ind = np.where(FC_shared>FC_opt-4)[0][0]
FC_early = FC_shared[FC_early_ind]

Step 4: Visualise the results

[5]:
fig=plt.figure(figsize=(20,7))
gs=gridspec.GridSpec(2,3,height_ratios=[3,2],hspace=0.5)
line_colors = [[0, 127/255, 1], [239/255, 48/255, 84/255], [0.5, 0.5, 0.5]]
dot_colors = [[0.6,0.6,0.6], [1, 169/255, 135/255],  [147/255, 192/255, 164/255]]

ax = plt.subplot(gs[0])
ax.plot(data1['iter_num'],data1['FCs'],color=line_colors[0],linewidth=3,label='Data sample 1')
ax.plot(data2['iter_num'],data2['FCs'],color=line_colors[1],linewidth=3,label='Data sample 2')
ax.hlines(FC_gt,data1['iter_num'][0],data1['iter_num'][-1],color=line_colors[2],linewidth=2,label='Ground truth')
plt.xscale('log')
plt.xlabel('iteration number')
plt.ylabel('Feature complexity')

ax=plt.subplot(gs[1])
llCV=data1['logliksCV'][data1['iter_num']]
llCV = (llCV-llCV[0])/(np.max(llCV)-np.min(llCV))
ax.plot(data1['FCs'], llCV,color=line_colors[0],linewidth=3)
llCV=data2['logliksCV'][data2['iter_num']]
llCV = (llCV-llCV[0])/(np.max(llCV)-np.min(llCV))
ax.plot(data2['FCs'], llCV,color=line_colors[1],linewidth=3)
ax.plot(FC_early,llCV[np.argmin(np.abs(data2['FCs']-FC_early))],'.',markersize=20,color=dot_colors[0])
ax.plot(FC_opt,llCV[np.argmin(np.abs(data2['FCs']-FC_opt))],'.',markersize=20,color=dot_colors[1])
ax.plot(FC_late,llCV[np.argmin(np.abs(data2['FCs']-FC_late))],'.',markersize=20,color=dot_colors[2])
plt.xlabel('Feature complexity')
plt.ylabel(r'$-\log\mathcal{L}$', fontsize=18)


ax=plt.subplot(gs[2])
ax.plot(FC_shared,KL, color = [0.47, 0.34, 0.66],linewidth=3)
ax.plot(FC_early,KL[np.argmin(np.abs(FC_shared-FC_early))],'.',markersize=20,color=dot_colors[0])
ax.plot(FC_opt,KL[np.argmin(np.abs(FC_shared-FC_opt))],'.',markersize=20,color=dot_colors[1])
ax.plot(FC_late,KL[np.argmin(np.abs(FC_shared-FC_late))],'.',markersize=20,color=dot_colors[2])
plt.xlabel('Feature complexity')
plt.ylabel('KL divergence')


ax=plt.subplot(gs[3])
ax.plot(em_gt.x_d_,np.minimum(-np.log(data1['peqs'][...,ind1_shared[FC_early_ind]]),6),color=line_colors[0],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(data2['peqs'][...,ind2_shared[FC_early_ind]]),6),color=line_colors[1],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(em_gt.peq_),6),color=[0.5, 0.5, 0.5],linewidth=2)
plt.xlabel('latent state, x')
plt.ylabel(r'$-\log\mathcal{L}$', fontsize=18)


ax=plt.subplot(gs[4])
ax.plot(em_gt.x_d_,np.minimum(-np.log(data1['peqs'][...,ind1_shared[FC_opt_ind]]),6),color=line_colors[0],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(data2['peqs'][...,ind2_shared[FC_opt_ind]]),6),color=line_colors[1],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(em_gt.peq_),6),color=[0.5, 0.5, 0.5],linewidth=2)
plt.xlabel('latent state, x')
plt.ylabel(r'$-\log\mathcal{L}$', fontsize=18)

ax=plt.subplot(gs[5])
ax.plot(em_gt.x_d_,np.minimum(-np.log(data1['peqs'][...,ind1_shared[FC_late_ind]]),6),color=line_colors[0],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(data2['peqs'][...,ind2_shared[FC_late_ind]]),6),color=line_colors[1],linewidth=3)
ax.plot(em_gt.x_d_,np.minimum(-np.log(em_gt.peq_),6),color=[0.5, 0.5, 0.5],linewidth=2)
plt.xlabel('latent state, x')
plt.ylabel(r'$-\log\mathcal{L}$', fontsize=18)

The code above should produce the following image: Jupyter notebook icon