Library SLPA

Require Import ZArith.
Require Import String Ascii.
Require Import Bool.
Require Import PAInfTheoryDi.
Require Import Sets.Ensembles.

Module Type STRVAR <: VARIABLE.
  Parameter var : Type.   Parameter var_eq_dec : forall v1 v2 : var, {v1 = v2} + {v1 <> v2}.
  Parameter var2string : var -> string.
  Parameter string2var : string -> var.
  Parameter freshvar : var.
  Axiom var2String2var : forall v, string2var(var2string v) = v.
  Axiom String2var2String : forall s, var2string(string2var s) = s.
End STRVAR.

Module HeapSolver(sv:STRVAR).

Import sv.

Module PA := ArithSemantics PureNat sv.

Inductive HF : Type :=
  | H_Emp : HF
  | H_Ptto : PA.ZExp -> PA.ZExp -> HF
  | H_Star : HF -> HF -> HF
  | H_List : PA.ZExp -> HF
  | H_List_Size : PA.ZExp -> nat -> HF
  | H_Exists : var -> HF -> HF
  | H_And : HF -> PA.ZF -> HF
  | H_Pure : PA.ZF -> HF.

Definition heap := Ensemble nat.

Definition empty_heap := Empty_set nat.

Definition single_heap e := Singleton nat (PA.dexp2ZE e).

Definition heap_union h1 h2 := Union nat h1 h2.

Definition heap_is_disjoint h1 h2 := Disjoint nat h1 h2.

Inductive LL (e:PA.ZExp) : heap -> Prop :=
  | NIL_LL : LL e empty_heap
  | CONS_LL : forall h h1 h2 e1, h = heap_union h1 h2
              -> heap_is_disjoint h1 h2
              -> h1 = single_heap e
              -> LL e1 h2 -> LL e h.

Inductive LLSIZE (e:PA.ZExp) (n: nat) : heap -> Prop :=
  | NIL_LLSIZE : (PA.dexp2ZE e) = 0 -> n = 0 -> LLSIZE e n empty_heap
  | CONS_LLSIZE : forall h h1 h2 e1 n1, h = heap_union h1 h2
              -> heap_is_disjoint h1 h2
              -> h1 = single_heap e
              -> (PA.dexp2ZE e) > 0
              -> n = n1 + 1
              -> LLSIZE e1 n1 h2 -> LLSIZE e n h.

Theorem LLSIZE_implies_LL: forall e h n, LLSIZE e n h -> LL e h.
Proof.
intros.
induction H.
subst.
apply NIL_LL.
eapply CONS_LL.
apply H. apply H0. apply H1. apply IHLLSIZE.
Qed.

Fixpoint subs (p : var * PureNat.N.A) (form : HF) : HF :=
match form with
  | H_Emp => form
  | H_Ptto e1 e2 => H_Ptto (PA.subst_exp p e1) (PA.subst_exp p e2)
  | H_Star f1 f2 => H_Star (subs p f1) (subs p f2)
  | H_List e => H_List (PA.subst_exp p e)
  | H_List_Size e n => H_List_Size (PA.subst_exp p e) n
  | H_Exists v g => if var_eq_dec (fst p) v then form else H_Exists v (subs p g)
  | H_And f g => H_And (subs p f) (PA.substitute p g)
  | H_Pure g => H_Pure (PA.substitute p g)
end.

Fixpoint length_hform (form : HF) : nat :=
match form with
  | H_Exists v g => S (length_hform g)
  | H_Star f1 f2 => S (length_hform f1 + length_hform f2)
  | H_And f g => S (length_hform f)
  | _ => 1
end.

Lemma length_hform_gteq_one: forall f, length_hform f >= 1.
Proof.
intros; destruct f; unfold length_hform; omega.
Qed.

Fixpoint dvalid_hform' (form: HF) (h:heap) (c:nat): Prop :=
match c with
  0 => False
| S c' => match form with
            H_Emp => h = empty_heap
          | H_Ptto e1 e2 => h = (single_heap e1) /\ (PA.dexp2ZE e1) > O
          | H_Star f1 f2 => exists h1 h2,
                            (dvalid_hform' f1 h1 c') /\ (dvalid_hform' f2 h2 c')
                            /\ (heap_is_disjoint h1 h2) /\ h = (heap_union h1 h2)
          | H_List e => LL e h
          | H_List_Size e n => LLSIZE e n h
          | H_Exists v g => exists x, dvalid_hform' (subs (v,x) g) h c'
          | H_And f g => (dvalid_hform' f h c') /\ (PA.dvalid_zform g)
          | H_Pure g => (PA.dvalid_zform g)
        end
end.

Definition dvalid_hform f h := dvalid_hform' f h (length_hform f).

Lemma pure_valid_in_all_heap: forall h g, (dvalid_hform (H_Pure g) h) <-> PA.dvalid_zform g.
Proof.
intros; induction g;
unfold dvalid_hform; unfold length_hform; unfold dvalid_hform';tauto.
Qed.

Lemma subs_length_inv : forall f x v, length_hform f = length_hform (subs (v, x) f).
  Proof.
    induction f; simpl; try tauto; intros;
    try (rewrite <- IHf1; rewrite <- IHf2);
    try rewrite <- IHf;
    try (case (var_eq_dec v0 v); intros; simpl); auto.
  Qed.

Lemma large_c_holds : forall f h c1 c2, c1 >= length_hform f -> c2 >= length_hform f ->
                                        (dvalid_hform' f h c1 <-> dvalid_hform' f h c2).
  Proof.
    intros f h c1.
    revert f h.
    induction c1.
    intros.
    exfalso; destruct f; simpl in H; omega.

    destruct f; intros; simpl in H; simpl in H0;
    destruct c2;
    simpl;
    try (exfalso; omega);
    try tauto.

   split; intros;
   destruct H1,H1;
   exists x,x0.
   rewrite <- (IHc1 f1 x c2) , <- (IHc1 f2 x0 c2).
   apply H1.
   destruct c2;
   simpl; omega. omega. omega. omega.
   rewrite -> (IHc1 f1 x c2) , -> (IHc1 f2 x0 c2).
   apply H1. omega. omega. omega. omega.

   split;intros;destruct H1; exists x.
   rewrite <- (IHc1 (subs (v, x) f) h c2).
   apply H1.
   rewrite <- subs_length_inv with (x := x) (v := v).
   omega.
   rewrite <- subs_length_inv with (x := x) (v := v).
   omega.
   rewrite -> (IHc1 (subs (v, x) f) h c2).
   apply H1.
   rewrite <- subs_length_inv with (x := x) (v := v).
   omega.
   rewrite <- subs_length_inv with (x := x) (v := v).
   omega.

   split; intros.
   rewrite <- IHc1. tauto. omega. omega.
   rewrite (IHc1 f h c2). tauto. omega. omega.
  Qed.

Definition unfold_list_pure (e: PA.ZExp) : PA.ZF :=

(PA.ZF_Or (PA.ZF_BF (PA.ZBF_Eq e (PA.ZExp_Const PureNat.N.Const0)))
(PA.ZF_BF (PA.ZBF_Gt e (PA.ZExp_Const PureNat.N.Const0))))
.

Definition unfold_list_size_pure (e: PA.ZExp) (n: nat) : PA.ZF :=
(PA.ZF_Or (PA.ZF_And
(PA.ZF_BF (PA.ZBF_Eq e (PA.ZExp_Const PureNat.N.Const0)))
(PA.ZF_BF (PA.ZBF_Eq (PA.ZExp_Const n) (PA.ZExp_Const PureNat.N.Const0))))
(PA.ZF_And
(PA.ZF_BF (PA.ZBF_Gt e (PA.ZExp_Const PureNat.N.Const0)))
(PA.ZF_BF (PA.ZBF_Gt (PA.ZExp_Const n)(PA.ZExp_Const PureNat.N.Const0)))))
.

Fixpoint xpure' (form: HF) : PA.ZF :=

          match form with
            | H_Emp => PA.ZF_BF (PA.ZBF_Const true)
            | H_Ptto e1 e2 => PA.ZF_BF (PA.ZBF_Gt e1 (PA.ZExp_Const PureNat.N.Const0))
            | H_Star f1 f2 => PA.ZF_And (xpure' f1) (xpure' f2)
            | H_List e => (unfold_list_pure e)
            | H_List_Size e n => (unfold_list_size_pure e n)
            | H_Exists v g => PA.ZF_Exists v tt (xpure' g)
            | H_And f g => PA.ZF_And (xpure' f) g
            | H_Pure g => g
          end
.

Definition xpure f := H_Pure (xpure' f).

Lemma PA_dexp2ZE_always_positive: forall e, (PA.dexp2ZE e) >= 0.
Proof.
intros; destruct e.
unfold PA.dexp2ZE.
omega.
unfold PA.dexp2ZE. omega.
simpl. omega.
simpl. omega.
simpl. omega.
simpl. omega.
Qed.

Lemma xpure_length_gt : forall f, (length_hform f) >= length_hform (xpure f).
  Proof.
    induction f; simpl; try tauto; intros;
    try (rewrite <- IHf1; rewrite <- IHf2);
    try rewrite <- IHf;
    try (case (var_eq_dec v0 v); intros; simpl); auto.
    assert ((length_hform f1) >=1) by apply length_hform_gteq_one.
    assert ((length_hform f2) >=1) by apply length_hform_gteq_one.
    omega.
  Qed.

Lemma xpure_length_one : forall f, (length_hform (xpure f)) = 1.
Proof.
    induction f; simpl; try tauto; intros;
    try (rewrite <- IHf1; rewrite <- IHf2);
    try rewrite <- IHf;
    try (case (var_eq_dec v0 v); intros; simpl); auto.
Qed.

Lemma substitute_xpure'_eq_xpure'_subs : forall v x f,
                                PA.substitute (v, x) (xpure' f) =
                                xpure' (subs (v, x) f).
Proof.

intros; induction f;
try (simpl; tauto).
simpl.
congruence.
simpl.
rewrite IHf.
destruct var_eq_dec.
tauto.
simpl. tauto. simpl.
rewrite IHf. tauto.
Qed.

Theorem xpure_valid: forall f h,
(dvalid_hform f h) -> dvalid_hform (xpure f) h.
Proof.
  intros f; remember (length_hform f); assert (length_hform f <= n) by omega; clear Heqn; revert f H.
  induction n; intros.
  exfalso; destruct f; simpl in H; omega.
 destruct f; simpl in *.
 unfold xpure, xpure'.
 unfold dvalid_hform, length_hform, dvalid_hform'.
 unfold PA.dvalid_zform, PA.length_zform, PA.dvalid_zform'.
 try (simpl; tauto).

 unfold xpure, xpure'.
 unfold dvalid_hform, length_hform, dvalid_hform'.
  unfold PA.dvalid_zform, PA.length_zform, PA.dvalid_zform'.
  simpl.
  destruct H0.
  unfold negb.
  destruct (PureNat.N.num_leq (PA.dexp2ZE z) PureNat.N.Const0) eqn: ?.
  destruct Heqb.
  unfold PureNat.N.num_leq.
  destruct (le_dec (PA.dexp2ZE z) PureNat.N.Const0).
  assert (PureNat.N.Const0 = 0). trivial. omega. trivial. trivial.

  unfold dvalid_hform.
  simpl.

  unfold PA.dvalid_zform. simpl.
  split.
  rewrite (PA.large_c_holds (xpure' f1)
  (PA.length_zform (xpure' f1) + PA.length_zform (xpure' f2))
  (PA.length_zform (xpure' f1))).

  unfold dvalid_hform in IHn.
  simpl in IHn.
  unfold PA.dvalid_zform in IHn.
  unfold dvalid_hform in H0.
  simpl in H0. destruct H0. destruct H0.
  destruct H0.

  apply IHn with (h:=x). omega.
  rewrite (large_c_holds f1 x (length_hform f1) (length_hform f1 + length_hform f2)).
  tauto.
  omega. omega. omega. omega.

  rewrite (PA.large_c_holds (xpure' f2)
  (PA.length_zform (xpure' f1) + PA.length_zform (xpure' f2))
  (PA.length_zform (xpure' f2))).

  unfold dvalid_hform in IHn.
  simpl in IHn.
  unfold PA.dvalid_zform in IHn.
  unfold dvalid_hform in H0.
  simpl in H0. destruct H0. destruct H0.
  destruct H0.

  apply IHn with (h:=x0). omega.
  rewrite (large_c_holds f2 x0 (length_hform f2) (length_hform f1 + length_hform f2)).
  destruct H1. trivial.
  omega. omega. omega. omega.
 unfold dvalid_hform in IHn. simpl in IHn.
 unfold xpure.

 unfold dvalid_hform. simpl.
 unfold PA.dvalid_zform. simpl.
 assert (PA.dexp2ZE z >= 0) by apply PA_dexp2ZE_always_positive.
 unfold PureNat.N.num_leq.
 destruct (le_dec (PA.dexp2ZE z) PureNat.N.Const0).
 destruct (le_dec PureNat.N.Const0 (PA.dexp2ZE z)).
 tauto. tauto. right. tauto.

 unfold dvalid_hform in IHn. simpl in IHn.
 unfold xpure.

 unfold dvalid_hform. simpl.
 unfold PA.dvalid_zform. simpl.
 unfold dvalid_hform in H0. simpl in H0.
 destruct H0.
  left. split. unfold PureNat.N.num_leq.
  assert (PureNat.N.Const0=0). trivial.
    destruct le_dec. destruct le_dec.
    tauto. exfalso. omega.
  destruct le_dec. exfalso. omega. exfalso. omega.
  unfold PureNat.N.num_leq.
  assert (PureNat.N.Const0=0). trivial.
  destruct le_dec. destruct le_dec.
  tauto. exfalso. omega.
  destruct le_dec. exfalso. omega. exfalso. omega.

  right. split. unfold PureNat.N.num_leq.
  assert (PureNat.N.Const0=0). trivial.
    destruct le_dec. exfalso;omega. tauto.
    unfold PureNat.N.num_leq.
    destruct le_dec.
    assert (PureNat.N.Const0=0). trivial.
    exfalso;omega.
    tauto.

  unfold dvalid_hform in IHn.
  simpl in IHn.
  unfold PA.dvalid_zform in IHn.
  unfold dvalid_hform in H0.
  simpl in H0.

 unfold dvalid_hform. simpl.
 unfold PA.dvalid_zform. simpl.

 unfold dvalid_hform in H0. simpl in H0.
 destruct H0. exists x.

  rewrite (PA.large_c_holds (PA.substitute (v, PureNat.conv x) (xpure' f))
  (PA.length_zform (xpure' f))
  (PA.length_zform ((PA.substitute (v, @PureNat.conv tt x) (xpure' f))))).

 rewrite substitute_xpure'_eq_xpure'_subs.
 apply IHn with (h:= h).
 rewrite <- subs_length_inv with (v:=v) (x:=@PureNat.conv tt x).
 omega.

rewrite (large_c_holds (subs (v, x) f) h (length_hform f)
(length_hform (subs (v, x) f))) in H0. tauto.
 rewrite subs_length_inv with (v:=v) (x:=x).
omega. omega. rewrite <- PA.substitute_length_inv with (v:=v) (x:=x).
omega. omega.

  unfold dvalid_hform in IHn.
  simpl in IHn.
  unfold PA.dvalid_zform in IHn.
  unfold dvalid_hform in H0.
  simpl in H0.
 unfold dvalid_hform. simpl.
 unfold PA.dvalid_zform. simpl.
 split.
 rewrite (PA.large_c_holds (xpure' f)
 (PA.length_zform (xpure' f) + PA.length_zform z)
 (PA.length_zform (xpure' f))).
 apply IHn with (h:=h). omega. destruct H0. tauto. omega.
 omega.

 rewrite (PA.large_c_holds z
 (PA.length_zform (xpure' f) + PA.length_zform z)
 (PA.length_zform (z))).
 destruct H0. tauto. omega. omega.

 unfold dvalid_hform. simpl.
 unfold dvalid_hform in H0. simpl in H0. tauto.

Qed.

Definition entail P Q := forall h, (dvalid_hform P h) -> (dvalid_hform Q h).

End HeapSolver.