Dirichlet-Multinomial regression and sparse regression

A demo of Dirichlet-Multinomial regression and sparse regression

Contents

Generate Dirichlet-Multinomial random vectors from covariates

clear;
% reset random seed
s = RandStream('mt19937ar','Seed',1);
RandStream.setGlobalStream(s);
% sample size
n = 200;
% # covariates
p = 15;
% # bins
d = 5;
% design matrix
X = randn(n,p);
% true regression coefficients
B = zeros(p,d);
nzidx = [1 3 5];
B(nzidx,:) = ones(length(nzidx),d);
alpha = exp(X*B);
batchsize = 25+unidrnd(25,n,1);
Y = dirmnrnd(batchsize,alpha);

zerorows = sum(Y,2);
Y=Y(zerorows~=0, :);
X=X(zerorows~=0, :);

Fit Dirichlet-Multinomial regression

tic;
[B_hat, stats_dm] = dirmnreg(X,Y);
toc;
display(B_hat);
display(stats_dm.se);
display(stats_dm);
% Wald test of predictor significance
display('Wald test p-values:');
display(stats_dm.wald_pvalue);

figure;
plot(stats_dm.logL_iter);
xlabel('iteration');
ylabel('log-likelihood');
Elapsed time is 0.074675 seconds.

B_hat =

    1.0433    1.0273    1.0261    1.0718    0.9714
    0.0732   -0.0784    0.1226   -0.0149    0.0134
    1.2702    1.2427    1.0411    1.1994    1.2891
   -0.0160    0.0594    0.0091    0.0163    0.0922
    1.0395    0.9820    1.1324    0.9810    1.1368
   -0.0436    0.0190   -0.0366    0.0211    0.0049
    0.0587    0.1417    0.1773    0.2716    0.1444
   -0.0450   -0.0731   -0.0828   -0.0449   -0.0730
   -0.0067    0.0120    0.0913    0.1893   -0.0101
    0.0947    0.2141    0.1521    0.0770    0.0679
    0.0713    0.1136    0.0868    0.0604    0.1609
   -0.0220   -0.1262   -0.1545   -0.1011   -0.1984
   -0.1732   -0.1461   -0.0800   -0.0821    0.0261
   -0.1059   -0.0128   -0.0464   -0.0136   -0.1410
   -0.0926   -0.1360   -0.0706   -0.0864   -0.0902

    0.1057    0.1004    0.1013    0.1042    0.0994
    0.1047    0.1008    0.1023    0.1021    0.0994
    0.1034    0.0994    0.0992    0.1002    0.0992
    0.0928    0.0950    0.0940    0.0943    0.0921
    0.1021    0.0987    0.0994    0.1038    0.1003
    0.0847    0.0825    0.0866    0.0870    0.0807
    0.0861    0.0859    0.0886    0.0881    0.0886
    0.0858    0.0860    0.0887    0.0869    0.0852
    0.0912    0.0852    0.0866    0.0881    0.0841
    0.0910    0.0887    0.0932    0.0945    0.0864
    0.0996    0.0963    0.0963    0.1007    0.0950
    0.0910    0.0914    0.0967    0.0928    0.0945
    0.0915    0.0873    0.0945    0.0969    0.0911
    0.0895    0.0834    0.0911    0.0880    0.0835
    0.0866    0.0865    0.0873    0.0879    0.0858


stats_dm = 

  struct with fields:

                     BIC: 4.2088e+03
                     AIC: 3.9614e+03
                     dof: 75
              iterations: 6
                    logL: -1.9057e+03
               logL_iter: [1×6 double]
                    yhat: [200×5 double]
                      se: [15×5 double]
               wald_stat: [1×15 double]
             wald_pvalue: [1×15 double]
                       H: [75×75 double]
                gradient: [75×1 double]
    observed_information: [75×75 double]

Wald test p-values:
  Columns 1 through 7

         0    0.2639         0    0.6872         0    0.9274    0.0312

  Columns 8 through 14

    0.9412    0.0450    0.2049    0.6178    0.1170    0.1066    0.3226

  Column 15

    0.7632

Fit Dirichlet-Multinomial sparse regression - lasso/group/nuclear penalty

penalty = {'sweep','group','nuclear'};
ngridpt = 20;
dist = 'dirmn';

for i = 1:length(penalty)

    pen = penalty{i};
    [~, stats] = mglm_sparsereg(X,Y,inf,'penalty',pen,'dist',dist);
    maxlambda = stats.maxlambda;

    lambdas = exp(linspace(log(maxlambda),log(maxlambda/100),ngridpt));
    BICs = zeros(1,ngridpt);
    tic;
    for j=1:ngridpt
        if j==1
            B0 = zeros(p,d);
        else
            B0 = B_hat;
        end
        [B_hat, stats] = mglm_sparsereg(X,Y,lambdas(j),'penalty',pen, ...
            'dist',dist,'B0',B0);
        BICs(j) = stats.BIC;
    end
    toc;

    % True signal versus estimated signal
    [bestbic,bestidx] = min(BICs);
    lambdas(bestidx)
    B_best = mglm_sparsereg(X,Y,lambdas(bestidx),'penalty',pen,'dist',dist);
    figure;
    subplot(1,3,1);
    semilogx(lambdas,BICs);
    ylabel('BIC');
    xlabel('\lambda');
    xlim([min(lambdas) max(lambdas)]);
    subplot(1,3,2);
    imshow(mat2gray(-B)); title('True B');
    subplot(1,3,3);
    imshow(mat2gray(-B_best)); title([pen ' estimate']);

end
Elapsed time is 0.408279 seconds.

ans =

   10.8980

Elapsed time is 0.458002 seconds.

ans =

   30.2672

Elapsed time is 0.670406 seconds.

ans =

  194.8746