treex.losses.huber
文章推薦指數: 80 %
Computes the Huber loss between target and predictions. ... behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior). Skiptocontent Treex cgarciae/treex Treex UserGuide UserGuide UserGuide DefiningModules Initialization ModuleAPI ModuleAPI Filter Merge Map Optimizer StateManagement TrainingState FreezingModules APIReference APIReference Apply ArrayLike BatchNorm BatchStat Cache Compact Conv ConvTranspose Copy Dropout Embed Extensions FieldInfo FieldMetadata Filter Filters Flatten FlattenMode FlaxModule GroupNorm HaikuModule Hashable Initializer Inputs KeySeq KindMixin Lambda LayerNorm Linear Log Loss LossAndLogs LossLog MISSING MLP Map Merge Metric MetricLog MetricState Missing ModelState Module ModuleMeta NOTHING Named Nothing Opaque OpaquePredicate OptState Optimizer Parameter Repr Rng Sequential State ToDict ToString Tree TreeMeta TreePart Treex add_field_info apply compact compact_module copy field filter flatten_mode in_compact losses losses CosineSimilarity Crossentropy Huber Loss MeanAbsoluteError MeanAbsolutePercentageError MeanSquaredError MeanSquaredLogarithmicError Reduction cosine_similarity crossentropy huber huber Tableofcontents treex.losses.huber Huber __init__() call() huber() mean_absolute_error mean_absolute_percentage_error mean_squared_error mean_squared_logarithmic_error map merge metrics metrics Accuracy AuxLosses AuxMetrics LossAndLogs Losses MAE MSE Mean MeanAbsoluteError MeanSquareError Metric Metrics Reduce Reduction next_key nn nn BatchNorm Conv ConvTranspose Dropout Embed Flatten FlaxModule GroupNorm HaikuModule Lambda LayerNorm Linear MLP Sequential sequence node preserve_state regularizers regularizers L1 L1L2 L2 rng_key sequence static to_dict to_string Tableofcontents treex.losses.huber Huber __init__() call() huber() treex.losses.huber Huber(Loss) ComputestheHuberlossbetweentargetandpredictions. Foreachvaluexinerror=target-preds: loss= \begin{cases} \0.5\timesx^2,\hskip8em\text{if}|x|\leqd\\ 0.5\timesd^2+d\times(|x|-d),\hskip1.7em\text{otherwise} \end{cases} wheredisdelta.See:https://en.wikipedia.org/wiki/Huber_loss Usage: target=jnp.array([[0,1],[0,0]]) preds=jnp.array([[0.6,0.4],[0.4,0.6]]) #Using'auto'/'sum_over_batch_size'reductiontype. huber_loss=tx.losses.Huber() asserthuber_loss(target,preds)==0.155 #Callingwith'sample_weight'. assert( huber_loss(target,preds,sample_weight=jnp.array([0.8,0.2]))==0.08500001 ) #Using'sum'reductiontype. huber_loss=tx.losses.Huber( reduction=tx.losses.Reduction.SUM ) asserthuber_loss(target,preds)==0.31 #Using'none'reductiontype. huber_loss=tx.losses.Huber( reduction=tx.losses.Reduction.NONE ) assertjnp.equal(huber_loss(target,preds),jnp.array([0.18,0.13000001])).all() UsagewiththeElegyAPI: model=elegy.Model( module_fn, loss=tx.losses.Huber(delta=1.0), metrics=elegy.metrics.Mean(), ) Sourcecodeintreex/losses/huber.py classHuber(Loss): r""" ComputestheHuberlossbetweentargetandpredictions. Foreachvaluexinerror=target-preds: $$ loss= \begin{cases} \0.5\timesx^2,\hskip8em\text{if}|x|\leqd\\ 0.5\timesd^2+d\times(|x|-d),\hskip1.7em\text{otherwise} \end{cases} $$ wheredisdelta.See:https://en.wikipedia.org/wiki/Huber_loss Usage: ```python target=jnp.array([[0,1],[0,0]]) preds=jnp.array([[0.6,0.4],[0.4,0.6]]) #Using'auto'/'sum_over_batch_size'reductiontype. huber_loss=tx.losses.Huber() asserthuber_loss(target,preds)==0.155 #Callingwith'sample_weight'. assert( huber_loss(target,preds,sample_weight=jnp.array([0.8,0.2]))==0.08500001 ) #Using'sum'reductiontype. huber_loss=tx.losses.Huber( reduction=tx.losses.Reduction.SUM ) asserthuber_loss(target,preds)==0.31 #Using'none'reductiontype. huber_loss=tx.losses.Huber( reduction=tx.losses.Reduction.NONE ) assertjnp.equal(huber_loss(target,preds),jnp.array([0.18,0.13000001])).all() ``` UsagewiththeElegyAPI: ```python model=elegy.Model( module_fn, loss=tx.losses.Huber(delta=1.0), metrics=elegy.metrics.Mean(), ) ``` """ def__init__( self, delta:float=1.0, reduction:tp.Optional[Reduction]=None, weight:tp.Optional[float]=None, on:tp.Optional[types.IndexLike]=None, **kwargs ): """ Initializes`Mean`class. Arguments: delta:(Optional)Defaultsto1.0.Afloat,thepointwheretheHuberlossfunctionchangesfromaquadratictolinear. reduction:(Optional)Typeof`tx.losses.Reduction`toapplyto loss.Defaultvalueis`SUM_OVER_BATCH_SIZE`.Foralmostallcases thisdefaultsto`SUM_OVER_BATCH_SIZE`. weight:Optionalweightcontributionforthetotalloss.Defaultsto`1`. on:Astringorinteger,oriterableofstringorintegers,that indicatehowtoindex/filterthe`target`and`preds` argumentsbeforepassingthemto`call`.Forexampleif`on="a"`then `target=target["a"]`.If`on`isaniterable thestructureswillbeindexediteratively,forexampleif`on=["a",0,"b"]` then`target=target["a"][0]["b"]`,samefor`preds`.Formoreinformation checkout[Keras-likebehavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior). """ self.delta=delta returnsuper().__init__(reduction=reduction,weight=weight,on=on,**kwargs) defcall( self, target:jnp.ndarray, preds:jnp.ndarray, sample_weight:tp.Optional[ jnp.ndarray ]=None,#notused,__call__handlesit,leftfordocumentationpurposes. )->jnp.ndarray: """ Invokesthe`Huber`instance. Arguments: target:Groundtruthvalues.shape=`[batch_size,d0,..dN]`,except sparselossfunctionssuchassparsecategoricalcrossentropywhere shape=`[batch_size,d0,..dN-1]` preds:Thepredictedvalues.shape=`[batch_size,d0,..dN]` sample_weight:Optional`sample_weight`actsasa coefficientfortheloss.Ifascalarisprovided,thenthelossis simplyscaledbythegivenvalue.If`sample_weight`isatensorofsize `[batch_size]`,thenthetotallossforeachsampleofthebatchis rescaledbythecorrespondingelementinthe`sample_weight`vector.If theshapeof`sample_weight`is`[batch_size,d0,..dN-1]`(orcanbe broadcastedtothisshape),theneachlosselementof`preds`isscaled bythecorrespondingvalueof`sample_weight`.(Noteon`dN-1`:allloss functionsreduceby1dimension,usuallyaxis=-1.) Returns: Weightedlossfloat`Tensor`.If`reduction`is`NONE`,thishas shape`[batch_size,d0,..dN-1]`;otherwise,itisscalar.(Note`dN-1` becausealllossfunctionsreduceby1dimension,usuallyaxis=-1.) Raises: ValueError:Iftheshapeof`sample_weight`isinvalid. """ returnhuber(target,preds,self.delta) __init__(self,delta=1.0,reduction=None,weight=None,on=None,**kwargs) special InitializesMeanclass. Parameters: Name Type Description Default delta float (Optional)Defaultsto1.0.Afloat,thepointwheretheHuberlossfunctionchangesfromaquadratictolinear. 1.0 reduction Optional[treex.losses.loss.Reduction] (Optional)Typeoftx.losses.Reductiontoapplyto loss.DefaultvalueisSUM_OVER_BATCH_SIZE.Foralmostallcases thisdefaultstoSUM_OVER_BATCH_SIZE. None weight Optional[float] Optionalweightcontributionforthetotalloss.Defaultsto1. None on Union[str,int,Sequence[Union[str,int]]] Astringorinteger,oriterableofstringorintegers,that indicatehowtoindex/filterthetargetandpreds argumentsbeforepassingthemtocall.Forexampleifon="a"then target=target["a"].Ifonisaniterable thestructureswillbeindexediteratively,forexampleifon=["a",0,"b"] thentarget=target["a"][0]["b"],sameforpreds.Formoreinformation checkoutKeras-likebehavior. None Sourcecodeintreex/losses/huber.py def__init__( self, delta:float=1.0, reduction:tp.Optional[Reduction]=None, weight:tp.Optional[float]=None, on:tp.Optional[types.IndexLike]=None, **kwargs ): """ Initializes`Mean`class. Arguments: delta:(Optional)Defaultsto1.0.Afloat,thepointwheretheHuberlossfunctionchangesfromaquadratictolinear. reduction:(Optional)Typeof`tx.losses.Reduction`toapplyto loss.Defaultvalueis`SUM_OVER_BATCH_SIZE`.Foralmostallcases thisdefaultsto`SUM_OVER_BATCH_SIZE`. weight:Optionalweightcontributionforthetotalloss.Defaultsto`1`. on:Astringorinteger,oriterableofstringorintegers,that indicatehowtoindex/filterthe`target`and`preds` argumentsbeforepassingthemto`call`.Forexampleif`on="a"`then `target=target["a"]`.If`on`isaniterable thestructureswillbeindexediteratively,forexampleif`on=["a",0,"b"]` then`target=target["a"][0]["b"]`,samefor`preds`.Formoreinformation checkout[Keras-likebehavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior). """ self.delta=delta returnsuper().__init__(reduction=reduction,weight=weight,on=on,**kwargs) call(self,target,preds,sample_weight=None) InvokestheHuberinstance. Parameters: Name Type Description Default target ndarray Groundtruthvalues.shape=[batch_size,d0,..dN],except sparselossfunctionssuchassparsecategoricalcrossentropywhere shape=[batch_size,d0,..dN-1] required preds ndarray Thepredictedvalues.shape=[batch_size,d0,..dN] required sample_weight Optional[jax._src.numpy.lax_numpy.ndarray] Optionalsample_weightactsasa coefficientfortheloss.Ifascalarisprovided,thenthelossis simplyscaledbythegivenvalue.Ifsample_weightisatensorofsize [batch_size],thenthetotallossforeachsampleofthebatchis rescaledbythecorrespondingelementinthesample_weightvector.If theshapeofsample_weightis[batch_size,d0,..dN-1](orcanbe broadcastedtothisshape),theneachlosselementofpredsisscaled bythecorrespondingvalueofsample_weight.(NoteondN-1:allloss functionsreduceby1dimension,usuallyaxis=-1.) None Returns: Type Description ndarray WeightedlossfloatTensor.IfreductionisNONE,thishas shape[batch_size,d0,..dN-1];otherwise,itisscalar.(NotedN-1 becausealllossfunctionsreduceby1dimension,usuallyaxis=-1.) Exceptions: Type Description ValueError Iftheshapeofsample_weightisinvalid. Sourcecodeintreex/losses/huber.py defcall( self, target:jnp.ndarray, preds:jnp.ndarray, sample_weight:tp.Optional[ jnp.ndarray ]=None,#notused,__call__handlesit,leftfordocumentationpurposes. )->jnp.ndarray: """ Invokesthe`Huber`instance. Arguments: target:Groundtruthvalues.shape=`[batch_size,d0,..dN]`,except sparselossfunctionssuchassparsecategoricalcrossentropywhere shape=`[batch_size,d0,..dN-1]` preds:Thepredictedvalues.shape=`[batch_size,d0,..dN]` sample_weight:Optional`sample_weight`actsasa coefficientfortheloss.Ifascalarisprovided,thenthelossis simplyscaledbythegivenvalue.If`sample_weight`isatensorofsize `[batch_size]`,thenthetotallossforeachsampleofthebatchis rescaledbythecorrespondingelementinthe`sample_weight`vector.If theshapeof`sample_weight`is`[batch_size,d0,..dN-1]`(orcanbe broadcastedtothisshape),theneachlosselementof`preds`isscaled bythecorrespondingvalueof`sample_weight`.(Noteon`dN-1`:allloss functionsreduceby1dimension,usuallyaxis=-1.) Returns: Weightedlossfloat`Tensor`.If`reduction`is`NONE`,thishas shape`[batch_size,d0,..dN-1]`;otherwise,itisscalar.(Note`dN-1` becausealllossfunctionsreduceby1dimension,usuallyaxis=-1.) Raises: ValueError:Iftheshapeof`sample_weight`isinvalid. """ returnhuber(target,preds,self.delta) huber(target,preds,delta) ComputestheHuberlossbetweentargetandpredictions. Foreachvaluexinerror=target-preds: loss= \begin{cases} \0.5\timesx^2,\hskip8em\text{if}|x|\leqd\\ 0.5\timesd^2+d\times(|x|-d),\hskip1.7em\text{otherwise} \end{cases} wheredisdelta.See:https://en.wikipedia.org/wiki/Huber_loss Usage: rng=jax.random.PRNGKey(42) target=jax.random.randint(rng,shape=(2,3),minval=0,maxval=2) preds=jax.random.uniform(rng,shape=(2,3)) loss=tx.losses.huber(target,preds,delta=1.0) assertloss.shape==(2,) preds=preds.astype(float) target=target.astype(float) delta=1.0 error=jnp.subtract(preds,target) abs_error=jnp.abs(error) quadratic=jnp.minimum(abs_error,delta) linear=jnp.subtract(abs_error,quadratic) assertjnp.array_equal(loss,jnp.mean( jnp.add( jnp.multiply( 0.5, jnp.multiply(quadratic,quadratic) ), jnp.multiply(delta,linear)),axis=-1 )) Parameters: Name Type Description Default target ndarray Groundtruthvalues.shape=[batch_size,d0,..dN]. required preds ndarray Thepredictedvalues.shape=[batch_size,d0,..dN]. required delta float Afloat,thepointwheretheHuberlossfunctionchangesfromaquadratictolinear. required Returns: Type Description ndarray huberlossValues.IfreductionisNONE,thishas shape[batch_size,d0,..dN-1];otherwise,itisscalar. (NotedN-1becausealllossfunctionsreduceby1dimension,usuallyaxis=-1.) Sourcecodeintreex/losses/huber.py defhuber(target:jnp.ndarray,preds:jnp.ndarray,delta:float)->jnp.ndarray: r""" ComputestheHuberlossbetweentargetandpredictions. Foreachvaluexinerror=target-preds: $$ loss= \begin{cases} \0.5\timesx^2,\hskip8em\text{if}|x|\leqd\\ 0.5\timesd^2+d\times(|x|-d),\hskip1.7em\text{otherwise} \end{cases} $$ wheredisdelta.See:https://en.wikipedia.org/wiki/Huber_loss Usage: ```python rng=jax.random.PRNGKey(42) target=jax.random.randint(rng,shape=(2,3),minval=0,maxval=2) preds=jax.random.uniform(rng,shape=(2,3)) loss=tx.losses.huber(target,preds,delta=1.0) assertloss.shape==(2,) preds=preds.astype(float) target=target.astype(float) delta=1.0 error=jnp.subtract(preds,target) abs_error=jnp.abs(error) quadratic=jnp.minimum(abs_error,delta) linear=jnp.subtract(abs_error,quadratic) assertjnp.array_equal(loss,jnp.mean( jnp.add( jnp.multiply( 0.5, jnp.multiply(quadratic,quadratic) ), jnp.multiply(delta,linear)),axis=-1 )) ``` Arguments: target:Groundtruthvalues.shape=`[batch_size,d0,..dN]`. preds:Thepredictedvalues.shape=`[batch_size,d0,..dN]`. delta:Afloat,thepointwheretheHuberlossfunctionchangesfromaquadratictolinear. Returns: huberlossValues.IfreductionisNONE,thishas shape[batch_size,d0,..dN-1];otherwise,itisscalar. (NotedN-1becausealllossfunctionsreduceby1dimension,usuallyaxis=-1.) """ preds=preds.astype(float) target=target.astype(float) delta=float(delta) error=jnp.subtract(preds,target) abs_error=jnp.abs(error) quadratic=jnp.minimum(abs_error,delta) linear=jnp.subtract(abs_error,quadratic) returnjnp.mean( jnp.add( jnp.multiply(0.5,jnp.multiply(quadratic,quadratic)), jnp.multiply(delta,linear), ), axis=-1, )
延伸文章資訊
- 1Losses - Keras
Losses. The purpose of loss functions is to compute the quantity that a model ... cosine_similari...
- 2tf.keras.losses.Huber - TensorFlow 2.3 - W3cubDocs
Computes the Huber loss between y_true and y_pred. ... Huber. tf.keras.losses.Huber( delta=1.0, r...
- 3tf.keras.losses.Huber | TensorFlow v2.10.0
tf.keras.losses.Huber ; delta, A float, the point where the Huber loss function changes from a qu...
- 4keras-loss-functions/huber-loss.py at master - GitHub
Keras model demonstrating Huber loss. ''' from keras.datasets import boston_housing. from keras.m...
- 5treex.losses.huber
Computes the Huber loss between target and predictions. ... behavior](https://poets-ai.github.io/...