Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/asplain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def reify_program(
Returns:
The reified program as a string.
"""
constants = constants or {}
extensions = [
TagExtension(include_program=True, include_loc=True, include_id=True),
ShowExtension(),
Expand All @@ -55,6 +56,7 @@ def reify_program(


# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments
def construct_program_graph(
file_paths: List[str],
prg: str = "",
Expand Down Expand Up @@ -104,7 +106,9 @@ def construct_program_graph(
return symbols_to_prg(list(model_symbols))


def set_model_subgraphs_ctl(pg, ctl=None, model_symbols: Optional[List[str]] = None) -> Control:
def set_model_subgraphs_ctl(
pg: str, ctl: Optional[Control] = None, model_symbols: Optional[List[str]] = None
) -> Control:
"""
Sets the control object for computing model subgraphs.
Args:
Expand Down
78 changes: 47 additions & 31 deletions src/asplain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import sys
from textwrap import dedent
from time import time
from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

from clingo import Application, ApplicationOptions, Control, Flag, Model, parse_term
from clingo import Application, ApplicationOptions, Control, Flag, Model, Symbol, parse_term

from asplain import (
construct_program_graph,
Expand All @@ -27,8 +27,6 @@
from asplain.utils.logging import colored, configure_logging, save_out
from asplain.utils.viz import viz_graph

# from asplain.utils.viz import viz_graph_mock as viz_graph

try:
from asplain.llm.models import ModelTag, OpenAIModel
from asplain.llm.models.google import GoogleModel
Expand All @@ -43,30 +41,31 @@
log = logging.getLogger(__name__)


# pylint: disable=too-many-instance-attributes, too-many-public-methods
class AsplainApp(Application):
"""Application for reification with extensions."""

def __init__(self, name, constants: Optional[dict[str, str]] = None) -> None:
def __init__(self, name: str, constants: Optional[dict[str, str]] = None) -> None:
"""Initialize AsplainApp."""
self.program_name = name
self._log_level = "WARNING"
self._constants = constants or {}
self._query_include = []
self._query_exclude = []
self._assumptions = []
self._query_include: List[Symbol] = []
self._query_exclude: List[Symbol] = []
self._assumptions: List[Tuple[str, bool]] = []
self._number_explanations = 1

self._dynamic_tags = []
self._cost_encoding = []
self._model_symbols = None
self._dynamic_tags: list[str] = []
self._cost_encoding: list[str] = []
self._model_symbols: Optional[list[str]] = None

self._open: Flag = Flag()

self._pruning_methods: list[PruningMethod] = []
if INSTALLED_LLMS:
self._llm_tag: Optional[ModelTag] = None

self.statistics = {
self.statistics: dict[str, Any] = {
"Program Graph": {},
"Reference Graph": {},
"Contrastive Graph": {},
Expand All @@ -81,15 +80,15 @@ def setter(value: Any) -> bool:
if not os.path.isfile(value):
raise ValueError(f"File '{value}' does not exist.")
if not multi:
self.__setattr__(attr_name, value)
setattr(self, attr_name, value)
else:
current_value = getattr(self, attr_name, [])
if not isinstance(current_value, list):
log.error("Attribute %s is not a list", attr_name)
log.error("Setting value to list")
current_value = [current_value]
current_value.append(value)
self.__setattr__(attr_name, current_value) # Use direct assignment instead of __setattr__
setattr(self, attr_name, current_value)
return True

return setter
Expand All @@ -104,7 +103,7 @@ def parse_log_level(self, log_level: str) -> bool:

return True

def parse_assumptions(self, value) -> bool:
def parse_assumptions(self, value: str) -> bool:
"""
Parse assumptions string
"""
Expand All @@ -115,25 +114,28 @@ def parse_assumptions(self, value) -> bool:

return True

def parse_number_explanations(self, value) -> bool:
def parse_number_explanations(self, value: str) -> bool:
"""
Parse number of explanations
"""
self._number_explanations = int(value)
return True

def parse_query(self, value) -> bool:
def parse_query(self, value: str) -> bool:
"""
Parse query string
"""

true_queries, false_queries = divide_space_string(value)
self._query_include = [str(parse_term(s)) for s in true_queries]
self._query_exclude = [str(parse_term(s)) for s in false_queries]
self._query_include = [parse_term(s) for s in true_queries]
self._query_exclude = [parse_term(s) for s in false_queries]

return True

def parse_model(self, value: str) -> bool:
"""
Save the model command line in the object
"""
self.parse_file("_model_file")(value)
ctl = Control(["1", "--warn=none"])
ctl.load(value)
Expand All @@ -147,6 +149,9 @@ def parse_model(self, value: str) -> bool:
return False

def parse_llm_tag(self, value: str) -> bool:
"""
Save the LLM tag for prompting in the object
"""
if INSTALLED_LLMS:
if value in [str(m) for m in ModelTag.__members__]:
tag = ModelTag[value]
Expand All @@ -155,13 +160,17 @@ def parse_llm_tag(self, value: str) -> bool:
return False

def parse_pruning(self, value: str) -> bool:
"""
Save the pruning method in the object
"""
if value in [str(m) for m in PruningMethod.__members__]:
method = PruningMethod[value]
self._pruning_methods.append(method)
return True
return False

def register_options(self, options: ApplicationOptions) -> None:
"""Register command line options."""
group = colored("blue", "Asplain Options")

options.add(
Expand Down Expand Up @@ -289,7 +298,7 @@ def register_options(self, options: ApplicationOptions) -> None:
self._open,
)

def size_for_statistics(self, name: str, pg: str) -> dict[str, int]:
def size_for_statistics(self, name: str, pg: str) -> dict[str, Any]:
"""
Compute size statistics for a program graph.
"""
Expand All @@ -305,26 +314,33 @@ def size_for_statistics(self, name: str, pg: str) -> dict[str, int]:
count = s.arguments[1].number
self.statistics[name][category] = count

def on_statistics(self, step, accu) -> None:
self.statistics["Cost encoding"] = len(self._cost_encoding)
self.statistics["Pruning methods"] = len(self._pruning_methods)
self.statistics["Explanations"] = self._number_explanations
return self.statistics[name] # type: ignore

def on_statistics(self, _: Any, accu: dict[str, Any]) -> None:
"""
Callback to collect statistics after solving
"""
self.statistics["Cost encoding"] = {"count": len(self._cost_encoding)}
self.statistics["Pruning methods"] = {"count": len(self._pruning_methods)}
self.statistics["Explanations"] = {"count": self._number_explanations}
self.statistics["Number of changes"] = (
{"added": len(self._foil_inspection[1]), "removed": len(self._foil_inspection[2])}
if self._foil_inspection
else None
)
accu["Asplain"] = self.statistics

def print_model(self, model: Model, _) -> None:
def print_model(self, model: Model, _) -> None: # type: ignore
"""Print the model's symbols."""
symbols = model.symbols(shown=True)
print(" ".join([str(s) for s in model_symbols(symbols)]))

def main(self, ctl: Control, files: Sequence[str]) -> None:
def main(self, control: Control, files: Sequence[str]) -> None:
"""
Main entry point.
"""
# pylint: disable=W0201
# pylint: disable=too-many-branches, too-many-statements

configure_logging(sys.stderr, self._log_level, sys.stderr.isatty()) # type: ignore
query_prg = get_query_prg(self._query_include, self._query_exclude)
Expand All @@ -351,8 +367,8 @@ def main(self, ctl: Control, files: Sequence[str]) -> None:
)
self.size_for_statistics("Program Graph", reference_pg)
start_time = time()
model_subgraphs_ctl = set_model_subgraphs_ctl(pg=reference_pg, ctl=ctl, model_symbols=self._model_symbols)
with model_subgraphs_ctl.solve(yield_=True, on_statistics=self.on_statistics) as hnd:
model_subgraphs_ctl = set_model_subgraphs_ctl(pg=reference_pg, ctl=control, model_symbols=self._model_symbols)
with model_subgraphs_ctl.solve(yield_=True, on_statistics=self.on_statistics) as hnd: # type: ignore
model_found = False
for model in hnd:
model_found = True
Expand Down Expand Up @@ -406,14 +422,14 @@ def main(self, ctl: Control, files: Sequence[str]) -> None:
pg=explanation_graph,
title="Contrastive Graph",
name=f"contrastive_pg_{model.number}_{foil_model.number}",
open=self._open.flag,
show=self._open.flag,
)
if INSTALLED_LLMS:
if self._llm_tag is not None:
# Prompt the LLM
if self._llm_tag.value.openai is not None:
log.info("Using OpenAI API")
llm = OpenAIModel(model_tag=self._llm_tag)
llm: Union[OpenAIModel, GoogleModel] = OpenAIModel(model_tag=self._llm_tag)
elif self._llm_tag.value.google is not None:
log.info("Using Google API")
llm = GoogleModel(model_tag=self._llm_tag)
Expand Down Expand Up @@ -451,7 +467,7 @@ def main(self, ctl: Control, files: Sequence[str]) -> None:
pg=explanation_graph,
title="Contrastive Graph",
name=f"contrastive_pg_UNSAT_{foil_model.number}",
open=self._open.flag,
show=self._open.flag,
)
explanation_symbols = list(foil_model.symbols(shown=True))
for method in self._pruning_methods:
Expand Down
4 changes: 2 additions & 2 deletions src/asplain/encodings/construct-foil.lp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

{removed(R)}:-abducible(remove, R), program(R, ref).
{added(R)}:-abducible(add, R), not program(R, ref).
{removed(R)}:-optional(remove, R), program(R, ref).
{added(R)}:-optional(add, R), not program(R, ref).

% Nodes that are in the reference and not removed, plus added nodes
program(N, foil):- program(N, ref), not removed(N).
Expand Down
2 changes: 0 additions & 2 deletions src/asplain/encodings/pruning/abduced.lp

This file was deleted.

2 changes: 2 additions & 0 deletions src/asplain/encodings/pruning/change.lp
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
change(rm,N) :- node(N,_), program(N, ref), not program(N, foil), solved(foil).
change(add,N) :- node(N,_), program(N, foil), not program(N, ref), solved(foil).
6 changes: 3 additions & 3 deletions src/asplain/encodings/pruning/changes.lp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "./abduced.lp".
#include "./change.lp".

% Always include query
node_included(N,T) :- query(N, _), node(N,T).
% Always include abductions
node_included(N,T) :- abduced(_, N), node(N,T).
% Always include changes
node_included(N,T) :- change(_, N), node(N,T).
node_included(N,T) :- fired(N), node(N,T).

node_included(N,T) :- node(N, T), model(N, ref), not model(N, foil).
Expand Down
4 changes: 2 additions & 2 deletions src/asplain/encodings/pruning/inclusion_filter.lp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ node_show(N,T) :- node(N,T), node_included(N,T), not tag(N,hide).
node_show(N,T) :- node(N,T), tag(N,show).

program_included(N,P) :- node_show(N,_), program(N,P).
abducible_included(A,N) :- node_show(N,_), abducible(A,N).
optional_included(A,N) :- node_show(N,_), optional(A,N).
model_included(N,M) :- node_show(N,_), model(N,M).
tag_included(N,T) :- node_show(N,_), tag(N,T).

Expand All @@ -18,7 +18,7 @@ solved_included(S) :- solved(S).
#show edge(E,T) : edge_included(E,T).

#show program(N,P) : program_included(N,P).
#show abducible(A,N) : abducible_included(A,N).
#show optional(A,N) : optional_included(A,N).
#show model(N,M) : model_included(N,M).
#show tag(N,T) : tag_included(N,T).
#show query(N,Q) : query_included(N,Q).
Expand Down
6 changes: 3 additions & 3 deletions src/asplain/encodings/pruning/orphans.lp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "./abduced.lp".
#include "./change.lp".

% Always include query
node_included(N,T) :- query(N, _), node(N,T).
% Always include abductions
node_included(N,T) :- abduced(_, N), node(N,T).
% Always include changes
node_included(N,T) :- change(_, N), node(N,T).

node_included(N2,T) :- node_included(N1,_), node(N2,T), edge((N1,N2), _).
node_included(N2,T) :- node_included(N1,_), node(N2,T), edge((N2,N1), _).
Expand Down
6 changes: 3 additions & 3 deletions src/asplain/encodings/pruning/path_test.lp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "./abduced.lp".
#include "./change.lp".

% EXTERNAL: path_depth(N).

Expand All @@ -8,7 +8,7 @@ reachable(N,T) :- query(N,_), node(N,T).
reachable(M,T) :- reachable(N,_), node_path(M,T), edge((N,M),_).
reachable(M,T) :- reachable(N,_), node_path(M,T), edge((M,N),_).

:- abduced(_,N), node(N,T), not reachable(N,T).
:- change(_,N), node(N,T), not reachable(N,T).

node_included(N,T) :- node_path(N,T).
edge_included((N1,N2),T) :- edge((N1,N2),T), node_included(N1,_), node_included(N2,_).
Expand All @@ -17,4 +17,4 @@ edge_included((N1,N2),T) :- edge((N1,N2),T), node_included(N1,_), node_included(

%#show .
%#show q(N) : reachable_query(N,_).
%#show a(N) : reachable_abduced(N,_).
%#show a(N) : reachable_change(N,_).
16 changes: 8 additions & 8 deletions src/asplain/encodings/pruning/paths.lp
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
#include "./abduced.lp".
#include "./change.lp".

% EXTERNAL: path_depth(N).

% Always include query with distance 0
reachable_query(N,T) :- query(N, _), node(N,T).
% Always include abductions with distance 0
reachable_abduced(N,T) :- abduced(_, N), node(N,T).
% Always include changes with distance 0
reachable_change(N,T) :- change(_, N), node(N,T).

reachable_query(M,T) :- edge((N,M),_), node(M,T), reachable_query(N,_).
reachable_abduced(M,T) :- edge((N,M),_), node(M,T), reachable_abduced(N,_).
reachable_change(M,T) :- edge((N,M),_), node(M,T), reachable_change(N,_).

% FIND INTERSECTIONS

intersection(N,T) :- reachable_query(N,T), reachable_abduced(N,T).
intersection(N,T) :- reachable_query(N,T), reachable_change(N,T).
reachable_query_back(N,T) :- intersection(N,T).
reachable_abduced_back(N,T) :- intersection(N,T).
reachable_change_back(N,T) :- intersection(N,T).

% BACK DIRECTION

reachable_query_back(N,T) :- edge((N,M),_), node(N,T), reachable_query_back(M,_).
reachable_abduced_back(N,T) :- edge((N,M),_), node(N,T), reachable_abduced_back(M,_).
reachable_change_back(N,T) :- edge((N,M),_), node(N,T), reachable_change_back(M,_).

% PATH INCLUSIONS

node_path(N,T) :- reachable_query(N,T), reachable_query_back(N,T).
node_path(N,T) :- reachable_abduced(N,T), reachable_abduced_back(N,T).
node_path(N,T) :- reachable_change(N,T), reachable_change_back(N,T).

% DEPTH

Expand Down
Loading
Loading