ねぇうしくんうしくん

今週のまとめ (一週間で自分が見た技術系サイトのログ)が今のところメインです。プログラミング言語、人工知能、セキュリティ 等

Coq 末尾再帰の等価性の証明

レイトレーシングを実装する際、光の反射の計算に再帰を用いる。擬似コードを以下に示す。

let ref f ... argN =
   let a = ~~~~ in (* 加算:光源など *)
   let b = ~~~~ in (* 乗算:反射での減衰など *)
   let finished = ~~~~ in (* 計算を終了するか *)
   if ( finished ) then 
      a
   else
      a + b * (f ...)

しかし、このような実装ではスタックオーバーフローが起こる可能性がある。 (特に不偏レンダリングをする際) これを回避するためには再帰しないような関数に書き直す必要がある。

https://github.com/githole/simple-pathtracer/blob/96c24d104ea6e30e3b1265d8443840a0aea8a104/simplept.cpp#L114

これはwhile を用いた C++ の例であるが、関数型っぽく実装したい場合は末尾再帰呼び出しに書き換える事が考えられる。

let ref f arg0 arg1 ... argN (s, t) =
   let a = ~~~~ in
   let b = ~~~~ in
   let finished = ~~~~ in
   if ( finished ) then 
      s + a * t
   else
      f ... (s + a * t, t * b)

本当に等価なのか怪しいので Coq で証明する。Coq だと上のように停止するかわからない関数は書けないので適当な変数 n を導入し停止する関数にエンコードする。

(* 再帰 *)
Fixpoint f
  (a b : nat -> nat) (n: nat) :=
  match n with
    | O => a 0
    | S n' => (a (S n')) + (b (S n')) * f a b n'
  end.

(* 末尾再帰 *)
Fixpoint f_tr
  (a b : nat -> nat) (n: nat) (acc : nat * nat) :=
  match n with
    | O => (fst acc) + (a 0) * (snd acc)
    | S n' => f_tr a b n' ((fst acc) + (a (S n')) * (snd acc),
                         (snd acc) * (b (S n')))
  end.

あとは証明を行う。特に難しくない。*1。 肝は a b を一般化したままにしておくぐらいか。

Lemma tail_rec_equiv_st : forall (a b : nat -> nat) (n s t : nat) ,
  t * (f a b n) + s = f_tr a b n (s, t).
Proof.
  intros a b n.
  induction n.
  - intros s t. simpl. rewrite plus_comm, mult_comm. reflexivity.
  - simpl. intros s t.
    rewrite <- IHn.
    (* やるだけ *)
    rewrite Nat.mul_add_distr_l.
    rewrite mult_assoc, mult_comm.
    assert(H: forall x y z, x + y + z = y + (z + x)). {
      intros x y z.
      omega.
    }
    apply H.
Qed.

Theorem tail_rec_equiv : forall (a b : nat -> nat) (n: nat),
  f a b n = f_tr a b n (0, 1).
Proof.
  intros a b n.
  rewrite <- tail_rec_equiv_st.
  omega.
Qed.

この証明を行う上において、もっと簡単な fact の末尾再帰の例を先に証明しそれを参考にした。

Fixpoint f (n: nat) :=
  match n with
    | O => 1
    | S n' => n' * f n'
  end.

Fixpoint f_tr (n: nat) (acc : nat) :=
  match n with
    | O => acc
    | S n' => f_tr n' (acc * n')
  end.

Theorem tail_rec_equiv: forall (n : nat) (p : nat),
  p * f n = f_tr n p.
Proof.
  intros n.
  induction n.
  - intros p. simpl. omega. 
  - simpl.
    intros p.
    rewrite <- IHn.
    rewrite Nat.mul_assoc.
    auto.
Qed.

*1: 僕は詰まりました