forked from
tangled.org/core
Monorepo for Tangled
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 lexutil "github.com/bluesky-social/indigo/lex/util"
16 "github.com/ipfs/go-cid"
17 "tangled.org/core/appview/models"
18 "tangled.org/core/appview/pagination"
19 "tangled.org/core/orm"
20 "tangled.org/core/sets"
21)
22
23func comparePullSource(existing, new *models.PullSource) bool {
24 if existing == nil && new == nil {
25 return true
26 }
27 if existing == nil || new == nil {
28 return false
29 }
30 if existing.Branch != new.Branch {
31 return false
32 }
33 if existing.RepoAt == nil && new.RepoAt == nil {
34 return true
35 }
36 if existing.RepoAt == nil || new.RepoAt == nil {
37 return false
38 }
39 return *existing.RepoAt == *new.RepoAt
40}
41
42func compareSubmissions(existing, new []*models.PullSubmission) bool {
43 if len(existing) != len(new) {
44 return false
45 }
46 for i := range existing {
47 if existing[i].Blob.Ref.String() != new[i].Blob.Ref.String() {
48 return false
49 }
50 if existing[i].Blob.MimeType != new[i].Blob.MimeType {
51 return false
52 }
53 if existing[i].Blob.Size != new[i].Blob.Size {
54 return false
55 }
56 }
57 return true
58}
59
60func PutPull(tx *sql.Tx, pull *models.Pull) error {
61 // ensure sequence exists
62 _, err := tx.Exec(`
63 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
64 values (?, 1)
65 `, pull.RepoAt)
66 if err != nil {
67 return err
68 }
69
70 pulls, err := GetPulls(
71 tx,
72 orm.FilterEq("owner_did", pull.OwnerDid),
73 orm.FilterEq("rkey", pull.Rkey),
74 )
75 switch {
76 case err != nil:
77 return err
78 case len(pulls) == 0:
79 return createNewPull(tx, pull)
80 case len(pulls) != 1: // should be unreachable
81 return fmt.Errorf("invalid number of pulls returned: %d", len(pulls))
82 default:
83 existingPull := pulls[0]
84 if existingPull.State == models.PullMerged {
85 return nil
86 }
87
88 dependentOnEqual := (existingPull.DependentOn == nil && pull.DependentOn == nil) ||
89 (existingPull.DependentOn != nil && pull.DependentOn != nil && *existingPull.DependentOn == *pull.DependentOn)
90
91 pullSourceEqual := comparePullSource(existingPull.PullSource, pull.PullSource)
92 submissionsEqual := compareSubmissions(existingPull.Submissions, pull.Submissions)
93
94 if existingPull.Title == pull.Title &&
95 existingPull.Body == pull.Body &&
96 existingPull.TargetBranch == pull.TargetBranch &&
97 existingPull.RepoAt == pull.RepoAt &&
98 dependentOnEqual &&
99 pullSourceEqual &&
100 submissionsEqual {
101 return nil
102 }
103
104 isLonger := len(existingPull.Submissions) < len(pull.Submissions)
105 if isLonger {
106 isAppendOnly := compareSubmissions(existingPull.Submissions, pull.Submissions[:len(existingPull.Submissions)])
107 if !isAppendOnly {
108 return fmt.Errorf("the new pull does not treat submissions as append-only")
109 }
110 } else if !submissionsEqual {
111 return fmt.Errorf("the new pull does not treat submissions as append-only")
112 }
113
114 pull.ID = existingPull.ID
115 pull.PullId = existingPull.PullId
116 return updatePull(tx, pull, existingPull)
117 }
118}
119
120func createNewPull(tx *sql.Tx, pull *models.Pull) error {
121 _, err := tx.Exec(`
122 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
123 values (?, 1)
124 `, pull.RepoAt)
125 if err != nil {
126 return err
127 }
128
129 var nextId int
130 err = tx.QueryRow(`
131 update repo_pull_seqs
132 set next_pull_id = next_pull_id + 1
133 where repo_at = ?
134 returning next_pull_id - 1
135 `, pull.RepoAt).Scan(&nextId)
136 if err != nil {
137 return err
138 }
139
140 pull.PullId = nextId
141 pull.State = models.PullOpen
142
143 var sourceBranch, sourceRepoAt *string
144 if pull.PullSource != nil {
145 sourceBranch = &pull.PullSource.Branch
146 if pull.PullSource.RepoAt != nil {
147 x := pull.PullSource.RepoAt.String()
148 sourceRepoAt = &x
149 }
150 }
151
152 result, err := tx.Exec(
153 `
154 insert into pulls (
155 repo_at,
156 owner_did,
157 pull_id,
158 title,
159 target_branch,
160 body,
161 rkey,
162 state,
163 dependent_on,
164 source_branch,
165 source_repo_at
166 )
167 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
168 pull.RepoAt,
169 pull.OwnerDid,
170 pull.PullId,
171 pull.Title,
172 pull.TargetBranch,
173 pull.Body,
174 pull.Rkey,
175 pull.State,
176 pull.DependentOn,
177 sourceBranch,
178 sourceRepoAt,
179 )
180 if err != nil {
181 return err
182 }
183
184 // Set the database primary key ID
185 id, err := result.LastInsertId()
186 if err != nil {
187 return err
188 }
189 pull.ID = int(id)
190
191 for i, s := range pull.Submissions {
192 _, err = tx.Exec(`
193 insert into pull_submissions (
194 pull_at,
195 round_number,
196 patch,
197 combined,
198 source_rev,
199 patch_blob_ref,
200 patch_blob_mime,
201 patch_blob_size
202 )
203 values (?, ?, ?, ?, ?, ?, ?, ?)
204 `,
205 pull.AtUri(),
206 i,
207 s.Patch,
208 s.Combined,
209 s.SourceRev,
210 s.Blob.Ref.String(),
211 s.Blob.MimeType,
212 s.Blob.Size,
213 )
214 if err != nil {
215 return err
216 }
217 }
218
219 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
220 return fmt.Errorf("put reference_links: %w", err)
221 }
222
223 return nil
224}
225
226func updatePull(tx *sql.Tx, pull *models.Pull, existingPull *models.Pull) error {
227 var sourceBranch, sourceRepoAt *string
228 if pull.PullSource != nil {
229 sourceBranch = &pull.PullSource.Branch
230 if pull.PullSource.RepoAt != nil {
231 x := pull.PullSource.RepoAt.String()
232 sourceRepoAt = &x
233 }
234 }
235
236 _, err := tx.Exec(`
237 update pulls set
238 title = ?,
239 body = ?,
240 target_branch = ?,
241 dependent_on = ?,
242 source_branch = ?,
243 source_repo_at = ?
244 where owner_did = ? and rkey = ?
245 `, pull.Title, pull.Body, pull.TargetBranch, pull.DependentOn, sourceBranch, sourceRepoAt, pull.OwnerDid, pull.Rkey)
246 if err != nil {
247 return err
248 }
249
250 // insert new submissions (append-only)
251 for i := len(existingPull.Submissions); i < len(pull.Submissions); i++ {
252 s := pull.Submissions[i]
253 _, err = tx.Exec(`
254 insert into pull_submissions (
255 pull_at,
256 round_number,
257 patch,
258 combined,
259 source_rev,
260 patch_blob_ref,
261 patch_blob_mime,
262 patch_blob_size
263 )
264 values (?, ?, ?, ?, ?, ?, ?, ?)
265 `,
266 pull.AtUri(),
267 i,
268 s.Patch,
269 s.Combined,
270 s.SourceRev,
271 s.Blob.Ref.String(),
272 s.Blob.MimeType,
273 s.Blob.Size,
274 )
275 if err != nil {
276 return err
277 }
278 }
279
280 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
281 return fmt.Errorf("put reference_links: %w", err)
282 }
283 return nil
284}
285
286func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
287 var pullId int
288 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
289 return pullId - 1, err
290}
291
292func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
293 pulls := make(map[syntax.ATURI]*models.Pull)
294
295 var conditions []string
296 var args []any
297 for _, filter := range filters {
298 conditions = append(conditions, filter.Condition())
299 args = append(args, filter.Arg()...)
300 }
301
302 whereClause := ""
303 if conditions != nil {
304 whereClause = " where " + strings.Join(conditions, " and ")
305 }
306 pageClause := ""
307 if page.Limit != 0 {
308 pageClause = fmt.Sprintf(
309 " limit %d offset %d ",
310 page.Limit,
311 page.Offset,
312 )
313 }
314
315 query := fmt.Sprintf(`
316 select
317 id,
318 owner_did,
319 repo_at,
320 pull_id,
321 created,
322 title,
323 state,
324 target_branch,
325 body,
326 rkey,
327 source_branch,
328 source_repo_at,
329 dependent_on
330 from
331 pulls
332 %s
333 order by
334 created desc
335 %s
336 `, whereClause, pageClause)
337
338 rows, err := e.Query(query, args...)
339 if err != nil {
340 return nil, err
341 }
342 defer rows.Close()
343
344 for rows.Next() {
345 var pull models.Pull
346 var createdAt string
347 var sourceBranch, sourceRepoAt, dependentOn sql.NullString
348 err := rows.Scan(
349 &pull.ID,
350 &pull.OwnerDid,
351 &pull.RepoAt,
352 &pull.PullId,
353 &createdAt,
354 &pull.Title,
355 &pull.State,
356 &pull.TargetBranch,
357 &pull.Body,
358 &pull.Rkey,
359 &sourceBranch,
360 &sourceRepoAt,
361 &dependentOn,
362 )
363 if err != nil {
364 return nil, err
365 }
366
367 createdTime, err := time.Parse(time.RFC3339, createdAt)
368 if err != nil {
369 return nil, err
370 }
371 pull.Created = createdTime
372
373 if sourceBranch.Valid {
374 pull.PullSource = &models.PullSource{
375 Branch: sourceBranch.String,
376 }
377 if sourceRepoAt.Valid {
378 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
379 if err != nil {
380 return nil, err
381 }
382 pull.PullSource.RepoAt = &sourceRepoAtParsed
383 }
384 }
385
386 if dependentOn.Valid {
387 x := syntax.ATURI(dependentOn.String)
388 pull.DependentOn = &x
389 }
390
391 pulls[pull.AtUri()] = &pull
392 }
393
394 var pullAts []syntax.ATURI
395 for _, p := range pulls {
396 pullAts = append(pullAts, p.AtUri())
397 }
398 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
399 if err != nil {
400 return nil, fmt.Errorf("failed to get submissions: %w", err)
401 }
402
403 for pullAt, submissions := range submissionsMap {
404 if p, ok := pulls[pullAt]; ok {
405 p.Submissions = submissions
406 }
407 }
408
409 // collect allLabels for each issue
410 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
411 if err != nil {
412 return nil, fmt.Errorf("failed to query labels: %w", err)
413 }
414 for pullAt, labels := range allLabels {
415 if p, ok := pulls[pullAt]; ok {
416 p.Labels = labels
417 }
418 }
419
420 // build up reverse mappings: p.Repo and p.PullSource
421 var repoAts []syntax.ATURI
422 for _, p := range pulls {
423 repoAts = append(repoAts, p.RepoAt)
424 if p.PullSource != nil && p.PullSource.RepoAt != nil {
425 repoAts = append(repoAts, *p.PullSource.RepoAt)
426 }
427 }
428
429 repos, err := GetRepos(e, orm.FilterIn("at_uri", repoAts))
430 if err != nil && !errors.Is(err, sql.ErrNoRows) {
431 return nil, fmt.Errorf("failed to get source repos: %w", err)
432 }
433
434 repoMap := make(map[syntax.ATURI]*models.Repo)
435 for _, r := range repos {
436 repoMap[r.RepoAt()] = &r
437 }
438
439 for _, p := range pulls {
440 if repo, ok := repoMap[p.RepoAt]; ok {
441 p.Repo = repo
442 }
443
444 if p.PullSource != nil && p.PullSource.RepoAt != nil {
445 if sourceRepo, ok := repoMap[*p.PullSource.RepoAt]; ok {
446 p.PullSource.Repo = sourceRepo
447 }
448 }
449 }
450
451 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
452 if err != nil {
453 return nil, fmt.Errorf("failed to query reference_links: %w", err)
454 }
455 for pullAt, references := range allReferences {
456 if pull, ok := pulls[pullAt]; ok {
457 pull.References = references
458 }
459 }
460
461 orderedByPullId := []*models.Pull{}
462 for _, p := range pulls {
463 orderedByPullId = append(orderedByPullId, p)
464 }
465 sort.Slice(orderedByPullId, func(i, j int) bool {
466 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
467 })
468
469 return orderedByPullId, nil
470}
471
472func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
473 return GetPullsPaginated(e, pagination.Page{}, filters...)
474}
475
476func GetPull(e Execer, filters ...orm.Filter) (*models.Pull, error) {
477 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, filters...)
478 if err != nil {
479 return nil, err
480 }
481 if len(pulls) == 0 {
482 return nil, sql.ErrNoRows
483 }
484
485 return pulls[0], nil
486}
487
488// mapping from pull -> pull submissions
489func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
490 var conditions []string
491 var args []any
492 for _, filter := range filters {
493 conditions = append(conditions, filter.Condition())
494 args = append(args, filter.Arg()...)
495 }
496
497 whereClause := ""
498 if conditions != nil {
499 whereClause = " where " + strings.Join(conditions, " and ")
500 }
501
502 query := fmt.Sprintf(`
503 select
504 id,
505 pull_at,
506 round_number,
507 patch,
508 combined,
509 created,
510 source_rev,
511 patch_blob_ref,
512 patch_blob_mime,
513 patch_blob_size
514 from
515 pull_submissions
516 %s
517 order by
518 round_number asc
519 `, whereClause)
520
521 rows, err := e.Query(query, args...)
522 if err != nil {
523 return nil, err
524 }
525 defer rows.Close()
526
527 submissionMap := make(map[int]*models.PullSubmission)
528
529 for rows.Next() {
530 var submission models.PullSubmission
531 var submissionCreatedStr string
532 var submissionSourceRev, submissionCombined sql.Null[string]
533 var patchBlobRef, patchBlobMime sql.Null[string]
534 var patchBlobSize sql.Null[int64]
535 err := rows.Scan(
536 &submission.ID,
537 &submission.PullAt,
538 &submission.RoundNumber,
539 &submission.Patch,
540 &submissionCombined,
541 &submissionCreatedStr,
542 &submissionSourceRev,
543 &patchBlobRef,
544 &patchBlobMime,
545 &patchBlobSize,
546 )
547 if err != nil {
548 return nil, err
549 }
550
551 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
552 submission.Created = t
553 }
554
555 if submissionSourceRev.Valid {
556 submission.SourceRev = submissionSourceRev.V
557 }
558
559 if submissionCombined.Valid {
560 submission.Combined = submissionCombined.V
561 }
562
563 if patchBlobRef.Valid {
564 submission.Blob.Ref = lexutil.LexLink(cid.MustParse(patchBlobRef.V))
565 }
566
567 if patchBlobMime.Valid {
568 submission.Blob.MimeType = patchBlobMime.V
569 }
570
571 if patchBlobSize.Valid {
572 submission.Blob.Size = patchBlobSize.V
573 }
574
575 submissionMap[submission.ID] = &submission
576 }
577
578 if err := rows.Err(); err != nil {
579 return nil, err
580 }
581
582 // Get comments for all submissions using GetPullComments
583 submissionIds := slices.Collect(maps.Keys(submissionMap))
584 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
585 if err != nil {
586 return nil, fmt.Errorf("failed to get pull comments: %w", err)
587 }
588 for _, comment := range comments {
589 if submission, ok := submissionMap[comment.SubmissionId]; ok {
590 submission.Comments = append(submission.Comments, comment)
591 }
592 }
593
594 // group the submissions by pull_at
595 m := make(map[syntax.ATURI][]*models.PullSubmission)
596 for _, s := range submissionMap {
597 m[s.PullAt] = append(m[s.PullAt], s)
598 }
599
600 // sort each one by round number
601 for _, s := range m {
602 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
603 return cmp.Compare(a.RoundNumber, b.RoundNumber)
604 })
605 }
606
607 return m, nil
608}
609
610func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
611 var conditions []string
612 var args []any
613 for _, filter := range filters {
614 conditions = append(conditions, filter.Condition())
615 args = append(args, filter.Arg()...)
616 }
617
618 whereClause := ""
619 if conditions != nil {
620 whereClause = " where " + strings.Join(conditions, " and ")
621 }
622
623 query := fmt.Sprintf(`
624 select
625 id,
626 pull_id,
627 submission_id,
628 repo_at,
629 owner_did,
630 comment_at,
631 body,
632 created
633 from
634 pull_comments
635 %s
636 order by
637 created asc
638 `, whereClause)
639
640 rows, err := e.Query(query, args...)
641 if err != nil {
642 return nil, err
643 }
644 defer rows.Close()
645
646 commentMap := make(map[string]*models.PullComment)
647 for rows.Next() {
648 var comment models.PullComment
649 var createdAt string
650 err := rows.Scan(
651 &comment.ID,
652 &comment.PullId,
653 &comment.SubmissionId,
654 &comment.RepoAt,
655 &comment.OwnerDid,
656 &comment.CommentAt,
657 &comment.Body,
658 &createdAt,
659 )
660 if err != nil {
661 return nil, err
662 }
663
664 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
665 comment.Created = t
666 }
667
668 atUri := comment.AtUri().String()
669 commentMap[atUri] = &comment
670 }
671
672 if err := rows.Err(); err != nil {
673 return nil, err
674 }
675
676 // collect references for each comments
677 commentAts := slices.Collect(maps.Keys(commentMap))
678 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
679 if err != nil {
680 return nil, fmt.Errorf("failed to query reference_links: %w", err)
681 }
682 for commentAt, references := range allReferences {
683 if comment, ok := commentMap[commentAt.String()]; ok {
684 comment.References = references
685 }
686 }
687
688 var comments []models.PullComment
689 for _, c := range commentMap {
690 comments = append(comments, *c)
691 }
692
693 sort.Slice(comments, func(i, j int) bool {
694 return comments[i].Created.Before(comments[j].Created)
695 })
696
697 return comments, nil
698}
699
700// timeframe here is directly passed into the sql query filter, and any
701// timeframe in the past should be negative; e.g.: "-3 months"
702func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
703 var pulls []models.Pull
704
705 rows, err := e.Query(`
706 select
707 p.owner_did,
708 p.repo_at,
709 p.pull_id,
710 p.created,
711 p.title,
712 p.state,
713 r.did,
714 r.name,
715 r.knot,
716 r.rkey,
717 r.created
718 from
719 pulls p
720 join
721 repos r on p.repo_at = r.at_uri
722 where
723 p.owner_did = ? and p.created >= date ('now', ?)
724 order by
725 p.created desc`, did, timeframe)
726 if err != nil {
727 return nil, err
728 }
729 defer rows.Close()
730
731 for rows.Next() {
732 var pull models.Pull
733 var repo models.Repo
734 var pullCreatedAt, repoCreatedAt string
735 err := rows.Scan(
736 &pull.OwnerDid,
737 &pull.RepoAt,
738 &pull.PullId,
739 &pullCreatedAt,
740 &pull.Title,
741 &pull.State,
742 &repo.Did,
743 &repo.Name,
744 &repo.Knot,
745 &repo.Rkey,
746 &repoCreatedAt,
747 )
748 if err != nil {
749 return nil, err
750 }
751
752 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
753 if err != nil {
754 return nil, err
755 }
756 pull.Created = pullCreatedTime
757
758 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
759 if err != nil {
760 return nil, err
761 }
762 repo.Created = repoCreatedTime
763
764 pull.Repo = &repo
765
766 pulls = append(pulls, pull)
767 }
768
769 if err := rows.Err(); err != nil {
770 return nil, err
771 }
772
773 return pulls, nil
774}
775
776func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
777 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
778 res, err := tx.Exec(
779 query,
780 comment.OwnerDid,
781 comment.RepoAt,
782 comment.SubmissionId,
783 comment.CommentAt,
784 comment.PullId,
785 comment.Body,
786 )
787 if err != nil {
788 return 0, err
789 }
790
791 i, err := res.LastInsertId()
792 if err != nil {
793 return 0, err
794 }
795
796 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
797 return 0, fmt.Errorf("put reference_links: %w", err)
798 }
799
800 return i, nil
801}
802
803// use with transaction
804func SetPullsState(e Execer, pullState models.PullState, filters ...orm.Filter) error {
805 var conditions []string
806 var args []any
807
808 args = append(args, pullState)
809 for _, filter := range filters {
810 conditions = append(conditions, filter.Condition())
811 args = append(args, filter.Arg()...)
812 }
813 args = append(args, models.PullAbandoned) // only update state of non-deleted pulls
814 args = append(args, models.PullMerged) // only update state of non-merged pulls
815
816 whereClause := ""
817 if conditions != nil {
818 whereClause = " where " + strings.Join(conditions, " and ")
819 }
820
821 query := fmt.Sprintf("update pulls set state = ? %s and state <> ? and state <> ?", whereClause)
822
823 _, err := e.Exec(query, args...)
824 return err
825}
826
827func ClosePulls(e Execer, filters ...orm.Filter) error {
828 return SetPullsState(e, models.PullClosed, filters...)
829}
830
831func ReopenPulls(e Execer, filters ...orm.Filter) error {
832 return SetPullsState(e, models.PullOpen, filters...)
833}
834
835func MergePulls(e Execer, filters ...orm.Filter) error {
836 return SetPullsState(e, models.PullMerged, filters...)
837}
838
839func AbandonPulls(e Execer, filters ...orm.Filter) error {
840 return SetPullsState(e, models.PullAbandoned, filters...)
841}
842
843func ResubmitPull(
844 e Execer,
845 pullAt syntax.ATURI,
846 newRoundNumber int,
847 newPatch string,
848 combinedPatch string,
849 newSourceRev string,
850 blob *lexutil.LexBlob,
851) error {
852 _, err := e.Exec(`
853 insert into pull_submissions (
854 pull_at,
855 round_number,
856 patch,
857 combined,
858 source_rev,
859 patch_blob_ref,
860 patch_blob_mime,
861 patch_blob_size
862 )
863 values (?, ?, ?, ?, ?, ?, ?, ?)
864 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev, blob.Ref.String(), blob.MimeType, blob.Size)
865
866 return err
867}
868
869func SetDependentOn(e Execer, dependentOn syntax.ATURI, filters ...orm.Filter) error {
870 var conditions []string
871 var args []any
872
873 args = append(args, dependentOn)
874
875 for _, filter := range filters {
876 conditions = append(conditions, filter.Condition())
877 args = append(args, filter.Arg()...)
878 }
879
880 whereClause := ""
881 if conditions != nil {
882 whereClause = " where " + strings.Join(conditions, " and ")
883 }
884
885 query := fmt.Sprintf("update pulls set dependent_on = ? %s", whereClause)
886 _, err := e.Exec(query, args...)
887
888 return err
889}
890
891func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
892 row := e.QueryRow(`
893 select
894 count(case when state = ? then 1 end) as open_count,
895 count(case when state = ? then 1 end) as merged_count,
896 count(case when state = ? then 1 end) as closed_count,
897 count(case when state = ? then 1 end) as deleted_count
898 from pulls
899 where repo_at = ?`,
900 models.PullOpen,
901 models.PullMerged,
902 models.PullClosed,
903 models.PullAbandoned,
904 repoAt,
905 )
906
907 var count models.PullCount
908 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
909 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
910 }
911
912 return count, nil
913}
914
915// change-id dependent_on
916//
917// 4 w ,-------- at_uri(z) (TOP)
918// 3 z <----',------- at_uri(y)
919// 2 y <-----',------ at_uri(x)
920// 1 x <------' nil (BOT)
921//
922// `w` has no dependents, so it is the top of the stack
923//
924// this unfortunately does a db query for *each* pull of the stack,
925// ideally this would be a recursive query, but in the interest of implementation simplicity,
926// we took the less performant route
927//
928// TODO: make this less bad
929func GetStack(e Execer, atUri syntax.ATURI) (models.Stack, error) {
930 // first get the pull for the given at-uri
931 pull, err := GetPull(e, orm.FilterEq("at_uri", atUri))
932 if err != nil {
933 return nil, err
934 }
935
936 // Collect all pulls in the stack by traversing up and down
937 allPulls := []*models.Pull{pull}
938 visited := sets.New[syntax.ATURI]()
939
940 // Traverse up to find all dependents
941 current := pull
942 for {
943 dependent, err := GetPull(e,
944 orm.FilterEq("dependent_on", current.AtUri()),
945 orm.FilterNotEq("state", models.PullAbandoned),
946 )
947 if err != nil || dependent == nil {
948 break
949 }
950 if visited.Contains(dependent.AtUri()) {
951 return allPulls, fmt.Errorf("circular dependency detected in stack")
952 }
953 allPulls = append(allPulls, dependent)
954 visited.Insert(dependent.AtUri())
955 current = dependent
956 }
957
958 // Traverse down to find all dependencies
959 current = pull
960 for current.DependentOn != nil {
961 dependency, err := GetPull(
962 e,
963 orm.FilterEq("at_uri", current.DependentOn),
964 orm.FilterNotEq("state", models.PullAbandoned),
965 )
966
967 if err != nil {
968 return allPulls, fmt.Errorf("failed to find parent pull request, stack is malformed, missing PR: %s", current.DependentOn)
969 }
970 if visited.Contains(dependency.AtUri()) {
971 return allPulls, fmt.Errorf("circular dependency detected in stack")
972 }
973 allPulls = append(allPulls, dependency)
974 visited.Insert(dependency.AtUri())
975 current = dependency
976 }
977
978 // sort the list: find the top and build ordered list
979 atUriMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
980 dependentMap := make(map[syntax.ATURI]*models.Pull, len(allPulls))
981
982 for _, p := range allPulls {
983 atUriMap[p.AtUri()] = p
984 if p.DependentOn != nil {
985 dependentMap[*p.DependentOn] = p
986 }
987 }
988
989 // the top of the stack is the pull that no other pull depends on
990 var topPull *models.Pull
991 for _, maybeTop := range allPulls {
992 if _, ok := dependentMap[maybeTop.AtUri()]; !ok {
993 topPull = maybeTop
994 break
995 }
996 }
997
998 pulls := []*models.Pull{}
999 for {
1000 pulls = append(pulls, topPull)
1001 if topPull.DependentOn != nil {
1002 if next, ok := atUriMap[*topPull.DependentOn]; ok {
1003 topPull = next
1004 } else {
1005 return pulls, fmt.Errorf("failed to find parent pull request, stack is malformed")
1006 }
1007 } else {
1008 break
1009 }
1010 }
1011
1012 return pulls, nil
1013}
1014
1015func GetAbandonedPulls(e Execer, atUri syntax.ATURI) ([]*models.Pull, error) {
1016 stack, err := GetStack(e, atUri)
1017 if err != nil {
1018 return nil, err
1019 }
1020
1021 var abandoned []*models.Pull
1022 for _, p := range stack {
1023 if p.State == models.PullAbandoned {
1024 abandoned = append(abandoned, p)
1025 }
1026 }
1027
1028 return abandoned, nil
1029}