Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 23 additions & 87 deletions apps/backend/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { authMiddleware } from "./middleware";
import dotenv from "dotenv";

import paymentRoutes from "./routes/payment.routes";
import {router as webhookRouter} from './routes/webhook.routes';
import { router as webhookRouter } from "./routes/webhook.routes";

const IMAGE_GEN_CREDITS = 1;
const TRAIN_MODEL_CREDITS = 20;
Expand Down Expand Up @@ -63,6 +63,27 @@ app.post("/ai/training", authMiddleware, async (req, res) => {
return;
}

// Todo : For checking the credits two queries are being made, we can combine them into one
const credits = await prismaClient.userCredit.findUnique({
where: {
userId: req.userId!,
},
});
if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) {
res.status(411).json({
message: "Not enough credits",
});
return;
}
await prismaClient.userCredit.update({
where: {
userId: req.userId!,
},
data: {
amount: { decrement: TRAIN_MODEL_CREDITS },
},
});

const { request_id, response_url } = await falAiModel.trainModel(
parsedBody.data.zipUrl,
parsedBody.data.name
Expand Down Expand Up @@ -120,7 +141,6 @@ app.post("/ai/generate", authMiddleware, async (req, res) => {
});
return;
}

const { request_id, response_url } = await falAiModel.generateImage(
parsedBody.data.prompt,
model.tensorPath
Expand Down Expand Up @@ -268,94 +288,10 @@ app.get("/models", authMiddleware, async (req, res) => {
});
});

app.post("/fal-ai/webhook/train", async (req, res) => {
const requestId = req.body.request_id as string;

const result = await fal.queue.result("fal-ai/flux-lora", {
requestId,
});

// check if the user has enough credits
const credits = await prismaClient.userCredit.findUnique({
where: {
userId: req.userId!,
},
});

if ((credits?.amount ?? 0) < TRAIN_MODEL_CREDITS) {
res.status(411).json({
message: "Not enough credits",
});
return;
}

const { imageUrl } = await falAiModel.generateImageSync(
result.data.diffusers_lora_file.url
);

await prismaClient.model.updateMany({
where: {
falAiRequestId: requestId,
},
data: {
trainingStatus: "Generated",
//@ts-ignore
tensorPath: result.data.diffusers_lora_file.url,
thumbnail: imageUrl,
},
});

await prismaClient.userCredit.update({
where: {
userId: req.userId!,
},
data: {
amount: { decrement: TRAIN_MODEL_CREDITS },
},
});

res.json({
message: "Webhook received",
});
});

app.post("/fal-ai/webhook/image", async (req, res) => {
console.log("fal-ai/webhook/image");
console.log(req.body);
// update the status of the image in the DB
const requestId = req.body.request_id;

if (req.body.status === "ERROR") {
res.status(411).json({});
prismaClient.outputImages.updateMany({
where: {
falAiRequestId: requestId,
},
data: {
status: "Failed",
imageUrl: req.body.payload.images[0].url,
},
});
return;
}

await prismaClient.outputImages.updateMany({
where: {
falAiRequestId: requestId,
},
data: {
status: "Generated",
imageUrl: req.body.payload.images[0].url,
},
});

res.json({
message: "Webhook received",
});
});

app.use("/payment", paymentRoutes);
app.use("/api/webhook",webhookRouter );
app.use("/api/webhook", webhookRouter);

app.listen(PORT, () => {
console.log(`Server is running on port ${PORT}`);
Expand Down
4 changes: 2 additions & 2 deletions apps/backend/models/FalAIModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export class FalAIModel {
prompt: prompt,
loras: [{ path: tensorPath, scale: 1 }]
},
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/image`,
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/api/webhook/fal-ai/image`,
});

return { request_id, response_url };
Expand All @@ -25,7 +25,7 @@ export class FalAIModel {
images_data_url: zipUrl,
trigger_word: triggerWord
},
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/fal-ai/webhook/train`,
webhookUrl: `${process.env.WEBHOOK_BASE_URL}/api/webhook/fal-ai/train`,
});

return { request_id, response_url };
Expand Down
79 changes: 72 additions & 7 deletions apps/backend/routes/webhook.routes.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import { prismaClient } from "db";
import { Router } from "express";
import { Webhook } from "svix";
import { fal } from "@fal-ai/client";
import { FalAIModel } from "../models/FalAIModel";

export const router = Router();

const IMAGE_GEN_CREDITS = 1;
const TRAIN_MODEL_CREDITS = 20;

const falAiModel = new FalAIModel();

/**
* POST api/webhook/clerk
* Clerk webhook endpoint
* Clerk will hit this endpoint when user is created, updated or deleted
*/
router.post("/clerk", async (req, res) => {
const SIGNING_SECRET =
Expand All @@ -27,10 +34,11 @@ router.post("/clerk", async (req, res) => {
const svix_signature = headers["svix-signature"];

if (!svix_id || !svix_timestamp || !svix_signature) {
return res.status(400).json({
res.status(400).json({
success: false,
message: "Error: Missing svix headers",
});
return;
}

let evt: any;
Expand All @@ -43,10 +51,11 @@ router.post("/clerk", async (req, res) => {
});
} catch (err) {
console.log("Error: Could not verify webhook:", err.message);
return res.status(400).json({
res.status(400).json({
success: false,
message: err.message,
});
return;
}

const { id } = evt.data;
Expand Down Expand Up @@ -86,9 +95,65 @@ router.post("/clerk", async (req, res) => {
}
} catch (error) {
console.error("Error handling webhook:", error);
return res
.status(500)
.json({ success: false, message: "Internal Server Error" });
res.status(500).json({ success: false, message: "Internal Server Error" });
return;
}
return res.status(200).json({ success: true, message: "Webhook received" });
res.status(200).json({ success: true, message: "Webhook received" });
return;
});

/**
* POST api/webhook/fal-ai/train
* Fal AI wil hit this endpoint when training is done
*/

router.post("/fal-ai/train", async (req, res) => {
const { requestId } = req.body;

const result = await fal.queue.result("fal-ai/flux-lora", {
requestId,
});

const { imageUrl } = await falAiModel.generateImageSync(
//@ts-ignore
result.data.diffusers_lora_file.url
);

await prismaClient.model.updateMany({
where: {
falAiRequestId: requestId,
},
data: {
trainingStatus: "Generated",
//@ts-ignore
tensorPath: result.data.diffusers_lora_file.url,
thumbnail: imageUrl,
},
});
res.json({
message: "Webhook received",
});
});

/**
* POST api/webhook/fal-ai/image
* Fal AI wil hit this endpoint when image is generated
*/

router.post("/fal-ai/image", async (req, res) => {
const { requestId } = req.body;

await prismaClient.outputImages.updateMany({
where: {
falAiRequestId: requestId,
},
data: {
status: "Generated",
imageUrl: req.body.payload.images[0].url,
},
});

res.status(200).json({
message: "Webhook received",
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AddForeignKey
ALTER TABLE "UserCredit" ADD CONSTRAINT "UserCredit_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
Loading