Generalized Dirichlet-Multinomial regression and sparse regression
A demo of gen. Dirichlet-Multinomial regression and sparse regression
Contents
Generate generalized Dirichlet-Multinomial random vectors from covariates
clear; % reset random seed s = RandStream('mt19937ar','Seed',1); RandStream.setGlobalStream(s); % sample size n = 500; % # covariates p = 15; % # bins d = 5; % design matrix X = randn(n,p); % true regression coefficients A = zeros(p,d-1); B = zeros(p,d-1); nzidx = [1 3 5]; A(nzidx,:) = 0.5.*ones(length(nzidx),d-1); B(nzidx,:) = 0.5.*ones(length(nzidx),d-1); alpha = exp(X*A); beta = exp(X*B); batchsize = 25+unidrnd(25,n,1); Y = gendirmnrnd(batchsize,alpha, beta);
Fit generalized Dirichlet-Multinomial regression
tic; [Bhat1,Bhat2,stats_gdm] = gendirmnreg(X,Y); toc; display(Bhat1); display(Bhat2); display(stats_gdm); display(stats_gdm.se); display(stats_gdm.wald_pvalue);
Elapsed time is 0.279083 seconds. Bhat1 = 0.4044 0.4217 0.4496 0.3368 -0.0299 -0.0123 0.0575 -0.0418 0.3943 0.5709 0.5170 0.5504 -0.1911 0.0241 0.0455 0.0265 0.5738 0.6706 0.6396 0.4776 -0.0031 -0.0937 0.0118 -0.0098 -0.0145 0.0940 -0.0978 0.1049 0.0084 -0.1723 0.0355 -0.0211 0.0555 -0.1403 -0.0538 -0.0038 -0.0237 -0.0899 0.0175 0.0732 0.0632 0.1570 0.0418 0.1159 -0.0457 -0.0442 0.0503 0.1275 -0.0062 -0.0604 0.0016 0.0843 0.0931 0.0649 0.0346 0.1222 -0.0544 -0.1110 0.1017 0.0154 Bhat2 = 0.3881 0.3333 0.4765 0.3418 0.0041 -0.0254 0.1981 -0.1771 0.4263 0.5280 0.4249 0.5621 -0.2371 0.0633 0.1171 0.1089 0.5609 0.6077 0.6085 0.5299 -0.0288 -0.0711 -0.0084 -0.0862 0.0126 0.0797 -0.0800 0.1977 -0.0098 -0.2746 0.0930 0.0466 0.0854 -0.1082 -0.1854 -0.0037 -0.0059 -0.0076 0.0470 0.0271 0.1780 0.0620 -0.0126 0.1101 -0.0260 -0.0153 -0.0137 0.0446 -0.0083 -0.0333 0.0510 0.0935 0.0969 0.0572 0.0486 0.0835 -0.0758 -0.0748 0.2051 0.0225 stats_gdm = struct with fields: gradient: [15×8 double] se: [15×8 double] wald_stat: [1×15 double] wald_pvalue: [1×15 double] H: [120×120 double] observed_information: [120×120 double] logL: -4.5644e+03 BIC: 9.8746e+03 AIC: 9.3689e+03 dof: 120 iterations: 19 Columns 1 through 7 0.0715 0.0839 0.1085 0.1604 0.0800 0.0908 0.1027 0.0743 0.0837 0.1015 0.1409 0.0787 0.0827 0.1074 0.0667 0.0838 0.1043 0.1407 0.0709 0.0867 0.1074 0.0676 0.0856 0.1065 0.1486 0.0701 0.0833 0.1066 0.0727 0.0925 0.1177 0.1689 0.0772 0.0908 0.1144 0.0678 0.0806 0.1059 0.1400 0.0755 0.0822 0.1011 0.0704 0.0828 0.1042 0.1437 0.0751 0.0802 0.1105 0.0730 0.0878 0.1107 0.1468 0.0790 0.0873 0.1162 0.0714 0.0854 0.1071 0.1453 0.0757 0.0847 0.1101 0.0739 0.0820 0.1113 0.1627 0.0765 0.0824 0.1082 0.0761 0.0914 0.1063 0.1411 0.0808 0.0921 0.1209 0.0722 0.0802 0.1039 0.1423 0.0769 0.0841 0.1074 0.0675 0.0745 0.0975 0.1223 0.0701 0.0820 0.1058 0.0672 0.0794 0.1019 0.1279 0.0779 0.0818 0.1078 0.0745 0.0849 0.1134 0.1527 0.0795 0.0846 0.1141 Column 8 0.1490 0.1400 0.1397 0.1437 0.1654 0.1389 0.1375 0.1461 0.1352 0.1508 0.1418 0.1367 0.1281 0.1232 0.1522 Columns 1 through 7 0.0000 0.0701 0 0.0065 0 0.7036 0.3447 Columns 8 through 14 0.0223 0.0857 0.5563 0.0319 0.5872 0.8417 0.5690 Column 15 0.2119
Fit generalized Dirichlet-Multinomial sparse regression - lasso/group/nuclear penalty
penalty = {'sweep','group','nuclear'}; ngridpt = 10; dist = 'gendirmn'; for i = 1:length(penalty) pen = penalty{i}; [~, stats] = mglm_sparsereg(X,Y,inf,'penalty',pen,'dist',dist); maxlambda = stats.maxlambda; lambdas = exp(linspace(log(maxlambda),log(maxlambda/100),ngridpt)); BICs = zeros(1,ngridpt); tic; for j=1:ngridpt if j==1 B0 = zeros(p,2*(d-1)); else B0 = B_hat; end [B_hat, stats] = mglm_sparsereg(X,Y,lambdas(j),'penalty',pen, ... 'dist',dist,'B0',B0); BICs(j) = stats.BIC; end toc; % True signal versus estimated signal [bestbic,bestidx] = min(BICs); lambdas(bestidx) B_best = mglm_sparsereg(X,Y,lambdas(bestidx),'penalty',pen,'dist',dist); figure; subplot(1,3,1); semilogx(lambdas,BICs); ylabel('BIC'); xlabel('\lambda'); xlim([min(lambdas) max(lambdas)]); subplot(1,3,2); imshow(mat2gray(-[A,B])); title('True B'); subplot(1,3,3); imshow(mat2gray(-B_best)); title([pen ' estimate']); end
Elapsed time is 0.455057 seconds. ans = 11.4994 Elapsed time is 0.357541 seconds. ans = 29.0023 Elapsed time is 0.439528 seconds. ans = 77.4753