import numpy as np
from scipy.special import betaln

# === Rao geodesic distance (g129-g130) ===
def nal_to_beta(f,c):
    w=c/(1-c)
    return w*f, w*(1-f)

def hellinger2(a1,b1,a2,b2):
    return 1-np.exp(betaln((a1+a2)/2,(b1+b2)/2)-0.5*(betaln(a1,b1)+betaln(a2,b2)))

def rao_distance(f1,c1,f2,c2):
    a1,b1=nal_to_beta(f1,c1)
    a2,b2=nal_to_beta(f2,c2)
    h2=hellinger2(a1,b1,a2,b2)
    return 2*np.arcsinh(np.sqrt(h2/(1-h2+1e-15)))

# === Homeostasis threshold (g120) ===
def c_star(d,o):
    return (-(o+1-d)+np.sqrt((o+1-d)**2+4*d*o))/(2*d)

# === Unified backward chainer ===
class UnifiedChainer:
    def __init__(self, kb, decay=0.05, obs_rate=0.1):
        self.kb=kb  # list of (name, subj, pred, f, c)
        self.decay=decay
        self.obs_rate=obs_rate
        self.threshold=c_star(decay,obs_rate)
    
    def alive(self,c):
        return c>self.threshold
    
    def select_premises(self,goal_f,goal_c,top_k=3):
        alive_kb=[(n,s,p,f,c) for n,s,p,f,c in self.kb if self.alive(c)]
        scored=[(n,s,p,f,c,rao_distance(goal_f,goal_c,f,c)) for n,s,p,f,c in alive_kb]
        scored.sort(key=lambda x:x[5])
        return scored[:top_k]
    
    def backward_chain(self,goal_pred,goal_f,goal_c,depth=3):
        if depth==0: return []
        premises=self.select_premises(goal_f,goal_c,top_k=5)
        chains=[]
        for n,s,p,f,c,d in premises:
            if p==goal_pred:
                chains.append([(n,f,c,d)])
            elif depth>1:
                sub=self.backward_chain(p,f,c,depth-1)
                for chain in sub:
                    chains.append([(n,f,c,d)]+chain)
        return chains

# === Test ===
kb=[("cat_animal","cat","animal",0.9,0.8),("animal_living","animal","living",0.85,0.7),
    ("living_entity","living","entity",0.8,0.6),("cat_danger","cat","danger",0.1,0.15),
    ("fish_swim","fish","swim",0.95,0.9),("dead_belief","cat","old",0.5,0.03)]
chainer=UnifiedChainer(kb,decay=0.05,obs_rate=0.1)
print(f"Homeostasis threshold c*={chainer.threshold:.4f}")
print(f"dead_belief alive? {chainer.alive(0.03)}")
print(f"cat_danger alive? {chainer.alive(0.15)}")
print("\nBackward chain to entity:")
for chain in chainer.backward_chain("entity",0.8,0.5,depth=3):
    steps=" -> ".join([f"{n}(f={f},c={c},rao={d:.3f})" for n,f,c,d in chain])
    print(f"  {steps}")