Python Decorator¶

by Kardi Teknomo

A decorator is wrapping object or function to existing objects or function. You can modify the behavior of a function or a class using decorator. The purpose is to add functionality (hence, to decorate) to existing objects without affecting the other objects. Say, you want to maintain the original functionality while adding new functionality or you want to deliver both the core functionality and then “decorate” those cores with functionality unique to the customer, then Python decorator is for you.

The technique is to create a wrapper function (or class) such that the wrapped function (or object) would change behavior without changing its code. This happens because function in Python can be passed as an argument of another function or another class as if it is a variable. In a decorator, function is taken as the argument into decorator function and then called inside the wrapper function inside the decorator.

In this tutorial, you will learn about:

  • Built-In Decorators
    • Property Decorator
    • Class Method and Static Method Decorators
  • timer decorator
  • debug decorator
  • error_handler decorator
  • flatten argument decorator
  • count_calls decorator
  • cache/memoization decorator
  • plot decorator
  • ode_solver decorator

Other Example of Applications of Python Decorators:

  • Data validation
  • Database Transaction
  • Logging
  • Monitoring
  • Business rules
  • Compression
  • Encryption

Built-In Decorators¶

Python contains many built-in decorators such as (@property,@classmethod, @staticmethod). Let us discuss some of them.

Property Decorator¶

@property is a built-in decorator to change class method into property to customize getters and setters for class attributes. We can change class method into class property by adding decorator @property above the method. In the the example below, without @property decorator, we must refer the area of rectangle as rectangleA.area(). With @property decorator, the area() method would behave as property.

In [1]:
class Rectangle():
    def __init__(self, width, height):
        self.width = width
        self.height = height

    @property  # add this line
    def area(self):
        return self.width * self.height


rectangleA = Rectangle(width=2, height=4)
print('Area rectangle A is ',
      rectangleA.area)  # remove parenthesis from area()

rectangleB = Rectangle(width=3, height=7.5)
print('Area rectangle B is ',
      rectangleB.area)  # remove parenthesis from area()
Area rectangle A is  8
Area rectangle B is  22.5

Class Method and Static Method Decorators¶

  • @classmethod: to call the method on the class instead of an object. This is useful to create factory of classes.
  • @staticmethod: to group functions which have some logical connection with a class to the class. A staticmethod is a method that knows nothing about the class or instance it was called on.

In the example below, a person object must be defined by its name and age. There is no function to to define the person object by the birth year.

Notice that isAdult(age) method does not contain self. This is an example of static method. We can call the method directly without defining the object. We cannot call person1.isAdult(16)

In [2]:
class Person:
    def __init__(self, name, age):
        self.name = name
        self.age = age
    
    def isAdult(age):
        return age > 18
 
 
person1 = Person('Maya', 16)
person2 = Person('Budi', 24)
print(person1.name,person1.age)
print(person2.name,person2.age)
print(Person.isAdult(19))

try: 
    print(person1.isAdult(16))
except Exception as e:
    print(e)
Maya 16
Budi 24
True
Person.isAdult() takes 1 positional argument but 2 were given

Now let us improve the class above by adding a class method to create a Person object by birth year and to explicitly state that isAdult(age) is actually a static method.

In [3]:
from datetime import date
 
class Person:
    def __init__(self, name, age):
        self.name = name
        self.age = age
 
    # a class method to create a Person object by birth year.
    @classmethod
    def fromBirthYear(cls, name, year):
        return cls(name, date.today().year - year)
 
    # a static method to check if a Person is adult or not.
    @staticmethod
    def isAdult(age):
        return age > 18
 
 
person1 = Person('Maya', 16)

# an alternative definition: the person can be defined from the class method
person2 = Person.fromBirthYear('Budi', 1999)
 
print(person1.name,person1.age)
print(person2.name,person2.age)
 
# using static method, we dont need to know the person
print(Person.isAdult(22))
print(person1.isAdult(16))
Maya 16
Budi 24
True
False

Below is another example that uses all the three decorators that we have discussed above.

In [4]:
class Circle:
    def __init__(self, radius):
        self._radius = radius

    @property
    def radius(self):
        """Get value of radius"""
        return self._radius

    @radius.setter
    def radius(self, value):
        """Set radius, raise error if negative"""
        if value >= 0:
            self._radius = value
        else:
            raise ValueError("Radius must be positive")

    @property
    def area(self):
        """Calculate area inside circle"""
        return self.pi() * self.radius**2

    def cylinder_volume(self, height):
        """Calculate volume of cylinder with circle as base"""
        return self.area * height

    @classmethod
    def unit_circle(cls):
        """Factory method creating a circle with radius 1"""
        return cls(1)
    
    @staticmethod
    def pi():
        """Value of π, could also use math.pi instead of a constant here"""
        return 3.1415926535
In [5]:
print('---new circle---')
c=Circle(2)  # define circle with radius 2 units
print('area:',c.area)
print('radius:',c.radius)
print('volume:',c.cylinder_volume(3))
print('---new circle---')
c.radius=1   # redefine the radius
print('area:',c.area)
print('radius:',c.radius)
print('volume:',c.cylinder_volume(3))
print('pi:',c.pi())
print('---new circle---')
f=Circle.unit_circle() # we don't need to enter the radius of the circle
print('area:',f.area)
print('radius:',f.radius)
print('volume:',f.cylinder_volume(3))
---new circle---
area: 12.566370614
radius: 2
volume: 37.699111842
---new circle---
area: 3.1415926535
radius: 1
volume: 9.4247779605
pi: 3.1415926535
---new circle---
area: 3.1415926535
radius: 1
volume: 9.4247779605

timer decorator¶

Let us define a useful decorator to print the run time of any iterative function. This timer function would take any function as the argument. Then we define a wrapper function that would take all the arguments and dictionary arguments of the wrapped function. The wrapper function would run the wrapped function, get the result, print the run time then pass the result back.

In [6]:
import time
import functools

def timer(func):
    @functools.wraps(func) # to preserve information about the original function
    def wrapper(*args, **kwargs):
        start_time=time.time()
        result=func(*args, **kwargs)
        end_time=time.time()
        print(f"run time of {func.__name__}: {end_time-start_time:.4f} seconds")
        return result
    return wrapper

To use the timer function above as a decorator, let us define the wrapped function and put @timer decorator on top of the wrapped function.

In [7]:
@timer
def my_wrapped_function(argument):
    time.sleep(1)
    return str(argument)+"abc"

my_wrapped_function(123)
run time of my_wrapped_function: 1.0094 seconds
Out[7]:
'123abc'

The meaning of decorator above is the same as the following function.

In [8]:
my_wrapped_function=timer(my_wrapped_function("123"))
run time of my_wrapped_function: 1.0062 seconds

For recurrence function, timer decorator above would print at every recursive calls and this is not desireable. To solve that problem, let us also define a Timer class. Source: real python

In [9]:
import time

class TimerError(Exception):
    """A custom exception used to report errors in use of Timer class"""

class Timer:
    def __init__(self):
        self._start_time = None
    
    
    def start(self):
        """Start a new timer"""
        if self._start_time is not None:
            raise TimerError(f"Timer is running. Use .stop() to stop it")

        self._start_time = time.perf_counter()
    
    
    def stop(self):
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            raise TimerError(f"Timer is not running. Use .start() to start it")

        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None
        print(f"Elapsed time: {elapsed_time:0.4f} seconds")
In [10]:
def my_wrapped_function(argument):
    time.sleep(1)
    return str(argument)+"abc"

t=Timer()
t.start()
my_wrapped_function("argument")
t.stop()
Elapsed time: 1.0025 seconds

debug decorator¶

The debug decorator below (source: real python) would print the arguments whenever the wrapped function is called and its return value. This is useful to see how the function works

In [11]:
import functools

def debug(func):
    """Print the function arguments as string and the return value"""
    @functools.wraps(func)
    def wrapper_debug(*args, **kwargs):
        args_repr = [repr(a) for a in args]                      # list of the positional arguments
        kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]  # list of the keyword arguments
        signature = ", ".join(args_repr + kwargs_repr)           # join positional and keyword arguments
        print(f"calling {func.__name__}({signature})")
        value = func(*args, **kwargs)
        print(f"{func.__name__!r} returned {value!r}")           # print return value
        return value
    return wrapper_debug

As example, we can debug how the fibonacci recursive function works.

In [12]:
@debug
def fib(n):
    if n<2:
        return n
    else:
        return fib(n-1)+fib(n-2)

t.start()
fib(7)
t.stop()
calling fib(7)
calling fib(6)
calling fib(5)
calling fib(4)
calling fib(3)
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
calling fib(1)
'fib' returned 1
'fib' returned 2
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
'fib' returned 3
calling fib(3)
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
calling fib(1)
'fib' returned 1
'fib' returned 2
'fib' returned 5
calling fib(4)
calling fib(3)
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
calling fib(1)
'fib' returned 1
'fib' returned 2
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
'fib' returned 3
'fib' returned 8
calling fib(5)
calling fib(4)
calling fib(3)
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
calling fib(1)
'fib' returned 1
'fib' returned 2
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
'fib' returned 3
calling fib(3)
calling fib(2)
calling fib(1)
'fib' returned 1
calling fib(0)
'fib' returned 0
'fib' returned 1
calling fib(1)
'fib' returned 1
'fib' returned 2
'fib' returned 5
'fib' returned 13
Elapsed time: 0.0006 seconds

error_handler decorator¶

Sometimes we want the program to run smoothly without terminating it even if there is an error. However, we want to get the warning on the error. Error handler decorator is useful for that purpose.

In [13]:
def error_handler(func):
    def wrapper(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except TypeError:
            print(f"{func.__name__} wrong data types.")
        except NameError:
            print(f"{func.__name__} has wrong variable.")
        except ZeroDivisionError: 
            print(f"function '{func.__name__}' contains division by zero.")
        except OSError as err:
            print("OS error:", err)
        except SystemError:
            print("There were SystemErrors")
        except ValueError:
            print("Could not convert data to an integer.")
        except Exception as e:
            print(f'caught {type(e)}: e')
        except Exception as err:
            print(f"Unexpected {err=}, {type(err)=}")
    return wrapper
In [14]:
@error_handler
def mean(a,b):
    return (a*b)/(a+b)

mean(0,0)
function 'mean' contains division by zero.

flatten argument decorator¶

The following decorator would allow a multi-argument function to be called with arguments in list/tuple. Source: stack overflow

In [15]:
import functools

def flatten_args(func):
    @functools.wraps(func)
    def wrapper(*args):
        if len(args) == 1:
            return func(*args[0])
        else:
            return func(*args)
    return wrapper
In [16]:
@flatten_args
def pow(base,exp):
    return base**exp
In [17]:
pow(3,4)
Out[17]:
81
In [18]:
pow([3,4]) # this is where the flatten argument take place
Out[18]:
81

count_calls decorator¶

The following decorator maintain the states of the number of calls. Source: real python

In [19]:
import functools

def count_calls(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        wrapper.num_calls += 1
        if wrapper.num_calls==1:
            print(f"calling {func.__name__!r} for {wrapper.num_calls} time")
        else:
            print(f"calling {func.__name__!r} for {wrapper.num_calls} times ")
        return func(*args, **kwargs)
    wrapper.num_calls = 0
    return wrapper
In [20]:
@count_calls
def fib(n):
    if n<2:
        return n
    else:
        return fib(n-1)+fib(n-2)

fib(7)
calling 'fib' for 1 time
calling 'fib' for 2 times 
calling 'fib' for 3 times 
calling 'fib' for 4 times 
calling 'fib' for 5 times 
calling 'fib' for 6 times 
calling 'fib' for 7 times 
calling 'fib' for 8 times 
calling 'fib' for 9 times 
calling 'fib' for 10 times 
calling 'fib' for 11 times 
calling 'fib' for 12 times 
calling 'fib' for 13 times 
calling 'fib' for 14 times 
calling 'fib' for 15 times 
calling 'fib' for 16 times 
calling 'fib' for 17 times 
calling 'fib' for 18 times 
calling 'fib' for 19 times 
calling 'fib' for 20 times 
calling 'fib' for 21 times 
calling 'fib' for 22 times 
calling 'fib' for 23 times 
calling 'fib' for 24 times 
calling 'fib' for 25 times 
calling 'fib' for 26 times 
calling 'fib' for 27 times 
calling 'fib' for 28 times 
calling 'fib' for 29 times 
calling 'fib' for 30 times 
calling 'fib' for 31 times 
calling 'fib' for 32 times 
calling 'fib' for 33 times 
calling 'fib' for 34 times 
calling 'fib' for 35 times 
calling 'fib' for 36 times 
calling 'fib' for 37 times 
calling 'fib' for 38 times 
calling 'fib' for 39 times 
calling 'fib' for 40 times 
calling 'fib' for 41 times 
Out[20]:
13

memoization decorator¶

The memoize decorator would return the cache of any function. First, it would check if the function has been already been called with the same argument. If so, it would return the cache value instead of calling the function again. Thus, it would speed up the computation.

In [21]:
import functools

def memoize(func):
    cache = {}
    
    @functools.wraps(func)
    def wrapper(*args):
        if args in cache:
            return cache[args]
        else:
            result=func(*args)
            cache[args] = result
            return result
    return wrapper

To see how it works, let us use this memoize function to compute fibonacci number. First, we use debug decorator to print the function calls. Then we chain it with memoize decorator.

In [22]:
@count_calls
@memoize
def fib(n):
    if n<2:
        return n
    else:
        return fib(n-1)+fib(n-2)

t.start()
fib(7)
t.stop()
calling 'fib' for 1 time
calling 'fib' for 2 times 
calling 'fib' for 3 times 
calling 'fib' for 4 times 
calling 'fib' for 5 times 
calling 'fib' for 6 times 
calling 'fib' for 7 times 
calling 'fib' for 8 times 
calling 'fib' for 9 times 
calling 'fib' for 10 times 
calling 'fib' for 11 times 
calling 'fib' for 12 times 
calling 'fib' for 13 times 
Elapsed time: 0.0001 seconds

The chain of decorators would works differently if we reverse the chain order.

In [23]:
@memoize
@count_calls
def fib(n):
    if n<2:
        return n
    else:
        return fib(n-1)+fib(n-2)

t.start()
fib(7)
t.stop()
calling 'fib' for 1 time
calling 'fib' for 2 times 
calling 'fib' for 3 times 
calling 'fib' for 4 times 
calling 'fib' for 5 times 
calling 'fib' for 6 times 
calling 'fib' for 7 times 
calling 'fib' for 8 times 
Elapsed time: 0.0003 seconds

Python standard library has already memoize function call @lru_cache. You just need to call it.

Let us define fibonacci function with and without memoize decorator and without debug decorator to check the run time of large fibonacci number.

In [24]:
# without memoization
def fib0(n):
    if n<2:
        return n
    else:
        return fib0(n-1)+fib0(n-2)

from functools import lru_cache


@memoize
def fib1(n):
    if n<2:
        return n
    else:
        return fib1(n-1)+fib1(n-2)

@lru_cache(maxsize=None)
def fib2(n):
    if n<2:
        return n
    else:
        return fib2(n-1)+fib2(n-2)
In [25]:
# let test them to see the difference in run time
n=35

print('without memoize')
t=Timer()
t.start()
fib0(n)
t.stop()

print('with memoize')
t.start()
fib1(n)
t.stop()

print('with lru_cache')
t.start()
fib2(n)
t.stop()
without memoize
Elapsed time: 2.1269 seconds
with memoize
Elapsed time: 0.0000 seconds
with lru_cache
Elapsed time: 0.0000 seconds

plot decorator¶

Using decorator we can plot any function and also return the array result of the function.

In [26]:
import matplotlib.pyplot as plt

def plot(func):
    def wrapper(*args, **kwargs):
        result=func(*args, **kwargs)
        plt.plot(result)
        plt.show()
        return result
    return wrapper
In [27]:
@plot
def cosinewave():
    import numpy as np
    x=np.linspace(0,2*np.pi,100)
    y=np.cos(x)
    return y

@plot
def sinewave():
    import numpy as np
    x=np.linspace(0,2*np.pi,100)
    y=np.sin(x)
    return y
In [28]:
cosinewave()
Out[28]:
array([ 1.        ,  0.99798668,  0.99195481,  0.9819287 ,  0.9679487 ,
        0.95007112,  0.92836793,  0.90292654,  0.87384938,  0.84125353,
        0.80527026,  0.76604444,  0.72373404,  0.67850941,  0.63055267,
        0.58005691,  0.52722547,  0.47227107,  0.41541501,  0.35688622,
        0.29692038,  0.23575894,  0.17364818,  0.1108382 ,  0.04758192,
       -0.01586596, -0.07924996, -0.14231484, -0.20480667, -0.26647381,
       -0.32706796, -0.38634513, -0.44406661, -0.5       , -0.55392006,
       -0.60560969, -0.65486073, -0.70147489, -0.74526445, -0.78605309,
       -0.82367658, -0.85798341, -0.88883545, -0.91610846, -0.93969262,
       -0.95949297, -0.97542979, -0.98743889, -0.99547192, -0.99949654,
       -0.99949654, -0.99547192, -0.98743889, -0.97542979, -0.95949297,
       -0.93969262, -0.91610846, -0.88883545, -0.85798341, -0.82367658,
       -0.78605309, -0.74526445, -0.70147489, -0.65486073, -0.60560969,
       -0.55392006, -0.5       , -0.44406661, -0.38634513, -0.32706796,
       -0.26647381, -0.20480667, -0.14231484, -0.07924996, -0.01586596,
        0.04758192,  0.1108382 ,  0.17364818,  0.23575894,  0.29692038,
        0.35688622,  0.41541501,  0.47227107,  0.52722547,  0.58005691,
        0.63055267,  0.67850941,  0.72373404,  0.76604444,  0.80527026,
        0.84125353,  0.87384938,  0.90292654,  0.92836793,  0.95007112,
        0.9679487 ,  0.9819287 ,  0.99195481,  0.99798668,  1.        ])
In [29]:
sinewave()
Out[29]:
array([ 0.00000000e+00,  6.34239197e-02,  1.26592454e-01,  1.89251244e-01,
        2.51147987e-01,  3.12033446e-01,  3.71662456e-01,  4.29794912e-01,
        4.86196736e-01,  5.40640817e-01,  5.92907929e-01,  6.42787610e-01,
        6.90079011e-01,  7.34591709e-01,  7.76146464e-01,  8.14575952e-01,
        8.49725430e-01,  8.81453363e-01,  9.09631995e-01,  9.34147860e-01,
        9.54902241e-01,  9.71811568e-01,  9.84807753e-01,  9.93838464e-01,
        9.98867339e-01,  9.99874128e-01,  9.96854776e-01,  9.89821442e-01,
        9.78802446e-01,  9.63842159e-01,  9.45000819e-01,  9.22354294e-01,
        8.95993774e-01,  8.66025404e-01,  8.32569855e-01,  7.95761841e-01,
        7.55749574e-01,  7.12694171e-01,  6.66769001e-01,  6.18158986e-01,
        5.67059864e-01,  5.13677392e-01,  4.58226522e-01,  4.00930535e-01,
        3.42020143e-01,  2.81732557e-01,  2.20310533e-01,  1.58001396e-01,
        9.50560433e-02,  3.17279335e-02, -3.17279335e-02, -9.50560433e-02,
       -1.58001396e-01, -2.20310533e-01, -2.81732557e-01, -3.42020143e-01,
       -4.00930535e-01, -4.58226522e-01, -5.13677392e-01, -5.67059864e-01,
       -6.18158986e-01, -6.66769001e-01, -7.12694171e-01, -7.55749574e-01,
       -7.95761841e-01, -8.32569855e-01, -8.66025404e-01, -8.95993774e-01,
       -9.22354294e-01, -9.45000819e-01, -9.63842159e-01, -9.78802446e-01,
       -9.89821442e-01, -9.96854776e-01, -9.99874128e-01, -9.98867339e-01,
       -9.93838464e-01, -9.84807753e-01, -9.71811568e-01, -9.54902241e-01,
       -9.34147860e-01, -9.09631995e-01, -8.81453363e-01, -8.49725430e-01,
       -8.14575952e-01, -7.76146464e-01, -7.34591709e-01, -6.90079011e-01,
       -6.42787610e-01, -5.92907929e-01, -5.40640817e-01, -4.86196736e-01,
       -4.29794912e-01, -3.71662456e-01, -3.12033446e-01, -2.51147987e-01,
       -1.89251244e-01, -1.26592454e-01, -6.34239197e-02, -2.44929360e-16])

ode_solver decorator¶

We can extend the decorator into any reusable function. For instance, here we can create ordinarry differential equation solver (ode solver) and use it as decorator of any function.

In [30]:
import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

def ode_solver(func):
    def wrapper(t_span,y0,args=(),**kwargs):
        solution=solve_ivp(func,t_span,y0,args=args,**kwargs)
        return solution.t, solution.y
    return wrapper
In [31]:
@ode_solver
def exponential_func(t,y,k):
    return k*y
In [32]:
t_span=[0,10]
y0=[1]
k=0.3
t,y=exponential_func(t_span,y0,args=(k,),method="RK45")

plt.plot(t,y[0])
plt.xlabel('time')
plt.ylabel('population')
plt.show()

print(t,y)
[ 0.          0.1272514   1.39976539  4.88735235  9.11252593 10.        ] [[ 1.          1.03891346  1.52185507  4.33284482 15.39054055 20.08544355]]

References¶

  • https://realpython.com/primer-on-python-decorators/
  • https://realpython.com/python-timer/
  • https://www.geeksforgeeks.org/python-decorators-a-complete-guide/
  • https://rapd.wordpress.com/2008/07/02/python-staticmethod-vs-classmethod/

Last Update: 04 June 2023