decorator to control returned values

This commit is contained in:
Benjamin Bertrand 2016-01-09 18:40:02 +03:00
parent f7b31425e2
commit e1cb61a1a2
3 changed files with 50 additions and 8 deletions

View File

@ -9,6 +9,7 @@
from math import sqrt, ceil from math import sqrt, ceil
from random import randint, uniform, gauss from random import randint, uniform, gauss
from .number_tools import number_factory
class Dataset(list): class Dataset(list):
""" A dataset (a list) with statistics and latex rendering methods """ A dataset (a list) with statistics and latex rendering methods
@ -19,11 +20,11 @@ class Dataset(list):
>>> s.mean() >>> s.mean()
49.5 49.5
>>> s.deviation() >>> s.deviation()
83325.0 83325
>>> s.variance() >>> s.variance()
833.25 833.25
>>> s.sd() >>> s.sd()
28.86607004772212 28.87
""" """
@classmethod @classmethod
@ -118,20 +119,25 @@ class Dataset(list):
def effectif_total(self): def effectif_total(self):
return len(self) return len(self)
@number_factory
def sum(self): def sum(self):
return sum(self) return sum(self)
@number_factory
def mean(self): def mean(self):
return self.sum()/self.effectif_total() return self.sum()/self.effectif_total()
@number_factory
def deviation(self): def deviation(self):
""" Compute the deviation (not normalized) """ """ Compute the deviation (not normalized) """
mean = self.mean() mean = self.mean()
return sum([(x - mean)**2 for x in self]) return sum([(x - mean)**2 for x in self])
@number_factory
def variance(self): def variance(self):
return self.deviation()/self.effectif_total() return self.deviation()/self.effectif_total()
@number_factory
def sd(self): def sd(self):
""" Compute the standard deviation """ """ Compute the standard deviation """
return sqrt(self.variance()) return sqrt(self.variance())
@ -148,6 +154,7 @@ class Dataset(list):
""" """
return (min(self) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self)) return (min(self) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self))
@number_factory
def quartile(self, quartile = 1): def quartile(self, quartile = 1):
""" """
Calcul un quartile de la série. Calcul un quartile de la série.

View File

@ -0,0 +1,27 @@
#/usr/bin/env python
# -*- coding:Utf-8 -*-
from functools import wraps
def number_factory(fun):
""" Decorator which format returned value """
@wraps(fun)
def wrapper(*args, **kwargs):
ans = fun(*args, **kwargs)
try:
if ans.is_integer():
return int(ans)
else:
return round(ans, 2)
except AttributeError:
return ans
return wrapper
# -----------------------------
# Reglages pour 'vim'
# vim:set autoindent expandtab tabstop=4 shiftwidth=4:
# cursor: 16 del

View File

@ -11,6 +11,7 @@ from math import sqrt, ceil
from collections import Counter from collections import Counter
from .dataset import Dataset from .dataset import Dataset
from ..calculus.generic import flatten_list from ..calculus.generic import flatten_list
from .number_tools import number_factory
class WeightedDataset(dict): class WeightedDataset(dict):
@ -24,13 +25,13 @@ class WeightedDataset(dict):
>>> w.sum() >>> w.sum()
120 120
>>> w.mean() >>> w.mean()
2.608695652173913 2.61
>>> w.deviation() >>> w.deviation()
56.95652173913044 56.96
>>> w.variance() >>> w.variance()
1.2381852551984878 1.24
>>> w.sd() >>> w.sd()
1.1127377297451937 1.11
""" """
@ -57,27 +58,33 @@ class WeightedDataset(dict):
except KeyError: except KeyError:
self[data] = weight self[data] = weight
@number_factory
def total_weight(self): def total_weight(self):
return sum(self.values()) return sum(self.values())
def effectif_total(self): def effectif_total(self):
return self.total_weight() return self.total_weight()
@number_factory
def sum(self): def sum(self):
""" Not really a sum but the sum of the product of key and values """ """ Not really a sum but the sum of the product of key and values """
return sum([k*v for (k,v) in self.items()]) return sum([k*v for (k,v) in self.items()])
@number_factory
def mean(self): def mean(self):
return self.sum()/self.effectif_total() return self.sum()/self.effectif_total()
@number_factory
def deviation(self): def deviation(self):
""" Compute the deviation (not normalized) """ """ Compute the deviation (not normalized) """
mean = self.mean() mean = self.mean()
return sum([v*(k - mean)**2 for (k,v) in self.items()]) return sum([v*(k - mean)**2 for (k,v) in self.items()])
@number_factory
def variance(self): def variance(self):
return self.deviation()/self.effectif_total() return self.deviation()/self.effectif_total()
@number_factory
def sd(self): def sd(self):
""" Compute the standard deviation """ """ Compute the standard deviation """
return sqrt(self.variance()) return sqrt(self.variance())
@ -92,7 +99,7 @@ class WeightedDataset(dict):
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(5)])) >>> w = WeightedDataset(flatten_list([i*[i] for i in range(5)]))
>>> w.quartiles() >>> w.quartiles()
(1, 2, 3.0, 4, 4) (1, 2, 3, 4, 4)
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)])) >>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)]))
>>> w.quartiles() >>> w.quartiles()
(1, 3, 4, 5, 5) (1, 3, 4, 5, 5)
@ -100,6 +107,7 @@ class WeightedDataset(dict):
""" """
return (min(self.keys()) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self.keys())) return (min(self.keys()) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self.keys()))
@number_factory
def quartile(self, quartile = 1): def quartile(self, quartile = 1):
""" """
Calcul un quartile de la série. Calcul un quartile de la série.
@ -114,7 +122,7 @@ class WeightedDataset(dict):
>>> w.quartile(1) >>> w.quartile(1)
2 2
>>> w.quartile(2) >>> w.quartile(2)
3.0 3
>>> w.quartile(3) >>> w.quartile(3)
4 4
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)])) >>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)]))