from dataclasses import dataclass

@dataclass
class TruthValue:
    f: float
    c: float

    def clamp(self):
        self.f = min(1.0, max(0.0, self.f))
        self.c = min(1.0, max(0.0, self.c))
        return self

def deduction(a, b):
    return TruthValue(a.f * b.f, a.c * b.c * 0.9).clamp()

def revision(x, y):
    if x is None:
        return y.clamp() if y is not None else None
    if y is None:
        return x.clamp()
    w1 = x.c / max(1e-9, 1.0 - x.c)
    w2 = y.c / max(1e-9, 1.0 - y.c)
    f = (w1 * x.f + w2 * y.f) / max(1e-9, w1 + w2)
    c = (w1 + w2) / (w1 + w2 + 1.0)
    return TruthValue(f, c).clamp()

@dataclass
class Rule:
    left: str
    link: str
    right: str
    cost: float = 1.0
    value: float = 1.0

@dataclass
class AttentionRecord:
    spent: float = 0.0
    earned: float = 0.0
    hits: int = 0

class Reasoner:
    def __init__(self, attention_budget=2.0):
        self.beliefs = {}
        self.rules = []
        self.attention_budget = attention_budget
        self.attention = {}

    def add_fact(self, stmt, tv):
        merged = revision(self.beliefs.get(stmt), tv)
        if merged is not None:
            self.beliefs[stmt] = merged

    def add_rule(self, left, link, right, cost=1.0, value=1.0):
        self.rules.append(Rule(left, link, right, cost, value))
        self.attention.setdefault(right, AttentionRecord())

    def market_value(self, stmt, inferred):
        old = self.beliefs.get(stmt)
        novelty = 1.0 if old is None else abs(old.f - inferred.f) + abs(old.c - inferred.c)
        confidence_gain = inferred.c if old is None else max(0.0, inferred.c - old.c)
        return novelty + confidence_gain

    def candidate_trades(self):
        trades = []
        for rule in self.rules:
            if rule.left in self.beliefs and rule.link in self.beliefs:
                inferred = deduction(self.beliefs[rule.left], self.beliefs[rule.link])
                utility = rule.value * self.market_value(rule.right, inferred)
                roi = utility / max(1e-9, rule.cost)
                trades.append((roi, utility, rule, inferred))
        trades.sort(key=lambda x: x[0], reverse=True)
        return trades

    def step(self):
        spent = 0.0
        changed = False
        for roi, utility, rule, inferred in self.candidate_trades():
            if spent + rule.cost > self.attention_budget:
                continue
            old = self.beliefs.get(rule.right)
            merged = revision(old, inferred)
            spent += rule.cost
            book = self.attention.setdefault(rule.right, AttentionRecord())
            book.spent += rule.cost
            book.earned += utility
            book.hits += 1
            if old is None or abs(old.f - merged.f) > 1e-9 or abs(old.c - merged.c) > 1e-9:
                self.beliefs[rule.right] = merged
                changed = True
        return changed

    def run(self, steps=8):
        for _ in range(steps):
            if not self.step():
                break

if __name__ == "__main__":
    r = Reasoner(attention_budget=2.0)
    r.add_fact("bird:robin", TruthValue(0.95, 0.90))
    r.add_fact("bird_implies_animal", TruthValue(0.98, 0.85))
    r.add_fact("animal_implies_mortal", TruthValue(0.97, 0.80))
    r.add_fact("bird_implies_can_fly", TruthValue(0.90, 0.70))
    r.add_rule("bird:robin", "bird_implies_animal", "animal:robin", cost=0.7, value=1.4)
    r.add_rule("animal:robin", "animal_implies_mortal", "mortal:robin", cost=0.9, value=1.2)
    r.add_rule("bird:robin", "bird_implies_can_fly", "can_fly:robin", cost=1.6, value=0.8)
    r.run()
    r.add_fact("animal:robin", TruthValue(0.70, 0.60))
    r.add_fact("animal:robin", TruthValue(0.90, 0.70))
    r.run()
    for k in sorted(r.beliefs):
        v = r.beliefs[k]
        print(f"{k} -> f={v.f:.3f} c={v.c:.3f}")
    print("attention")
    for k in sorted(r.attention):
        a = r.attention[k]
        print(f"{k} -> spent={a.spent:.3f} earned={a.earned:.3f} hits={a.hits}")
