Dirichlet-Multinomial regression and sparse regression
A demo of Dirichlet-Multinomial regression and sparse regression
Contents
Generate Dirichlet-Multinomial random vectors from covariates
clear; % reset random seed s = RandStream('mt19937ar','Seed',1); RandStream.setGlobalStream(s); % sample size n = 200; % # covariates p = 15; % # bins d = 5; % design matrix X = randn(n,p); % true regression coefficients B = zeros(p,d); nzidx = [1 3 5]; B(nzidx,:) = ones(length(nzidx),d); alpha = exp(X*B); batchsize = 25+unidrnd(25,n,1); Y = dirmnrnd(batchsize,alpha); zerorows = sum(Y,2); Y=Y(zerorows~=0, :); X=X(zerorows~=0, :);
Fit Dirichlet-Multinomial regression
tic; [B_hat, stats_dm] = dirmnreg(X,Y); toc; display(B_hat); display(stats_dm.se); display(stats_dm); % Wald test of predictor significance display('Wald test p-values:'); display(stats_dm.wald_pvalue); figure; plot(stats_dm.logL_iter); xlabel('iteration'); ylabel('log-likelihood');
Elapsed time is 0.074675 seconds. B_hat = 1.0433 1.0273 1.0261 1.0718 0.9714 0.0732 -0.0784 0.1226 -0.0149 0.0134 1.2702 1.2427 1.0411 1.1994 1.2891 -0.0160 0.0594 0.0091 0.0163 0.0922 1.0395 0.9820 1.1324 0.9810 1.1368 -0.0436 0.0190 -0.0366 0.0211 0.0049 0.0587 0.1417 0.1773 0.2716 0.1444 -0.0450 -0.0731 -0.0828 -0.0449 -0.0730 -0.0067 0.0120 0.0913 0.1893 -0.0101 0.0947 0.2141 0.1521 0.0770 0.0679 0.0713 0.1136 0.0868 0.0604 0.1609 -0.0220 -0.1262 -0.1545 -0.1011 -0.1984 -0.1732 -0.1461 -0.0800 -0.0821 0.0261 -0.1059 -0.0128 -0.0464 -0.0136 -0.1410 -0.0926 -0.1360 -0.0706 -0.0864 -0.0902 0.1057 0.1004 0.1013 0.1042 0.0994 0.1047 0.1008 0.1023 0.1021 0.0994 0.1034 0.0994 0.0992 0.1002 0.0992 0.0928 0.0950 0.0940 0.0943 0.0921 0.1021 0.0987 0.0994 0.1038 0.1003 0.0847 0.0825 0.0866 0.0870 0.0807 0.0861 0.0859 0.0886 0.0881 0.0886 0.0858 0.0860 0.0887 0.0869 0.0852 0.0912 0.0852 0.0866 0.0881 0.0841 0.0910 0.0887 0.0932 0.0945 0.0864 0.0996 0.0963 0.0963 0.1007 0.0950 0.0910 0.0914 0.0967 0.0928 0.0945 0.0915 0.0873 0.0945 0.0969 0.0911 0.0895 0.0834 0.0911 0.0880 0.0835 0.0866 0.0865 0.0873 0.0879 0.0858 stats_dm = struct with fields: BIC: 4.2088e+03 AIC: 3.9614e+03 dof: 75 iterations: 6 logL: -1.9057e+03 logL_iter: [1×6 double] yhat: [200×5 double] se: [15×5 double] wald_stat: [1×15 double] wald_pvalue: [1×15 double] H: [75×75 double] gradient: [75×1 double] observed_information: [75×75 double] Wald test p-values: Columns 1 through 7 0 0.2639 0 0.6872 0 0.9274 0.0312 Columns 8 through 14 0.9412 0.0450 0.2049 0.6178 0.1170 0.1066 0.3226 Column 15 0.7632
Fit Dirichlet-Multinomial sparse regression - lasso/group/nuclear penalty
penalty = {'sweep','group','nuclear'}; ngridpt = 20; dist = 'dirmn'; for i = 1:length(penalty) pen = penalty{i}; [~, stats] = mglm_sparsereg(X,Y,inf,'penalty',pen,'dist',dist); maxlambda = stats.maxlambda; lambdas = exp(linspace(log(maxlambda),log(maxlambda/100),ngridpt)); BICs = zeros(1,ngridpt); tic; for j=1:ngridpt if j==1 B0 = zeros(p,d); else B0 = B_hat; end [B_hat, stats] = mglm_sparsereg(X,Y,lambdas(j),'penalty',pen, ... 'dist',dist,'B0',B0); BICs(j) = stats.BIC; end toc; % True signal versus estimated signal [bestbic,bestidx] = min(BICs); lambdas(bestidx) B_best = mglm_sparsereg(X,Y,lambdas(bestidx),'penalty',pen,'dist',dist); figure; subplot(1,3,1); semilogx(lambdas,BICs); ylabel('BIC'); xlabel('\lambda'); xlim([min(lambdas) max(lambdas)]); subplot(1,3,2); imshow(mat2gray(-B)); title('True B'); subplot(1,3,3); imshow(mat2gray(-B_best)); title([pen ' estimate']); end
Elapsed time is 0.408279 seconds. ans = 10.8980 Elapsed time is 0.458002 seconds. ans = 30.2672 Elapsed time is 0.670406 seconds. ans = 194.8746