#lang racket/base (require math/distributions racket/match racket/stream racket/vector rebellion/streaming/transducer rebellion/collection/list) (struct transition-outcome (new-state reward) #:transparent) (struct markov-process (transition-function states action-space rewards)) (define (markov-process-outcomes process state action) ((markov-process-transition-function process) state action)) (define (markov-process-perform-action process state action) (sample (markov-process-outcomes process state action))) ; Gridworld, from Reinforcement Learning: An Introduction Chapter 4.1: ; ; T 1 2 3 ; 4 5 6 7 ; 8 9 10 11 ; 12 13 14 T (define gridworld-states (vector-immutable 1 2 3 4 5 6 7 8 9 10 11 12 13 14 'terminal)) (define gridworld-actions (vector-immutable 'left 'up 'right 'down)) (define gridworld-rewards (vector-immutable -1 1)) (define gridworld (markov-process (λ (state action) (define new-state (match* (state action) [(1 'left) 'terminal] [(1 'up) 1] [(1 'right) 2] [(1 'down) 5] [(2 'left) 1] [(2 'up) 2] [(2 'right) 3] [(2 'down) 6] [(3 'left) 2] [(3 'up) 3] [(3 'right) 3] [(3 'down) 7] [(4 'left) 4] [(4 'up) 'terminal] [(4 'right) 5] [(4 'down) 8] [(5 'left) 4] [(5 'up) 1] [(5 'right) 6] [(5 'down) 9] [(6 'left) 5] [(6 'up) 2] [(6 'right) 7] [(6 'down) 10] [(7 'left) 6] [(7 'up) 3] [(7 'right) 7] [(7 'down) 11] [(8 'left) 8] [(8 'up) 4] [(8 'right) 9] [(8 'down) 12] [(9 'left) 8] [(9 'up) 5] [(9 'right) 10] [(9 'down) 13] [(10 'left) 9] [(10 'up) 6] [(10 'right) 11] [(10 'down) 14] [(11 'left) 10] [(11 'up) 7] [(11 'right) 11] [(11 'down) 'terminal] [(12 'left) 12] [(12 'up) 8] [(12 'right) 13] [(12 'down) 12] [(13 'left) 12] [(13 'up) 9] [(13 'right) 14] [(13 'down) 13] [(14 'left) 13] [(14 'up) 10] [(14 'right) 'terminal] [(14 'down) 14] [('terminal _) 'terminal])) (define reward (if (equal? new-state 'terminal) 1 -1)) (discrete-dist (list (transition-outcome new-state reward)))) gridworld-states (λ (_) gridworld-actions) gridworld-rewards)) (struct markov-policy (process choice-function)) (define (random-policy process) (markov-policy process (λ (_ actions) (discrete-dist actions)))) (define (greedy-policy process state-value-estimator #:discount-rate [discount-rate #e0.95]) (define (estimate-action-value state action) (define outcomes (markov-process-outcomes process state action)) (for*/sum ([s* (in-vector (markov-process-states process))] [r (in-vector (markov-process-rewards process))]) (* (pdf outcomes (transition-outcome s* r)) (+ r (* discount-rate (state-value-estimator s*)))))) (markov-policy process (λ (state actions) (define optimal-actions (transduce actions (taking-maxima #:key (λ (a) (estimate-action-value state a))) #:into into-list)) (discrete-dist optimal-actions)))) (define (markov-policy-actions policy state) (define process (markov-policy-process policy)) (define actions ((markov-process-action-space process) state)) ((markov-policy-choice-function policy) state actions)) (define (markov-policy-choose-action policy state) (sample (markov-policy-actions policy state))) (struct transition-step (action new-state reward) #:transparent) (define (markov-policy-next-step policy state) (define process (markov-policy-process policy)) (define action (markov-policy-choose-action policy state)) (match-define (transition-outcome new-state reward) (markov-process-perform-action process state action)) (transition-step action new-state reward)) (define (markov-policy-rollout policy init-state) (define step (markov-policy-next-step policy init-state)) (define new-state (transition-step-new-state step)) (stream-cons step (markov-policy-rollout policy new-state))) (define (markov-policy-table policy) (for/hash ([s (in-vector (markov-process-states (markov-policy-process policy)))]) (values s (markov-policy-actions policy s)))) (define (markov-policy-iterative-evaluation policy #:discount-rate [discount-rate #e0.95]) (define process (markov-policy-process policy)) (define states (markov-process-states process)) (define rewards (markov-process-rewards process)) (define action-space (markov-process-action-space process)) (define value-estimates (make-vector (vector-length states) 0)) (define state-indices (for/hash ([s (in-vector states)] [i (in-naturals)]) (values s i))) (define (estimate-state-value s) (vector-ref value-estimates (hash-ref state-indices s))) (let loop () (define any-changed? (for/fold ([any-changed? #false]) ([v (in-vector value-estimates)] [s (in-vector states)] [i (in-naturals)]) (define policy-actions (markov-policy-actions policy s)) (define v* (for/sum ([a (in-vector (action-space s))]) (define outcomes (markov-process-outcomes process s a)) (for*/sum ([s* (in-vector states)] [r (in-vector rewards)]) (define s*-v (estimate-state-value s*)) (* (pdf policy-actions a) (pdf outcomes (transition-outcome s* r)) (+ r (* discount-rate s*-v)))))) (vector-set! value-estimates i v*) (define delta (abs (- v v*))) (or any-changed? (> delta #e0.1)))) (when any-changed? (loop))) (for/hash ([s (in-vector states)] [v (in-vector value-estimates)]) (values s v))) (define value-estimates (markov-policy-iterative-evaluation (random-policy gridworld))) (define policy (greedy-policy gridworld (λ (s) (hash-ref value-estimates s)))) (for ([(s actions) (in-hash (markov-policy-table policy))]) (printf "state ~a: ~a\n" s actions)) #;(module+ main (printf "Starting in state ~a\n" 9) (for ([step (in-stream (stream-take (markov-policy-rollout policy 9) 10))]) (match-define (transition-step action new-state reward) step) (printf "Reached state ~a via action ~a, received reward ~a\n" new-state action reward)))