#lang racket

(provide (all-defined-out))

;; Under the Hood

(define (var name) (vector name))

(define (var? x) (vector? x))

(define empty-s '())

(define (walk v s)
  (let ((a (and (var? v) (assv v s))))
    (cond
      ((pair? a) (walk (cdr a) s))
      (else v))))

(define (ext-s x v s)
  (cond
    ((occurs? x v s) #f)
    (else (cons (cons x v) s))))

(define (occurs? x v s)
  (let ((v (walk v s)))
    (cond
      ((var? v) (eqv? v x))
      ((pair? v)
       (or (occurs? x (car v) s)
           (occurs? x (cdr v) s)))
      (else #f))))

(define (unify u v s)
  (let ((u (walk u s))
        (v (walk v s)))
    (cond
      ((eqv? u v) s)
      ((var? u) (ext-s u v s))
      ((var? v) (ext-s v u s))
      ((and (pair? u) (pair? v))
       (let ((s (unify (car u) (car v) s)))
         (and s (unify (cdr u) (cdr v) s))))
      (else #f))))

(define (== u v)
  (lambda (s)
    (let ((s (unify u v s)))
      (if s (list s) '()))))

(define succeed (lambda (s) (list s)))

(define fail (lambda (s) '()))

(define (disj2 g1 g2)
  (lambda (s)
    (append-inf (g1 s) (g2 s))))

(define (append-inf s-inf t-inf)
  (cond
    ((null? s-inf) t-inf)
    ((pair? s-inf)
     (cons (car s-inf)
           (append-inf (cdr s-inf) t-inf)))
    (else
     (lambda () (append-inf t-inf (s-inf))))))

(define (unproductiveo)
  (lambda (s)
    (lambda ()
      ((unproductiveo) s))))

(define (pull-inf s-inf)
  (cond
    ((null? s-inf) '())
    ((pair? s-inf) s-inf)
    (else (pull-inf (s-inf)))))

(define (run-goal g)
  (pull-inf (g empty-s)))

(define (productiveo)
  (lambda (s)
    (lambda ()
      ((disj2 succeed (productiveo)) s))))

(define (take-inf n S-inf)
  (cond
    ((null? S-inf) '())
    (else
     (cons (car S-inf)
           (cond
             ((and n (zero? (sub1 n))) '())
             (else
              (take-inf (and n (sub1 n))
                        (pull-inf (cdr S-inf)))))))))

(define (conj2 g1 g2)
  (lambda (s)
    (append-map-inf g2 (g1 s))))

(define (append-map-inf g s-inf)
  (cond
    ((null? s-inf) '())
    ((pair? s-inf)
     (append-inf (g (car s-inf))
                 (append-map-inf g (cdr s-inf))))
    (else
     (lambda () (append-map-inf g (s-inf))))))

(define (call/fresh name f)
  (f (var name)))

(define (reify-name n)
  (string->symbol
   (string-append "_"
                  (number->string n))))

(define (walk* v s)
  (let ((v (walk v s)))
    (cond
      ((var? v) v)
      ((pair? v)
       (cons
        (walk* (car v) s)
        (walk* (cdr v) s)))
      (else v))))

(define-syntax project
  (syntax-rules ()
    ((project (x ...) g ...)
     (lambda (s)
       (let ((x (walk* x s)) ...)
         (( conj g ...) s))))))

(define (reify-s v r)
  (let ((v (walk v r)))
    (cond
      ((var? v)
       (let ((rn (reify-name (length r))))
         (ext-s v rn r)))
      ((pair? v)
       (let ((r (reify-s (car v) r)))
         (reify-s (cdr v) r)))
      (else r))))

(define (reify v)
  (lambda (s)
    (let ((v (walk* v s)))
      (let ((r (reify-s v empty-s)))
        (walk* v r)))))

(define (appendo l t out)
  (lambda (s)
    (lambda ()
      ((disj2
        (conj2 (== '() l) (== t out))
        (call/fresh 'a
                    (lambda (a)
                      (call/fresh 'd
                                  (lambda (d)
                                    (call/fresh 'res
                                                (lambda (res)
                                                  (conj2
                                                   (== (cons a d) l)
                                                   (conj2
                                                    (== (cons a res) out)
                                                    (appendo d t
                                                             res)))))))))))
      s)))

(define (ifte g1 g2 g3)
  (lambda (s)
    (ifte-help (g1 s) g2 g3)))

(define (ifte-help s-inf g2 g3)
  (cond
    ((null? s-inf) (g3 s-inf))
    ((pair? s-inf) (append-map-inf g2 s-inf))
    (else
     (lambda () (ifte-help (s-inf) g2 g3)))))

(define (once g)
  (lambda (s)
    (once-help (g s))))

(define (once-help s-inf)
  (cond
    ((null? s-inf) '())
    ((pair? s-inf) (cons (car s-inf) '()))
    (else
     (lambda () (once-help (s-inf))))))

;; Connecting the Wires

(define-syntax disj
  (syntax-rules ()
    ((disj) fail)
    ((disj g) g)
    ((disj g0 g1 g ...)
     (disj2 g0 (disj g1 g ...)))))

(define-syntax conj
  (syntax-rules ()
    ((conj) succeed)
    ((conj g) g)
    ((conj g0 g1 g ...)
     (conj2 g0 (conj g1 g ...)))))

(define-syntax fresh
  (syntax-rules ()
    ((fresh () g ...) (conj g ...))
    ((fresh (x0 x ...) g ...)
     (call/fresh 'x0
                 (lambda (x0) (fresh (x ...) g ...))))))

(define-syntax run
  (syntax-rules ()
    ((run n (x0 x ...) g ...)
     (run n q
          (fresh (x0 x ...) (== (list x0 x ...) q) g ...)))
    ((run n q g ...)
     (let ((q (var 'q)))
       (map (reify q)
            (take-inf n (run-goal (conj g ...))))))))

(define-syntax run*
  (syntax-rules ()
    ((run* q g ...) (run #f q g ...))))

(define-syntax conde
  (syntax-rules ()
    ((conde (g ...) ...) (disj (conj g ...) ...))))

(define-syntax defrel
  (syntax-rules ()
    ((defrel (name a ...) g ...)
     (define (name a ...)
       (lambda (s)
         (lambda ()
           ((conj g ...) s)))))))

(define-syntax conda
  (syntax-rules (else)
    ((conda (else g^ ...)) (conj g^ ...))
    ((conda (g0 g ...) l ...)
     (ifte g0 (conj g ...) (conda l ...)))))

(define-syntax condu
  (syntax-rules (else)
    ((condu (g0 g ...) ... (else g^ ...))
     (conda ((once g0) g ...) ... (else g^ ...)))))