Generalized Dirichlet-Multinomial regression and sparse regression
A demo of gen. Dirichlet-Multinomial regression and sparse regression
Contents
Generate generalized Dirichlet-Multinomial random vectors from covariates
clear; % reset random seed s = RandStream('mt19937ar','Seed',1); RandStream.setGlobalStream(s); % sample size n = 500; % # covariates p = 15; % # bins d = 5; % design matrix X = randn(n,p); % true regression coefficients A = zeros(p,d-1); B = zeros(p,d-1); nzidx = [1 3 5]; A(nzidx,:) = 0.5.*ones(length(nzidx),d-1); B(nzidx,:) = 0.5.*ones(length(nzidx),d-1); alpha = exp(X*A); beta = exp(X*B); batchsize = 25+unidrnd(25,n,1); Y = gendirmnrnd(batchsize,alpha, beta);
Fit generalized Dirichlet-Multinomial regression
tic; [Bhat1,Bhat2,stats_gdm] = gendirmnreg(X,Y); toc; display(Bhat1); display(Bhat2); display(stats_gdm); display(stats_gdm.se); display(stats_gdm.wald_pvalue);
Elapsed time is 0.279083 seconds.
Bhat1 =
0.4044 0.4217 0.4496 0.3368
-0.0299 -0.0123 0.0575 -0.0418
0.3943 0.5709 0.5170 0.5504
-0.1911 0.0241 0.0455 0.0265
0.5738 0.6706 0.6396 0.4776
-0.0031 -0.0937 0.0118 -0.0098
-0.0145 0.0940 -0.0978 0.1049
0.0084 -0.1723 0.0355 -0.0211
0.0555 -0.1403 -0.0538 -0.0038
-0.0237 -0.0899 0.0175 0.0732
0.0632 0.1570 0.0418 0.1159
-0.0457 -0.0442 0.0503 0.1275
-0.0062 -0.0604 0.0016 0.0843
0.0931 0.0649 0.0346 0.1222
-0.0544 -0.1110 0.1017 0.0154
Bhat2 =
0.3881 0.3333 0.4765 0.3418
0.0041 -0.0254 0.1981 -0.1771
0.4263 0.5280 0.4249 0.5621
-0.2371 0.0633 0.1171 0.1089
0.5609 0.6077 0.6085 0.5299
-0.0288 -0.0711 -0.0084 -0.0862
0.0126 0.0797 -0.0800 0.1977
-0.0098 -0.2746 0.0930 0.0466
0.0854 -0.1082 -0.1854 -0.0037
-0.0059 -0.0076 0.0470 0.0271
0.1780 0.0620 -0.0126 0.1101
-0.0260 -0.0153 -0.0137 0.0446
-0.0083 -0.0333 0.0510 0.0935
0.0969 0.0572 0.0486 0.0835
-0.0758 -0.0748 0.2051 0.0225
stats_gdm =
struct with fields:
gradient: [15×8 double]
se: [15×8 double]
wald_stat: [1×15 double]
wald_pvalue: [1×15 double]
H: [120×120 double]
observed_information: [120×120 double]
logL: -4.5644e+03
BIC: 9.8746e+03
AIC: 9.3689e+03
dof: 120
iterations: 19
Columns 1 through 7
0.0715 0.0839 0.1085 0.1604 0.0800 0.0908 0.1027
0.0743 0.0837 0.1015 0.1409 0.0787 0.0827 0.1074
0.0667 0.0838 0.1043 0.1407 0.0709 0.0867 0.1074
0.0676 0.0856 0.1065 0.1486 0.0701 0.0833 0.1066
0.0727 0.0925 0.1177 0.1689 0.0772 0.0908 0.1144
0.0678 0.0806 0.1059 0.1400 0.0755 0.0822 0.1011
0.0704 0.0828 0.1042 0.1437 0.0751 0.0802 0.1105
0.0730 0.0878 0.1107 0.1468 0.0790 0.0873 0.1162
0.0714 0.0854 0.1071 0.1453 0.0757 0.0847 0.1101
0.0739 0.0820 0.1113 0.1627 0.0765 0.0824 0.1082
0.0761 0.0914 0.1063 0.1411 0.0808 0.0921 0.1209
0.0722 0.0802 0.1039 0.1423 0.0769 0.0841 0.1074
0.0675 0.0745 0.0975 0.1223 0.0701 0.0820 0.1058
0.0672 0.0794 0.1019 0.1279 0.0779 0.0818 0.1078
0.0745 0.0849 0.1134 0.1527 0.0795 0.0846 0.1141
Column 8
0.1490
0.1400
0.1397
0.1437
0.1654
0.1389
0.1375
0.1461
0.1352
0.1508
0.1418
0.1367
0.1281
0.1232
0.1522
Columns 1 through 7
0.0000 0.0701 0 0.0065 0 0.7036 0.3447
Columns 8 through 14
0.0223 0.0857 0.5563 0.0319 0.5872 0.8417 0.5690
Column 15
0.2119
Fit generalized Dirichlet-Multinomial sparse regression - lasso/group/nuclear penalty
penalty = {'sweep','group','nuclear'};
ngridpt = 10;
dist = 'gendirmn';
for i = 1:length(penalty)
pen = penalty{i};
[~, stats] = mglm_sparsereg(X,Y,inf,'penalty',pen,'dist',dist);
maxlambda = stats.maxlambda;
lambdas = exp(linspace(log(maxlambda),log(maxlambda/100),ngridpt));
BICs = zeros(1,ngridpt);
tic;
for j=1:ngridpt
if j==1
B0 = zeros(p,2*(d-1));
else
B0 = B_hat;
end
[B_hat, stats] = mglm_sparsereg(X,Y,lambdas(j),'penalty',pen, ...
'dist',dist,'B0',B0);
BICs(j) = stats.BIC;
end
toc;
% True signal versus estimated signal
[bestbic,bestidx] = min(BICs);
lambdas(bestidx)
B_best = mglm_sparsereg(X,Y,lambdas(bestidx),'penalty',pen,'dist',dist);
figure;
subplot(1,3,1);
semilogx(lambdas,BICs);
ylabel('BIC');
xlabel('\lambda');
xlim([min(lambdas) max(lambdas)]);
subplot(1,3,2);
imshow(mat2gray(-[A,B])); title('True B');
subplot(1,3,3);
imshow(mat2gray(-B_best)); title([pen ' estimate']);
end
Elapsed time is 0.455057 seconds. ans = 11.4994 Elapsed time is 0.357541 seconds. ans = 29.0023 Elapsed time is 0.439528 seconds. ans = 77.4753