%------------------------------------------------------
% /* vim: set filetype=prolog ; */
% flags
:- discontiguous want/2, got/1,set/4.

%------------------------------------------------------
%% runtime stuff
% querying semantics
wants     :- wants1(L), println(L).
gots      :- gots1(L), print(L).

wants1(L) :- setof(X,wants2(X),L).
gots1(L)  :- setof(X, gots2(X),L).

wants2(X/L) :- want(X,_),myBagof(S,X^status(X,S),L).
gots2( X/L) :- got(X), myBagof(S,X^status(X,S),L).

status(X,ungot) :-  want(X,_), \+ got(X).
status(X,want ) :- got(X), \+ want(X,_).

roots(L) :- setof(X,root(X),L).
root(X)  :- clause(set(X,_,_,_),(tr(_,_,_,_),_)).

independents(W,L)  :- setof(X,W^independent(X,W),L).
independent(X=V,W) :- root(X),member(X=V,W).
 
% DCG-based assumption space management
=(X,Y,In,     In)   :- member(X=Z,In),!,Y=Z.
=(X,Y,In,[X=Y|Out]) :- clause(set(X,_,_,_),_),!,set(X,Z0,In,Out),ones(Z0,Z),Y=Z.
=(X,Y,In,[X=Y|In])  :- Y is random(1000)/999.

% dot
dot :- dot('/tmp/graph.dot').
dot(F) :- 
		tmp_file(dot,F1), 
		tell(F1),dot1,told,
		rename_file(F1,F).
dot1 :- 
	print('digraph G { '),nl,
	print('\trankdir=LR;'),nl,
	print('\ranksep=0.25;'),nl,
	print('\tnode [fontsize=12,shape=none]; '),nl,
	want(X,goal),
	format('\t "    ~p" [shape=doubleoctagon];\n',X),
	forall((want(To,From),From \= goal),
	        format('\t"    ~p" -> "    ~p";\n',[To,From])),
	print('}'),nl.

% misc
ones(N0,N) :- within(0,1,N0,N).

%------------------------------------------------------
% load-time stuff

xpand(Head=Body0,[got(Head),Rule|Rest]) :- 
	xpand1(Body0,Head,Arg,N0,Body,Rest,[]), N is N0,
	expand_term((set(Head,Out) --> Body,{Out is Arg/N}),Rule).

xpand1(A0 + B0, H,A1 + B1, A2+B2,(A,B)) --> !, xpand1(A0,H,A1,A2,A), xpand1(B0,H,B1,B2,B).
xpand1(A0 - B0, H,A1 - B1, A2+B2,(A,B)) --> !, xpand1(A0,H,A1,A2,A), xpand1(B0,H,B1,B2,B).
xpand1(A * N,   H,P, N   ,       Goal ) --> !, [want(A,H)], {xpand2(A=X,X*N,P,Goal)}.
xpand1(A / N,   H,P, 1/N ,     Goal   ) --> !, [want(A,H)], {xpand2(A=X,X/N,P,Goal)}.
xpand1(tr(Min,Max,Mode),_,X,1,{tr(Min,Max,Mode,X)}) --> [],!.
xpand1(trig(Min,Max,Mode),_,X,1,{trig(Min,Max,Mode,X)}) --> [],!.
xpand1(A    ,   H,P, 1,       Goal  ) -->    [want(A,H)], {xpand2(A=X,X,P,Goal)}.

xpand2(           - A = X,    P, -1*P, A=X) :- !.
xpand2(             A = X,    P,    P, A=X).


%-----------------------------------------------------------------------
%% numerics

:- arithmetic_function(ranf/0).
ranf(X) :-  X is random(65535)/65536.

sum([H|T],S) :- sum(T,H,S).
sum([],In,In).
sum([H|T],In,Out) :- Temp is H+ In, sum(T,Temp,Out).


within(Bottom,Top,N,N     ) :- N >= Bottom, N=<Top.
within(Bottom,_,  N,Bottom) :- N < Bottom.
within(_,     Top,N,Top   ) :- N > Top.

add2(X = N1,X=N2,X=N3) :- N3 is N2 + N1.

normal(M,S,N) :- box_muller(M,S,N).
box_muller(M,S,N) :- w(W0,X), W is sqrt((-2.0 * log(W0))/W0), Y1 is X * W, N is M + Y1*S.

w(W,X) :-
	X1 is 2.0 * ranf - 1,
	X2 is 2.0 * ranf - 1,
	W0 is X1*X1 + X2*X2,
	(W0  >= 1.0 -> w(W,X) ; W0=W, X = X1).

showCdf(Cdf) :-
		forall(member(A=B,Cdf),
			   (str(B*100/4,'~',S),
			    format('~2f = ~2f | ~p\n',[A,B,S]))).

str(N0,C,S) :- N is round(N0), str1(N,C,S).
str1(N,_,'') :- N =< 0,!.
str1(N,C,S)  :- findall(C,between(1,N,_),L),concat_atom(L,S).

cdfPlus([H|T],Combined)  :- cdfPlus1(T,H,Nums), nums2cdf(Nums,Combined).	

cdfPlus1([],Out,Out).
cdfPlus1([H0|T],H1,Out) :- maplist(add2,H0,H1,H2), cdfPlus1(T,H2,Out).

tr2Cdf(A,B,C,Samples,Cdf) :-
	Samples1 is Samples + 1,
 	bagof(One,A^B^C^Samples^tr2Dist1(A,B,C,Samples1,One),L0),
	nums2cdf(L0,Cdf).

nums2cdf(L0,Cdf) :- 
		accumulate(L0,L1,0,N),
		maplist(nums2cdf1(N),L1,Cdf).

accumulate([],[],N,N).
accumulate([X =N1|T0],[X=N2|T],N0,N) :- N2 is N1 + N0, accumulate(T0,T,N2,N).

nums2cdf1(N0,X=N1,X=N2) :- N2 is N1 / N0.

tr2Dist1(A,B,C,Samples,X=Y) :-
	between(1,Samples,N),
	X is 1/Samples * (N-1),
	trPDF(A,B,C,X,Y).


trCDF(A,_,_,X,0) :- X < A,!.
trCDF(A,B,C,X,Y) :- X >= A, X =< C,!, Y is      (X - A)^2/((B-A)*(C-A)).
trCDF(A,B,C,X,Y) :- X > C,  X =< B,!, Y is 1 -  (B - X)^2/((B-A)*(B-C)).
trCDF(_,B,_,X,1) :- X > B.

trPDF(A,_,_,X,0) :- X < A,!. 
trPDF(A,B,C,X,Y) :- X >= A, X =< C,!, Y is 2*(X - A)/((B-A)*(C-A)).
trPDF(A,B,C,X,Y) :- X >  C, X =< B,!, Y is 2*(B - X)/((B-A)*(B-C)).
trPDF(_,B,_,X,0) :- X > B.

%trig(A,B,C,Z) :- tr(A,B,C,Z0), Z is 2*(Z0 - A)/(B-A) - 1.
tr(A,B,C,_)   :- (C < A; C > B),!,print(bad(tr(min = A,max = B, mode = C))),nl,fail.
tr(A,A,_,A)   :- !.
tr(A,B,C,Z)   :- R is ranf, tr1(R,0.01, A,B,C,A,B,Z).

tr1(R,P,A,B,C,Min,Max,Z) :-
	Half is Min + (Max - Min)/2,
	trCDF(A,B,C,Half,Sample),
	myCompare(Op,P,R,Sample),
	tr2(Op,R,P,A,B,C,Min,Half,Max,Half,Z).

tr2(=,_,_,_,_,_,_  ,_   ,_  ,Z,Z).
tr2(>,R,P,A,B,C,_  ,Half,Max,_,Z) :- tr1(R,P,A,B,C,Half,Max,Z).
tr2(<,R,P,A,B,C,Min,Half,_  ,_,Z) :- tr1(R,P,A,B,C,Min,Half,Z).

myCompare(=, P,X,Y) :- abs(X - Y) =< P,!.
myCompare(Op,_,X,Y) :- compare(Op,X,Y).

%------------------------------------------------------
%% standard stuff

printl([H|T]) :- print(H), forall(member(X,T),format(',~p',X)).
 
println(L) :- println(L,'\t').
println(L,Tab) :- forall(member(X,L),format('~w~p\n',[Tab,X])).

myBagof(X,Y,Z) :- bagof(X,Y,Z),!.
myBagof(_,_,[]).

countDown(Top,Now,Show) :-
		0 is Now mod Show,!, 
		N is Top - Now,
	 	print(user_error,N), nl(user_error),
		flush_output(user_error).
countDown(_,_,_).

repeatN(N,Goal) :- forall(between(1,N,_),Goal).

%------------------------------------------------------
%% learning stuff

tar3(Stem,N) :- 
	atom_concat(Stem,'.csv',Data),
	tell(Data), csv(N), told,
	atom_concat(Stem,'.names',Names),
	tell(Names),names,told,
	print(files(csv(Data),names(Names))),
	nl.

csv(N) :-
	roots([H|T]),
	format('$~p',H),
	forall(member(X,T),format(',$~p',X)),
	write(',>$'),print(goal),nl,
	forall(between(1,N,M),
	       (countDown(N,M,100),
		    csv1)).

csv1 :-	
	set(goal,X,[],W0), 
	independents(W0,W1),
	maplist(arg(2),W1,W), 
	printl(W),  write(','), print(X), nl.

names :-
	roots(L),
	write('_0, _1'),nl,nl,
	forall(member(X,L),format('~p : continuous.\n',X)).

%------------------------------------------------------
% main stuff

set1(X) :- set(X,Y,[],L),print(X = Y),nl,println(L),
           print(ind),nl,independents(L,I),println(I).
eg(1)   :- set1(goal).

eg(2) :-
	tell(dat),
	forall(between(1,1000,_), (tr(10,20,10.0001,X),print(X),nl)),
	told,
	shell("sort -n dat > dat.sorted").

eg(3) :-  tar3('/tmp/db',1000).

eg(4) :- repeatN(100,(trig(8,10,9,X), print(X),print(' '))).

eg(5) :- dot, 
	shell('dot -Tps -o /tmp/graph.eps /tmp/graph.dot'),
	shell('epstopdf /tmp/graph.eps').

%------------------------------------------------------
% change the compiler

term_expansion(X = Y,Out ) :- xpand(X=Y,Out).
goal_expansion(
	 trig(Mode0,Down,Up,X), 
	 tr(Min,Mode,Max,X)) :-
	Mode is Mode0/10,
	Min  is (Mode - Down)/10,
	Max  is (Mode + Up)/10.

twenty2zeroOne(A,B) :- B is -1 + 2*A/10.
	
