import numpy as np
from scipy.special import betaln
import time

def nal_to_beta(f,c):
    w=c/(1-c)
    return w*f, w*(1-f)

def hellinger2(a1,b1,a2,b2):
    return 1-np.exp(betaln((a1+a2)/2,(b1+b2)/2)-0.5*(betaln(a1,b1)+betaln(a2,b2)))

def rao_distance(f1,c1,f2,c2):
    a1,b1=nal_to_beta(f1,c1)
    a2,b2=nal_to_beta(f2,c2)
    h2=hellinger2(a1,b1,a2,b2)
    return 2*np.arcsinh(np.sqrt(h2/(1-h2+1e-15)))

def c_star(d,o): return (-(o+1-d)+np.sqrt((o+1-d)**2+4*d*o))/(2*d)

np.random.seed(7)
domains=['cat','dog','fish','bird','snake','tree','rock','human','robot','alien']
props=['animal','living','entity','danger','pet','swim','fly','smart','old','fast']
kb=[]
for i in range(50):
    s=domains[i%10]; p=props[(i//10+i%7)%10]
    f=np.clip(np.random.beta(2,1),0.05,0.95)
    c=np.clip(np.random.beta(2,2),0.05,0.95)
    kb.append((f'{s}_{p}_{i}',s,p,round(f,3),round(c,3)))

thresh=c_star(0.05,0.1)
alive=[b for b in kb if b[4]>thresh]
print(f'KB: {len(kb)} beliefs, {len(alive)} alive (c*={thresh:.4f})')

def select_premises(gf,gc,kb_alive,top_k=5):
    scored=[(n,s,p,f,c,rao_distance(gf,gc,f,c)) for n,s,p,f,c in kb_alive]
    scored.sort(key=lambda x:x[5])
    return scored[:top_k]

t0=time.time()
goals=[('entity',0.7,0.5),('danger',0.3,0.4),('living',0.9,0.8)]
for gp,gf,gc in goals:
    premises=select_premises(gf,gc,alive)
    matching=[p for p in premises if p[2]==gp]
    print(f'Goal {gp} stv({gf},{gc}): {len(premises)} ranked, {len(matching)} match pred')
    for n,s,p,f,c,d in premises[:3]:
        print(f'  {n} stv({f},{c}) Rao={d:.4f}')
elapsed=time.time()-t0
print(f'Benchmark: 3 goals x 50 beliefs in {elapsed*1000:.1f}ms')