ウチのコンパイラは微分もしますぜ

誰でも一度はコンパイラ微分してみたくなることがあったず。無性に何か書きたかったので書いてみました。


自動で微分するくらいなら、課題なり何なりで書いてみたことある人いるんじゃないでしょうか。template でもっとアレげに。

#include <iostream>
#include <cmath>

using namespace std;

template <int n>
class Num {
public:
  double calc(double x) const {
    return n;
  }

  double derivation(double x) const {
    return 0;
  }
};

class Var {
public:
  double calc(double x) const {
    return x;
  }

  double derivation(double x) const {
    return 1;
  }
};

template <int n>
class Power {
public:
  double calc(double x) const {
    return x * Power<n - 1>().calc(x);
  }

  double derivation(double x) const {
    return n * Power<n - 1>().calc(x);
  }
};

template <>
double Power<0>::calc(double x) const {
  return 1;
}

template <>
double Power<0>::derivation(double x) const {
  return 0;
}

class Exp {
public:
  double calc(double x) const {
    return exp(x);
  }

  double derivation(double x) const {
    return exp(x);
  }
};

template <typename F, typename E>
class Fun {
public:
  double calc(double x) const {
    return F().calc(E().calc(x));
  }

  double derivation(double x) const {
    return E().derivation(x) * F().derivation(E().calc(x));
  }
};

template <typename E1, typename E2>
class Add {
public:
  double calc(double x) const {
    return E1().calc(x) + E2().calc(x);
  }

  double derivation(double x) const {
    return E1().derivation(x) + E2().derivation(x);
  }
};

template <typename E1, typename E2>
class Mul {
public:
  double calc(double x) const {
    return E1().calc(x) * E2().calc(x);
  }

  double derivation(double x) const {
    return E1().derivation(x) * E2().calc(x) + E1().calc(x) * E2().derivation(x);
  }
};

int main() {
  cout << Fun<Power<3>, Var>().derivation(4) << " = " << 3 * 4 * 4 << endl; // d(x^3)/dx | x=4   == 48
  cout << Fun<Exp, Fun<Power<3>, Var> >().derivation(5) << " = " << 3 * 5 * 5 * exp(5 * 5 * 5) << endl; // d(exp(x^3))/dx | x=5   == 1.4e56
  cout << Add<Fun<Power<2>, Var>, Mul<Fun<Power<3>, Var>, Num<10> > >().derivation(4)
       << " = " << 2 * 4 + 10 * 3 * 4 * 4 << endl; // d(x^2 + 10x^3)/dx | x=4   == 488
}

calc はふつうにその関数を計算します。derivation は導関数の値を計算します。Mul の derivation とか見れば、何をしたいかわかりますよね。

で、肝心のコンパイラはどこまでがんばるか。-S でアセンブリ吐かせたら、ちゃんと 488 とか 48 という数字が入っていました! 偉い。