importtypingimporterrrimportnumpyasnpimportscipy.stats.distributionsas_distributionsfrom..importconfigfrom..exceptionsimportDistributionCastErrorfrom.importtypesiftyping.TYPE_CHECKING:from..coreimportScaffold# Scan the scipy distributions module for all distribution names. Ignore `_gen` which are# the factory functions for the distribution classes. `rvs` is a duck type check._available_distributions=[dford,vin_distributions.__dict__.items()ifhasattr(v,"rvs")andnotd.endswith("_gen")]_available_distributions.append("constant")
[docs]@config.nodeclassDistribution:scaffold:"Scaffold"distribution:str=config.attr(type=types.in_(_available_distributions),required=True)""" Name of the scipy.stats distribution function. """parameters:dict[str,typing.Any]=config.catch_all(type=types.any_())""" Parameters to pass to the distribution. """def__init__(self,**kwargs):ifself.distribution=="constant":self._distr=_ConstantDistribution(self.parameters["constant"])returntry:self._distr=getattr(_distributions,self.distribution)(**self.parameters)exceptExceptionase:errr.wrap(DistributionCastError,e,prepend=f"Can't cast to '{self.distribution}': ",)
[docs]defdraw(self,n):""" Draw n random samples from the distribution. """returnself._distr.rvs(size=n)
[docs]defdefinition_interval(self,epsilon=0):""" Returns the `epsilon` and 1 - `epsilon` values of the distribution Percent point function. :param float epsilon: ratio of the interval to ignore """ifepsilon<0orepsilon>1:raiseValueError("Epsilon must be between 0 and 1")returnself._distr.ppf(epsilon),self._distr.ppf(1-epsilon)
[docs]defcdf(self,value):""" Returns the result of the cumulative distribution function for `value` :param float value: value to evaluate """returnself._distr.cdf(value)
[docs]defsf(self,value):""" Returns the result of the Survival function for `value` :param float value: value to evaluate """returnself._distr.sf(value)
def__getattr__(self,attr):if"_distr"notinself.__dict__:raiseAttributeError("No underlying _distr found for distribution node.")returngetattr(self._distr,attr)