From e1cb61a1a25e37ae5a450ff1f99ae63e1aa714f5 Mon Sep 17 00:00:00 2001 From: Benjamin Bertrand Date: Sat, 9 Jan 2016 18:40:02 +0300 Subject: [PATCH] decorator to control returned values --- pymath/stat/dataset.py | 11 +++++++++-- pymath/stat/number_tools.py | 27 +++++++++++++++++++++++++++ pymath/stat/weightedDataset.py | 20 ++++++++++++++------ 3 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 pymath/stat/number_tools.py diff --git a/pymath/stat/dataset.py b/pymath/stat/dataset.py index ddb2f31..b0932c8 100644 --- a/pymath/stat/dataset.py +++ b/pymath/stat/dataset.py @@ -9,6 +9,7 @@ from math import sqrt, ceil from random import randint, uniform, gauss +from .number_tools import number_factory class Dataset(list): """ A dataset (a list) with statistics and latex rendering methods @@ -19,11 +20,11 @@ class Dataset(list): >>> s.mean() 49.5 >>> s.deviation() - 83325.0 + 83325 >>> s.variance() 833.25 >>> s.sd() - 28.86607004772212 + 28.87 """ @classmethod @@ -118,20 +119,25 @@ class Dataset(list): def effectif_total(self): return len(self) + @number_factory def sum(self): return sum(self) + @number_factory def mean(self): return self.sum()/self.effectif_total() + @number_factory def deviation(self): """ Compute the deviation (not normalized) """ mean = self.mean() return sum([(x - mean)**2 for x in self]) + @number_factory def variance(self): return self.deviation()/self.effectif_total() + @number_factory def sd(self): """ Compute the standard deviation """ return sqrt(self.variance()) @@ -148,6 +154,7 @@ class Dataset(list): """ return (min(self) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self)) + @number_factory def quartile(self, quartile = 1): """ Calcul un quartile de la série. diff --git a/pymath/stat/number_tools.py b/pymath/stat/number_tools.py new file mode 100644 index 0000000..04a597b --- /dev/null +++ b/pymath/stat/number_tools.py @@ -0,0 +1,27 @@ +#/usr/bin/env python +# -*- coding:Utf-8 -*- + +from functools import wraps + + +def number_factory(fun): + """ Decorator which format returned value """ + @wraps(fun) + def wrapper(*args, **kwargs): + ans = fun(*args, **kwargs) + try: + if ans.is_integer(): + return int(ans) + else: + return round(ans, 2) + except AttributeError: + return ans + return wrapper + + + +# ----------------------------- +# Reglages pour 'vim' +# vim:set autoindent expandtab tabstop=4 shiftwidth=4: +# cursor: 16 del + diff --git a/pymath/stat/weightedDataset.py b/pymath/stat/weightedDataset.py index 3c0131c..63e4cff 100644 --- a/pymath/stat/weightedDataset.py +++ b/pymath/stat/weightedDataset.py @@ -11,6 +11,7 @@ from math import sqrt, ceil from collections import Counter from .dataset import Dataset from ..calculus.generic import flatten_list +from .number_tools import number_factory class WeightedDataset(dict): @@ -24,13 +25,13 @@ class WeightedDataset(dict): >>> w.sum() 120 >>> w.mean() - 2.608695652173913 + 2.61 >>> w.deviation() - 56.95652173913044 + 56.96 >>> w.variance() - 1.2381852551984878 + 1.24 >>> w.sd() - 1.1127377297451937 + 1.11 """ @@ -57,27 +58,33 @@ class WeightedDataset(dict): except KeyError: self[data] = weight + @number_factory def total_weight(self): return sum(self.values()) def effectif_total(self): return self.total_weight() + @number_factory def sum(self): """ Not really a sum but the sum of the product of key and values """ return sum([k*v for (k,v) in self.items()]) + @number_factory def mean(self): return self.sum()/self.effectif_total() + @number_factory def deviation(self): """ Compute the deviation (not normalized) """ mean = self.mean() return sum([v*(k - mean)**2 for (k,v) in self.items()]) + @number_factory def variance(self): return self.deviation()/self.effectif_total() + @number_factory def sd(self): """ Compute the standard deviation """ return sqrt(self.variance()) @@ -92,7 +99,7 @@ class WeightedDataset(dict): >>> w = WeightedDataset(flatten_list([i*[i] for i in range(5)])) >>> w.quartiles() - (1, 2, 3.0, 4, 4) + (1, 2, 3, 4, 4) >>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)])) >>> w.quartiles() (1, 3, 4, 5, 5) @@ -100,6 +107,7 @@ class WeightedDataset(dict): """ return (min(self.keys()) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self.keys())) + @number_factory def quartile(self, quartile = 1): """ Calcul un quartile de la série. @@ -114,7 +122,7 @@ class WeightedDataset(dict): >>> w.quartile(1) 2 >>> w.quartile(2) - 3.0 + 3 >>> w.quartile(3) 4 >>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)]))