Quantile Regression — Part 2 - Medium

文章推薦指數: 80 %
投票人數:10人

PyTorch. The loss function is implemented as a class: class QuantileLoss(nn.Module): def __init__(self, quantiles): GetunlimitedaccessOpeninappHomeNotificationsListsStoriesWritePublishedinVeritableQuantileRegression—Part2AnOverviewofTensorflow,Pytorch,LightGBMimplementationsPhotoCreditQuantileRegression—Part1WhatisitandHowdoesitwork?medium.comWe’vediscussedwhatquantileregressionisandhowdoesitworkinPart1.InthisPart2we’regoingtoexplorehowtotrainquantileregressionmodelsindeeplearningmodelsandgradientboostingtrees.SourceCodeThesourcecodetothispostisprovidedinthisrepository:ceshine/quantile-regression-tensorflowquantile-regression-tensorflow-ImplementationsofQuantileRegressiongithub.comItisaforkofstrongio/quantile-regression-tensorflow,withfollowingmodifcations:Usetheexampledatasetfromthescikit-learnexample.TheTensorFlowimplementationismostlythesameasinstrongio/quantile-regression-tensorflow.AddanexampleofLightGBMmodelusing“quantile”objective(andascikit-learnGBMexampleforcomparison)basedonthisGithubissue.AddaPytorchimplementation.ProvideaDockerfiletoreproducetheenvironmentandresults.Inadditionto.ipynbnotebookfiles,providea.pycopyforcommandlinereading.TensorflowThemostimportantpieceofthepuzzleisthelossfunction,asintroducedinPart1:LossFunctionofQuantileRegression(Source)Thetrickypartishowtodealwiththeindicatorfunction.Usingif-elsestatementoneachexamplewouldbeveryinefficient.Thesmarterwaytodoitistocalculatebothy*τandy*(τ-1)andtakeelement-wisemaximums(thispairwillalwayshaveonepositiveandonenegativenumberexceptwheny=0.τisin(0,1)range.).Thefollowingimplementationisdirectlycopiedfromstrongio/quantile-regression-tensorflowby@jacobzweig:error=tf.subtract(self.y,output)loss=tf.reduce_mean(tf.maximum(q*error,(q-1)*error),axis=-1)Ifusingthisimplementation,you’llhavetocalculatelossesforeachdesiredquantileτseparately.ButIthinksinceweusuallyonlywanttopredictonly2to3quantiles,theneedtooptimizethisisinsubstantial.TherestisjustregularTensorflowneuralnetworkbuilding.Youcanusewhateverstructureyouwant.Wegiveanexampleherefittingtheexampledatasetfromscikit-learn:Thefinallayerscanbemadeevensimplerbyusingonebiglinearlayeratthetop.Theunderlyingcalculationsareexactlythesame,butthelatterimplementsitinabigweightmatrix(andabiasvector).ThePyTorchimplementationwe’regoingtoseelaterprovidesanimplementationofthisapproach:Thefittedmodelwouldlooklikethis:TensorflowImplementationPyTorchThelossfunctionisimplementedasaclass:classQuantileLoss(nn.Module):def__init__(self,quantiles):super().__init__()self.quantiles=quantilesdefforward(self,preds,target):assertnottarget.requires_gradassertpreds.size(0)==target.size(0)losses=[]fori,qinenumerate(self.quantiles):errors=target-preds[:,i]losses.append(torch.max((q-1)*errors,q*errors).unsqueeze(1))loss=torch.mean(torch.sum(torch.cat(losses,dim=1),dim=1))returnloss20180718Edit:Fixedabuginsideforword()method.Shouldhaveusedself.quantilesinsteadofquantiles.Lessonlearned:avoidingreusingvariablenamesinJupyternotebooks,whichusuallyarefullofglobalvariables.Itexpectsthepredictionstocomeinonetensorofshape(N,Q).Thefinaltorch.sumandtorch.meanreductionfollowstheTensorflowimplementation.Youcanalsochooseusedifferentweightsfordifferentquantiles,butI’mnotverysurehowit’llaffecttheresult.PyTorchImplementationBatchNormalizationOneinterestingInoticedisthataddingbatchnormalizationmakesthePyTorchmodelseverelyunder-fit,buttheTensorflowmodelseemstofarebetter.Maybeit’stheverysmallsizeofbatches(10)andthesmalltrainingdataset(100examples)that’scausingproblems.Thefinalversionhasbatchnormalizationfrombothmodelremoved.Trainingmodelsonalargerreal-worlddataset,whichhasyettobedone,shouldhelpmefiguringitout.MonteCarlodropoutI’vecoveredMonteCarlodropoutpreviouslyinthispost:[LearningNote]DropoutinRecurrentNetworks—Part1TheoreticalFoundationsbecominghuman.aiInshort,itperformsdropoutintest/predictiontimetoapproximatesamplingfromtheposteriordistribution.Thewecanfindthecredibleintervalforaquantile.Weshowthecredibleintervalofthemedianbelow:TensorflowMCDropoutfortheMedianPyTorchMCDropoutfortheMedianAswecansee,thecredibleintervalismuchnarrowerthanthepredictioninterval(checkPart1ifyou’renotsurewhattheymean).Thesetwocanbequiteconfusing,asstatedinthefollowingdiscussionthread(CheckIanOsband’scomment):InOsband’scomment,“theposteriordistributionforanoutcome”isusedtoconstructthepredictioninterval,and“posteriordistributionforwhatyouthinkisthemeanofanoutcome”isusedtoconstructthecredibleinterval.Healsogaveancleverexampletodemonstratetheirdifferences:Afaircoin,thatyouknowisafaircoin.Youcanbe100%surethattheexpectedoutcomeis0.5.Thisistheposteriordistributionforthemean—adiracat0.5.Ontheotherhand,foranysingleflipthedistributionofoutcomesis50%at0and50%at1.Thetwodistributionsarecompletelydistinct.Ithinkyou’llhavetofindsomewaytosamplefromtheposteriordistributionoftheerrortermtocreatepredictionintervalswithMCdropout.Notsurehowtodoityet.Maybewecanestimatethedistributionbycollectingtheerrorswhensamplingthetarget(Ineedtodomoreresearchhere).LightGBMFortreemodels,it’snotpossibletopredictmorethanonevaluepermodel.ThereforewhatwedohereisessentiallytrainingQindependentmodelswhichpredictonequantile.Scikit-learnisthebaselinehere.Whatyouneedtodoispassloss=’quantile’andalpha=ALPHA,whereALPHA((0,1)range)isthequantilewewanttopredict:Scikit-LearnGradientBoostingRegressorLightGBMhastheexactsameparameterforquantileregression(checkthefulllisthere).Whenusingthescikit-learnAPI,thecallwouldbesomethingsimilarto:clfl=lgb.LGBMRegressor(objective='quantile',alpha=1-ALPHA,num_leaves=NUM_LEAVES,learning_rate=LEARNING_RATE,n_estimators=N_ESTIMATORS,min_data_in_leaf=5,reg_sqrt=REG_SQRT,max_depth=MAX_DEPTH)LightGBMLGBMRegressorOnespecialparametertotuneforLightGBM—min_data_in_leaf.Itdefaultsto20,whichistoolargeforthisdataset(100examples)andwillcauseunder-fit.Tuneitdowntogetnarrowerpredictionintervals.Notmuchtosayhere.Thegradientsofthelossfunctionarehandledinsidethelibrary.Weonlyneedtopassthecorrectparameter.However,unlikeneuralnetworks,wecannoteasilygetconfidenceintervalnorcredibleintervalfromtreemodels.TheEndThankyouforreadingthisfar!Pleaseconsidergivethispostsomeclapstoshowyoursupport.It’llbemuchappreciated.(Thispostisalsopublishedonmypersonalblog.)MorefromVeritableTowardshuman-centeredAI.https://veritable.pwReadmorefromVeritableRecommendedfromMediumJamesonTooleinHeartbeatHeartbeatNewsletter—Vol.23FangLintaoKalmanFilterinFairMOT?DeepLabCutBlogDeepLabCutmeetsthebrainstem:howdeeplearningforbehavioryieldsinsightsintotheneural…CeshineLeeinVeritable[Notes]JigsawUnintendedBiasinToxicityClassificationBibekChaudharyMarkovDecisionProcess(MDP)SimplifiedRoryMcManusinTowardsAIAzureCognitiveServicesSentimentAnalysisV3—UsingPySparkJosefLindmanHörnlundSudokuRNNinPyTorchKayathiriMahendrakumaraninதழலிDealingwithImbalancedDataAboutHelpTermsPrivacyGettheMediumappGetstartedCeshineLee1.6KFollowersDataGeek.Maker.Researcher.Twitter:@ceshine_enFollowHelpStatusWritersBlogCareersPrivacyTermsAboutKnowable



請為這篇文章評分?