Generalized Dirichlet-Multinomial regression and sparse regression

A demo of gen. Dirichlet-Multinomial regression and sparse regression

Contents

Generate generalized Dirichlet-Multinomial random vectors from covariates

clear;
% reset random seed
s = RandStream('mt19937ar','Seed',1);
RandStream.setGlobalStream(s);
% sample size
n = 500;
% # covariates
p = 15;
% # bins
d = 5;
% design matrix
X = randn(n,p);
% true regression coefficients
A = zeros(p,d-1);
B = zeros(p,d-1);
nzidx = [1 3 5];
A(nzidx,:) = 0.5.*ones(length(nzidx),d-1);
B(nzidx,:) = 0.5.*ones(length(nzidx),d-1);
alpha = exp(X*A);
beta = exp(X*B);
batchsize = 25+unidrnd(25,n,1);
Y = gendirmnrnd(batchsize,alpha, beta);

Fit generalized Dirichlet-Multinomial regression

tic;
[Bhat1,Bhat2,stats_gdm] = gendirmnreg(X,Y);
toc;
display(Bhat1);
display(Bhat2);
display(stats_gdm);
display(stats_gdm.se);
display(stats_gdm.wald_pvalue);
Elapsed time is 0.279083 seconds.

Bhat1 =

    0.4044    0.4217    0.4496    0.3368
   -0.0299   -0.0123    0.0575   -0.0418
    0.3943    0.5709    0.5170    0.5504
   -0.1911    0.0241    0.0455    0.0265
    0.5738    0.6706    0.6396    0.4776
   -0.0031   -0.0937    0.0118   -0.0098
   -0.0145    0.0940   -0.0978    0.1049
    0.0084   -0.1723    0.0355   -0.0211
    0.0555   -0.1403   -0.0538   -0.0038
   -0.0237   -0.0899    0.0175    0.0732
    0.0632    0.1570    0.0418    0.1159
   -0.0457   -0.0442    0.0503    0.1275
   -0.0062   -0.0604    0.0016    0.0843
    0.0931    0.0649    0.0346    0.1222
   -0.0544   -0.1110    0.1017    0.0154


Bhat2 =

    0.3881    0.3333    0.4765    0.3418
    0.0041   -0.0254    0.1981   -0.1771
    0.4263    0.5280    0.4249    0.5621
   -0.2371    0.0633    0.1171    0.1089
    0.5609    0.6077    0.6085    0.5299
   -0.0288   -0.0711   -0.0084   -0.0862
    0.0126    0.0797   -0.0800    0.1977
   -0.0098   -0.2746    0.0930    0.0466
    0.0854   -0.1082   -0.1854   -0.0037
   -0.0059   -0.0076    0.0470    0.0271
    0.1780    0.0620   -0.0126    0.1101
   -0.0260   -0.0153   -0.0137    0.0446
   -0.0083   -0.0333    0.0510    0.0935
    0.0969    0.0572    0.0486    0.0835
   -0.0758   -0.0748    0.2051    0.0225


stats_gdm = 

  struct with fields:

                gradient: [15×8 double]
                      se: [15×8 double]
               wald_stat: [1×15 double]
             wald_pvalue: [1×15 double]
                       H: [120×120 double]
    observed_information: [120×120 double]
                    logL: -4.5644e+03
                     BIC: 9.8746e+03
                     AIC: 9.3689e+03
                     dof: 120
              iterations: 19

  Columns 1 through 7

    0.0715    0.0839    0.1085    0.1604    0.0800    0.0908    0.1027
    0.0743    0.0837    0.1015    0.1409    0.0787    0.0827    0.1074
    0.0667    0.0838    0.1043    0.1407    0.0709    0.0867    0.1074
    0.0676    0.0856    0.1065    0.1486    0.0701    0.0833    0.1066
    0.0727    0.0925    0.1177    0.1689    0.0772    0.0908    0.1144
    0.0678    0.0806    0.1059    0.1400    0.0755    0.0822    0.1011
    0.0704    0.0828    0.1042    0.1437    0.0751    0.0802    0.1105
    0.0730    0.0878    0.1107    0.1468    0.0790    0.0873    0.1162
    0.0714    0.0854    0.1071    0.1453    0.0757    0.0847    0.1101
    0.0739    0.0820    0.1113    0.1627    0.0765    0.0824    0.1082
    0.0761    0.0914    0.1063    0.1411    0.0808    0.0921    0.1209
    0.0722    0.0802    0.1039    0.1423    0.0769    0.0841    0.1074
    0.0675    0.0745    0.0975    0.1223    0.0701    0.0820    0.1058
    0.0672    0.0794    0.1019    0.1279    0.0779    0.0818    0.1078
    0.0745    0.0849    0.1134    0.1527    0.0795    0.0846    0.1141

  Column 8

    0.1490
    0.1400
    0.1397
    0.1437
    0.1654
    0.1389
    0.1375
    0.1461
    0.1352
    0.1508
    0.1418
    0.1367
    0.1281
    0.1232
    0.1522

  Columns 1 through 7

    0.0000    0.0701         0    0.0065         0    0.7036    0.3447

  Columns 8 through 14

    0.0223    0.0857    0.5563    0.0319    0.5872    0.8417    0.5690

  Column 15

    0.2119

Fit generalized Dirichlet-Multinomial sparse regression - lasso/group/nuclear penalty

penalty = {'sweep','group','nuclear'};
ngridpt = 10;
dist = 'gendirmn';

for i = 1:length(penalty)

    pen = penalty{i};
    [~, stats] = mglm_sparsereg(X,Y,inf,'penalty',pen,'dist',dist);
    maxlambda = stats.maxlambda;

    lambdas = exp(linspace(log(maxlambda),log(maxlambda/100),ngridpt));
    BICs = zeros(1,ngridpt);
    tic;
    for j=1:ngridpt
        if j==1
            B0 = zeros(p,2*(d-1));
        else
            B0 = B_hat;
        end
        [B_hat, stats] = mglm_sparsereg(X,Y,lambdas(j),'penalty',pen, ...
            'dist',dist,'B0',B0);
        BICs(j) = stats.BIC;
    end
    toc;

    % True signal versus estimated signal
    [bestbic,bestidx] = min(BICs);
    lambdas(bestidx)
    B_best = mglm_sparsereg(X,Y,lambdas(bestidx),'penalty',pen,'dist',dist);
    figure;
    subplot(1,3,1);
    semilogx(lambdas,BICs);
    ylabel('BIC');
    xlabel('\lambda');
    xlim([min(lambdas) max(lambdas)]);
    subplot(1,3,2);
    imshow(mat2gray(-[A,B])); title('True B');
    subplot(1,3,3);
    imshow(mat2gray(-B_best)); title([pen ' estimate']);

end
Elapsed time is 0.455057 seconds.

ans =

   11.4994

Elapsed time is 0.357541 seconds.

ans =

   29.0023

Elapsed time is 0.439528 seconds.

ans =

   77.4753