cli + tui to publish to leaflet (wip) & manage tasks, notes & watch/read lists 馃崈
charm
leaflet
readability
golang
1package store
2
3import (
4 "database/sql"
5 "fmt"
6 "io/fs"
7 "sort"
8 "strings"
9)
10
11// Migration represents a single database migration
12type Migration struct {
13 Version string
14 Name string
15 UpSQL string
16 DownSQL string
17 Applied bool
18 AppliedAt string
19}
20
21// FileSystem interface for reading migration files
22type FileSystem interface {
23 ReadDir(name string) ([]fs.DirEntry, error)
24 ReadFile(name string) ([]byte, error)
25}
26
27// MigrationRunner handles database migrations
28type MigrationRunner struct {
29 db *sql.DB
30 migrationFiles FileSystem
31}
32
33// NewMigrationRunner creates a new migration runner
34func NewMigrationRunner(db *sql.DB, files FileSystem) *MigrationRunner {
35 return &MigrationRunner{
36 db: db,
37 migrationFiles: files,
38 }
39}
40
41// RunMigrations applies all pending migrations
42func (mr *MigrationRunner) RunMigrations() error {
43 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
44 if err != nil {
45 return fmt.Errorf("failed to read migrations directory: %w", err)
46 }
47
48 var upMigrations []string
49 for _, entry := range entries {
50 if strings.HasSuffix(entry.Name(), "_up.sql") {
51 upMigrations = append(upMigrations, entry.Name())
52 }
53 }
54 sort.Strings(upMigrations)
55
56 for _, migrationFile := range upMigrations {
57 version := extractVersionFromFilename(migrationFile)
58
59 var count int
60 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count)
61 if err != nil {
62 return fmt.Errorf("failed to check migrations table: %w", err)
63 }
64
65 if count == 0 && version != "0000" {
66 continue
67 }
68
69 if count > 0 {
70 var applied int
71 err = mr.db.QueryRow("SELECT COUNT(*) FROM migrations WHERE version = ?", version).Scan(&applied)
72 if err != nil {
73 return fmt.Errorf("failed to check migration %s: %w", version, err)
74 }
75 if applied > 0 {
76 continue
77 }
78 }
79
80 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + migrationFile)
81 if err != nil {
82 return fmt.Errorf("failed to read migration %s: %w", migrationFile, err)
83 }
84
85 if _, err := mr.db.Exec(string(content)); err != nil {
86 return fmt.Errorf("failed to execute migration %s: %w", migrationFile, err)
87 }
88
89 if count > 0 || version == "0000" {
90 if _, err := mr.db.Exec("INSERT INTO migrations (version) VALUES (?)", version); err != nil {
91 return fmt.Errorf("failed to record migration %s: %w", version, err)
92 }
93 }
94 }
95
96 return nil
97}
98
99// GetAppliedMigrations returns a list of all applied migrations
100func (mr *MigrationRunner) GetAppliedMigrations() ([]Migration, error) {
101 var count int
102 err := mr.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='migrations'").Scan(&count)
103 if err != nil {
104 return nil, fmt.Errorf("failed to check migrations table: %w", err)
105 }
106
107 if count == 0 {
108 return []Migration{}, nil
109 }
110
111 rows, err := mr.db.Query("SELECT version, applied_at FROM migrations ORDER BY version")
112 if err != nil {
113 return nil, fmt.Errorf("failed to query migrations: %w", err)
114 }
115 defer rows.Close()
116
117 var migrations []Migration
118 for rows.Next() {
119 var m Migration
120 if err := rows.Scan(&m.Version, &m.AppliedAt); err != nil {
121 return nil, fmt.Errorf("failed to scan migration: %w", err)
122 }
123 m.Applied = true
124 migrations = append(migrations, m)
125 }
126
127 return migrations, nil
128}
129
130// GetAvailableMigrations returns all available migrations from embedded files
131func (mr *MigrationRunner) GetAvailableMigrations() ([]Migration, error) {
132 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
133 if err != nil {
134 return nil, fmt.Errorf("failed to read migrations directory: %w", err)
135 }
136
137 migrationMap := make(map[string]*Migration)
138
139 for _, entry := range entries {
140 version := extractVersionFromFilename(entry.Name())
141 if version == "" {
142 continue
143 }
144
145 if migrationMap[version] == nil {
146 migrationMap[version] = &Migration{
147 Version: version,
148 Name: extractNameFromFilename(entry.Name()),
149 }
150 }
151
152 content, err := mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name())
153 if err != nil {
154 return nil, fmt.Errorf("failed to read migration file %s: %w", entry.Name(), err)
155 }
156
157 if strings.HasSuffix(entry.Name(), "_up.sql") {
158 migrationMap[version].UpSQL = string(content)
159 } else if strings.HasSuffix(entry.Name(), "_down.sql") {
160 migrationMap[version].DownSQL = string(content)
161 }
162 }
163
164 var migrations []Migration
165 for _, m := range migrationMap {
166 migrations = append(migrations, *m)
167 }
168 sort.Slice(migrations, func(i, j int) bool {
169 return migrations[i].Version < migrations[j].Version
170 })
171
172 return migrations, nil
173}
174
175// Rollback rolls back the last applied migration
176func (mr *MigrationRunner) Rollback() error {
177 var version string
178 err := mr.db.QueryRow("SELECT version FROM migrations ORDER BY version DESC LIMIT 1").Scan(&version)
179 if err != nil {
180 if err == sql.ErrNoRows {
181 return fmt.Errorf("no migrations to rollback")
182 }
183 return fmt.Errorf("failed to get last migration: %w", err)
184 }
185
186 entries, err := mr.migrationFiles.ReadDir("sql/migrations")
187 if err != nil {
188 return fmt.Errorf("failed to read migrations directory: %w", err)
189 }
190
191 var downContent []byte
192 for _, entry := range entries {
193 if strings.HasPrefix(entry.Name(), version) && strings.HasSuffix(entry.Name(), "_down.sql") {
194 downContent, err = mr.migrationFiles.ReadFile("sql/migrations/" + entry.Name())
195 if err != nil {
196 return fmt.Errorf("failed to read down migration: %w", err)
197 }
198 break
199 }
200 }
201
202 if downContent == nil {
203 return fmt.Errorf("down migration not found for version %s", version)
204 }
205
206 if _, err := mr.db.Exec(string(downContent)); err != nil {
207 return fmt.Errorf("failed to execute down migration: %w", err)
208 }
209
210 if _, err := mr.db.Exec("DELETE FROM migrations WHERE version = ?", version); err != nil {
211 return fmt.Errorf("failed to remove migration record: %w", err)
212 }
213
214 return nil
215}
216
217// extractVersionFromFilename extracts the 4-digit version from a migration filename
218func extractVersionFromFilename(filename string) string {
219 parts := strings.Split(filename, "_")
220 if len(parts) > 0 {
221 return parts[0]
222 }
223 return ""
224}
225
226func extractNameFromFilename(filename string) string {
227 parts := strings.Split(filename, "_")
228 if len(parts) < 3 {
229 return ""
230 }
231
232 name := strings.Join(parts[1:len(parts)-1], "_")
233 return strings.TrimSuffix(name, "_up")
234}