feat(billing): implement epic 12 billing and subscription
This commit is contained in:
19
.env.example
19
.env.example
@@ -29,6 +29,25 @@ CORS_ALLOWED_ORIGINS=http://localhost:5173
|
||||
# Rate limiting — requests per minute per IP
|
||||
RATE_LIMIT_RPM=100
|
||||
|
||||
# Stripe Integration (data sync OAuth/webhooks)
|
||||
STRIPE_CLIENT_ID=
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
STRIPE_OAUTH_REDIRECT_URL=http://localhost:8080/api/v1/integrations/stripe/callback
|
||||
STRIPE_ENCRYPTION_KEY=
|
||||
STRIPE_SYNC_INTERVAL_MIN=15
|
||||
STRIPE_PAYMENT_SYNC_DAYS=90
|
||||
|
||||
# Stripe Billing (PulseScore subscriptions)
|
||||
STRIPE_BILLING_SECRET_KEY=
|
||||
STRIPE_BILLING_PUBLISHABLE_KEY=
|
||||
STRIPE_BILLING_WEBHOOK_SECRET=
|
||||
STRIPE_BILLING_PORTAL_RETURN_URL=http://localhost:5173/settings/billing
|
||||
STRIPE_BILLING_PRICE_GROWTH_MONTHLY=
|
||||
STRIPE_BILLING_PRICE_GROWTH_ANNUAL=
|
||||
STRIPE_BILLING_PRICE_SCALE_MONTHLY=
|
||||
STRIPE_BILLING_PRICE_SCALE_ANNUAL=
|
||||
|
||||
# HubSpot Integration
|
||||
HUBSPOT_CLIENT_ID=
|
||||
HUBSPOT_CLIENT_SECRET=
|
||||
|
||||
@@ -16,3 +16,13 @@ CORS_ALLOWED_ORIGINS=https://yourdomain.com
|
||||
|
||||
# Rate limiting — requests per minute per IP
|
||||
RATE_LIMIT_RPM=100
|
||||
|
||||
# Stripe Billing (PulseScore subscriptions)
|
||||
STRIPE_BILLING_SECRET_KEY=
|
||||
STRIPE_BILLING_PUBLISHABLE_KEY=
|
||||
STRIPE_BILLING_WEBHOOK_SECRET=
|
||||
STRIPE_BILLING_PORTAL_RETURN_URL=https://yourdomain.com/settings/billing
|
||||
STRIPE_BILLING_PRICE_GROWTH_MONTHLY=
|
||||
STRIPE_BILLING_PRICE_GROWTH_ANNUAL=
|
||||
STRIPE_BILLING_PRICE_SCALE_MONTHLY=
|
||||
STRIPE_BILLING_PRICE_SCALE_ANNUAL=
|
||||
|
||||
27
README.md
27
README.md
@@ -96,3 +96,30 @@ The frontend starts on http://localhost:5173.
|
||||
| `npm run lint` | ESLint check |
|
||||
| `npm run format` | Format with Prettier |
|
||||
| `npm run preview` | Preview production build |
|
||||
|
||||
## Billing & Subscription (Epic 12)
|
||||
|
||||
PulseScore now includes a dedicated Stripe billing domain (separate from Stripe customer-data integration):
|
||||
|
||||
- Plan catalog: `free`, `growth`, `scale` (`internal/billing/plans.go`)
|
||||
- Protected billing APIs:
|
||||
- `GET /api/v1/billing/subscription`
|
||||
- `POST /api/v1/billing/checkout` (admin)
|
||||
- `POST /api/v1/billing/portal-session` (admin)
|
||||
- `POST /api/v1/billing/cancel` (admin)
|
||||
- Public billing webhook:
|
||||
- `POST /api/v1/webhooks/stripe-billing`
|
||||
|
||||
### Required production billing env vars
|
||||
|
||||
- `STRIPE_BILLING_SECRET_KEY`
|
||||
- `STRIPE_BILLING_PUBLISHABLE_KEY`
|
||||
- `STRIPE_BILLING_WEBHOOK_SECRET`
|
||||
- `STRIPE_BILLING_PRICE_GROWTH_MONTHLY`
|
||||
- `STRIPE_BILLING_PRICE_GROWTH_ANNUAL`
|
||||
- `STRIPE_BILLING_PRICE_SCALE_MONTHLY`
|
||||
- `STRIPE_BILLING_PRICE_SCALE_ANNUAL`
|
||||
|
||||
Optional:
|
||||
|
||||
- `STRIPE_BILLING_PORTAL_RETURN_URL` (defaults to `http://localhost:5173/settings/billing`)
|
||||
|
||||
@@ -17,12 +17,14 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/onnwee/pulse-score/internal/auth"
|
||||
billingcatalog "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/config"
|
||||
"github.com/onnwee/pulse-score/internal/database"
|
||||
"github.com/onnwee/pulse-score/internal/handler"
|
||||
"github.com/onnwee/pulse-score/internal/middleware"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
"github.com/onnwee/pulse-score/internal/service"
|
||||
billingsvc "github.com/onnwee/pulse-score/internal/service/billing"
|
||||
"github.com/onnwee/pulse-score/internal/service/scoring"
|
||||
)
|
||||
|
||||
@@ -100,9 +102,11 @@ func main() {
|
||||
if pool != nil {
|
||||
userRepo := repository.NewUserRepository(pool.P)
|
||||
orgRepo := repository.NewOrganizationRepository(pool.P)
|
||||
orgSubRepo := repository.NewOrgSubscriptionRepository(pool.P)
|
||||
refreshTokenRepo := repository.NewRefreshTokenRepository(pool.P)
|
||||
invitationRepo := repository.NewInvitationRepository(pool.P)
|
||||
passwordResetRepo := repository.NewPasswordResetRepository(pool.P)
|
||||
billingWebhookEventRepo := repository.NewBillingWebhookEventRepository(pool.P)
|
||||
|
||||
emailSvc := service.NewSendGridEmailService(service.SendGridConfig{
|
||||
APIKey: cfg.SendGrid.APIKey,
|
||||
@@ -136,6 +140,20 @@ func main() {
|
||||
onboardingEventRepo := repository.NewOnboardingEventRepository(pool.P)
|
||||
|
||||
// Stripe services
|
||||
planCatalog := billingcatalog.NewCatalog(billingcatalog.PriceConfig{
|
||||
GrowthMonthly: cfg.BillingStripe.PriceGrowthMonthly,
|
||||
GrowthAnnual: cfg.BillingStripe.PriceGrowthAnnual,
|
||||
ScaleMonthly: cfg.BillingStripe.PriceScaleMonthly,
|
||||
ScaleAnnual: cfg.BillingStripe.PriceScaleAnnual,
|
||||
})
|
||||
|
||||
if cfg.IsProd() {
|
||||
if err := billingcatalog.VerifyConfiguredPrices(context.Background(), cfg.BillingStripe.SecretKey, planCatalog); err != nil {
|
||||
slog.Error("invalid Stripe billing price configuration", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
stripeOAuthSvc := service.NewStripeOAuthService(service.StripeOAuthConfig{
|
||||
ClientID: cfg.Stripe.ClientID,
|
||||
SecretKey: cfg.Stripe.SecretKey,
|
||||
@@ -200,6 +218,46 @@ func main() {
|
||||
mrrSvc, paymentHealthSvc,
|
||||
)
|
||||
|
||||
billingSubscriptionSvc := billingsvc.NewSubscriptionService(
|
||||
orgSubRepo,
|
||||
orgRepo,
|
||||
customerRepo,
|
||||
connRepo,
|
||||
planCatalog,
|
||||
)
|
||||
|
||||
billingLimitsSvc := billingsvc.NewLimitsService(
|
||||
billingSubscriptionSvc,
|
||||
customerRepo,
|
||||
connRepo,
|
||||
connRepo,
|
||||
planCatalog,
|
||||
)
|
||||
|
||||
billingCheckoutSvc := billingsvc.NewCheckoutService(
|
||||
cfg.BillingStripe.SecretKey,
|
||||
cfg.SendGrid.FrontendURL,
|
||||
orgRepo,
|
||||
planCatalog,
|
||||
)
|
||||
|
||||
billingPortalSvc := billingsvc.NewPortalService(
|
||||
cfg.BillingStripe.SecretKey,
|
||||
cfg.BillingStripe.PortalReturnURL,
|
||||
cfg.SendGrid.FrontendURL,
|
||||
orgRepo,
|
||||
orgSubRepo,
|
||||
)
|
||||
|
||||
billingWebhookSvc := billingsvc.NewWebhookService(
|
||||
cfg.BillingStripe.WebhookSecret,
|
||||
pool.P,
|
||||
orgRepo,
|
||||
orgSubRepo,
|
||||
billingWebhookEventRepo,
|
||||
planCatalog,
|
||||
)
|
||||
|
||||
hubspotWebhookSvc := service.NewHubSpotWebhookService(
|
||||
cfg.HubSpot.WebhookSecret,
|
||||
hubspotSyncSvc,
|
||||
@@ -343,6 +401,10 @@ func main() {
|
||||
webhookHandler := handler.NewWebhookStripeHandler(stripeWebhookSvc)
|
||||
r.Post("/webhooks/stripe", webhookHandler.HandleWebhook)
|
||||
|
||||
// Stripe billing webhook (public — verified by signature)
|
||||
billingWebhookHandler := handler.NewWebhookStripeBillingHandler(billingWebhookSvc)
|
||||
r.Post("/webhooks/stripe-billing", billingWebhookHandler.HandleWebhook)
|
||||
|
||||
// SendGrid webhook (public — for delivery tracking)
|
||||
sendgridWebhookHandler := handler.NewWebhookSendGridHandler(alertHistoryRepo)
|
||||
r.Post("/webhooks/sendgrid", sendgridWebhookHandler.HandleWebhook)
|
||||
@@ -367,6 +429,21 @@ func main() {
|
||||
r.Get("/organizations/current", orgHandler.GetCurrent)
|
||||
r.Patch("/organizations/current", orgHandler.UpdateCurrent)
|
||||
|
||||
billingHandler := handler.NewBillingHandler(
|
||||
billingCheckoutSvc,
|
||||
billingPortalSvc,
|
||||
billingSubscriptionSvc,
|
||||
)
|
||||
r.Route("/billing", func(r chi.Router) {
|
||||
r.Get("/subscription", billingHandler.GetSubscription)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.RequireRole("admin"))
|
||||
r.Post("/checkout", billingHandler.CreateCheckout)
|
||||
r.Post("/portal-session", billingHandler.CreatePortalSession)
|
||||
r.Post("/cancel", billingHandler.CancelSubscription)
|
||||
})
|
||||
})
|
||||
|
||||
// User profile routes
|
||||
userSvc := service.NewUserService(userRepo, orgRepo)
|
||||
userHandler := handler.NewUserHandler(userSvc)
|
||||
@@ -451,7 +528,7 @@ func main() {
|
||||
stripeHandler := handler.NewIntegrationStripeHandler(stripeOAuthSvc, syncOrchestrator)
|
||||
r.Route("/integrations/stripe", func(r chi.Router) {
|
||||
r.Use(middleware.RequireRole("admin"))
|
||||
r.Get("/connect", stripeHandler.Connect)
|
||||
r.With(middleware.RequireIntegrationLimit(billingLimitsSvc, "stripe")).Get("/connect", stripeHandler.Connect)
|
||||
r.Get("/callback", stripeHandler.Callback)
|
||||
r.Get("/status", stripeHandler.Status)
|
||||
r.Delete("/", stripeHandler.Disconnect)
|
||||
@@ -462,7 +539,7 @@ func main() {
|
||||
hubspotHandler := handler.NewIntegrationHubSpotHandler(hubspotOAuthSvc, hubspotSyncOrchestrator)
|
||||
r.Route("/integrations/hubspot", func(r chi.Router) {
|
||||
r.Use(middleware.RequireRole("admin"))
|
||||
r.Get("/connect", hubspotHandler.Connect)
|
||||
r.With(middleware.RequireIntegrationLimit(billingLimitsSvc, "hubspot")).Get("/connect", hubspotHandler.Connect)
|
||||
r.Get("/callback", hubspotHandler.Callback)
|
||||
r.Get("/status", hubspotHandler.Status)
|
||||
r.Delete("/", hubspotHandler.Disconnect)
|
||||
@@ -473,7 +550,7 @@ func main() {
|
||||
intercomHandler := handler.NewIntegrationIntercomHandler(intercomOAuthSvc, intercomSyncOrchestrator)
|
||||
r.Route("/integrations/intercom", func(r chi.Router) {
|
||||
r.Use(middleware.RequireRole("admin"))
|
||||
r.Get("/connect", intercomHandler.Connect)
|
||||
r.With(middleware.RequireIntegrationLimit(billingLimitsSvc, "intercom")).Get("/connect", intercomHandler.Connect)
|
||||
r.Get("/callback", intercomHandler.Callback)
|
||||
r.Get("/status", intercomHandler.Status)
|
||||
r.Delete("/", intercomHandler.Disconnect)
|
||||
|
||||
200
internal/billing/plans.go
Normal file
200
internal/billing/plans.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const Unlimited = -1
|
||||
|
||||
type Tier string
|
||||
|
||||
const (
|
||||
TierFree Tier = "free"
|
||||
TierGrowth Tier = "growth"
|
||||
TierScale Tier = "scale"
|
||||
)
|
||||
|
||||
type BillingCycle string
|
||||
|
||||
const (
|
||||
BillingCycleMonthly BillingCycle = "monthly"
|
||||
BillingCycleAnnual BillingCycle = "annual"
|
||||
)
|
||||
|
||||
type UsageLimits struct {
|
||||
CustomerLimit int `json:"customer_limit"`
|
||||
IntegrationLimit int `json:"integration_limit"`
|
||||
}
|
||||
|
||||
type FeatureFlags struct {
|
||||
Playbooks bool `json:"playbooks"`
|
||||
AIInsights bool `json:"ai_insights"`
|
||||
}
|
||||
|
||||
type Plan struct {
|
||||
Tier Tier `json:"tier"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
MonthlyPriceCents int `json:"monthly_price_cents"`
|
||||
AnnualPriceCents int `json:"annual_price_cents"`
|
||||
StripeMonthlyPriceID string `json:"stripe_monthly_price_id"`
|
||||
StripeAnnualPriceID string `json:"stripe_annual_price_id"`
|
||||
Limits UsageLimits `json:"limits"`
|
||||
Features FeatureFlags `json:"features"`
|
||||
}
|
||||
|
||||
type PriceConfig struct {
|
||||
GrowthMonthly string
|
||||
GrowthAnnual string
|
||||
ScaleMonthly string
|
||||
ScaleAnnual string
|
||||
}
|
||||
|
||||
type Catalog struct {
|
||||
plans map[Tier]Plan
|
||||
priceIndex map[string]priceRef
|
||||
}
|
||||
|
||||
type priceRef struct {
|
||||
Tier Tier
|
||||
Cycle BillingCycle
|
||||
}
|
||||
|
||||
func NewCatalog(cfg PriceConfig) *Catalog {
|
||||
plans := map[Tier]Plan{
|
||||
TierFree: {
|
||||
Tier: TierFree,
|
||||
Name: "Free",
|
||||
Description: "Best for evaluating PulseScore with a small portfolio.",
|
||||
MonthlyPriceCents: 0,
|
||||
AnnualPriceCents: 0,
|
||||
Limits: UsageLimits{
|
||||
CustomerLimit: 10,
|
||||
IntegrationLimit: 1,
|
||||
},
|
||||
Features: FeatureFlags{},
|
||||
},
|
||||
TierGrowth: {
|
||||
Tier: TierGrowth,
|
||||
Name: "Growth",
|
||||
Description: "For fast-moving teams managing churn at scale.",
|
||||
MonthlyPriceCents: 4900,
|
||||
AnnualPriceCents: 49000,
|
||||
StripeMonthlyPriceID: strings.TrimSpace(cfg.GrowthMonthly),
|
||||
StripeAnnualPriceID: strings.TrimSpace(cfg.GrowthAnnual),
|
||||
Limits: UsageLimits{
|
||||
CustomerLimit: 250,
|
||||
IntegrationLimit: 3,
|
||||
},
|
||||
Features: FeatureFlags{
|
||||
Playbooks: true,
|
||||
},
|
||||
},
|
||||
TierScale: {
|
||||
Tier: TierScale,
|
||||
Name: "Scale",
|
||||
Description: "For mature revenue teams with complex customer motion.",
|
||||
MonthlyPriceCents: 14900,
|
||||
AnnualPriceCents: 149000,
|
||||
StripeMonthlyPriceID: strings.TrimSpace(cfg.ScaleMonthly),
|
||||
StripeAnnualPriceID: strings.TrimSpace(cfg.ScaleAnnual),
|
||||
Limits: UsageLimits{
|
||||
CustomerLimit: Unlimited,
|
||||
IntegrationLimit: Unlimited,
|
||||
},
|
||||
Features: FeatureFlags{
|
||||
Playbooks: true,
|
||||
AIInsights: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
priceIndex := map[string]priceRef{}
|
||||
for _, p := range plans {
|
||||
if p.StripeMonthlyPriceID != "" {
|
||||
priceIndex[p.StripeMonthlyPriceID] = priceRef{Tier: p.Tier, Cycle: BillingCycleMonthly}
|
||||
}
|
||||
if p.StripeAnnualPriceID != "" {
|
||||
priceIndex[p.StripeAnnualPriceID] = priceRef{Tier: p.Tier, Cycle: BillingCycleAnnual}
|
||||
}
|
||||
}
|
||||
|
||||
return &Catalog{plans: plans, priceIndex: priceIndex}
|
||||
}
|
||||
|
||||
func (c *Catalog) GetPlanByTier(tier string) (Plan, bool) {
|
||||
if c == nil {
|
||||
return Plan{}, false
|
||||
}
|
||||
plan, ok := c.plans[NormalizeTier(tier)]
|
||||
return plan, ok
|
||||
}
|
||||
|
||||
func (c *Catalog) GetLimits(tier string) (UsageLimits, bool) {
|
||||
plan, ok := c.GetPlanByTier(tier)
|
||||
if !ok {
|
||||
return UsageLimits{}, false
|
||||
}
|
||||
return plan.Limits, true
|
||||
}
|
||||
|
||||
func (c *Catalog) GetPriceID(tier string, annual bool) (string, error) {
|
||||
plan, ok := c.GetPlanByTier(tier)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown tier: %s", tier)
|
||||
}
|
||||
if plan.Tier == TierFree {
|
||||
return "", fmt.Errorf("free tier does not have a Stripe price")
|
||||
}
|
||||
|
||||
if annual {
|
||||
if plan.StripeAnnualPriceID == "" {
|
||||
return "", fmt.Errorf("annual price id is not configured for tier %s", plan.Tier)
|
||||
}
|
||||
return plan.StripeAnnualPriceID, nil
|
||||
}
|
||||
|
||||
if plan.StripeMonthlyPriceID == "" {
|
||||
return "", fmt.Errorf("monthly price id is not configured for tier %s", plan.Tier)
|
||||
}
|
||||
return plan.StripeMonthlyPriceID, nil
|
||||
}
|
||||
|
||||
func (c *Catalog) ResolveTierAndCycleByPriceID(priceID string) (Tier, BillingCycle, bool) {
|
||||
if c == nil {
|
||||
return "", "", false
|
||||
}
|
||||
ref, ok := c.priceIndex[strings.TrimSpace(priceID)]
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return ref.Tier, ref.Cycle, true
|
||||
}
|
||||
|
||||
func (c *Catalog) RecommendedUpgrade(tier string) Tier {
|
||||
switch NormalizeTier(tier) {
|
||||
case TierFree:
|
||||
return TierGrowth
|
||||
case TierGrowth:
|
||||
return TierScale
|
||||
default:
|
||||
return TierScale
|
||||
}
|
||||
}
|
||||
|
||||
func NormalizeTier(tier string) Tier {
|
||||
switch strings.ToLower(strings.TrimSpace(tier)) {
|
||||
case string(TierGrowth):
|
||||
return TierGrowth
|
||||
case string(TierScale):
|
||||
return TierScale
|
||||
default:
|
||||
return TierFree
|
||||
}
|
||||
}
|
||||
|
||||
func IsPaidTier(tier string) bool {
|
||||
t := NormalizeTier(tier)
|
||||
return t == TierGrowth || t == TierScale
|
||||
}
|
||||
82
internal/billing/plans_test.go
Normal file
82
internal/billing/plans_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package billing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGetPlanByTier(t *testing.T) {
|
||||
catalog := NewCatalog(PriceConfig{})
|
||||
|
||||
plan, ok := catalog.GetPlanByTier("growth")
|
||||
if !ok {
|
||||
t.Fatal("expected growth plan to exist")
|
||||
}
|
||||
if plan.Tier != TierGrowth {
|
||||
t.Fatalf("expected growth tier, got %s", plan.Tier)
|
||||
}
|
||||
|
||||
free, ok := catalog.GetPlanByTier("FREE")
|
||||
if !ok {
|
||||
t.Fatal("expected free plan to exist")
|
||||
}
|
||||
if free.Tier != TierFree {
|
||||
t.Fatalf("expected free tier, got %s", free.Tier)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLimits(t *testing.T) {
|
||||
catalog := NewCatalog(PriceConfig{})
|
||||
|
||||
freeLimits, ok := catalog.GetLimits("free")
|
||||
if !ok {
|
||||
t.Fatal("expected free limits")
|
||||
}
|
||||
if freeLimits.CustomerLimit != 10 {
|
||||
t.Fatalf("expected free customer limit 10, got %d", freeLimits.CustomerLimit)
|
||||
}
|
||||
if freeLimits.IntegrationLimit != 1 {
|
||||
t.Fatalf("expected free integration limit 1, got %d", freeLimits.IntegrationLimit)
|
||||
}
|
||||
|
||||
scaleLimits, ok := catalog.GetLimits("scale")
|
||||
if !ok {
|
||||
t.Fatal("expected scale limits")
|
||||
}
|
||||
if scaleLimits.CustomerLimit != Unlimited {
|
||||
t.Fatalf("expected scale customer limit unlimited, got %d", scaleLimits.CustomerLimit)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPriceMappingMonthlyAnnual(t *testing.T) {
|
||||
catalog := NewCatalog(PriceConfig{
|
||||
GrowthMonthly: "price_growth_monthly",
|
||||
GrowthAnnual: "price_growth_annual",
|
||||
ScaleMonthly: "price_scale_monthly",
|
||||
ScaleAnnual: "price_scale_annual",
|
||||
})
|
||||
|
||||
monthly, err := catalog.GetPriceID("growth", false)
|
||||
if err != nil {
|
||||
t.Fatalf("expected growth monthly price id, got error: %v", err)
|
||||
}
|
||||
if monthly != "price_growth_monthly" {
|
||||
t.Fatalf("expected growth monthly id, got %s", monthly)
|
||||
}
|
||||
|
||||
annual, err := catalog.GetPriceID("growth", true)
|
||||
if err != nil {
|
||||
t.Fatalf("expected growth annual price id, got error: %v", err)
|
||||
}
|
||||
if annual != "price_growth_annual" {
|
||||
t.Fatalf("expected growth annual id, got %s", annual)
|
||||
}
|
||||
|
||||
tier, cycle, ok := catalog.ResolveTierAndCycleByPriceID("price_scale_annual")
|
||||
if !ok {
|
||||
t.Fatal("expected price id mapping for scale annual")
|
||||
}
|
||||
if tier != TierScale {
|
||||
t.Fatalf("expected scale tier, got %s", tier)
|
||||
}
|
||||
if cycle != BillingCycleAnnual {
|
||||
t.Fatalf("expected annual cycle, got %s", cycle)
|
||||
}
|
||||
}
|
||||
72
internal/billing/verify_prices.go
Normal file
72
internal/billing/verify_prices.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
stripeprice "github.com/stripe/stripe-go/v81/price"
|
||||
)
|
||||
|
||||
// VerifyConfiguredPrices validates that configured Stripe price IDs exist and include required metadata.
|
||||
// Required metadata keys: tier, customer_limit, integration_limit.
|
||||
func VerifyConfiguredPrices(ctx context.Context, stripeSecretKey string, catalog *Catalog) error {
|
||||
if strings.TrimSpace(stripeSecretKey) == "" {
|
||||
return fmt.Errorf("stripe billing secret key is required for price verification")
|
||||
}
|
||||
if catalog == nil {
|
||||
return fmt.Errorf("billing catalog is required for price verification")
|
||||
}
|
||||
|
||||
client := stripeprice.Client{B: stripe.GetBackend(stripe.APIBackend), Key: stripeSecretKey}
|
||||
|
||||
for _, tier := range []Tier{TierGrowth, TierScale} {
|
||||
plan, ok := catalog.GetPlanByTier(string(tier))
|
||||
if !ok {
|
||||
return fmt.Errorf("plan not found for tier %s", tier)
|
||||
}
|
||||
|
||||
for _, cycle := range []BillingCycle{BillingCycleMonthly, BillingCycleAnnual} {
|
||||
priceID := plan.StripeMonthlyPriceID
|
||||
if cycle == BillingCycleAnnual {
|
||||
priceID = plan.StripeAnnualPriceID
|
||||
}
|
||||
if priceID == "" {
|
||||
return fmt.Errorf("missing configured price id for tier=%s cycle=%s", tier, cycle)
|
||||
}
|
||||
|
||||
p, err := client.Get(priceID, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch stripe price %s: %w", priceID, err)
|
||||
}
|
||||
|
||||
if got := strings.ToLower(strings.TrimSpace(p.Metadata["tier"])); got != string(tier) {
|
||||
return fmt.Errorf("price %s metadata tier mismatch: expected=%s got=%s", priceID, tier, got)
|
||||
}
|
||||
|
||||
if !limitMetadataMatches(p.Metadata["customer_limit"], plan.Limits.CustomerLimit) {
|
||||
return fmt.Errorf("price %s metadata customer_limit mismatch for tier %s", priceID, tier)
|
||||
}
|
||||
if !limitMetadataMatches(p.Metadata["integration_limit"], plan.Limits.IntegrationLimit) {
|
||||
return fmt.Errorf("price %s metadata integration_limit mismatch for tier %s", priceID, tier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = ctx // kept for future request scoping if Stripe SDK adds context support to this client path
|
||||
return nil
|
||||
}
|
||||
|
||||
func limitMetadataMatches(raw string, expected int) bool {
|
||||
raw = strings.ToLower(strings.TrimSpace(raw))
|
||||
if expected == Unlimited {
|
||||
return raw == "unlimited" || raw == "-1"
|
||||
}
|
||||
parsed, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return parsed == expected
|
||||
}
|
||||
@@ -4,22 +4,24 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds application configuration.
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
CORS CORSConfig
|
||||
Rate RateConfig
|
||||
JWT JWTConfig
|
||||
SendGrid SendGridConfig
|
||||
Stripe StripeConfig
|
||||
HubSpot HubSpotConfig
|
||||
Intercom IntercomConfig
|
||||
Scoring ScoringConfig
|
||||
Alert AlertConfig
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
CORS CORSConfig
|
||||
Rate RateConfig
|
||||
JWT JWTConfig
|
||||
SendGrid SendGridConfig
|
||||
Stripe StripeConfig
|
||||
BillingStripe BillingStripeConfig
|
||||
HubSpot HubSpotConfig
|
||||
Intercom IntercomConfig
|
||||
Scoring ScoringConfig
|
||||
Alert AlertConfig
|
||||
}
|
||||
|
||||
// AlertConfig holds alert engine settings.
|
||||
@@ -46,6 +48,18 @@ type StripeConfig struct {
|
||||
PaymentSyncDays int
|
||||
}
|
||||
|
||||
// BillingStripeConfig holds PulseScore-owned Stripe billing settings.
|
||||
type BillingStripeConfig struct {
|
||||
SecretKey string
|
||||
PublishableKey string
|
||||
WebhookSecret string
|
||||
PortalReturnURL string
|
||||
PriceGrowthMonthly string
|
||||
PriceGrowthAnnual string
|
||||
PriceScaleMonthly string
|
||||
PriceScaleAnnual string
|
||||
}
|
||||
|
||||
// HubSpotConfig holds HubSpot OAuth and webhook settings.
|
||||
type HubSpotConfig struct {
|
||||
ClientID string
|
||||
@@ -159,6 +173,16 @@ func Load() *Config {
|
||||
SyncIntervalMin: getInt("STRIPE_SYNC_INTERVAL_MIN", 15),
|
||||
PaymentSyncDays: getInt("STRIPE_PAYMENT_SYNC_DAYS", 90),
|
||||
},
|
||||
BillingStripe: BillingStripeConfig{
|
||||
SecretKey: getEnv("STRIPE_BILLING_SECRET_KEY", ""),
|
||||
PublishableKey: getEnv("STRIPE_BILLING_PUBLISHABLE_KEY", ""),
|
||||
WebhookSecret: getEnv("STRIPE_BILLING_WEBHOOK_SECRET", ""),
|
||||
PortalReturnURL: getEnv("STRIPE_BILLING_PORTAL_RETURN_URL", "http://localhost:5173/settings/billing"),
|
||||
PriceGrowthMonthly: getEnv("STRIPE_BILLING_PRICE_GROWTH_MONTHLY", ""),
|
||||
PriceGrowthAnnual: getEnv("STRIPE_BILLING_PRICE_GROWTH_ANNUAL", ""),
|
||||
PriceScaleMonthly: getEnv("STRIPE_BILLING_PRICE_SCALE_MONTHLY", ""),
|
||||
PriceScaleAnnual: getEnv("STRIPE_BILLING_PRICE_SCALE_ANNUAL", ""),
|
||||
},
|
||||
HubSpot: HubSpotConfig{
|
||||
ClientID: getEnv("HUBSPOT_CLIENT_ID", ""),
|
||||
ClientSecret: getEnv("HUBSPOT_CLIENT_SECRET", ""),
|
||||
@@ -196,6 +220,28 @@ func (c *Config) Validate() error {
|
||||
if c.JWT.Secret == "" || c.JWT.Secret == "dev-secret-change-me-in-production" {
|
||||
return fmt.Errorf("JWT_SECRET must be set to a secure value in production")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(c.BillingStripe.SecretKey) == "" {
|
||||
return fmt.Errorf("STRIPE_BILLING_SECRET_KEY is required in production")
|
||||
}
|
||||
if strings.TrimSpace(c.BillingStripe.PublishableKey) == "" {
|
||||
return fmt.Errorf("STRIPE_BILLING_PUBLISHABLE_KEY is required in production")
|
||||
}
|
||||
if strings.TrimSpace(c.BillingStripe.WebhookSecret) == "" {
|
||||
return fmt.Errorf("STRIPE_BILLING_WEBHOOK_SECRET is required in production")
|
||||
}
|
||||
|
||||
requiredPriceIDs := map[string]string{
|
||||
"STRIPE_BILLING_PRICE_GROWTH_MONTHLY": c.BillingStripe.PriceGrowthMonthly,
|
||||
"STRIPE_BILLING_PRICE_GROWTH_ANNUAL": c.BillingStripe.PriceGrowthAnnual,
|
||||
"STRIPE_BILLING_PRICE_SCALE_MONTHLY": c.BillingStripe.PriceScaleMonthly,
|
||||
"STRIPE_BILLING_PRICE_SCALE_ANNUAL": c.BillingStripe.PriceScaleAnnual,
|
||||
}
|
||||
for name, value := range requiredPriceIDs {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fmt.Errorf("%s is required in production", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -12,7 +12,11 @@ func clearEnv() {
|
||||
"DB_MAX_OPEN_CONNS", "DB_MAX_IDLE_CONNS",
|
||||
"DB_MAX_CONN_LIFETIME", "DB_HEALTH_CHECK_SEC",
|
||||
"READ_TIMEOUT", "WRITE_TIMEOUT", "IDLE_TIMEOUT",
|
||||
"CORS_ALLOWED_ORIGINS", "RATE_LIMIT_RPM",
|
||||
"CORS_ALLOWED_ORIGINS", "RATE_LIMIT_RPM", "JWT_SECRET",
|
||||
"STRIPE_BILLING_SECRET_KEY", "STRIPE_BILLING_PUBLISHABLE_KEY",
|
||||
"STRIPE_BILLING_WEBHOOK_SECRET", "STRIPE_BILLING_PORTAL_RETURN_URL",
|
||||
"STRIPE_BILLING_PRICE_GROWTH_MONTHLY", "STRIPE_BILLING_PRICE_GROWTH_ANNUAL",
|
||||
"STRIPE_BILLING_PRICE_SCALE_MONTHLY", "STRIPE_BILLING_PRICE_SCALE_ANNUAL",
|
||||
} {
|
||||
os.Unsetenv(key)
|
||||
}
|
||||
@@ -107,6 +111,72 @@ func TestValidateDevelopment(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadBillingStripeFromEnv(t *testing.T) {
|
||||
clearEnv()
|
||||
os.Setenv("STRIPE_BILLING_SECRET_KEY", "sk_test_123")
|
||||
os.Setenv("STRIPE_BILLING_PUBLISHABLE_KEY", "pk_test_123")
|
||||
os.Setenv("STRIPE_BILLING_WEBHOOK_SECRET", "whsec_123")
|
||||
os.Setenv("STRIPE_BILLING_PORTAL_RETURN_URL", "https://app.example.com/settings/billing")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_GROWTH_MONTHLY", "price_growth_monthly")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_GROWTH_ANNUAL", "price_growth_annual")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_SCALE_MONTHLY", "price_scale_monthly")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_SCALE_ANNUAL", "price_scale_annual")
|
||||
defer clearEnv()
|
||||
|
||||
cfg := Load()
|
||||
|
||||
if cfg.BillingStripe.SecretKey != "sk_test_123" {
|
||||
t.Errorf("expected billing secret key to load from env")
|
||||
}
|
||||
if cfg.BillingStripe.PublishableKey != "pk_test_123" {
|
||||
t.Errorf("expected billing publishable key to load from env")
|
||||
}
|
||||
if cfg.BillingStripe.WebhookSecret != "whsec_123" {
|
||||
t.Errorf("expected billing webhook secret to load from env")
|
||||
}
|
||||
if cfg.BillingStripe.PriceGrowthMonthly != "price_growth_monthly" {
|
||||
t.Errorf("expected growth monthly price id to load from env")
|
||||
}
|
||||
if cfg.BillingStripe.PriceScaleAnnual != "price_scale_annual" {
|
||||
t.Errorf("expected scale annual price id to load from env")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProductionRequiresBillingStripeConfig(t *testing.T) {
|
||||
clearEnv()
|
||||
os.Setenv("ENVIRONMENT", "production")
|
||||
os.Setenv("DATABASE_URL", "postgres://prod-db")
|
||||
os.Setenv("JWT_SECRET", "prod-super-secret")
|
||||
defer clearEnv()
|
||||
|
||||
cfg := Load()
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("expected validation error when billing stripe config is missing in production")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProductionWithBillingStripeConfig(t *testing.T) {
|
||||
clearEnv()
|
||||
os.Setenv("ENVIRONMENT", "production")
|
||||
os.Setenv("DATABASE_URL", "postgres://prod-db")
|
||||
os.Setenv("JWT_SECRET", "prod-super-secret")
|
||||
os.Setenv("STRIPE_BILLING_SECRET_KEY", "sk_live_123")
|
||||
os.Setenv("STRIPE_BILLING_PUBLISHABLE_KEY", "pk_live_123")
|
||||
os.Setenv("STRIPE_BILLING_WEBHOOK_SECRET", "whsec_live_123")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_GROWTH_MONTHLY", "price_growth_monthly")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_GROWTH_ANNUAL", "price_growth_annual")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_SCALE_MONTHLY", "price_scale_monthly")
|
||||
os.Setenv("STRIPE_BILLING_PRICE_SCALE_ANNUAL", "price_scale_annual")
|
||||
defer clearEnv()
|
||||
|
||||
cfg := Load()
|
||||
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("expected production validation to pass, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsProd(t *testing.T) {
|
||||
clearEnv()
|
||||
cfg := Load()
|
||||
|
||||
171
internal/handler/billing.go
Normal file
171
internal/handler/billing.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/onnwee/pulse-score/internal/auth"
|
||||
core "github.com/onnwee/pulse-score/internal/service"
|
||||
billing "github.com/onnwee/pulse-score/internal/service/billing"
|
||||
)
|
||||
|
||||
type billingCheckoutServicer interface {
|
||||
CreateCheckoutSession(ctx context.Context, orgID, userID uuid.UUID, req billing.CreateCheckoutSessionRequest) (*billing.CreateCheckoutSessionResponse, error)
|
||||
}
|
||||
|
||||
type billingPortalServicer interface {
|
||||
CreatePortalSession(ctx context.Context, orgID uuid.UUID) (*billing.PortalSessionResponse, error)
|
||||
CancelAtPeriodEnd(ctx context.Context, orgID uuid.UUID) error
|
||||
}
|
||||
|
||||
type billingSubscriptionServicer interface {
|
||||
GetSubscriptionSummary(ctx context.Context, orgID uuid.UUID) (*billing.SubscriptionSummary, error)
|
||||
}
|
||||
|
||||
type billingWebhookServicer interface {
|
||||
HandleEvent(ctx context.Context, payload []byte, sigHeader string) error
|
||||
}
|
||||
|
||||
// BillingHandler provides PulseScore billing endpoints.
|
||||
type BillingHandler struct {
|
||||
checkoutSvc billingCheckoutServicer
|
||||
portalSvc billingPortalServicer
|
||||
subscriptionSvc billingSubscriptionServicer
|
||||
}
|
||||
|
||||
func NewBillingHandler(
|
||||
checkoutSvc billingCheckoutServicer,
|
||||
portalSvc billingPortalServicer,
|
||||
subscriptionSvc billingSubscriptionServicer,
|
||||
) *BillingHandler {
|
||||
return &BillingHandler{
|
||||
checkoutSvc: checkoutSvc,
|
||||
portalSvc: portalSvc,
|
||||
subscriptionSvc: subscriptionSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCheckout handles POST /api/v1/billing/checkout.
|
||||
func (h *BillingHandler) CreateCheckout(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := auth.GetUserID(r.Context())
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
var req billing.CreateCheckoutSessionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse("invalid request body"))
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.checkoutSvc.CreateCheckoutSession(r.Context(), orgID, userID, req)
|
||||
if err != nil {
|
||||
handleServiceError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// CreatePortalSession handles POST /api/v1/billing/portal-session.
|
||||
func (h *BillingHandler) CreatePortalSession(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.portalSvc.CreatePortalSession(r.Context(), orgID)
|
||||
if err != nil {
|
||||
handleServiceError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// GetSubscription handles GET /api/v1/billing/subscription.
|
||||
func (h *BillingHandler) GetSubscription(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.subscriptionSvc.GetSubscriptionSummary(r.Context(), orgID)
|
||||
if err != nil {
|
||||
handleServiceError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// CancelSubscription handles POST /api/v1/billing/cancel.
|
||||
func (h *BillingHandler) CancelSubscription(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusUnauthorized, errorResponse("unauthorized"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.portalSvc.CancelAtPeriodEnd(r.Context(), orgID); err != nil {
|
||||
handleServiceError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "cancel_at_period_end"})
|
||||
}
|
||||
|
||||
// WebhookStripeBillingHandler handles Stripe billing webhooks.
|
||||
type WebhookStripeBillingHandler struct {
|
||||
webhookSvc billingWebhookServicer
|
||||
}
|
||||
|
||||
func NewWebhookStripeBillingHandler(webhookSvc billingWebhookServicer) *WebhookStripeBillingHandler {
|
||||
return &WebhookStripeBillingHandler{webhookSvc: webhookSvc}
|
||||
}
|
||||
|
||||
// HandleWebhook handles POST /api/v1/webhooks/stripe-billing.
|
||||
func (h *WebhookStripeBillingHandler) HandleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
const maxBodySize = 65536
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBodySize)
|
||||
|
||||
payload, err := readBody(r)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse("invalid request body"))
|
||||
return
|
||||
}
|
||||
|
||||
sigHeader := r.Header.Get("Stripe-Signature")
|
||||
if sigHeader == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse("missing Stripe-Signature header"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.webhookSvc.HandleEvent(r.Context(), payload, sigHeader); err != nil {
|
||||
var validationErr *core.ValidationError
|
||||
if errors.As(err, &validationErr) {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse(validationErr.Message))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Error("billing webhook processing error", "error", err)
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse("internal server error"))
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
156
internal/handler/billing_test.go
Normal file
156
internal/handler/billing_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/onnwee/pulse-score/internal/auth"
|
||||
core "github.com/onnwee/pulse-score/internal/service"
|
||||
billing "github.com/onnwee/pulse-score/internal/service/billing"
|
||||
)
|
||||
|
||||
type mockBillingCheckoutService struct {
|
||||
createFn func(ctx context.Context, orgID, userID uuid.UUID, req billing.CreateCheckoutSessionRequest) (*billing.CreateCheckoutSessionResponse, error)
|
||||
}
|
||||
|
||||
func (m *mockBillingCheckoutService) CreateCheckoutSession(ctx context.Context, orgID, userID uuid.UUID, req billing.CreateCheckoutSessionRequest) (*billing.CreateCheckoutSessionResponse, error) {
|
||||
return m.createFn(ctx, orgID, userID, req)
|
||||
}
|
||||
|
||||
type mockBillingPortalService struct {
|
||||
portalFn func(ctx context.Context, orgID uuid.UUID) (*billing.PortalSessionResponse, error)
|
||||
cancelFn func(ctx context.Context, orgID uuid.UUID) error
|
||||
}
|
||||
|
||||
func (m *mockBillingPortalService) CreatePortalSession(ctx context.Context, orgID uuid.UUID) (*billing.PortalSessionResponse, error) {
|
||||
return m.portalFn(ctx, orgID)
|
||||
}
|
||||
|
||||
func (m *mockBillingPortalService) CancelAtPeriodEnd(ctx context.Context, orgID uuid.UUID) error {
|
||||
return m.cancelFn(ctx, orgID)
|
||||
}
|
||||
|
||||
type mockBillingSubscriptionService struct {
|
||||
getFn func(ctx context.Context, orgID uuid.UUID) (*billing.SubscriptionSummary, error)
|
||||
}
|
||||
|
||||
func (m *mockBillingSubscriptionService) GetSubscriptionSummary(ctx context.Context, orgID uuid.UUID) (*billing.SubscriptionSummary, error) {
|
||||
return m.getFn(ctx, orgID)
|
||||
}
|
||||
|
||||
type mockBillingWebhookService struct {
|
||||
handleFn func(ctx context.Context, payload []byte, sigHeader string) error
|
||||
}
|
||||
|
||||
func (m *mockBillingWebhookService) HandleEvent(ctx context.Context, payload []byte, sigHeader string) error {
|
||||
return m.handleFn(ctx, payload, sigHeader)
|
||||
}
|
||||
|
||||
func TestBillingCreateCheckout_Unauthorized(t *testing.T) {
|
||||
h := NewBillingHandler(&mockBillingCheckoutService{}, &mockBillingPortalService{}, &mockBillingSubscriptionService{})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/billing/checkout", strings.NewReader(`{"tier":"growth"}`))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.CreateCheckout(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCreateCheckout_Success(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
userID := uuid.New()
|
||||
h := NewBillingHandler(
|
||||
&mockBillingCheckoutService{createFn: func(ctx context.Context, gotOrgID, gotUserID uuid.UUID, req billing.CreateCheckoutSessionRequest) (*billing.CreateCheckoutSessionResponse, error) {
|
||||
if gotOrgID != orgID || gotUserID != userID {
|
||||
t.Fatalf("unexpected org/user ids passed")
|
||||
}
|
||||
return &billing.CreateCheckoutSessionResponse{URL: "https://checkout.stripe.test"}, nil
|
||||
}},
|
||||
&mockBillingPortalService{},
|
||||
&mockBillingSubscriptionService{},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/billing/checkout", strings.NewReader(`{"tier":"growth","annual":true}`))
|
||||
req = req.WithContext(auth.WithUserID(auth.WithOrgID(req.Context(), orgID), userID))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.CreateCheckout(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingGetSubscription_Success(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
h := NewBillingHandler(
|
||||
&mockBillingCheckoutService{},
|
||||
&mockBillingPortalService{},
|
||||
&mockBillingSubscriptionService{getFn: func(context.Context, uuid.UUID) (*billing.SubscriptionSummary, error) {
|
||||
return &billing.SubscriptionSummary{Tier: "free", Status: "free"}, nil
|
||||
}},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/v1/billing/subscription", nil)
|
||||
req = req.WithContext(auth.WithOrgID(req.Context(), orgID))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.GetSubscription(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingCancelSubscription_Success(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
h := NewBillingHandler(
|
||||
&mockBillingCheckoutService{},
|
||||
&mockBillingPortalService{cancelFn: func(context.Context, uuid.UUID) error { return nil }},
|
||||
&mockBillingSubscriptionService{},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/billing/cancel", nil)
|
||||
req = req.WithContext(auth.WithOrgID(req.Context(), orgID))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.CancelSubscription(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingWebhook_MissingSignature(t *testing.T) {
|
||||
h := NewWebhookStripeBillingHandler(&mockBillingWebhookService{})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/webhooks/stripe-billing", strings.NewReader("{}"))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleWebhook(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingWebhook_ValidationError(t *testing.T) {
|
||||
h := NewWebhookStripeBillingHandler(&mockBillingWebhookService{handleFn: func(context.Context, []byte, string) error {
|
||||
return &core.ValidationError{Field: "signature", Message: "invalid webhook signature"}
|
||||
}})
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/v1/webhooks/stripe-billing", strings.NewReader("{}"))
|
||||
req.Header.Set("Stripe-Signature", "bad-signature")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleWebhook(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
81
internal/middleware/feature_gate.go
Normal file
81
internal/middleware/feature_gate.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/onnwee/pulse-score/internal/auth"
|
||||
billing "github.com/onnwee/pulse-score/internal/service/billing"
|
||||
)
|
||||
|
||||
// RequireIntegrationLimit enforces integration connection limits for the current org.
|
||||
func RequireIntegrationLimit(limitsSvc *billing.LimitsService, provider string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeFeatureGateJSON(w, http.StatusUnauthorized, map[string]string{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
decision, err := limitsSvc.CheckIntegrationLimit(r.Context(), orgID, provider)
|
||||
if err != nil {
|
||||
writeFeatureGateJSON(w, http.StatusInternalServerError, map[string]string{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
if !decision.Allowed {
|
||||
writeFeatureGateJSON(w, http.StatusPaymentRequired, map[string]any{
|
||||
"error": "plan limit reached",
|
||||
"current_plan": decision.CurrentPlan,
|
||||
"limit_type": decision.LimitType,
|
||||
"current_usage": decision.CurrentUsage,
|
||||
"limit": decision.Limit,
|
||||
"recommended_upgrade_tier": decision.RecommendedUpgradeTier,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireCustomerLimit enforces customer creation limits for the current org.
|
||||
func RequireCustomerLimit(limitsSvc *billing.LimitsService) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
orgID, ok := auth.GetOrgID(r.Context())
|
||||
if !ok {
|
||||
writeFeatureGateJSON(w, http.StatusUnauthorized, map[string]string{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
decision, err := limitsSvc.CheckCustomerLimit(r.Context(), orgID)
|
||||
if err != nil {
|
||||
writeFeatureGateJSON(w, http.StatusInternalServerError, map[string]string{"error": "internal server error"})
|
||||
return
|
||||
}
|
||||
|
||||
if !decision.Allowed {
|
||||
writeFeatureGateJSON(w, http.StatusPaymentRequired, map[string]any{
|
||||
"error": "plan limit reached",
|
||||
"current_plan": decision.CurrentPlan,
|
||||
"limit_type": decision.LimitType,
|
||||
"current_usage": decision.CurrentUsage,
|
||||
"limit": decision.Limit,
|
||||
"recommended_upgrade_tier": decision.RecommendedUpgradeTier,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func writeFeatureGateJSON(w http.ResponseWriter, status int, body any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(body)
|
||||
}
|
||||
43
internal/repository/billing_webhook_event.go
Normal file
43
internal/repository/billing_webhook_event.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// BillingWebhookEventRepository tracks processed billing webhook events for idempotency.
|
||||
type BillingWebhookEventRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewBillingWebhookEventRepository creates a new BillingWebhookEventRepository.
|
||||
func NewBillingWebhookEventRepository(pool *pgxpool.Pool) *BillingWebhookEventRepository {
|
||||
return &BillingWebhookEventRepository{pool: pool}
|
||||
}
|
||||
|
||||
// MarkProcessed inserts an event id and returns true if inserted, false if already processed.
|
||||
func (r *BillingWebhookEventRepository) MarkProcessed(ctx context.Context, eventID, eventType string) (bool, error) {
|
||||
cmdTag, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO billing_webhook_events (event_id, event_type)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (event_id) DO NOTHING`, eventID, eventType)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("insert billing webhook event: %w", err)
|
||||
}
|
||||
return cmdTag.RowsAffected() == 1, nil
|
||||
}
|
||||
|
||||
// MarkProcessedTx inserts an event id in an existing transaction.
|
||||
func (r *BillingWebhookEventRepository) MarkProcessedTx(ctx context.Context, tx pgx.Tx, eventID, eventType string) (bool, error) {
|
||||
cmdTag, err := tx.Exec(ctx, `
|
||||
INSERT INTO billing_webhook_events (event_id, event_type)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (event_id) DO NOTHING`, eventID, eventType)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("insert billing webhook event in tx: %w", err)
|
||||
}
|
||||
return cmdTag.RowsAffected() == 1, nil
|
||||
}
|
||||
@@ -237,3 +237,14 @@ func (r *IntegrationConnectionRepository) GetCustomerCountBySource(ctx context.C
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountActiveByOrg returns the number of active integration connections for an org.
|
||||
func (r *IntegrationConnectionRepository) CountActiveByOrg(ctx context.Context, orgID uuid.UUID) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM integration_connections WHERE org_id = $1 AND status = 'active'`
|
||||
var count int
|
||||
err := r.pool.QueryRow(ctx, query, orgID).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count active integrations: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
194
internal/repository/org_subscription.go
Normal file
194
internal/repository/org_subscription.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// OrgSubscription represents an org_subscriptions row.
|
||||
type OrgSubscription struct {
|
||||
ID uuid.UUID
|
||||
OrgID uuid.UUID
|
||||
StripeSubscriptionID string
|
||||
StripeCustomerID string
|
||||
PlanTier string
|
||||
BillingCycle string
|
||||
Status string
|
||||
CurrentPeriodStart *time.Time
|
||||
CurrentPeriodEnd *time.Time
|
||||
CancelAtPeriodEnd bool
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// OrgSubscriptionRepository handles org subscription database operations.
|
||||
type OrgSubscriptionRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewOrgSubscriptionRepository creates a new OrgSubscriptionRepository.
|
||||
func NewOrgSubscriptionRepository(pool *pgxpool.Pool) *OrgSubscriptionRepository {
|
||||
return &OrgSubscriptionRepository{pool: pool}
|
||||
}
|
||||
|
||||
const upsertOrgSubscriptionQuery = `
|
||||
INSERT INTO org_subscriptions (
|
||||
org_id, stripe_subscription_id, stripe_customer_id, plan_tier, billing_cycle,
|
||||
status, current_period_start, current_period_end, cancel_at_period_end
|
||||
)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
ON CONFLICT (org_id) DO UPDATE SET
|
||||
stripe_subscription_id = EXCLUDED.stripe_subscription_id,
|
||||
stripe_customer_id = EXCLUDED.stripe_customer_id,
|
||||
plan_tier = EXCLUDED.plan_tier,
|
||||
billing_cycle = EXCLUDED.billing_cycle,
|
||||
status = EXCLUDED.status,
|
||||
current_period_start = EXCLUDED.current_period_start,
|
||||
current_period_end = EXCLUDED.current_period_end,
|
||||
cancel_at_period_end = EXCLUDED.cancel_at_period_end
|
||||
RETURNING id, created_at, updated_at`
|
||||
|
||||
// UpsertByOrg creates or updates org subscription state by org ID.
|
||||
func (r *OrgSubscriptionRepository) UpsertByOrg(ctx context.Context, sub *OrgSubscription) error {
|
||||
return r.pool.QueryRow(ctx, upsertOrgSubscriptionQuery,
|
||||
sub.OrgID,
|
||||
emptyToNil(sub.StripeSubscriptionID),
|
||||
emptyToNil(sub.StripeCustomerID),
|
||||
sub.PlanTier,
|
||||
sub.BillingCycle,
|
||||
sub.Status,
|
||||
sub.CurrentPeriodStart,
|
||||
sub.CurrentPeriodEnd,
|
||||
sub.CancelAtPeriodEnd,
|
||||
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
|
||||
}
|
||||
|
||||
// UpsertByOrgTx creates or updates org subscription state by org ID in an existing transaction.
|
||||
func (r *OrgSubscriptionRepository) UpsertByOrgTx(ctx context.Context, tx pgx.Tx, sub *OrgSubscription) error {
|
||||
return tx.QueryRow(ctx, upsertOrgSubscriptionQuery,
|
||||
sub.OrgID,
|
||||
emptyToNil(sub.StripeSubscriptionID),
|
||||
emptyToNil(sub.StripeCustomerID),
|
||||
sub.PlanTier,
|
||||
sub.BillingCycle,
|
||||
sub.Status,
|
||||
sub.CurrentPeriodStart,
|
||||
sub.CurrentPeriodEnd,
|
||||
sub.CancelAtPeriodEnd,
|
||||
).Scan(&sub.ID, &sub.CreatedAt, &sub.UpdatedAt)
|
||||
}
|
||||
|
||||
// GetByOrg retrieves org subscription state by organization ID.
|
||||
func (r *OrgSubscriptionRepository) GetByOrg(ctx context.Context, orgID uuid.UUID) (*OrgSubscription, error) {
|
||||
query := `
|
||||
SELECT id, org_id, COALESCE(stripe_subscription_id, ''), COALESCE(stripe_customer_id, ''),
|
||||
plan_tier, billing_cycle, status, current_period_start, current_period_end,
|
||||
cancel_at_period_end, created_at, updated_at
|
||||
FROM org_subscriptions
|
||||
WHERE org_id = $1`
|
||||
|
||||
sub := &OrgSubscription{}
|
||||
err := r.pool.QueryRow(ctx, query, orgID).Scan(
|
||||
&sub.ID,
|
||||
&sub.OrgID,
|
||||
&sub.StripeSubscriptionID,
|
||||
&sub.StripeCustomerID,
|
||||
&sub.PlanTier,
|
||||
&sub.BillingCycle,
|
||||
&sub.Status,
|
||||
&sub.CurrentPeriodStart,
|
||||
&sub.CurrentPeriodEnd,
|
||||
&sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt,
|
||||
&sub.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org subscription by org: %w", err)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// GetByStripeSubscriptionID retrieves org subscription state by Stripe subscription ID.
|
||||
func (r *OrgSubscriptionRepository) GetByStripeSubscriptionID(ctx context.Context, stripeSubscriptionID string) (*OrgSubscription, error) {
|
||||
query := `
|
||||
SELECT id, org_id, COALESCE(stripe_subscription_id, ''), COALESCE(stripe_customer_id, ''),
|
||||
plan_tier, billing_cycle, status, current_period_start, current_period_end,
|
||||
cancel_at_period_end, created_at, updated_at
|
||||
FROM org_subscriptions
|
||||
WHERE stripe_subscription_id = $1`
|
||||
|
||||
sub := &OrgSubscription{}
|
||||
err := r.pool.QueryRow(ctx, query, stripeSubscriptionID).Scan(
|
||||
&sub.ID,
|
||||
&sub.OrgID,
|
||||
&sub.StripeSubscriptionID,
|
||||
&sub.StripeCustomerID,
|
||||
&sub.PlanTier,
|
||||
&sub.BillingCycle,
|
||||
&sub.Status,
|
||||
&sub.CurrentPeriodStart,
|
||||
&sub.CurrentPeriodEnd,
|
||||
&sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt,
|
||||
&sub.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org subscription by stripe subscription id: %w", err)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
// GetByStripeCustomerID retrieves org subscription state by Stripe customer ID.
|
||||
func (r *OrgSubscriptionRepository) GetByStripeCustomerID(ctx context.Context, stripeCustomerID string) (*OrgSubscription, error) {
|
||||
query := `
|
||||
SELECT id, org_id, COALESCE(stripe_subscription_id, ''), COALESCE(stripe_customer_id, ''),
|
||||
plan_tier, billing_cycle, status, current_period_start, current_period_end,
|
||||
cancel_at_period_end, created_at, updated_at
|
||||
FROM org_subscriptions
|
||||
WHERE stripe_customer_id = $1`
|
||||
|
||||
sub := &OrgSubscription{}
|
||||
err := r.pool.QueryRow(ctx, query, stripeCustomerID).Scan(
|
||||
&sub.ID,
|
||||
&sub.OrgID,
|
||||
&sub.StripeSubscriptionID,
|
||||
&sub.StripeCustomerID,
|
||||
&sub.PlanTier,
|
||||
&sub.BillingCycle,
|
||||
&sub.Status,
|
||||
&sub.CurrentPeriodStart,
|
||||
&sub.CurrentPeriodEnd,
|
||||
&sub.CancelAtPeriodEnd,
|
||||
&sub.CreatedAt,
|
||||
&sub.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org subscription by stripe customer id: %w", err)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
func emptyToNil(s string) any {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -142,6 +142,73 @@ func (r *OrganizationRepository) Update(ctx context.Context, orgID uuid.UUID, na
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePlan updates an org's billing plan.
|
||||
func (r *OrganizationRepository) UpdatePlan(ctx context.Context, orgID uuid.UUID, plan string) error {
|
||||
query := `UPDATE organizations SET plan = $2 WHERE id = $1 AND deleted_at IS NULL`
|
||||
_, err := r.pool.Exec(ctx, query, orgID, plan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update org plan: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePlanTx updates an org's billing plan inside an existing transaction.
|
||||
func (r *OrganizationRepository) UpdatePlanTx(ctx context.Context, tx pgx.Tx, orgID uuid.UUID, plan string) error {
|
||||
query := `UPDATE organizations SET plan = $2 WHERE id = $1 AND deleted_at IS NULL`
|
||||
_, err := tx.Exec(ctx, query, orgID, plan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update org plan in tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStripeCustomerID updates an org's Stripe customer ID.
|
||||
func (r *OrganizationRepository) UpdateStripeCustomerID(ctx context.Context, orgID uuid.UUID, stripeCustomerID string) error {
|
||||
query := `UPDATE organizations SET stripe_customer_id = $2 WHERE id = $1 AND deleted_at IS NULL`
|
||||
_, err := r.pool.Exec(ctx, query, orgID, stripeCustomerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update org stripe customer id: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStripeCustomerIDTx updates an org's Stripe customer ID inside an existing transaction.
|
||||
func (r *OrganizationRepository) UpdateStripeCustomerIDTx(ctx context.Context, tx pgx.Tx, orgID uuid.UUID, stripeCustomerID string) error {
|
||||
query := `UPDATE organizations SET stripe_customer_id = $2 WHERE id = $1 AND deleted_at IS NULL`
|
||||
_, err := tx.Exec(ctx, query, orgID, stripeCustomerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update org stripe customer id in tx: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByStripeCustomerID retrieves an organization by Stripe customer ID.
|
||||
func (r *OrganizationRepository) GetByStripeCustomerID(ctx context.Context, stripeCustomerID string) (*Organization, error) {
|
||||
query := `
|
||||
SELECT id, name, slug, plan, COALESCE(stripe_customer_id, ''), created_at, updated_at
|
||||
FROM organizations
|
||||
WHERE stripe_customer_id = $1 AND deleted_at IS NULL
|
||||
LIMIT 1`
|
||||
|
||||
o := &Organization{}
|
||||
err := r.pool.QueryRow(ctx, query, stripeCustomerID).Scan(
|
||||
&o.ID,
|
||||
&o.Name,
|
||||
&o.Slug,
|
||||
&o.Plan,
|
||||
&o.StripeCustomerID,
|
||||
&o.CreatedAt,
|
||||
&o.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org by stripe customer id: %w", err)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// CountMembers returns the number of members in an org.
|
||||
func (r *OrganizationRepository) CountMembers(ctx context.Context, orgID uuid.UUID) (int, error) {
|
||||
query := `SELECT COUNT(*) FROM user_organizations WHERE org_id = $1`
|
||||
@@ -155,13 +222,13 @@ func (r *OrganizationRepository) CountMembers(ctx context.Context, orgID uuid.UU
|
||||
|
||||
// OrgMember holds a member of an org with user details.
|
||||
type OrgMember struct {
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"first_name"`
|
||||
LastName string `json:"last_name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
Role string `json:"role"`
|
||||
JoinedAt time.Time `json:"joined_at"`
|
||||
}
|
||||
|
||||
// ListMembers returns all members of an org with user details.
|
||||
|
||||
@@ -58,6 +58,7 @@ type AuthOrg struct {
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Role string `json:"role"`
|
||||
Plan string `json:"plan"`
|
||||
}
|
||||
|
||||
// RefreshRequest holds the input for token refresh.
|
||||
@@ -189,6 +190,7 @@ func (s *AuthService) Register(ctx context.Context, req RegisterRequest) (*AuthR
|
||||
Name: org.Name,
|
||||
Slug: org.Slug,
|
||||
Role: "owner",
|
||||
Plan: "free",
|
||||
},
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
@@ -282,6 +284,7 @@ func (s *AuthService) Login(ctx context.Context, req LoginRequest) (*AuthRespons
|
||||
Name: orgDetails.Name,
|
||||
Slug: orgDetails.Slug,
|
||||
Role: defaultOrg.Role,
|
||||
Plan: orgDetails.Plan,
|
||||
},
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
@@ -411,6 +414,7 @@ func (s *AuthService) Refresh(ctx context.Context, req RefreshRequest) (*AuthRes
|
||||
Name: orgDetails.Name,
|
||||
Slug: orgDetails.Slug,
|
||||
Role: defaultOrg.Role,
|
||||
Plan: orgDetails.Plan,
|
||||
},
|
||||
Tokens: tokens,
|
||||
}, nil
|
||||
|
||||
170
internal/service/billing/checkout.go
Normal file
170
internal/service/billing/checkout.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
checkoutsession "github.com/stripe/stripe-go/v81/checkout/session"
|
||||
stripecustomer "github.com/stripe/stripe-go/v81/customer"
|
||||
|
||||
planmodel "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
core "github.com/onnwee/pulse-score/internal/service"
|
||||
)
|
||||
|
||||
type orgBillingRepository interface {
|
||||
GetByID(ctx context.Context, id uuid.UUID) (*repository.Organization, error)
|
||||
UpdateStripeCustomerID(ctx context.Context, orgID uuid.UUID, stripeCustomerID string) error
|
||||
}
|
||||
|
||||
// CreateCheckoutSessionRequest defines supported inputs for checkout creation.
|
||||
type CreateCheckoutSessionRequest struct {
|
||||
PriceID string `json:"priceId"`
|
||||
Tier string `json:"tier"`
|
||||
Cycle string `json:"cycle"`
|
||||
Annual bool `json:"annual"`
|
||||
}
|
||||
|
||||
// CreateCheckoutSessionResponse returns the Stripe hosted checkout URL.
|
||||
type CreateCheckoutSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// CheckoutService handles Stripe checkout session creation for PulseScore billing.
|
||||
type CheckoutService struct {
|
||||
stripeSecretKey string
|
||||
frontendURL string
|
||||
orgs orgBillingRepository
|
||||
catalog *planmodel.Catalog
|
||||
}
|
||||
|
||||
func NewCheckoutService(
|
||||
stripeSecretKey, frontendURL string,
|
||||
orgs orgBillingRepository,
|
||||
catalog *planmodel.Catalog,
|
||||
) *CheckoutService {
|
||||
return &CheckoutService{
|
||||
stripeSecretKey: strings.TrimSpace(stripeSecretKey),
|
||||
frontendURL: strings.TrimRight(strings.TrimSpace(frontendURL), "/"),
|
||||
orgs: orgs,
|
||||
catalog: catalog,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CheckoutService) CreateCheckoutSession(ctx context.Context, orgID, userID uuid.UUID, req CreateCheckoutSessionRequest) (*CreateCheckoutSessionResponse, error) {
|
||||
if s.stripeSecretKey == "" {
|
||||
return nil, &core.ValidationError{Field: "billing", Message: "stripe billing is not configured"}
|
||||
}
|
||||
|
||||
priceID, tier, cycle, err := s.resolvePriceDetails(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
org, err := s.orgs.GetByID(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org: %w", err)
|
||||
}
|
||||
if org == nil {
|
||||
return nil, &core.NotFoundError{Resource: "organization", Message: "organization not found"}
|
||||
}
|
||||
|
||||
customerID := strings.TrimSpace(org.StripeCustomerID)
|
||||
if customerID == "" {
|
||||
customerID, err = s.createStripeCustomer(ctx, org)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stripe customer: %w", err)
|
||||
}
|
||||
if err := s.orgs.UpdateStripeCustomerID(ctx, orgID, customerID); err != nil {
|
||||
return nil, fmt.Errorf("persist stripe customer id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
successURL := fmt.Sprintf("%s/settings/billing?checkout=success", s.frontendURL)
|
||||
cancelURL := fmt.Sprintf("%s/settings/billing?checkout=cancelled", s.frontendURL)
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
|
||||
SuccessURL: stripe.String(successURL),
|
||||
CancelURL: stripe.String(cancelURL),
|
||||
Customer: stripe.String(customerID),
|
||||
ClientReferenceID: stripe.String(orgID.String()),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{Price: stripe.String(priceID), Quantity: stripe.Int64(1)},
|
||||
},
|
||||
Metadata: map[string]string{
|
||||
"org_id": orgID.String(),
|
||||
"user_id": userID.String(),
|
||||
"tier": string(tier),
|
||||
"cycle": string(cycle),
|
||||
},
|
||||
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
|
||||
Metadata: map[string]string{
|
||||
"org_id": orgID.String(),
|
||||
"user_id": userID.String(),
|
||||
"tier": string(tier),
|
||||
"cycle": string(cycle),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
client := checkoutsession.Client{B: stripe.GetBackend(stripe.APIBackend), Key: s.stripeSecretKey}
|
||||
session, err := client.New(params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create checkout session: %w", err)
|
||||
}
|
||||
|
||||
return &CreateCheckoutSessionResponse{URL: session.URL}, nil
|
||||
}
|
||||
|
||||
func (s *CheckoutService) resolvePriceDetails(req CreateCheckoutSessionRequest) (string, planmodel.Tier, planmodel.BillingCycle, error) {
|
||||
if strings.TrimSpace(req.PriceID) != "" {
|
||||
tier, cycle, ok := s.catalog.ResolveTierAndCycleByPriceID(req.PriceID)
|
||||
if !ok {
|
||||
return "", "", "", &core.ValidationError{Field: "priceId", Message: "unsupported price id"}
|
||||
}
|
||||
return req.PriceID, tier, cycle, nil
|
||||
}
|
||||
|
||||
tier := planmodel.NormalizeTier(req.Tier)
|
||||
if tier == planmodel.TierFree {
|
||||
return "", "", "", &core.ValidationError{Field: "tier", Message: "free tier does not require checkout"}
|
||||
}
|
||||
|
||||
annual := req.Annual
|
||||
if strings.EqualFold(strings.TrimSpace(req.Cycle), string(planmodel.BillingCycleAnnual)) {
|
||||
annual = true
|
||||
}
|
||||
|
||||
priceID, err := s.catalog.GetPriceID(string(tier), annual)
|
||||
if err != nil {
|
||||
return "", "", "", &core.ValidationError{Field: "tier", Message: err.Error()}
|
||||
}
|
||||
|
||||
cycle := planmodel.BillingCycleMonthly
|
||||
if annual {
|
||||
cycle = planmodel.BillingCycleAnnual
|
||||
}
|
||||
|
||||
return priceID, tier, cycle, nil
|
||||
}
|
||||
|
||||
func (s *CheckoutService) createStripeCustomer(ctx context.Context, org *repository.Organization) (string, error) {
|
||||
params := &stripe.CustomerParams{
|
||||
Name: stripe.String(org.Name),
|
||||
Metadata: map[string]string{
|
||||
"org_id": org.ID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
client := stripecustomer.Client{B: stripe.GetBackend(stripe.APIBackend), Key: s.stripeSecretKey}
|
||||
cust, err := client.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return cust.ID, nil
|
||||
}
|
||||
160
internal/service/billing/limits.go
Normal file
160
internal/service/billing/limits.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
planmodel "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
)
|
||||
|
||||
type integrationLookup interface {
|
||||
GetByOrgAndProvider(ctx context.Context, orgID uuid.UUID, provider string) (*repository.IntegrationConnection, error)
|
||||
}
|
||||
|
||||
// LimitDecision is used for feature-gating responses and middleware decisions.
|
||||
type LimitDecision struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
CurrentPlan string `json:"current_plan"`
|
||||
LimitType string `json:"limit_type"`
|
||||
CurrentUsage int `json:"current_usage"`
|
||||
Limit int `json:"limit"`
|
||||
RecommendedUpgradeTier string `json:"recommended_upgrade_tier"`
|
||||
}
|
||||
|
||||
// FeatureDecision is used for plan-based feature access checks.
|
||||
type FeatureDecision struct {
|
||||
Allowed bool `json:"allowed"`
|
||||
CurrentPlan string `json:"current_plan"`
|
||||
Feature string `json:"feature"`
|
||||
RecommendedUpgradeTier string `json:"recommended_upgrade_tier"`
|
||||
}
|
||||
|
||||
// LimitsService handles server-side billing limits and feature access checks.
|
||||
type LimitsService struct {
|
||||
subscriptions *SubscriptionService
|
||||
customers customerCounter
|
||||
integrationCounter integrationCounter
|
||||
integrationLookup integrationLookup
|
||||
catalog *planmodel.Catalog
|
||||
}
|
||||
|
||||
func NewLimitsService(
|
||||
subscriptions *SubscriptionService,
|
||||
customers customerCounter,
|
||||
integrationCounter integrationCounter,
|
||||
integrationLookup integrationLookup,
|
||||
catalog *planmodel.Catalog,
|
||||
) *LimitsService {
|
||||
return &LimitsService{
|
||||
subscriptions: subscriptions,
|
||||
customers: customers,
|
||||
integrationCounter: integrationCounter,
|
||||
integrationLookup: integrationLookup,
|
||||
catalog: catalog,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LimitsService) CheckCustomerLimit(ctx context.Context, orgID uuid.UUID) (*LimitDecision, error) {
|
||||
tier, err := s.subscriptions.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
limits, ok := s.catalog.GetLimits(tier)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no limits configured for tier %s", tier)
|
||||
}
|
||||
|
||||
used, err := s.customers.CountByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count customers: %w", err)
|
||||
}
|
||||
|
||||
return s.buildLimitDecision(tier, "customer_limit", used, limits.CustomerLimit), nil
|
||||
}
|
||||
|
||||
func (s *LimitsService) CheckIntegrationLimit(ctx context.Context, orgID uuid.UUID, provider string) (*LimitDecision, error) {
|
||||
if provider != "" {
|
||||
conn, err := s.integrationLookup.GetByOrgAndProvider(ctx, orgID, provider)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get integration by provider: %w", err)
|
||||
}
|
||||
if conn != nil && conn.Status == "active" {
|
||||
tier, err := s.subscriptions.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LimitDecision{Allowed: true, CurrentPlan: tier}, nil
|
||||
}
|
||||
}
|
||||
|
||||
tier, err := s.subscriptions.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
limits, ok := s.catalog.GetLimits(tier)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no limits configured for tier %s", tier)
|
||||
}
|
||||
|
||||
used, err := s.integrationCounter.CountActiveByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count active integrations: %w", err)
|
||||
}
|
||||
|
||||
return s.buildLimitDecision(tier, "integration_limit", used, limits.IntegrationLimit), nil
|
||||
}
|
||||
|
||||
func (s *LimitsService) CanAccess(ctx context.Context, orgID uuid.UUID, featureName string) (*FeatureDecision, error) {
|
||||
tier, err := s.subscriptions.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plan, ok := s.catalog.GetPlanByTier(tier)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no plan configured for tier %s", tier)
|
||||
}
|
||||
|
||||
allowed := false
|
||||
switch featureName {
|
||||
case "playbooks":
|
||||
allowed = plan.Features.Playbooks
|
||||
case "ai_insights":
|
||||
allowed = plan.Features.AIInsights
|
||||
default:
|
||||
allowed = false
|
||||
}
|
||||
|
||||
decision := &FeatureDecision{
|
||||
Allowed: allowed,
|
||||
CurrentPlan: tier,
|
||||
Feature: featureName,
|
||||
}
|
||||
if !allowed {
|
||||
decision.RecommendedUpgradeTier = string(s.catalog.RecommendedUpgrade(tier))
|
||||
}
|
||||
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func (s *LimitsService) buildLimitDecision(tier, limitType string, used, limit int) *LimitDecision {
|
||||
decision := &LimitDecision{
|
||||
Allowed: true,
|
||||
CurrentPlan: tier,
|
||||
LimitType: limitType,
|
||||
CurrentUsage: used,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
if limit != planmodel.Unlimited && used >= limit {
|
||||
decision.Allowed = false
|
||||
decision.RecommendedUpgradeTier = string(s.catalog.RecommendedUpgrade(tier))
|
||||
}
|
||||
|
||||
return decision
|
||||
}
|
||||
148
internal/service/billing/portal.go
Normal file
148
internal/service/billing/portal.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
stripeportal "github.com/stripe/stripe-go/v81/billingportal/session"
|
||||
stripecustomer "github.com/stripe/stripe-go/v81/customer"
|
||||
stripesubscription "github.com/stripe/stripe-go/v81/subscription"
|
||||
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
core "github.com/onnwee/pulse-score/internal/service"
|
||||
)
|
||||
|
||||
type orgSubscriptionWriter interface {
|
||||
GetByOrg(ctx context.Context, orgID uuid.UUID) (*repository.OrgSubscription, error)
|
||||
UpsertByOrg(ctx context.Context, sub *repository.OrgSubscription) error
|
||||
}
|
||||
|
||||
// PortalSessionResponse returns hosted Stripe customer portal URL.
|
||||
type PortalSessionResponse struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// PortalService handles Stripe customer portal and cancellation operations.
|
||||
type PortalService struct {
|
||||
stripeSecretKey string
|
||||
portalReturnURL string
|
||||
frontendURL string
|
||||
orgs orgBillingRepository
|
||||
subscriptions orgSubscriptionWriter
|
||||
}
|
||||
|
||||
func NewPortalService(
|
||||
stripeSecretKey, portalReturnURL, frontendURL string,
|
||||
orgs orgBillingRepository,
|
||||
subscriptions orgSubscriptionWriter,
|
||||
) *PortalService {
|
||||
return &PortalService{
|
||||
stripeSecretKey: strings.TrimSpace(stripeSecretKey),
|
||||
portalReturnURL: strings.TrimSpace(portalReturnURL),
|
||||
frontendURL: strings.TrimRight(strings.TrimSpace(frontendURL), "/"),
|
||||
orgs: orgs,
|
||||
subscriptions: subscriptions,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PortalService) CreatePortalSession(ctx context.Context, orgID uuid.UUID) (*PortalSessionResponse, error) {
|
||||
if s.stripeSecretKey == "" {
|
||||
return nil, &core.ValidationError{Field: "billing", Message: "stripe billing is not configured"}
|
||||
}
|
||||
|
||||
org, err := s.orgs.GetByID(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get organization: %w", err)
|
||||
}
|
||||
if org == nil {
|
||||
return nil, &core.NotFoundError{Resource: "organization", Message: "organization not found"}
|
||||
}
|
||||
|
||||
customerID := strings.TrimSpace(org.StripeCustomerID)
|
||||
if customerID == "" {
|
||||
customerID, err = s.createStripeCustomer(ctx, org)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stripe customer: %w", err)
|
||||
}
|
||||
if err := s.orgs.UpdateStripeCustomerID(ctx, orgID, customerID); err != nil {
|
||||
return nil, fmt.Errorf("persist stripe customer id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
returnURL := s.portalReturnURL
|
||||
if returnURL == "" {
|
||||
returnURL = fmt.Sprintf("%s/settings/billing", s.frontendURL)
|
||||
}
|
||||
|
||||
params := &stripe.BillingPortalSessionParams{
|
||||
Customer: stripe.String(customerID),
|
||||
ReturnURL: stripe.String(returnURL),
|
||||
}
|
||||
|
||||
client := stripeportal.Client{B: stripe.GetBackend(stripe.APIBackend), Key: s.stripeSecretKey}
|
||||
portalSession, err := client.New(params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create portal session: %w", err)
|
||||
}
|
||||
|
||||
return &PortalSessionResponse{URL: portalSession.URL}, nil
|
||||
}
|
||||
|
||||
func (s *PortalService) CancelAtPeriodEnd(ctx context.Context, orgID uuid.UUID) error {
|
||||
if s.stripeSecretKey == "" {
|
||||
return &core.ValidationError{Field: "billing", Message: "stripe billing is not configured"}
|
||||
}
|
||||
|
||||
sub, err := s.subscriptions.GetByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get subscription: %w", err)
|
||||
}
|
||||
if sub == nil || strings.TrimSpace(sub.StripeSubscriptionID) == "" {
|
||||
return &core.NotFoundError{Resource: "subscription", Message: "no active Stripe subscription found"}
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{CancelAtPeriodEnd: stripe.Bool(true)}
|
||||
client := stripesubscription.Client{B: stripe.GetBackend(stripe.APIBackend), Key: s.stripeSecretKey}
|
||||
stripeSub, err := client.Update(sub.StripeSubscriptionID, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cancel subscription at period end: %w", err)
|
||||
}
|
||||
|
||||
sub.CancelAtPeriodEnd = true
|
||||
sub.Status = string(stripeSub.Status)
|
||||
if stripeSub.CurrentPeriodStart > 0 {
|
||||
t := time.Unix(stripeSub.CurrentPeriodStart, 0)
|
||||
sub.CurrentPeriodStart = &t
|
||||
}
|
||||
if stripeSub.CurrentPeriodEnd > 0 {
|
||||
t := time.Unix(stripeSub.CurrentPeriodEnd, 0)
|
||||
sub.CurrentPeriodEnd = &t
|
||||
}
|
||||
|
||||
if err := s.subscriptions.UpsertByOrg(ctx, sub); err != nil {
|
||||
return fmt.Errorf("persist subscription cancellation state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PortalService) createStripeCustomer(ctx context.Context, org *repository.Organization) (string, error) {
|
||||
params := &stripe.CustomerParams{
|
||||
Name: stripe.String(org.Name),
|
||||
Metadata: map[string]string{
|
||||
"org_id": org.ID.String(),
|
||||
},
|
||||
}
|
||||
|
||||
client := stripecustomer.Client{B: stripe.GetBackend(stripe.APIBackend), Key: s.stripeSecretKey}
|
||||
cust, err := client.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return cust.ID, nil
|
||||
}
|
||||
186
internal/service/billing/subscription.go
Normal file
186
internal/service/billing/subscription.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
planmodel "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
)
|
||||
|
||||
type orgSubscriptionReader interface {
|
||||
GetByOrg(ctx context.Context, orgID uuid.UUID) (*repository.OrgSubscription, error)
|
||||
}
|
||||
|
||||
type organizationReader interface {
|
||||
GetByID(ctx context.Context, id uuid.UUID) (*repository.Organization, error)
|
||||
}
|
||||
|
||||
type customerCounter interface {
|
||||
CountByOrg(ctx context.Context, orgID uuid.UUID) (int, error)
|
||||
}
|
||||
|
||||
type integrationCounter interface {
|
||||
CountActiveByOrg(ctx context.Context, orgID uuid.UUID) (int, error)
|
||||
}
|
||||
|
||||
// SubscriptionService exposes org-level subscription state and limit information.
|
||||
type SubscriptionService struct {
|
||||
subscriptions orgSubscriptionReader
|
||||
orgs organizationReader
|
||||
customers customerCounter
|
||||
integrations integrationCounter
|
||||
catalog *planmodel.Catalog
|
||||
}
|
||||
|
||||
// UsageSnapshot contains current usage against plan limits.
|
||||
type UsageSnapshot struct {
|
||||
Customers struct {
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
} `json:"customers"`
|
||||
Integrations struct {
|
||||
Used int `json:"used"`
|
||||
Limit int `json:"limit"`
|
||||
} `json:"integrations"`
|
||||
}
|
||||
|
||||
// SubscriptionSummary is returned by GET /api/v1/billing/subscription.
|
||||
type SubscriptionSummary struct {
|
||||
Tier string `json:"tier"`
|
||||
Status string `json:"status"`
|
||||
BillingCycle string `json:"billing_cycle"`
|
||||
RenewalDate *time.Time `json:"renewal_date"`
|
||||
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
||||
Usage UsageSnapshot `json:"usage"`
|
||||
Features map[string]any `json:"features"`
|
||||
}
|
||||
|
||||
func NewSubscriptionService(
|
||||
subscriptions orgSubscriptionReader,
|
||||
orgs organizationReader,
|
||||
customers customerCounter,
|
||||
integrations integrationCounter,
|
||||
catalog *planmodel.Catalog,
|
||||
) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
subscriptions: subscriptions,
|
||||
orgs: orgs,
|
||||
customers: customers,
|
||||
integrations: integrations,
|
||||
catalog: catalog,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCurrentPlan resolves the current tier, falling back to org.plan then free.
|
||||
func (s *SubscriptionService) GetCurrentPlan(ctx context.Context, orgID uuid.UUID) (string, error) {
|
||||
sub, err := s.subscriptions.GetByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get org subscription: %w", err)
|
||||
}
|
||||
if sub != nil && strings.TrimSpace(sub.PlanTier) != "" {
|
||||
return string(planmodel.NormalizeTier(sub.PlanTier)), nil
|
||||
}
|
||||
|
||||
org, err := s.orgs.GetByID(ctx, orgID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get organization: %w", err)
|
||||
}
|
||||
if org != nil && strings.TrimSpace(org.Plan) != "" {
|
||||
return string(planmodel.NormalizeTier(org.Plan)), nil
|
||||
}
|
||||
|
||||
return string(planmodel.TierFree), nil
|
||||
}
|
||||
|
||||
// IsActive reports whether the org subscription status is currently active.
|
||||
func (s *SubscriptionService) IsActive(ctx context.Context, orgID uuid.UUID) (bool, error) {
|
||||
sub, err := s.subscriptions.GetByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get org subscription: %w", err)
|
||||
}
|
||||
if sub == nil {
|
||||
return true, nil // Free tier fallback should remain usable.
|
||||
}
|
||||
|
||||
status := strings.ToLower(strings.TrimSpace(sub.Status))
|
||||
switch status {
|
||||
case "active", "trialing", "past_due":
|
||||
return true, nil
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GetUsageLimits resolves the current usage limits from the plan catalog.
|
||||
func (s *SubscriptionService) GetUsageLimits(ctx context.Context, orgID uuid.UUID) (planmodel.UsageLimits, error) {
|
||||
tier, err := s.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return planmodel.UsageLimits{}, err
|
||||
}
|
||||
|
||||
limits, ok := s.catalog.GetLimits(tier)
|
||||
if !ok {
|
||||
return planmodel.UsageLimits{}, fmt.Errorf("no limits configured for tier %s", tier)
|
||||
}
|
||||
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
// GetSubscriptionSummary returns the current subscription state and usage counters.
|
||||
func (s *SubscriptionService) GetSubscriptionSummary(ctx context.Context, orgID uuid.UUID) (*SubscriptionSummary, error) {
|
||||
tier, err := s.GetCurrentPlan(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
limits, err := s.GetUsageLimits(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
customerCount, err := s.customers.CountByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count customers: %w", err)
|
||||
}
|
||||
integrationCount, err := s.integrations.CountActiveByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("count integrations: %w", err)
|
||||
}
|
||||
|
||||
summary := &SubscriptionSummary{
|
||||
Tier: tier,
|
||||
Status: "free",
|
||||
BillingCycle: string(planmodel.BillingCycleMonthly),
|
||||
Features: map[string]any{},
|
||||
}
|
||||
|
||||
summary.Usage.Customers.Used = customerCount
|
||||
summary.Usage.Customers.Limit = limits.CustomerLimit
|
||||
summary.Usage.Integrations.Used = integrationCount
|
||||
summary.Usage.Integrations.Limit = limits.IntegrationLimit
|
||||
|
||||
if plan, ok := s.catalog.GetPlanByTier(tier); ok {
|
||||
summary.Features["playbooks"] = plan.Features.Playbooks
|
||||
summary.Features["ai_insights"] = plan.Features.AIInsights
|
||||
}
|
||||
|
||||
sub, err := s.subscriptions.GetByOrg(ctx, orgID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get org subscription: %w", err)
|
||||
}
|
||||
if sub != nil {
|
||||
summary.Status = sub.Status
|
||||
if strings.TrimSpace(sub.BillingCycle) != "" {
|
||||
summary.BillingCycle = sub.BillingCycle
|
||||
}
|
||||
summary.RenewalDate = sub.CurrentPeriodEnd
|
||||
summary.CancelAtPeriodEnd = sub.CancelAtPeriodEnd
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
102
internal/service/billing/subscription_test.go
Normal file
102
internal/service/billing/subscription_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
planmodel "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
)
|
||||
|
||||
type mockOrgSubscriptionReader struct {
|
||||
getByOrgFn func(ctx context.Context, orgID uuid.UUID) (*repository.OrgSubscription, error)
|
||||
}
|
||||
|
||||
func (m *mockOrgSubscriptionReader) GetByOrg(ctx context.Context, orgID uuid.UUID) (*repository.OrgSubscription, error) {
|
||||
return m.getByOrgFn(ctx, orgID)
|
||||
}
|
||||
|
||||
type mockOrganizationReader struct {
|
||||
getByIDFn func(ctx context.Context, id uuid.UUID) (*repository.Organization, error)
|
||||
}
|
||||
|
||||
func (m *mockOrganizationReader) GetByID(ctx context.Context, id uuid.UUID) (*repository.Organization, error) {
|
||||
return m.getByIDFn(ctx, id)
|
||||
}
|
||||
|
||||
type mockCustomerCounter struct {
|
||||
countByOrgFn func(ctx context.Context, orgID uuid.UUID) (int, error)
|
||||
}
|
||||
|
||||
func (m *mockCustomerCounter) CountByOrg(ctx context.Context, orgID uuid.UUID) (int, error) {
|
||||
return m.countByOrgFn(ctx, orgID)
|
||||
}
|
||||
|
||||
type mockIntegrationCounter struct {
|
||||
countActiveByOrgFn func(ctx context.Context, orgID uuid.UUID) (int, error)
|
||||
}
|
||||
|
||||
func (m *mockIntegrationCounter) CountActiveByOrg(ctx context.Context, orgID uuid.UUID) (int, error) {
|
||||
return m.countActiveByOrgFn(ctx, orgID)
|
||||
}
|
||||
|
||||
func TestGetCurrentPlan_DefaultsToFreeWhenNoSubscriptionRow(t *testing.T) {
|
||||
svc := NewSubscriptionService(
|
||||
&mockOrgSubscriptionReader{getByOrgFn: func(context.Context, uuid.UUID) (*repository.OrgSubscription, error) {
|
||||
return nil, nil
|
||||
}},
|
||||
&mockOrganizationReader{getByIDFn: func(context.Context, uuid.UUID) (*repository.Organization, error) {
|
||||
return nil, nil
|
||||
}},
|
||||
&mockCustomerCounter{countByOrgFn: func(context.Context, uuid.UUID) (int, error) { return 0, nil }},
|
||||
&mockIntegrationCounter{countActiveByOrgFn: func(context.Context, uuid.UUID) (int, error) { return 0, nil }},
|
||||
planmodel.NewCatalog(planmodel.PriceConfig{}),
|
||||
)
|
||||
|
||||
plan, err := svc.GetCurrentPlan(context.Background(), uuid.New())
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if plan != string(planmodel.TierFree) {
|
||||
t.Fatalf("expected free plan fallback, got %s", plan)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsActiveStatusTransitions(t *testing.T) {
|
||||
orgID := uuid.New()
|
||||
|
||||
statuses := map[string]bool{
|
||||
"active": true,
|
||||
"trialing": true,
|
||||
"past_due": true,
|
||||
"canceled": false,
|
||||
"unpaid": false,
|
||||
"incomplete": false,
|
||||
}
|
||||
|
||||
for status, expected := range statuses {
|
||||
t.Run(status, func(t *testing.T) {
|
||||
svc := NewSubscriptionService(
|
||||
&mockOrgSubscriptionReader{getByOrgFn: func(context.Context, uuid.UUID) (*repository.OrgSubscription, error) {
|
||||
return &repository.OrgSubscription{OrgID: orgID, Status: status, PlanTier: "growth"}, nil
|
||||
}},
|
||||
&mockOrganizationReader{getByIDFn: func(context.Context, uuid.UUID) (*repository.Organization, error) {
|
||||
return &repository.Organization{ID: orgID, Plan: "growth"}, nil
|
||||
}},
|
||||
&mockCustomerCounter{countByOrgFn: func(context.Context, uuid.UUID) (int, error) { return 0, nil }},
|
||||
&mockIntegrationCounter{countActiveByOrgFn: func(context.Context, uuid.UUID) (int, error) { return 0, nil }},
|
||||
planmodel.NewCatalog(planmodel.PriceConfig{}),
|
||||
)
|
||||
|
||||
active, err := svc.IsActive(context.Background(), orgID)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if active != expected {
|
||||
t.Fatalf("expected active=%v for status=%s, got %v", expected, status, active)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
277
internal/service/billing/webhook.go
Normal file
277
internal/service/billing/webhook.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stripe/stripe-go/v81"
|
||||
"github.com/stripe/stripe-go/v81/webhook"
|
||||
|
||||
planmodel "github.com/onnwee/pulse-score/internal/billing"
|
||||
"github.com/onnwee/pulse-score/internal/repository"
|
||||
core "github.com/onnwee/pulse-score/internal/service"
|
||||
)
|
||||
|
||||
// WebhookService handles PulseScore billing Stripe webhook events.
|
||||
type WebhookService struct {
|
||||
webhookSecret string
|
||||
pool *pgxpool.Pool
|
||||
orgs *repository.OrganizationRepository
|
||||
subscriptions *repository.OrgSubscriptionRepository
|
||||
processed *repository.BillingWebhookEventRepository
|
||||
catalog *planmodel.Catalog
|
||||
}
|
||||
|
||||
func NewWebhookService(
|
||||
webhookSecret string,
|
||||
pool *pgxpool.Pool,
|
||||
orgs *repository.OrganizationRepository,
|
||||
subscriptions *repository.OrgSubscriptionRepository,
|
||||
processed *repository.BillingWebhookEventRepository,
|
||||
catalog *planmodel.Catalog,
|
||||
) *WebhookService {
|
||||
return &WebhookService{
|
||||
webhookSecret: strings.TrimSpace(webhookSecret),
|
||||
pool: pool,
|
||||
orgs: orgs,
|
||||
subscriptions: subscriptions,
|
||||
processed: processed,
|
||||
catalog: catalog,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebhookService) HandleEvent(ctx context.Context, payload []byte, sigHeader string) error {
|
||||
if s.webhookSecret == "" {
|
||||
return &core.ValidationError{Field: "billing", Message: "stripe billing webhook secret is not configured"}
|
||||
}
|
||||
|
||||
event, err := webhook.ConstructEvent(payload, sigHeader, s.webhookSecret)
|
||||
if err != nil {
|
||||
return &core.ValidationError{Field: "signature", Message: "invalid webhook signature"}
|
||||
}
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin webhook tx: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
inserted, err := s.processed.MarkProcessedTx(ctx, tx, event.ID, string(event.Type))
|
||||
if err != nil {
|
||||
return fmt.Errorf("mark webhook event processed: %w", err)
|
||||
}
|
||||
if !inserted {
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "checkout.session.completed":
|
||||
err = s.handleCheckoutSessionCompleted(ctx, tx, event)
|
||||
case "customer.subscription.created", "customer.subscription.updated":
|
||||
err = s.handleSubscriptionUpsert(ctx, tx, event, false)
|
||||
case "customer.subscription.deleted":
|
||||
err = s.handleSubscriptionUpsert(ctx, tx, event, true)
|
||||
case "invoice.payment_succeeded":
|
||||
err = s.handleInvoiceStatusUpdate(ctx, tx, event, "active")
|
||||
case "invoice.payment_failed":
|
||||
err = s.handleInvoiceStatusUpdate(ctx, tx, event, "past_due")
|
||||
default:
|
||||
// Ignore unsupported events while still recording idempotency.
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (s *WebhookService) handleCheckoutSessionCompleted(ctx context.Context, tx pgx.Tx, event stripe.Event) error {
|
||||
var session stripe.CheckoutSession
|
||||
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
|
||||
return fmt.Errorf("unmarshal checkout session: %w", err)
|
||||
}
|
||||
|
||||
orgID, err := resolveOrgIDFromCheckoutSession(&session)
|
||||
if err != nil {
|
||||
return nil // allow replay/reconciliation by subsequent subscription events
|
||||
}
|
||||
|
||||
if session.Customer != nil && session.Customer.ID != "" {
|
||||
if err := s.orgs.UpdateStripeCustomerIDTx(ctx, tx, orgID, session.Customer.ID); err != nil {
|
||||
return fmt.Errorf("update org stripe customer id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if tier := strings.TrimSpace(session.Metadata["tier"]); tier != "" {
|
||||
if err := s.orgs.UpdatePlanTx(ctx, tx, orgID, string(planmodel.NormalizeTier(tier))); err != nil {
|
||||
return fmt.Errorf("update org plan from checkout metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebhookService) handleSubscriptionUpsert(ctx context.Context, tx pgx.Tx, event stripe.Event, forceFreePlan bool) error {
|
||||
var stripeSub stripe.Subscription
|
||||
if err := json.Unmarshal(event.Data.Raw, &stripeSub); err != nil {
|
||||
return fmt.Errorf("unmarshal subscription: %w", err)
|
||||
}
|
||||
|
||||
orgID, err := s.resolveOrgIDForSubscription(ctx, &stripeSub)
|
||||
if err != nil {
|
||||
return nil // defer to replay once org can be resolved
|
||||
}
|
||||
|
||||
tier, cycle := s.resolveTierAndCycleForSubscription(&stripeSub)
|
||||
if forceFreePlan {
|
||||
tier = planmodel.TierFree
|
||||
cycle = planmodel.BillingCycleMonthly
|
||||
}
|
||||
|
||||
localSub := &repository.OrgSubscription{
|
||||
OrgID: orgID,
|
||||
StripeSubscriptionID: stripeSub.ID,
|
||||
PlanTier: string(tier),
|
||||
BillingCycle: string(cycle),
|
||||
Status: string(stripeSub.Status),
|
||||
CancelAtPeriodEnd: stripeSub.CancelAtPeriodEnd,
|
||||
}
|
||||
if stripeSub.Customer != nil {
|
||||
localSub.StripeCustomerID = stripeSub.Customer.ID
|
||||
}
|
||||
if stripeSub.CurrentPeriodStart > 0 {
|
||||
t := time.Unix(stripeSub.CurrentPeriodStart, 0)
|
||||
localSub.CurrentPeriodStart = &t
|
||||
}
|
||||
if stripeSub.CurrentPeriodEnd > 0 {
|
||||
t := time.Unix(stripeSub.CurrentPeriodEnd, 0)
|
||||
localSub.CurrentPeriodEnd = &t
|
||||
}
|
||||
|
||||
if err := s.subscriptions.UpsertByOrgTx(ctx, tx, localSub); err != nil {
|
||||
return fmt.Errorf("upsert org subscription: %w", err)
|
||||
}
|
||||
|
||||
if localSub.StripeCustomerID != "" {
|
||||
if err := s.orgs.UpdateStripeCustomerIDTx(ctx, tx, orgID, localSub.StripeCustomerID); err != nil {
|
||||
return fmt.Errorf("update org stripe customer id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.orgs.UpdatePlanTx(ctx, tx, orgID, string(tier)); err != nil {
|
||||
return fmt.Errorf("update org plan from subscription: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebhookService) handleInvoiceStatusUpdate(ctx context.Context, tx pgx.Tx, event stripe.Event, nextStatus string) error {
|
||||
var inv stripe.Invoice
|
||||
if err := json.Unmarshal(event.Data.Raw, &inv); err != nil {
|
||||
return fmt.Errorf("unmarshal invoice: %w", err)
|
||||
}
|
||||
|
||||
var sub *repository.OrgSubscription
|
||||
var err error
|
||||
if inv.Subscription != nil && inv.Subscription.ID != "" {
|
||||
sub, err = s.subscriptions.GetByStripeSubscriptionID(ctx, inv.Subscription.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get org subscription by stripe subscription id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if sub == nil && inv.Customer != nil && inv.Customer.ID != "" {
|
||||
sub, err = s.subscriptions.GetByStripeCustomerID(ctx, inv.Customer.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get org subscription by stripe customer id: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if sub == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sub.Status = nextStatus
|
||||
if err := s.subscriptions.UpsertByOrgTx(ctx, tx, sub); err != nil {
|
||||
return fmt.Errorf("persist invoice status update: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebhookService) resolveOrgIDForSubscription(ctx context.Context, stripeSub *stripe.Subscription) (uuid.UUID, error) {
|
||||
if stripeSub.Metadata != nil {
|
||||
if rawOrgID := strings.TrimSpace(stripeSub.Metadata["org_id"]); rawOrgID != "" {
|
||||
orgID, err := uuid.Parse(rawOrgID)
|
||||
if err == nil {
|
||||
return orgID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stripeSub.Customer != nil && stripeSub.Customer.ID != "" {
|
||||
org, err := s.orgs.GetByStripeCustomerID(ctx, stripeSub.Customer.ID)
|
||||
if err != nil {
|
||||
return uuid.Nil, fmt.Errorf("find org by stripe customer id: %w", err)
|
||||
}
|
||||
if org != nil {
|
||||
return org.ID, nil
|
||||
}
|
||||
}
|
||||
|
||||
return uuid.Nil, fmt.Errorf("could not resolve org for stripe subscription %s", stripeSub.ID)
|
||||
}
|
||||
|
||||
func (s *WebhookService) resolveTierAndCycleForSubscription(stripeSub *stripe.Subscription) (planmodel.Tier, planmodel.BillingCycle) {
|
||||
tier := planmodel.TierFree
|
||||
cycle := planmodel.BillingCycleMonthly
|
||||
|
||||
if stripeSub.Metadata != nil {
|
||||
if rawTier := strings.TrimSpace(stripeSub.Metadata["tier"]); rawTier != "" {
|
||||
tier = planmodel.NormalizeTier(rawTier)
|
||||
}
|
||||
if rawCycle := strings.TrimSpace(stripeSub.Metadata["cycle"]); strings.EqualFold(rawCycle, string(planmodel.BillingCycleAnnual)) {
|
||||
cycle = planmodel.BillingCycleAnnual
|
||||
}
|
||||
}
|
||||
|
||||
if stripeSub.Items != nil {
|
||||
for _, item := range stripeSub.Items.Data {
|
||||
if item.Price == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if mappedTier, mappedCycle, ok := s.catalog.ResolveTierAndCycleByPriceID(item.Price.ID); ok {
|
||||
tier = mappedTier
|
||||
cycle = mappedCycle
|
||||
break
|
||||
}
|
||||
|
||||
if item.Price.Recurring != nil && item.Price.Recurring.Interval == "year" {
|
||||
cycle = planmodel.BillingCycleAnnual
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tier, cycle
|
||||
}
|
||||
|
||||
func resolveOrgIDFromCheckoutSession(session *stripe.CheckoutSession) (uuid.UUID, error) {
|
||||
if session.Metadata != nil {
|
||||
if rawOrgID := strings.TrimSpace(session.Metadata["org_id"]); rawOrgID != "" {
|
||||
return uuid.Parse(rawOrgID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(session.ClientReferenceID) != "" {
|
||||
return uuid.Parse(session.ClientReferenceID)
|
||||
}
|
||||
|
||||
return uuid.Nil, fmt.Errorf("missing org metadata")
|
||||
}
|
||||
2
migrations/000021_create_org_subscriptions.down.sql
Normal file
2
migrations/000021_create_org_subscriptions.down.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
DROP TRIGGER IF EXISTS set_org_subscriptions_updated_at ON org_subscriptions;
|
||||
DROP TABLE IF EXISTS org_subscriptions;
|
||||
25
migrations/000021_create_org_subscriptions.up.sql
Normal file
25
migrations/000021_create_org_subscriptions.up.sql
Normal file
@@ -0,0 +1,25 @@
|
||||
CREATE TABLE IF NOT EXISTS org_subscriptions (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
org_id UUID NOT NULL REFERENCES organizations (id) ON DELETE CASCADE,
|
||||
stripe_subscription_id VARCHAR(255) UNIQUE,
|
||||
stripe_customer_id VARCHAR(255),
|
||||
plan_tier VARCHAR(50) NOT NULL DEFAULT 'free',
|
||||
billing_cycle VARCHAR(20) NOT NULL DEFAULT 'monthly'
|
||||
CHECK (billing_cycle IN ('monthly', 'annual')),
|
||||
status VARCHAR(50) NOT NULL DEFAULT 'inactive',
|
||||
current_period_start TIMESTAMPTZ,
|
||||
current_period_end TIMESTAMPTZ,
|
||||
cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
|
||||
UNIQUE (org_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_org_subscriptions_org_status ON org_subscriptions (org_id, status);
|
||||
CREATE INDEX idx_org_subscriptions_customer ON org_subscriptions (stripe_customer_id) WHERE stripe_customer_id IS NOT NULL;
|
||||
|
||||
CREATE TRIGGER set_org_subscriptions_updated_at
|
||||
BEFORE UPDATE ON org_subscriptions
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION trigger_set_updated_at();
|
||||
1
migrations/000022_create_billing_webhook_events.down.sql
Normal file
1
migrations/000022_create_billing_webhook_events.down.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS billing_webhook_events;
|
||||
7
migrations/000022_create_billing_webhook_events.up.sql
Normal file
7
migrations/000022_create_billing_webhook_events.up.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS billing_webhook_events (
|
||||
event_id VARCHAR(255) PRIMARY KEY,
|
||||
event_type VARCHAR(120) NOT NULL,
|
||||
processed_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_billing_webhook_events_processed_at ON billing_webhook_events (processed_at DESC);
|
||||
335
plans/epic12-billing-subscription.md
Normal file
335
plans/epic12-billing-subscription.md
Normal file
@@ -0,0 +1,335 @@
|
||||
# Execution Plan: Epic 12 — Billing & Subscription (#20)
|
||||
|
||||
## Overview
|
||||
|
||||
**Epic:** [#20 — Billing & Subscription](https://github.com/subculture-collective/pulse-score/issues/20)
|
||||
**Sub-issues:** #153–#161 (9 issues)
|
||||
**Scope:** Implement PulseScore-owned Stripe billing (separate from Stripe data integration): plan catalog + Checkout + billing webhooks + subscription tracking + feature gating + pricing/checkout/subscription management UI + Stripe Customer Portal.
|
||||
|
||||
---
|
||||
|
||||
## Current State (What already exists)
|
||||
|
||||
- **Organization billing fields already exist:** `organizations.plan` and `organizations.stripe_customer_id` in migration `000002_create_organizations.up.sql`.
|
||||
- **Stripe integration is already implemented for customer-data sync**, not SaaS billing:
|
||||
- OAuth + sync endpoints in `internal/handler/integration_stripe.go` and `cmd/api/main.go`
|
||||
- Webhook endpoint `POST /api/v1/webhooks/stripe` handled by `internal/service/stripe_webhook.go`
|
||||
- Stripe subscription/payment repositories currently model **customer telemetry data** (`stripe_subscriptions`, `stripe_payments`) in `internal/repository/stripe_data.go`
|
||||
- **Frontend billing surface is a placeholder** in `web/src/pages/settings/BillingTab.tsx`.
|
||||
- **Landing pricing section exists** in `web/src/components/landing/PricingSection.tsx`, but currently static and not wired to checkout flow.
|
||||
|
||||
### Key Gap
|
||||
|
||||
Epic 12 requires a **new billing domain** for PulseScore product subscriptions. This must remain isolated from existing Stripe data-integration services to avoid token/config confusion.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Guardrails for Epic 12
|
||||
|
||||
1. **Separate Stripe configs for billing vs integration**
|
||||
- Keep existing `StripeConfig` for integration OAuth/data sync untouched.
|
||||
- Add billing-specific config/env (e.g., `STRIPE_BILLING_SECRET_KEY`, `STRIPE_BILLING_PUBLISHABLE_KEY`, `STRIPE_BILLING_WEBHOOK_SECRET`).
|
||||
|
||||
2. **Separate webhook endpoint**
|
||||
- Keep `POST /api/v1/webhooks/stripe` for integration events.
|
||||
- Add `POST /api/v1/webhooks/stripe-billing` for billing lifecycle events.
|
||||
|
||||
3. **Dedicated billing service layer**
|
||||
- Create billing domain under `internal/service/billing/*` and `internal/repository/*billing*`.
|
||||
|
||||
4. **Plan limits are enforced server-side**
|
||||
- UI hints are informational only.
|
||||
- Middleware/service checks are the source of truth for customer/integration limits.
|
||||
|
||||
---
|
||||
|
||||
## Dependency Graph
|
||||
|
||||
```text
|
||||
#153 Billing plans catalog/config
|
||||
├──► #154 Checkout session endpoint
|
||||
├──► #157 Feature-gating limits
|
||||
└──► #158 Pricing page
|
||||
|
||||
#154 Checkout
|
||||
└──► #155 Billing webhook handler
|
||||
|
||||
#155 Billing webhook handler
|
||||
└──► #156 Subscription tracking model/service
|
||||
|
||||
#156 Subscription tracking
|
||||
├──► #157 Feature gating middleware
|
||||
├──► #160 Subscription management page
|
||||
└──► #161 Portal session endpoint
|
||||
|
||||
#158 Pricing page
|
||||
└──► #159 Frontend checkout flow
|
||||
|
||||
#154 + #159 + #161 + #156
|
||||
└──► #160 Subscription management page
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Execution Phases
|
||||
|
||||
### Phase 1 — Billing Foundations (Plan catalog + subscription schema)
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#153](https://github.com/subculture-collective/pulse-score/issues/153) | Configure Stripe billing products and prices | critical | `internal/billing/plans.go`, `internal/config/config.go`, `.env.example`, docs under `docs/` |
|
||||
| [#156](https://github.com/subculture-collective/pulse-score/issues/156) | Implement subscription tracking model | high | `migrations/000021_*`, `internal/repository/org_subscription.go`, `internal/service/billing/subscription.go` |
|
||||
|
||||
#### #153 Implementation details
|
||||
|
||||
1. Create `internal/billing/plans.go` as canonical plan definitions:
|
||||
- `free`, `growth`, `scale`
|
||||
- Monthly + annual Stripe price IDs
|
||||
- Limits: `customer_limit`, `integration_limit`
|
||||
- Feature flags (e.g., playbooks, ai insights)
|
||||
2. Add billing Stripe config section in `internal/config/config.go`:
|
||||
- `BillingStripeSecretKey`
|
||||
- `BillingStripePublishableKey`
|
||||
- `BillingStripeWebhookSecret`
|
||||
- Optional: `BillingStripePortalReturnURL`
|
||||
3. Add startup validation so app fails fast if billing keys are missing in non-dev environments.
|
||||
4. Add a small verification utility/service to ensure configured Stripe price IDs exist and carry required metadata (`tier`, `customer_limit`, `integration_limit`).
|
||||
|
||||
**Tests**
|
||||
- Unit tests for plan lookup and limit access (`GetPlanByTier`, `GetLimits`, annual/monthly mapping)
|
||||
- Config loading tests for new billing env vars
|
||||
|
||||
#### #156 Implementation details
|
||||
|
||||
1. Add new migration for org-level subscription state (recommended table `org_subscriptions`):
|
||||
- `org_id` (unique)
|
||||
- `stripe_subscription_id`
|
||||
- `stripe_customer_id`
|
||||
- `plan_tier`
|
||||
- `status`
|
||||
- `current_period_start`, `current_period_end`
|
||||
- `cancel_at_period_end`
|
||||
- `created_at`, `updated_at`
|
||||
2. Add repository `internal/repository/org_subscription.go` with:
|
||||
- `UpsertByOrg`
|
||||
- `GetByOrg`
|
||||
- `GetByStripeSubscriptionID`
|
||||
3. Add service `internal/service/billing/subscription.go` with query helpers:
|
||||
- `GetCurrentPlan(orgID)`
|
||||
- `IsActive(orgID)`
|
||||
- `GetUsageLimits(orgID)`
|
||||
4. Default new organizations to Free tier if no subscription row exists.
|
||||
|
||||
**Tests**
|
||||
- Repository CRUD/upsert tests
|
||||
- Service tests for default free fallback and status transitions
|
||||
|
||||
---
|
||||
|
||||
### Phase 2 — Checkout + Portal backend APIs
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#154](https://github.com/subculture-collective/pulse-score/issues/154) | Implement Stripe Checkout session creation | critical | `internal/service/billing/checkout.go`, `internal/handler/billing.go`, `cmd/api/main.go` |
|
||||
| [#161](https://github.com/subculture-collective/pulse-score/issues/161) | Implement Stripe Customer Portal redirect | high | `internal/service/billing/portal.go`, `internal/handler/billing.go`, `cmd/api/main.go` |
|
||||
|
||||
#### #154 Implementation details
|
||||
|
||||
1. Add `POST /api/v1/billing/checkout` (admin-gated).
|
||||
2. Request: `{ priceId: string, annual: boolean }` (or `{ tier, cycle }` mapped server-side to configured price ID).
|
||||
3. Ensure org has Stripe customer:
|
||||
- Reuse `organizations.stripe_customer_id` if present.
|
||||
- Create new Stripe customer if absent and persist back.
|
||||
4. Create Stripe Checkout Session (`mode=subscription`) with metadata:
|
||||
- `org_id`, `user_id`, `tier`, `cycle`.
|
||||
5. Return `{ url }` for frontend redirect.
|
||||
|
||||
**Tests**
|
||||
- Handler tests for validation/authz
|
||||
- Service tests for customer-creation path and session creation path
|
||||
|
||||
#### #161 Implementation details
|
||||
|
||||
1. Add `POST /api/v1/billing/portal-session`.
|
||||
2. Ensure Stripe customer exists before creating portal session.
|
||||
3. Return hosted portal URL with return path `/settings/billing`.
|
||||
4. Expose in billing service for UI integration.
|
||||
|
||||
**Tests**
|
||||
- Portal session created with right customer
|
||||
- Error path when no customer + Stripe creation failure
|
||||
|
||||
---
|
||||
|
||||
### Phase 3 — Billing webhook and subscription state sync
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#155](https://github.com/subculture-collective/pulse-score/issues/155) | Implement billing webhook handler | critical | `internal/service/billing/webhook.go`, `internal/handler/webhook_billing.go`, `cmd/api/main.go` |
|
||||
|
||||
#### #155 Implementation details
|
||||
|
||||
1. Add public endpoint `POST /api/v1/webhooks/stripe-billing`.
|
||||
2. Verify billing webhook signatures with billing webhook secret (not integration webhook secret).
|
||||
3. Handle events:
|
||||
- `checkout.session.completed`
|
||||
- `customer.subscription.created`
|
||||
- `customer.subscription.updated`
|
||||
- `customer.subscription.deleted`
|
||||
- `invoice.payment_succeeded`
|
||||
- `invoice.payment_failed`
|
||||
4. Update `org_subscriptions` and `organizations.plan` in one transaction where possible.
|
||||
5. On `payment_failed`, flag status and trigger notification hooks.
|
||||
6. Add idempotency guard (event ID dedupe table or durable event log) to safely replay events.
|
||||
|
||||
**Tests**
|
||||
- Signature verification success/failure
|
||||
- Idempotency checks
|
||||
- Event-specific state transition tests
|
||||
|
||||
---
|
||||
|
||||
### Phase 4 — Enforce plan limits in backend
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#157](https://github.com/subculture-collective/pulse-score/issues/157) | Implement feature gating middleware | critical | `internal/middleware/feature_gate.go`, `internal/service/billing/limits.go`, selected handlers in `internal/handler/*` |
|
||||
|
||||
#### #157 Implementation details
|
||||
|
||||
1. Implement tier limit checks for:
|
||||
- Customer creation operations (max customers)
|
||||
- Integration connection operations (max integrations)
|
||||
2. Add feature access checks:
|
||||
- `CanAccess(orgID, featureName)`
|
||||
3. Return `402 Payment Required` with structured payload:
|
||||
- current plan
|
||||
- limit reached
|
||||
- recommended upgrade tier
|
||||
4. Apply middleware/checks to relevant routes/handlers:
|
||||
- customer create/upsert paths
|
||||
- integration connect endpoints
|
||||
5. Ensure downgrade behavior enforces Free limits for **new actions** while preserving existing over-limit data until user acts.
|
||||
|
||||
**Tests**
|
||||
- Per-tier customer/integration limit tests
|
||||
- Feature gate allow/deny matrix tests
|
||||
- Proper 402 payload contract tests
|
||||
|
||||
---
|
||||
|
||||
### Phase 5 — Pricing page + checkout UX
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#158](https://github.com/subculture-collective/pulse-score/issues/158) | Build pricing page component | high | `web/src/pages/PricingPage.tsx`, `web/src/lib/api.ts`, `web/src/components/landing/PricingSection.tsx` |
|
||||
| [#159](https://github.com/subculture-collective/pulse-score/issues/159) | Build checkout flow and success handling | high | `web/src/hooks/useCheckout.ts`, `web/src/pages/settings/BillingTab.tsx`, router integration |
|
||||
|
||||
#### #158 Implementation details
|
||||
|
||||
1. Create a reusable pricing component/page with Free/Growth/Scale comparison.
|
||||
2. Add monthly/annual toggle and savings badge.
|
||||
3. Highlight current plan for authenticated orgs.
|
||||
4. CTA buttons call checkout hook/API for paid plans.
|
||||
5. Optionally replace static landing `PricingSection` internals with shared pricing model to avoid drift.
|
||||
|
||||
**Tests**
|
||||
- Toggle behavior tests
|
||||
- Current plan highlighting tests
|
||||
- CTA callback wiring tests
|
||||
|
||||
#### #159 Implementation details
|
||||
|
||||
1. Implement `useCheckout` hook:
|
||||
- Loading state
|
||||
- API call to `/billing/checkout`
|
||||
- Redirect to Stripe URL
|
||||
2. Handle return states in billing settings page:
|
||||
- `?checkout=success` → success toast + refresh subscription
|
||||
- `?checkout=cancelled` → informative notice
|
||||
3. Refresh org subscription data after successful return.
|
||||
|
||||
**Tests**
|
||||
- Hook success/failure/loading tests
|
||||
- Success/cancel query param behavior tests
|
||||
|
||||
---
|
||||
|
||||
### Phase 6 — Subscription management UI
|
||||
|
||||
| Issue | Title | Priority | Files to Create/Modify |
|
||||
|------|------|------|------|
|
||||
| [#160](https://github.com/subculture-collective/pulse-score/issues/160) | Build subscription management page | high | `web/src/components/billing/SubscriptionManager.tsx`, `web/src/pages/settings/BillingTab.tsx`, `web/src/lib/api.ts`, backend billing handler for `GET /billing/subscription` |
|
||||
|
||||
#### #160 Implementation details
|
||||
|
||||
1. Replace placeholder `BillingTab` with real manager component.
|
||||
2. Add `GET /api/v1/billing/subscription` backend response:
|
||||
- current tier, status, billing cycle, renewal date
|
||||
- usage counters (customers/integrations) + limits
|
||||
3. UI actions:
|
||||
- Change plan (launch checkout)
|
||||
- Cancel at period end (confirmation modal)
|
||||
- Open Stripe portal (`POST /billing/portal-session`)
|
||||
4. Ensure plan/usage is consistent with middleware enforcement logic.
|
||||
|
||||
**Tests**
|
||||
- Component rendering with active/canceled/past_due states
|
||||
- Cancel confirmation flow tests
|
||||
- Portal redirect action tests
|
||||
|
||||
---
|
||||
|
||||
## API Surface to Add (Epic 12)
|
||||
|
||||
### Public
|
||||
- `POST /api/v1/webhooks/stripe-billing`
|
||||
|
||||
### Protected (admin where appropriate)
|
||||
- `POST /api/v1/billing/checkout`
|
||||
- `POST /api/v1/billing/portal-session`
|
||||
- `GET /api/v1/billing/subscription`
|
||||
- `POST /api/v1/billing/cancel` (if cancellation endpoint is implemented explicitly)
|
||||
|
||||
---
|
||||
|
||||
## Suggested PR Slicing
|
||||
|
||||
1. **PR-1:** #153 + config/env + plan catalog tests
|
||||
2. **PR-2:** #156 schema/repository/service foundation
|
||||
3. **PR-3:** #154 checkout endpoint + customer creation
|
||||
4. **PR-4:** #155 billing webhook + idempotency
|
||||
5. **PR-5:** #157 feature gating middleware
|
||||
6. **PR-6:** #158 pricing page + shared plan presentation
|
||||
7. **PR-7:** #159 checkout UX + success/cancel handling
|
||||
8. **PR-8:** #161 portal endpoint
|
||||
9. **PR-9:** #160 subscription management page + usage meters
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
1. **Confusing Stripe integration keys with billing keys**
|
||||
_Mitigation:_ strict config separation + explicit naming (`STRIPE_BILLING_*`).
|
||||
|
||||
2. **Event replay / duplicate webhook side effects**
|
||||
_Mitigation:_ durable idempotency key storage by Stripe event ID.
|
||||
|
||||
3. **Plan drift between backend and frontend pricing copy**
|
||||
_Mitigation:_ central shared plan model or API-driven pricing payload.
|
||||
|
||||
4. **Incorrect limit enforcement on downgrade**
|
||||
_Mitigation:_ define policy up-front (block new actions only) and encode in tests.
|
||||
|
||||
---
|
||||
|
||||
## Definition of Done (Epic 12)
|
||||
|
||||
Epic #20 is complete when all sub-issues #153–#161 are merged and validated end-to-end:
|
||||
|
||||
- Checkout works from pricing/settings to Stripe and back
|
||||
- Billing webhook updates org subscription state reliably in near real-time
|
||||
- Feature limits are enforced for customer/integration actions
|
||||
- Billing settings page shows current plan + usage + renewal details
|
||||
- Stripe Customer Portal opens successfully for payment/invoice self-service
|
||||
- Test coverage exists for all new handlers/services/middleware and key UI states
|
||||
@@ -21,5 +21,11 @@ export default defineConfig([
|
||||
globals: globals.browser,
|
||||
},
|
||||
},
|
||||
{
|
||||
files: ["src/contexts/**/*.{ts,tsx}"],
|
||||
rules: {
|
||||
"react-refresh/only-export-components": "off",
|
||||
},
|
||||
},
|
||||
eslintConfigPrettier,
|
||||
]);
|
||||
|
||||
@@ -6,6 +6,7 @@ import ProtectedRoute from "@/components/ProtectedRoute";
|
||||
import ErrorBoundary from "@/components/ErrorBoundary";
|
||||
import AppLayout from "@/layouts/AppLayout";
|
||||
import LandingPage from "@/pages/LandingPage";
|
||||
import PricingPage from "@/pages/PricingPage";
|
||||
import LoginPage from "@/pages/auth/LoginPage";
|
||||
import RegisterPage from "@/pages/auth/RegisterPage";
|
||||
import DashboardPage from "@/pages/DashboardPage";
|
||||
@@ -27,6 +28,7 @@ function App() {
|
||||
<Routes>
|
||||
{/* Public marketing + auth routes */}
|
||||
<Route path="/" element={<LandingPage />} />
|
||||
<Route path="/pricing" element={<PricingPage />} />
|
||||
<Route path="/login" element={<LoginPage />} />
|
||||
<Route path="/register" element={<RegisterPage />} />
|
||||
<Route path="/privacy" element={<PrivacyPage />} />
|
||||
|
||||
237
web/src/components/billing/SubscriptionManager.tsx
Normal file
237
web/src/components/billing/SubscriptionManager.tsx
Normal file
@@ -0,0 +1,237 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { ExternalLink, Loader2 } from "lucide-react";
|
||||
|
||||
import { useToast } from "@/contexts/ToastContext";
|
||||
import { useCheckout } from "@/hooks/useCheckout";
|
||||
import { billingApi, type BillingSubscriptionResponse } from "@/lib/api";
|
||||
import {
|
||||
billingPlans,
|
||||
type BillingCycle,
|
||||
type PlanTier,
|
||||
} from "@/lib/billingPlans";
|
||||
|
||||
interface SubscriptionManagerProps {
|
||||
checkoutState?: "success" | "cancelled" | null;
|
||||
}
|
||||
|
||||
function normalizeTier(value?: string): PlanTier {
|
||||
const normalized = value?.toLowerCase();
|
||||
if (normalized === "growth" || normalized === "scale") return normalized;
|
||||
return "free";
|
||||
}
|
||||
|
||||
function formatRenewalDate(date: string | null): string {
|
||||
if (!date) return "—";
|
||||
const parsed = new Date(date);
|
||||
if (Number.isNaN(parsed.getTime())) return "—";
|
||||
return parsed.toLocaleDateString();
|
||||
}
|
||||
|
||||
function formatLimit(limit: number): string {
|
||||
return limit < 0 ? "∞" : String(limit);
|
||||
}
|
||||
|
||||
function usagePercent(used: number, limit: number): number {
|
||||
if (limit <= 0) return 0;
|
||||
return Math.min(100, Math.round((used / limit) * 100));
|
||||
}
|
||||
|
||||
export default function SubscriptionManager({ checkoutState }: SubscriptionManagerProps) {
|
||||
const [subscription, setSubscription] = useState<BillingSubscriptionResponse | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [cancelling, setCancelling] = useState(false);
|
||||
const [openingPortal, setOpeningPortal] = useState(false);
|
||||
|
||||
const { loading: checkoutLoading, startCheckout } = useCheckout();
|
||||
const toast = useToast();
|
||||
|
||||
const currentTier = normalizeTier(subscription?.tier);
|
||||
const cycle = (subscription?.billing_cycle ?? "monthly") as BillingCycle;
|
||||
|
||||
const fetchSubscription = useCallback(async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const { data } = await billingApi.getSubscription();
|
||||
setSubscription(data);
|
||||
} catch {
|
||||
toast.error("Failed to load subscription details");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [toast]);
|
||||
|
||||
useEffect(() => {
|
||||
void fetchSubscription();
|
||||
}, [fetchSubscription]);
|
||||
|
||||
useEffect(() => {
|
||||
if (checkoutState === "success") {
|
||||
toast.success("Checkout complete. Subscription is refreshing...");
|
||||
void fetchSubscription();
|
||||
}
|
||||
if (checkoutState === "cancelled") {
|
||||
toast.info("Checkout cancelled. No changes were made.");
|
||||
}
|
||||
}, [checkoutState, fetchSubscription, toast]);
|
||||
|
||||
const recommendedPlans = useMemo(
|
||||
() => billingPlans.filter((plan) => plan.tier !== "free"),
|
||||
[],
|
||||
);
|
||||
|
||||
async function handleOpenPortal() {
|
||||
setOpeningPortal(true);
|
||||
try {
|
||||
const { data } = await billingApi.createPortalSession();
|
||||
window.location.href = data.url;
|
||||
} catch {
|
||||
toast.error("Unable to open Stripe customer portal right now.");
|
||||
} finally {
|
||||
setOpeningPortal(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleCancelAtPeriodEnd() {
|
||||
if (!window.confirm("Cancel this subscription at the end of the current billing period?")) {
|
||||
return;
|
||||
}
|
||||
|
||||
setCancelling(true);
|
||||
try {
|
||||
await billingApi.cancelAtPeriodEnd();
|
||||
toast.success("Your subscription will cancel at period end.");
|
||||
await fetchSubscription();
|
||||
} catch {
|
||||
toast.error("Failed to schedule cancellation. Please try again.");
|
||||
} finally {
|
||||
setCancelling(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="flex justify-center py-10">
|
||||
<Loader2 className="h-6 w-6 animate-spin text-gray-400" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!subscription) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<section className="rounded-xl border border-gray-200 bg-white p-5 dark:border-gray-800 dark:bg-gray-900">
|
||||
<div className="flex flex-wrap items-start justify-between gap-4">
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-gray-900 dark:text-gray-100">
|
||||
{currentTier[0].toUpperCase() + currentTier.slice(1)} plan
|
||||
</h3>
|
||||
<p className="mt-1 text-sm text-gray-500 dark:text-gray-400">
|
||||
Status: <span className="font-medium text-gray-700 dark:text-gray-200">{subscription.status}</span>
|
||||
{" · "}
|
||||
Cycle: <span className="font-medium text-gray-700 dark:text-gray-200">{cycle}</span>
|
||||
{" · "}
|
||||
Renewal: <span className="font-medium text-gray-700 dark:text-gray-200">{formatRenewalDate(subscription.renewal_date)}</span>
|
||||
</p>
|
||||
{subscription.cancel_at_period_end && (
|
||||
<p className="mt-2 text-xs font-medium text-amber-600 dark:text-amber-300">
|
||||
This subscription is scheduled to cancel at period end.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<button
|
||||
onClick={handleOpenPortal}
|
||||
disabled={openingPortal}
|
||||
className="inline-flex items-center gap-1 rounded-lg border border-gray-300 px-3 py-2 text-sm font-medium text-gray-700 hover:bg-gray-50 disabled:opacity-60 dark:border-gray-700 dark:text-gray-200 dark:hover:bg-gray-800"
|
||||
>
|
||||
{openingPortal ? "Opening..." : "Open customer portal"}
|
||||
<ExternalLink className="h-4 w-4" />
|
||||
</button>
|
||||
{currentTier !== "free" && !subscription.cancel_at_period_end && (
|
||||
<button
|
||||
onClick={handleCancelAtPeriodEnd}
|
||||
disabled={cancelling}
|
||||
className="rounded-lg border border-rose-300 px-3 py-2 text-sm font-medium text-rose-700 hover:bg-rose-50 disabled:opacity-60 dark:border-rose-700 dark:text-rose-300 dark:hover:bg-rose-950/40"
|
||||
>
|
||||
{cancelling ? "Cancelling..." : "Cancel at period end"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section className="rounded-xl border border-gray-200 bg-white p-5 dark:border-gray-800 dark:bg-gray-900">
|
||||
<h4 className="text-sm font-semibold text-gray-900 dark:text-gray-100">Usage</h4>
|
||||
<div className="mt-4 space-y-4">
|
||||
{[
|
||||
{
|
||||
label: "Customers",
|
||||
used: subscription.usage.customers.used,
|
||||
limit: subscription.usage.customers.limit,
|
||||
},
|
||||
{
|
||||
label: "Integrations",
|
||||
used: subscription.usage.integrations.used,
|
||||
limit: subscription.usage.integrations.limit,
|
||||
},
|
||||
].map((item) => (
|
||||
<div key={item.label}>
|
||||
<div className="mb-1 flex items-center justify-between text-sm text-gray-600 dark:text-gray-300">
|
||||
<span>{item.label}</span>
|
||||
<span>
|
||||
{item.used} / {formatLimit(item.limit)}
|
||||
</span>
|
||||
</div>
|
||||
<div className="h-2 rounded-full bg-gray-200 dark:bg-gray-700">
|
||||
<div
|
||||
className="h-2 rounded-full bg-indigo-500"
|
||||
style={{ width: `${usagePercent(item.used, item.limit)}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section className="rounded-xl border border-gray-200 bg-white p-5 dark:border-gray-800 dark:bg-gray-900">
|
||||
<h4 className="text-sm font-semibold text-gray-900 dark:text-gray-100">Change plan</h4>
|
||||
<div className="mt-3 grid gap-3 md:grid-cols-2">
|
||||
{recommendedPlans.map((plan) => {
|
||||
const isCurrent = plan.tier === currentTier;
|
||||
return (
|
||||
<div
|
||||
key={plan.tier}
|
||||
className={`rounded-lg border p-4 ${
|
||||
isCurrent
|
||||
? "border-emerald-300 bg-emerald-50 dark:border-emerald-700 dark:bg-emerald-950/20"
|
||||
: "border-gray-200 bg-gray-50 dark:border-gray-700 dark:bg-gray-800/50"
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-start justify-between">
|
||||
<div>
|
||||
<p className="font-semibold text-gray-900 dark:text-gray-100">{plan.name}</p>
|
||||
<p className="text-xs text-gray-500 dark:text-gray-400">{plan.description}</p>
|
||||
</div>
|
||||
<p className="text-sm font-medium text-gray-700 dark:text-gray-200">
|
||||
${cycle === "monthly" ? plan.monthlyPrice : plan.annualPrice}/{cycle === "monthly" ? "mo" : "yr"}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
disabled={checkoutLoading || isCurrent}
|
||||
onClick={() => startCheckout({ tier: plan.tier, cycle })}
|
||||
className="mt-3 w-full rounded-lg bg-indigo-600 px-3 py-2 text-sm font-semibold text-white hover:bg-indigo-700 disabled:cursor-not-allowed disabled:opacity-60"
|
||||
>
|
||||
{isCurrent ? "Current plan" : checkoutLoading ? "Redirecting..." : `Switch to ${plan.name}`}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,83 +1,84 @@
|
||||
import { useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { Check, Sparkles } from "lucide-react";
|
||||
import { Link } from "react-router-dom";
|
||||
|
||||
type BillingCycle = "monthly" | "annual";
|
||||
import { useAuth } from "@/contexts/AuthContext";
|
||||
import { useCheckout } from "@/hooks/useCheckout";
|
||||
import { billingApi } from "@/lib/api";
|
||||
import {
|
||||
billingPlans,
|
||||
savingsBadge,
|
||||
type BillingCycle,
|
||||
type PlanTier,
|
||||
} from "@/lib/billingPlans";
|
||||
|
||||
interface Plan {
|
||||
name: string;
|
||||
monthlyPrice: number;
|
||||
annualPrice: number;
|
||||
description: string;
|
||||
ctaLabel: string;
|
||||
ctaHref: string;
|
||||
featured?: boolean;
|
||||
limits: {
|
||||
customers: string;
|
||||
integrations: string;
|
||||
alerts: string;
|
||||
support: string;
|
||||
};
|
||||
interface PricingSectionProps {
|
||||
showStandaloneHeader?: boolean;
|
||||
}
|
||||
|
||||
const plans: Plan[] = [
|
||||
{
|
||||
name: "Free",
|
||||
monthlyPrice: 0,
|
||||
annualPrice: 0,
|
||||
description: "Best for evaluating PulseScore with a small portfolio.",
|
||||
ctaLabel: "Get Started Free",
|
||||
ctaHref: "/register?plan=free",
|
||||
limits: {
|
||||
customers: "Up to 10 customers",
|
||||
integrations: "1 integration",
|
||||
alerts: "Basic alerts",
|
||||
support: "Community support",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Growth",
|
||||
monthlyPrice: 49,
|
||||
annualPrice: 490,
|
||||
description: "For fast-moving teams managing churn at scale.",
|
||||
ctaLabel: "Start Growth Trial",
|
||||
ctaHref: "/register?plan=growth",
|
||||
featured: true,
|
||||
limits: {
|
||||
customers: "Up to 250 customers",
|
||||
integrations: "Up to 3 integrations",
|
||||
alerts: "Advanced alert rules",
|
||||
support: "Priority email support",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Scale",
|
||||
monthlyPrice: 149,
|
||||
annualPrice: 1490,
|
||||
description: "For mature revenue teams with complex customer motion.",
|
||||
ctaLabel: "Start Scale Trial",
|
||||
ctaHref: "/register?plan=scale",
|
||||
limits: {
|
||||
customers: "Unlimited customers",
|
||||
integrations: "Unlimited integrations",
|
||||
alerts: "Advanced workflows",
|
||||
support: "Dedicated success partner",
|
||||
},
|
||||
},
|
||||
];
|
||||
function normalizeTier(value?: string | null): PlanTier | null {
|
||||
if (!value) return null;
|
||||
const tier = value.trim().toLowerCase();
|
||||
if (tier === "free" || tier === "growth" || tier === "scale") {
|
||||
return tier;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export default function PricingSection() {
|
||||
export default function PricingSection({ showStandaloneHeader = false }: PricingSectionProps) {
|
||||
const [cycle, setCycle] = useState<BillingCycle>("monthly");
|
||||
const { isAuthenticated, organization } = useAuth();
|
||||
const { loading, startCheckout } = useCheckout();
|
||||
const [currentTier, setCurrentTier] = useState<PlanTier | null>(
|
||||
normalizeTier(organization?.plan),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
|
||||
async function loadSubscription() {
|
||||
if (!isAuthenticated) {
|
||||
setCurrentTier(null);
|
||||
return;
|
||||
}
|
||||
|
||||
setCurrentTier(normalizeTier(organization?.plan));
|
||||
|
||||
try {
|
||||
const { data } = await billingApi.getSubscription();
|
||||
if (!cancelled) {
|
||||
setCurrentTier(normalizeTier(data.tier));
|
||||
}
|
||||
} catch {
|
||||
// Keep fallback from auth payload if subscription endpoint is unavailable.
|
||||
}
|
||||
}
|
||||
|
||||
void loadSubscription();
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [isAuthenticated, organization?.plan]);
|
||||
|
||||
const annualSavingsText = useMemo(() => {
|
||||
const growthMonthlyAnnualized = plans[1].monthlyPrice * 12;
|
||||
const growthSavings = growthMonthlyAnnualized - plans[1].annualPrice;
|
||||
return `Save $${growthSavings}/yr`;
|
||||
const growthPlan = billingPlans.find((plan) => plan.tier === "growth") ?? billingPlans[0];
|
||||
return savingsBadge(growthPlan);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<section id="pricing" className="bg-white px-6 py-16 dark:bg-gray-950 sm:px-10 lg:px-14 lg:py-24">
|
||||
<div className="mx-auto max-w-7xl">
|
||||
{showStandaloneHeader && (
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold tracking-tight text-gray-900 dark:text-gray-100 sm:text-4xl">
|
||||
Choose the right PulseScore plan
|
||||
</h1>
|
||||
<p className="mt-2 text-gray-600 dark:text-gray-300">
|
||||
Start on Free and upgrade when your customer health workflow scales.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col items-start justify-between gap-6 md:flex-row md:items-end">
|
||||
<div className="max-w-3xl">
|
||||
<h2 className="text-3xl font-bold tracking-tight text-gray-900 dark:text-gray-100 sm:text-4xl">
|
||||
@@ -120,20 +121,38 @@ export default function PricingSection() {
|
||||
</div>
|
||||
|
||||
<div className="mt-10 grid grid-cols-1 gap-4 lg:grid-cols-3">
|
||||
{plans.map((plan) => {
|
||||
{billingPlans.map((plan) => {
|
||||
const isFree = plan.monthlyPrice === 0;
|
||||
const displayPrice = cycle === "monthly" ? plan.monthlyPrice : plan.annualPrice;
|
||||
const period = cycle === "monthly" ? "/mo" : "/yr";
|
||||
const isCurrentPlan = currentTier === plan.tier;
|
||||
|
||||
const ctaLabel = isCurrentPlan
|
||||
? "Current plan"
|
||||
: isFree
|
||||
? isAuthenticated
|
||||
? "Stay on Free"
|
||||
: "Get Started Free"
|
||||
: isAuthenticated
|
||||
? `Choose ${plan.name}`
|
||||
: `Start ${plan.name}`;
|
||||
|
||||
return (
|
||||
<article
|
||||
key={plan.name}
|
||||
key={plan.tier}
|
||||
className={`relative rounded-2xl border p-6 shadow-sm ${
|
||||
plan.featured
|
||||
isCurrentPlan
|
||||
? "border-emerald-400 bg-emerald-50/40 dark:border-emerald-700 dark:bg-emerald-950/10"
|
||||
: plan.featured
|
||||
? "border-indigo-300 bg-indigo-50/50 dark:border-indigo-700 dark:bg-indigo-950/20"
|
||||
: "border-gray-200 bg-white dark:border-gray-800 dark:bg-gray-900"
|
||||
}`}
|
||||
>
|
||||
{isCurrentPlan && (
|
||||
<span className="absolute -top-3 left-4 rounded-full bg-emerald-600 px-3 py-1 text-xs font-semibold text-white">
|
||||
Current plan
|
||||
</span>
|
||||
)}
|
||||
{plan.featured && (
|
||||
<span className="absolute -top-3 right-4 rounded-full bg-indigo-600 px-3 py-1 text-xs font-semibold text-white">
|
||||
Most popular
|
||||
@@ -153,16 +172,30 @@ export default function PricingSection() {
|
||||
<p className="mt-1 text-xs text-gray-500 dark:text-gray-400">No credit card required</p>
|
||||
)}
|
||||
|
||||
<Link
|
||||
to={plan.ctaHref}
|
||||
className={`mt-5 inline-flex w-full items-center justify-center rounded-xl px-4 py-2.5 text-sm font-semibold transition ${
|
||||
plan.featured
|
||||
? "bg-indigo-600 text-white hover:bg-indigo-700"
|
||||
: "border border-gray-300 bg-white text-gray-800 hover:bg-gray-50 dark:border-gray-700 dark:bg-gray-900 dark:text-gray-200 dark:hover:bg-gray-800"
|
||||
}`}
|
||||
>
|
||||
{plan.ctaLabel}
|
||||
</Link>
|
||||
{isAuthenticated && !isFree && !isCurrentPlan ? (
|
||||
<button
|
||||
onClick={() => startCheckout({ tier: plan.tier, cycle })}
|
||||
disabled={loading}
|
||||
className={`mt-5 inline-flex w-full items-center justify-center rounded-xl px-4 py-2.5 text-sm font-semibold transition disabled:cursor-not-allowed disabled:opacity-60 ${
|
||||
plan.featured
|
||||
? "bg-indigo-600 text-white hover:bg-indigo-700"
|
||||
: "border border-gray-300 bg-white text-gray-800 hover:bg-gray-50 dark:border-gray-700 dark:bg-gray-900 dark:text-gray-200 dark:hover:bg-gray-800"
|
||||
}`}
|
||||
>
|
||||
{loading ? "Redirecting..." : ctaLabel}
|
||||
</button>
|
||||
) : (
|
||||
<Link
|
||||
to={isAuthenticated ? "/dashboard" : `/register?plan=${plan.tier}`}
|
||||
className={`mt-5 inline-flex w-full items-center justify-center rounded-xl px-4 py-2.5 text-sm font-semibold transition ${
|
||||
plan.featured
|
||||
? "bg-indigo-600 text-white hover:bg-indigo-700"
|
||||
: "border border-gray-300 bg-white text-gray-800 hover:bg-gray-50 dark:border-gray-700 dark:bg-gray-900 dark:text-gray-200 dark:hover:bg-gray-800"
|
||||
}`}
|
||||
>
|
||||
{ctaLabel}
|
||||
</Link>
|
||||
)}
|
||||
|
||||
<ul className="mt-5 space-y-2 text-sm text-gray-600 dark:text-gray-300">
|
||||
{Object.values(plan.limits).map((item) => (
|
||||
|
||||
@@ -47,6 +47,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
});
|
||||
const [loading, setLoading] = useState(true);
|
||||
const refreshTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const applySessionRef = useRef<(data: AuthResponse) => void>(() => undefined);
|
||||
|
||||
const clearSession = useCallback(() => {
|
||||
setState({ user: null, organization: null, accessToken: null });
|
||||
@@ -71,7 +72,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
refreshTimer.current = setTimeout(async () => {
|
||||
try {
|
||||
const { data } = await authApi.refresh(refreshToken);
|
||||
applySession(data);
|
||||
applySessionRef.current(data);
|
||||
} catch {
|
||||
clearSession();
|
||||
}
|
||||
@@ -97,6 +98,10 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
[scheduleRefresh],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
applySessionRef.current = applySession;
|
||||
}, [applySession]);
|
||||
|
||||
// On mount: attempt silent refresh to restore session
|
||||
useEffect(() => {
|
||||
async function tryRestore() {
|
||||
|
||||
37
web/src/hooks/useCheckout.ts
Normal file
37
web/src/hooks/useCheckout.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
import { useCallback, useState } from "react";
|
||||
import { AxiosError } from "axios";
|
||||
|
||||
import { billingApi, type CheckoutPayload } from "@/lib/api";
|
||||
import { useToast } from "@/contexts/ToastContext";
|
||||
|
||||
export function useCheckout() {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const toast = useToast();
|
||||
|
||||
const startCheckout = useCallback(
|
||||
async (payload: CheckoutPayload) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const { data } = await billingApi.createCheckout(payload);
|
||||
if (!data.url) {
|
||||
throw new Error("missing checkout url");
|
||||
}
|
||||
window.location.href = data.url;
|
||||
} catch (error) {
|
||||
if (error instanceof AxiosError && error.response?.data?.error) {
|
||||
toast.error(String(error.response.data.error));
|
||||
} else {
|
||||
toast.error("Unable to start checkout right now.");
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
},
|
||||
[toast],
|
||||
);
|
||||
|
||||
return {
|
||||
loading,
|
||||
startCheckout,
|
||||
};
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { Outlet, useLocation } from "react-router-dom";
|
||||
import { useCallback, useState } from "react";
|
||||
import { Outlet } from "react-router-dom";
|
||||
import Sidebar from "@/components/Sidebar";
|
||||
import Header from "@/components/Header";
|
||||
|
||||
@@ -8,12 +8,6 @@ export default function AppLayout() {
|
||||
return localStorage.getItem("sidebar-collapsed") === "true";
|
||||
});
|
||||
const [mobileOpen, setMobileOpen] = useState(false);
|
||||
const location = useLocation();
|
||||
|
||||
// Close mobile drawer on route change
|
||||
useEffect(() => {
|
||||
setMobileOpen(false);
|
||||
}, [location.pathname]);
|
||||
|
||||
const toggleCollapse = useCallback(() => {
|
||||
setCollapsed((prev) => {
|
||||
|
||||
@@ -19,6 +19,7 @@ export interface AuthOrg {
|
||||
name: string;
|
||||
slug: string;
|
||||
role: string;
|
||||
plan: string;
|
||||
}
|
||||
|
||||
export interface TokenPair {
|
||||
@@ -64,6 +65,45 @@ export const authApi = {
|
||||
}),
|
||||
};
|
||||
|
||||
export interface BillingUsageMetric {
|
||||
used: number;
|
||||
limit: number;
|
||||
}
|
||||
|
||||
export interface BillingSubscriptionResponse {
|
||||
tier: string;
|
||||
status: string;
|
||||
billing_cycle: "monthly" | "annual";
|
||||
renewal_date: string | null;
|
||||
cancel_at_period_end: boolean;
|
||||
usage: {
|
||||
customers: BillingUsageMetric;
|
||||
integrations: BillingUsageMetric;
|
||||
};
|
||||
features: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface CheckoutPayload {
|
||||
priceId?: string;
|
||||
tier?: string;
|
||||
cycle?: "monthly" | "annual";
|
||||
annual?: boolean;
|
||||
}
|
||||
|
||||
export const billingApi = {
|
||||
getSubscription: () =>
|
||||
api.get<BillingSubscriptionResponse>("/billing/subscription"),
|
||||
|
||||
createCheckout: (payload: CheckoutPayload) =>
|
||||
api.post<{ url: string }>("/billing/checkout", payload),
|
||||
|
||||
createPortalSession: () =>
|
||||
api.post<{ url: string }>("/billing/portal-session"),
|
||||
|
||||
cancelAtPeriodEnd: () =>
|
||||
api.post<{ status: string }>("/billing/cancel"),
|
||||
};
|
||||
|
||||
// Alert types and API
|
||||
export interface AlertRule {
|
||||
id: string;
|
||||
|
||||
66
web/src/lib/billingPlans.ts
Normal file
66
web/src/lib/billingPlans.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
export type BillingCycle = "monthly" | "annual";
|
||||
export type PlanTier = "free" | "growth" | "scale";
|
||||
|
||||
export interface BillingPlanDefinition {
|
||||
tier: PlanTier;
|
||||
name: string;
|
||||
description: string;
|
||||
monthlyPrice: number;
|
||||
annualPrice: number;
|
||||
limits: {
|
||||
customers: string;
|
||||
integrations: string;
|
||||
alerts: string;
|
||||
support: string;
|
||||
};
|
||||
featured?: boolean;
|
||||
}
|
||||
|
||||
export const billingPlans: BillingPlanDefinition[] = [
|
||||
{
|
||||
tier: "free",
|
||||
name: "Free",
|
||||
monthlyPrice: 0,
|
||||
annualPrice: 0,
|
||||
description: "Best for evaluating PulseScore with a small portfolio.",
|
||||
limits: {
|
||||
customers: "Up to 10 customers",
|
||||
integrations: "1 integration",
|
||||
alerts: "Basic alerts",
|
||||
support: "Community support",
|
||||
},
|
||||
},
|
||||
{
|
||||
tier: "growth",
|
||||
name: "Growth",
|
||||
monthlyPrice: 49,
|
||||
annualPrice: 490,
|
||||
description: "For fast-moving teams managing churn at scale.",
|
||||
featured: true,
|
||||
limits: {
|
||||
customers: "Up to 250 customers",
|
||||
integrations: "Up to 3 integrations",
|
||||
alerts: "Advanced alert rules",
|
||||
support: "Priority email support",
|
||||
},
|
||||
},
|
||||
{
|
||||
tier: "scale",
|
||||
name: "Scale",
|
||||
monthlyPrice: 149,
|
||||
annualPrice: 1490,
|
||||
description: "For mature revenue teams with complex customer motion.",
|
||||
limits: {
|
||||
customers: "Unlimited customers",
|
||||
integrations: "Unlimited integrations",
|
||||
alerts: "Advanced workflows",
|
||||
support: "Dedicated success partner",
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export function savingsBadge(plan: BillingPlanDefinition): string {
|
||||
const monthlyAnnualized = plan.monthlyPrice * 12;
|
||||
const delta = monthlyAnnualized - plan.annualPrice;
|
||||
return delta > 0 ? `Save $${delta}/yr` : "Annual billing";
|
||||
}
|
||||
@@ -429,8 +429,7 @@ function OnboardingContent() {
|
||||
}
|
||||
}
|
||||
|
||||
const steps: WizardShellStep[] = useMemo(
|
||||
() => [
|
||||
const steps: WizardShellStep[] = [
|
||||
{
|
||||
id: "welcome",
|
||||
label: "Welcome",
|
||||
@@ -549,29 +548,7 @@ function OnboardingContent() {
|
||||
});
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
welcomeValue,
|
||||
organization?.name,
|
||||
stripeStatus,
|
||||
stripeBusy,
|
||||
stripeError,
|
||||
hubSpotStatus,
|
||||
hubSpotBusy,
|
||||
hubSpotError,
|
||||
intercomStatus,
|
||||
intercomBusy,
|
||||
intercomError,
|
||||
connectedProviders,
|
||||
syncStatus,
|
||||
previewLoading,
|
||||
distribution,
|
||||
atRiskCustomers,
|
||||
fetchStripeStatus,
|
||||
fetchHubSpotStatus,
|
||||
fetchIntercomStatus,
|
||||
],
|
||||
);
|
||||
];
|
||||
|
||||
async function handleDone() {
|
||||
try {
|
||||
|
||||
15
web/src/pages/PricingPage.tsx
Normal file
15
web/src/pages/PricingPage.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
import PricingSection from "@/components/landing/PricingSection";
|
||||
import SeoMeta from "@/components/SeoMeta";
|
||||
|
||||
export default function PricingPage() {
|
||||
return (
|
||||
<div className="min-h-screen bg-white dark:bg-gray-950">
|
||||
<SeoMeta
|
||||
title="PulseScore Pricing"
|
||||
description="Compare PulseScore Free, Growth, and Scale plans with monthly and annual billing options."
|
||||
path="/pricing"
|
||||
/>
|
||||
<PricingSection showStandaloneHeader />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,9 +1,18 @@
|
||||
import { useMemo } from "react";
|
||||
import { useSearchParams } from "react-router-dom";
|
||||
|
||||
import SubscriptionManager from "@/components/billing/SubscriptionManager";
|
||||
|
||||
export default function BillingTab() {
|
||||
return (
|
||||
<div className="py-8 text-center">
|
||||
<p className="text-sm text-gray-500 dark:text-gray-400">
|
||||
Billing and subscription management coming soon.
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
const [searchParams] = useSearchParams();
|
||||
|
||||
const checkoutState = useMemo(() => {
|
||||
const value = searchParams.get("checkout");
|
||||
if (value === "success" || value === "cancelled") {
|
||||
return value;
|
||||
}
|
||||
return null;
|
||||
}, [searchParams]);
|
||||
|
||||
return <SubscriptionManager checkoutState={checkoutState} />;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useNavigate, useSearchParams } from "react-router-dom";
|
||||
import { hubspotApi } from "@/lib/hubspot";
|
||||
import BaseLayout from "@/components/BaseLayout";
|
||||
@@ -7,21 +7,24 @@ import { ONBOARDING_RESUME_STEP_STORAGE_KEY } from "@/contexts/onboarding/consta
|
||||
export default function HubSpotCallbackPage() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const [error, setError] = useState("");
|
||||
const [asyncError, setAsyncError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const error = useMemo(() => {
|
||||
const errParam = searchParams.get("error");
|
||||
|
||||
if (errParam) {
|
||||
const desc = searchParams.get("error_description") || errParam;
|
||||
setError(`HubSpot connection failed: ${desc}`);
|
||||
return;
|
||||
return `HubSpot connection failed: ${desc}`;
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
setError("Invalid callback parameters.");
|
||||
return "Invalid callback parameters.";
|
||||
}
|
||||
return asyncError;
|
||||
}, [searchParams, code, state, asyncError]);
|
||||
|
||||
useEffect(() => {
|
||||
if (error || !code || !state) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -37,9 +40,9 @@ export default function HubSpotCallbackPage() {
|
||||
navigate("/settings/integrations", { replace: true });
|
||||
})
|
||||
.catch(() => {
|
||||
setError("Failed to complete HubSpot connection.");
|
||||
setAsyncError("Failed to complete HubSpot connection.");
|
||||
});
|
||||
}, [searchParams, navigate]);
|
||||
}, [error, code, state, navigate]);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useNavigate, useSearchParams } from "react-router-dom";
|
||||
import { intercomApi } from "@/lib/intercom";
|
||||
import BaseLayout from "@/components/BaseLayout";
|
||||
@@ -7,21 +7,24 @@ import { ONBOARDING_RESUME_STEP_STORAGE_KEY } from "@/contexts/onboarding/consta
|
||||
export default function IntercomCallbackPage() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const [error, setError] = useState("");
|
||||
const [asyncError, setAsyncError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const error = useMemo(() => {
|
||||
const errParam = searchParams.get("error");
|
||||
|
||||
if (errParam) {
|
||||
const desc = searchParams.get("error_description") || errParam;
|
||||
setError(`Intercom connection failed: ${desc}`);
|
||||
return;
|
||||
return `Intercom connection failed: ${desc}`;
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
setError("Invalid callback parameters.");
|
||||
return "Invalid callback parameters.";
|
||||
}
|
||||
return asyncError;
|
||||
}, [searchParams, code, state, asyncError]);
|
||||
|
||||
useEffect(() => {
|
||||
if (error || !code || !state) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -37,9 +40,9 @@ export default function IntercomCallbackPage() {
|
||||
navigate("/settings/integrations", { replace: true });
|
||||
})
|
||||
.catch(() => {
|
||||
setError("Failed to complete Intercom connection.");
|
||||
setAsyncError("Failed to complete Intercom connection.");
|
||||
});
|
||||
}, [searchParams, navigate]);
|
||||
}, [error, code, state, navigate]);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useNavigate, useSearchParams } from "react-router-dom";
|
||||
import { stripeApi } from "@/lib/stripe";
|
||||
import BaseLayout from "@/components/BaseLayout";
|
||||
@@ -7,21 +7,24 @@ import { ONBOARDING_RESUME_STEP_STORAGE_KEY } from "@/contexts/onboarding/consta
|
||||
export default function StripeCallbackPage() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const navigate = useNavigate();
|
||||
const [error, setError] = useState("");
|
||||
const [asyncError, setAsyncError] = useState("");
|
||||
|
||||
useEffect(() => {
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
const error = useMemo(() => {
|
||||
const errParam = searchParams.get("error");
|
||||
|
||||
if (errParam) {
|
||||
const desc = searchParams.get("error_description") || errParam;
|
||||
setError(`Stripe connection failed: ${desc}`);
|
||||
return;
|
||||
return `Stripe connection failed: ${desc}`;
|
||||
}
|
||||
|
||||
if (!code || !state) {
|
||||
setError("Invalid callback parameters.");
|
||||
return "Invalid callback parameters.";
|
||||
}
|
||||
return asyncError;
|
||||
}, [searchParams, code, state, asyncError]);
|
||||
|
||||
useEffect(() => {
|
||||
if (error || !code || !state) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -37,9 +40,9 @@ export default function StripeCallbackPage() {
|
||||
navigate("/settings/integrations", { replace: true });
|
||||
})
|
||||
.catch(() => {
|
||||
setError("Failed to complete Stripe connection.");
|
||||
setAsyncError("Failed to complete Stripe connection.");
|
||||
});
|
||||
}, [searchParams, navigate]);
|
||||
}, [error, code, state, navigate]);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user