package db import ( "context" "errors" "fmt" "os" "rvcx/internal/atputils" "rvcx/internal/lex" "rvcx/internal/types" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type Store struct { pool *pgxpool.Pool } func Init() (*Store, error) { pool, err := initialize() return &Store{pool}, err } func (s *Store) Close() { s.pool.Close() } func initialize() (*pgxpool.Pool, error) { dbuser := os.Getenv("POSTGRES_USER") dbpass := os.Getenv("POSTGRES_PASSWORD") dbhost := "localhost" dbport := os.Getenv("POSTGRES_PORT") dbdb := os.Getenv("POSTGRES_DB") dburl := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dbuser, dbpass, dbhost, dbport, dbdb) pool, err := pgxpool.New(context.Background(), dburl) if err != nil { return nil, err } pingErr := pool.Ping(context.Background()) if pingErr != nil { return nil, pingErr } fmt.Println("connected!") return pool, nil } func (s *Store) ResolveHandle(handle string, ctx context.Context) (string, error) { row := s.pool.QueryRow(ctx, `SELECT h.did FROM did_handles h WHERE h.handle = $1`, handle) var did string err := row.Scan(&did) if err != nil { return "", err } return did, nil } func (s *Store) FullResolveHandle(hdl string, ctx context.Context) (string, error) { did, err := s.ResolveHandle(hdl, ctx) if err == nil { return did, nil } did, err = atputils.TryLookupHandle(ctx, hdl) if err != nil { return "", errors.New("couldn't resolve: " + err.Error()) } s.StoreDidHandle(did, hdl, ctx) return did, nil } func (s *Store) ResolveDid(did string, ctx context.Context) (string, error) { row := s.pool.QueryRow(ctx, `SELECT h.handle FROM did_handles h WHERE h.did = $1`, did) var handle string err := row.Scan(&handle) if err != nil { return "", errors.New("error scanning row for handle: " + err.Error()) } return handle, nil } func (s *Store) FullResolveDid(did string, ctx context.Context) (string, error) { hdl, err := s.ResolveDid(did, ctx) if err == nil { return hdl, nil } hdl, err = atputils.TryLookupDid(ctx, did) if err != nil { return "", errors.New("couldn't resolve: " + err.Error()) } s.StoreDidHandle(did, hdl, ctx) return hdl, nil } func (s *Store) StoreDidHandle(did string, handle string, ctx context.Context) error { _, err := s.pool.Exec(ctx, `INSERT INTO did_handles ( handle, did ) VALUES ($1, $2) ON CONFLICT (handle) DO NOTHING`, handle, did) if err != nil { return errors.New("error storing did/handle: " + err.Error()) } return nil } func (s *Store) GetLastSeen(did string, ctx context.Context) (where *string, when *time.Time) { row := s.pool.QueryRow(ctx, `SELECT s.channel_uri, m.posted_at FROM messages m JOIN signets s ON m.signet_uri = s.uri JOIN did_handles dh ON m.did = dh.did WHERE m.did = $1 AND dh.handle = s.author_handle ORDER BY m.posted_at DESC`, did) row.Scan(&where, &when) return } func (s *Store) GetHistory(channelURI string, limit int, cursor *int, ctx context.Context) ([]types.SignedItemView, error) { queryFmt := ` SELECT 'message' AS content_type, m.uri, m.did, dh.handle, p.display_name, p.status, p.color, p.avatar_cid, p.default_nick, m.body, NULL AS blob_cid, NULL AS blob_mime, NULL AS alt, NULL AS height, NULL AS width, m.nick, m.color, s.uri, s.issuer_did, s.channel_uri, s.message_id, s.author, s.author_handle, s.started_at, m.posted_at FROM signets s JOIN messages m ON s.uri = m.signet_uri JOIN did_handles dh ON m.did = dh.did JOIN profiles p ON m.did = p.did WHERE s.channel_uri = $2 AND m.did = s.author %s UNION ALL SELECT 'image' AS content_type, i.uri, i.did, dh.handle, p.display_name, p.status, p.color, p.avatar_cid, p.default_nick, NULL AS body, i.blob_cid, i.blob_mime, i.alt, i.height, i.width, i.nick, i.color, s.uri, s.issuer_did, s.channel_uri, s.message_id, s.author, s.author_handle, s.started_at, i.posted_at FROM signets s JOIN images i ON s.uri = i.signet_uri JOIN did_handles dh ON i.did = dh.did JOIN profiles p ON i.did = p.did WHERE s.channel_uri = $2 AND i.did = s.author %s ORDER BY message_id DESC LIMIT $1 ` var query string if cursor != nil { query = fmt.Sprintf(queryFmt, "AND s.message_id < $3", "AND s.message_id < $3") return s.evalGetItems(query, ctx, limit, channelURI, *cursor) } else { query = fmt.Sprintf(queryFmt, "", "") return s.evalGetItems(query, ctx, limit, channelURI) } } func (s *Store) evalGetItems(query string, ctx context.Context, limit int, params ...any) ([]types.SignedItemView, error) { args := []any{limit} args = append(args, params...) rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var items = make([]types.SignedItemView, 0) for rows.Next() { var t string var p types.ProfileView var uri string var body *string var image types.Image var alt *string var nick string var color uint32 var s types.SignetView var time time.Time err := rows.Scan( &t, &uri, &p.DID, &p.Handle, &p.DisplayName, &p.Status, &p.Color, &p.Avatar, &p.DefaultNick, &body, &image.BlobCID, &image.BlobMIME, &alt, &image.Height, &image.Width, &nick, &color, &s.URI, &s.Issuer, &s.ChannelURI, &s.LrcId, &s.Author, &s.AuthorHandle, &s.StartedAt, &time, ) if err != nil { return nil, err } if t == "message" { var msg types.SignedMessageView if body != nil { msg.Body = *body } if nick != "" { msg.Nick = &nick } if color != 0 { msg.Color = &color } msg.Author = p msg.Signet = s msg.PostedAt = time msg.URI = uri items = append(items, msg) } else if t == "image" { var img types.SignedMediaView var imgview types.ImageView if image.Height != nil && image.Width != nil { var aspect lex.AspectRatio aspect.Width = *image.Width aspect.Height = *image.Height imgview.AspectRatio = &aspect } if alt != nil { imgview.Alt = *alt } base := os.Getenv("MY_IDENTITY") src := fmt.Sprintf("https://%s/xrpc/org.xcvr.lrc.getImage?uri=%s", base, uri) imgview.Src = &src img.Image = &imgview if nick != "" { img.Nick = &nick } if color != 0 { img.Color = &color } img.Author = p img.Signet = s img.PostedAt = time img.URI = uri items = append(items, img) } else { return nil, errors.New("recieved strange type t: " + t) } } return items, nil } func (s *Store) GetMessages(channelURI string, limit int, cursor *int, ctx context.Context) ([]types.SignedMessageView, error) { queryFmt := ` SELECT m.uri, m.did, dh.handle, p.display_name, p.status, p.color, p.avatar_cid, p.default_nick, m.body, m.nick, m.color, s.uri, issuer_dh.handle, s.channel_uri, s.message_id, s.author_handle, s.started_at, m.posted_at FROM messages m JOIN signets s ON m.signet_uri = s.uri JOIN did_handles dh ON m.did = dh.did LEFT JOIN profiles p ON m.did = p.did JOIN did_handles issuer_dh ON s.issuer_did = issuer_dh.did WHERE s.channel_uri = $2 AND dh.handle = s.author_handle %s ORDER BY s.message_id DESC LIMIT $1 ` var query string if cursor != nil { query = fmt.Sprintf(queryFmt, "AND s.message_id < $3") return s.evalGetMessages(query, ctx, limit, channelURI, *cursor) } else { query = fmt.Sprintf(queryFmt, "") return s.evalGetMessages(query, ctx, limit, channelURI) } } func (s *Store) evalGetMessages(query string, ctx context.Context, limit int, params ...any) ([]types.SignedMessageView, error) { args := []any{limit} args = append(args, params...) rows, err := s.pool.Query(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var msgs = make([]types.SignedMessageView, 0) for rows.Next() { var msg types.SignedMessageView err := rows.Scan( &msg.URI, &msg.Author.DID, &msg.Author.Handle, &msg.Author.DisplayName, &msg.Author.Status, &msg.Author.Color, &msg.Author.Avatar, &msg.Author.DefaultNick, &msg.Body, &msg.Nick, &msg.Color, &msg.Signet.URI, &msg.Signet.Issuer, &msg.Signet.ChannelURI, &msg.Signet.LrcId, &msg.Signet.AuthorHandle, &msg.Signet.StartedAt, &msg.PostedAt, ) if err != nil { return nil, err } msgs = append(msgs, msg) } return msgs, nil } func (s *Store) GetChannelURI(handle string, title string, ctx context.Context) (string, error) { rows, err := s.pool.Query(ctx, ` SELECT channels.uri FROM channels LEFT JOIN did_handles ON channels.did = did_handles.did WHERE channels.title = $1 AND did_handles.handle = $2 ORDER BY channels.created_at DESC LIMIT 1 `, title, handle) if err != nil { return "", err } defer rows.Close() var uri string rows.Next() err = rows.Scan(&uri) if err != nil { return "", err } return uri, nil } type URIHost struct { URI string Host string Topic string LastID uint32 } func (s *Store) GetChannelURIs(ctx context.Context) ([]URIHost, error) { rows, err := s.pool.Query(ctx, ` SELECT channels.uri, channels.host, channels.topic FROM channels `) if err != nil { return nil, err } defer rows.Close() var urihosts = make([]URIHost, 0, 100) for rows.Next() { var urihost URIHost err := rows.Scan(&urihost.URI, &urihost.Host, &urihost.Topic) if err != nil { return nil, err } var maxMessageID uint32 err = s.pool.QueryRow(ctx, ` SELECT COALESCE(MAX(message_id), 0) FROM signets WHERE channel_uri = $1 `, urihost.URI).Scan(&maxMessageID) if err != nil { return nil, err } urihost.LastID = maxMessageID urihosts = append(urihosts, urihost) } return urihosts, nil } func (s *Store) GetChannelViews(limit int, ctx context.Context) ([]types.ChannelView, error) { rows, err := s.pool.Query(ctx, ` SELECT channels.uri, channels.host, channels.title, channels.topic, channels.created_at, did_handles.did, did_handles.handle, profiles.display_name, profiles.status, profiles.color, profiles.avatar_cid FROM channels LEFT JOIN profiles ON channels.did = profiles.did LEFT JOIN did_handles ON profiles.did = did_handles.did ORDER BY channels.created_at DESC LIMIT $1 `, limit) if err != nil { return nil, err } defer rows.Close() var chans = make([]types.ChannelView, 0, limit) for rows.Next() { var c types.ChannelView var p types.ProfileView err := rows.Scan(&c.URI, &c.Host, &c.Title, &c.Topic, &c.CreatedAt, &p.DID, &p.Handle, &p.DisplayName, &p.Status, &p.Color, &p.Avatar) if err != nil { return nil, err } c.Creator = p chans = append(chans, c) } return chans, nil } func (s *Store) GetChannelView(uri string, ctx context.Context) (*types.ChannelView, error) { row := s.pool.QueryRow(ctx, ` SELECT channels.uri, channels.host, channels.title, channels.topic, channels.created_at, did_handles.did, did_handles.handle, profiles.display_name, profiles.status, profiles.color, profiles.avatar_cid FROM channels LEFT JOIN profiles ON channels.did = profiles.did LEFT JOIN did_handles ON profiles.did = did_handles.did WHERE channels.uri = $1 `, uri) var c types.ChannelView var p types.ProfileView err := row.Scan(&c.URI, &c.Host, &c.Title, &c.Topic, &c.CreatedAt, &p.DID, &p.Handle, &p.DisplayName, &p.Status, &p.Color, &p.Avatar) if err != nil { return nil, err } c.Creator = p return &c, nil } func (s *Store) GetChannelViewHR(handle string, rkey string, ctx context.Context) (*types.ChannelView, error) { did, err := s.ResolveHandle(handle, ctx) if err != nil { return nil, err } uri := fmt.Sprintf("at://%s/org.xcvr.feed.channel/%s", did, rkey) row := s.pool.QueryRow(ctx, ` SELECT channels.uri, channels.host, channels.title, channels.topic, channels.created_at, did_handles.did, did_handles.handle, profiles.display_name, profiles.status, profiles.color, profiles.avatar_cid FROM channels LEFT JOIN profiles ON channels.did = profiles.did LEFT JOIN did_handles ON profiles.did = did_handles.did WHERE channels.uri = $1 `, uri) var c types.ChannelView var p types.ProfileView err = row.Scan(&c.URI, &c.Host, &c.Title, &c.Topic, &c.CreatedAt, &p.DID, &p.Handle, &p.DisplayName, &p.Status, &p.Color, &p.Avatar) if err != nil { return nil, err } c.Creator = p return &c, nil } func (s *Store) DeleteChannel(uri string, ctx context.Context) error { _, err := s.pool.Exec(ctx, `DELETE FROM channels WHERE uri = $1`, uri) return err } func (s *Store) GetBanned(did string, ctx context.Context) (*types.Ban, error) { row := s.pool.QueryRow(ctx, `SELECT id, reason, till, banned_at FROM bans WHERE did = $1 ORDER BY id DESC`, did) var ban types.Ban err := row.Scan(&ban.Id, &ban.Reason, &ban.Till, &ban.BannedAt) if err != nil { return nil, err } ban.Did = did return &ban, nil } func (s *Store) GetBanId(id int, ctx context.Context) (*types.Ban, error) { row := s.pool.QueryRow(ctx, `SELECT did, reason, till, banned_at FROM bans WHERE id = $1`, id) var ban types.Ban err := row.Scan(&ban.Did, &ban.Reason, &ban.Till, &ban.BannedAt) if err != nil { return nil, err } ban.Id = id return &ban, nil } func (s *Store) AddBan(did string, reason *string, till *time.Time, ctx context.Context) error { _, err := s.pool.Exec(ctx, `INSERT INTO bans ( did, reason, till ) VALUES ( $1, $2, $3 ) `, did, reason, till) return err } func (s *Store) IsBanned(did string, ctx context.Context) (bool, error) { ban, err := s.GetBanned(did, ctx) if ban != nil { defbanned := false if ban.Till == nil { defbanned = true } else { defbanned = time.Now().Before(*ban.Till) } if defbanned { return true, nil } } if err != nil && !errors.Is(err, pgx.ErrNoRows) { return false, err } return false, nil }