2009-10-27 136 views
5

我正试图想出一个优雅的方式来处理一些生成的多项式。以下是我们将重点关注(独家)对于这个问题的情况:多项式评估的生成方法

  1. 为了在产生ň阶多项式,其中n参数:=整理+ 1
  2. 是一个整数参数,范围为0..n
  3. 多项式在x_j处有零,其中j = 1..n和j≠i(在这一点上应该清楚,StackOverflow需要一个新的特征或它存在我不知道)
  4. 多项式求值es在x_i处为1。

由于这个特定的代码示例生成x_1 .. x_n,我将解释它们在代码中的含义。这些点均匀分布在x_j = j * elementSize/order处,其中n = order + 1

我产生Func<double, double>评估这个polynomial¹。

private static Func<double, double> GeneratePsi(double elementSize, int order, int i) 
{ 
    if (order < 1) 
     throw new ArgumentOutOfRangeException("order", "order must be greater than 0."); 

    if (i < 0) 
     throw new ArgumentOutOfRangeException("i", "i cannot be less than zero."); 
    if (i > order) 
     throw new ArgumentException("i", "i cannot be greater than order"); 

    ParameterExpression xp = Expression.Parameter(typeof(double), "x"); 

    // generate the terms of the factored polynomial in form (x_j - x) 
    List<Expression> factors = new List<Expression>(); 
    for (int j = 0; j <= order; j++) 
    { 
     if (j == i) 
      continue; 

     double p = j * elementSize/order; 
     factors.Add(Expression.Subtract(Expression.Constant(p), xp)); 
    } 

    // evaluate the result at the point x_i to get scaleInv=1.0/scale. 
    double xi = i * elementSize/order; 
    double scaleInv = Enumerable.Range(0, order + 1).Aggregate(0.0, (product, j) => product * (j == i ? 1.0 : (j * elementSize/order - xi))); 

    /* generate an expression to evaluate 
    * (x_0 - x) * (x_1 - x) .. (x_n - x)/(x_i - x) 
    * obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear 
    */ 
    Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply); 
    // multiplying by scale forces the condition f(x_i)=1 
    expr = Expression.Multiply(Expression.Constant(1.0/scaleInv), expr); 

    Expression<Func<double, double>> lambdaMethod = Expression.Lambda<Func<double, double>>(expr, xp); 
    return lambdaMethod.Compile(); 
} 

问题:我还需要评估ψ=dψ/ DX。要做到这一点,我可以用ψ=α_n×x^n +α_n×x的形式来重写ψ= scale×(x_0-x)(x_1-x)×..×(x_n-x)/(x_i-x) ^(n-1)+ .. +α_1×x +α_0。这给出了ψ'= n×α_n×x ^(n-1)+(n-1)×α_n×x ^(n-2)+ .. + 1×α_1。我们可以通过写ψ'= x×(x×(x×(..) - β_2) - β_1) - β_0来改写最终答案,不需要调用Math.Pow

要做到这一切的“挂羊头卖狗肉”(一切都非常基本的代数),我需要一个干净的方式:

  1. 展开包含ConstantExpression和​​叶和基本的数学运算(结束了要么BinaryExpression一个因素ExpressionNodeType设置为操作) - 这里的结果可以包括InvocationExpression元素到MethodInfoMath.Pow,我们将在整个一个特殊的方式来处理。
  2. 然后我走衍生物相对于一些特定​​。条款中的结果,其中右手侧参数的Math.Pow调用是常数2通过ConstantExpression(2)乘以什么左侧取代(的Math.Pow(x,1)调用被移除)。结果中的术语由于它们对于x的常数而变为零,因此将被删除。
  3. 然后分解出的在那里它们发生如左手侧参数的Math.Pow调用一些具体​​实例。当调用的右侧变得与价值1一个ConstantExpression,我们只用​​本身更换调用。

¹在未来,我想方法采取​​并返回评估基于该参数的Expression。这样我可以聚合生成的函数。我还没有。 ²未来,我希望发布一个用于处理LINQ表达式的通用库作为符号数学。

+4

+1后,5号线将要失去我...这一定是真的一个聪明的问题;) – 2009-10-27 15:06:40

+0

另一方面,我了解所有数学,对LINQ一无所知!虽然,似乎你的算法已经很成功了。祝你好运! – Cascabel 2009-10-27 15:14:45

+0

@Jefromi:我可以生成一个表达式树。我想要建立的是一种优雅的方式来转化树木,将它们当作符号数学的表达。 :) – 2009-10-27 15:23:35

回答

6

我使用.NET 4中的ExpressionVisitor类型编写了几个符号数学特性的基础知识。它不完美,但它看起来像是可行解决方案的基础。

  • Symbolic是一个公共静态类曝光方法如ExpandSimplify,和PartialDerivative
  • ExpandVisitor为展开表达式
  • SimplifyVisitor是一个简化表达式
  • DerivativeVisitor是一个内部辅助型的内部型辅助一个内部帮助类型,采用表达式的派生
  • ListPrintVisitor是一个那一个Expression

Symbolic

public static class Symbolic 
{ 
    public static Expression Expand(Expression expression) 
    { 
     return new ExpandVisitor().Visit(expression); 
    } 

    public static Expression Simplify(Expression expression) 
    { 
     return new SimplifyVisitor().Visit(expression); 
    } 

    public static Expression PartialDerivative(Expression expression, ParameterExpression parameter) 
    { 
     bool totalDerivative = false; 
     return new DerivativeVisitor(parameter, totalDerivative).Visit(expression); 
    } 

    public static string ToString(Expression expression) 
    { 
     ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression); 
     return result.Value.ToString(); 
    } 
} 

扩大表达与ExpandVisitor

internal class ExpandVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     if (node.NodeType == ExpressionType.Multiply) 
     { 
      Expression[] leftNodes = GetAddedNodes(left).ToArray(); 
      Expression[] rightNodes = GetAddedNodes(right).ToArray(); 
      var result = 
       leftNodes 
       .SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y))) 
       .Aggregate((sum, term) => Expression.Add(sum, term)); 

      return result; 
     } 

     if (node.Left == left && node.Right == right) 
      return node; 

     return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion); 
    } 

    /// <summary> 
    /// Treats the <paramref name="node"/> as the sum (or difference) of one or more child nodes and returns the 
    /// the individual addends in the sum. 
    /// </summary> 
    private static IEnumerable<Expression> GetAddedNodes(Expression node) 
    { 
     BinaryExpression binary = node as BinaryExpression; 
     if (binary != null) 
     { 
      switch (binary.NodeType) 
      { 
      case ExpressionType.Add: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return n; 

       yield break; 

      case ExpressionType.Subtract: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return Expression.Negate(n); 

       yield break; 

      default: 
       break; 
      } 
     } 

     yield return node; 
    } 
} 

隔空衍生物与DerivativeVisitor

internal class DerivativeVisitor : ExpressionVisitor 
{ 
    private ParameterExpression _parameter; 
    private bool _totalDerivative; 

    public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative) 
    { 
     if (_totalDerivative) 
      throw new NotImplementedException(); 

     _parameter = parameter; 
     _totalDerivative = totalDerivative; 
    } 

    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
     case ExpressionType.Subtract: 
      return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right)); 

     case ExpressionType.Multiply: 
      return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right)); 

     case ExpressionType.Divide: 
      return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2))); 

     case ExpressionType.Power: 
      if (node.Right is ConstantExpression) 
      { 
       return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1)))); 
      } 
      else if (node.Left is ConstantExpression) 
      { 
       return Expression.Multiply(node, MathExpressions.Log(node.Left)); 
      } 
      else 
      { 
       return Expression.Multiply(node, Expression.Add(
        Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)), 
        Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left)) 
        )); 
      } 

     default: 
      throw new NotImplementedException(); 
     } 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     return MathExpressions.Zero; 
    } 

    protected override Expression VisitInvocation(InvocationExpression node) 
    { 
     MemberExpression memberExpression = node.Expression as MemberExpression; 
     if (memberExpression != null) 
     { 
      var member = memberExpression.Member; 
      if (member.DeclaringType != typeof(Math)) 
       throw new NotImplementedException(); 

      switch (member.Name) 
      { 
      case "Log": 
       return Expression.Divide(Visit(node.Expression), node.Expression); 

      case "Log10": 
       return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression)); 

      case "Exp": 
      case "Sin": 
      case "Cos": 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     throw new NotImplementedException(); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     if (node == _parameter) 
      return MathExpressions.One; 

     return MathExpressions.Zero; 
    } 
} 
转换为前缀符号与一个Lisp语法内部辅助类型

简化表达与SimplifyVisitor

internal class SimplifyVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     ConstantExpression leftConstant = left as ConstantExpression; 
     ConstantExpression rightConstant = right as ConstantExpression; 
     if (leftConstant != null && rightConstant != null 
      && (leftConstant.Value is double) && (rightConstant.Value is double)) 
     { 
      double leftValue = (double)leftConstant.Value; 
      double rightValue = (double)rightConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Add: 
       return Expression.Constant(leftValue + rightValue); 
      case ExpressionType.Subtract: 
       return Expression.Constant(leftValue - rightValue); 
      case ExpressionType.Multiply: 
       return Expression.Constant(leftValue * rightValue); 
      case ExpressionType.Divide: 
       return Expression.Constant(leftValue/rightValue); 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      if (IsZero(left)) 
       return right; 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Subtract: 
      if (IsZero(left)) 
       return Expression.Negate(right); 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Multiply: 
      if (IsZero(left) || IsZero(right)) 
       return MathExpressions.Zero; 
      if (IsOne(left)) 
       return right; 
      if (IsOne(right)) 
       return left; 
      break; 

     case ExpressionType.Divide: 
      if (IsZero(right)) 
       throw new DivideByZeroException(); 
      if (IsZero(left)) 
       return MathExpressions.Zero; 
      if (IsOne(right)) 
       return left; 
      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeBinary(node.NodeType, left, right); 
    } 

    protected override Expression VisitUnary(UnaryExpression node) 
    { 
     var operand = Visit(node.Operand); 

     ConstantExpression operandConstant = operand as ConstantExpression; 
     if (operandConstant != null && (operandConstant.Value is double)) 
     { 
      double operandValue = (double)operandConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Negate: 
       if (operandValue == 0.0) 
        return MathExpressions.Zero; 

       return Expression.Constant(-operandValue); 

      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Negate: 
      if (operand.NodeType == ExpressionType.Negate) 
      { 
       return ((UnaryExpression)operand).Operand; 
      } 

      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeUnary(node.NodeType, operand, node.Type); 
    } 

    private static bool IsZero(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(0.0)) 
       return true; 
     } 

     return false; 
    } 

    private static bool IsOne(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(1.0)) 
       return true; 
     } 

     return false; 
    } 
} 

显示格式表达与ListPrintVisitor

internal class ListPrintVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     string op = null; 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      op = "+"; 
      break; 
     case ExpressionType.Subtract: 
      op = "-"; 
      break; 
     case ExpressionType.Multiply: 
      op = "*"; 
      break; 
     case ExpressionType.Divide: 
      op = "/"; 
      break; 
     default: 
      throw new NotImplementedException(); 
     } 

     var left = Visit(node.Left); 
     var right = Visit(node.Right); 
     string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value); 
     return Expression.Constant(result); 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     if (node.Value is string) 
      return node; 

     return Expression.Constant(node.Value.ToString()); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     return Expression.Constant(node.Name); 
    } 
} 

测试结果

[TestMethod] 
public void BasicSymbolicTest() 
{ 
    ParameterExpression x = Expression.Parameter(typeof(double), "x"); 
    Expression linear = Expression.Add(Expression.Constant(3.0), x); 
    Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear)); 

    Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x)); 
    Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic)); 

    Expression expanded = Symbolic.Expand(quadratic); 
    Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(expanded)); 
    Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(Symbolic.Simplify(expanded))); 

    Expression derivative = Symbolic.PartialDerivative(expanded, x); 
    Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative)); 

    Expression simplified = Symbolic.Simplify(derivative); 
    Assert.AreEqual("(+ 5 (+ x x))", Symbolic.ToString(simplified)); 
}