A Racket library for (non-LLM) AI algorithms
1#lang racket/base
2
3
4(require math/distributions
5 racket/match
6 racket/stream
7 racket/vector
8 rebellion/streaming/transducer
9 rebellion/collection/list)
10
11
12(struct transition-outcome (new-state reward) #:transparent)
13(struct markov-process (transition-function states action-space rewards))
14
15
16(define (markov-process-outcomes process state action)
17 ((markov-process-transition-function process) state action))
18
19
20(define (markov-process-perform-action process state action)
21 (sample (markov-process-outcomes process state action)))
22
23
24; Gridworld, from Reinforcement Learning: An Introduction Chapter 4.1:
25;
26; T 1 2 3
27; 4 5 6 7
28; 8 9 10 11
29; 12 13 14 T
30
31
32(define gridworld-states (vector-immutable 1 2 3 4 5 6 7 8 9 10 11 12 13 14 'terminal))
33(define gridworld-actions (vector-immutable 'left 'up 'right 'down))
34(define gridworld-rewards (vector-immutable -1 1))
35
36
37(define gridworld
38 (markov-process
39 (λ (state action)
40 (define new-state
41 (match* (state action)
42 [(1 'left) 'terminal]
43 [(1 'up) 1]
44 [(1 'right) 2]
45 [(1 'down) 5]
46 [(2 'left) 1]
47 [(2 'up) 2]
48 [(2 'right) 3]
49 [(2 'down) 6]
50 [(3 'left) 2]
51 [(3 'up) 3]
52 [(3 'right) 3]
53 [(3 'down) 7]
54 [(4 'left) 4]
55 [(4 'up) 'terminal]
56 [(4 'right) 5]
57 [(4 'down) 8]
58 [(5 'left) 4]
59 [(5 'up) 1]
60 [(5 'right) 6]
61 [(5 'down) 9]
62 [(6 'left) 5]
63 [(6 'up) 2]
64 [(6 'right) 7]
65 [(6 'down) 10]
66 [(7 'left) 6]
67 [(7 'up) 3]
68 [(7 'right) 7]
69 [(7 'down) 11]
70 [(8 'left) 8]
71 [(8 'up) 4]
72 [(8 'right) 9]
73 [(8 'down) 12]
74 [(9 'left) 8]
75 [(9 'up) 5]
76 [(9 'right) 10]
77 [(9 'down) 13]
78 [(10 'left) 9]
79 [(10 'up) 6]
80 [(10 'right) 11]
81 [(10 'down) 14]
82 [(11 'left) 10]
83 [(11 'up) 7]
84 [(11 'right) 11]
85 [(11 'down) 'terminal]
86 [(12 'left) 12]
87 [(12 'up) 8]
88 [(12 'right) 13]
89 [(12 'down) 12]
90 [(13 'left) 12]
91 [(13 'up) 9]
92 [(13 'right) 14]
93 [(13 'down) 13]
94 [(14 'left) 13]
95 [(14 'up) 10]
96 [(14 'right) 'terminal]
97 [(14 'down) 14]
98 [('terminal _) 'terminal]))
99 (define reward (if (equal? new-state 'terminal) 1 -1))
100 (discrete-dist (list (transition-outcome new-state reward))))
101 gridworld-states
102 (λ (_) gridworld-actions)
103 gridworld-rewards))
104
105
106(struct markov-policy (process choice-function))
107
108
109(define (random-policy process)
110 (markov-policy process (λ (_ actions) (discrete-dist actions))))
111
112
113(define (greedy-policy process state-value-estimator #:discount-rate [discount-rate #e0.95])
114
115 (define (estimate-action-value state action)
116 (define outcomes (markov-process-outcomes process state action))
117 (for*/sum ([s* (in-vector (markov-process-states process))]
118 [r (in-vector (markov-process-rewards process))])
119 (* (pdf outcomes (transition-outcome s* r))
120 (+ r (* discount-rate (state-value-estimator s*))))))
121
122 (markov-policy
123 process
124 (λ (state actions)
125 (define optimal-actions
126 (transduce actions
127 (taking-maxima #:key (λ (a) (estimate-action-value state a)))
128 #:into into-list))
129 (discrete-dist optimal-actions))))
130
131
132(define (markov-policy-actions policy state)
133 (define process (markov-policy-process policy))
134 (define actions ((markov-process-action-space process) state))
135 ((markov-policy-choice-function policy) state actions))
136
137
138(define (markov-policy-choose-action policy state)
139 (sample (markov-policy-actions policy state)))
140
141
142(struct transition-step (action new-state reward) #:transparent)
143
144
145(define (markov-policy-next-step policy state)
146 (define process (markov-policy-process policy))
147 (define action (markov-policy-choose-action policy state))
148 (match-define (transition-outcome new-state reward)
149 (markov-process-perform-action process state action))
150 (transition-step action new-state reward))
151
152
153(define (markov-policy-rollout policy init-state)
154 (define step (markov-policy-next-step policy init-state))
155 (define new-state (transition-step-new-state step))
156 (stream-cons step (markov-policy-rollout policy new-state)))
157
158
159(define (markov-policy-table policy)
160 (for/hash ([s (in-vector (markov-process-states (markov-policy-process policy)))])
161 (values s (markov-policy-actions policy s))))
162
163
164(define (markov-policy-iterative-evaluation policy #:discount-rate [discount-rate #e0.95])
165 (define process (markov-policy-process policy))
166 (define states (markov-process-states process))
167 (define rewards (markov-process-rewards process))
168 (define action-space (markov-process-action-space process))
169 (define value-estimates (make-vector (vector-length states) 0))
170 (define state-indices
171 (for/hash ([s (in-vector states)]
172 [i (in-naturals)])
173 (values s i)))
174 (define (estimate-state-value s)
175 (vector-ref value-estimates (hash-ref state-indices s)))
176 (let loop ()
177 (define any-changed?
178 (for/fold ([any-changed? #false])
179 ([v (in-vector value-estimates)]
180 [s (in-vector states)]
181 [i (in-naturals)])
182 (define policy-actions (markov-policy-actions policy s))
183 (define v*
184 (for/sum ([a (in-vector (action-space s))])
185 (define outcomes (markov-process-outcomes process s a))
186 (for*/sum ([s* (in-vector states)]
187 [r (in-vector rewards)])
188 (define s*-v (estimate-state-value s*))
189 (* (pdf policy-actions a)
190 (pdf outcomes (transition-outcome s* r))
191 (+ r (* discount-rate s*-v))))))
192 (vector-set! value-estimates i v*)
193 (define delta (abs (- v v*)))
194 (or any-changed? (> delta #e0.1))))
195 (when any-changed?
196 (loop)))
197 (for/hash ([s (in-vector states)]
198 [v (in-vector value-estimates)])
199 (values s v)))
200
201
202(define value-estimates (markov-policy-iterative-evaluation (random-policy gridworld)))
203(define policy (greedy-policy gridworld (λ (s) (hash-ref value-estimates s))))
204
205(for ([(s actions) (in-hash (markov-policy-table policy))])
206 (printf "state ~a: ~a\n" s actions))
207
208#;(module+ main
209 (printf "Starting in state ~a\n" 9)
210 (for ([step (in-stream (stream-take (markov-policy-rollout policy 9) 10))])
211 (match-define (transition-step action new-state reward) step)
212 (printf "Reached state ~a via action ~a, received reward ~a\n" new-state action reward)))