2017-08-10 62 views
1

我写了一个装饰器来打印由某个函数调用产生的递归树。Python:为特定函数调用修补打印函数? (用于打印递归树的装饰器)

from functools import wraps 

def printRecursionTree(func): 
    global _recursiondepth 
    _print = print 
    _recursiondepth = 0 

    def getpads(): 
     if _recursiondepth == 0: 
      strFn = '{} └──'.format(' │ ' * (_recursiondepth-1)) 
      strOther = '{} ▒▒'.format(' │ ' * (_recursiondepth-1)) 
      strRet = '{} '.format(' │ ' * (_recursiondepth-1)) 
     else: 
      strFn = ' {} ├──'.format(' │ ' * (_recursiondepth-1)) 
      strOther = ' {} │▒▒'.format(' │ ' * (_recursiondepth-1)) 
      strRet = ' {} │ '.format(' │ ' * (_recursiondepth-1)) 

     return strFn, strRet, strOther 

    def indentedprint(): 
     @wraps(print) 
     def wrapper(*args, **kwargs): 
      strFn, strRet, strOther = getpads() 
      _print(strOther, end=' ') 
      _print(*args, **kwargs) 
     return wrapper 


    @wraps(func) 
    def wrapper(*args, **kwargs): 
     global _recursiondepth 
     global print 

     strFn, strRet, strOther = getpads() 

     if args and kwargs: 
      _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs)) 
     else: 
      _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else '')) 
     _recursiondepth += 1 
     print, backup = indentedprint(), print 
     retval = func(*args, **kwargs) 
     print = backup 
     _recursiondepth -= 1 
     _print(strRet, '╰', retval) 
     if _recursiondepth == 0: 
      _print() 
     return retval 

    return wrapper 

实例:

@printRecursionTree 
def fib(n): 
    if n <= 1: 
     print('Base Case') 
     return n 
    print('Recursive Case') 
    return fib(n-1) + fib(n-2) 

# This works with mutually recursive functions too, 
# since the variable _recursiondepth is global 
@printRecursionTree 
def iseven(n): 
    print('checking if even') 
    if n == 0: return True 
    return isodd(n-1) 

@printRecursionTree 
def isodd(n): 
    print('checking if odd') 
    if n == 0: return False 
    return iseven(n-1) 

iseven(5) 
fib(5) 

'''Prints: 

└── iseven(5): 
    │▒▒ checking if even 
    │▒▒ Note how the print 
    │▒▒ statements get nicely indented 
    ├── isodd(4): 
    │ │▒▒ checking if odd 
    │ ├── iseven(3): 
    │ │ │▒▒ checking if even 
    │ │ │▒▒ Note how the print 
    │ │ │▒▒ statements get nicely indented 
    │ │ ├── isodd(2): 
    │ │ │ │▒▒ checking if odd 
    │ │ │ ├── iseven(1): 
    │ │ │ │ │▒▒ checking if even 
    │ │ │ │ │▒▒ Note how the print 
    │ │ │ │ │▒▒ statements get nicely indented 
    │ │ │ │ ├── isodd(0): 
    │ │ │ │ │ │▒▒ checking if odd 
    │ │ │ │ │ ╰ False 
    │ │ │ │ ╰ False 
    │ │ │ ╰ False 
    │ │ ╰ False 
    │ ╰ False 
    ╰ False 

└── fib(5): 
    │▒▒ Recursive Case 
    ├── fib(4): 
    │ │▒▒ Recursive Case 
    │ ├── fib(3): 
    │ │ │▒▒ Recursive Case 
    │ │ ├── fib(2): 
    │ │ │ │▒▒ Recursive Case 
    │ │ │ ├── fib(1): 
    │ │ │ │ │▒▒ Base Case 
    │ │ │ │ ╰ 1 
    │ │ │ ├── fib(0): 
    │ │ │ │ │▒▒ Base Case 
    │ │ │ │ ╰ 0 
    │ │ │ ╰ 1 
    │ │ ├── fib(1): 
    │ │ │ │▒▒ Base Case 
    │ │ │ ╰ 1 
    │ │ ╰ 2 
    │ ├── fib(2): 
    │ │ │▒▒ Recursive Case 
    │ │ ├── fib(1): 
    │ │ │ │▒▒ Base Case 
    │ │ │ ╰ 1 
    │ │ ├── fib(0): 
    │ │ │ │▒▒ Base Case 
    │ │ │ ╰ 0 
    │ │ ╰ 1 
    │ ╰ 3 
    ├── fib(3): 
    │ │▒▒ Recursive Case 
    │ ├── fib(2): 
    │ │ │▒▒ Recursive Case 
    │ │ ├── fib(1): 
    │ │ │ │▒▒ Base Case 
    │ │ │ ╰ 1 
    │ │ ├── fib(0): 
    │ │ │ │▒▒ Base Case 
    │ │ │ ╰ 0 
    │ │ ╰ 1 
    │ ├── fib(1): 
    │ │ │▒▒ Base Case 
    │ │ ╰ 1 
    │ ╰ 2 
    ╰ 5 
''' 

此示例代码工作正常,只要它是在将装饰定义相同的文件。

但是,如果从某个模块导入装饰器,则打印语句不再缩进。

我知道这种行为的产生是因为由decorator修补的print语句对于它自己的模块是全局的,并不是跨模块共享。

  1. 我该如何解决这个问题?
  2. 有没有更好的方法来修补一个函数只适用于另一个函数的特定调用?

回答

1

您可以通过在builtins模块中替换它来更改所有模块的内置打印功能的行为。

因此改变你分配到全局变量print与分配到builtins.print(进口builtins后):

import builtins 

... 

    @wraps(func) 
    def wrapper(*args, **kwargs): 
     global _recursiondepth # no more need for global print up here 

     strFn, strRet, strOther = getpads() 

     if args and kwargs: 
      _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs)) 
     else: 
      _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else '')) 
     _recursiondepth += 1 
     builtins.print, backup = indentedprint(), print # change here 
     retval = func(*args, **kwargs) 
     builtins.print = backup       # and here 
     _recursiondepth -= 1 
     _print(strRet, '╰', retval) 
     if _recursiondepth == 0: 
      _print() 
     return retval 
+0

快速短!谢谢。修补建筑这样认为是危险的吗? –

+1

只要您保留对“print”原始版本的引用,我不认为这太危险,但要安全地执行此操作,您可能需要在代码中处理异常情况时使用一些额外的逻辑。例如,我建议在调用'func'(它需要缩进一个级别)之前放置'try'语句,并将'builtins.print = backup'行放在'finally'子句中。您可能希望'_recursiondepth'更新也可以这样处理。 – Blckknght