A Racket library for (non-LLM) AI algorithms
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 212 lines 6.7 kB view raw
1#lang racket/base 2 3 4(require math/distributions 5 racket/match 6 racket/stream 7 racket/vector 8 rebellion/streaming/transducer 9 rebellion/collection/list) 10 11 12(struct transition-outcome (new-state reward) #:transparent) 13(struct markov-process (transition-function states action-space rewards)) 14 15 16(define (markov-process-outcomes process state action) 17 ((markov-process-transition-function process) state action)) 18 19 20(define (markov-process-perform-action process state action) 21 (sample (markov-process-outcomes process state action))) 22 23 24; Gridworld, from Reinforcement Learning: An Introduction Chapter 4.1: 25; 26; T 1 2 3 27; 4 5 6 7 28; 8 9 10 11 29; 12 13 14 T 30 31 32(define gridworld-states (vector-immutable 1 2 3 4 5 6 7 8 9 10 11 12 13 14 'terminal)) 33(define gridworld-actions (vector-immutable 'left 'up 'right 'down)) 34(define gridworld-rewards (vector-immutable -1 1)) 35 36 37(define gridworld 38 (markov-process 39 (λ (state action) 40 (define new-state 41 (match* (state action) 42 [(1 'left) 'terminal] 43 [(1 'up) 1] 44 [(1 'right) 2] 45 [(1 'down) 5] 46 [(2 'left) 1] 47 [(2 'up) 2] 48 [(2 'right) 3] 49 [(2 'down) 6] 50 [(3 'left) 2] 51 [(3 'up) 3] 52 [(3 'right) 3] 53 [(3 'down) 7] 54 [(4 'left) 4] 55 [(4 'up) 'terminal] 56 [(4 'right) 5] 57 [(4 'down) 8] 58 [(5 'left) 4] 59 [(5 'up) 1] 60 [(5 'right) 6] 61 [(5 'down) 9] 62 [(6 'left) 5] 63 [(6 'up) 2] 64 [(6 'right) 7] 65 [(6 'down) 10] 66 [(7 'left) 6] 67 [(7 'up) 3] 68 [(7 'right) 7] 69 [(7 'down) 11] 70 [(8 'left) 8] 71 [(8 'up) 4] 72 [(8 'right) 9] 73 [(8 'down) 12] 74 [(9 'left) 8] 75 [(9 'up) 5] 76 [(9 'right) 10] 77 [(9 'down) 13] 78 [(10 'left) 9] 79 [(10 'up) 6] 80 [(10 'right) 11] 81 [(10 'down) 14] 82 [(11 'left) 10] 83 [(11 'up) 7] 84 [(11 'right) 11] 85 [(11 'down) 'terminal] 86 [(12 'left) 12] 87 [(12 'up) 8] 88 [(12 'right) 13] 89 [(12 'down) 12] 90 [(13 'left) 12] 91 [(13 'up) 9] 92 [(13 'right) 14] 93 [(13 'down) 13] 94 [(14 'left) 13] 95 [(14 'up) 10] 96 [(14 'right) 'terminal] 97 [(14 'down) 14] 98 [('terminal _) 'terminal])) 99 (define reward (if (equal? new-state 'terminal) 1 -1)) 100 (discrete-dist (list (transition-outcome new-state reward)))) 101 gridworld-states 102 (λ (_) gridworld-actions) 103 gridworld-rewards)) 104 105 106(struct markov-policy (process choice-function)) 107 108 109(define (random-policy process) 110 (markov-policy process (λ (_ actions) (discrete-dist actions)))) 111 112 113(define (greedy-policy process state-value-estimator #:discount-rate [discount-rate #e0.95]) 114 115 (define (estimate-action-value state action) 116 (define outcomes (markov-process-outcomes process state action)) 117 (for*/sum ([s* (in-vector (markov-process-states process))] 118 [r (in-vector (markov-process-rewards process))]) 119 (* (pdf outcomes (transition-outcome s* r)) 120 (+ r (* discount-rate (state-value-estimator s*)))))) 121 122 (markov-policy 123 process 124 (λ (state actions) 125 (define optimal-actions 126 (transduce actions 127 (taking-maxima #:key (λ (a) (estimate-action-value state a))) 128 #:into into-list)) 129 (discrete-dist optimal-actions)))) 130 131 132(define (markov-policy-actions policy state) 133 (define process (markov-policy-process policy)) 134 (define actions ((markov-process-action-space process) state)) 135 ((markov-policy-choice-function policy) state actions)) 136 137 138(define (markov-policy-choose-action policy state) 139 (sample (markov-policy-actions policy state))) 140 141 142(struct transition-step (action new-state reward) #:transparent) 143 144 145(define (markov-policy-next-step policy state) 146 (define process (markov-policy-process policy)) 147 (define action (markov-policy-choose-action policy state)) 148 (match-define (transition-outcome new-state reward) 149 (markov-process-perform-action process state action)) 150 (transition-step action new-state reward)) 151 152 153(define (markov-policy-rollout policy init-state) 154 (define step (markov-policy-next-step policy init-state)) 155 (define new-state (transition-step-new-state step)) 156 (stream-cons step (markov-policy-rollout policy new-state))) 157 158 159(define (markov-policy-table policy) 160 (for/hash ([s (in-vector (markov-process-states (markov-policy-process policy)))]) 161 (values s (markov-policy-actions policy s)))) 162 163 164(define (markov-policy-iterative-evaluation policy #:discount-rate [discount-rate #e0.95]) 165 (define process (markov-policy-process policy)) 166 (define states (markov-process-states process)) 167 (define rewards (markov-process-rewards process)) 168 (define action-space (markov-process-action-space process)) 169 (define value-estimates (make-vector (vector-length states) 0)) 170 (define state-indices 171 (for/hash ([s (in-vector states)] 172 [i (in-naturals)]) 173 (values s i))) 174 (define (estimate-state-value s) 175 (vector-ref value-estimates (hash-ref state-indices s))) 176 (let loop () 177 (define any-changed? 178 (for/fold ([any-changed? #false]) 179 ([v (in-vector value-estimates)] 180 [s (in-vector states)] 181 [i (in-naturals)]) 182 (define policy-actions (markov-policy-actions policy s)) 183 (define v* 184 (for/sum ([a (in-vector (action-space s))]) 185 (define outcomes (markov-process-outcomes process s a)) 186 (for*/sum ([s* (in-vector states)] 187 [r (in-vector rewards)]) 188 (define s*-v (estimate-state-value s*)) 189 (* (pdf policy-actions a) 190 (pdf outcomes (transition-outcome s* r)) 191 (+ r (* discount-rate s*-v)))))) 192 (vector-set! value-estimates i v*) 193 (define delta (abs (- v v*))) 194 (or any-changed? (> delta #e0.1)))) 195 (when any-changed? 196 (loop))) 197 (for/hash ([s (in-vector states)] 198 [v (in-vector value-estimates)]) 199 (values s v))) 200 201 202(define value-estimates (markov-policy-iterative-evaluation (random-policy gridworld))) 203(define policy (greedy-policy gridworld (λ (s) (hash-ref value-estimates s)))) 204 205(for ([(s actions) (in-hash (markov-policy-table policy))]) 206 (printf "state ~a: ~a\n" s actions)) 207 208#;(module+ main 209 (printf "Starting in state ~a\n" 9) 210 (for ([step (in-stream (stream-take (markov-policy-rollout policy 9) 10))]) 211 (match-define (transition-step action new-state reward) step) 212 (printf "Reached state ~a via action ~a, received reward ~a\n" new-state action reward)))