Supervised Contrastive Learning - Keras

文章推薦指數: 80 %
投票人數:10人

Note that this example requires TensorFlow Addons, which you can install using the ... Supervised contrastive learning loss function. Star AboutKeras Gettingstarted Developerguides KerasAPIreference Codeexamples ComputerVision NaturalLanguageProcessing StructuredData Timeseries AudioData GenerativeDeepLearning ReinforcementLearning GraphData QuickKerasRecipes WhychooseKeras? Community&governance ContributingtoKeras KerasTuner KerasCV KerasNLP search » Codeexamples/ ComputerVision/ SupervisedContrastiveLearning SupervisedContrastiveLearning Author:KhalidSalama Datecreated:2020/11/30 Lastmodified:2020/11/30 Description:Usingsupervisedcontrastivelearningforimageclassification. ViewinColab•GitHubsource Introduction SupervisedContrastiveLearning (PrannayKhoslaetal.)isatrainingmethodologythatoutperforms supervisedtrainingwithcrossentropyonclassificationtasks. Essentially,traininganimageclassificationmodelwithSupervisedContrastive Learningisperformedintwophases: Traininganencodertolearntoproducevectorrepresentationsofinputimagessuch thatrepresentationsofimagesinthesameclasswillbemoresimilarcomparedto representationsofimagesindifferentclasses. Trainingaclassifierontopofthefrozenencoder. NotethatthisexamplerequiresTensorFlowAddons,whichyoucaninstallusing thefollowingcommand: pipinstalltensorflow-addons Setup importtensorflowastf importtensorflow_addonsastfa importnumpyasnp fromtensorflowimportkeras fromtensorflow.kerasimportlayers Preparethedata num_classes=10 input_shape=(32,32,3) #Loadthetrainandtestdatasplits (x_train,y_train),(x_test,y_test)=keras.datasets.cifar10.load_data() #Displayshapesoftrainandtestdatasets print(f"x_trainshape:{x_train.shape}-y_trainshape:{y_train.shape}") print(f"x_testshape:{x_test.shape}-y_testshape:{y_test.shape}") x_trainshape:(50000,32,32,3)-y_trainshape:(50000,1) x_testshape:(10000,32,32,3)-y_testshape:(10000,1) Usingimagedataaugmentation data_augmentation=keras.Sequential( [ layers.Normalization(), layers.RandomFlip("horizontal"), layers.RandomRotation(0.02), layers.RandomWidth(0.2), layers.RandomHeight(0.2), ] ) #Settingthestateofthenormalizationlayer. data_augmentation.layers[0].adapt(x_train) Buildtheencodermodel Theencodermodeltakestheimageasinputandturnsitintoa2048-dimensional featurevector. defcreate_encoder(): resnet=keras.applications.ResNet50V2( include_top=False,weights=None,input_shape=input_shape,pooling="avg" ) inputs=keras.Input(shape=input_shape) augmented=data_augmentation(inputs) outputs=resnet(augmented) model=keras.Model(inputs=inputs,outputs=outputs,name="cifar10-encoder") returnmodel encoder=create_encoder() encoder.summary() learning_rate=0.001 batch_size=265 hidden_units=512 projection_units=128 num_epochs=50 dropout_rate=0.5 temperature=0.05 Model:"cifar10-encoder" _________________________________________________________________ Layer(type)OutputShapeParam# ================================================================= input_2(InputLayer)[(None,32,32,3)]0 _________________________________________________________________ sequential(Sequential)(None,None,None,3)7 _________________________________________________________________ resnet50v2(Functional)(None,2048)23564800 ================================================================= Totalparams:23,564,807 Trainableparams:23,519,360 Non-trainableparams:45,447 _________________________________________________________________ Buildtheclassificationmodel Theclassificationmodeladdsafully-connectedlayerontopoftheencoder, plusasoftmaxlayerwiththetargetclasses. defcreate_classifier(encoder,trainable=True): forlayerinencoder.layers: layer.trainable=trainable inputs=keras.Input(shape=input_shape) features=encoder(inputs) features=layers.Dropout(dropout_rate)(features) features=layers.Dense(hidden_units,activation="relu")(features) features=layers.Dropout(dropout_rate)(features) outputs=layers.Dense(num_classes,activation="softmax")(features) model=keras.Model(inputs=inputs,outputs=outputs,name="cifar10-classifier") model.compile( optimizer=keras.optimizers.Adam(learning_rate), loss=keras.losses.SparseCategoricalCrossentropy(), metrics=[keras.metrics.SparseCategoricalAccuracy()], ) returnmodel Experiment1:Trainthebaselineclassificationmodel Inthisexperiment,abaselineclassifieristrainedasusual,i.e.,the encoderandtheclassifierpartsaretrainedtogetherasasinglemodel tominimizethecrossentropyloss. encoder=create_encoder() classifier=create_classifier(encoder) classifier.summary() history=classifier.fit(x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs) accuracy=classifier.evaluate(x_test,y_test)[1] print(f"Testaccuracy:{round(accuracy*100,2)}%") Model:"cifar10-classifier" _________________________________________________________________ Layer(type)OutputShapeParam# ================================================================= input_5(InputLayer)[(None,32,32,3)]0 _________________________________________________________________ cifar10-encoder(Functional)(None,2048)23564807 _________________________________________________________________ dropout(Dropout)(None,2048)0 _________________________________________________________________ dense(Dense)(None,512)1049088 _________________________________________________________________ dropout_1(Dropout)(None,512)0 _________________________________________________________________ dense_1(Dense)(None,10)5130 ================================================================= Totalparams:24,619,025 Trainableparams:24,573,578 Non-trainableparams:45,447 _________________________________________________________________ Epoch1/50 189/189[==============================]-15s77ms/step-loss:1.9369-sparse_categorical_accuracy:0.2874 Epoch2/50 189/189[==============================]-11s57ms/step-loss:1.5133-sparse_categorical_accuracy:0.4505 Epoch3/50 189/189[==============================]-11s57ms/step-loss:1.3468-sparse_categorical_accuracy:0.5204 Epoch4/50 189/189[==============================]-11s60ms/step-loss:1.2159-sparse_categorical_accuracy:0.5733 Epoch5/50 189/189[==============================]-11s56ms/step-loss:1.1516-sparse_categorical_accuracy:0.6032 Epoch6/50 189/189[==============================]-11s58ms/step-loss:1.0769-sparse_categorical_accuracy:0.6254 Epoch7/50 189/189[==============================]-11s58ms/step-loss:0.9964-sparse_categorical_accuracy:0.6547 Epoch8/50 189/189[==============================]-10s55ms/step-loss:0.9563-sparse_categorical_accuracy:0.6703 Epoch9/50 189/189[==============================]-10s55ms/step-loss:0.8952-sparse_categorical_accuracy:0.6925 Epoch10/50 189/189[==============================]-11s56ms/step-loss:0.8986-sparse_categorical_accuracy:0.6922 Epoch11/50 189/189[==============================]-10s55ms/step-loss:0.8381-sparse_categorical_accuracy:0.7145 Epoch12/50 189/189[==============================]-10s55ms/step-loss:0.8513-sparse_categorical_accuracy:0.7086 Epoch13/50 189/189[==============================]-11s56ms/step-loss:0.7557-sparse_categorical_accuracy:0.7448 Epoch14/50 189/189[==============================]-11s56ms/step-loss:0.7168-sparse_categorical_accuracy:0.7548 Epoch15/50 189/189[==============================]-10s55ms/step-loss:0.6772-sparse_categorical_accuracy:0.7690 Epoch16/50 189/189[==============================]-11s56ms/step-loss:0.7587-sparse_categorical_accuracy:0.7416 Epoch17/50 189/189[==============================]-10s55ms/step-loss:0.6873-sparse_categorical_accuracy:0.7665 Epoch18/50 189/189[==============================]-11s56ms/step-loss:0.6418-sparse_categorical_accuracy:0.7804 Epoch19/50 189/189[==============================]-11s56ms/step-loss:0.6086-sparse_categorical_accuracy:0.7927 Epoch20/50 189/189[==============================]-10s55ms/step-loss:0.5903-sparse_categorical_accuracy:0.7978 Epoch21/50 189/189[==============================]-11s56ms/step-loss:0.5636-sparse_categorical_accuracy:0.8083 Epoch22/50 189/189[==============================]-11s56ms/step-loss:0.5527-sparse_categorical_accuracy:0.8123 Epoch23/50 189/189[==============================]-11s56ms/step-loss:0.5308-sparse_categorical_accuracy:0.8191 Epoch24/50 189/189[==============================]-10s55ms/step-loss:0.5282-sparse_categorical_accuracy:0.8223 Epoch25/50 189/189[==============================]-10s55ms/step-loss:0.5090-sparse_categorical_accuracy:0.8263 Epoch26/50 189/189[==============================]-10s55ms/step-loss:0.5497-sparse_categorical_accuracy:0.8181 Epoch27/50 189/189[==============================]-10s55ms/step-loss:0.4950-sparse_categorical_accuracy:0.8332 Epoch28/50 189/189[==============================]-11s56ms/step-loss:0.4727-sparse_categorical_accuracy:0.8391 Epoch29/50 167/189[=========================>....]-ETA:1s-loss:0.4594-sparse_categorical_accuracy:0.8444 Experiment2:Usesupervisedcontrastivelearning Inthisexperiment,themodelistrainedintwophases.Inthefirstphase, theencoderispretrainedtooptimizethesupervisedcontrastiveloss, describedinPrannayKhoslaetal.. Inthesecondphase,theclassifieristrainedusingthetrainedencoderwith itsweightsfreezed;onlytheweightsoffully-connectedlayerswiththe softmaxareoptimized. 1.Supervisedcontrastivelearninglossfunction classSupervisedContrastiveLoss(keras.losses.Loss): def__init__(self,temperature=1,name=None): super(SupervisedContrastiveLoss,self).__init__(name=name) self.temperature=temperature def__call__(self,labels,feature_vectors,sample_weight=None): #Normalizefeaturevectors feature_vectors_normalized=tf.math.l2_normalize(feature_vectors,axis=1) #Computelogits logits=tf.divide( tf.matmul( feature_vectors_normalized,tf.transpose(feature_vectors_normalized) ), self.temperature, ) returntfa.losses.npairs_loss(tf.squeeze(labels),logits) defadd_projection_head(encoder): inputs=keras.Input(shape=input_shape) features=encoder(inputs) outputs=layers.Dense(projection_units,activation="relu")(features) model=keras.Model( inputs=inputs,outputs=outputs,name="cifar-encoder_with_projection-head" ) returnmodel 2.Pretraintheencoder encoder=create_encoder() encoder_with_projection_head=add_projection_head(encoder) encoder_with_projection_head.compile( optimizer=keras.optimizers.Adam(learning_rate), loss=SupervisedContrastiveLoss(temperature), ) encoder_with_projection_head.summary() history=encoder_with_projection_head.fit( x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs ) Model:"cifar-encoder_with_projection-head" _________________________________________________________________ Layer(type)OutputShapeParam# ================================================================= input_8(InputLayer)[(None,32,32,3)]0 _________________________________________________________________ cifar10-encoder(Functional)(None,2048)23564807 _________________________________________________________________ dense_2(Dense)(None,128)262272 ================================================================= Totalparams:23,827,079 Trainableparams:23,781,632 Non-trainableparams:45,447 _________________________________________________________________ Epoch1/50 189/189[==============================]-11s56ms/step-loss:5.3730 Epoch2/50 189/189[==============================]-11s56ms/step-loss:5.1583 Epoch3/50 189/189[==============================]-10s55ms/step-loss:5.0368 Epoch4/50 189/189[==============================]-11s56ms/step-loss:4.9349 Epoch5/50 189/189[==============================]-10s55ms/step-loss:4.8262 Epoch6/50 189/189[==============================]-11s56ms/step-loss:4.7470 Epoch7/50 189/189[==============================]-11s56ms/step-loss:4.6835 Epoch8/50 189/189[==============================]-11s56ms/step-loss:4.6120 Epoch9/50 189/189[==============================]-11s56ms/step-loss:4.5608 Epoch10/50 189/189[==============================]-10s55ms/step-loss:4.5075 Epoch11/50 189/189[==============================]-11s56ms/step-loss:4.4674 Epoch12/50 189/189[==============================]-10s56ms/step-loss:4.4362 Epoch13/50 189/189[==============================]-11s56ms/step-loss:4.3899 Epoch14/50 189/189[==============================]-10s55ms/step-loss:4.3664 Epoch15/50 189/189[==============================]-11s56ms/step-loss:4.3188 Epoch16/50 189/189[==============================]-10s56ms/step-loss:4.3030 Epoch17/50 189/189[==============================]-11s57ms/step-loss:4.2725 Epoch18/50 189/189[==============================]-10s55ms/step-loss:4.2523 Epoch19/50 189/189[==============================]-11s56ms/step-loss:4.2100 Epoch20/50 189/189[==============================]-10s55ms/step-loss:4.2033 Epoch21/50 189/189[==============================]-11s56ms/step-loss:4.1741 Epoch22/50 189/189[==============================]-11s56ms/step-loss:4.1443 Epoch23/50 189/189[==============================]-11s56ms/step-loss:4.1350 Epoch24/50 189/189[==============================]-11s57ms/step-loss:4.1192 Epoch25/50 189/189[==============================]-11s56ms/step-loss:4.1002 Epoch26/50 189/189[==============================]-11s57ms/step-loss:4.0797 Epoch27/50 189/189[==============================]-11s56ms/step-loss:4.0547 Epoch28/50 189/189[==============================]-11s56ms/step-loss:4.0336 Epoch29/50 189/189[==============================]-11s56ms/step-loss:4.0299 Epoch30/50 189/189[==============================]-11s56ms/step-loss:4.0031 Epoch31/50 189/189[==============================]-11s56ms/step-loss:3.9979 Epoch32/50 189/189[==============================]-11s56ms/step-loss:3.9777 Epoch33/50 189/189[==============================]-10s55ms/step-loss:3.9800 Epoch34/50 189/189[==============================]-11s56ms/step-loss:3.9538 Epoch35/50 189/189[==============================]-11s56ms/step-loss:3.9298 Epoch36/50 189/189[==============================]-11s57ms/step-loss:3.9241 Epoch37/50 189/189[==============================]-11s56ms/step-loss:3.9102 Epoch38/50 189/189[==============================]-11s56ms/step-loss:3.9075 Epoch39/50 189/189[==============================]-11s56ms/step-loss:3.8897 Epoch40/50 189/189[==============================]-11s57ms/step-loss:3.8871 Epoch41/50 189/189[==============================]-11s56ms/step-loss:3.8596 Epoch42/50 189/189[==============================]-10s56ms/step-loss:3.8526 Epoch43/50 189/189[==============================]-11s56ms/step-loss:3.8417 Epoch44/50 189/189[==============================]-10s55ms/step-loss:3.8239 Epoch45/50 189/189[==============================]-11s56ms/step-loss:3.8178 Epoch46/50 189/189[==============================]-11s56ms/step-loss:3.8065 Epoch47/50 189/189[==============================]-11s56ms/step-loss:3.8185 Epoch48/50 189/189[==============================]-11s56ms/step-loss:3.8022 Epoch49/50 189/189[==============================]-11s56ms/step-loss:3.7815 Epoch50/50 189/189[==============================]-11s56ms/step-loss:3.7601 3.Traintheclassifierwiththefrozenencoder classifier=create_classifier(encoder,trainable=False) history=classifier.fit(x=x_train,y=y_train,batch_size=batch_size,epochs=num_epochs) accuracy=classifier.evaluate(x_test,y_test)[1] print(f"Testaccuracy:{round(accuracy*100,2)}%") Epoch1/50 189/189[==============================]-3s16ms/step-loss:0.3979-sparse_categorical_accuracy:0.8869 Epoch2/50 189/189[==============================]-3s16ms/step-loss:0.3422-sparse_categorical_accuracy:0.8959 Epoch3/50 189/189[==============================]-3s16ms/step-loss:0.3251-sparse_categorical_accuracy:0.9004 Epoch4/50 189/189[==============================]-3s16ms/step-loss:0.3313-sparse_categorical_accuracy:0.8963 Epoch5/50 189/189[==============================]-3s16ms/step-loss:0.3213-sparse_categorical_accuracy:0.9006 Epoch6/50 189/189[==============================]-3s16ms/step-loss:0.3221-sparse_categorical_accuracy:0.9001 Epoch7/50 189/189[==============================]-3s16ms/step-loss:0.3134-sparse_categorical_accuracy:0.9001 Epoch8/50 189/189[==============================]-3s16ms/step-loss:0.3245-sparse_categorical_accuracy:0.8978 Epoch9/50 189/189[==============================]-3s16ms/step-loss:0.3144-sparse_categorical_accuracy:0.9001 Epoch10/50 189/189[==============================]-3s16ms/step-loss:0.3191-sparse_categorical_accuracy:0.8984 Epoch11/50 189/189[==============================]-3s16ms/step-loss:0.3104-sparse_categorical_accuracy:0.9025 Epoch12/50 189/189[==============================]-3s16ms/step-loss:0.3261-sparse_categorical_accuracy:0.8958 Epoch13/50 189/189[==============================]-3s16ms/step-loss:0.3130-sparse_categorical_accuracy:0.9001 Epoch14/50 189/189[==============================]-3s16ms/step-loss:0.3147-sparse_categorical_accuracy:0.9003 Epoch15/50 189/189[==============================]-3s16ms/step-loss:0.3113-sparse_categorical_accuracy:0.9016 Epoch16/50 189/189[==============================]-3s16ms/step-loss:0.3114-sparse_categorical_accuracy:0.9008 Epoch17/50 189/189[==============================]-3s16ms/step-loss:0.3044-sparse_categorical_accuracy:0.9026 Epoch18/50 189/189[==============================]-3s16ms/step-loss:0.3142-sparse_categorical_accuracy:0.8987 Epoch19/50 189/189[==============================]-3s16ms/step-loss:0.3139-sparse_categorical_accuracy:0.9018 Epoch20/50 189/189[==============================]-3s16ms/step-loss:0.3199-sparse_categorical_accuracy:0.8987 Epoch21/50 189/189[==============================]-3s16ms/step-loss:0.3125-sparse_categorical_accuracy:0.8994 Epoch22/50 189/189[==============================]-3s16ms/step-loss:0.3291-sparse_categorical_accuracy:0.8967 Epoch23/50 189/189[==============================]-3s16ms/step-loss:0.3208-sparse_categorical_accuracy:0.8963 Epoch24/50 189/189[==============================]-3s16ms/step-loss:0.3065-sparse_categorical_accuracy:0.9041 Epoch25/50 189/189[==============================]-3s16ms/step-loss:0.3099-sparse_categorical_accuracy:0.9006 Epoch26/50 189/189[==============================]-3s16ms/step-loss:0.3181-sparse_categorical_accuracy:0.8986 Epoch27/50 189/189[==============================]-3s16ms/step-loss:0.3112-sparse_categorical_accuracy:0.9013 Epoch28/50 189/189[==============================]-3s16ms/step-loss:0.3136-sparse_categorical_accuracy:0.8996 Epoch29/50 189/189[==============================]-3s16ms/step-loss:0.3217-sparse_categorical_accuracy:0.8969 Epoch30/50 189/189[==============================]-3s16ms/step-loss:0.3161-sparse_categorical_accuracy:0.8998 Epoch31/50 189/189[==============================]-3s16ms/step-loss:0.3151-sparse_categorical_accuracy:0.8999 Epoch32/50 189/189[==============================]-3s16ms/step-loss:0.3092-sparse_categorical_accuracy:0.9009 Epoch33/50 189/189[==============================]-3s16ms/step-loss:0.3246-sparse_categorical_accuracy:0.8961 Epoch34/50 189/189[==============================]-3s16ms/step-loss:0.3143-sparse_categorical_accuracy:0.8995 Epoch35/50 189/189[==============================]-3s16ms/step-loss:0.3106-sparse_categorical_accuracy:0.9002 Epoch36/50 189/189[==============================]-3s16ms/step-loss:0.3210-sparse_categorical_accuracy:0.8980 Epoch37/50 189/189[==============================]-3s16ms/step-loss:0.3178-sparse_categorical_accuracy:0.9009 Epoch38/50 189/189[==============================]-3s16ms/step-loss:0.3064-sparse_categorical_accuracy:0.9032 Epoch39/50 189/189[==============================]-3s16ms/step-loss:0.3196-sparse_categorical_accuracy:0.8981 Epoch40/50 189/189[==============================]-3s16ms/step-loss:0.3177-sparse_categorical_accuracy:0.8988 Epoch41/50 189/189[==============================]-3s16ms/step-loss:0.3167-sparse_categorical_accuracy:0.8987 Epoch42/50 189/189[==============================]-3s16ms/step-loss:0.3110-sparse_categorical_accuracy:0.9014 Epoch43/50 189/189[==============================]-3s16ms/step-loss:0.3124-sparse_categorical_accuracy:0.9002 Epoch44/50 189/189[==============================]-3s16ms/step-loss:0.3128-sparse_categorical_accuracy:0.8999 Epoch45/50 189/189[==============================]-3s16ms/step-loss:0.3131-sparse_categorical_accuracy:0.8991 Epoch46/50 189/189[==============================]-3s16ms/step-loss:0.3149-sparse_categorical_accuracy:0.8992 Epoch47/50 189/189[==============================]-3s16ms/step-loss:0.3082-sparse_categorical_accuracy:0.9021 Epoch48/50 189/189[==============================]-3s16ms/step-loss:0.3223-sparse_categorical_accuracy:0.8959 Epoch49/50 189/189[==============================]-3s16ms/step-loss:0.3195-sparse_categorical_accuracy:0.8981 Epoch50/50 189/189[==============================]-3s16ms/step-loss:0.3240-sparse_categorical_accuracy:0.8962 313/313[==============================]-2s7ms/step-loss:0.7332-sparse_categorical_accuracy:0.8162 Testaccuracy:81.62% Wegettoanimprovedtestaccuracy. Conclusion Asshownintheexperiments,usingthesupervisedcontrastivelearningtechnique outperformedtheconventionaltechniqueintermsofthetestaccuracy.Notethat thesametrainingbudget(i.e.,numberofepochs)wasgiventoeachtechnique. Supervisedcontrastivelearningpaysoffwhentheencoderinvolvesacomplex architecture,likeResNet,andmulti-classproblemswithmanylabels. Inaddition,largebatchsizesandmulti-layerprojectionheads improveitseffectiveness.SeetheSupervisedContrastiveLearning paperformoredetails. YoucanusethetrainedmodelhostedonHuggingFaceHubandtrythedemoonHuggingFaceSpaces. SupervisedContrastiveLearning ▻ Introduction ▻ Setup ▻ Preparethedata ▻ Usingimagedataaugmentation ▻ Buildtheencodermodel ▻ Buildtheclassificationmodel ▻ Experiment1:Trainthebaselineclassificationmodel ▻ Experiment2:Usesupervisedcontrastivelearning 1.Supervisedcontrastivelearninglossfunction 2.Pretraintheencoder 3.Traintheclassifierwiththefrozenencoder ▻ Conclusion



請為這篇文章評分?