|
1 | 1 | package v1beta1connect
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + |
| 7 | + "connectrpc.com/connect" |
| 8 | + "github.com/raystack/frontier/billing/plan" |
| 9 | + "github.com/raystack/frontier/billing/product" |
4 | 10 | "github.com/raystack/frontier/billing/subscription"
|
5 | 11 | frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
|
6 | 12 | "google.golang.org/protobuf/types/known/timestamppb"
|
7 | 13 | )
|
8 | 14 |
|
| 15 | +type SubscriptionService interface { |
| 16 | + GetByID(ctx context.Context, id string) (subscription.Subscription, error) |
| 17 | + List(ctx context.Context, filter subscription.Filter) ([]subscription.Subscription, error) |
| 18 | + Cancel(ctx context.Context, id string, immediate bool) (subscription.Subscription, error) |
| 19 | + ChangePlan(ctx context.Context, id string, change subscription.ChangeRequest) (subscription.Phase, error) |
| 20 | + HasUserSubscribedBefore(ctx context.Context, customerID string, planID string) (bool, error) |
| 21 | +} |
| 22 | + |
| 23 | +type PlanService interface { |
| 24 | + GetByID(ctx context.Context, id string) (plan.Plan, error) |
| 25 | +} |
| 26 | + |
| 27 | +func (h *ConnectHandler) ListSubscriptions(ctx context.Context, request *connect.Request[frontierv1beta1.ListSubscriptionsRequest]) (*connect.Response[frontierv1beta1.ListSubscriptionsResponse], error) { |
| 28 | + if request.Msg.GetOrgId() == "" || request.Msg.GetBillingId() == "" { |
| 29 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) |
| 30 | + } |
| 31 | + planID := request.Msg.GetPlan() |
| 32 | + if planID != "" { |
| 33 | + plan, err := h.planService.GetByID(ctx, planID) |
| 34 | + if err != nil { |
| 35 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrBadRequest) |
| 36 | + } |
| 37 | + planID = plan.ID |
| 38 | + } |
| 39 | + |
| 40 | + var subscriptions []*frontierv1beta1.Subscription |
| 41 | + subscriptionList, err := h.subscriptionService.List(ctx, subscription.Filter{ |
| 42 | + CustomerID: request.Msg.GetBillingId(), |
| 43 | + State: request.Msg.GetState(), |
| 44 | + PlanID: planID, |
| 45 | + }) |
| 46 | + if err != nil { |
| 47 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 48 | + } |
| 49 | + for _, v := range subscriptionList { |
| 50 | + subscriptionPB, err := transformSubscriptionToPB(v) |
| 51 | + if err != nil { |
| 52 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 53 | + } |
| 54 | + subscriptions = append(subscriptions, subscriptionPB) |
| 55 | + } |
| 56 | + |
| 57 | + return connect.NewResponse(&frontierv1beta1.ListSubscriptionsResponse{ |
| 58 | + Subscriptions: subscriptions, |
| 59 | + }), nil |
| 60 | +} |
| 61 | + |
| 62 | +func (h *ConnectHandler) GetSubscription(ctx context.Context, request *connect.Request[frontierv1beta1.GetSubscriptionRequest]) (*connect.Response[frontierv1beta1.GetSubscriptionResponse], error) { |
| 63 | + subscription, err := h.subscriptionService.GetByID(ctx, request.Msg.GetId()) |
| 64 | + if err != nil { |
| 65 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 66 | + } |
| 67 | + |
| 68 | + subscriptionPB, err := transformSubscriptionToPB(subscription) |
| 69 | + if err != nil { |
| 70 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 71 | + } |
| 72 | + return connect.NewResponse(&frontierv1beta1.GetSubscriptionResponse{ |
| 73 | + Subscription: subscriptionPB, |
| 74 | + }), nil |
| 75 | +} |
| 76 | + |
| 77 | +func (h *ConnectHandler) CancelSubscription(ctx context.Context, request *connect.Request[frontierv1beta1.CancelSubscriptionRequest]) (*connect.Response[frontierv1beta1.CancelSubscriptionResponse], error) { |
| 78 | + _, err := h.subscriptionService.Cancel(ctx, request.Msg.GetId(), request.Msg.GetImmediate()) |
| 79 | + if err != nil { |
| 80 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 81 | + } |
| 82 | + return connect.NewResponse(&frontierv1beta1.CancelSubscriptionResponse{}), nil |
| 83 | +} |
| 84 | + |
| 85 | +func (h *ConnectHandler) ChangeSubscription(ctx context.Context, request *connect.Request[frontierv1beta1.ChangeSubscriptionRequest]) (*connect.Response[frontierv1beta1.ChangeSubscriptionResponse], error) { |
| 86 | + changeReq := subscription.ChangeRequest{ |
| 87 | + PlanID: request.Msg.GetPlan(), |
| 88 | + Immediate: request.Msg.GetImmediate(), |
| 89 | + CancelUpcoming: false, |
| 90 | + } |
| 91 | + if request.Msg.GetPlanChange() != nil { |
| 92 | + changeReq.PlanID = request.Msg.GetPlanChange().GetPlan() |
| 93 | + changeReq.Immediate = request.Msg.GetPlanChange().GetImmediate() |
| 94 | + } |
| 95 | + if request.Msg.GetPhaseChange() != nil { |
| 96 | + changeReq.CancelUpcoming = request.Msg.GetPhaseChange().GetCancelUpcomingChanges() |
| 97 | + } |
| 98 | + if changeReq.PlanID != "" && changeReq.CancelUpcoming { |
| 99 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrConflictingPlanChange) |
| 100 | + } |
| 101 | + if changeReq.PlanID == "" && !changeReq.CancelUpcoming { |
| 102 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrNoChangeRequested) |
| 103 | + } |
| 104 | + |
| 105 | + phase, err := h.subscriptionService.ChangePlan(ctx, request.Msg.GetId(), changeReq) |
| 106 | + if err != nil { |
| 107 | + if errors.Is(err, product.ErrPerSeatLimitReached) { |
| 108 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrPerSeatLimitReached) |
| 109 | + } |
| 110 | + if errors.Is(err, subscription.ErrAlreadyOnSamePlan) { |
| 111 | + return nil, connect.NewError(connect.CodeInvalidArgument, ErrAlreadyOnSamePlan) |
| 112 | + } |
| 113 | + return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError) |
| 114 | + } |
| 115 | + |
| 116 | + phasePb := &frontierv1beta1.Subscription_Phase{ |
| 117 | + PlanId: phase.PlanID, |
| 118 | + Reason: phase.Reason, |
| 119 | + } |
| 120 | + if !phase.EffectiveAt.IsZero() { |
| 121 | + phasePb.EffectiveAt = timestamppb.New(phase.EffectiveAt) |
| 122 | + } |
| 123 | + return connect.NewResponse(&frontierv1beta1.ChangeSubscriptionResponse{ |
| 124 | + Phase: phasePb, |
| 125 | + }), nil |
| 126 | +} |
| 127 | + |
9 | 128 | func transformSubscriptionToPB(subs subscription.Subscription) (*frontierv1beta1.Subscription, error) {
|
10 | 129 | metaData, err := subs.Metadata.ToStructPB()
|
11 | 130 | if err != nil {
|
|
0 commit comments