6248 字
31 分钟

Python 装饰器进阶完全指南:带参数的装饰器 / 类装饰器 / functools 深度解析 + 实战场景

《Python 装饰器详解》 中,我们已经掌握了装饰器的基本概念、语法糖 @decorator 的用法、functools.wraps 的基本作用,以及如何处理带参数和返回值的函数。

然而在实际项目中,你可能会遇到更复杂的场景:需要根据不同配置动态调整装饰器行为、用类来实现装饰器以便维护状态、或者在一个函数上叠加多个装饰器。本文将深入讲解这些进阶用法,并通过实战场景让你掌握装饰器的完整能力。

Python装饰器进阶示意图


🔁 回顾:装饰器基础#

在深入之前,我们用一张图快速回顾装饰器的核心机制:

┌─────────────────────────────────────────┐
│ @decorator │
│ def func(): │
│ pass │
│ │
│ 等价于:func = decorator(func) │
│ │
│ 装饰器 = 接收函数,返回新函数的函数 │
└─────────────────────────────────────────┘

如果你还不熟悉上面的概念,建议先阅读 《Python 装饰器详解》 打好基础。


🎯 一、带参数的装饰器#

1.1 为什么需要带参数的装饰器?#

在基础篇中我们已经了解了如何写一个带参数的装饰器,但实际应用场景远比你想象的丰富。考虑以下需求:

  • 你写了一个计时器装饰器,但有时希望输出毫秒,有时希望输出秒
  • 你写了一个日志装饰器,希望能灵活选择输出到控制台还是文件
  • 你写了一个重试装饰器,希望能自定义重试次数和间隔

这类场景下,装饰器本身需要接收配置参数来改变行为。

1.2 三层嵌套函数结构#

带参数的装饰器本质上是一个返回装饰器的函数,形成了三层嵌套结构:

outer(参数) → 返回一个装饰器 decorator
└── decorator(func) → 返回包装函数 wrapper
└── wrapper(*args, **kwargs) → 实际执行逻辑

代码示例:

def repeat(times=3):
"""让被装饰的函数重复执行 times 次"""
def decorator(func):
def wrapper(*args, **kwargs):
results = []
for i in range(times):
result = func(*args, **kwargs)
results.append(result)
return results
return wrapper
return decorator

使用方式:

@repeat(times=5)
def greet(name):
return f"Hello, {name}!"
print(greet("Python"))
# 输出: ['Hello, Python!', 'Hello, Python!', 'Hello, Python!', 'Hello, Python!', 'Hello, Python!']

1.3 更复杂的示例:可配置的日志装饰器#

import logging
from functools import wraps
def log(level=logging.INFO, logger_name=None, message=None):
"""
可配置的日志装饰器
参数:
level: 日志级别 (logging.DEBUG, INFO, WARNING, ERROR)
logger_name: 使用哪个 logger(None 表示使用 root logger)
message: 自定义日志消息(支持 {func_name} 占位符)
"""
def decorator(func):
logger = logging.getLogger(logger_name) if logger_name else logging.getLogger(__name__)
log_message = message or "{func_name} called"
@wraps(func)
def wrapper(*args, **kwargs):
# 调用前记录日志
logger.log(level, log_message.format(func_name=func.__name__))
logger.log(level, f" args: {args}")
logger.log(level, f" kwargs: {kwargs}")
result = func(*args, **kwargs)
# 调用后记录日志
logger.log(level, f" result: {result}")
return result
return wrapper
return decorator

使用示例:

@log(level=logging.DEBUG, message="调用函数: {func_name}")
def add(a, b):
return a + b
@log(level=logging.WARNING, logger_name="myapp", message="【警告】执行 {func_name}")
def risky_operation(data):
return process(data)

1.4 灵活的参数解析:支持省略括号#

有时候你希望装饰器既可以带参数使用 @log(level=DEBUG),也可以不带参数直接使用 @log。这需要在装饰器中进行参数类型判断:

import logging
from functools import wraps
def log(arg=None, *, level=logging.INFO, message=None):
"""
支持两种用法:
@log # 无参数
@log(level=DEBUG) # 带参数
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
msg = message or f"{func.__name__} called"
logging.log(level, msg)
return func(*args, **kwargs)
return wrapper
# 如果第一个参数是函数,说明是 @log 的用法
if callable(arg):
return decorator(arg)
# 如果第一个参数是 None 或其他值,说明是 @log(...) 的用法
return decorator

两种用法都能正常工作:

@log
def func1():
pass
@log(level=logging.ERROR, message="错误发生")
def func2():
pass

🎯 二、类装饰器#

2.1 什么是类装饰器?#

类装饰器(Class Decorator)使用类来实现装饰器功能。当类实现了 __call__ 方法后,它的实例就可以像函数一样被调用,这就是类装饰器的核心机制。

class Decorator:
def __init__(self, func):
self.func = func # 保存被装饰的函数
def __call__(self, *args, **kwargs):
# 这里写装饰器逻辑,相当于 wrapper 函数
print("Before call")
result = self.func(*args, **kwargs)
print("After call")
return result

使用方式和函数装饰器一样:

@Decorator
def greet(name):
return f"Hello, {name}!"
print(greet("World"))
# Before call
# After call
# Hello, World!

2.2 类装饰器的优势#

类装饰器相比函数装饰器有以下优势:

特性类装饰器函数装饰器
维护状态✅ 易于使用实例变量⚠️ 需要使用闭包或 nonlocal
代码组织✅ 逻辑清晰,可分组方法❌ 所有逻辑嵌套在函数中
继承复用✅ 可通过继承扩展装饰器❌ 难以复用
复杂场景✅ 适合处理复杂逻辑⚠️ 嵌套过深难以维护

2.3 类装饰器实战:带计数的装饰器#

from functools import wraps
class CountCalls:
"""记录函数被调用的次数"""
def __init__(self, func):
self.func = func
self.count = 0 # 计数器,维护在实例中
def __call__(self, *args, **kwargs):
self.count += 1
print(f"{self.func.__name__} 已被调用 {self.count} 次")
return self.func(*args, **kwargs)

使用效果:

@CountCalls
def fibonacci(n):
"""计算第 n 个斐波那契数"""
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
print(fibonacci(10))
# fibonacci 已被调用 1 次
# fibonacci 已被调用 2 次
# ...
# 55
print(f"函数总共被调用了 {fibonacci.count} 次")
# 函数总共被调用了 177 次

💡 注意:类装饰器中要保留原函数的元信息(__name____doc__ 等),需要在 __init__ 中使用 functools.update_wrapper

改进版:

from functools import update_wrapper
class CountCalls:
def __init__(self, func):
update_wrapper(self, func) # 保留原函数的元信息
self.func = func
self.count = 0
def __call__(self, *args, **kwargs):
self.count += 1
return self.func(*args, **kwargs)

2.4 带参数的类装饰器#

类装饰器也可以接收参数,方法是在 __init__ 中接收参数,在 __call__ 中接收函数:

import time
from functools import update_wrapper
class Timer:
"""
可配置的计时器装饰器
参数:
unit: 's' | 'ms' | 'us',输出时间单位
threshold: 超过该阈值的调用才输出日志(秒)
"""
def __init__(self, unit="s", threshold=0):
self.unit = unit
self.threshold = threshold
def __call__(self, func):
update_wrapper(self, func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
# 根据单位转换
if self.unit == "ms":
elapsed_display = elapsed * 1000
unit_label = "ms"
elif self.unit == "us":
elapsed_display = elapsed * 1_000_000
unit_label = "μs"
else:
elapsed_display = elapsed
unit_label = "s"
# 超过阈值才输出
if elapsed >= self.threshold:
print(f"{func.__name__}: {elapsed_display:.3f} {unit_label}")
return result
return wrapper

使用方式:

@Timer(unit="ms", threshold=0.1) # 超过 0.1 秒才输出日志,毫秒为单位
def slow_function(n):
total = sum(i * i for i in range(n))
return total
slow_function(1_000_000)
# slow_function: 128.456 ms
slow_function(1_000)
# (无输出,因为执行时间低于阈值)

🎯 三、functools 模块深度解析#

functools 是 Python 标准库中与装饰器最密切相关的模块。除了 wraps,还有几个非常强大的工具值得深入掌握。

3.1 functools.wraps:装饰器的标配#

@wraps 是装饰器的标配,作用是将被包装函数 func 的元信息(__name____doc____module__ 等)复制到包装函数 wrapper 上:

from functools import wraps
# ❌ 没有使用 wraps
def bad_decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
# ✅ 使用了 wraps
def good_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@bad_decorator
def func_a():
"""这是一个测试函数"""
pass
@good_decorator
def func_b():
"""这是一个测试函数"""
pass
print(func_a.__name__) # 'wrapper' ❌ 信息丢失了!
print(func_a.__doc__) # None
print(func_b.__name__) # 'func_b' ✅ 保留了原信息
print(func_b.__doc__) # '这是一个测试函数'

3.2 functools.lru_cache:缓存装饰器#

lru_cache(Least Recently Used Cache)是 Python 内置的缓存装饰器,可以自动缓存函数的调用结果,是最实用的装饰器之一。

基本用法#

from functools import lru_cache
@lru_cache(maxsize=128)
def fibonacci(n):
"""计算第 n 个斐波那契数"""
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)

我们来对比一下有无缓存的性能差异:

# 无缓存
def fib_slow(n):
if n < 2:
return n
return fib_slow(n - 1) + fib_slow(n - 2)
# 有缓存
@lru_cache(maxsize=None)
def fib_fast(n):
if n < 2:
return n
return fib_fast(n - 1) + fib_fast(n - 2)
import time
start = time.perf_counter()
fib_slow(35)
print(f"无缓存: {time.perf_counter() - start:.2f}s")
# 无缓存: 2.34s
start = time.perf_counter()
fib_fast(35)
print(f"有缓存: {time.perf_counter() - start:.6f}s")
# 有缓存: 0.000012s

🚀 性能提升:使用 lru_cache 后,递归版斐波那契数列的时间复杂度从 O(2ⁿ) 降到了 O(n)

lru_cache 的参数详解#

@lru_cache(maxsize=128, typed=False)
def your_function(a, b):
...
参数类型默认值说明
maxsizeint / None128缓存的最大条目数。设为 None 表示不限制,即缓存所有结果
typedboolFalse如果为 True,则不同类型的参数会分别缓存(例如 f(3)f(3.0) 视为不同调用)

查看和操作缓存#

@lru_cache(maxsize=100)
def expensive_func(n):
return n ** n
# 调用几次
expensive_func(10)
expensive_func(20)
expensive_func(30)
# 查看缓存信息
info = expensive_func.cache_info()
print(info)
# CacheInfo(hits=0, misses=3, maxsize=100, currsize=3)
# 清空缓存
expensive_func.cache_clear()
# 查看清空后的状态
print(expensive_func.cache_info())
# CacheInfo(hits=0, misses=0, maxsize=100, currsize=0)

lru_cache 的限制#

lru_cache 虽然强大,但有一个关键限制:函数的所有参数必须是可哈希的(hashable)。

@lru_cache(maxsize=None)
def process_list(lst): # ❌ 错误!list 是不可哈希的
return sum(lst)
process_list([1, 2, 3])
# TypeError: unhashable type: 'list'

解决方案:将可变参数转换为不可变类型,或者手动实现缓存逻辑。

# 方案一:传参时转换
@lru_cache(maxsize=None)
def process_tuple(tpl): # ✅ tuple 是可哈希的
return sum(tpl)
process_tuple(tuple([1, 2, 3]))
# 方案二:使用 functools.cache (Python 3.9+)
# cache 等价于 lru_cache(maxsize=None)
from functools import cache
@cache
def process_list(lst):
# 内部手动处理
key = tuple(lst)
...

⚠️ Python 3.9 新增functools.cachelru_cache(maxsize=None) 的简写形式,写法更简洁。

3.3 functools.singledispatch:单分派泛函数#

singledispatch 允许你根据第一个参数的类型来选择不同的函数实现,这在处理多种输入类型时非常有用。

from functools import singledispatch
@singledispatch
def format_data(data):
"""根据数据类型以不同方式格式化输出"""
raise NotImplementedError(f"不支持的类型: {type(data)}")
# 注册针对 str 的实现
@format_data.register(str)
def _(data):
return f"字符串: '{data}'"
# 注册针对 int 的实现
@format_data.register(int)
def _(data):
return f"整数: {data} (二进制: {bin(data)})"
# 注册针对 list 的实现
@format_data.register(list)
def _(data):
formatted = ", ".join(str(item) for item in data)
return f"列表(长度{len(data)}): [{formatted}]"
# 注册针对 dict 的实现
@format_data.register(dict)
def _(data):
items = ", ".join(f"{k}: {v}" for k, v in data.items())
return f"字典: {{{items}}}"

测试一下:

print(format_data("hello")) # 字符串: 'hello'
print(format_data(42)) # 整数: 42 (二进制: 0b101010)
print(format_data([1, 2, 3])) # 列表(长度3): [1, 2, 3]
print(format_data({"a": 1, "b": 2})) # 字典: {a: 1, b: 2}
print(format_data(3.14)) # NotImplementedError: 不支持的类型: <class 'float'>

singledispatch 相比传统的 if/elif 链有这些好处:

  • 可扩展:任何人都可以在任何地方注册新的类型处理
  • 模块化:每种类型的处理逻辑可以单独定义
  • 清晰:主函数只做分发,不关心具体实现
# 传统写法:if/elif 链,难以扩展
def format_data_traditional(data):
if isinstance(data, str):
return f"字符串: '{data}'"
elif isinstance(data, int):
return f"整数: {data}"
elif isinstance(data, list):
return f"列表: {data}"
# ... 越来越长
else:
raise TypeError()

3.4 functools.partial:偏函数#

partial 不是装饰器,但它与装饰器配合使用非常常见。它的作用是固定函数的某些参数,创建一个新的简化版本。

from functools import partial
# 原始函数
def power(base, exponent):
return base ** exponent
# 创建偏函数:固定 exponent=2
square = partial(power, exponent=2)
# 创建偏函数:固定 exponent=3
cube = partial(power, exponent=3)
print(square(5)) # 25 (等价于 power(5, exponent=2))
print(cube(5)) # 125 (等价于 power(5, exponent=3))

与装饰器配合使用的场景:

from functools import wraps, partial
def decorator_with_partial(func=None, *, prefix="LOG"):
"""使用 partial 实现可省略括号的装饰器"""
if func is None:
return partial(decorator_with_partial, prefix=prefix)
@wraps(func)
def wrapper(*args, **kwargs):
print(f"[{prefix}] 调用 {func.__name__}")
return func(*args, **kwargs)
return wrapper
# 两种用法都支持
@decorator_with_partial
def func1():
pass
@decorator_with_partial(prefix="WARNING")
def func2():
pass

3.5 functools.total_ordering:自动生成比较方法#

total_ordering 是一个类装饰器(装饰整个类),它可以根据你定义的少数比较方法,自动为类生成完整的比较运算符。

from functools import total_ordering
@total_ordering
class Version:
def __init__(self, major, minor, patch=0):
self.major = major
self.minor = minor
self.patch = patch
def _key(self):
return (self.major, self.minor, self.patch)
# 只需定义两个方法:__eq__ 和 __lt__
def __eq__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._key() == other._key()
def __lt__(self, other):
if not isinstance(other, Version):
return NotImplemented
return self._key() < other._key()
# total_ordering 自动生成:
# __le__ (<=), __gt__ (>), __ge__ (>=), __ne__ (!=)

测试自动生成的比较方法:

v1 = Version(2, 1, 0)
v2 = Version(2, 5, 0)
v3 = Version(2, 1, 0)
print(v1 < v2) # True (你定义的)
print(v1 <= v2) # True (自动生成)
print(v1 > v2) # False (自动生成)
print(v1 >= v3) # True (自动生成)
print(v1 != v2) # True (自动生成)
print(v1 == v3) # True (你定义的)

💡 优点:减少重复代码,确保比较运算符的一致性

⚠️ 注意total_ordering 生成的方法会通过调用你定义的 __eq____lt__ 来实现,可能比手写所有方法略慢一点,但通常可以忽略不计。


🎯 四、装饰器叠加与执行顺序#

4.1 多个装饰器叠加#

Python 允许在一个函数上叠加多个装饰器,语法如下:

@decorator1
@decorator2
@decorator3
def func():
pass

等价于:

func = decorator1(decorator2(decorator3(func)))

4.2 执行顺序详解#

理解装饰器叠加的关键在于分清两个阶段:

阶段执行顺序说明
装饰阶段(定义时)从下到上decorator3decorator2decorator1
执行阶段(调用时)从上到下decorator1 前置 → decorator2 前置 → decorator3 前置 → 原函数 → decorator3 后置 → decorator2 后置 → decorator1 后置

通过代码示例验证:

def decorator1(func):
print("装饰阶段: decorator1 被调用")
def wrapper():
print("执行阶段: decorator1 前置代码")
func()
print("执行阶段: decorator1 后置代码")
return wrapper
def decorator2(func):
print("装饰阶段: decorator2 被调用")
def wrapper():
print("执行阶段: decorator2 前置代码")
func()
print("执行阶段: decorator2 后置代码")
return wrapper
def decorator3(func):
print("装饰阶段: decorator3 被调用")
def wrapper():
print("执行阶段: decorator3 前置代码")
func()
print("执行阶段: decorator3 后置代码")
return wrapper
@decorator1
@decorator2
@decorator3
def func():
print("原始函数被调用")
# === 定义阶段输出(按从上到下阅读代码时输出)===
# 装饰阶段: decorator3 被调用
# 装饰阶段: decorator2 被调用
# 装饰阶段: decorator1 被调用
# === 调用时输出 ===
# 执行阶段: decorator1 前置代码
# 执行阶段: decorator2 前置代码
# 执行阶段: decorator3 前置代码
# 原始函数被调用
# 执行阶段: decorator3 后置代码
# 执行阶段: decorator2 后置代码
# 执行阶段: decorator1 后置代码

用图示理解执行阶段的流程:

调用 func()
┌─ decorator1.wrapper ────────────────────┐
│ print("decorator1 前置") │
│ ┌─ decorator2.wrapper ────────────────┐│
│ │ print("decorator2 前置") ││
│ │ ┌─ decorator3.wrapper ────────────┐││
│ │ │ print("decorator3 前置") │││
│ │ │ func() # 原始函数 │││
│ │ │ print("decorator3 后置") │││
│ │ └─────────────────────────────────┘││
│ │ print("decorator2 后置") ││
│ └─────────────────────────────────────┘│
│ print("decorator1 后置") │
└─────────────────────────────────────────┘

4.3 实战:叠加日志 + 计时 + 缓存装饰器#

from functools import wraps, lru_cache
import time
import logging
logging.basicConfig(level=logging.INFO)
def log_calls(func):
@wraps(func)
def wrapper(*args, **kwargs):
logging.info(f"CALL: {func.__name__}{args}")
result = func(*args, **kwargs)
logging.info(f"DONE: {func.__name__}, result={result}")
return result
return wrapper
def measure_time(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
logging.info(f"TIME: {func.__name__} 耗时 {elapsed*1000:.2f}ms")
return result
return wrapper
# 顺序很重要!
# 从内到外:缓存 → 计时 → 日志
@log_calls # 最外层:每次调用都记录日志
@measure_time # 中间层:计时(包括缓存命中的快速返回)
@lru_cache(maxsize=128) # 最内层:缓存(最接近原函数)
def fibonacci(n):
if n < 2:
return n
return fibonacci(n - 1) + fibonacci(n - 2)

⚠️ 装饰器顺序的重要性:在上面的例子中,lru_cache 放在最内层是正确的——它直接作用于计算逻辑,避免重复计算。如果把它放到最外层,虽然缓存依然有效,但每次调用都会先经过 log_callsmeasure_time,导致日志和计时信息不准确(缓存命中也会产生日志和计时)。

经验法则

  • 函数逻辑密切相关的装饰器放内层(如缓存、参数验证)
  • 通用处理的装饰器放外层(如日志、计时、错误处理)

🎯 五、装饰器实战场景#

下面是装饰器在真实项目中的 5 个经典应用场景。

5.1 场景一:计时器(性能分析)#

from functools import wraps
import time
import statistics
class timer:
"""
智能计时器装饰器
功能:
- 记录每次调用的耗时
- 统计平均/最大/最小耗时
- 可设置超时警告
"""
def __init__(self, warn_threshold=None):
self.warn_threshold = warn_threshold
self.timings = []
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.perf_counter()
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
self.timings.append(elapsed)
msg = f"[{func.__name__}] {elapsed*1000:.2f}ms"
if self.warn_threshold and elapsed > self.warn_threshold:
msg += f" ⚠️ 超过阈值 {self.warn_threshold*1000:.0f}ms"
print(msg)
return result
# 给 wrapper 添加统计方法
wrapper.stats = lambda: self._get_stats(func.__name__)
wrapper.reset = lambda: self.timings.clear()
return wrapper
def _get_stats(self, name):
if not self.timings:
return f"[{name}] 尚无记录"
stats = {
"count": len(self.timings),
"avg_ms": statistics.mean(self.timings) * 1000,
"max_ms": max(self.timings) * 1000,
"min_ms": min(self.timings) * 1000,
"std_ms": statistics.stdev(self.timings) * 1000 if len(self.timings) > 1 else 0,
}
return (f"[{name}] 共 {stats['count']} 次调用, "
f"平均 {stats['avg_ms']:.2f}ms, "
f"最大 {stats['max_ms']:.2f}ms, "
f"最小 {stats['min_ms']:.2f}ms, "
f"标准差 {stats['std_ms']:.2f}ms")

使用示例:

@timer(warn_threshold=0.1) # 超过 100ms 警告
def process_data(n):
total = sum(i * i for i in range(n))
return total
for n in [1000, 10000, 100000, 1000000]:
process_data(n)
# 输出:
# [process_data] 0.05ms
# [process_data] 0.48ms
# [process_data] 5.12ms
# [process_data] 58.34ms ⚠️ 超过阈值 100ms
# 获取统计信息
print(process_data.stats())
# [process_data] 共 4 次调用, 平均 16.00ms, 最大 58.34ms, ...

5.2 场景二:重试装饰器(网络请求/不稳定操作)#

from functools import wraps
import time
import random
def retry(max_attempts=3, delay=1, backoff=2, exceptions=(Exception,)):
"""
自动重试装饰器
参数:
max_attempts: 最大尝试次数
delay: 初始重试间隔(秒)
backoff: 每次重试后延迟的倍数(指数退避)
exceptions: 需要捕获的异常类型(元组)
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt == max_attempts:
print(f"[{func.__name__}] 第 {attempt} 次失败,已达最大次数,放弃")
raise
print(f"[{func.__name__}] 第 {attempt} 次失败: {e},"
f"{current_delay} 秒后重试...")
time.sleep(current_delay)
current_delay *= backoff # 指数退避
return wrapper
return decorator

模拟一个不稳定的网络请求:

@retry(max_attempts=5, delay=1, backoff=2)
def fetch_data_from_api():
"""模拟一个不稳定的 API 请求"""
if random.random() < 0.7: # 70% 概率失败
raise ConnectionError("网络连接超时")
return "成功获取数据"
try:
result = fetch_data_from_api()
print(f"结果: {result}")
except Exception as e:
print(f"最终失败: {e}")

可能的输出:

[fetch_data_from_api] 第 1 次失败: 网络连接超时,1 秒后重试...
[fetch_data_from_api] 第 2 次失败: 网络连接超时,2 秒后重试...
结果: 成功获取数据

5.3 场景三:权限验证装饰器(Web 框架常用)#

from functools import wraps
from enum import Enum
class UserRole(Enum):
GUEST = "guest"
USER = "user"
ADMIN = "admin"
class PermissionDenied(Exception):
pass
def require_role(*allowed_roles):
"""
权限验证装饰器
使用:
@require_role(UserRole.USER, UserRole.ADMIN)
def view_profile(user):
...
"""
def decorator(func):
@wraps(func)
def wrapper(user, *args, **kwargs):
if not hasattr(user, "role") or user.role not in allowed_roles:
allowed = ", ".join(r.value for r in allowed_roles)
raise PermissionDenied(
f"需要以下角色之一: {allowed},当前角色: {getattr(user, 'role', 'N/A')}"
)
return func(user, *args, **kwargs)
return wrapper
return decorator

使用示例:

class User:
def __init__(self, name, role):
self.name = name
self.role = role
@require_role(UserRole.ADMIN)
def delete_user(admin_user, target_user_id):
return f"{admin_user.name} 删除了用户 {target_user_id}"
@require_role(UserRole.USER, UserRole.ADMIN)
def view_profile(user, profile_id):
return f"{user.name} 查看了 {profile_id} 的资料"
# 测试
admin = User("超级管理员", UserRole.ADMIN)
normal_user = User("张三", UserRole.USER)
guest = User("访客", UserRole.GUEST)
print(delete_user(admin, 123)) # 成功: 超级管理员 删除了用户 123
print(view_profile(normal_user, 456)) # 成功: 张三 查看了 456 的资料
try:
delete_user(normal_user, 123) # 失败: 权限不足
except PermissionDenied as e:
print(f"错误: {e}")

5.4 场景四:数据验证装饰器#

from functools import wraps
from typing import Callable, Any
class ValidationError(Exception):
pass
def validate(**validators):
"""
参数验证装饰器
使用:
@validate(age=lambda x: x > 0, name=lambda x: len(x) > 0)
def create_user(name, age):
...
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# 将位置参数和关键字参数合并为命名参数字典
import inspect
sig = inspect.signature(func)
bound = sig.bind_partial(*args, **kwargs)
bound.apply_defaults()
# 逐个验证
for param_name, validator in validators.items():
if param_name in bound.arguments:
value = bound.arguments[param_name]
try:
if not validator(value):
raise ValidationError(
f"参数 '{param_name}' = {repr(value)} 未通过验证"
)
except ValidationError:
raise
except Exception as e:
raise ValidationError(
f"参数 '{param_name}' 验证时发生错误: {e}"
)
return func(*args, **kwargs)
return wrapper
return decorator

使用示例:

@validate(
username=lambda s: isinstance(s, str) and 3 <= len(s) <= 20,
age=lambda n: isinstance(n, int) and 0 < n < 150,
email=lambda e: isinstance(e, str) and "@" in e and "." in e,
)
def register_user(username, age, email):
return f"用户 {username} 注册成功"
print(register_user("alice", 25, "alice@example.com")) # 成功
try:
register_user("ab", 200, "invalid-email")
except ValidationError as e:
print(f"验证失败: {e}")

5.5 场景五:速率限制装饰器(API 限流)#

from functools import wraps
import time
from collections import defaultdict
class RateLimit:
"""
速率限制装饰器(滑动窗口算法)
参数:
max_calls: 时间窗口内的最大调用次数
period: 时间窗口(秒)
key_func: 生成限流键的函数(用于区分不同调用者)
"""
def __init__(self, max_calls=10, period=60, key_func=lambda *a, **kw: "default"):
self.max_calls = max_calls
self.period = period
self.key_func = key_func
self._calls = defaultdict(list) # key -> 调用时间戳列表
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
key = self.key_func(*args, **kwargs)
now = time.time()
# 清理过期的调用记录
self._calls[key] = [
t for t in self._calls[key]
if now - t < self.period
]
# 检查是否超限
if len(self._calls[key]) >= self.max_calls:
wait_time = self.period - (now - self._calls[key][0])
raise Exception(
f"调用频率过高!请等待 {wait_time:.1f} 秒后重试。"
f"({self.max_calls} 次/{self.period} 秒)"
)
self._calls[key].append(now)
return func(*args, **kwargs)
return wrapper

实际应用:模拟 API 调用限流

@RateLimit(max_calls=3, period=5, key_func=lambda user, **kw: user)
def api_call(user, data):
"""模拟 API 调用,每个用户 5 秒内最多调用 3 次"""
return f"{user} 调用 API,数据: {data}"
# 测试:用户 A 在短时间内多次调用
for i in range(5):
try:
print(api_call("用户A", f"请求{i+1}"))
except Exception as e:
print(f" 错误: {e}")
# 输出:
# 用户A 调用 API,数据: 请求1
# 用户A 调用 API,数据: 请求2
# 用户A 调用 API,数据: 请求3
# 错误: 调用频率过高!请等待 5.0 秒后重试。(3 次/5 秒)
# 错误: 调用频率过高!请等待 5.0 秒后重试。(3 次/5 秒)
# 但用户 B 可以正常调用(不同的限流键)
print(api_call("用户B", "请求1")) # 用户B 调用 API,数据: 请求1

🎯 六、装饰器常见陷阱#

6.1 陷阱一:忘记使用 functools.wraps#

问题:装饰器覆盖了原函数的元信息。

def decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@decorator
def my_func():
"""这是一个重要的函数"""
pass
print(my_func.__name__) # 'wrapper' ❌
print(my_func.__doc__) # None ❌
help(my_func) # 显示 wrapper 的信息 ❌

解决:总是使用 @wraps(func)

6.2 陷阱二:装饰器无法正确处理递归#

问题:如果装饰器应用到递归函数上,每次递归调用都会触发装饰器逻辑,可能导致意外的性能开销或行为。

@timer() # 每次递归都会计时并输出日志
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
factorial(5)
# [factorial] 0.00ms
# [factorial] 0.01ms
# [factorial] 0.02ms
# [factorial] 0.03ms
# [factorial] 0.05ms

解决:将递归逻辑分离到一个未被装饰的内部函数。

@timer() # 只在最外层计时一次
def factorial(n):
def _factorial(k): # 内部函数未被装饰,递归时不触发装饰器
if k <= 1:
return 1
return k * _factorial(k - 1)
return _factorial(n)
factorial(5)
# [factorial] 0.00ms ✅ 只输出一次

6.3 陷阱三:lru_cache 对可变参数失效#

@lru_cache(maxsize=None)
def sum_list(lst):
return sum(lst)
sum_list([1, 2, 3]) # TypeError: unhashable type: 'list' ❌

解决:参考前面 3.2 节中的方案。

6.4 陷阱四:装饰器无法处理异步函数#

问题:普通装饰器的 wrapper 是同步函数,无法正确处理 async def 定义的异步函数。

from functools import wraps
import asyncio
def simple_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs): # ❌ 这是同步函数
print("Before")
result = func(*args, **kwargs) # func 是 async,返回 coroutine
print("After")
return result
return wrapper
@simple_decorator
async def async_func():
await asyncio.sleep(1)
return "done"
# 虽然能运行,但 "After" 会在协程完成前就输出
# 正确的做法是 wrapper 也用 async def

解决:为异步函数编写异步装饰器。

def async_decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs): # ✅ async wrapper
print("Before")
result = await func(*args, **kwargs) # await 异步函数
print("After")
return result
return wrapper
@async_decorator
async def async_func():
await asyncio.sleep(1)
return "done"

如果你需要一个装饰器同时支持同步和异步函数,可以这样实现:

import asyncio
from functools import wraps
import inspect
def universal_decorator(func):
"""同时支持同步和异步函数的装饰器"""
if inspect.iscoroutinefunction(func):
# 异步版本
@wraps(func)
async def async_wrapper(*args, **kwargs):
print(f"[async] 调用 {func.__name__}")
result = await func(*args, **kwargs)
print(f"[async] 完成 {func.__name__}")
return result
return async_wrapper
else:
# 同步版本
@wraps(func)
def sync_wrapper(*args, **kwargs):
print(f"[sync] 调用 {func.__name__}")
result = func(*args, **kwargs)
print(f"[sync] 完成 {func.__name__}")
return result
return sync_wrapper

6.5 陷阱五:在类方法上使用装饰器时的 self 参数#

问题:普通装饰器在类方法上使用时,self 会被当作 args 的第一个参数,这通常没问题,但装饰器内部如果需要处理 self,可能会产生困惑。

def log_calls(func):
@wraps(func)
def wrapper(*args, **kwargs):
# args[0] 是 self!
print(f"调用 {func.__name__}, 参数: {args[1:]}, kwargs: {kwargs}")
return func(*args, **kwargs)
return wrapper
class MyClass:
@log_calls
def method(self, x, y):
return x + y
obj = MyClass()
obj.method(1, 2) # 输出: 调用 method, 参数: (1, 2), kwargs: {}
# 注意 self 被排除在输出之外了

解决:如果装饰器专门用于方法,可以显式处理 self

def log_method(func):
@wraps(func)
def wrapper(self, *args, **kwargs): # 显式声明 self
print(f"[{self.__class__.__name__}] 调用 {func.__name__}")
return func(self, *args, **kwargs)
return wrapper

📋 总结#

让我们用一张表格总结装饰器的进阶知识点:

概念用途关键技术
带参数装饰器灵活配置装饰器行为三层嵌套函数 (outer → decorator → wrapper)
类装饰器维护状态、继承复用__call__ 方法、update_wrapper
functools.wraps保留原函数元信息装饰器标配,必须使用
functools.lru_cache函数结果缓存参数需可哈希,注意 maxsize 和内存
functools.singledispatch根据类型分发逻辑注册不同类型的实现,扩展性强
functools.partial固定部分参数简化函数签名,与装饰器配合
functools.total_ordering自动生成比较方法只需定义 __eq__ 和一个比较运算符
装饰器叠加组合多种功能注意顺序:内层放业务相关,外层放通用处理

装饰器是 Python 中最强大的特性之一。掌握基础后,通过这些进阶技巧,你可以编写出更加优雅、可复用、可维护的代码。


相关阅读

Python 装饰器进阶完全指南:带参数的装饰器 / 类装饰器 / functools 深度解析 + 实战场景
https://971918.xyz/posts/python-guide/python-decorator-advanced/
作者
九所长
发布于
2026-06-16
许可协议
CC BY-NC-SA 4.0