2014-10-30 113 views
1

我有一个简单的函数,定义如下:与函数参数的Python装饰

def simple_function(x): 
    """ x is an input numpy array""" 
    return x + 0.1 

我想通过应用一些边界条件,以它来修改这个功能。这些边界条件是x的本身功能:

def upper_bound(x): 
    return x**2 

def lower_bound(x): 
    return np.zeros(len(x)) 

特别是,如果simple_function(x)超过upper_bound(x)值,或低于lower_bound(x),我想的simple_function(x)装饰的版本返回值upper_bound(x),同样为LOWER_BOUND 。我如何使用python中的@decorator语法来完成这种行为?

回答

0

除了修改__doc__等其他事项外,在这里你是:

def constrain(lower, upper): 
    def outer(f): 
    def inner(x): 
     r = f(x) 
     u = upper(x) 
     if r > u: 
     return u 
     l = lower(x) 
     if r < l: 
     return l 
     return r 
    return inner 
    return outer 

...

@constrain(lower_bound, upper_bound) 
def simple_function(x): 
    ... 

不同类型和下限比上限不高的处理。

+0

伟大的答案!这很好,是单变量参数的优秀语法模板,但它不适用于numpy数组。知道如何为此目的调整语法? – aph 2014-10-30 05:31:00

+0

NumPy数组是一个参数。您是否看到任何*特定*问题? – 2014-10-30 05:36:20

+0

我怀疑问题不在于参数,而是返回值。你不能像比较标量一样比较numpy数组(你可能需要使用'all',或者按照元素的方式来替换越界值,而不是只返回'l'或'u ')。 – Blckknght 2014-10-30 05:40:17

2

如果你的参数,边界和结果都是numpy数组,你可以做几个数组赋值来将每个元素夹在你的upper_boundlower_bound函数返回的对应值之间。核心部分是:

r = f(x) 
l = lower_bound(x) 
u = upper_bound(x) 

i = r < l 
j = r > u 

r[i] = l 
r[j] = u 

ij将布尔数组即说哪些索引需要分别被夹持到下限和上限。为了使此代码工作作为装饰,你只需把它一对嵌套函数里面,像这样:

def clamp(f): 
    @functools.wraps(f) 
    def wrapper(x): 
     r = f(x) 
     l = lower_bound(x) 
     u = upper_bound(x) 

     i = r < l 
     j = r > u 

     r[i] = l 
     r[j] = u 

     return r 

    return wrapper 

functools.wraps使装饰功能的它,因此包装函数副本的名称,注释和文档字符串。

上面的代码假定您总是使用相同的upper_boundlower_bound函数。如果你需要的是定制为你装饰不同的功能,你可以添加嵌套一个额外层,并定义一个“装饰工厂”就像在伊格纳西奥巴斯克斯 - 艾布拉姆斯的回答是:

def clamp(lower_bound, upper_bound): # this is the decorator factory function 
    def decotator(f):     # this is the decorator function 
     @functools.wraps(f) 
     def wrapper(x):    # this is the wrapper function 
      ... # same code here as above 
      return r 

     return wrapper 

    return decorator 
+0

Bickknght,这是一些漂亮的代码。我的实现非常有效,谢谢! – aph 2014-10-30 13:11:56

+0

我唯一的超级小调整就是让您的布尔条件在numpy数组中更好地工作:r = np.where(r aph 2014-10-30 13:14:09