import numpy as np
from scipy.special import betaln
import time

def nal_to_beta(f,c):
    w=c/(1-c)
    return w*f, w*(1-f)

def hellinger2(a1,b1,a2,b2):
    return 1-np.exp(betaln((a1+a2)/2,(b1+b2)/2)-0.5*(betaln(a1,b1)+betaln(a2,b2)))

def rao_distance(f1,c1,f2,c2):
    a1,b1=nal_to_beta(f1,c1)
    a2,b2=nal_to_beta(f2,c2)
    h2=hellinger2(a1,b1,a2,b2)
    return 2*np.arcsinh(np.sqrt(h2/(1-h2+1e-15)))

def c_star(d,o): return (-(o+1-d)+np.sqrt((o+1-d)**2+4*d*o))/(2*d)

def filtered_select(goal_pred, goal_f, goal_c, kb, thresh, top_k=5):
    alive=[b for b in kb if b[4]>thresh]
    # Layer 1: exact predicate match
    exact=[b for b in alive if b[2]==goal_pred]
    # Layer 2: subject match (for chaining)
    subj_match=[b for b in alive if b[1]==goal_pred and b not in exact]
    # Layer 3: all remaining alive
    rest=[b for b in alive if b not in exact and b not in subj_match]
    def rank(pool):
        scored=[(n,s,p,f,c,rao_distance(goal_f,goal_c,f,c)) for n,s,p,f,c in pool]
        scored.sort(key=lambda x:x[5])
        return scored
    result=rank(exact)+rank(subj_match)+rank(rest)
    return result[:top_k], len(exact), len(subj_match)

# Benchmark on 50-belief KB
np.random.seed(7)
domains=['cat','dog','fish','bird','snake','tree','rock','human','robot','alien']
props=['animal','living','entity','danger','pet','swim','fly','smart','old','fast']
kb=[]
for i in range(50):
    s=domains[i%10]; p=props[(i//10+i%7)%10]
    f=np.clip(np.random.beta(2,1),0.05,0.95)
    c=np.clip(np.random.beta(2,2),0.05,0.95)
    kb.append((f'{s}_{p}_{i}',s,p,round(f,3),round(c,3)))

thresh=c_star(0.05,0.1)
t0=time.time()
goals=[('entity',0.7,0.5),('danger',0.3,0.4),('living',0.9,0.8),('animal',0.8,0.7),('swim',0.6,0.3)]
for gp,gf,gc in goals:
    top,n_exact,n_subj=filtered_select(gp,gf,gc,kb,thresh)
    print(f'Goal {gp} stv({gf},{gc}): {n_exact} exact, {n_subj} chainable')
    for n,s,p,f,c,d in top[:3]:
        print(f'  {n} [{p}] stv({f},{c}) Rao={d:.4f}')
elapsed=time.time()-t0
print(f'Benchmark: 5 goals x 50 beliefs in {elapsed*1000:.1f}ms')