this repo has no description
9
fork

Configure Feed

Select the types of activity you want to include in your feed.

add create table statement

Signed-off-by: oppiliappan <me@oppi.li>

+548
+209
create_table.go
··· 1 + package norm 2 + 3 + import ( 4 + "context" 5 + "database/sql" 6 + "fmt" 7 + "strings" 8 + ) 9 + 10 + type SQLiteType string 11 + 12 + const ( 13 + Integer SQLiteType = "INTEGER" 14 + Text SQLiteType = "TEXT" 15 + Real SQLiteType = "REAL" 16 + Blob SQLiteType = "BLOB" 17 + Numeric SQLiteType = "NUMERIC" 18 + ) 19 + 20 + type ColumnConstraint interface { 21 + applyConstraint(*columnDef) 22 + } 23 + 24 + type constraintFunc func(*columnDef) 25 + 26 + func (f constraintFunc) applyConstraint(col *columnDef) { 27 + f(col) 28 + } 29 + 30 + var ( 31 + PrimaryKey = constraintFunc(func(col *columnDef) { 32 + col.constraints = append(col.constraints, "PRIMARY KEY") 33 + }) 34 + 35 + AutoIncrement = constraintFunc(func(col *columnDef) { 36 + col.constraints = append(col.constraints, "AUTOINCREMENT") 37 + }) 38 + 39 + NotNull = constraintFunc(func(col *columnDef) { 40 + col.constraints = append(col.constraints, "NOT NULL") 41 + }) 42 + 43 + Unique = constraintFunc(func(col *columnDef) { 44 + col.constraints = append(col.constraints, "UNIQUE") 45 + }) 46 + ) 47 + 48 + func Default(val any) ColumnConstraint { 49 + return constraintFunc(func(col *columnDef) { 50 + col.constraints = append(col.constraints, fmt.Sprintf("DEFAULT %v", val)) 51 + }) 52 + } 53 + 54 + func Check(expr string) ColumnConstraint { 55 + return constraintFunc(func(col *columnDef) { 56 + col.constraints = append(col.constraints, fmt.Sprintf("CHECK (%s)", expr)) 57 + }) 58 + } 59 + 60 + func Collate(collation string) ColumnConstraint { 61 + return constraintFunc(func(col *columnDef) { 62 + col.constraints = append(col.constraints, fmt.Sprintf("COLLATE %s", collation)) 63 + }) 64 + } 65 + 66 + type columnDef struct { 67 + name string 68 + dataType SQLiteType 69 + constraints []string 70 + } 71 + 72 + type createTable struct { 73 + table string 74 + ifNotExists bool 75 + columns []columnDef 76 + tableConstraints []string 77 + withoutRowid bool 78 + strict bool 79 + } 80 + 81 + func CreateTable(name string) createTable { 82 + return createTable{table: name} 83 + } 84 + 85 + func (c createTable) IfNotExists() createTable { 86 + c.ifNotExists = true 87 + return c 88 + } 89 + 90 + func (c createTable) Column(name string, dataType SQLiteType, constraints ...ColumnConstraint) createTable { 91 + col := columnDef{ 92 + name: name, 93 + dataType: dataType, 94 + } 95 + for _, constraint := range constraints { 96 + constraint.applyConstraint(&col) 97 + } 98 + c.columns = append(c.columns, col) 99 + return c 100 + } 101 + 102 + func (c createTable) PrimaryKey(cols ...string) createTable { 103 + c.tableConstraints = append(c.tableConstraints, 104 + fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(cols, ", "))) 105 + return c 106 + } 107 + 108 + func (c createTable) UniqueConstraint(cols ...string) createTable { 109 + c.tableConstraints = append(c.tableConstraints, 110 + fmt.Sprintf("UNIQUE (%s)", strings.Join(cols, ", "))) 111 + return c 112 + } 113 + 114 + func (c createTable) CheckConstraint(expr string) createTable { 115 + c.tableConstraints = append(c.tableConstraints, 116 + fmt.Sprintf("CHECK (%s)", expr)) 117 + return c 118 + } 119 + 120 + func (c createTable) ForeignKey(col, refTable, refCol string) createTable { 121 + c.tableConstraints = append(c.tableConstraints, 122 + fmt.Sprintf("FOREIGN KEY (%s) REFERENCES %s(%s)", col, refTable, refCol)) 123 + return c 124 + } 125 + 126 + func (c createTable) WithoutRowid() createTable { 127 + c.withoutRowid = true 128 + return c 129 + } 130 + 131 + func (c createTable) Strict() createTable { 132 + c.strict = true 133 + return c 134 + } 135 + 136 + func (c createTable) Compile() (string, []any, error) { 137 + var sql strings.Builder 138 + 139 + sql.WriteString("CREATE TABLE ") 140 + 141 + if c.ifNotExists { 142 + sql.WriteString("IF NOT EXISTS ") 143 + } 144 + 145 + if c.table == "" { 146 + return "", nil, fmt.Errorf("table name is required") 147 + } 148 + sql.WriteString(c.table) 149 + 150 + if len(c.columns) == 0 { 151 + return "", nil, fmt.Errorf("at least one column is required") 152 + } 153 + 154 + sql.WriteString(" (") 155 + 156 + // Column definitions 157 + for i, col := range c.columns { 158 + if i > 0 { 159 + sql.WriteString(", ") 160 + } 161 + 162 + sql.WriteString(col.name) 163 + sql.WriteString(" ") 164 + sql.WriteString(string(col.dataType)) 165 + 166 + for _, constraint := range col.constraints { 167 + sql.WriteString(" ") 168 + sql.WriteString(constraint) 169 + } 170 + } 171 + 172 + // Table-level constraints 173 + for _, constraint := range c.tableConstraints { 174 + sql.WriteString(", ") 175 + sql.WriteString(constraint) 176 + } 177 + 178 + sql.WriteString(")") 179 + 180 + if c.strict { 181 + sql.WriteString(" STRICT") 182 + } 183 + 184 + if c.withoutRowid { 185 + sql.WriteString(" WITHOUT ROWID") 186 + } 187 + 188 + return sql.String(), nil, nil 189 + } 190 + 191 + func (c createTable) MustCompile() (string, []any) { 192 + sql, args, err := c.Compile() 193 + if err != nil { 194 + panic(err) 195 + } 196 + return sql, args 197 + } 198 + 199 + func (c createTable) Build(p Database) (*sql.Stmt, []any, error) { return Build(c, p) } 200 + func (c createTable) MustBuild(p Database) (*sql.Stmt, []any) { return MustBuild(c, p) } 201 + 202 + func (c createTable) Exec(p Database) (sql.Result, error) { return Exec(c, p) } 203 + func (c createTable) ExecContext(ctx context.Context, p Database) (sql.Result, error) { 204 + return ExecContext(ctx, c, p) 205 + } 206 + func (c createTable) MustExec(p Database) sql.Result { return MustExec(c, p) } 207 + func (c createTable) MustExecContext(ctx context.Context, p Database) sql.Result { 208 + return MustExecContext(ctx, c, p) 209 + }
+339
create_table_test.go
··· 1 + package norm 2 + 3 + import ( 4 + "database/sql" 5 + "testing" 6 + 7 + _ "github.com/mattn/go-sqlite3" 8 + ) 9 + 10 + func TestCreateTableCompileSuccess(t *testing.T) { 11 + tests := []struct { 12 + name string 13 + stmt Compiler 14 + expectedSql string 15 + }{ 16 + { 17 + name: "Simple table", 18 + stmt: CreateTable("users"). 19 + Column("id", Integer), 20 + expectedSql: "CREATE TABLE users (id INTEGER)", 21 + }, 22 + { 23 + name: "Table with primary key", 24 + stmt: CreateTable("users"). 25 + Column("id", Integer, PrimaryKey), 26 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY)", 27 + }, 28 + { 29 + name: "Table with autoincrement", 30 + stmt: CreateTable("users"). 31 + Column("id", Integer, PrimaryKey, AutoIncrement), 32 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT)", 33 + }, 34 + { 35 + name: "Table with multiple columns", 36 + stmt: CreateTable("users"). 37 + Column("id", Integer, PrimaryKey). 38 + Column("name", Text, NotNull). 39 + Column("age", Integer), 40 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER)", 41 + }, 42 + { 43 + name: "Table with IF NOT EXISTS", 44 + stmt: CreateTable("users"). 45 + IfNotExists(). 46 + Column("id", Integer, PrimaryKey), 47 + expectedSql: "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY)", 48 + }, 49 + { 50 + name: "Table with unique constraint", 51 + stmt: CreateTable("users"). 52 + Column("id", Integer, PrimaryKey). 53 + Column("email", Text, Unique), 54 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT UNIQUE)", 55 + }, 56 + { 57 + name: "Table with default value", 58 + stmt: CreateTable("users"). 59 + Column("id", Integer, PrimaryKey). 60 + Column("active", Integer, Default(1)), 61 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, active INTEGER DEFAULT 1)", 62 + }, 63 + { 64 + name: "Table with check constraint", 65 + stmt: CreateTable("users"). 66 + Column("id", Integer, PrimaryKey). 67 + Column("age", Integer, Check("age >= 18")), 68 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, age INTEGER CHECK (age >= 18))", 69 + }, 70 + { 71 + name: "Table with collate", 72 + stmt: CreateTable("users"). 73 + Column("id", Integer, PrimaryKey). 74 + Column("name", Text, Collate("NOCASE")), 75 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT COLLATE NOCASE)", 76 + }, 77 + { 78 + name: "Table with composite primary key", 79 + stmt: CreateTable("user_roles"). 80 + Column("user_id", Integer). 81 + Column("role_id", Integer). 82 + PrimaryKey("user_id", "role_id"), 83 + expectedSql: "CREATE TABLE user_roles (user_id INTEGER, role_id INTEGER, PRIMARY KEY (user_id, role_id))", 84 + }, 85 + { 86 + name: "Table with table-level unique constraint", 87 + stmt: CreateTable("users"). 88 + Column("id", Integer, PrimaryKey). 89 + Column("email", Text). 90 + Column("username", Text). 91 + UniqueConstraint("email", "username"), 92 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT, username TEXT, UNIQUE (email, username))", 93 + }, 94 + { 95 + name: "Table with table-level check constraint", 96 + stmt: CreateTable("users"). 97 + Column("id", Integer, PrimaryKey). 98 + Column("age", Integer). 99 + CheckConstraint("age >= 0 AND age <= 150"), 100 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, age INTEGER, CHECK (age >= 0 AND age <= 150))", 101 + }, 102 + { 103 + name: "Table with foreign key", 104 + stmt: CreateTable("posts"). 105 + Column("id", Integer, PrimaryKey). 106 + Column("user_id", Integer, NotNull). 107 + ForeignKey("user_id", "users", "id"), 108 + expectedSql: "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER NOT NULL, FOREIGN KEY (user_id) REFERENCES users(id))", 109 + }, 110 + { 111 + name: "Table WITHOUT ROWID", 112 + stmt: CreateTable("cache"). 113 + Column("key", Text, PrimaryKey). 114 + Column("value", Blob). 115 + WithoutRowid(), 116 + expectedSql: "CREATE TABLE cache (key TEXT PRIMARY KEY, value BLOB) WITHOUT ROWID", 117 + }, 118 + { 119 + name: "Table STRICT", 120 + stmt: CreateTable("users"). 121 + Column("id", Integer, PrimaryKey). 122 + Column("name", Text). 123 + Strict(), 124 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT) STRICT", 125 + }, 126 + { 127 + name: "Table STRICT WITHOUT ROWID", 128 + stmt: CreateTable("cache"). 129 + Column("key", Text, PrimaryKey). 130 + Column("value", Blob). 131 + Strict(). 132 + WithoutRowid(), 133 + expectedSql: "CREATE TABLE cache (key TEXT PRIMARY KEY, value BLOB) STRICT WITHOUT ROWID", 134 + }, 135 + { 136 + name: "All data types", 137 + stmt: CreateTable("types_test"). 138 + Column("col_int", Integer). 139 + Column("col_text", Text). 140 + Column("col_real", Real). 141 + Column("col_blob", Blob). 142 + Column("col_numeric", Numeric), 143 + expectedSql: "CREATE TABLE types_test (col_int INTEGER, col_text TEXT, col_real REAL, col_blob BLOB, col_numeric NUMERIC)", 144 + }, 145 + { 146 + name: "Multiple constraints on column", 147 + stmt: CreateTable("users"). 148 + Column("id", Integer, PrimaryKey, AutoIncrement). 149 + Column("email", Text, NotNull, Unique), 150 + expectedSql: "CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL UNIQUE)", 151 + }, 152 + { 153 + name: "Complex table", 154 + stmt: CreateTable("users"). 155 + IfNotExists(). 156 + Column("id", Integer, PrimaryKey, AutoIncrement). 157 + Column("email", Text, NotNull). 158 + Column("username", Text, NotNull). 159 + Column("age", Integer, Check("age >= 18")). 160 + Column("created_at", Text, Default("CURRENT_TIMESTAMP")). 161 + UniqueConstraint("email", "username"). 162 + CheckConstraint("age < 150"), 163 + expectedSql: "CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL, username TEXT NOT NULL, age INTEGER CHECK (age >= 18), created_at TEXT DEFAULT CURRENT_TIMESTAMP, UNIQUE (email, username), CHECK (age < 150))", 164 + }, 165 + } 166 + 167 + for _, test := range tests { 168 + t.Run(test.name, func(t *testing.T) { 169 + sql, args := test.stmt.MustCompile() 170 + 171 + if sql != test.expectedSql { 172 + t.Errorf("Expected SQL:\n%s\nGot:\n%s", test.expectedSql, sql) 173 + } 174 + 175 + if len(args) != 0 { 176 + t.Errorf("Expected 0 args, got %d args", len(args)) 177 + } 178 + }) 179 + } 180 + } 181 + 182 + func TestCreateTableCompileFail(t *testing.T) { 183 + tests := []struct { 184 + name string 185 + stmt Compiler 186 + expectedError string 187 + }{ 188 + { 189 + name: "No table name", 190 + stmt: CreateTable(""), 191 + expectedError: "table name is required", 192 + }, 193 + { 194 + name: "No columns", 195 + stmt: CreateTable("users"), 196 + expectedError: "at least one column is required", 197 + }, 198 + } 199 + 200 + for _, test := range tests { 201 + t.Run(test.name, func(t *testing.T) { 202 + _, _, err := test.stmt.Compile() 203 + if err == nil { 204 + t.Error("Expected error, got nil") 205 + } 206 + if err.Error() != test.expectedError { 207 + t.Errorf("Expected error '%s', got '%s'", test.expectedError, err.Error()) 208 + } 209 + }) 210 + } 211 + } 212 + 213 + func TestCreateTableIntegration(t *testing.T) { 214 + tests := []struct { 215 + name string 216 + createStmt Execer 217 + insertStmt Execer 218 + selectStmt Querier 219 + verify func(t *testing.T, db *sql.DB) 220 + }{ 221 + { 222 + name: "Create simple table and insert", 223 + createStmt: CreateTable("test_users"). 224 + Column("id", Integer, PrimaryKey). 225 + Column("name", Text), 226 + insertStmt: Insert().Into("test_users").Value("id", 1).Value("name", "Alice"), 227 + selectStmt: Select("id", "name").From("test_users"), 228 + verify: func(t *testing.T, db *sql.DB) { 229 + var count int 230 + err := db.QueryRow("SELECT COUNT(*) FROM test_users").Scan(&count) 231 + if err != nil { 232 + t.Fatalf("Failed to count rows: %v", err) 233 + } 234 + if count != 1 { 235 + t.Errorf("Expected 1 row, got %d", count) 236 + } 237 + }, 238 + }, 239 + { 240 + name: "Create table with IF NOT EXISTS", 241 + createStmt: CreateTable("test_users2"). 242 + IfNotExists(). 243 + Column("id", Integer, PrimaryKey), 244 + verify: func(t *testing.T, db *sql.DB) { 245 + _, err := CreateTable("test_users2"). 246 + IfNotExists(). 247 + Column("id", Integer, PrimaryKey). 248 + Exec(db) 249 + if err != nil { 250 + t.Errorf("Second create with IF NOT EXISTS should not fail: %v", err) 251 + } 252 + }, 253 + }, 254 + { 255 + name: "Create table with autoincrement", 256 + createStmt: CreateTable("test_users3"). 257 + Column("id", Integer, PrimaryKey, AutoIncrement). 258 + Column("name", Text), 259 + insertStmt: Insert().Into("test_users3").Value("name", "Bob"), 260 + verify: func(t *testing.T, db *sql.DB) { 261 + var id int 262 + var name string 263 + err := db.QueryRow("SELECT id, name FROM test_users3").Scan(&id, &name) 264 + if err != nil { 265 + t.Fatalf("Failed to query: %v", err) 266 + } 267 + if id != 1 { 268 + t.Errorf("Expected auto-incremented id=1, got %d", id) 269 + } 270 + if name != "Bob" { 271 + t.Errorf("Expected name=Bob, got %s", name) 272 + } 273 + }, 274 + }, 275 + { 276 + name: "Create table with default value", 277 + createStmt: CreateTable("test_users4"). 278 + Column("id", Integer, PrimaryKey). 279 + Column("active", Integer, Default(1)), 280 + insertStmt: Insert().Into("test_users4").Value("id", 1), 281 + verify: func(t *testing.T, db *sql.DB) { 282 + var active int 283 + err := db.QueryRow("SELECT active FROM test_users4 WHERE id = 1").Scan(&active) 284 + if err != nil { 285 + t.Fatalf("Failed to query: %v", err) 286 + } 287 + if active != 1 { 288 + t.Errorf("Expected default active=1, got %d", active) 289 + } 290 + }, 291 + }, 292 + { 293 + name: "Create table with unique constraint", 294 + createStmt: CreateTable("test_users5"). 295 + Column("id", Integer, PrimaryKey). 296 + Column("email", Text, Unique), 297 + verify: func(t *testing.T, db *sql.DB) { 298 + _, err := Insert().Into("test_users5"). 299 + Value("id", 1). 300 + Value("email", "test@example.com"). 301 + Exec(db) 302 + if err != nil { 303 + t.Fatalf("First insert should succeed: %v", err) 304 + } 305 + 306 + _, err = Insert().Into("test_users5"). 307 + Value("id", 2). 308 + Value("email", "test@example.com"). 309 + Exec(db) 310 + if err == nil { 311 + t.Error("Expected unique constraint violation, got nil") 312 + } 313 + }, 314 + }, 315 + } 316 + 317 + for _, test := range tests { 318 + t.Run(test.name, func(t *testing.T) { 319 + db := setupTestDB(t) 320 + defer db.Close() 321 + 322 + _, err := test.createStmt.Exec(db) 323 + if err != nil { 324 + t.Fatalf("Failed to create table: %v", err) 325 + } 326 + 327 + if test.insertStmt != nil { 328 + _, err = test.insertStmt.Exec(db) 329 + if err != nil { 330 + t.Fatalf("Failed to insert: %v", err) 331 + } 332 + } 333 + 334 + if test.verify != nil { 335 + test.verify(t, db) 336 + } 337 + }) 338 + } 339 + }