decorator to control returned values

This commit is contained in:
Benjamin Bertrand 2016-01-09 18:40:02 +03:00
parent f7b31425e2
commit e1cb61a1a2
3 changed files with 50 additions and 8 deletions

View File

@ -9,6 +9,7 @@
from math import sqrt, ceil
from random import randint, uniform, gauss
from .number_tools import number_factory
class Dataset(list):
""" A dataset (a list) with statistics and latex rendering methods
@ -19,11 +20,11 @@ class Dataset(list):
>>> s.mean()
49.5
>>> s.deviation()
83325.0
83325
>>> s.variance()
833.25
>>> s.sd()
28.86607004772212
28.87
"""
@classmethod
@ -118,20 +119,25 @@ class Dataset(list):
def effectif_total(self):
return len(self)
@number_factory
def sum(self):
return sum(self)
@number_factory
def mean(self):
return self.sum()/self.effectif_total()
@number_factory
def deviation(self):
""" Compute the deviation (not normalized) """
mean = self.mean()
return sum([(x - mean)**2 for x in self])
@number_factory
def variance(self):
return self.deviation()/self.effectif_total()
@number_factory
def sd(self):
""" Compute the standard deviation """
return sqrt(self.variance())
@ -148,6 +154,7 @@ class Dataset(list):
"""
return (min(self) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self))
@number_factory
def quartile(self, quartile = 1):
"""
Calcul un quartile de la série.

View File

@ -0,0 +1,27 @@
#/usr/bin/env python
# -*- coding:Utf-8 -*-
from functools import wraps
def number_factory(fun):
""" Decorator which format returned value """
@wraps(fun)
def wrapper(*args, **kwargs):
ans = fun(*args, **kwargs)
try:
if ans.is_integer():
return int(ans)
else:
return round(ans, 2)
except AttributeError:
return ans
return wrapper
# -----------------------------
# Reglages pour 'vim'
# vim:set autoindent expandtab tabstop=4 shiftwidth=4:
# cursor: 16 del

View File

@ -11,6 +11,7 @@ from math import sqrt, ceil
from collections import Counter
from .dataset import Dataset
from ..calculus.generic import flatten_list
from .number_tools import number_factory
class WeightedDataset(dict):
@ -24,13 +25,13 @@ class WeightedDataset(dict):
>>> w.sum()
120
>>> w.mean()
2.608695652173913
2.61
>>> w.deviation()
56.95652173913044
56.96
>>> w.variance()
1.2381852551984878
1.24
>>> w.sd()
1.1127377297451937
1.11
"""
@ -57,27 +58,33 @@ class WeightedDataset(dict):
except KeyError:
self[data] = weight
@number_factory
def total_weight(self):
return sum(self.values())
def effectif_total(self):
return self.total_weight()
@number_factory
def sum(self):
""" Not really a sum but the sum of the product of key and values """
return sum([k*v for (k,v) in self.items()])
@number_factory
def mean(self):
return self.sum()/self.effectif_total()
@number_factory
def deviation(self):
""" Compute the deviation (not normalized) """
mean = self.mean()
return sum([v*(k - mean)**2 for (k,v) in self.items()])
@number_factory
def variance(self):
return self.deviation()/self.effectif_total()
@number_factory
def sd(self):
""" Compute the standard deviation """
return sqrt(self.variance())
@ -92,7 +99,7 @@ class WeightedDataset(dict):
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(5)]))
>>> w.quartiles()
(1, 2, 3.0, 4, 4)
(1, 2, 3, 4, 4)
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)]))
>>> w.quartiles()
(1, 3, 4, 5, 5)
@ -100,6 +107,7 @@ class WeightedDataset(dict):
"""
return (min(self.keys()) , self.quartile(1) , self.quartile(2) , self.quartile(3), max(self.keys()))
@number_factory
def quartile(self, quartile = 1):
"""
Calcul un quartile de la série.
@ -114,7 +122,7 @@ class WeightedDataset(dict):
>>> w.quartile(1)
2
>>> w.quartile(2)
3.0
3
>>> w.quartile(3)
4
>>> w = WeightedDataset(flatten_list([i*[i] for i in range(6)]))