Implement updateSubscription feature and refactor billing services

This commit introduces the updateSubscription method to the BillingStrategyProviderService, ensuring that subscriptions can be updated within the billing core. Additionally, a refactor has been applied to the BillingGatewayFactoryService and stripe-billing-strategy.service to improve error handling and the robustness of subscription updates. Logging in the webhook route has been adjusted for clarity and the data model has been enhanced.
This commit is contained in:
giancarlo
2024-04-04 20:15:12 +08:00
parent 4a122ee5df
commit 220a23e185
26 changed files with 1499 additions and 993 deletions

View File

@@ -103,6 +103,34 @@ export const PlanSchema = z
message: 'Line item IDs must be unique',
path: ['lineItems'],
},
)
.refine(
(data) => {
if (data.paymentType === 'one-time') {
const meteredItems = data.lineItems.filter(
(item) => item.type === 'metered',
);
return meteredItems.length === 0;
}
},
{
message: 'One-time plans must not have metered line items',
path: ['paymentType', 'lineItems'],
},
)
.refine(
(data) => {
if (data.paymentType === 'one-time') {
const baseItems = data.lineItems.filter((item) => item.type !== 'base');
return baseItems.length === 0;
}
},
{
message: 'One-time plans must not have non-base line items',
path: ['paymentType', 'lineItems'],
},
);
const ProductSchema = z
@@ -259,3 +287,20 @@ export function getProductPlanPairByVariantId(
throw new Error('Plan not found');
}
export function getLineItemTypeById(
config: z.infer<typeof BillingSchema>,
id: string,
) {
for (const product of config.products) {
for (const plan of product.plans) {
for (const lineItem of plan.lineItems) {
if (lineItem.type === id) {
return lineItem.type;
}
}
}
}
throw new Error(`Line Item with ID ${id} not found`);
}

View File

@@ -3,3 +3,4 @@ export * from './create-biling-portal-session.schema';
export * from './retrieve-checkout-session.schema';
export * from './cancel-subscription-params.schema';
export * from './report-billing-usage.schema';
export * from './update-subscription-params.schema';

View File

@@ -0,0 +1,7 @@
import { z } from 'zod';
export const UpdateSubscriptionParamsSchema = z.object({
subscriptionId: z.string().min(1),
subscriptionItemId: z.string().min(1),
quantity: z.number().min(1),
});

View File

@@ -4,9 +4,10 @@ import {
CancelSubscriptionParamsSchema,
CreateBillingCheckoutSchema,
CreateBillingPortalSessionSchema,
ReportBillingUsageSchema,
RetrieveCheckoutSessionSchema,
UpdateSubscriptionParamsSchema,
} from '../schema';
import { ReportBillingUsageSchema } from '../schema';
export abstract class BillingStrategyProviderService {
abstract createBillingPortalSession(
@@ -44,4 +45,10 @@ export abstract class BillingStrategyProviderService {
): Promise<{
success: boolean;
}>;
abstract updateSubscription(
params: z.infer<typeof UpdateSubscriptionParamsSchema>,
): Promise<{
success: boolean;
}>;
}

View File

@@ -154,7 +154,7 @@ export class BillingEventHandlerService {
Logger.info(
{
namespace: 'billing',
namespace: this.namespace,
sessionId,
},
'Successfully updated payment status',

View File

@@ -1,6 +1,7 @@
import { z } from 'zod';
import {
BillingConfig,
BillingProviderSchema,
BillingWebhookHandlerService,
} from '@kit/billing';
@@ -8,12 +9,13 @@ import {
export class BillingEventHandlerFactoryService {
static async GetProviderStrategy(
provider: z.infer<typeof BillingProviderSchema>,
config: BillingConfig,
): Promise<BillingWebhookHandlerService> {
switch (provider) {
case 'stripe': {
const { StripeWebhookHandlerService } = await import('@kit/stripe');
return new StripeWebhookHandlerService();
return new StripeWebhookHandlerService(config);
}
case 'lemon-squeezy': {
@@ -21,7 +23,7 @@ export class BillingEventHandlerFactoryService {
'@kit/lemon-squeezy'
);
return new LemonSqueezyWebhookHandlerService();
return new LemonSqueezyWebhookHandlerService(config);
}
case 'paddle': {

View File

@@ -1,3 +1,4 @@
import { BillingConfig } from '@kit/billing';
import { Database } from '@kit/supabase/database';
import { getSupabaseServerActionClient } from '@kit/supabase/server-actions-client';
@@ -12,9 +13,12 @@ import { BillingEventHandlerFactoryService } from './billing-gateway-factory.ser
export async function getBillingEventHandlerService(
clientProvider: () => ReturnType<typeof getSupabaseServerActionClient>,
provider: Database['public']['Enums']['billing_provider'],
config: BillingConfig,
) {
const strategy =
await BillingEventHandlerFactoryService.GetProviderStrategy(provider);
const strategy = await BillingEventHandlerFactoryService.GetProviderStrategy(
provider,
config,
);
return new BillingEventHandlerService(clientProvider, strategy);
}

View File

@@ -5,7 +5,9 @@ import {
CancelSubscriptionParamsSchema,
CreateBillingCheckoutSchema,
CreateBillingPortalSessionSchema,
ReportBillingUsageSchema,
RetrieveCheckoutSessionSchema,
UpdateSubscriptionParamsSchema,
} from '@kit/billing/schema';
import { BillingGatewayFactoryService } from './billing-gateway-factory.service';
@@ -92,4 +94,35 @@ export class BillingGatewayService {
return strategy.cancelSubscription(payload);
}
/**
* Reports the usage of the billing.
* @description This is used to report the usage of the billing to the provider.
* @param params
*/
async reportUsage(params: z.infer<typeof ReportBillingUsageSchema>) {
const strategy = await BillingGatewayFactoryService.GetProviderStrategy(
this.provider,
);
const payload = ReportBillingUsageSchema.parse(params);
return strategy.reportUsage(payload);
}
/**
* Updates a subscription with the specified parameters.
* @param params
*/
async updateSubscriptionItem(
params: z.infer<typeof UpdateSubscriptionParamsSchema>,
) {
const strategy = await BillingGatewayFactoryService.GetProviderStrategy(
this.provider,
);
const payload = UpdateSubscriptionParamsSchema.parse(params);
return strategy.updateSubscription(payload);
}
}

View File

@@ -2,6 +2,7 @@ import {
cancelSubscription,
createUsageRecord,
getCheckout,
updateSubscriptionItem,
} from '@lemonsqueezy/lemonsqueezy.js';
import 'server-only';
import { z } from 'zod';
@@ -13,6 +14,7 @@ import {
CreateBillingPortalSessionSchema,
ReportBillingUsageSchema,
RetrieveCheckoutSessionSchema,
UpdateSubscriptionParamsSchema,
} from '@kit/billing/schema';
import { Logger } from '@kit/shared/logger';
@@ -240,4 +242,35 @@ export class LemonSqueezyBillingStrategyService
return { success: true };
}
async updateSubscription(
params: z.infer<typeof UpdateSubscriptionParamsSchema>,
) {
const ctx = {
name: 'billing.lemon-squeezy',
...params,
};
Logger.info(ctx, 'Updating subscription...');
const { error } = await updateSubscriptionItem(params.subscriptionItemId, {
quantity: params.quantity,
});
if (error) {
Logger.error(
{
...ctx,
error,
},
'Failed to update subscription',
);
throw error;
}
Logger.info(ctx, 'Subscription updated successfully');
return { success: true };
}
}

View File

@@ -1,7 +1,11 @@
import { getOrder, getVariant } from '@lemonsqueezy/lemonsqueezy.js';
import { createHmac, timingSafeEqual } from 'crypto';
import { BillingWebhookHandlerService } from '@kit/billing';
import {
BillingConfig,
BillingWebhookHandlerService,
getLineItemTypeById,
} from '@kit/billing';
import { Logger } from '@kit/shared/logger';
import { Database } from '@kit/supabase/database';
@@ -35,6 +39,8 @@ export class LemonSqueezyWebhookHandlerService
private readonly namespace = 'billing.lemon-squeezy';
constructor(private readonly config: BillingConfig) {}
/**
* @description Verifies the webhook signature - should throw an error if the signature is invalid
*/
@@ -307,6 +313,7 @@ export class LemonSqueezyWebhookHandlerService
product_id: item.product,
variant_id: item.variant,
price_amount: item.unitAmount,
type: getLineItemTypeById(this.config, item.id),
};
});

View File

@@ -9,6 +9,7 @@ import {
CreateBillingPortalSessionSchema,
ReportBillingUsageSchema,
RetrieveCheckoutSessionSchema,
UpdateSubscriptionParamsSchema,
} from '@kit/billing/schema';
import { Logger } from '@kit/shared/logger';
@@ -198,6 +199,52 @@ export class StripeBillingStrategyService
return { success: true };
}
async updateSubscription(
params: z.infer<typeof UpdateSubscriptionParamsSchema>,
) {
const stripe = await this.stripeProvider();
Logger.info(
{
name: 'billing.stripe',
...params,
},
'Updating subscription...',
);
try {
await stripe.subscriptions.update(params.subscriptionId, {
items: [
{
id: params.subscriptionItemId,
quantity: params.quantity,
},
],
});
Logger.info(
{
name: 'billing.stripe',
...params,
},
'Subscription updated successfully',
);
return { success: true };
} catch (e) {
Logger.error(
{
name: 'billing.stripe',
...params,
error: e,
},
'Failed to update subscription',
);
throw new Error('Failed to update subscription');
}
}
private async stripeProvider(): Promise<Stripe> {
return createStripeClient();
}

View File

@@ -1,6 +1,10 @@
import Stripe from 'stripe';
import { BillingWebhookHandlerService } from '@kit/billing';
import {
BillingConfig,
BillingWebhookHandlerService,
getLineItemTypeById,
} from '@kit/billing';
import { Logger } from '@kit/shared/logger';
import { Database } from '@kit/supabase/database';
@@ -18,6 +22,8 @@ export class StripeWebhookHandlerService
{
private stripe: Stripe | undefined;
constructor(private readonly config: BillingConfig) {}
private readonly provider: Database['public']['Enums']['billing_provider'] =
'stripe';
@@ -134,6 +140,8 @@ export class StripeWebhookHandlerService
const accountId = session.client_reference_id!;
const customerId = session.customer as string;
// if it's a subscription, we need to retrieve the subscription
// and build the payload for the subscription
if (isSubscription) {
const subscriptionId = session.subscription as string;
const subscription = await stripe.subscriptions.retrieve(subscriptionId);
@@ -154,8 +162,10 @@ export class StripeWebhookHandlerService
return onCheckoutCompletedCallback(payload);
} else {
// if it's a one-time payment, we need to retrieve the session
const sessionId = event.data.object.id;
// from the session, we need to retrieve the line items
const sessionWithLineItems = await stripe.checkout.sessions.retrieve(
event.data.object.id,
{
@@ -280,6 +290,7 @@ export class StripeWebhookHandlerService
price_amount: item.price?.unit_amount,
interval: item.price?.recurring?.interval as string,
interval_count: item.price?.recurring?.interval_count as number,
type: getLineItemTypeById(this.config, item.id),
};
});