···11+//go:build billing
22+33+package billing
44+55+import (
66+ "encoding/json"
77+ "errors"
88+ "fmt"
99+ "io"
1010+ "log/slog"
1111+ "net/http"
1212+ "os"
1313+ "sort"
1414+ "strings"
1515+ "sync"
1616+ "time"
1717+1818+ "github.com/stripe/stripe-go/v84"
1919+ portalsession "github.com/stripe/stripe-go/v84/billingportal/session"
2020+ "github.com/stripe/stripe-go/v84/checkout/session"
2121+ "github.com/stripe/stripe-go/v84/customer"
2222+ "github.com/stripe/stripe-go/v84/price"
2323+ "github.com/stripe/stripe-go/v84/subscription"
2424+ "github.com/stripe/stripe-go/v84/webhook"
2525+2626+ "atcr.io/pkg/hold/quota"
2727+)
2828+2929+// Manager handles Stripe billing integration.
3030+type Manager struct {
3131+ quotaMgr *quota.Manager
3232+ billingCfg *BillingConfig
3333+ holdPublicURL string
3434+ stripeKey string
3535+ webhookSecret string
3636+ publishableKey string
3737+3838+ // In-memory cache for customer lookups (DID -> customer)
3939+ customerCache map[string]*cachedCustomer
4040+ customerCacheMu sync.RWMutex
4141+}
4242+4343+type cachedCustomer struct {
4444+ customer *stripe.Customer
4545+ expiresAt time.Time
4646+}
4747+4848+const customerCacheTTL = 10 * time.Minute
4949+5050+// New creates a new billing manager with Stripe integration.
5151+func New(quotaMgr *quota.Manager, holdPublicURL string) *Manager {
5252+ stripeKey := os.Getenv("STRIPE_SECRET_KEY")
5353+ if stripeKey != "" {
5454+ stripe.Key = stripeKey
5555+ }
5656+5757+ return &Manager{
5858+ quotaMgr: quotaMgr,
5959+ holdPublicURL: holdPublicURL,
6060+ stripeKey: stripeKey,
6161+ webhookSecret: os.Getenv("STRIPE_WEBHOOK_SECRET"),
6262+ publishableKey: os.Getenv("STRIPE_PUBLISHABLE_KEY"),
6363+ customerCache: make(map[string]*cachedCustomer),
6464+ }
6565+}
6666+6767+// Enabled returns true if billing is properly configured.
6868+func (m *Manager) Enabled() bool {
6969+ return m.billingCfg != nil && m.billingCfg.Enabled && m.stripeKey != ""
7070+}
7171+7272+// GetSubscriptionInfo returns subscription and quota information for a user.
7373+func (m *Manager) GetSubscriptionInfo(userDID string) (*SubscriptionInfo, error) {
7474+ if !m.Enabled() {
7575+ return nil, ErrBillingDisabled
7676+ }
7777+7878+ info := &SubscriptionInfo{
7979+ UserDID: userDID,
8080+ PaymentsEnabled: true,
8181+ Tiers: m.buildTierList(userDID),
8282+ }
8383+8484+ // Try to find existing customer
8585+ cust, err := m.findCustomerByDID(userDID)
8686+ if err != nil {
8787+ slog.Debug("No Stripe customer found for user", "userDid", userDID)
8888+ } else if cust != nil {
8989+ info.CustomerID = cust.ID
9090+9191+ // Get active subscription if any (check all nil pointers)
9292+ if cust.Subscriptions != nil && len(cust.Subscriptions.Data) > 0 {
9393+ sub := cust.Subscriptions.Data[0]
9494+ info.SubscriptionID = sub.ID
9595+9696+ // Safely access subscription items
9797+ if sub.Items != nil && len(sub.Items.Data) > 0 && sub.Items.Data[0].Price != nil {
9898+ info.CurrentTier = m.billingCfg.GetTierByPriceID(sub.Items.Data[0].Price.ID)
9999+100100+ if sub.Items.Data[0].Price.Recurring != nil {
101101+ switch sub.Items.Data[0].Price.Recurring.Interval {
102102+ case stripe.PriceRecurringIntervalMonth:
103103+ info.BillingInterval = "monthly"
104104+ case stripe.PriceRecurringIntervalYear:
105105+ info.BillingInterval = "yearly"
106106+ }
107107+ }
108108+ }
109109+ }
110110+ }
111111+112112+ // If no subscription, use default tier
113113+ if info.CurrentTier == "" {
114114+ info.CurrentTier = m.quotaMgr.GetDefaultTier()
115115+ }
116116+117117+ // Get quota limit for current tier
118118+ limit := m.quotaMgr.GetTierLimit(info.CurrentTier)
119119+ info.CurrentLimit = limit
120120+121121+ // Mark current tier in tier list
122122+ for i := range info.Tiers {
123123+ if info.Tiers[i].ID == info.CurrentTier {
124124+ info.Tiers[i].IsCurrent = true
125125+ }
126126+ }
127127+128128+ return info, nil
129129+}
130130+131131+// buildTierList creates the list of available tiers by merging quota limits
132132+// from the quota manager with billing metadata from the billing config.
133133+func (m *Manager) buildTierList(userDID string) []TierInfo {
134134+ quotaTiers := m.quotaMgr.ListTiers()
135135+ if len(quotaTiers) == 0 {
136136+ return nil
137137+ }
138138+139139+ result := make([]TierInfo, 0, len(quotaTiers))
140140+ for _, qt := range quotaTiers {
141141+ var quotaBytes int64
142142+ if qt.Limit != nil {
143143+ quotaBytes = *qt.Limit
144144+ }
145145+146146+ // Capitalize tier ID for display name (e.g., "swabbie" -> "Swabbie")
147147+ name := strings.ToUpper(qt.Key[:1]) + qt.Key[1:]
148148+149149+ tier := TierInfo{
150150+ ID: qt.Key,
151151+ Name: name,
152152+ QuotaBytes: quotaBytes,
153153+ QuotaFormatted: quota.FormatHumanBytes(quotaBytes),
154154+ }
155155+156156+ // Merge billing metadata if available
157157+ if bt := m.billingCfg.GetTierPricing(qt.Key); bt != nil {
158158+ tier.Description = bt.Description
159159+160160+ // Fetch actual prices from Stripe
161161+ if bt.StripePriceMonthly != "" {
162162+ if p, err := price.Get(bt.StripePriceMonthly, nil); err == nil && p != nil {
163163+ tier.PriceCentsMonthly = int(p.UnitAmount)
164164+ } else {
165165+ slog.Debug("Failed to fetch monthly price", "priceId", bt.StripePriceMonthly, "error", err)
166166+ tier.PriceCentsMonthly = -1
167167+ }
168168+ }
169169+ if bt.StripePriceYearly != "" {
170170+ if p, err := price.Get(bt.StripePriceYearly, nil); err == nil && p != nil {
171171+ tier.PriceCentsYearly = int(p.UnitAmount)
172172+ } else {
173173+ slog.Debug("Failed to fetch yearly price", "priceId", bt.StripePriceYearly, "error", err)
174174+ tier.PriceCentsYearly = -1
175175+ }
176176+ }
177177+ }
178178+179179+ result = append(result, tier)
180180+ }
181181+182182+ // Sort tiers by quota size (ascending)
183183+ sort.Slice(result, func(i, j int) bool {
184184+ return result[i].QuotaBytes < result[j].QuotaBytes
185185+ })
186186+187187+ return result
188188+}
189189+190190+// CreateCheckoutSession creates a Stripe checkout session for subscription.
191191+func (m *Manager) CreateCheckoutSession(r *http.Request, req *CheckoutSessionRequest) (*CheckoutSessionResponse, error) {
192192+ if !m.Enabled() {
193193+ return nil, ErrBillingDisabled
194194+ }
195195+196196+ // Get user DID from request context (set by auth middleware)
197197+ userDID := r.Header.Get("X-User-DID")
198198+ if userDID == "" {
199199+ return nil, errors.New("user not authenticated")
200200+ }
201201+202202+ // Get tier config
203203+ tierCfg := m.billingCfg.GetTierPricing(req.Tier)
204204+ if tierCfg == nil {
205205+ return nil, fmt.Errorf("tier not found: %s", req.Tier)
206206+ }
207207+208208+ // Determine price ID - prefer requested interval, fall back to what's available
209209+ var priceID string
210210+ switch req.Interval {
211211+ case "monthly":
212212+ priceID = tierCfg.StripePriceMonthly
213213+ case "yearly":
214214+ priceID = tierCfg.StripePriceYearly
215215+ default:
216216+ // No interval specified - prefer monthly, fall back to yearly
217217+ if tierCfg.StripePriceMonthly != "" {
218218+ priceID = tierCfg.StripePriceMonthly
219219+ } else {
220220+ priceID = tierCfg.StripePriceYearly
221221+ }
222222+ }
223223+224224+ if priceID == "" {
225225+ return nil, fmt.Errorf("tier %s has no Stripe price configured", req.Tier)
226226+ }
227227+228228+ // Get or create customer
229229+ cust, err := m.getOrCreateCustomer(userDID)
230230+ if err != nil {
231231+ return nil, fmt.Errorf("failed to get/create customer: %w", err)
232232+ }
233233+234234+ // Build success/cancel URLs
235235+ successURL := strings.ReplaceAll(m.billingCfg.SuccessURL, "{hold_url}", m.holdPublicURL)
236236+ cancelURL := strings.ReplaceAll(m.billingCfg.CancelURL, "{hold_url}", m.holdPublicURL)
237237+238238+ if req.ReturnURL != "" {
239239+ successURL = req.ReturnURL + "?success=true"
240240+ cancelURL = req.ReturnURL + "?cancelled=true"
241241+ }
242242+243243+ // Create checkout session
244244+ params := &stripe.CheckoutSessionParams{
245245+ Customer: stripe.String(cust.ID),
246246+ Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
247247+ LineItems: []*stripe.CheckoutSessionLineItemParams{
248248+ {
249249+ Price: stripe.String(priceID),
250250+ Quantity: stripe.Int64(1),
251251+ },
252252+ },
253253+ SuccessURL: stripe.String(successURL),
254254+ CancelURL: stripe.String(cancelURL),
255255+ }
256256+257257+ sess, err := session.New(params)
258258+ if err != nil {
259259+ return nil, fmt.Errorf("failed to create checkout session: %w", err)
260260+ }
261261+262262+ return &CheckoutSessionResponse{
263263+ CheckoutURL: sess.URL,
264264+ SessionID: sess.ID,
265265+ }, nil
266266+}
267267+268268+// GetBillingPortalURL returns a URL to the Stripe billing portal.
269269+func (m *Manager) GetBillingPortalURL(userDID string, returnURL string) (*BillingPortalResponse, error) {
270270+ if !m.Enabled() {
271271+ return nil, ErrBillingDisabled
272272+ }
273273+274274+ // Find existing customer
275275+ cust, err := m.findCustomerByDID(userDID)
276276+ if err != nil || cust == nil {
277277+ return nil, errors.New("no billing account found")
278278+ }
279279+280280+ if returnURL == "" {
281281+ returnURL = m.holdPublicURL
282282+ }
283283+284284+ params := &stripe.BillingPortalSessionParams{
285285+ Customer: stripe.String(cust.ID),
286286+ ReturnURL: stripe.String(returnURL),
287287+ }
288288+289289+ sess, err := portalsession.New(params)
290290+ if err != nil {
291291+ return nil, fmt.Errorf("failed to create portal session: %w", err)
292292+ }
293293+294294+ return &BillingPortalResponse{
295295+ PortalURL: sess.URL,
296296+ }, nil
297297+}
298298+299299+// HandleWebhook processes a Stripe webhook event.
300300+func (m *Manager) HandleWebhook(r *http.Request) (*WebhookEvent, error) {
301301+ if !m.Enabled() {
302302+ return nil, ErrBillingDisabled
303303+ }
304304+305305+ body, err := io.ReadAll(r.Body)
306306+ if err != nil {
307307+ return nil, fmt.Errorf("failed to read request body: %w", err)
308308+ }
309309+310310+ // Verify webhook signature
311311+ event, err := webhook.ConstructEvent(body, r.Header.Get("Stripe-Signature"), m.webhookSecret)
312312+ if err != nil {
313313+ return nil, fmt.Errorf("failed to verify webhook signature: %w", err)
314314+ }
315315+316316+ result := &WebhookEvent{
317317+ Type: string(event.Type),
318318+ }
319319+320320+ switch event.Type {
321321+ case "checkout.session.completed":
322322+ var sess stripe.CheckoutSession
323323+ if err := json.Unmarshal(event.Data.Raw, &sess); err != nil {
324324+ return nil, fmt.Errorf("failed to parse checkout session: %w", err)
325325+ }
326326+327327+ result.CustomerID = sess.Customer.ID
328328+ result.SubscriptionID = sess.Subscription.ID
329329+ result.Status = "active"
330330+331331+ // Fetch customer to get DID from metadata
332332+ result.UserDID = m.getCustomerDID(sess.Customer.ID)
333333+334334+ // Get subscription to find the price/tier
335335+ if sess.Subscription != nil && sess.Subscription.ID != "" {
336336+ if sub, err := m.getSubscription(sess.Subscription.ID); err == nil && sub != nil {
337337+ if len(sub.Items.Data) > 0 {
338338+ result.PriceID = sub.Items.Data[0].Price.ID
339339+ result.NewTier = m.billingCfg.GetTierByPriceID(result.PriceID)
340340+ }
341341+ }
342342+ }
343343+344344+ if result.UserDID != "" && result.NewTier != "" {
345345+ slog.Info("Checkout completed",
346346+ "userDid", result.UserDID,
347347+ "tier", result.NewTier,
348348+ "subscriptionId", result.SubscriptionID,
349349+ )
350350+ }
351351+352352+ case "customer.subscription.created", "customer.subscription.updated":
353353+ var sub stripe.Subscription
354354+ if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
355355+ return nil, fmt.Errorf("failed to parse subscription: %w", err)
356356+ }
357357+358358+ result.SubscriptionID = sub.ID
359359+ result.CustomerID = sub.Customer.ID
360360+ result.Status = string(sub.Status)
361361+362362+ if len(sub.Items.Data) > 0 {
363363+ result.PriceID = sub.Items.Data[0].Price.ID
364364+ result.NewTier = m.billingCfg.GetTierByPriceID(result.PriceID)
365365+ }
366366+367367+ // Fetch customer to get DID from metadata (webhook doesn't include expanded customer)
368368+ result.UserDID = m.getCustomerDID(sub.Customer.ID)
369369+370370+ // If we have user DID and new tier, this signals that crew tier should be updated
371371+ if result.UserDID != "" && result.NewTier != "" && sub.Status == stripe.SubscriptionStatusActive {
372372+ slog.Info("Subscription activated",
373373+ "userDid", result.UserDID,
374374+ "tier", result.NewTier,
375375+ "subscriptionId", result.SubscriptionID,
376376+ )
377377+ }
378378+379379+ case "customer.subscription.deleted", "customer.subscription.paused":
380380+ var sub stripe.Subscription
381381+ if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
382382+ return nil, fmt.Errorf("failed to parse subscription: %w", err)
383383+ }
384384+385385+ result.SubscriptionID = sub.ID
386386+ result.CustomerID = sub.Customer.ID
387387+ if event.Type == "customer.subscription.deleted" {
388388+ result.Status = "cancelled"
389389+ } else {
390390+ result.Status = "paused"
391391+ }
392392+393393+ // Fetch customer to get DID from metadata
394394+ result.UserDID = m.getCustomerDID(sub.Customer.ID)
395395+396396+ // Set tier to default (downgrade on cancellation/pause)
397397+ result.NewTier = m.quotaMgr.GetDefaultTier()
398398+399399+ if result.UserDID != "" {
400400+ slog.Info("Subscription inactive, downgrading to default tier",
401401+ "userDid", result.UserDID,
402402+ "tier", result.NewTier,
403403+ "status", result.Status,
404404+ )
405405+ }
406406+407407+ case "customer.subscription.resumed":
408408+ var sub stripe.Subscription
409409+ if err := json.Unmarshal(event.Data.Raw, &sub); err != nil {
410410+ return nil, fmt.Errorf("failed to parse subscription: %w", err)
411411+ }
412412+413413+ result.SubscriptionID = sub.ID
414414+ result.CustomerID = sub.Customer.ID
415415+ result.Status = "active"
416416+417417+ if len(sub.Items.Data) > 0 {
418418+ result.PriceID = sub.Items.Data[0].Price.ID
419419+ result.NewTier = m.billingCfg.GetTierByPriceID(result.PriceID)
420420+ }
421421+422422+ // Fetch customer to get DID from metadata
423423+ result.UserDID = m.getCustomerDID(sub.Customer.ID)
424424+425425+ if result.UserDID != "" && result.NewTier != "" {
426426+ slog.Info("Subscription resumed, restoring tier",
427427+ "userDid", result.UserDID,
428428+ "tier", result.NewTier,
429429+ )
430430+ }
431431+ }
432432+433433+ return result, nil
434434+}
435435+436436+// getOrCreateCustomer finds or creates a Stripe customer for the given DID.
437437+func (m *Manager) getOrCreateCustomer(userDID string) (*stripe.Customer, error) {
438438+ // Check cache first
439439+ m.customerCacheMu.RLock()
440440+ if cached, ok := m.customerCache[userDID]; ok && time.Now().Before(cached.expiresAt) {
441441+ m.customerCacheMu.RUnlock()
442442+ return cached.customer, nil
443443+ }
444444+ m.customerCacheMu.RUnlock()
445445+446446+ // Try to find existing customer
447447+ cust, err := m.findCustomerByDID(userDID)
448448+ if err == nil && cust != nil {
449449+ m.cacheCustomer(userDID, cust)
450450+ return cust, nil
451451+ }
452452+453453+ // Create new customer
454454+ params := &stripe.CustomerParams{
455455+ Metadata: map[string]string{
456456+ "user_did": userDID,
457457+ "hold_did": m.holdPublicURL, // Not actually a DID but useful for tracking
458458+ },
459459+ }
460460+461461+ cust, err = customer.New(params)
462462+ if err != nil {
463463+ return nil, fmt.Errorf("failed to create customer: %w", err)
464464+ }
465465+466466+ m.cacheCustomer(userDID, cust)
467467+ return cust, nil
468468+}
469469+470470+// findCustomerByDID searches Stripe for a customer with the given DID in metadata.
471471+func (m *Manager) findCustomerByDID(userDID string) (*stripe.Customer, error) {
472472+ // Check cache first
473473+ m.customerCacheMu.RLock()
474474+ if cached, ok := m.customerCache[userDID]; ok && time.Now().Before(cached.expiresAt) {
475475+ m.customerCacheMu.RUnlock()
476476+ return cached.customer, nil
477477+ }
478478+ m.customerCacheMu.RUnlock()
479479+480480+ // Search Stripe by metadata
481481+ params := &stripe.CustomerSearchParams{
482482+ SearchParams: stripe.SearchParams{
483483+ Query: fmt.Sprintf("metadata['user_did']:'%s'", userDID),
484484+ },
485485+ }
486486+ params.AddExpand("data.subscriptions")
487487+488488+ iter := customer.Search(params)
489489+ if iter.Next() {
490490+ cust := iter.Customer()
491491+ m.cacheCustomer(userDID, cust)
492492+ return cust, nil
493493+ }
494494+495495+ if err := iter.Err(); err != nil {
496496+ return nil, err
497497+ }
498498+499499+ return nil, nil // Not found
500500+}
501501+502502+// cacheCustomer adds a customer to the in-memory cache.
503503+func (m *Manager) cacheCustomer(userDID string, cust *stripe.Customer) {
504504+ m.customerCacheMu.Lock()
505505+ defer m.customerCacheMu.Unlock()
506506+507507+ m.customerCache[userDID] = &cachedCustomer{
508508+ customer: cust,
509509+ expiresAt: time.Now().Add(customerCacheTTL),
510510+ }
511511+}
512512+513513+// InvalidateCustomerCache removes a customer from the cache.
514514+func (m *Manager) InvalidateCustomerCache(userDID string) {
515515+ m.customerCacheMu.Lock()
516516+ defer m.customerCacheMu.Unlock()
517517+518518+ delete(m.customerCache, userDID)
519519+}
520520+521521+// getCustomerDID fetches a customer by ID and returns the user_did from metadata.
522522+func (m *Manager) getCustomerDID(customerID string) string {
523523+ if customerID == "" {
524524+ return ""
525525+ }
526526+527527+ cust, err := customer.Get(customerID, nil)
528528+ if err != nil {
529529+ slog.Debug("Failed to fetch customer", "customerId", customerID, "error", err)
530530+ return ""
531531+ }
532532+533533+ if cust.Metadata != nil {
534534+ return cust.Metadata["user_did"]
535535+ }
536536+ return ""
537537+}
538538+539539+// getSubscription fetches a subscription by ID.
540540+func (m *Manager) getSubscription(subscriptionID string) (*stripe.Subscription, error) {
541541+ if subscriptionID == "" {
542542+ return nil, nil
543543+ }
544544+545545+ params := &stripe.SubscriptionParams{}
546546+ params.AddExpand("items.data.price")
547547+548548+ return subscription.Get(subscriptionID, params)
549549+}
+60
pkg/hold/billing/billing_stub.go
···11+//go:build !billing
22+33+package billing
44+55+import (
66+ "net/http"
77+88+ "github.com/go-chi/chi/v5"
99+1010+ "atcr.io/pkg/hold/pds"
1111+ "atcr.io/pkg/hold/quota"
1212+)
1313+1414+// Manager is a no-op billing manager when billing is not compiled in.
1515+type Manager struct{}
1616+1717+// New creates a new no-op billing manager.
1818+// This is used when the billing build tag is not set.
1919+func New(_ *quota.Manager, _ string) *Manager {
2020+ return &Manager{}
2121+}
2222+2323+// Enabled returns false when billing is not compiled in.
2424+func (m *Manager) Enabled() bool {
2525+ return false
2626+}
2727+2828+// RegisterHandlers is a no-op when billing is not compiled in.
2929+func (m *Manager) RegisterHandlers(_ chi.Router) {}
3030+3131+// GetSubscriptionInfo returns an error when billing is not compiled in.
3232+func (m *Manager) GetSubscriptionInfo(_ string) (*SubscriptionInfo, error) {
3333+ return nil, ErrBillingDisabled
3434+}
3535+3636+// CreateCheckoutSession returns an error when billing is not compiled in.
3737+func (m *Manager) CreateCheckoutSession(_ *http.Request, _ *CheckoutSessionRequest) (*CheckoutSessionResponse, error) {
3838+ return nil, ErrBillingDisabled
3939+}
4040+4141+// GetBillingPortalURL returns an error when billing is not compiled in.
4242+func (m *Manager) GetBillingPortalURL(_ string, _ string) (*BillingPortalResponse, error) {
4343+ return nil, ErrBillingDisabled
4444+}
4545+4646+// HandleWebhook returns an error when billing is not compiled in.
4747+func (m *Manager) HandleWebhook(_ *http.Request) (*WebhookEvent, error) {
4848+ return nil, ErrBillingDisabled
4949+}
5050+5151+// XRPCHandler is a no-op handler when billing is not compiled in.
5252+type XRPCHandler struct{}
5353+5454+// NewXRPCHandler creates a new no-op XRPC handler.
5555+func NewXRPCHandler(_ *Manager, _ *pds.HoldPDS, _ *http.Client) *XRPCHandler {
5656+ return &XRPCHandler{}
5757+}
5858+5959+// RegisterHandlers is a no-op when billing is not compiled in.
6060+func (h *XRPCHandler) RegisterHandlers(_ chi.Router) {}
+303
pkg/hold/billing/config.go
···11+//go:build billing
22+33+package billing
44+55+import (
66+ "fmt"
77+ "os"
88+99+ "go.yaml.in/yaml/v4"
1010+1111+ "atcr.io/pkg/hold"
1212+)
1313+1414+// BillingConfig holds billing/Stripe settings parsed from the hold config YAML.
1515+// The billing fields live in the same YAML file as the hold config, but are
1616+// ignored by atcr.io's parser (Go YAML ignores unknown fields by default).
1717+type BillingConfig struct {
1818+ Enabled bool
1919+ Currency string
2020+ SuccessURL string
2121+ CancelURL string
2222+2323+ // Tier-level billing info keyed by tier name (same keys as quota tiers).
2424+ Tiers map[string]BillingTierConfig
2525+2626+ // Tier assigned to plankowner crew members.
2727+ PlankOwnerCrewTier string
2828+}
2929+3030+// BillingTierConfig holds Stripe pricing for a single tier.
3131+type BillingTierConfig struct {
3232+ Description string
3333+ StripePriceMonthly string
3434+ StripePriceYearly string
3535+}
3636+3737+// --- internal YAML structs for parsing the extended hold config ---
3838+3939+// extendedHoldConfig mirrors the hold config but only the quota section.
4040+type extendedHoldConfig struct {
4141+ Quota extendedQuotaConfig `yaml:"quota"`
4242+}
4343+4444+type extendedQuotaConfig struct {
4545+ Tiers map[string]extendedTierConfig `yaml:"tiers"`
4646+ Defaults extendedDefaults `yaml:"defaults"`
4747+ Billing rawBillingConfig `yaml:"billing"`
4848+}
4949+5050+type extendedTierConfig struct {
5151+ Description string `yaml:"description,omitempty"`
5252+ StripePriceMonthly string `yaml:"stripe_price_monthly,omitempty"`
5353+ StripePriceYearly string `yaml:"stripe_price_yearly,omitempty"`
5454+}
5555+5656+type extendedDefaults struct {
5757+ PlankOwnerCrewTier string `yaml:"plankowner_crew_tier,omitempty"`
5858+}
5959+6060+type rawBillingConfig struct {
6161+ Enabled bool `yaml:"enabled"`
6262+ Currency string `yaml:"currency,omitempty"`
6363+ SuccessURL string `yaml:"success_url,omitempty"`
6464+ CancelURL string `yaml:"cancel_url,omitempty"`
6565+}
6666+6767+// LoadBillingConfig reads the hold config YAML and extracts billing fields.
6868+// Returns (nil, nil) if the file is missing or billing is not enabled.
6969+// Returns (nil, err) if the file exists with billing enabled but is misconfigured.
7070+func LoadBillingConfig(configPath string) (*BillingConfig, error) {
7171+ if configPath == "" {
7272+ return nil, nil
7373+ }
7474+7575+ data, err := os.ReadFile(configPath)
7676+ if err != nil {
7777+ if os.IsNotExist(err) {
7878+ return nil, nil
7979+ }
8080+ return nil, fmt.Errorf("failed to read config: %w", err)
8181+ }
8282+8383+ return parseBillingConfig(data)
8484+}
8585+8686+// parseBillingConfig extracts billing fields from hold config YAML bytes.
8787+// Returns (nil, nil) if billing is not enabled.
8888+// Returns (nil, err) if billing is enabled but misconfigured.
8989+func parseBillingConfig(data []byte) (*BillingConfig, error) {
9090+ var ext extendedHoldConfig
9191+ if err := yaml.Unmarshal(data, &ext); err != nil {
9292+ return nil, fmt.Errorf("failed to parse config: %w", err)
9393+ }
9494+9595+ if !ext.Quota.Billing.Enabled {
9696+ return nil, nil
9797+ }
9898+9999+ cfg := &BillingConfig{
100100+ Enabled: true,
101101+ Currency: ext.Quota.Billing.Currency,
102102+ SuccessURL: ext.Quota.Billing.SuccessURL,
103103+ CancelURL: ext.Quota.Billing.CancelURL,
104104+ PlankOwnerCrewTier: ext.Quota.Defaults.PlankOwnerCrewTier,
105105+ Tiers: make(map[string]BillingTierConfig, len(ext.Quota.Tiers)),
106106+ }
107107+108108+ for name, tier := range ext.Quota.Tiers {
109109+ cfg.Tiers[name] = BillingTierConfig{
110110+ Description: tier.Description,
111111+ StripePriceMonthly: tier.StripePriceMonthly,
112112+ StripePriceYearly: tier.StripePriceYearly,
113113+ }
114114+ }
115115+116116+ // Validate: billing enabled but no tiers have any Stripe prices configured
117117+ hasAnyPrice := false
118118+ for _, tier := range cfg.Tiers {
119119+ if tier.StripePriceMonthly != "" || tier.StripePriceYearly != "" {
120120+ hasAnyPrice = true
121121+ break
122122+ }
123123+ }
124124+ if !hasAnyPrice {
125125+ return nil, fmt.Errorf("billing is enabled but no tiers have Stripe prices configured")
126126+ }
127127+128128+ return cfg, nil
129129+}
130130+131131+// GetTierPricing returns billing info for a tier, or nil if not found.
132132+func (c *BillingConfig) GetTierPricing(tierKey string) *BillingTierConfig {
133133+ if c == nil {
134134+ return nil
135135+ }
136136+ t, ok := c.Tiers[tierKey]
137137+ if !ok {
138138+ return nil
139139+ }
140140+ return &t
141141+}
142142+143143+// GetTierByPriceID finds the tier key that contains the given Stripe price ID.
144144+// Returns empty string if no match.
145145+func (c *BillingConfig) GetTierByPriceID(priceID string) string {
146146+ if c == nil || priceID == "" {
147147+ return ""
148148+ }
149149+ for key, tier := range c.Tiers {
150150+ if tier.StripePriceMonthly == priceID || tier.StripePriceYearly == priceID {
151151+ return key
152152+ }
153153+ }
154154+ return ""
155155+}
156156+157157+// ExampleHoldYAML generates a complete hold config example including billing fields.
158158+// It calls hold.ExampleYAML() for the base config, then injects billing-specific
159159+// fields into the YAML node tree before re-marshalling.
160160+func ExampleHoldYAML() ([]byte, error) {
161161+ base, err := hold.ExampleYAML()
162162+ if err != nil {
163163+ return nil, fmt.Errorf("failed to generate base hold config: %w", err)
164164+ }
165165+166166+ var doc yaml.Node
167167+ if err := yaml.Unmarshal(base, &doc); err != nil {
168168+ return nil, fmt.Errorf("failed to parse base hold config: %w", err)
169169+ }
170170+171171+ // doc is DocumentNode -> Content[0] is the root MappingNode
172172+ if doc.Kind != yaml.DocumentNode || len(doc.Content) == 0 {
173173+ return nil, fmt.Errorf("unexpected YAML structure")
174174+ }
175175+ root := doc.Content[0]
176176+177177+ // Find the "quota" mapping inside root
178178+ quotaNode := findMappingValue(root, "quota")
179179+ if quotaNode == nil {
180180+ return nil, fmt.Errorf("quota section not found in base config")
181181+ }
182182+183183+ // Inject billing fields into tier entries
184184+ tiersNode := findMappingValue(quotaNode, "tiers")
185185+ if tiersNode != nil {
186186+ injectTierBillingFields(tiersNode)
187187+ }
188188+189189+ // Inject plankowner_crew_tier into defaults
190190+ defaultsNode := findMappingValue(quotaNode, "defaults")
191191+ if defaultsNode != nil {
192192+ injectPlankOwnerDefault(defaultsNode)
193193+ }
194194+195195+ // Inject billing section under quota
196196+ injectBillingSection(quotaNode)
197197+198198+ return yaml.Marshal(&doc)
199199+}
200200+201201+// findMappingValue finds a value node in a YAML mapping by key.
202202+func findMappingValue(mapping *yaml.Node, key string) *yaml.Node {
203203+ if mapping.Kind != yaml.MappingNode {
204204+ return nil
205205+ }
206206+ for i := 0; i < len(mapping.Content)-1; i += 2 {
207207+ if mapping.Content[i].Value == key {
208208+ return mapping.Content[i+1]
209209+ }
210210+ }
211211+ return nil
212212+}
213213+214214+// injectTierBillingFields adds description and stripe_price fields to each tier entry.
215215+func injectTierBillingFields(tiersNode *yaml.Node) {
216216+ if tiersNode.Kind != yaml.MappingNode {
217217+ return
218218+ }
219219+220220+ examples := map[string]struct {
221221+ description string
222222+ monthly string
223223+ yearly string
224224+ }{
225225+ "bosun": {"Standard tier — recommended for most users.", "price_bosun_monthly_id", "price_bosun_yearly_id"},
226226+ "deckhand": {"Starter tier — free for new crew members.", "", ""},
227227+ "quartermaster": {"Professional tier — for power users and teams.", "price_qm_monthly_id", "price_qm_yearly_id"},
228228+ }
229229+230230+ for i := 0; i < len(tiersNode.Content)-1; i += 2 {
231231+ tierKey := tiersNode.Content[i].Value
232232+ tierVal := tiersNode.Content[i+1]
233233+ if tierVal.Kind != yaml.MappingNode {
234234+ continue
235235+ }
236236+237237+ ex, ok := examples[tierKey]
238238+ if !ok {
239239+ continue
240240+ }
241241+242242+ // Add description
243243+ tierVal.Content = append(tierVal.Content,
244244+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "description", HeadComment: "Human-readable tier description (used in billing UI)."},
245245+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: ex.description},
246246+ )
247247+248248+ // Add stripe prices if applicable
249249+ if ex.monthly != "" {
250250+ tierVal.Content = append(tierVal.Content,
251251+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "stripe_price_monthly", HeadComment: "Stripe Price ID for monthly billing."},
252252+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: ex.monthly},
253253+ )
254254+ }
255255+ if ex.yearly != "" {
256256+ tierVal.Content = append(tierVal.Content,
257257+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "stripe_price_yearly", HeadComment: "Stripe Price ID for yearly billing (optional)."},
258258+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: ex.yearly},
259259+ )
260260+ }
261261+ }
262262+}
263263+264264+// injectPlankOwnerDefault adds plankowner_crew_tier to the defaults section.
265265+func injectPlankOwnerDefault(defaultsNode *yaml.Node) {
266266+ if defaultsNode.Kind != yaml.MappingNode {
267267+ return
268268+ }
269269+ defaultsNode.Content = append(defaultsNode.Content,
270270+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "plankowner_crew_tier", HeadComment: "Tier granted to early crew members (plankowners). Ignored by base hold service."},
271271+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "bosun"},
272272+ )
273273+}
274274+275275+// injectBillingSection adds the billing subsection under quota.
276276+func injectBillingSection(quotaNode *yaml.Node) {
277277+ if quotaNode.Kind != yaml.MappingNode {
278278+ return
279279+ }
280280+281281+ billing := &yaml.Node{
282282+ Kind: yaml.MappingNode,
283283+ Tag: "!!map",
284284+ }
285285+ billing.Content = append(billing.Content,
286286+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "enabled"},
287287+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "false"},
288288+289289+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "currency", HeadComment: "ISO 4217 currency code for Stripe charges."},
290290+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "usd"},
291291+292292+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "success_url", HeadComment: "Redirect URL after successful checkout. {hold_url} is replaced at runtime."},
293293+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "{hold_url}/billing/success"},
294294+295295+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "cancel_url", HeadComment: "Redirect URL when checkout is cancelled."},
296296+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "{hold_url}/billing/cancel"},
297297+ )
298298+299299+ quotaNode.Content = append(quotaNode.Content,
300300+ &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "billing", HeadComment: "Stripe billing settings. Ignored by base hold service (seamark.dev only)."},
301301+ billing,
302302+ )
303303+}
+332
pkg/hold/billing/config_test.go
···11+package billing
22+33+import (
44+ "os"
55+ "path/filepath"
66+ "testing"
77+88+ "go.yaml.in/yaml/v4"
99+1010+ "atcr.io/pkg/hold/quota"
1111+)
1212+1313+// yamlUnmarshal is a thin wrapper to avoid shadowing the yaml package import.
1414+func yamlUnmarshal(data []byte, v any) error {
1515+ return yaml.Unmarshal(data, v)
1616+}
1717+1818+func TestParseBillingConfig_Disabled(t *testing.T) {
1919+ yaml := []byte(`
2020+quota:
2121+ tiers:
2222+ deckhand:
2323+ quota: 5GB
2424+ billing:
2525+ enabled: false
2626+`)
2727+ cfg, err := parseBillingConfig(yaml)
2828+ if err != nil {
2929+ t.Fatalf("unexpected error: %v", err)
3030+ }
3131+ if cfg != nil {
3232+ t.Error("expected nil config when billing disabled")
3333+ }
3434+}
3535+3636+func TestParseBillingConfig_NoBillingSection(t *testing.T) {
3737+ yaml := []byte(`
3838+quota:
3939+ tiers:
4040+ deckhand:
4141+ quota: 5GB
4242+`)
4343+ cfg, err := parseBillingConfig(yaml)
4444+ if err != nil {
4545+ t.Fatalf("unexpected error: %v", err)
4646+ }
4747+ if cfg != nil {
4848+ t.Error("expected nil config when no billing section")
4949+ }
5050+}
5151+5252+func TestParseBillingConfig_Enabled(t *testing.T) {
5353+ yaml := []byte(`
5454+quota:
5555+ tiers:
5656+ deckhand:
5757+ quota: 5GB
5858+ description: Starter tier
5959+ bosun:
6060+ quota: 50GB
6161+ description: Standard tier
6262+ stripe_price_monthly: price_bosun_monthly
6363+ stripe_price_yearly: price_bosun_yearly
6464+ defaults:
6565+ new_crew_tier: deckhand
6666+ plankowner_crew_tier: bosun
6767+ billing:
6868+ enabled: true
6969+ currency: usd
7070+ success_url: "{hold_url}/billing/success"
7171+ cancel_url: "{hold_url}/billing/cancel"
7272+`)
7373+ cfg, err := parseBillingConfig(yaml)
7474+ if err != nil {
7575+ t.Fatalf("unexpected error: %v", err)
7676+ }
7777+ if cfg == nil {
7878+ t.Fatal("expected non-nil config")
7979+ }
8080+8181+ if !cfg.Enabled {
8282+ t.Error("expected Enabled=true")
8383+ }
8484+ if cfg.Currency != "usd" {
8585+ t.Errorf("expected currency 'usd', got %q", cfg.Currency)
8686+ }
8787+ if cfg.PlankOwnerCrewTier != "bosun" {
8888+ t.Errorf("expected plankowner_crew_tier 'bosun', got %q", cfg.PlankOwnerCrewTier)
8989+ }
9090+ if cfg.SuccessURL != "{hold_url}/billing/success" {
9191+ t.Errorf("unexpected success_url: %q", cfg.SuccessURL)
9292+ }
9393+9494+ // Check tier pricing
9595+ bosun := cfg.GetTierPricing("bosun")
9696+ if bosun == nil {
9797+ t.Fatal("expected bosun tier pricing")
9898+ }
9999+ if bosun.StripePriceMonthly != "price_bosun_monthly" {
100100+ t.Errorf("expected bosun monthly price 'price_bosun_monthly', got %q", bosun.StripePriceMonthly)
101101+ }
102102+ if bosun.StripePriceYearly != "price_bosun_yearly" {
103103+ t.Errorf("expected bosun yearly price 'price_bosun_yearly', got %q", bosun.StripePriceYearly)
104104+ }
105105+ if bosun.Description != "Standard tier" {
106106+ t.Errorf("expected bosun description 'Standard tier', got %q", bosun.Description)
107107+ }
108108+109109+ // Deckhand has no prices
110110+ deckhand := cfg.GetTierPricing("deckhand")
111111+ if deckhand == nil {
112112+ t.Fatal("expected deckhand tier pricing entry")
113113+ }
114114+ if deckhand.StripePriceMonthly != "" {
115115+ t.Error("expected no monthly price for deckhand")
116116+ }
117117+}
118118+119119+func TestParseBillingConfig_EnabledButNoPrices(t *testing.T) {
120120+ yaml := []byte(`
121121+quota:
122122+ tiers:
123123+ deckhand:
124124+ quota: 5GB
125125+ billing:
126126+ enabled: true
127127+ currency: usd
128128+`)
129129+ cfg, err := parseBillingConfig(yaml)
130130+ if err == nil {
131131+ t.Error("expected error when billing enabled but no prices configured")
132132+ }
133133+ if cfg != nil {
134134+ t.Error("expected nil config on error")
135135+ }
136136+}
137137+138138+func TestGetTierByPriceID(t *testing.T) {
139139+ cfg := &BillingConfig{
140140+ Tiers: map[string]BillingTierConfig{
141141+ "deckhand": {},
142142+ "bosun": {StripePriceMonthly: "price_m", StripePriceYearly: "price_y"},
143143+ },
144144+ }
145145+146146+ if got := cfg.GetTierByPriceID("price_m"); got != "bosun" {
147147+ t.Errorf("expected 'bosun' for monthly price, got %q", got)
148148+ }
149149+ if got := cfg.GetTierByPriceID("price_y"); got != "bosun" {
150150+ t.Errorf("expected 'bosun' for yearly price, got %q", got)
151151+ }
152152+ if got := cfg.GetTierByPriceID("price_unknown"); got != "" {
153153+ t.Errorf("expected empty for unknown price, got %q", got)
154154+ }
155155+ if got := cfg.GetTierByPriceID(""); got != "" {
156156+ t.Errorf("expected empty for empty price, got %q", got)
157157+ }
158158+159159+ // nil receiver
160160+ var nilCfg *BillingConfig
161161+ if got := nilCfg.GetTierByPriceID("price_m"); got != "" {
162162+ t.Errorf("expected empty from nil config, got %q", got)
163163+ }
164164+}
165165+166166+func TestGetTierPricing_NilConfig(t *testing.T) {
167167+ var cfg *BillingConfig
168168+ if cfg.GetTierPricing("anything") != nil {
169169+ t.Error("expected nil from nil config")
170170+ }
171171+}
172172+173173+func TestLoadBillingConfig_MissingFile(t *testing.T) {
174174+ cfg, err := LoadBillingConfig("/nonexistent/config.yaml")
175175+ if err != nil {
176176+ t.Fatalf("expected no error for missing file, got: %v", err)
177177+ }
178178+ if cfg != nil {
179179+ t.Error("expected nil config for missing file")
180180+ }
181181+}
182182+183183+func TestLoadBillingConfig_EmptyPath(t *testing.T) {
184184+ cfg, err := LoadBillingConfig("")
185185+ if err != nil {
186186+ t.Fatalf("unexpected error: %v", err)
187187+ }
188188+ if cfg != nil {
189189+ t.Error("expected nil config for empty path")
190190+ }
191191+}
192192+193193+func TestLoadBillingConfig_FromFile(t *testing.T) {
194194+ dir := t.TempDir()
195195+ path := filepath.Join(dir, "config.yaml")
196196+197197+ content := `
198198+quota:
199199+ tiers:
200200+ bosun:
201201+ quota: 50GB
202202+ stripe_price_monthly: price_test
203203+ billing:
204204+ enabled: true
205205+ currency: usd
206206+`
207207+ if err := os.WriteFile(path, []byte(content), 0644); err != nil {
208208+ t.Fatal(err)
209209+ }
210210+211211+ cfg, err := LoadBillingConfig(path)
212212+ if err != nil {
213213+ t.Fatalf("unexpected error: %v", err)
214214+ }
215215+ if cfg == nil {
216216+ t.Fatal("expected non-nil config")
217217+ }
218218+ if cfg.GetTierByPriceID("price_test") != "bosun" {
219219+ t.Error("expected bosun tier for price_test")
220220+ }
221221+}
222222+223223+// holdQuotaWrapper mirrors the hold config structure just enough to extract
224224+// the quota section for testing. This avoids importing the full hold package.
225225+type holdQuotaWrapper struct {
226226+ Quota quota.Config `yaml:"quota"`
227227+}
228228+229229+// TestExampleHoldYAMLRoundTrip verifies that the generated example config
230230+// can be parsed by both atcr.io's quota parser and seamark.dev's billing parser.
231231+// This catches silent breakage if atcr.io renames or restructures the quota section.
232232+func TestExampleHoldYAMLRoundTrip(t *testing.T) {
233233+ yamlBytes, err := ExampleHoldYAML()
234234+ if err != nil {
235235+ t.Fatalf("ExampleHoldYAML failed: %v", err)
236236+ }
237237+238238+ // Verify atcr.io's quota parser can read the quota section.
239239+ // The full hold config nests tiers under "quota:", so we parse with
240240+ // a wrapper struct (same as hold.Config does) then use NewManagerFromConfig.
241241+ var wrapper holdQuotaWrapper
242242+ if err := yamlUnmarshal(yamlBytes, &wrapper); err != nil {
243243+ t.Fatalf("failed to parse generated config for quota: %v", err)
244244+ }
245245+246246+ quotaMgr, err := quota.NewManagerFromConfig(&wrapper.Quota)
247247+ if err != nil {
248248+ t.Fatalf("quota.NewManagerFromConfig failed: %v", err)
249249+ }
250250+ if !quotaMgr.IsEnabled() {
251251+ t.Error("expected quotas to be enabled in generated config")
252252+ }
253253+ if quotaMgr.TierCount() != 3 {
254254+ t.Errorf("expected 3 quota tiers, got %d", quotaMgr.TierCount())
255255+ }
256256+ if quotaMgr.GetDefaultTier() != "deckhand" {
257257+ t.Errorf("expected default tier 'deckhand', got %q", quotaMgr.GetDefaultTier())
258258+ }
259259+260260+ // The generated example has billing.enabled: false, so parseBillingConfig
261261+ // returns nil. Enable it to verify the billing fields were injected correctly.
262262+ // Use the full "billing:\n...enabled:" pattern to avoid replacing admin.enabled.
263263+ enabledYAML := replaceOnce(string(yamlBytes), "billing:\n enabled: false", "billing:\n enabled: true")
264264+265265+ billingCfg, err := parseBillingConfig([]byte(enabledYAML))
266266+ if err != nil {
267267+ t.Fatalf("parseBillingConfig failed on generated config: %v", err)
268268+ }
269269+ if billingCfg == nil {
270270+ t.Fatal("expected non-nil billing config after enabling")
271271+ }
272272+273273+ // Verify billing fields were injected into the YAML
274274+ if billingCfg.Currency != "usd" {
275275+ t.Errorf("expected currency 'usd', got %q", billingCfg.Currency)
276276+ }
277277+ if billingCfg.PlankOwnerCrewTier != "bosun" {
278278+ t.Errorf("expected plankowner_crew_tier 'bosun', got %q", billingCfg.PlankOwnerCrewTier)
279279+ }
280280+281281+ // Verify tier-level billing fields
282282+ bosun := billingCfg.GetTierPricing("bosun")
283283+ if bosun == nil {
284284+ t.Fatal("expected bosun billing tier")
285285+ }
286286+ if bosun.StripePriceMonthly == "" {
287287+ t.Error("expected bosun to have stripe_price_monthly")
288288+ }
289289+ if bosun.Description == "" {
290290+ t.Error("expected bosun to have description")
291291+ }
292292+293293+ qm := billingCfg.GetTierPricing("quartermaster")
294294+ if qm == nil {
295295+ t.Fatal("expected quartermaster billing tier")
296296+ }
297297+ if qm.StripePriceMonthly == "" {
298298+ t.Error("expected quartermaster to have stripe_price_monthly")
299299+ }
300300+301301+ // Deckhand is the free tier — no Stripe prices expected
302302+ deckhand := billingCfg.GetTierPricing("deckhand")
303303+ if deckhand == nil {
304304+ t.Fatal("expected deckhand billing tier entry")
305305+ }
306306+ if deckhand.StripePriceMonthly != "" {
307307+ t.Error("expected no stripe_price_monthly for deckhand")
308308+ }
309309+310310+ // Verify the price ID reverse lookup works
311311+ if billingCfg.GetTierByPriceID(bosun.StripePriceMonthly) != "bosun" {
312312+ t.Error("GetTierByPriceID failed for bosun monthly price")
313313+ }
314314+}
315315+316316+// replaceOnce replaces the first occurrence of old with new in s.
317317+func replaceOnce(s, old, new string) string {
318318+ i := indexOf(s, old)
319319+ if i < 0 {
320320+ return s
321321+ }
322322+ return s[:i] + new + s[i+len(old):]
323323+}
324324+325325+func indexOf(s, substr string) int {
326326+ for i := 0; i <= len(s)-len(substr); i++ {
327327+ if s[i:i+len(substr)] == substr {
328328+ return i
329329+ }
330330+ }
331331+ return -1
332332+}
+222
pkg/hold/billing/handlers.go
···11+//go:build billing
22+33+package billing
44+55+import (
66+ "encoding/json"
77+ "log/slog"
88+ "net/http"
99+1010+ "github.com/go-chi/chi/v5"
1111+1212+ "atcr.io/pkg/hold/pds"
1313+)
1414+1515+// XRPCHandler handles billing-related XRPC endpoints.
1616+type XRPCHandler struct {
1717+ manager *Manager
1818+ pdsServer *pds.HoldPDS
1919+ httpClient *http.Client
2020+}
2121+2222+// NewXRPCHandler creates a new billing XRPC handler.
2323+func NewXRPCHandler(manager *Manager, pdsServer *pds.HoldPDS, httpClient *http.Client) *XRPCHandler {
2424+ return &XRPCHandler{
2525+ manager: manager,
2626+ pdsServer: pdsServer,
2727+ httpClient: httpClient,
2828+ }
2929+}
3030+3131+// RegisterHandlers registers billing XRPC endpoints on the router.
3232+func (m *Manager) RegisterHandlers(r chi.Router) {
3333+ // This is a no-op for the Manager itself
3434+ // Use NewXRPCHandler and call its RegisterHandlers method
3535+}
3636+3737+// RegisterHandlers registers billing endpoints on the router.
3838+func (h *XRPCHandler) RegisterHandlers(r chi.Router) {
3939+ if !h.manager.Enabled() {
4040+ slog.Info("Billing endpoints disabled (not configured)")
4141+ return
4242+ }
4343+4444+ slog.Info("Registering billing XRPC endpoints")
4545+4646+ // Public endpoint - get subscription info (auth optional for tiers list)
4747+ r.Get("/xrpc/io.atcr.hold.getSubscriptionInfo", h.HandleGetSubscriptionInfo)
4848+4949+ // Authenticated endpoints
5050+ r.Group(func(r chi.Router) {
5151+ r.Use(h.requireAuth)
5252+ r.Post("/xrpc/io.atcr.hold.createCheckoutSession", h.HandleCreateCheckoutSession)
5353+ r.Get("/xrpc/io.atcr.hold.getBillingPortalUrl", h.HandleGetBillingPortalURL)
5454+ })
5555+5656+ // Stripe webhook (authenticated by Stripe signature)
5757+ r.Post("/xrpc/io.atcr.hold.stripeWebhook", h.HandleStripeWebhook)
5858+}
5959+6060+// requireAuth is middleware that validates user authentication.
6161+func (h *XRPCHandler) requireAuth(next http.Handler) http.Handler {
6262+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
6363+ // Use the same auth validation as other hold endpoints
6464+ user, err := pds.ValidateDPoPRequest(r, h.httpClient)
6565+ if err != nil {
6666+ // Try service token
6767+ user, err = pds.ValidateServiceToken(r, h.pdsServer.DID(), h.httpClient)
6868+ }
6969+ if err != nil {
7070+ respondError(w, http.StatusUnauthorized, "authentication required")
7171+ return
7272+ }
7373+7474+ // Store user DID in header for handlers
7575+ r.Header.Set("X-User-DID", user.DID)
7676+ next.ServeHTTP(w, r)
7777+ })
7878+}
7979+8080+// HandleGetSubscriptionInfo returns subscription and quota information.
8181+// GET /xrpc/io.atcr.hold.getSubscriptionInfo?userDid=did:plc:xxx
8282+func (h *XRPCHandler) HandleGetSubscriptionInfo(w http.ResponseWriter, r *http.Request) {
8383+ userDID := r.URL.Query().Get("userDid")
8484+8585+ // If no userDID provided, try to get from auth
8686+ if userDID == "" {
8787+ // Try to authenticate (optional)
8888+ user, err := pds.ValidateDPoPRequest(r, h.httpClient)
8989+ if err != nil {
9090+ user, _ = pds.ValidateServiceToken(r, h.pdsServer.DID(), h.httpClient)
9191+ }
9292+ if user != nil {
9393+ userDID = user.DID
9494+ }
9595+ }
9696+9797+ info, err := h.manager.GetSubscriptionInfo(userDID)
9898+ if err != nil {
9999+ if err == ErrBillingDisabled {
100100+ // Return basic info with payments disabled
101101+ respondJSON(w, http.StatusOK, &SubscriptionInfo{
102102+ UserDID: userDID,
103103+ PaymentsEnabled: false,
104104+ Tiers: h.manager.buildTierList(userDID),
105105+ })
106106+ return
107107+ }
108108+ respondError(w, http.StatusInternalServerError, err.Error())
109109+ return
110110+ }
111111+112112+ // Get current usage and crew tier from PDS quota stats
113113+ if userDID != "" {
114114+ stats, err := h.pdsServer.GetQuotaForUserWithTier(r.Context(), userDID, h.manager.quotaMgr)
115115+ if err == nil {
116116+ info.CurrentUsage = stats.TotalSize
117117+ info.CrewTier = stats.Tier // tier from local crew record (what's actually enforced)
118118+ info.CurrentLimit = stats.Limit
119119+120120+ // If no subscription but crew has a tier, show that as current
121121+ if info.SubscriptionID == "" && info.CrewTier != "" {
122122+ info.CurrentTier = info.CrewTier
123123+ }
124124+ }
125125+ }
126126+127127+ // Mark which tier is actually current (use crew tier if available, otherwise subscription tier)
128128+ effectiveTier := info.CurrentTier
129129+ if info.CrewTier != "" {
130130+ effectiveTier = info.CrewTier
131131+ }
132132+ for i := range info.Tiers {
133133+ info.Tiers[i].IsCurrent = info.Tiers[i].ID == effectiveTier
134134+ }
135135+136136+ respondJSON(w, http.StatusOK, info)
137137+}
138138+139139+// HandleCreateCheckoutSession creates a Stripe checkout session.
140140+// POST /xrpc/io.atcr.hold.createCheckoutSession
141141+func (h *XRPCHandler) HandleCreateCheckoutSession(w http.ResponseWriter, r *http.Request) {
142142+ var req CheckoutSessionRequest
143143+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
144144+ respondError(w, http.StatusBadRequest, "invalid request body")
145145+ return
146146+ }
147147+148148+ if req.Tier == "" {
149149+ respondError(w, http.StatusBadRequest, "tier is required")
150150+ return
151151+ }
152152+153153+ resp, err := h.manager.CreateCheckoutSession(r, &req)
154154+ if err != nil {
155155+ slog.Error("Failed to create checkout session", "error", err)
156156+ respondError(w, http.StatusInternalServerError, err.Error())
157157+ return
158158+ }
159159+160160+ respondJSON(w, http.StatusOK, resp)
161161+}
162162+163163+// HandleGetBillingPortalURL returns a URL to the Stripe billing portal.
164164+// GET /xrpc/io.atcr.hold.getBillingPortalUrl?returnUrl=https://...
165165+func (h *XRPCHandler) HandleGetBillingPortalURL(w http.ResponseWriter, r *http.Request) {
166166+ userDID := r.Header.Get("X-User-DID")
167167+ returnURL := r.URL.Query().Get("returnUrl")
168168+169169+ resp, err := h.manager.GetBillingPortalURL(userDID, returnURL)
170170+ if err != nil {
171171+ slog.Error("Failed to get billing portal URL", "error", err, "userDid", userDID)
172172+ respondError(w, http.StatusInternalServerError, err.Error())
173173+ return
174174+ }
175175+176176+ respondJSON(w, http.StatusOK, resp)
177177+}
178178+179179+// HandleStripeWebhook processes Stripe webhook events.
180180+// POST /xrpc/io.atcr.hold.stripeWebhook
181181+func (h *XRPCHandler) HandleStripeWebhook(w http.ResponseWriter, r *http.Request) {
182182+ event, err := h.manager.HandleWebhook(r)
183183+ if err != nil {
184184+ slog.Error("Failed to process webhook", "error", err)
185185+ respondError(w, http.StatusBadRequest, err.Error())
186186+ return
187187+ }
188188+189189+ // If we have a tier update, apply it to the crew record
190190+ if event.UserDID != "" && event.NewTier != "" {
191191+ if err := h.pdsServer.UpdateCrewMemberTier(r.Context(), event.UserDID, event.NewTier); err != nil {
192192+ slog.Error("Failed to update crew tier", "error", err, "userDid", event.UserDID, "tier", event.NewTier)
193193+ // Don't fail the webhook - Stripe will retry
194194+ } else {
195195+ slog.Info("Updated crew tier from subscription",
196196+ "userDid", event.UserDID,
197197+ "tier", event.NewTier,
198198+ "event", event.Type,
199199+ )
200200+ }
201201+202202+ // Invalidate customer cache since subscription changed
203203+ h.manager.InvalidateCustomerCache(event.UserDID)
204204+ }
205205+206206+ // Return 200 to acknowledge receipt
207207+ respondJSON(w, http.StatusOK, map[string]string{"received": "true"})
208208+}
209209+210210+// respondJSON writes a JSON response.
211211+func respondJSON(w http.ResponseWriter, status int, v any) {
212212+ w.Header().Set("Content-Type", "application/json")
213213+ w.WriteHeader(status)
214214+ if err := json.NewEncoder(w).Encode(v); err != nil {
215215+ slog.Error("Failed to encode JSON response", "error", err)
216216+ }
217217+}
218218+219219+// respondError writes a JSON error response.
220220+func respondError(w http.ResponseWriter, status int, message string) {
221221+ respondJSON(w, status, map[string]string{"error": message})
222222+}
+65
pkg/hold/billing/types.go
···11+// Package billing provides optional Stripe billing integration for hold services.
22+// This package uses build tags to conditionally compile Stripe support.
33+// Build with -tags billing to enable Stripe integration.
44+package billing
55+66+import "errors"
77+88+// ErrBillingDisabled is returned when billing operations are attempted
99+// but billing is not enabled (either not compiled in or disabled at runtime).
1010+var ErrBillingDisabled = errors.New("billing not enabled")
1111+1212+// SubscriptionInfo contains subscription and quota information for a user.
1313+type SubscriptionInfo struct {
1414+ UserDID string `json:"userDid"`
1515+ CurrentTier string `json:"currentTier"` // tier from Stripe subscription (or default)
1616+ CrewTier string `json:"crewTier,omitempty"` // tier from local crew record (what's actually enforced)
1717+ CurrentUsage int64 `json:"currentUsage"` // bytes used
1818+ CurrentLimit *int64 `json:"currentLimit,omitempty"` // nil = unlimited
1919+ PaymentsEnabled bool `json:"paymentsEnabled"` // whether online payments are available
2020+ Tiers []TierInfo `json:"tiers"` // available tiers
2121+ SubscriptionID string `json:"subscriptionId,omitempty"` // Stripe subscription ID if active
2222+ CustomerID string `json:"customerId,omitempty"` // Stripe customer ID if exists
2323+ BillingInterval string `json:"billingInterval,omitempty"` // "monthly" or "yearly"
2424+}
2525+2626+// TierInfo describes a single tier available for subscription.
2727+type TierInfo struct {
2828+ ID string `json:"id"` // tier key (e.g., "deckhand", "bosun")
2929+ Name string `json:"name"` // display name (same as ID if not specified)
3030+ Description string `json:"description,omitempty"` // human-readable description
3131+ QuotaBytes int64 `json:"quotaBytes"` // quota limit in bytes
3232+ QuotaFormatted string `json:"quotaFormatted"` // human-readable quota (e.g., "5 GB")
3333+ PriceCentsMonthly int `json:"priceCentsMonthly,omitempty"` // monthly price in cents (0 = free)
3434+ PriceCentsYearly int `json:"priceCentsYearly,omitempty"` // yearly price in cents (0 = not available)
3535+ IsCurrent bool `json:"isCurrent,omitempty"` // whether this is user's current tier
3636+}
3737+3838+// CheckoutSessionRequest is the request to create a Stripe checkout session.
3939+type CheckoutSessionRequest struct {
4040+ Tier string `json:"tier"` // tier to subscribe to
4141+ Interval string `json:"interval,omitempty"` // "monthly" or "yearly" (default: monthly)
4242+ ReturnURL string `json:"returnUrl,omitempty"` // URL to return to after checkout
4343+}
4444+4545+// CheckoutSessionResponse is the response with the Stripe checkout URL.
4646+type CheckoutSessionResponse struct {
4747+ CheckoutURL string `json:"checkoutUrl"`
4848+ SessionID string `json:"sessionId"`
4949+}
5050+5151+// BillingPortalResponse is the response with the Stripe billing portal URL.
5252+type BillingPortalResponse struct {
5353+ PortalURL string `json:"portalUrl"`
5454+}
5555+5656+// WebhookEvent represents a processed Stripe webhook event.
5757+type WebhookEvent struct {
5858+ Type string `json:"type"` // e.g., "customer.subscription.updated"
5959+ CustomerID string `json:"customerId"` // Stripe customer ID
6060+ UserDID string `json:"userDid"` // user's DID from customer metadata
6161+ SubscriptionID string `json:"subscriptionId,omitempty"` // Stripe subscription ID
6262+ PriceID string `json:"priceId,omitempty"` // Stripe price ID
6363+ NewTier string `json:"newTier,omitempty"` // resolved tier name
6464+ Status string `json:"status,omitempty"` // subscription status
6565+}