From f96c5b5788269d11010ecca1b68e382045d89850 Mon Sep 17 00:00:00 2001 From: rohitdhakane Date: Wed, 26 Feb 2025 01:15:43 +0530 Subject: [PATCH] Fix bug where users can repeatedly train models without reducing credits --- apps/backend/index.ts | 110 ++++-------------- apps/backend/models/FalAIModel.ts | 4 +- apps/backend/routes/webhook.routes.ts | 79 +++++++++++-- .../migration.sql | 2 + packages/db/prisma/schema.prisma | 102 ++++++++-------- 5 files changed, 152 insertions(+), 145 deletions(-) create mode 100644 packages/db/prisma/migrations/20250225175603_add_relation_usercredit_to_user/migration.sql diff --git a/apps/backend/index.ts b/apps/backend/index.ts index 6281e76..68e2182 100644 --- a/apps/backend/index.ts +++ b/apps/backend/index.ts @@ -13,7 +13,7 @@ import { authMiddleware } from "./middleware"; import dotenv from "dotenv"; import paymentRoutes from "./routes/payment.routes"; -import {router as webhookRouter} from './routes/webhook.routes'; +import { router as webhookRouter } from "./routes/webhook.routes"; const IMAGE_GEN_CREDITS = 1; const TRAIN_MODEL_CREDITS = 20; @@ -63,6 +63,27 @@ app.post("/ai/training", authMiddleware, async (req, res) => { return; } + // Todo : For checking the credits two queries are being made, we can combine them into one + const credits = await prismaClient.userCredit.findUnique({ + where: { + userId: req.userId!, + }, + }); + if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) { + res.status(411).json({ + message: "Not enough credits", + }); + return; + } + await prismaClient.userCredit.update({ + where: { + userId: req.userId!, + }, + data: { + amount: { decrement: TRAIN_MODEL_CREDITS }, + }, + }); + const { request_id, response_url } = await falAiModel.trainModel( parsedBody.data.zipUrl, parsedBody.data.name @@ -120,7 +141,6 @@ app.post("/ai/generate", authMiddleware, async (req, res) => { }); return; } - const { request_id, response_url } = await falAiModel.generateImage( parsedBody.data.prompt, model.tensorPath @@ -268,94 +288,10 @@ app.get("/models", authMiddleware, async (req, res) => { }); }); -app.post("/fal-ai/webhook/train", async (req, res) => { - const requestId = req.body.request_id as string; - const result = await fal.queue.result("fal-ai/flux-lora", { - requestId, - }); - - // check if the user has enough credits - const credits = await prismaClient.userCredit.findUnique({ - where: { - userId: req.userId!, - }, - }); - - if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) { - res.status(411).json({ - message: "Not enough credits", - }); - return; - } - - const { imageUrl } = await falAiModel.generateImageSync( - result.data.diffusers_lora_file.url - ); - - await prismaClient.model.updateMany({ - where: { - falAiRequestId: requestId, - }, - data: { - trainingStatus: "Generated", - //@ts-ignore - tensorPath: result.data.diffusers_lora_file.url, - thumbnail: imageUrl, - }, - }); - - await prismaClient.userCredit.update({ - where: { - userId: req.userId!, - }, - data: { - amount: { decrement: TRAIN_MODEL_CREDITS }, - }, - }); - - res.json({ - message: "Webhook received", - }); -}); - -app.post("/fal-ai/webhook/image", async (req, res) => { - console.log("fal-ai/webhook/image"); - console.log(req.body); - // update the status of the image in the DB - const requestId = req.body.request_id; - - if (req.body.status === "ERROR") { - res.status(411).json({}); - prismaClient.outputImages.updateMany({ - where: { - falAiRequestId: requestId, - }, - data: { - status: "Failed", - imageUrl: req.body.payload.images[0].url, - }, - }); - return; - } - - await prismaClient.outputImages.updateMany({ - where: { - falAiRequestId: requestId, - }, - data: { - status: "Generated", - imageUrl: req.body.payload.images[0].url, - }, - }); - - res.json({ - message: "Webhook received", - }); -}); app.use("/payment", paymentRoutes); -app.use("/api/webhook",webhookRouter ); +app.use("/api/webhook", webhookRouter); app.listen(PORT, () => { console.log(`Server is running on port ${PORT}`); diff --git a/apps/backend/models/FalAIModel.ts b/apps/backend/models/FalAIModel.ts index 7ce4776..9186a9c 100644 --- a/apps/backend/models/FalAIModel.ts +++ b/apps/backend/models/FalAIModel.ts @@ -12,7 +12,7 @@ export class FalAIModel { prompt: prompt, loras: [{ path: tensorPath, scale: 1 }] }, - webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/image`, + webhookUrl: `${process.env.WEBHOOK_BASE_URL}/api/webhook/fal-ai/image`, }); return { request_id, response_url }; @@ -25,7 +25,7 @@ export class FalAIModel { images_data_url: zipUrl, trigger_word: triggerWord }, - webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/train`, + webhookUrl: `${process.env.WEBHOOK_BASE_URL}/api/webhook/fal-ai/train`, }); return { request_id, response_url }; diff --git a/apps/backend/routes/webhook.routes.ts b/apps/backend/routes/webhook.routes.ts index f5138f2..c113171 100644 --- a/apps/backend/routes/webhook.routes.ts +++ b/apps/backend/routes/webhook.routes.ts @@ -1,12 +1,19 @@ import { prismaClient } from "db"; import { Router } from "express"; import { Webhook } from "svix"; +import { fal } from "@fal-ai/client"; +import { FalAIModel } from "../models/FalAIModel"; export const router = Router(); +const IMAGE_GEN_CREDITS = 1; +const TRAIN_MODEL_CREDITS = 20; + +const falAiModel = new FalAIModel(); + /** * POST api/webhook/clerk - * Clerk webhook endpoint + * Clerk will hit this endpoint when user is created, updated or deleted */ router.post("/clerk", async (req, res) => { const SIGNING_SECRET = @@ -27,10 +34,11 @@ router.post("/clerk", async (req, res) => { const svix_signature = headers["svix-signature"]; if (!svix_id || !svix_timestamp || !svix_signature) { - return res.status(400).json({ + res.status(400).json({ success: false, message: "Error: Missing svix headers", }); + return; } let evt: any; @@ -43,10 +51,11 @@ router.post("/clerk", async (req, res) => { }); } catch (err) { console.log("Error: Could not verify webhook:", err.message); - return res.status(400).json({ + res.status(400).json({ success: false, message: err.message, }); + return; } const { id } = evt.data; @@ -86,9 +95,65 @@ router.post("/clerk", async (req, res) => { } } catch (error) { console.error("Error handling webhook:", error); - return res - .status(500) - .json({ success: false, message: "Internal Server Error" }); + res.status(500).json({ success: false, message: "Internal Server Error" }); + return; } - return res.status(200).json({ success: true, message: "Webhook received" }); + res.status(200).json({ success: true, message: "Webhook received" }); + return; +}); + +/** + * POST api/webhook/fal-ai/train + * Fal AI wil hit this endpoint when training is done + */ + +router.post("/fal-ai/train", async (req, res) => { + const { requestId } = req.body; + + const result = await fal.queue.result("fal-ai/flux-lora", { + requestId, + }); + + const { imageUrl } = await falAiModel.generateImageSync( + //@ts-ignore + result.data.diffusers_lora_file.url + ); + + await prismaClient.model.updateMany({ + where: { + falAiRequestId: requestId, + }, + data: { + trainingStatus: "Generated", + //@ts-ignore + tensorPath: result.data.diffusers_lora_file.url, + thumbnail: imageUrl, + }, + }); + res.json({ + message: "Webhook received", + }); +}); + +/** + * POST api/webhook/fal-ai/image + * Fal AI wil hit this endpoint when image is generated + */ + +router.post("/fal-ai/image", async (req, res) => { + const { requestId } = req.body; + + await prismaClient.outputImages.updateMany({ + where: { + falAiRequestId: requestId, + }, + data: { + status: "Generated", + imageUrl: req.body.payload.images[0].url, + }, + }); + + res.status(200).json({ + message: "Webhook received", + }); }); diff --git a/packages/db/prisma/migrations/20250225175603_add_relation_usercredit_to_user/migration.sql b/packages/db/prisma/migrations/20250225175603_add_relation_usercredit_to_user/migration.sql new file mode 100644 index 0000000..e01ae7e --- /dev/null +++ b/packages/db/prisma/migrations/20250225175603_add_relation_usercredit_to_user/migration.sql @@ -0,0 +1,2 @@ +-- AddForeignKey +ALTER TABLE "UserCredit" ADD CONSTRAINT "UserCredit_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE; diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 3305819..a8e8c1a 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -14,13 +14,14 @@ datasource db { } model User { - id String @id @default(uuid()) - clerkId String @unique - email String @unique - name String? + id String @id @default(uuid()) + clerkId String @unique + email String @unique + name String? profilePicture String? - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt + credits UserCredit? + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt } enum ModelTrainingStatusEnum { @@ -30,24 +31,25 @@ enum ModelTrainingStatusEnum { } model Model { - id String @id @default(uuid()) - name String - type ModelTypeEnum - age Int - ethinicity EthenecityEnum - eyeColor EyeColorEnum - bald Boolean - userId String - triggerWord String? - tensorPath String? - thumbnail String? - trainingStatus ModelTrainingStatusEnum @default(Pending) - outputImages OutputImages[] - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - falAiRequestId String? - zipUrl String - open Boolean @default(false) + id String @id @default(uuid()) + name String + type ModelTypeEnum + age Int + ethinicity EthenecityEnum + eyeColor EyeColorEnum + bald Boolean + userId String + triggerWord String? + tensorPath String? + thumbnail String? + trainingStatus ModelTrainingStatusEnum @default(Pending) + outputImages OutputImages[] + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + falAiRequestId String? + zipUrl String + open Boolean @default(false) + @@index([falAiRequestId]) } @@ -58,33 +60,34 @@ enum OutputImageStatusEnum { } model OutputImages { - id String @id @default(uuid()) - imageUrl String @default("") - modelId String - userId String - prompt String - falAiRequestId String? - status OutputImageStatusEnum @default(Pending) - model Model @relation(fields: [modelId], references: [id]) - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt + id String @id @default(uuid()) + imageUrl String @default("") + modelId String + userId String + prompt String + falAiRequestId String? + status OutputImageStatusEnum @default(Pending) + model Model @relation(fields: [modelId], references: [id]) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + @@index([falAiRequestId]) } model Packs { - id String @id @default(uuid()) - name String - description String @default("") - imageUrl1 String @default("") - imageUrl2 String @default("") - prompts PackPrompts[] + id String @id @default(uuid()) + name String + description String @default("") + imageUrl1 String @default("") + imageUrl2 String @default("") + prompts PackPrompts[] } model PackPrompts { - id String @id @default(uuid()) - prompt String - packId String - pack Packs @relation(fields: [packId], references: [id]) + id String @id @default(uuid()) + prompt String + packId String + pack Packs @relation(fields: [packId], references: [id]) } model Subscription { @@ -111,11 +114,11 @@ enum ModelTypeEnum { enum EthenecityEnum { White Black - Asian_American @map("Asian American") - East_Asian @map("East Asian") - South_East_Asian @map("South East Asian") - South_Asian @map("South Asian") - Middle_Eastern @map("Middle Eastern") + Asian_American @map("Asian American") + East_Asian @map("East Asian") + South_East_Asian @map("South East Asian") + South_Asian @map("South Asian") + Middle_Eastern @map("Middle Eastern") Pacific Hispanic } @@ -130,6 +133,7 @@ enum EyeColorEnum { model UserCredit { id String @id @default(cuid()) userId String @unique + user User @relation(fields: [userId], references: [id]) amount Int @default(0) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt