this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

more lexgen updates

+400 -213
+17 -2
cmd/lexgen/main.go
··· 70 70 &cli.BoolFlag{ 71 71 Name: "gen-server", 72 72 }, 73 + &cli.BoolFlag{ 74 + Name: "gen-handlers", 75 + }, 73 76 &cli.StringSliceFlag{ 74 77 Name: "types-import", 75 78 }, ··· 103 106 104 107 pkgname := cctx.String("package") 105 108 109 + imports := map[string]string{ 110 + "app.bsky": "github.com/whyrusleeping/gosky/api/bsky", 111 + "com.atproto": "github.com/whyrusleeping/gosky/api/atproto", 112 + } 113 + 106 114 if cctx.Bool("gen-server") { 107 115 paths := cctx.StringSlice("types-import") 108 116 importmap := make(map[string]string) ··· 111 119 importmap[parts[0]] = parts[1] 112 120 } 113 121 114 - if err := lex.CreateHandlerStub(pkgname, importmap, outdir, schemas); err != nil { 122 + handlers := cctx.Bool("gen-handlers") 123 + 124 + if err := lex.CreateHandlerStub(pkgname, importmap, outdir, schemas, handlers); err != nil { 115 125 return err 116 126 } 117 127 118 128 } else { 129 + defmap := lex.BuildExtDefMap(schemas, []string{"com.atproto", "app.bsky"}) 119 130 for i, s := range schemas { 131 + if !strings.HasPrefix(s.ID, prefix) { 132 + continue 133 + } 134 + 120 135 fname := filepath.Join(outdir, s.Name()+".go") 121 136 122 - if err := lex.GenCodeForSchema(pkgname, prefix, fname, true, s); err != nil { 137 + if err := lex.GenCodeForSchema(pkgname, prefix, fname, true, s, defmap, imports); err != nil { 123 138 return fmt.Errorf("failed to process schema %q: %w", paths[i], err) 124 139 } 125 140 }
+374 -209
lex/gen.go
··· 22 22 type Schema struct { 23 23 prefix string 24 24 25 - Lexicon int `json:"lexicon"` 26 - ID string `json:"id"` 27 - Type string `json:"type"` 28 - Key string `json:"key"` 29 - Description string `json:"description"` 30 - Parameters *TypeSchema `json:"parameters"` 31 - Input *InputType `json:"input"` 32 - Output *OutputType `json:"output"` 33 - Defs map[string]TypeSchema `json:"defs"` 34 - Record *TypeSchema `json:"record"` 25 + Lexicon int `json:"lexicon"` 26 + ID string `json:"id"` 27 + Defs map[string]*TypeSchema `json:"defs"` 35 28 } 36 29 37 30 type Param struct { ··· 41 34 } 42 35 43 36 type OutputType struct { 44 - Encoding string `json:"encoding"` 45 - Schema TypeSchema `json:"schema"` 37 + Encoding string `json:"encoding"` 38 + Schema *TypeSchema `json:"schema"` 46 39 } 47 40 48 41 type InputType struct { 49 - Encoding string `json:"encoding"` 50 - Schema TypeSchema `json:"schema"` 42 + Encoding string `json:"encoding"` 43 + Schema *TypeSchema `json:"schema"` 51 44 } 52 45 53 46 type TypeSchema struct { 54 - Type string `json:"type"` 55 - Ref string `json:"$ref"` 56 - Required []string `json:"required"` 57 - Properties map[string]TypeSchema `json:"properties"` 58 - MaxLength int `json:"maxLength"` 59 - Items *TypeSchema `json:"items"` 60 - OneOf []TypeSchema `json:"oneOf"` 61 - Const *string `json:"const"` 62 - Enum []string `json:"enum"` 63 - Not *TypeSchema `json:"not"` 47 + prefix string 48 + id string 49 + defName string 50 + defMap map[string]*ExtDef 51 + 52 + Type string `json:"type"` 53 + Key string `json:"key"` 54 + Description string `json:"description"` 55 + Parameters *TypeSchema `json:"parameters"` 56 + Input *InputType `json:"input"` 57 + Output *OutputType `json:"output"` 58 + Record *TypeSchema `json:"record"` 59 + 60 + Ref string `json:"ref"` 61 + Refs []string `json:"refs"` 62 + Required []string `json:"required"` 63 + Properties map[string]*TypeSchema `json:"properties"` 64 + MaxLength int `json:"maxLength"` 65 + Items *TypeSchema `json:"items"` 66 + Const any `json:"const"` 67 + Enum []string `json:"enum"` 68 + Closed bool `json:"closed"` 64 69 } 65 70 66 71 func (s *Schema) Name() string { ··· 69 74 } 70 75 71 76 type outputType struct { 72 - Name string 73 - Type TypeSchema 74 - Record bool 77 + Name string 78 + DefName string 79 + Type *TypeSchema 80 + Record bool 75 81 } 76 82 77 - func (s *Schema) AllTypes(prefix string) []outputType { 83 + func (s *Schema) AllTypes(prefix string, defMap map[string]*ExtDef) []outputType { 78 84 var out []outputType 79 85 80 - var walk func(name string, ts TypeSchema, record bool) 81 - walk = func(name string, ts TypeSchema, record bool) { 86 + var walk func(name string, ts *TypeSchema, record bool) 87 + walk = func(name string, ts *TypeSchema, record bool) { 88 + if ts == nil { 89 + panic(fmt.Sprintf("nil type schema in %q (%s)", name, s.ID)) 90 + 91 + } 92 + ts.prefix = prefix 93 + ts.id = s.ID 94 + ts.defMap = defMap 82 95 if ts.Type == "object" || 83 - (ts.Type == "" && len(ts.OneOf) > 0) { 96 + (ts.Type == "union" && len(ts.Refs) > 0) { 84 97 out = append(out, outputType{ 85 98 Name: name, 86 99 Type: ts, ··· 93 106 } 94 107 95 108 if ts.Items != nil { 96 - walk(name+"_Elem", *ts.Items, false) 109 + walk(name+"_Elem", ts.Items, false) 110 + } 111 + 112 + if ts.Input != nil { 113 + if ts.Input.Schema == nil { 114 + if ts.Input.Encoding != "application/cbor" { 115 + panic(fmt.Sprintf("strange input type def in %s", s.ID)) 116 + } 117 + } else { 118 + walk(name+"_Input", ts.Input.Schema, false) 119 + } 120 + } 121 + 122 + if ts.Output != nil { 123 + if ts.Output.Schema == nil { 124 + if ts.Output.Encoding != "application/cbor" { 125 + panic(fmt.Sprintf("strange output type def in %s", s.ID)) 126 + } 127 + } else { 128 + walk(name+"_Output", ts.Output.Schema, false) 129 + } 97 130 } 98 - } 99 131 100 - tname := s.nameFromID(s.ID, prefix) 132 + if ts.Type == "record" { 133 + walk(name, ts.Record, false) 134 + } 101 135 102 - for name, def := range s.Defs { 103 - walk(tname+"_"+strings.Title(name), def, false) 104 136 } 105 137 106 - if s.Input != nil { 107 - walk(tname+"_Input", s.Input.Schema, false) 108 - } 109 - if s.Output != nil { 110 - walk(tname+"_Output", s.Output.Schema, false) 111 - } 138 + tname := nameFromID(s.ID, prefix) 112 139 113 - if s.Type == "record" { 114 - walk(tname, *s.Record, false) 140 + for name, def := range s.Defs { 141 + n := tname + "_" + strings.Title(name) 142 + if name == "main" { 143 + n = tname 144 + } 145 + walk(n, def, false) 115 146 } 116 147 117 148 return out ··· 131 162 return &s, nil 132 163 } 133 164 134 - func BuildExtDefMap(ss []*Schema) map[string]*ExtDef { 165 + func BuildExtDefMap(ss []*Schema, prefixes []string) map[string]*ExtDef { 135 166 out := make(map[string]*ExtDef) 136 167 for _, s := range ss { 137 168 for k, d := range s.Defs { 138 - out[s.ID+"#"+k] = &ExtDef{ 169 + d.defMap = out 170 + d.id = s.ID 171 + d.defName = k 172 + 173 + var pref string 174 + for _, p := range prefixes { 175 + if strings.HasPrefix(s.ID, p) { 176 + pref = p 177 + break 178 + } 179 + } 180 + d.prefix = pref 181 + 182 + n := s.ID 183 + if k != "main" { 184 + n = s.ID + "#" + k 185 + } 186 + fmt.Println("def map: ", n) 187 + out[n] = &ExtDef{ 139 188 Type: d, 140 189 } 141 190 } ··· 144 193 } 145 194 146 195 type ExtDef struct { 147 - Type TypeSchema 196 + Type *TypeSchema 148 197 } 149 198 150 - func GenCodeForSchema(pkg string, prefix string, fname string, reqcode bool, s *Schema, defmap map[string]*ExtDef) error { 199 + func GenCodeForSchema(pkg string, prefix string, fname string, reqcode bool, s *Schema, defmap map[string]*ExtDef, imports map[string]string) error { 151 200 buf := new(bytes.Buffer) 152 201 153 202 s.prefix = prefix 203 + for _, d := range s.Defs { 204 + fmt.Println("def id: ", d.id) 205 + d.prefix = prefix 206 + } 154 207 155 208 fmt.Fprintf(buf, "package %s\n\n", pkg) 156 209 fmt.Fprintf(buf, "import (\n") ··· 159 212 fmt.Fprintf(buf, "\t\"encoding/json\"\n") 160 213 fmt.Fprintf(buf, "\t\"github.com/whyrusleeping/gosky/xrpc\"\n") 161 214 fmt.Fprintf(buf, "\t\"github.com/whyrusleeping/gosky/lex/util\"\n") 215 + for k, v := range imports { 216 + if k != prefix { 217 + fmt.Fprintf(buf, "\t%s %q\n", importNameForPrefix(k), v) 218 + } 219 + } 162 220 fmt.Fprintf(buf, ")\n\n") 163 221 fmt.Fprintf(buf, "// schema: %s\n\n", s.ID) 164 222 165 - tps := s.AllTypes(prefix) 223 + tps := s.AllTypes(prefix, defmap) 166 224 167 225 for _, ot := range tps { 168 - if err := s.WriteType(ot.Name, ot.Type, buf); err != nil { 226 + if err := ot.Type.WriteType(ot.Name, buf); err != nil { 169 227 return err 170 228 } 171 229 } 172 230 173 231 if reqcode { 174 - if err := writeMethods(prefix, s, buf); err != nil { 232 + name := nameFromID(s.ID, prefix) 233 + if err := writeMethods(name, s.Defs["main"], buf); err != nil { 175 234 return err 176 235 } 177 236 } ··· 216 275 return buf.Bytes(), nil 217 276 } 218 277 219 - func writeMethods(prefix string, s *Schema, w io.Writer) error { 220 - switch s.Type { 278 + func writeMethods(typename string, ts *TypeSchema, w io.Writer) error { 279 + switch ts.Type { 221 280 case "token": 222 - n := s.nameFromID(s.ID, prefix) 223 - fmt.Fprintf(w, "const %s = %q\n", n, s.ID) 281 + fmt.Fprintf(w, "const %s = %q\n", typename, ts.id+"#"+ts.defName) 224 282 return nil 225 283 case "record": 226 284 return nil 227 285 case "query": 228 - return s.WriteRPC(w, prefix) 286 + return ts.WriteRPC(w, typename) 229 287 case "procedure": 230 - return s.WriteRPC(w, prefix) 288 + return ts.WriteRPC(w, typename) 289 + case "object": 290 + return nil 231 291 default: 232 - return fmt.Errorf("unrecognized lexicon type %q", s.Type) 292 + return fmt.Errorf("unrecognized lexicon type %q", ts.Type) 233 293 } 234 294 } 235 295 236 - func (s *Schema) nameFromID(id, prefix string) string { 296 + func nameFromID(id, prefix string) string { 237 297 parts := strings.Split(strings.TrimPrefix(id, prefix), ".") 238 298 var tname string 239 299 for _, s := range parts { ··· 244 304 245 305 } 246 306 247 - func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error { 307 + func orderedMapIter[T any](m map[string]*T, cb func(string, T) error) error { 248 308 var keys []string 249 309 for k := range m { 250 310 keys = append(keys, k) ··· 253 313 sort.Strings(keys) 254 314 255 315 for _, k := range keys { 256 - if err := cb(k, m[k]); err != nil { 316 + if err := cb(k, *m[k]); err != nil { 257 317 return err 258 318 } 259 319 } 260 320 return nil 261 321 } 262 322 263 - func (s *Schema) WriteRPC(w io.Writer, prefix string) error { 264 - fname := s.nameFromID(s.ID, s.prefix) 323 + func (s *TypeSchema) WriteRPC(w io.Writer, typename string) error { 324 + fname := typename 265 325 266 326 params := "ctx context.Context, c *xrpc.Client" 267 327 inpvar := "nil" 268 328 inpenc := "" 329 + 269 330 if s.Input != nil { 270 331 inpvar = "input" 271 332 inpenc = s.Input.Encoding ··· 355 416 return fmt.Errorf("can only generate RPC for Query or Procedure (got %s)", s.Type) 356 417 } 357 418 358 - fmt.Fprintf(w, "\tif err := c.Do(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.ID, queryparams, inpvar, outvar) 419 + fmt.Fprintf(w, "\tif err := c.Do(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.id, queryparams, inpvar, outvar) 359 420 fmt.Fprintf(w, "\t\treturn %s\n", errRet) 360 421 fmt.Fprintf(w, "\t}\n\n") 361 422 fmt.Fprintf(w, "\treturn %s\n", outRet) ··· 375 436 return t.Execute(w, info) 376 437 } 377 438 378 - func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema) error { 439 + func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error { 379 440 buf := new(bytes.Buffer) 380 441 381 442 if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil { ··· 387 448 return err 388 449 } 389 450 451 + if handlers { 452 + buf := new(bytes.Buffer) 453 + 454 + if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil { 455 + return err 456 + } 457 + 458 + fname := filepath.Join(dir, "handlers.go") 459 + if err := writeCodeFile(buf.Bytes(), fname); err != nil { 460 + return err 461 + } 462 + 463 + 464 + } 465 + 390 466 return nil 391 467 } 392 468 ··· 394 470 return strings.Join(strings.Split(prefix, "."), "") + "types" 395 471 } 396 472 473 + func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 474 + fmt.Fprintf(w, "package %s\n\n", pkg) 475 + fmt.Fprintf(w, "import (\n") 476 + fmt.Fprintf(w, "\t\"context\"\n") 477 + fmt.Fprintf(w, "\t\"fmt\"\n") 478 + fmt.Fprintf(w, "\t\"encoding/json\"\n") 479 + fmt.Fprintf(w, "\t\"github.com/whyrusleeping/gosky/xrpc\"\n") 480 + for k, v := range impmap { 481 + fmt.Fprintf(w, "\t%s\"%s\"\n", importNameForPrefix(k), v) 482 + } 483 + fmt.Fprintf(w, ")\n\n") 484 + 485 + 486 + for _, s := range schemas { 487 + 488 + var prefix string 489 + for k := range impmap { 490 + if strings.HasPrefix(s.ID, k) { 491 + prefix = k 492 + break 493 + } 494 + } 495 + 496 + main, ok := s.Defs["main"] 497 + if !ok { 498 + return fmt.Errorf("schema %q doesnt have a main def", s.ID) 499 + } 500 + 501 + if main.Type == "procedure" || main.Type == "query" { 502 + fname := idToTitle(s.ID) 503 + tname := nameFromID(s.ID, prefix) 504 + impname := importNameForPrefix(prefix) 505 + if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil { 506 + return err 507 + } 508 + } 509 + } 510 + 511 + return nil 512 + } 513 + 397 514 func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 398 515 fmt.Fprintf(w, "package %s\n\n", pkg) 399 516 fmt.Fprintf(w, "import (\n") ··· 409 526 410 527 fmt.Fprintf(w, "func (s *Server) RegisterHandlers(e echo.Echo) error {\n") 411 528 for _, s := range schemas { 529 + 530 + main, ok := s.Defs["main"] 531 + if !ok { 532 + return fmt.Errorf("schema %q has no main", s.ID) 533 + } 534 + 412 535 var verb string 413 - switch s.Type { 536 + switch main.Type { 414 537 case "query": 415 538 verb = "GET" 416 539 case "procedure": ··· 425 548 fmt.Fprintf(w, "return nil\n}\n\n") 426 549 427 550 for _, s := range schemas { 551 + 428 552 var prefix string 429 553 for k := range impmap { 430 554 if strings.HasPrefix(s.ID, k) { ··· 433 557 } 434 558 } 435 559 436 - if s.Type == "procedure" || s.Type == "query" { 437 - if err := s.WriteRPCHandler(w, prefix); err != nil { 560 + main, ok := s.Defs["main"] 561 + if !ok { 562 + return fmt.Errorf("schema %q doesnt have a main def", s.ID) 563 + } 564 + 565 + if main.Type == "procedure" || main.Type == "query" { 566 + fname := idToTitle(s.ID) 567 + tname := nameFromID(s.ID, prefix) 568 + impname := importNameForPrefix(prefix) 569 + if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil { 438 570 return err 439 571 } 440 572 } ··· 451 583 return fname 452 584 } 453 585 454 - func (s *Schema) WriteRPCHandler(w io.Writer, prefix string) error { 455 - fname := idToTitle(s.ID) 586 + func (s *TypeSchema) WriteHandlerStub(w io.Writer, fname, shortname, impname string) error { 587 + paramtypes := []string{"ctx context.Context"} 588 + if s.Type == "query" { 589 + if s.Parameters != nil { 590 + orderedMapIter[TypeSchema](s.Parameters.Properties, func(k string, t TypeSchema) error { 591 + switch t.Type { 592 + case "string": 593 + paramtypes = append(paramtypes, k+" string") 594 + case "integer": 595 + paramtypes = append(paramtypes, k+" int") 596 + case "number": 597 + return fmt.Errorf("non-integer numbers currently unsupported") 598 + default: 599 + return fmt.Errorf("unsupported handler parameter type: %s", t.Type) 600 + } 601 + return nil 602 + }) 603 + } 604 + } 605 + 606 + returndef := "error" 607 + if s.Output != nil { 608 + switch s.Output.Encoding { 609 + case "application/json": 610 + returndef = fmt.Sprintf("(*%s.%s_Output, error)", impname, shortname) 611 + case "application/cbor": 612 + returndef = fmt.Sprintf("([]byte, error)" ) 613 + default: 614 + return fmt.Errorf("unsupported encoding: %q", s.Output.Encoding) 615 + } 616 + } 617 + 618 + fmt.Fprintf(w, "func (s *Server) handle%s(%s) %s\n", fname, strings.Join(paramtypes, ","), returndef) 619 + 620 + return nil 621 + } 456 622 457 - tname := s.nameFromID(s.ID, prefix) 623 + func (s *TypeSchema) WriteRPCHandler(w io.Writer, fname, shortname, impname string) error { 624 + tname := shortname 458 625 459 626 fmt.Fprintf(w, "func (s *Server) Handle%s(c echo.Context) error {\n", fname) 460 627 461 628 fmt.Fprintf(w, "ctx, span := otel.Tracer(\"server\").Start(c.Request().Context(), %q)\n", "Handle"+fname) 462 629 fmt.Fprintf(w, "defer span.End()\n") 463 - 464 - impname := importNameForPrefix(prefix) 465 630 466 631 paramtypes := []string{"ctx context.Context"} 467 632 params := []string{"ctx"} ··· 513 678 fmt.Fprintf(w, "%s = s.handle%s(%s)\n", assign, fname, strings.Join(params, ",")) 514 679 fmt.Fprintf(w, "if handleErr != nil {\nreturn handleErr\n}\n") 515 680 681 + if s.Output != nil { 516 682 fmt.Fprintf(w, "return c.JSON(200, out)\n}\n\n") 683 + } else { 684 + fmt.Fprintf(w, "return nil\n}\n\n") 685 + } 517 686 518 687 return nil 519 688 } 520 689 521 - func (s *Schema) typeNameFromRef(r string) string { 522 - sname := s.nameFromID(s.ID, s.prefix) 523 - p := strings.Split(r, "/") 524 - return sname + "_" + strings.Title(p[len(p)-1]) 690 + func (s *TypeSchema) typeNameFromRef(r string) string { 691 + ts, err := s.lookupRef(r) 692 + if err != nil { 693 + panic(err) 694 + } 695 + 696 + if ts.prefix == "" { 697 + panic(fmt.Sprintf("no prefix for referenced type: %s", ts.id)) 698 + } 699 + 700 + if s.prefix == "" { 701 + panic(fmt.Sprintf("no prefix for referencing type: %q %q", s.id, s.defName)) 702 + } 703 + 704 + var pkg string 705 + if ts.prefix != s.prefix { 706 + pkg = importNameForPrefix(ts.prefix) + "." 707 + } 708 + 709 + return pkg + ts.TypeName() 525 710 } 526 711 527 - func (s *Schema) typeNameForField(name, k string, v TypeSchema) (string, error) { 712 + func (s *TypeSchema) TypeName() string { 713 + if s.id == "" { 714 + panic("type schema hint fields not set") 715 + } 716 + if s.prefix == "" { 717 + panic("why no prefix?") 718 + } 719 + n := nameFromID(s.id, s.prefix) 720 + if s.defName != "main" { 721 + n += "_" + strings.Title(s.defName) 722 + } 723 + 724 + return n 725 + } 726 + 727 + func (s *TypeSchema) typeNameForField(name, k string, v TypeSchema) (string, error) { 528 728 switch v.Type { 529 729 case "string": 530 730 return "string", nil ··· 536 736 return "bool", nil 537 737 case "object": 538 738 return "*" + name + "_" + strings.Title(k), nil 539 - case "": 540 - if v.Ref != "" { 541 - return "*" + s.typeNameFromRef(v.Ref), nil 542 - } 543 - 544 - if len(v.OneOf) > 0 { 545 - return "*" + name + "_" + strings.Title(k), nil 546 - } 547 - 548 - if v.Const != nil { 549 - return "string", nil 550 - } 551 - 552 - return "", fmt.Errorf("field %q in %s does not have discernable type name", k, name) 739 + case "ref": 740 + return "*" + s.typeNameFromRef(v.Ref), nil 741 + case "datetime": 742 + // TODO: maybe do a native type? 743 + return "string", nil 744 + case "unknown": 745 + return "any", nil 746 + case "union": 747 + return "*" + name + "_" + strings.Title(k), nil 748 + case "image": 749 + return "*util.Blob", nil 750 + case "blob": 751 + return "*util.Blob", nil 553 752 case "array": 554 753 subt, err := s.typeNameForField(name+"_"+strings.Title(k), "Elem", *v.Items) 555 754 if err != nil { ··· 562 761 } 563 762 } 564 763 565 - func (s *Schema) lookupRef(ref string) (*TypeSchema, error) { 566 - parts := strings.Split(ref, "/") 567 - if len(parts) < 3 { 568 - return nil, fmt.Errorf("invalid ref: %q", ref) 569 - } 570 - 571 - if parts[1] != "defs" { 572 - return nil, fmt.Errorf("ref lookups outside of defs not supported") 764 + func (ts *TypeSchema) lookupRef(ref string) (*TypeSchema, error) { 765 + fqref := ref 766 + if strings.HasPrefix(ref, "#") { 767 + fmt.Println("updating fqref: ", ts.id) 768 + fqref = ts.id + ref 573 769 } 574 - t, ok := s.Defs[parts[2]] 770 + rr, ok := ts.defMap[fqref] 575 771 if !ok { 576 - return nil, fmt.Errorf("no such def: %q", ref) 772 + fmt.Println(ts.defMap) 773 + panic(fmt.Sprintf("no such ref: %q", fqref)) 577 774 } 578 775 579 - return &t, nil 776 + return rr.Type, nil 580 777 } 581 778 582 - func (s *Schema) WriteType(name string, t TypeSchema, w io.Writer) error { 779 + func (ts *TypeSchema) WriteType(name string, w io.Writer) error { 583 780 name = strings.Title(name) 584 - if err := s.writeTypeDefinition(name, t, w); err != nil { 781 + if err := ts.writeTypeDefinition(name, w); err != nil { 585 782 return err 586 783 } 587 784 588 - if err := s.writeTypeMethods(name, t, w); err != nil { 785 + if err := ts.writeTypeMethods(name, w); err != nil { 589 786 return err 590 787 } 591 788 592 789 return nil 593 790 } 594 791 595 - func (s *Schema) writeTypeDefinition(name string, t TypeSchema, w io.Writer) error { 596 - switch t.Type { 792 + func (ts *TypeSchema) writeTypeDefinition(name string, w io.Writer) error { 793 + switch ts.Type { 597 794 case "string": 598 795 // TODO: deal with max length 599 796 fmt.Fprintf(w, "type %s string\n", name) ··· 604 801 case "boolean": 605 802 fmt.Fprintf(w, "type %s bool\n", name) 606 803 case "object": 607 - if len(t.Properties) == 0 { 804 + if len(ts.Properties) == 0 { 608 805 fmt.Fprintf(w, "type %s interface{}\n", name) 609 806 return nil 610 807 } 611 808 612 809 fmt.Fprintf(w, "type %s struct {\n", name) 613 810 614 - for k, v := range t.Properties { 811 + for k, v := range ts.Properties { 615 812 goname := strings.Title(k) 616 813 617 - tname, err := s.typeNameForField(name, k, v) 814 + tname, err := ts.typeNameForField(name, k, *v) 618 815 if err != nil { 619 816 return err 620 817 } ··· 624 821 fmt.Fprintf(w, "}\n\n") 625 822 626 823 case "array": 627 - tname, err := s.typeNameForField(name, "elem", *t.Items) 824 + tname, err := ts.typeNameForField(name, "elem", *ts.Items) 628 825 if err != nil { 629 826 return err 630 827 } 631 828 632 829 fmt.Fprintf(w, "type %s []%s\n", name, tname) 633 830 634 - case "": 635 - if len(t.OneOf) > 0 { 636 - // check if this is actually just a string enum 637 - first, err := s.lookupRef(t.OneOf[0].Ref) 638 - if err != nil { 639 - return fmt.Errorf("oneOf pre-check failed: %w", err) 831 + case "union": 832 + if len(ts.Refs) > 0 { 833 + fmt.Fprintf(w, "type %s struct {\n", name) 834 + for _, r := range ts.Refs { 835 + tname := ts.typeNameFromRef(r) 836 + fmt.Fprintf(w, "\t%s *%s\n", tname, tname) 640 837 } 641 - 642 - if first.Type == "string" { 643 - // okay, this is just a string enum, do something different 644 - fmt.Fprintf(w, "type %s string\n", name) 645 - } else { 646 - 647 - fmt.Fprintf(w, "type %s struct {\n", name) 648 - for _, e := range t.OneOf { 649 - // TODO: for now, asserting that all enum options are refs 650 - if e.Ref == "" { 651 - return fmt.Errorf("Enums must only contain refs") 652 - } 653 - 654 - tname := s.typeNameFromRef(e.Ref) 655 - fmt.Fprintf(w, "\t%s *%s\n", tname, tname) 656 - } 657 - fmt.Fprintf(w, "}\n\n") 658 - } 659 - 838 + fmt.Fprintf(w, "}\n\n") 660 839 } 661 840 default: 662 - return fmt.Errorf("%s has unrecognized type type %s", name, t.Type) 841 + return fmt.Errorf("%s has unrecognized type type %s", name, ts.Type) 663 842 } 664 843 665 844 return nil 666 845 } 667 846 668 - func (s *Schema) writeTypeMethods(name string, t TypeSchema, w io.Writer) error { 669 - switch t.Type { 847 + func (ts *TypeSchema) writeTypeMethods(name string, w io.Writer) error { 848 + switch ts.Type { 670 849 case "string", "number", "array", "boolean", "integer": 671 850 return nil 672 851 case "object": 673 - if err := s.writeJsonMarshalerObject(name, t, w); err != nil { 852 + if err := ts.writeJsonMarshalerObject(name, w); err != nil { 674 853 return err 675 854 } 676 855 677 - if err := s.writeJsonUnmarshalerObject(name, t, w); err != nil { 856 + if err := ts.writeJsonUnmarshalerObject(name, w); err != nil { 678 857 return err 679 858 } 680 859 681 860 return nil 682 - case "": 683 - if len(t.OneOf) > 0 { 684 - reft, err := s.lookupRef(t.OneOf[0].Ref) 861 + case "union": 862 + if len(ts.Refs) > 0 { 863 + reft, err := ts.lookupRef(ts.Refs[0]) 685 864 if err != nil { 686 865 return err 687 866 } ··· 690 869 return nil 691 870 } 692 871 693 - if err := s.writeJsonMarshalerEnum(name, t, w); err != nil { 872 + if err := ts.writeJsonMarshalerEnum(name, w); err != nil { 694 873 return err 695 874 } 696 875 697 - if err := s.writeJsonUnmarshalerEnum(name, t, w); err != nil { 876 + if err := ts.writeJsonUnmarshalerEnum(name, w); err != nil { 698 877 return err 699 878 } 700 879 ··· 703 882 704 883 return fmt.Errorf("%q unsupported for marshaling", name) 705 884 default: 706 - return fmt.Errorf("%q has unrecognized type type %s", name, t.Type) 885 + return fmt.Errorf("%q has unrecognized type type %s", name, ts.Type) 707 886 } 708 887 } 709 888 ··· 718 897 for _, k := range keys { 719 898 subv := t.Properties[k] 720 899 721 - if err := cb(k, subv); err != nil { 900 + if err := cb(k, *subv); err != nil { 722 901 return err 723 902 } 724 903 } 725 904 return nil 726 905 } 727 906 728 - func (s *Schema) writeJsonMarshalerObject(name string, t TypeSchema, w io.Writer) error { 729 - if len(t.Properties) == 0 { 907 + func (ts *TypeSchema) writeJsonMarshalerObject(name string, w io.Writer) error { 908 + if len(ts.Properties) == 0 { 730 909 // TODO: this is a hacky special casing of record types... 731 910 return nil 732 911 } 733 912 734 913 fmt.Fprintf(w, "func (t *%s) MarshalJSON() ([]byte, error) {\n", name) 735 914 736 - if err := forEachProp(t, func(k string, ts TypeSchema) error { 915 + if err := forEachProp(*ts, func(k string, ts TypeSchema) error { 737 916 if ts.Const != nil { 738 917 // TODO: maybe check for mutations before overwriting? mutations would mean bad code 739 - fmt.Fprintf(w, "\tt.%s = %q\n", strings.Title(k), *ts.Const) 918 + switch ts.Const.(type) { 919 + case string: 920 + fmt.Fprintf(w, "\tt.%s = %q\n", strings.Title(k), ts.Const) 921 + case bool: 922 + fmt.Fprintf(w, "\tt.%s = %v\n", strings.Title(k), ts.Const) 923 + default: 924 + return fmt.Errorf("unsupported const type: %T", ts.Const) 925 + 926 + } 740 927 } 741 928 742 929 return nil ··· 746 933 747 934 // TODO: this is ugly since i can't just pass things through to json.Marshal without causing an infinite recursion... 748 935 fmt.Fprintf(w, "\tout := make(map[string]interface{})\n") 749 - if err := forEachProp(t, func(k string, ts TypeSchema) error { 936 + if err := forEachProp(*ts, func(k string, ts TypeSchema) error { 750 937 fmt.Fprintf(w, "\tout[%q] = t.%s\n", k, strings.Title(k)) 751 938 return nil 752 939 }); err != nil { ··· 757 944 return nil 758 945 } 759 946 760 - func (s *Schema) writeJsonMarshalerEnum(name string, t TypeSchema, w io.Writer) error { 947 + func (ts *TypeSchema) writeJsonMarshalerEnum(name string, w io.Writer) error { 761 948 fmt.Fprintf(w, "func (t *%s) MarshalJSON() ([]byte, error) {\n", name) 762 949 763 - for _, e := range t.OneOf { 764 - tname := s.typeNameFromRef(e.Ref) 950 + for _, e := range ts.Refs { 951 + tname := ts.typeNameFromRef(e) 765 952 fmt.Fprintf(w, "\tif t.%s != nil {\n", tname) 766 953 fmt.Fprintf(w, "\t\treturn json.Marshal(t.%s)\n\t}\n", tname) 767 954 } ··· 770 957 return nil 771 958 } 772 959 773 - func (s *Schema) writeJsonUnmarshalerObject(name string, t TypeSchema, w io.Writer) error { 960 + func (s *TypeSchema) writeJsonUnmarshalerObject(name string, w io.Writer) error { 774 961 // TODO: would be nice to add some validation... 775 962 return nil 776 963 //fmt.Fprintf(w, "func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name) 777 964 } 778 965 779 - func (s *Schema) getTypeConstValueForType(t TypeSchema) (string, []string, error) { 780 - parts := strings.Split(t.Ref, "/") 781 - if len(parts) == 3 && parts[0] == "#" && parts[1] == "defs" { 782 - def, ok := s.Defs[parts[2]] 783 - if !ok { 784 - return "", nil, fmt.Errorf("bad reference %q", parts[2]) 785 - } 966 + func (ts *TypeSchema) getTypeConstValueForType(ref string) (any, error) { 967 + rr, err := ts.lookupRef(ref) 968 + if err != nil { 969 + return nil, err 970 + } 786 971 787 - typ, ok := def.Properties["type"] 788 - if !ok { 789 - return "", nil, fmt.Errorf("referenced enum value %q does not have type property", parts[2]) 790 - } 791 - 792 - if typ.Const == nil && typ.Not == nil { 793 - return "", nil, fmt.Errorf("referenced enum value %q has non-const type property and no not", parts[2]) 794 - } 795 - 796 - if typ.Const != nil { 797 - return *typ.Const, nil, nil 798 - } 799 - 800 - if len(typ.Not.Enum) == 0 { 801 - return "", nil, fmt.Errorf("final clause 'not' enum must not be empty") 802 - } 803 - 804 - return "", typ.Not.Enum, nil 972 + reft, ok := rr.Properties["type"] 973 + if !ok { 974 + return nil, nil 805 975 } 806 976 807 - return "", nil, fmt.Errorf("type had bad Ref value: %q", t.Ref) 977 + return reft.Const, nil 808 978 } 809 979 810 - func (s *Schema) writeJsonUnmarshalerEnum(name string, t TypeSchema, w io.Writer) error { 980 + func (ts *TypeSchema) writeJsonUnmarshalerEnum(name string, w io.Writer) error { 811 981 fmt.Fprintf(w, "func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name) 812 982 fmt.Fprintf(w, "\ttyp, err := util.EnumTypeExtract(b)\n") 813 983 fmt.Fprintf(w, "\tif err != nil {\n\t\treturn err\n\t}\n\n") 814 984 fmt.Fprintf(w, "\tswitch typ {\n") 815 - for i, e := range t.OneOf { 816 - tc, nots, err := s.getTypeConstValueForType(e) 817 - if err != nil { 818 - return err 819 - } 820 - 821 - if len(nots) > 0 { 822 - if i == len(t.OneOf)-1 { 823 - tnref := s.typeNameFromRef(e.Ref) 824 - fmt.Fprintf(w, ` 825 - default: 826 - var out %s 827 - if err := json.Unmarshal(b, &out); err != nil { 828 - return err 985 + for _, e := range ts.Refs { 986 + if strings.HasPrefix(e, "#") { 987 + e = ts.id + e 829 988 } 830 - t.%s = &out 831 - return nil 832 - `, tnref, tnref) 833 989 834 - } else { 835 - return fmt.Errorf("enum member with a not clause must be the last in a oneOf") 836 - } 837 - break 838 - } 839 990 840 - goname := s.typeNameFromRef(e.Ref) 991 + goname := ts.typeNameFromRef(e) 841 992 842 - fmt.Fprintf(w, "\t\tcase \"%s\":\n", tc) 993 + fmt.Fprintf(w, "\t\tcase \"%s\":\n", e) 843 994 fmt.Fprintf(w, "\t\t\tt.%s = new(%s)\n", goname, goname) 844 995 fmt.Fprintf(w, "\t\t\treturn json.Unmarshal(b, t.%s)\n", goname) 845 996 } 997 + 998 + if ts.Closed { 999 + fmt.Fprintf(w, ` 1000 + default: 1001 + return fmt.Errorf("closed enums must have a matching value") 1002 + `) 1003 + } else { 1004 + fmt.Fprintf(w, ` 1005 + default: 1006 + return nil 1007 + `) 1008 + 1009 + } 1010 + 846 1011 fmt.Fprintf(w, "\t}\n") 847 1012 fmt.Fprintf(w, "}\n\n") 848 1013
+9 -2
lex/util/util.go
··· 1 1 package util 2 2 3 - import "encoding/json" 3 + import ( 4 + "encoding/json" 5 + ) 4 6 5 7 type typeExtractor struct { 6 - Type string `json:"type"` 8 + Type string `json:"$type"` 7 9 } 8 10 9 11 func EnumTypeExtract(b []byte) (string, error) { ··· 14 16 15 17 return te.Type, nil 16 18 } 19 + 20 + type Blob struct { 21 + Cid string `json:"cid"` 22 + MimeType string `json:"mimeType"` 23 + }