diff --git a/backend/cmd/fix_rentables/main.go b/backend/cmd/fix_rentables/main.go index 25a2b8faf..54afcbc16 100644 --- a/backend/cmd/fix_rentables/main.go +++ b/backend/cmd/fix_rentables/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "kubecloud/internal/core/models" + corepersistence "kubecloud/internal/core/persistence" "kubecloud/internal/infrastructure/persistence" "strings" @@ -52,8 +53,9 @@ func main() { defer substrateClient.Close() // Get all user_nodes records - var allRecords []models.UserNodes - if err := db.GetDB().Order("created_at DESC, id DESC").Find(&allRecords).Error; err != nil { + contractsRepo := corepersistence.NewGormUserContractDataRepository(db) + allRecords, err := contractsRepo.ListAllReservedNodes() + if err != nil { log.Error().Err(err).Msg("Failed to get user_nodes records") return } diff --git a/backend/cmd/root.go b/backend/cmd/root.go index d3fe1dd1f..608b59c92 100644 --- a/backend/cmd/root.go +++ b/backend/cmd/root.go @@ -154,18 +154,28 @@ func addFlags() error { } // === Monitor Balance Interval In Hours === - if err := bindIntFlag(rootCmd, "monitor_balance_interval_in_minutes", 1, "Number of minutes to monitor balance"); err != nil { - return fmt.Errorf("failed to bind monitor_balance_interval_in_minutes flag: %w", err) + if err := bindIntFlag(rootCmd, "settle_transfer_records_interval_in_minutes", 1, "Number of minutes to monitor balance"); err != nil { + return fmt.Errorf("failed to bind settle_transfer_records_interval_in_minutes flag: %w", err) } if err := bindIntFlag(rootCmd, "notify_admins_for_pending_records_in_hours", 1, "Number of hours to notify admins about pending records"); err != nil { return fmt.Errorf("failed to bind notify_admins_for_pending_records_in_hours flag: %w", err) } + // === Applied Discount === + if err := bindStringFlag(rootCmd, "applied_discount", "", "Applied discount to fund users"); err != nil { + return fmt.Errorf("failed to bind applied_discount flag: %w", err) + } + + if err := bindIntFlag(rootCmd, "minimum_tft_amount_in_wallet", 10, "Minimum TFT amount in wallet"); err != nil { + return fmt.Errorf("failed to bind minimum_tft_amount_in_wallet flag: %w", err) + } + // === Users Balance Check Interval In Hours === if err := bindIntFlag(rootCmd, "users_balance_check_interval_in_hours", 6, "Number of hours to check users balance"); err != nil { return fmt.Errorf("failed to bind users_balance_check_interval_in_hours flag: %w", err) } + if err := bindIntFlag(rootCmd, "check_user_debt_interval_in_hours", 48, "Number of upcoming hours to check user debt"); err != nil { return fmt.Errorf("failed to bind check_user_debt_interval_in_hours flag: %w", err) } diff --git a/backend/config-example.json b/backend/config-example.json index e1033d030..3617d2ec8 100644 --- a/backend/config-example.json +++ b/backend/config-example.json @@ -49,7 +49,6 @@ "debug": false, "disable_sentry": true, "dev_mode": false, - "monitor_balance_interval_in_minutes": 120, "notify_admins_for_pending_records_in_hours": 24, "cluster_health_check_interval_in_hours": 1, "node_health_check" : { @@ -71,6 +70,9 @@ "host": "" } }, + "settle_transfer_records_interval_in_minutes": 5, + "applied_discount": "gold", + "minimum_tft_amount_in_wallet": 10, "telemetry": { "otlp_endpoint": "jaeger:4317" } diff --git a/backend/docs/swagger/docs.go b/backend/docs/swagger/docs.go index 6ab3aafd6..e0c27a1dd 100644 --- a/backend/docs/swagger/docs.go +++ b/backend/docs/swagger/docs.go @@ -1177,53 +1177,6 @@ const docTemplate = `{ } } }, - "/pending-records": { - "get": { - "security": [ - { - "AdminMiddleware": [] - } - ], - "description": "Returns all pending records in the system", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "admin" - ], - "summary": "List pending records", - "operationId": "list-pending-records", - "responses": { - "200": { - "description": "Pending records are retrieved successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/services.PendingRecordsWithUSDAmounts" - } - } - } - ] - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, "/stats": { "get": { "security": [ @@ -1362,6 +1315,47 @@ const docTemplate = `{ } } }, + "/transfer-records": { + "get": { + "security": [ + { + "AdminMiddleware": [] + } + ], + "description": "Returns all transfer records in the system", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "admin" + ], + "summary": "List transfer records", + "operationId": "list-transfer-records", + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/services.TransferRecordsWithTFTAmount" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.APIResponse" + } + } + } + } + }, "/twins/{twin_id}/account": { "get": { "description": "Retrieve the account ID associated with a specific twin ID", @@ -1465,52 +1459,7 @@ const docTemplate = `{ "type": "object", "properties": { "data": { - "$ref": "#/definitions/services.UserWithPendingBalance" - } - } - } - ] - } - }, - "404": { - "description": "User is not found", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, - "/user/balance": { - "get": { - "description": "Retrieves the user's balance in USD", - "produces": [ - "application/json" - ], - "tags": [ - "users" - ], - "summary": "Get user balance", - "operationId": "get-user-balance", - "responses": { - "200": { - "description": "Balance fetched successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/handlers.UserBalanceResponse" + "$ref": "#/definitions/services.UserWithBalancesInUSD" } } } @@ -2231,53 +2180,6 @@ const docTemplate = `{ } } }, - "/user/pending-records": { - "get": { - "security": [ - { - "BearerAuth": [] - } - ], - "description": "Returns user pending records in the system", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "users" - ], - "summary": "List user pending records", - "operationId": "list-user-pending-records", - "responses": { - "200": { - "description": "Pending records returned successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/handlers.PendingRecordsResponse" - } - } - } - ] - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, "/user/redeem/{voucher_code}": { "put": { "description": "Redeems a voucher for the user", @@ -2797,7 +2699,7 @@ const docTemplate = `{ "schema": { "type": "array", "items": { - "$ref": "#/definitions/services.UserWithUSDBalance" + "$ref": "#/definitions/services.UserWithTFTBalance" } } }, @@ -3350,7 +3252,6 @@ const docTemplate = `{ }, "gridtypes.Unit": { "type": "integer", - "format": "int64", "enum": [ 1024, 1048576, @@ -3826,17 +3727,6 @@ const docTemplate = `{ } } }, - "handlers.PendingRecordsResponse": { - "type": "object", - "properties": { - "pending_records": { - "type": "array", - "items": { - "$ref": "#/definitions/services.PendingRecordsWithUSDAmounts" - } - } - } - }, "handlers.RedeemVoucherResponse": { "type": "object", "properties": { @@ -3982,20 +3872,6 @@ const docTemplate = `{ } } }, - "handlers.UserBalanceResponse": { - "type": "object", - "properties": { - "balance_usd": { - "type": "number" - }, - "debt_usd": { - "type": "number" - }, - "pending_balance_usd": { - "type": "number" - } - } - }, "handlers.UserWorkflow": { "type": "object", "properties": { @@ -4153,6 +4029,9 @@ const docTemplate = `{ "type": "string" } }, + "healthy": { + "type": "boolean" + }, "ip": { "description": "Computed", "type": "string" @@ -4230,7 +4109,6 @@ const docTemplate = `{ } }, "tax": { - "description": "TODO:", "type": "number" }, "total": { @@ -4347,6 +4225,19 @@ const docTemplate = `{ } } }, + "models.State": { + "type": "string", + "enum": [ + "failed", + "success", + "pending" + ], + "x-enum-varnames": [ + "FailedState", + "SuccessState", + "PendingState" + ] + }, "models.Voucher": { "type": "object", "required": [ @@ -4371,11 +4262,28 @@ const docTemplate = `{ "redeemed": { "type": "boolean" }, + "user_id": { + "type": "integer" + }, + "username": { + "type": "string" + }, "value": { "type": "number" } } }, + "models.operation": { + "type": "string", + "enum": [ + "withdraw", + "deposit" + ], + "x-enum-varnames": [ + "WithdrawOperation", + "DepositOperation" + ] + }, "services.AdminWorkflow": { "type": "object", "properties": { @@ -4444,42 +4352,6 @@ const docTemplate = `{ } } }, - "services.PendingRecordsWithUSDAmounts": { - "type": "object", - "properties": { - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "tft_amount": { - "description": "TFTs are multiplied by 1e7", - "type": "integer" - }, - "transfer_mode": { - "type": "string" - }, - "transferred_tft_amount": { - "type": "integer" - }, - "transferred_usd_amount": { - "type": "number" - }, - "updated_at": { - "type": "string" - }, - "usd_amount": { - "type": "number" - }, - "user_id": { - "type": "integer" - }, - "username": { - "type": "string" - } - } - }, "services.Pool": { "type": "object", "properties": { @@ -4522,7 +4394,43 @@ const docTemplate = `{ } } }, - "services.UserWithPendingBalance": { + "services.TransferRecordsWithTFTAmount": { + "type": "object", + "properties": { + "created_at": { + "type": "string" + }, + "failure": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "operation": { + "$ref": "#/definitions/models.operation" + }, + "state": { + "$ref": "#/definitions/models.State" + }, + "tft_amount": { + "description": "TFTs are multiplied by 1e7", + "type": "integer" + }, + "tft_amount_in_whole_unit": { + "type": "number" + }, + "updated_at": { + "type": "string" + }, + "user_id": { + "type": "integer" + }, + "username": { + "type": "string" + } + } + }, + "services.UserWithBalancesInUSD": { "type": "object", "required": [ "email", @@ -4535,6 +4443,9 @@ const docTemplate = `{ "admin": { "type": "boolean" }, + "balance_in_tft": { + "type": "number" + }, "code": { "type": "integer" }, @@ -4545,23 +4456,29 @@ const docTemplate = `{ "description": "millicent, money from credit card", "type": "integer" }, + "credit_card_balance_in_usd": { + "type": "number" + }, "credited_balance": { "description": "millicent, manually added by admin or from vouchers", "type": "integer" }, + "credited_balance_in_usd": { + "type": "number" + }, "debt": { "description": "millicent", "type": "integer" }, + "debt_in_usd": { + "type": "number" + }, "email": { "type": "string" }, "id": { "type": "integer" }, - "pending_balance_usd": { - "type": "number" - }, "sponsored": { "type": "boolean" }, @@ -4582,7 +4499,7 @@ const docTemplate = `{ } } }, - "services.UserWithUSDBalance": { + "services.UserWithTFTBalance": { "type": "object", "required": [ "email", @@ -4595,8 +4512,7 @@ const docTemplate = `{ "admin": { "type": "boolean" }, - "balance": { - "description": "USD balance", + "balance_in_tft": { "type": "number" }, "code": { @@ -5014,8 +4930,7 @@ const docTemplate = `{ "externalSK": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "iprange": { @@ -5026,8 +4941,7 @@ const docTemplate = `{ "additionalProperties": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } }, @@ -5036,8 +4950,7 @@ const docTemplate = `{ "additionalProperties": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } }, @@ -5047,15 +4960,13 @@ const docTemplate = `{ "nodeDeploymentID": { "type": "object", "additionalProperties": { - "type": "integer", - "format": "int64" + "type": "integer" } }, "nodes": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "nodesIPRange": { @@ -5065,8 +4976,7 @@ const docTemplate = `{ } }, "publicNodeID": { - "type": "integer", - "format": "int32" + "type": "integer" }, "solutionType": { "type": "string" @@ -5086,16 +4996,14 @@ const docTemplate = `{ "description": "network number", "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "mask": { "description": "network mask", "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } } diff --git a/backend/docs/swagger/swagger.json b/backend/docs/swagger/swagger.json index 485993b7f..a88ca9c2c 100644 --- a/backend/docs/swagger/swagger.json +++ b/backend/docs/swagger/swagger.json @@ -1170,53 +1170,6 @@ } } }, - "/pending-records": { - "get": { - "security": [ - { - "AdminMiddleware": [] - } - ], - "description": "Returns all pending records in the system", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "admin" - ], - "summary": "List pending records", - "operationId": "list-pending-records", - "responses": { - "200": { - "description": "Pending records are retrieved successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/services.PendingRecordsWithUSDAmounts" - } - } - } - ] - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, "/stats": { "get": { "security": [ @@ -1355,6 +1308,47 @@ } } }, + "/transfer-records": { + "get": { + "security": [ + { + "AdminMiddleware": [] + } + ], + "description": "Returns all transfer records in the system", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "admin" + ], + "summary": "List transfer records", + "operationId": "list-transfer-records", + "responses": { + "200": { + "description": "OK", + "schema": { + "type": "array", + "items": { + "type": "array", + "items": { + "$ref": "#/definitions/services.TransferRecordsWithTFTAmount" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "$ref": "#/definitions/handlers.APIResponse" + } + } + } + } + }, "/twins/{twin_id}/account": { "get": { "description": "Retrieve the account ID associated with a specific twin ID", @@ -1458,52 +1452,7 @@ "type": "object", "properties": { "data": { - "$ref": "#/definitions/services.UserWithPendingBalance" - } - } - } - ] - } - }, - "404": { - "description": "User is not found", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, - "/user/balance": { - "get": { - "description": "Retrieves the user's balance in USD", - "produces": [ - "application/json" - ], - "tags": [ - "users" - ], - "summary": "Get user balance", - "operationId": "get-user-balance", - "responses": { - "200": { - "description": "Balance fetched successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/handlers.UserBalanceResponse" + "$ref": "#/definitions/services.UserWithBalancesInUSD" } } } @@ -2224,53 +2173,6 @@ } } }, - "/user/pending-records": { - "get": { - "security": [ - { - "BearerAuth": [] - } - ], - "description": "Returns user pending records in the system", - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "users" - ], - "summary": "List user pending records", - "operationId": "list-user-pending-records", - "responses": { - "200": { - "description": "Pending records returned successfully", - "schema": { - "allOf": [ - { - "$ref": "#/definitions/handlers.APIResponse" - }, - { - "type": "object", - "properties": { - "data": { - "$ref": "#/definitions/handlers.PendingRecordsResponse" - } - } - } - ] - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/handlers.APIResponse" - } - } - } - } - }, "/user/redeem/{voucher_code}": { "put": { "description": "Redeems a voucher for the user", @@ -2790,7 +2692,7 @@ "schema": { "type": "array", "items": { - "$ref": "#/definitions/services.UserWithUSDBalance" + "$ref": "#/definitions/services.UserWithTFTBalance" } } }, @@ -3343,7 +3245,6 @@ }, "gridtypes.Unit": { "type": "integer", - "format": "int64", "enum": [ 1024, 1048576, @@ -3819,17 +3720,6 @@ } } }, - "handlers.PendingRecordsResponse": { - "type": "object", - "properties": { - "pending_records": { - "type": "array", - "items": { - "$ref": "#/definitions/services.PendingRecordsWithUSDAmounts" - } - } - } - }, "handlers.RedeemVoucherResponse": { "type": "object", "properties": { @@ -3975,20 +3865,6 @@ } } }, - "handlers.UserBalanceResponse": { - "type": "object", - "properties": { - "balance_usd": { - "type": "number" - }, - "debt_usd": { - "type": "number" - }, - "pending_balance_usd": { - "type": "number" - } - } - }, "handlers.UserWorkflow": { "type": "object", "properties": { @@ -4146,6 +4022,9 @@ "type": "string" } }, + "healthy": { + "type": "boolean" + }, "ip": { "description": "Computed", "type": "string" @@ -4223,7 +4102,6 @@ } }, "tax": { - "description": "TODO:", "type": "number" }, "total": { @@ -4340,6 +4218,19 @@ } } }, + "models.State": { + "type": "string", + "enum": [ + "failed", + "success", + "pending" + ], + "x-enum-varnames": [ + "FailedState", + "SuccessState", + "PendingState" + ] + }, "models.Voucher": { "type": "object", "required": [ @@ -4364,11 +4255,28 @@ "redeemed": { "type": "boolean" }, + "user_id": { + "type": "integer" + }, + "username": { + "type": "string" + }, "value": { "type": "number" } } }, + "models.operation": { + "type": "string", + "enum": [ + "withdraw", + "deposit" + ], + "x-enum-varnames": [ + "WithdrawOperation", + "DepositOperation" + ] + }, "services.AdminWorkflow": { "type": "object", "properties": { @@ -4437,42 +4345,6 @@ } } }, - "services.PendingRecordsWithUSDAmounts": { - "type": "object", - "properties": { - "created_at": { - "type": "string" - }, - "id": { - "type": "integer" - }, - "tft_amount": { - "description": "TFTs are multiplied by 1e7", - "type": "integer" - }, - "transfer_mode": { - "type": "string" - }, - "transferred_tft_amount": { - "type": "integer" - }, - "transferred_usd_amount": { - "type": "number" - }, - "updated_at": { - "type": "string" - }, - "usd_amount": { - "type": "number" - }, - "user_id": { - "type": "integer" - }, - "username": { - "type": "string" - } - } - }, "services.Pool": { "type": "object", "properties": { @@ -4515,7 +4387,43 @@ } } }, - "services.UserWithPendingBalance": { + "services.TransferRecordsWithTFTAmount": { + "type": "object", + "properties": { + "created_at": { + "type": "string" + }, + "failure": { + "type": "string" + }, + "id": { + "type": "integer" + }, + "operation": { + "$ref": "#/definitions/models.operation" + }, + "state": { + "$ref": "#/definitions/models.State" + }, + "tft_amount": { + "description": "TFTs are multiplied by 1e7", + "type": "integer" + }, + "tft_amount_in_whole_unit": { + "type": "number" + }, + "updated_at": { + "type": "string" + }, + "user_id": { + "type": "integer" + }, + "username": { + "type": "string" + } + } + }, + "services.UserWithBalancesInUSD": { "type": "object", "required": [ "email", @@ -4528,6 +4436,9 @@ "admin": { "type": "boolean" }, + "balance_in_tft": { + "type": "number" + }, "code": { "type": "integer" }, @@ -4538,23 +4449,29 @@ "description": "millicent, money from credit card", "type": "integer" }, + "credit_card_balance_in_usd": { + "type": "number" + }, "credited_balance": { "description": "millicent, manually added by admin or from vouchers", "type": "integer" }, + "credited_balance_in_usd": { + "type": "number" + }, "debt": { "description": "millicent", "type": "integer" }, + "debt_in_usd": { + "type": "number" + }, "email": { "type": "string" }, "id": { "type": "integer" }, - "pending_balance_usd": { - "type": "number" - }, "sponsored": { "type": "boolean" }, @@ -4575,7 +4492,7 @@ } } }, - "services.UserWithUSDBalance": { + "services.UserWithTFTBalance": { "type": "object", "required": [ "email", @@ -4588,8 +4505,7 @@ "admin": { "type": "boolean" }, - "balance": { - "description": "USD balance", + "balance_in_tft": { "type": "number" }, "code": { @@ -5007,8 +4923,7 @@ "externalSK": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "iprange": { @@ -5019,8 +4934,7 @@ "additionalProperties": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } }, @@ -5029,8 +4943,7 @@ "additionalProperties": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } }, @@ -5040,15 +4953,13 @@ "nodeDeploymentID": { "type": "object", "additionalProperties": { - "type": "integer", - "format": "int64" + "type": "integer" } }, "nodes": { "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "nodesIPRange": { @@ -5058,8 +4969,7 @@ } }, "publicNodeID": { - "type": "integer", - "format": "int32" + "type": "integer" }, "solutionType": { "type": "string" @@ -5079,16 +4989,14 @@ "description": "network number", "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } }, "mask": { "description": "network mask", "type": "array", "items": { - "type": "integer", - "format": "int32" + "type": "integer" } } } diff --git a/backend/docs/swagger/swagger.yaml b/backend/docs/swagger/swagger.yaml index 883907d1a..701a760a2 100644 --- a/backend/docs/swagger/swagger.yaml +++ b/backend/docs/swagger/swagger.yaml @@ -21,7 +21,6 @@ definitions: - 1048576 - 1073741824 - 1099511627776 - format: int64 type: integer x-enum-varnames: - Kilobyte @@ -339,13 +338,6 @@ definitions: type: $ref: '#/definitions/models.NotificationType' type: object - handlers.PendingRecordsResponse: - properties: - pending_records: - items: - $ref: '#/definitions/services.PendingRecordsWithUSDAmounts' - type: array - type: object handlers.RedeemVoucherResponse: properties: amount: @@ -442,15 +434,6 @@ definitions: workflow_id: type: string type: object - handlers.UserBalanceResponse: - properties: - balance_usd: - type: number - debt_usd: - type: number - pending_balance_usd: - type: number - type: object handlers.UserWorkflow: properties: created_at: @@ -550,6 +533,8 @@ definitions: items: type: string type: array + healthy: + type: boolean ip: description: Computed type: string @@ -610,7 +595,6 @@ definitions: $ref: '#/definitions/models.NodeItem' type: array tax: - description: 'TODO:' type: number total: type: number @@ -697,6 +681,16 @@ definitions: - public_key - userID type: object + models.State: + enum: + - failed + - success + - pending + type: string + x-enum-varnames: + - FailedState + - SuccessState + - PendingState models.Voucher: properties: code: @@ -709,6 +703,10 @@ definitions: type: integer redeemed: type: boolean + user_id: + type: integer + username: + type: string value: type: number required: @@ -717,6 +715,14 @@ definitions: - expires_at - value type: object + models.operation: + enum: + - withdraw + - deposit + type: string + x-enum-varnames: + - WithdrawOperation + - DepositOperation services.AdminWorkflow: properties: created_at: @@ -762,30 +768,6 @@ definitions: updated_at: type: string type: object - services.PendingRecordsWithUSDAmounts: - properties: - created_at: - type: string - id: - type: integer - tft_amount: - description: TFTs are multiplied by 1e7 - type: integer - transfer_mode: - type: string - transferred_tft_amount: - type: integer - transferred_usd_amount: - type: number - updated_at: - type: string - usd_amount: - type: number - user_id: - type: integer - username: - type: string - type: object services.Pool: properties: free: @@ -814,12 +796,38 @@ definitions: up_nodes: type: integer type: object - services.UserWithPendingBalance: + services.TransferRecordsWithTFTAmount: + properties: + created_at: + type: string + failure: + type: string + id: + type: integer + operation: + $ref: '#/definitions/models.operation' + state: + $ref: '#/definitions/models.State' + tft_amount: + description: TFTs are multiplied by 1e7 + type: integer + tft_amount_in_whole_unit: + type: number + updated_at: + type: string + user_id: + type: integer + username: + type: string + type: object + services.UserWithBalancesInUSD: properties: account_address: type: string admin: type: boolean + balance_in_tft: + type: number code: type: integer created_at: @@ -827,18 +835,22 @@ definitions: credit_card_balance: description: millicent, money from credit card type: integer + credit_card_balance_in_usd: + type: number credited_balance: description: millicent, manually added by admin or from vouchers type: integer + credited_balance_in_usd: + type: number debt: description: millicent type: integer + debt_in_usd: + type: number email: type: string id: type: integer - pending_balance_usd: - type: number sponsored: type: boolean ssh_key: @@ -855,14 +867,13 @@ definitions: - email - username type: object - services.UserWithUSDBalance: + services.UserWithTFTBalance: properties: account_address: type: string admin: type: boolean - balance: - description: USD balance + balance_in_tft: type: number code: type: integer @@ -1143,7 +1154,6 @@ definitions: $ref: '#/definitions/zos.IPNet' externalSK: items: - format: int32 type: integer type: array iprange: @@ -1151,14 +1161,12 @@ definitions: keys: additionalProperties: items: - format: int32 type: integer type: array type: object myceliumKeys: additionalProperties: items: - format: int32 type: integer type: array type: object @@ -1166,12 +1174,10 @@ definitions: type: string nodeDeploymentID: additionalProperties: - format: int64 type: integer type: object nodes: items: - format: int32 type: integer type: array nodesIPRange: @@ -1179,7 +1185,6 @@ definitions: $ref: '#/definitions/zos.IPNet' type: object publicNodeID: - format: int32 type: integer solutionType: type: string @@ -1193,13 +1198,11 @@ definitions: ip: description: network number items: - format: int32 type: integer type: array mask: description: network mask items: - format: int32 type: integer type: array type: object @@ -1915,33 +1918,6 @@ paths: summary: Get unread notifications tags: - notifications - /pending-records: - get: - consumes: - - application/json - description: Returns all pending records in the system - operationId: list-pending-records - produces: - - application/json - responses: - "200": - description: Pending records are retrieved successfully - schema: - allOf: - - $ref: '#/definitions/handlers.APIResponse' - - properties: - data: - $ref: '#/definitions/services.PendingRecordsWithUSDAmounts' - type: object - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.APIResponse' - security: - - AdminMiddleware: [] - summary: List pending records - tags: - - admin /stats: get: consumes: @@ -2024,6 +2000,32 @@ paths: summary: Set maintenance mode tags: - admin + /transfer-records: + get: + consumes: + - application/json + description: Returns all transfer records in the system + operationId: list-transfer-records + produces: + - application/json + responses: + "200": + description: OK + schema: + items: + items: + $ref: '#/definitions/services.TransferRecordsWithTFTAmount' + type: array + type: array + "500": + description: Internal Server Error + schema: + $ref: '#/definitions/handlers.APIResponse' + security: + - AdminMiddleware: [] + summary: List transfer records + tags: + - admin /twins/{twin_id}/account: get: consumes: @@ -2088,7 +2090,7 @@ paths: - $ref: '#/definitions/handlers.APIResponse' - properties: data: - $ref: '#/definitions/services.UserWithPendingBalance' + $ref: '#/definitions/services.UserWithBalancesInUSD' type: object "404": description: User is not found @@ -2101,33 +2103,6 @@ paths: summary: Get user details tags: - users - /user/balance: - get: - description: Retrieves the user's balance in USD - operationId: get-user-balance - produces: - - application/json - responses: - "200": - description: Balance fetched successfully - schema: - allOf: - - $ref: '#/definitions/handlers.APIResponse' - - properties: - data: - $ref: '#/definitions/handlers.UserBalanceResponse' - type: object - "404": - description: User is not found - schema: - $ref: '#/definitions/handlers.APIResponse' - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.APIResponse' - summary: Get user balance - tags: - - users /user/balance/charge: post: consumes: @@ -2560,33 +2535,6 @@ paths: summary: Unreserve node tags: - nodes - /user/pending-records: - get: - consumes: - - application/json - description: Returns user pending records in the system - operationId: list-user-pending-records - produces: - - application/json - responses: - "200": - description: Pending records returned successfully - schema: - allOf: - - $ref: '#/definitions/handlers.APIResponse' - - properties: - data: - $ref: '#/definitions/handlers.PendingRecordsResponse' - type: object - "500": - description: Internal Server Error - schema: - $ref: '#/definitions/handlers.APIResponse' - security: - - BearerAuth: [] - summary: List user pending records - tags: - - users /user/redeem/{voucher_code}: put: description: Redeems a voucher for the user @@ -2905,7 +2853,7 @@ paths: description: OK schema: items: - $ref: '#/definitions/services.UserWithUSDBalance' + $ref: '#/definitions/services.UserWithTFTBalance' type: array "500": description: Internal Server Error diff --git a/backend/go.mod b/backend/go.mod index 0433e12ea..12809bc74 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -26,8 +26,8 @@ require ( github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.1 github.com/swaggo/swag v1.16.6 - github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20250929084418-b950278ead30 - github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251209150615-a3bb942f9860 + github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20251221150744-62c2f0fbdc2e + github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251221165053-aa8e353f7446 github.com/threefoldtech/tfgrid-sdk-go/grid-proxy v0.17.6-0.20251209150615-a3bb942f9860 github.com/vedhavyas/go-subkey v1.0.3 github.com/xmonader/ewf v0.0.0-20251127155219-5a8a59ee967f @@ -81,7 +81,7 @@ require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.14.2 // indirect github.com/bytedance/sonic/loader v0.4.0 // indirect - github.com/cenkalti/backoff v2.2.1+incompatible // indirect + github.com/cenkalti/backoff v2.2.1+incompatible github.com/cenkalti/backoff/v3 v3.2.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/centrifuge/go-substrate-rpc-client/v4 v4.2.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index 276ccae12..58b1212a5 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -363,10 +363,10 @@ github.com/swaggo/gin-swagger v1.6.1 h1:Ri06G4gc9N4t4k8hekMigJ9zKTFSlqj/9paAQCQs github.com/swaggo/gin-swagger v1.6.1/go.mod h1:LQ+hJStHakCWRiK/YNYtJOu4mR2FP+pxLnILT/qNiTw= github.com/swaggo/swag v1.16.6 h1:qBNcx53ZaX+M5dxVyTrgQ0PJ/ACK+NzhwcbieTt+9yI= github.com/swaggo/swag v1.16.6/go.mod h1:ngP2etMK5a0P3QBizic5MEwpRmluJZPHjXcMoj4Xesg= -github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20250929084418-b950278ead30 h1:sH/hiHxCEpeIm2gJsmu4GxKskfQVPZMz9PAgDwk1BfY= -github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20250929084418-b950278ead30/go.mod h1:cOL5YgHUmDG5SAXrsZxFjUECRQQuAqOoqvXhZG5sEUw= -github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251209150615-a3bb942f9860 h1:4nX4aoA+OteH7M2pTaWYIpN/qVGoqs/kkVAxSnhJbWU= -github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251209150615-a3bb942f9860/go.mod h1:I48ny96GOuMDo9o7INEb43UJIhO3l4mVtBLMNi8Si3M= +github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20251221150744-62c2f0fbdc2e h1:0U8ys+JAlxZyr2A69K8EdB37oQB2sFtCUbUeZBKLgio= +github.com/threefoldtech/tfchain/clients/tfchain-client-go v0.0.0-20251221150744-62c2f0fbdc2e/go.mod h1:cOL5YgHUmDG5SAXrsZxFjUECRQQuAqOoqvXhZG5sEUw= +github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251221165053-aa8e353f7446 h1:hjc3PKLQ5uZQpUZqmKOGuae94CnY9haB+LiK8LYPoiw= +github.com/threefoldtech/tfgrid-sdk-go/grid-client v0.17.6-0.20251221165053-aa8e353f7446/go.mod h1:Z+qUII96KHMN0pXmesfB41OdDOWIeS29uxLk4ZNEKfA= github.com/threefoldtech/tfgrid-sdk-go/grid-proxy v0.17.6-0.20251209150615-a3bb942f9860 h1:LlHQ2Du7S5IAoG0Q61v1s8OAbx8qqf/uO/5QwMf1zZY= github.com/threefoldtech/tfgrid-sdk-go/grid-proxy v0.17.6-0.20251209150615-a3bb942f9860/go.mod h1:J57xHAagUOddwz2nonrkB91/8T9TJ9IKk/6wzKVB5EM= github.com/threefoldtech/tfgrid-sdk-go/rmb-sdk-go v0.17.5 h1:zp5iZOvtvcQrcR7Po3UZBNk2uBYi1i1VxMA/ENIvCZY= diff --git a/backend/internal/api/app/app.go b/backend/internal/api/app/app.go index 56bdc34cf..96dc5a323 100644 --- a/backend/internal/api/app/app.go +++ b/backend/internal/api/app/app.go @@ -128,7 +128,7 @@ func (app *App) registerHandlers() { usersGroup.POST("/drain-all", app.handlers.adminHandler.DrainAllUsersHandler) } usersGroup.POST("/mail", app.handlers.adminHandler.SendMailToAllUsersHandler) - adminGroup.GET("/pending-records", app.handlers.adminHandler.ListPendingRecordsHandler) + adminGroup.GET("/transfer-records", app.handlers.adminHandler.ListTransferRecordsHandler) adminGroup.GET("/invoices", app.handlers.invoiceHandler.ListAllInvoicesHandler) adminGroup.GET("/workflows", app.handlers.adminHandler.ListAllWorkflowsHandler) @@ -161,9 +161,7 @@ func (app *App) registerHandlers() { authGroup.GET("/", app.handlers.userHandler.GetUserHandler) authGroup.POST("/balance/charge", app.handlers.userHandler.ChargeBalance) authGroup.PUT("/change_password", app.handlers.userHandler.ChangePasswordHandler) - authGroup.GET("/balance", app.handlers.userHandler.GetUserBalance) authGroup.PUT("/redeem/:voucher_code", app.handlers.userHandler.RedeemVoucherHandler) - authGroup.GET("/pending-records", app.handlers.userHandler.ListUserPendingRecordsHandler) authGroup.GET("/workflows", app.handlers.userHandler.ListUserRemainingWorkflowsHandler) authGroup.GET("/nodes", app.handlers.nodeHandler.ListNodesHandler) @@ -226,6 +224,7 @@ func (app *App) StartBackgroundWorkers() { go app.workers.MonitorSystemBalanceAndHandleSettlement() go app.workers.TrackClusterHealth() go app.workers.TrackReservedNodeHealth() + go app.workers.DeductUSDBalanceBasedOnUsage() go app.workers.CollectGORMMetrics() go app.workers.CollectGoRuntimeMetrics() } diff --git a/backend/internal/api/app/app_dependencies.go b/backend/internal/api/app/app_dependencies.go index 5f55af757..d8e874152 100644 --- a/backend/internal/api/app/app_dependencies.go +++ b/backend/internal/api/app/app_dependencies.go @@ -308,18 +308,23 @@ func (app *App) createHandlers() appHandlers { // Repositories userRepo := corepersistence.NewGormUserRepository(app.core.db) voucherRepo := corepersistence.NewGormVoucherRepository(app.core.db) - pendingRecordRepo := corepersistence.NewGormPendingRecordRepository(app.core.db) + transferRecordsRepo := corepersistence.NewGormTransferRecordRepository(app.core.db) notificationRepo := corepersistence.NewGormNotificationRepository(app.core.db) clusterRepo := corepersistence.NewGormClusterRepository(app.core.db) invoiceRepo := corepersistence.NewGormInvoiceRepository(app.core.db) - userNodesRepo := corepersistence.NewGormUserNodesRepository(app.core.db) + contractsRepo := corepersistence.NewGormUserContractDataRepository(app.core.db) transactionRepo := corepersistence.NewGormTransactionRepository(app.core.db) settingsRepo := corepersistence.NewGormSettingsRepository(app.core.db) ewfRepo := corepersistence.NewGormEWFRepository(app.core.db) // Services + billingService := services.NewBillingService( + userRepo, contractsRepo, transferRecordsRepo, clusterRepo, + app.infra.graphql, app.infra.gridClient, + uint64(app.config.MinimumTFTAmountInWallet), services.Discount(app.config.AppliedDiscount), + ) userService := services.NewUserService( - app.core.appCtx, userRepo, voucherRepo, pendingRecordRepo, + app.core.appCtx, userRepo, voucherRepo, app.infra.gridClient, app.core.ewfEngine, app.security.kycClient, app.core.metrics, app.config.MailSender.TimeoutMin, app.config.Admins, @@ -332,7 +337,7 @@ func (app *App) createHandlers() appHandlers { notificationAPIService := services.NewNotificationService(notificationRepo) nodeService := services.NewNodeService( - userNodesRepo, userRepo, app.core.appCtx, app.core.ewfEngine, + contractsRepo, userRepo, app.core.appCtx, app.core.ewfEngine, app.infra.gridClient, ) @@ -341,12 +346,12 @@ func (app *App) createHandlers() appHandlers { ) deploymentService := services.NewDeploymentService( - app.core.appCtx, clusterRepo, userRepo, userNodesRepo, app.core.ewfEngine, + app.core.appCtx, clusterRepo, userRepo, contractsRepo, app.core.ewfEngine, app.config.Debug, app.security.sshPublicKey, app.config.SSH.PrivateKeyPath, app.config.SystemAccount.Network, ) adminService := services.NewAdminService( - app.core.appCtx, userRepo, userNodesRepo, pendingRecordRepo, voucherRepo, + app.core.appCtx, userRepo, contractsRepo, transferRecordsRepo, voucherRepo, transactionRepo, app.infra.gridClient, app.core.ewfEngine, app.communication.mailService, app.communication.notificationDispatcher, ewfRepo, ) @@ -355,15 +360,15 @@ func (app *App) createHandlers() appHandlers { // Handlers stripeClient := &billing.DefaultStripeClient{} userHandler := handlers.NewUserHandler( - userService, app.communication.notificationDispatcher, + userService, billingService, app.communication.notificationDispatcher, app.communication.mailService, app.security.tokenManager, stripeClient, ) statsHandler := handlers.NewStatsHandler(statsService) notificationHandler := handlers.NewNotificationHandler(notificationAPIService) - nodeHandler := handlers.NewNodeHandler(nodeService) - deploymentHandler := handlers.NewDeploymentHandler(deploymentService) + nodeHandler := handlers.NewNodeHandler(nodeService, billingService) + deploymentHandler := handlers.NewDeploymentHandler(deploymentService, billingService) invoiceHandler := handlers.NewInvoiceHandler(invoiceService) - adminHandler := handlers.NewAdminHandler(adminService, app.communication.notificationDispatcher, app.communication.mailService) + adminHandler := handlers.NewAdminHandler(adminService, billingService, app.communication.notificationDispatcher, app.communication.mailService) healthHandler := handlers.NewHealthHandler(app.config.SystemAccount.Network, app.infra.firesquidClient, app.infra.graphql, app.core.db) settingsHandler := handlers.NewSettingsHandler(settingsService) @@ -382,19 +387,32 @@ func (app *App) createHandlers() appHandlers { func (app *App) createWorkers() workers.Workers { userRepo := corepersistence.NewGormUserRepository(app.core.db) - pendingRecordRepo := corepersistence.NewGormPendingRecordRepository(app.core.db) + transferRecordsRepo := corepersistence.NewGormTransferRecordRepository(app.core.db) clusterRepo := corepersistence.NewGormClusterRepository(app.core.db) invoiceRepo := corepersistence.NewGormInvoiceRepository(app.core.db) - userNodesRepo := corepersistence.NewGormUserNodesRepository(app.core.db) + contractsRepo := corepersistence.NewGormUserContractDataRepository(app.core.db) workersService := services.NewWorkersService( - app.core.appCtx, userRepo, userNodesRepo, invoiceRepo, clusterRepo, pendingRecordRepo, + app.core.appCtx, userRepo, contractsRepo, invoiceRepo, clusterRepo, transferRecordsRepo, app.communication.mailService, app.infra.gridClient, app.core.ewfEngine, app.communication.notificationDispatcher, app.infra.graphql, app.infra.firesquidClient, app.config.Invoice, app.config.SystemAccount.Mnemonic, app.config.Currency, app.config.ClusterHealthCheckIntervalInHours, - app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, app.config.MonitorBalanceIntervalInMinutes, app.config.NotifyAdminsForPendingRecordsInHours, app.config.UsersBalanceCheckIntervalInHours, app.config.CheckUserDebtIntervalInHours, + app.config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, + app.config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, + app.config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, + app.config.SettleTransferRecordsIntervalInMinutes, + app.config.NotifyAdminsForPendingRecordsInHours, + app.config.MinimumTFTAmountInWallet, services.Discount(app.config.AppliedDiscount), + app.config.UsersBalanceCheckIntervalInHours, + app.config.CheckUserDebtIntervalInHours, + ) + + billingService := services.NewBillingService( + userRepo, contractsRepo, transferRecordsRepo, clusterRepo, + app.infra.graphql, app.infra.gridClient, + uint64(app.config.MinimumTFTAmountInWallet), services.Discount(app.config.AppliedDiscount), ) - return workers.NewWorkers(app.core.appCtx, workersService, app.core.metrics, app.core.db) + return workers.NewWorkers(app.core.appCtx, workersService, billingService, app.core.metrics, app.core.db) } diff --git a/backend/internal/api/handlers/admin_handler.go b/backend/internal/api/handlers/admin_handler.go index 2a1380405..30c036776 100644 --- a/backend/internal/api/handlers/admin_handler.go +++ b/backend/internal/api/handlers/admin_handler.go @@ -15,7 +15,6 @@ import ( "time" "kubecloud/internal/core/services" - "kubecloud/internal/infrastructure/logger" "kubecloud/internal/infrastructure/mailservice" mailsender "kubecloud/internal/infrastructure/mailservice/mail_sender" "kubecloud/internal/infrastructure/notification" @@ -26,15 +25,17 @@ import ( type AdminHandler struct { svc services.AdminService + billingService services.BillingService notificationDispatcher *notification.NotificationDispatcher mailService mailservice.MailService } -func NewAdminHandler(svc services.AdminService, +func NewAdminHandler(svc services.AdminService, billingService services.BillingService, notificationDispatcher *notification.NotificationDispatcher, mailService mailservice.MailService, ) AdminHandler { return AdminHandler{ svc: svc, + billingService: billingService, notificationDispatcher: notificationDispatcher, mailService: mailService, } @@ -77,7 +78,7 @@ type MaintenanceModeStatus struct { // @ID get-all-users // @Accept json // @Produce json -// @Success 200 {array} services.UserWithUSDBalance +// @Success 200 {array} services.UserWithTFTBalance // @Failure 500 {object} APIResponse // @Security AdminMiddleware // @Router /users [get] @@ -139,7 +140,7 @@ func (h *AdminHandler) DeleteUsersHandler(c *gin.Context) { NotFound(c, "User not found") return } - logger.GetLogger().Error().Err(err).Msg("failed to delete user by id") + reqLog.Error().Err(err).Msg("failed to delete user by id") InternalServerError(c) return } @@ -255,6 +256,17 @@ func (h *AdminHandler) CreditUserHandler(c *gin.Context) { return } + user, err := h.svc.GetUserByID(id) + if err != nil { + if errors.Is(err, models.ErrUserNotFound) { + NotFound(c, "User is not found") + return + } + reqLog.Error().Err(err).Msg("failed to retrieve user") + InternalServerError(c) + return + } + transaction := models.Transaction{ UserID: id, AdminID: adminID, @@ -263,41 +275,59 @@ func (h *AdminHandler) CreditUserHandler(c *gin.Context) { CreatedAt: time.Now(), } - if err := h.svc.AsyncCreditUserUSD(&transaction); err != nil { - reqLog.Error().Err(err).Msg("failed to credit user") + if err := h.svc.CreditUserBalance(c.Request.Context(), transaction, &user); err != nil { + reqLog.Error().Err(err).Msg("Failed to credit user balance") InternalServerError(c) return } - Accepted(c, "Transaction is created successfully, Money transfer is in progress", CreditUserResponse{ + if err := h.billingService.AfterUserGetCredit(c.Request.Context(), &user); err != nil { + reqLog.Error().Err(err).Msg("Failed to credit user balance") + InternalServerError(c) + return + } + + notif := notification.BillingNotification(adminID). + Success(fmt.Sprintf("Admin %s has credited your account with %v$ successfully", user.Username, request.AmountUSD)). + WithSubject("Admin Credited Your Account"). + WithStatus("succeeded"). + WithChannels(notification.ChannelUI). + NoPersist(). + Build() + + if err := h.notificationDispatcher.Send(c.Request.Context(), notif); err != nil { + reqLog.Error().Err(err).Msg("failed to send UI ") + } + + Success(c, http.StatusCreated, fmt.Sprintf("User is credited with %v$ successfully", request.AmountUSD), CreditUserResponse{ AmountUSD: request.AmountUSD, Memo: request.Memo, }) } -// @Summary List pending records -// @Description Returns all pending records in the system +// @Summary List transfer records +// @Description Returns all transfer records in the system // @Tags admin -// @ID list-pending-records +// @ID list-transfer-records // @Accept json // @Produce json -// @Success 200 {object} APIResponse{data=services.PendingRecordsWithUSDAmounts} "Pending records are retrieved successfully" +// @Success 200 {array} []services.TransferRecordsWithTFTAmount // @Failure 500 {object} APIResponse // @Security AdminMiddleware -// @Router /pending-records [get] -// ListPendingRecordsHandler returns all pending records in the system -func (h *AdminHandler) ListPendingRecordsHandler(c *gin.Context) { +// @Router /transfer-records [get] +// ListTransferRecordsHandler returns all transfer records in the system +func (h *AdminHandler) ListTransferRecordsHandler(c *gin.Context) { reqLog := requestLogger(c, "ListPendingRecordsHandler") - pendingRecordsResponse, err := h.svc.ListAllPendingRecordsWithUSDAmounts() + transferRecordsResponse, err := h.svc.ListAllTransferRecordsWithTFTAmount() if err != nil { reqLog.Error().Err(err).Msg("failed to list all pending records") InternalServerError(c) return } - OK(c, "Pending records are retrieved successfully", gin.H{ - "pending_records": pendingRecordsResponse, + OK(c, "Transfer records are retrieved successfully", gin.H{ + "transfer_records": transferRecordsResponse, }) } @@ -373,7 +403,6 @@ func (h *AdminHandler) parseAttachments(fileHeaders []*multipart.FileHeader) ([] attachment, err := h.parseAttachment(fh) if err != nil { - logger.GetLogger().Error().Err(err).Str("filename", fh.Filename).Msg("failed to parse attachment") mu.Lock() multiErr = multierror.Append(multiErr, err) mu.Unlock() diff --git a/backend/internal/api/handlers/admin_handler_test.go b/backend/internal/api/handlers/admin_handler_test.go index de7e946ce..ded1a9421 100644 --- a/backend/internal/api/handlers/admin_handler_test.go +++ b/backend/internal/api/handlers/admin_handler_test.go @@ -330,11 +330,11 @@ func TestCreditUserHandler(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusAccepted, resp.Code) + assert.Equal(t, http.StatusCreated, resp.Code) var result map[string]interface{} err := json.Unmarshal(resp.Body.Bytes(), &result) assert.NoError(t, err) - assert.Equal(t, "Transaction is created successfully, Money transfer is in progress", result["message"]) + assert.Equal(t, "User is credited with 1$ successfully", result["message"]) assert.NotNil(t, result["data"]) data, ok := result["data"].(map[string]interface{}) assert.True(t, ok) @@ -421,34 +421,34 @@ func TestListPendingRecordsHandler(t *testing.T) { adminUser := setup.CreateTestUser(t, "admin@example.com", "Admin User", []byte("securepassword"), true, true, false, 0, time.Now()) nonAdminUser := setup.CreateTestUser(t, "user@example.com", "Normal User", []byte("securepassword"), true, false, false, 0, time.Now()) - t.Run("Test ListPendingRecordsHandler successfully", func(t *testing.T) { + t.Run("Test ListTransferRecordsHandler successfully", func(t *testing.T) { token := setup.GetAuthToken(t, adminUser.ID, adminUser.Email, adminUser.Username, true) - req, _ := http.NewRequest("GET", "/api/v1/pending-records", nil) + req, _ := http.NewRequest("GET", "/api/v1/transfer-records", nil) req.Header.Set("Authorization", "Bearer "+token) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) assert.Equal(t, http.StatusOK, resp.Code) }) - t.Run("Test ListPendingRecordsHandler with no token", func(t *testing.T) { - req, _ := http.NewRequest("GET", "/api/v1/pending-records", nil) + t.Run("Test ListTransferRecordsHandler with no token", func(t *testing.T) { + req, _ := http.NewRequest("GET", "/api/v1/transfer-records", nil) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) assert.Equal(t, http.StatusUnauthorized, resp.Code) }) - t.Run("Test ListPendingRecordsHandler with non-admin user", func(t *testing.T) { + t.Run("Test ListTransferRecordsHandler with non-admin user", func(t *testing.T) { token := setup.GetAuthToken(t, nonAdminUser.ID, nonAdminUser.Email, nonAdminUser.Username, false) - req, _ := http.NewRequest("GET", "/api/v1/pending-records", nil) + req, _ := http.NewRequest("GET", "/api/v1/transfer-records", nil) req.Header.Set("Authorization", "Bearer "+token) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) assert.Equal(t, http.StatusForbidden, resp.Code) }) - t.Run("Test ListPendingRecordsHandler with non-existing user", func(t *testing.T) { + t.Run("Test ListTransferRecordsHandler with non-existing user", func(t *testing.T) { token := setup.GetAuthToken(t, adminUser.ID, adminUser.Email, adminUser.Username, true) - req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/pending-records/%d", nonAdminUser.ID+1), nil) + req, _ := http.NewRequest("GET", fmt.Sprintf("/api/v1/transfer-records/%d", nonAdminUser.ID+1), nil) req.Header.Set("Authorization", "Bearer "+token) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) diff --git a/backend/internal/api/handlers/deployment_handler.go b/backend/internal/api/handlers/deployment_handler.go index c76560ff0..536d6f1c0 100644 --- a/backend/internal/api/handlers/deployment_handler.go +++ b/backend/internal/api/handlers/deployment_handler.go @@ -12,12 +12,14 @@ import ( ) type DeploymentHandler struct { - svc services.DeploymentService + svc services.DeploymentService + billingService services.BillingService } -func NewDeploymentHandler(svc services.DeploymentService) DeploymentHandler { +func NewDeploymentHandler(svc services.DeploymentService, billingService services.BillingService) DeploymentHandler { return DeploymentHandler{ - svc: svc, + svc: svc, + billingService: billingService, } } @@ -222,6 +224,23 @@ func (h *DeploymentHandler) HandleDeployCluster(c *gin.Context) { return } + user, err := h.svc.GetUserByID(config.UserID) + if err != nil { + if errors.Is(err, models.ErrUserNotFound) { + NotFound(c, "User not found") + return + } + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + + if err := h.billingService.FundUserToFulfillDiscount(c.Request.Context(), &user, nil, cluster.Nodes); err != nil { + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + projectName := kubedeployer.GetProjectName(config.UserID, cluster.Name) logWithProject := reqLog.With().Str("project_name", projectName).Logger() reqLog = &logWithProject @@ -409,6 +428,23 @@ func (h *DeploymentHandler) HandleAddNode(c *gin.Context) { } } + user, err := h.svc.GetUserByID(config.UserID) + if err != nil { + if errors.Is(err, models.ErrUserNotFound) { + NotFound(c, "User not found") + return + } + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + + if err := h.billingService.FundUserToFulfillDiscount(c.Request.Context(), &user, nil, cluster.Nodes); err != nil { + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + wfUUID, wfStatus, err := h.svc.AsyncAddNode(config, cl, cluster.Nodes[0]) if err != nil { reqLog.Error().Err(err).Msg("failed to start add node workflow") diff --git a/backend/internal/api/handlers/node_handler.go b/backend/internal/api/handlers/node_handler.go index 470b10834..bb9571eb8 100644 --- a/backend/internal/api/handlers/node_handler.go +++ b/backend/internal/api/handlers/node_handler.go @@ -17,12 +17,14 @@ import ( ) type NodeHandler struct { - svc services.NodeService + svc services.NodeService + billingService services.BillingService } -func NewNodeHandler(svc services.NodeService) NodeHandler { +func NewNodeHandler(svc services.NodeService, billingService services.BillingService) NodeHandler { return NodeHandler{ - svc: svc, + svc: svc, + billingService: billingService, } } @@ -290,6 +292,12 @@ func (h *NodeHandler) ReserveNodeHandler(c *gin.Context) { return } + if err := h.billingService.FundUserToFulfillDiscount(c.Request.Context(), &user, nil, nil); err != nil { + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + wfUUID, err := h.svc.AsyncReserveNode(userID, user.Mnemonic, nodeID) if err != nil { reqLog.Error().Err(err).Msg("failed to start workflow to reserve node") diff --git a/backend/internal/api/handlers/setup.go b/backend/internal/api/handlers/setup.go index aeeade99b..83e29d7f9 100644 --- a/backend/internal/api/handlers/setup.go +++ b/backend/internal/api/handlers/setup.go @@ -111,9 +111,8 @@ func SetUp(t testing.TB) (setup, error) { "private_key_path": "%s", "public_key_path": "%s" }, - "monitor_balance_interval_in_minutes": 2, + "settle_transfer_records_interval_in_minutes": 2, "notify_admins_for_pending_records_in_hours": 1, - "verification_code_length": 4, "kyc_verifier_api_url": "https://kyc.dev.grid.tf", "kyc_challenge_domain": "kyc.dev.grid.tf", "cluster_health_check_interval_in_hours": 1, diff --git a/backend/internal/api/handlers/user_handler.go b/backend/internal/api/handlers/user_handler.go index 41233c435..6b1774c45 100644 --- a/backend/internal/api/handlers/user_handler.go +++ b/backend/internal/api/handlers/user_handler.go @@ -10,6 +10,7 @@ import ( "kubecloud/internal/infrastructure/gridclient" "kubecloud/internal/infrastructure/mailservice" "kubecloud/internal/infrastructure/notification" + "net/http" "sort" "strconv" "strings" @@ -26,6 +27,7 @@ import ( type UserHandler struct { svc services.UserService + billingService services.BillingService notificationDispatcher *notification.NotificationDispatcher mailService mailservice.MailService tokenManager auth.TokenManager @@ -34,6 +36,7 @@ type UserHandler struct { func NewUserHandler( svc services.UserService, + billing services.BillingService, notificationDispatcher *notification.NotificationDispatcher, mailService mailservice.MailService, tokenManager auth.TokenManager, @@ -109,13 +112,6 @@ type ChargeBalanceResponse struct { Email string `json:"email"` } -// UserBalanceResponse struct holds the response data for user balance -type UserBalanceResponse struct { - BalanceUSD float64 `json:"balance_usd"` - DebtUSD float64 `json:"debt_usd"` - PendingBalanceUSD float64 `json:"pending_balance_usd"` -} - // SSHKeyInput struct for adding SSH keys type SSHKeyInput struct { Name string `json:"name" binding:"required"` @@ -134,11 +130,6 @@ type VerifyRegisterUserResponse struct { *auth.TokenPair } -// PendingRecordsResponse swagger model -type PendingRecordsResponse struct { - PendingRecords []services.PendingRecordsWithUSDAmounts `json:"pending_records"` -} - // RedeemVoucherResponse holds the response for redeeming a voucher type RedeemVoucherResponse struct { WorkflowID string `json:"workflow_id"` @@ -662,6 +653,12 @@ func (h *UserHandler) ChargeBalance(c *gin.Context) { return } + if err := h.billingService.AfterUserGetCredit(c.Request.Context(), &user); err != nil { + reqLog.Error().Err(err).Msg("Failed to credit user balance") + InternalServerError(c) + return + } + Accepted(c, "Charge in progress. You can check its status using the workflow id.", ChargeBalanceResponse{ WorkflowID: wfUUID, Email: user.Email, @@ -673,7 +670,7 @@ func (h *UserHandler) ChargeBalance(c *gin.Context) { // @Tags users // @ID get-user // @Produce json -// @Success 200 {object} APIResponse{data=services.UserWithPendingBalance} "User is retrieved successfully" +// @Success 200 {object} APIResponse{data=services.UserWithBalancesInUSD} "User is retrieved successfully" // @Failure 404 {object} APIResponse "User is not found" // @Failure 500 {object} APIResponse // @Router /user [get] @@ -682,7 +679,7 @@ func (h *UserHandler) GetUserHandler(c *gin.Context) { userID := c.GetInt("user_id") reqLog := requestLogger(c, "GetUserHandler") - userWithPendingBalance, err := h.svc.GetUserWithPendingBalance(userID) + userWithBalancesInUSD, err := h.svc.GetUserWithBalancesInUSD(userID) if err != nil { if errors.Is(err, models.ErrUserNotFound) { NotFound(c, "User is not found") @@ -695,53 +692,7 @@ func (h *UserHandler) GetUserHandler(c *gin.Context) { } OK(c, "User is retrieved successfully", gin.H{ - "user": userWithPendingBalance, - }) -} - -// @Summary Get user balance -// @Description Retrieves the user's balance in USD -// @Tags users -// @ID get-user-balance -// @Produce json -// @Success 200 {object} APIResponse{data=UserBalanceResponse} "Balance fetched successfully" -// @Failure 404 {object} APIResponse "User is not found" -// @Failure 500 {object} APIResponse -// @Router /user/balance [get] -// GetUserBalance returns user's balance in usd -func (h *UserHandler) GetUserBalance(c *gin.Context) { - userID := c.GetInt("user_id") - reqLog := requestLogger(c, "GetUserBalance") - - user, err := h.svc.GetUserByID(userID) - if err != nil { - if errors.Is(err, models.ErrUserNotFound) { - NotFound(c, "User is not found") - return - } - reqLog.Error().Err(err).Msg("User is not found") - InternalServerError(c) - return - } - - usdMillicentBalance, err := h.svc.GetUserBalanceInUSDMillicent(user.Mnemonic) - if err != nil { - reqLog.Error().Err(err).Msg("failed to get user balance in usd millicent") - InternalServerError(c) - return - } - - pendingAmountInUSDMillicent, err := h.svc.GetUserPendingBalanceInUSDMillicent(userID) - if err != nil { - reqLog.Error().Err(err).Msg("failed to list pending records") - InternalServerError(c) - return - } - - OK(c, "Balance is fetched", UserBalanceResponse{ - BalanceUSD: gridclient.FromUSDMilliCentToUSD(usdMillicentBalance), - DebtUSD: gridclient.FromUSDMilliCentToUSD(user.Debt), - PendingBalanceUSD: gridclient.FromUSDMilliCentToUSD(pendingAmountInUSDMillicent), + "user": userWithBalancesInUSD, }) } @@ -800,15 +751,25 @@ func (h *UserHandler) RedeemVoucherHandler(c *gin.Context) { return } - wfUUID, err := h.svc.AsyncRedeemVoucher(user.ID, voucher.Value, user.Mnemonic, user.Username, voucher.Code) - if err != nil { - reqLog.Error().Err(err).Msg("failed to redeem voucher") + millicentAmount := gridclient.FromUSDToUSDMillicent(voucher.Value) + user.CreditedBalance += millicentAmount + if err := h.svc.UpdateUserByID(&user); err != nil { + if errors.Is(err, models.ErrUserNotFound) { + NotFound(c, "User is not found") + return + } + reqLog.Error().Err(err).Send() + InternalServerError(c) + return + } + + if err := h.billingService.AfterUserGetCredit(c.Request.Context(), &user); err != nil { + reqLog.Error().Err(err).Msg("Failed to credit user balance") InternalServerError(c) return } - Accepted(c, "Voucher is redeemed successfully. Money transfer in progress.", RedeemVoucherResponse{ - WorkflowID: wfUUID, + Success(c, http.StatusOK, fmt.Sprintf("Voucher with value %v$ is redeemed successfully.", voucher.Value), RedeemVoucherResponse{ VoucherCode: voucher.Code, Amount: voucher.Value, Email: user.Email, @@ -1000,33 +961,6 @@ func (h *UserHandler) GetWorkflowStatus(c *gin.Context) { OK(c, "Status returned successfully", workflowStatus) } -// @Summary List user pending records -// @Description Returns user pending records in the system -// @Tags users -// @ID list-user-pending-records -// @Accept json -// @Produce json -// @Success 200 {object} APIResponse{data=PendingRecordsResponse} "Pending records returned successfully" -// @Failure 500 {object} APIResponse -// @Security BearerAuth -// @Router /user/pending-records [get] -// ListUserPendingRecordsHandler returns user pending records in the system -func (h *UserHandler) ListUserPendingRecordsHandler(c *gin.Context) { - userID := c.GetInt("user_id") - reqLog := requestLogger(c, "ListUserPendingRecordsHandler") - - pendingRecordsWithUSDAmounts, err := h.svc.ListUserPendingRecordsWithUSDAmounts(userID) - if err != nil { - reqLog.Error().Err(err).Msg("failed to list pending records with usd amounts") - InternalServerError(c) - return - } - - OK(c, "Pending records are retrieved successfully", gin.H{ - "pending_records": pendingRecordsWithUSDAmounts, - }) -} - // @Summary List remaining user workflows // @Description Returns all pending/running workflows belonging to the authenticated user. // @Tags workflow diff --git a/backend/internal/api/handlers/user_handler_test.go b/backend/internal/api/handlers/user_handler_test.go index e61d537d7..3eb50156f 100644 --- a/backend/internal/api/handlers/user_handler_test.go +++ b/backend/internal/api/handlers/user_handler_test.go @@ -612,47 +612,6 @@ func TestGetUserHandler(t *testing.T) { } -func TestGetUserBalanceHandler(t *testing.T) { - setup, err := SetUp(t) - require.NoError(t, err) - router := setup.router - t.Run("Test Get balance successfully", func(t *testing.T) { - - user := setup.CreateTestUser(t, "balanceuser@example.com", "Balance User", []byte("securepassword"), true, false, true, 0, time.Now()) - - assert.NoError(t, err) - token := setup.GetAuthToken(t, user.ID, user.Email, user.Username, false) - req, _ := http.NewRequest("GET", "/api/v1/user/balance", nil) - req.Header.Set("Authorization", "Bearer "+token) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - var result map[string]interface{} - err = json.Unmarshal(resp.Body.Bytes(), &result) - assert.NoError(t, err) - assert.Equal(t, "Balance is fetched", result["message"]) - assert.NotNil(t, result["data"]) - data := result["data"].(map[string]interface{}) - assert.Contains(t, data, "balance_usd") - assert.Contains(t, data, "debt_usd") - }) - - t.Run("Test Get balance for non-existing user", func(t *testing.T) { - - token := setup.GetAuthToken(t, 999, "notfound@example.com", "Not Found", false) - req, _ := http.NewRequest("GET", "/api/v1/user/balance", nil) - req.Header.Set("Authorization", "Bearer "+token) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusNotFound, resp.Code) - var result map[string]interface{} - err = json.Unmarshal(resp.Body.Bytes(), &result) - assert.NoError(t, err) - assert.Contains(t, result["message"], "User is not found") - }) - -} - func TestRedeemVoucherHandler(t *testing.T) { setup, err := SetUp(t) require.NoError(t, err) @@ -676,13 +635,11 @@ func TestRedeemVoucherHandler(t *testing.T) { req.Header.Set("Authorization", "Bearer "+token) resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusAccepted, resp.Code) + assert.Equal(t, http.StatusOK, resp.Code) var result map[string]interface{} err = json.Unmarshal(resp.Body.Bytes(), &result) assert.NoError(t, err) - assert.Equal(t, "Voucher is redeemed successfully. Money transfer in progress.", result["message"]) - assert.NotNil(t, result["data"]) - assert.NotEmpty(t, result["data"].(map[string]interface{})["workflow_id"]) + assert.Equal(t, "Voucher with value 50$ is redeemed successfully.", result["message"]) }) t.Run("Test redeem non-existing voucher", func(t *testing.T) { @@ -895,63 +852,6 @@ func TestAddSSHKeyHandler(t *testing.T) { router.ServeHTTP(resp2, req2) assert.Equal(t, http.StatusBadRequest, resp2.Code) }) - -} - -func TestListUserPendingRecordsHandler(t *testing.T) { - setup, err := SetUp(t) - require.NoError(t, err) - router := setup.router - user := setup.CreateTestUser(t, "pendinguser@example.com", "Pending User", []byte("securepassword"), true, false, false, 0, time.Now()) - token := setup.GetAuthToken(t, user.ID, user.Email, user.Username, false) - t.Run("Test list user pending records successfully", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/v1/user/pending-records", nil) - assert.NoError(t, err) - req.Header.Set("Authorization", "Bearer "+token) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - var result map[string]interface{} - err = json.Unmarshal(resp.Body.Bytes(), &result) - assert.NoError(t, err) - assert.Equal(t, "Pending records are retrieved successfully", result["message"]) - assert.NotNil(t, result["data"]) - }) - - t.Run("Test list user pending records with no records", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/v1/user/pending-records", nil) - assert.NoError(t, err) - - req.Header.Set("Authorization", "Bearer "+token) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusOK, resp.Code) - var result map[string]interface{} - err = json.Unmarshal(resp.Body.Bytes(), &result) - assert.NoError(t, err) - assert.Equal(t, "Pending records are retrieved successfully", result["message"]) - assert.NotNil(t, result["data"]) - }) - - t.Run("Test list user pending records with no token", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/v1/user/pending-records", nil) - assert.NoError(t, err) - - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusUnauthorized, resp.Code) - }) - - t.Run("Test list user pending records with invalid token", func(t *testing.T) { - req, err := http.NewRequest("GET", "/api/v1/user/pending-records", nil) - assert.NoError(t, err) - - req.Header.Set("Authorization", "Bearer invalidtoken") - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) - assert.Equal(t, http.StatusUnauthorized, resp.Code) - }) - } func TestListUserWorkflowsHandler(t *testing.T) { diff --git a/backend/internal/billing/contracts_billing.go b/backend/internal/billing/contracts_billing.go index 90ee6119b..84c24beb4 100644 --- a/backend/internal/billing/contracts_billing.go +++ b/backend/internal/billing/contracts_billing.go @@ -17,16 +17,15 @@ type ContractBillReports struct { } type Report struct { - ContractID string `json:"contractID"` - Timestamp string `json:"timestamp"` - AmountBilled string `json:"amountBilled"` + ContractID string `json:"contractID"` + Timestamp string `json:"timestamp"` + AmountBilled string `json:"amountBilled"` + DiscountReceived string `json:"discountReceived"` } -// ListContractBillReportsPerMonth returns bill reports for contract ID month ago -func ListContractBillReportsPerMonth(graphqlClient graphql.GraphQl, contractID uint64, currentTime time.Time) (ContractBillReports, error) { - monthAgo := currentTime.AddDate(0, -1, 0) - - options := fmt.Sprintf(`(where: {contractID_eq: %v, timestamp_lte: %v, timestamp_gte: %v}, orderBy: id_ASC)`, contractID, currentTime.Unix(), monthAgo.Unix()) +// ListContractBillReports returns bill reports for contract ID month ago +func ListContractBillReports(graphqlClient graphql.GraphQl, contractID uint64, startTime, endTime time.Time) (ContractBillReports, error) { + options := fmt.Sprintf(`(where: {contractID_eq: %v, timestamp_lte: %v, timestamp_gte: %v}, orderBy: id_ASC)`, contractID, endTime.Unix(), startTime.Unix()) billingReportsCount, err := graphqlClient.GetItemTotalCount("contractBillReports", options) if err != nil { return ContractBillReports{}, err @@ -36,8 +35,9 @@ func ListContractBillReportsPerMonth(graphqlClient graphql.GraphQl, contractID u contractID timestamp amountBilled + discountReceived } - }`, contractID, currentTime.Unix(), monthAgo.Unix()), + }`, contractID, endTime.Unix(), startTime.Unix()), map[string]interface{}{ "billingReportsCount": billingReportsCount, }) @@ -60,7 +60,7 @@ func ListContractBillReportsPerMonth(graphqlClient graphql.GraphQl, contractID u } // TODO: check returned float or int -func AmountBilledPerMonth(reports ContractBillReports) (uint64, error) { +func CalculateTotalAmountBilledForReports(reports ContractBillReports) (uint64, error) { var totalAmount uint64 for _, report := range reports.Reports { amount, err := strconv.ParseInt(report.AmountBilled, 10, 64) diff --git a/backend/internal/billing/contracts_billing_test.go b/backend/internal/billing/contracts_billing_test.go index 4c08d27ed..286630e64 100644 --- a/backend/internal/billing/contracts_billing_test.go +++ b/backend/internal/billing/contracts_billing_test.go @@ -24,14 +24,14 @@ func (m *MockGraphQLClient) GetItemTotalCount(itemType string, options string) ( return 0, nil } -// TestAmountBilledPerMonth tests amount calculation from billing reports. +// TestCalculateTotalAmountBilledForReports tests amount calculation from billing reports. // This scenario covers: // - Single report with amount is calculated correctly // - Multiple reports are summed correctly // - Zero amount reports are handled // - Large amounts are handled correctly // - Invalid amount string fails with error -func TestAmountBilledPerMonth(t *testing.T) { +func TestCalculateTotalAmountBilledForReports(t *testing.T) { tests := []struct { name string reports ContractBillReports @@ -122,15 +122,15 @@ func TestAmountBilledPerMonth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := AmountBilledPerMonth(tt.reports) + result, err := CalculateTotalAmountBilledForReports(tt.reports) if (err != nil) != tt.expectError { - t.Errorf("AmountBilledPerMonth() error = %v, expectError %v (%s)", err, tt.expectError, tt.description) + t.Errorf("CalculateTotalAmountBilledForReports() error = %v, expectError %v (%s)", err, tt.expectError, tt.description) return } if !tt.expectError && result != tt.expected { - t.Errorf("AmountBilledPerMonth() = %d, want %d (%s)", result, tt.expected, tt.description) + t.Errorf("CalculateTotalAmountBilledForReports() = %d, want %d (%s)", result, tt.expected, tt.description) } }) } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 1886e2100..b3731d06c 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -14,30 +14,33 @@ import ( ) type Configuration struct { - Server Server `json:"server" validate:"required,dive"` - Database DB `json:"database" validate:"required"` - JwtToken JwtToken `json:"jwt_token" validate:"required"` - Admins []string `json:"admins" validate:"required"` - MailSender MailSender `json:"mailSender"` - Currency string `json:"currency" default:"usd"` - StripeSecret string `json:"stripe_secret" validate:"required"` - VoucherNameLength int `json:"voucher_name_length" validate:"required,gt=0" default:"8"` - VerificationCodeLength int `json:"verification_code_length" validate:"gt=0" default:"4"` - TermsANDConditions TermsANDConditions `json:"terms_and_conditions"` - SystemAccount GridAccount `json:"system_account"` - DeployerWorkersNum int `json:"deployer_workers_num" default:"1"` - Invoice InvoiceCompanyData `json:"invoice"` - SSH SSHConfig `json:"ssh" validate:"required,dive"` - Redis RedisConfig `json:"redis" validate:"dive"` - Debug bool `json:"debug"` - DisableSentry bool `json:"disable_sentry" default:"true"` - DevMode bool `json:"dev_mode"` // When true, allows empty SendGridKey and uses FakeMailService - MonitorBalanceIntervalInMinutes int `json:"monitor_balance_interval_in_minutes" validate:"required,gt=0"` - NotifyAdminsForPendingRecordsInHours int `json:"notify_admins_for_pending_records_in_hours" validate:"required,gt=0"` - ClusterHealthCheckIntervalInHours int `json:"cluster_health_check_interval_in_hours" validate:"gt=0" default:"1"` - UsersBalanceCheckIntervalInHours int `json:"users_balance_check_interval_in_hours" validate:"gt=0" default:"6"` - CheckUserDebtIntervalInHours int `json:"check_user_debt_interval_in_hours" validate:"gt=0" default:"48"` - NodeHealthCheck ReservedNodeHealthCheckConfig `json:"node_health_check" validate:"required,dive"` + Server Server `json:"server" validate:"required,dive"` + Database DB `json:"database" validate:"required"` + JwtToken JwtToken `json:"jwt_token" validate:"required"` + Admins []string `json:"admins" validate:"required"` + MailSender MailSender `json:"mailSender"` + Currency string `json:"currency" default:"usd"` + StripeSecret string `json:"stripe_secret" validate:"required"` + VoucherNameLength int `json:"voucher_name_length" validate:"required,gt=0" default:"8"` + VerificationCodeLength int `json:"verification_code_length" validate:"gt=0" default:"4"` + TermsANDConditions TermsANDConditions `json:"terms_and_conditions"` + SystemAccount GridAccount `json:"system_account"` + DeployerWorkersNum int `json:"deployer_workers_num" default:"1"` + Invoice InvoiceCompanyData `json:"invoice"` + SSH SSHConfig `json:"ssh" validate:"required,dive"` + Redis RedisConfig `json:"redis" validate:"dive"` + Debug bool `json:"debug"` + DisableSentry bool `json:"disable_sentry" default:"true"` + DevMode bool `json:"dev_mode"` // When true, allows empty SendGridKey and uses FakeMailService + NotifyAdminsForPendingRecordsInHours int `json:"notify_admins_for_pending_records_in_hours" validate:"required,gt=0"` + ClusterHealthCheckIntervalInHours int `json:"cluster_health_check_interval_in_hours" validate:"gt=0" default:"1"` + SettleTransferRecordsIntervalInMinutes int `json:"settle_transfer_records_interval_in_minutes" validate:"required,gt=0"` + NodeHealthCheck ReservedNodeHealthCheckConfig `json:"node_health_check" validate:"required,dive"` + UsersBalanceCheckIntervalInHours int `json:"users_balance_check_interval_in_hours" validate:"gt=0" default:"6"` + CheckUserDebtIntervalInHours int `json:"check_user_debt_interval_in_hours" validate:"gt=0" default:"48"` + + AppliedDiscount string `json:"applied_discount" validate:"required"` + MinimumTFTAmountInWallet int `json:"minimum_tft_amount_in_wallet" default:"10" validate:"required,gt=0"` Logger LoggerConfig `json:"logger"` Loki LokiConfig `json:"loki"` @@ -321,6 +324,11 @@ func applyDefaultValues(config *Configuration) { config.ClusterHealthCheckIntervalInHours = 1 } + // SettleTransferRecordsIntervalInMinutes default + if config.SettleTransferRecordsIntervalInMinutes == 0 { + config.SettleTransferRecordsIntervalInMinutes = 5 + } + // JwtToken defaults if config.JwtToken.AccessExpiryMinutes == 0 { config.JwtToken.AccessExpiryMinutes = 60 @@ -384,13 +392,18 @@ func applyDefaultValues(config *Configuration) { config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum = 10 } - if config.MonitorBalanceIntervalInMinutes == 0 { - config.MonitorBalanceIntervalInMinutes = 120 - } if config.NotifyAdminsForPendingRecordsInHours == 0 { config.NotifyAdminsForPendingRecordsInHours = 24 } + if config.MinimumTFTAmountInWallet == 0 { + config.MinimumTFTAmountInWallet = 10 + } + + if config.AppliedDiscount == "" { + config.AppliedDiscount = "gold" + } + if config.Telemetry.OTLPEndpoint == "" { config.Telemetry.OTLPEndpoint = "jaeger:4317" } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 25e2a88a9..4cf684d01 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -138,7 +138,7 @@ func TestLoadConfig(t *testing.T) { assert.Equal(t, "Test Governorate", config.Invoice.Governorate) assert.Equal(t, privateKeyPath, config.SSH.PrivateKeyPath) assert.Equal(t, publicKeyPath, config.SSH.PublicKeyPath) - assert.Equal(t, 5, config.MonitorBalanceIntervalInMinutes) + assert.Equal(t, 5, config.SettleTransferRecordsIntervalInMinutes) assert.Equal(t, 2, config.NotifyAdminsForPendingRecordsInHours) assert.Equal(t, 1, config.ClusterHealthCheckIntervalInHours) assert.Equal(t, 1, config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours) @@ -244,7 +244,7 @@ func TestDefaultTagsInConfig(t *testing.T) { assert.Equal(t, 1, config.NodeHealthCheck.ReservedNodeHealthCheckIntervalInHours, "Default node health check interval not applied") assert.Equal(t, 1, config.NodeHealthCheck.ReservedNodeHealthCheckTimeoutInMinutes, "Default node health check timeout not applied") assert.Equal(t, 10, config.NodeHealthCheck.ReservedNodeHealthCheckWorkersNum, "Default node health check workers num not applied") - assert.Equal(t, 120, config.MonitorBalanceIntervalInMinutes, "Default monitor balance interval not applied") + assert.Equal(t, 5, config.SettleTransferRecordsIntervalInMinutes, "Default monitor balance interval not applied") assert.Equal(t, 24, config.NotifyAdminsForPendingRecordsInHours, "Default notify admins for pending records interval not applied") } diff --git a/backend/internal/core/models/invoice.go b/backend/internal/core/models/invoice.go index 0ca8fa4c6..f5e43c417 100644 --- a/backend/internal/core/models/invoice.go +++ b/backend/internal/core/models/invoice.go @@ -5,14 +5,13 @@ import ( ) type Invoice struct { - ID int `json:"id" gorm:"primaryKey"` - UserID int `json:"user_id" binding:"required"` - Total float64 `json:"total"` - Nodes []NodeItem `json:"nodes" gorm:"foreignKey:invoice_id"` - // TODO: - Tax float64 `json:"tax"` - CreatedAt time.Time `json:"created_at"` - FileData []byte `json:"-" gorm:"type:bytea;column:file_data"` + ID int `json:"id" gorm:"primaryKey"` + UserID int `json:"user_id" binding:"required"` + Total float64 `json:"total"` + Nodes []NodeItem `json:"nodes" gorm:"foreignKey:invoice_id"` + Tax float64 `json:"tax"` + CreatedAt time.Time `json:"created_at"` + FileData []byte `json:"-" gorm:"type:bytea;column:file_data"` } type NodeItem struct { diff --git a/backend/internal/core/models/pending_record.go b/backend/internal/core/models/pending_record.go deleted file mode 100644 index db32b9677..000000000 --- a/backend/internal/core/models/pending_record.go +++ /dev/null @@ -1,23 +0,0 @@ -package models - -import ( - "time" -) - -const ( - ChargeBalanceMode = "charge_balance" - RedeemVoucherMode = "redeem_voucher" - AdminCreditMode = "admin_credit" -) - -type PendingRecord struct { - ID int `json:"id" gorm:"primaryKey;autoIncrement"` - UserID int `json:"user_id" gorm:"not null"` - Username string `json:"username"` - // TFTs are multiplied by 1e7 - TFTAmount uint64 `json:"tft_amount" gorm:"not null"` - TransferredTFTAmount uint64 `json:"transferred_tft_amount" gorm:"not null"` - TransferMode string `json:"transfer_mode"` - CreatedAt time.Time `json:"created_at" gorm:"not null"` - UpdatedAt time.Time `json:"updated_at" gorm:"not null"` -} diff --git a/backend/internal/core/models/repositories.go b/backend/internal/core/models/repositories.go index 8a8c716ad..1fab3c81b 100644 --- a/backend/internal/core/models/repositories.go +++ b/backend/internal/core/models/repositories.go @@ -1,11 +1,14 @@ package models +import "time" + // UserRepository defines operations for user data persistence type UserRepository interface { RegisterUser(user *User) error GetUserByEmail(email string) (User, error) GetUserByID(userID int) (User, error) UpdateUserByID(user *User) error + DeductUserBalance(user *User, amount uint64) error ListAllUsers() ([]User, error) ListAdmins() ([]User, error) DeleteUserByID(userID int) error @@ -15,6 +18,10 @@ type UserRepository interface { // stats methods CountAllUsers() (int64, error) + // Usage calculation time methods + GetUserLastCalcTime(userID int) (time.Time, error) + UpdateUserLastCalcTime(userID int, lastCalcTime time.Time) error + // SSH Key methods CreateSSHKey(sshKey *SSHKey) error ListUserSSHKeys(userID int) ([]SSHKey, error) @@ -27,23 +34,24 @@ type ClusterRepository interface { CreateCluster(userID int, cluster *Cluster) error ListUserClusters(userID int) ([]Cluster, error) GetClusterByName(userID int, projectName string) (Cluster, error) - UpdateCluster(cluster *Cluster) error + UpdateCluster(contractsRepo ContractDataRepository, cluster *Cluster) error DeleteCluster(userID int, projectName string) error - DeleteAllUserClusters(userID int) error + DeleteAllUserClusters(contractsRepo ContractDataRepository, userID int) error // stats methods CountAllClusters() (int64, error) ListAllClusters() ([]Cluster, error) } -// UserNodesRepository defines operations for user nodes data persistence -type UserNodesRepository interface { - CreateUserNode(userNode *UserNodes) error - DeleteUserNode(contractID uint64) error - ListUserNodes(userID int) ([]UserNodes, error) - GetUserNodeByNodeID(nodeID uint64) (UserNodes, error) - GetUserNodeByContractID(contractID uint64) (UserNodes, error) - ListAllReservedNodes() ([]UserNodes, error) +// ContractDataRepository defines operations for contract data persistence +type ContractDataRepository interface { + CreateUserContractData(contractData *UserContractData) error + DeleteUserContract(contractID uint64) error + ListUserRentedNodes(userID int) ([]UserContractData, error) + GetUserNodeByNodeID(nodeID uint64) (UserContractData, error) + GetUserNodeByContractID(contractID uint64) (UserContractData, error) + ListAllReservedNodes() ([]UserContractData, error) + ListAllContractsInPeriod(userID int, start, end time.Time) ([]UserContractData, error) } // VoucherRepository defines operations for voucher data persistence @@ -80,13 +88,15 @@ type NotificationRepository interface { DeleteAllNotifications(userID int) error } -// PendingRecordRepository defines operations for pending record data persistence -type PendingRecordRepository interface { - CreatePendingRecord(record *PendingRecord) error - ListAllPendingRecords() ([]PendingRecord, error) - ListOnlyPendingRecords() ([]PendingRecord, error) - ListUserPendingRecords(userID int) ([]PendingRecord, error) - UpdatePendingRecordTransferredAmount(id int, amount uint64) error +// TransferRecordRepository defines operations for transfer record data persistence +type TransferRecordRepository interface { + CreateTransferRecord(record *TransferRecord) error + ListTransferRecords() ([]TransferRecord, error) + ListUserTransferRecords(userID int) ([]TransferRecord, error) + ListPendingTransferRecords() ([]TransferRecord, error) + ListFailedTransferRecords() ([]TransferRecord, error) + UpdateTransferRecordState(recordID int, state State, failure string) error + CalculateTotalPendingTFTAmountPerUser(userID int) (uint64, error) } // SettingsRepository defines operations for settings data persistence diff --git a/backend/internal/core/models/transfer_record.go b/backend/internal/core/models/transfer_record.go new file mode 100644 index 000000000..60460e22a --- /dev/null +++ b/backend/internal/core/models/transfer_record.go @@ -0,0 +1,27 @@ +package models + +import "time" + +type operation string +type State string + +const ( + WithdrawOperation operation = "withdraw" + DepositOperation operation = "deposit" + + FailedState State = "failed" + SuccessState State = "success" + PendingState State = "pending" +) + +type TransferRecord struct { + ID int `json:"id" gorm:"primaryKey;autoIncrement"` + UserID int `json:"user_id" gorm:"not null"` + Username string `json:"username"` + TFTAmount uint64 `json:"tft_amount" gorm:"not null"` // TFTs are multiplied by 1e7 + Operation operation `json:"operation" gorm:"not null"` + State State `json:"state" gorm:"not null;default:pending"` + Failure string `json:"failure" gorm:"not null"` + CreatedAt time.Time `json:"created_at" gorm:"not null"` + UpdatedAt time.Time `json:"updated_at" gorm:"not null"` +} diff --git a/backend/internal/core/models/usage_calculation_time.go b/backend/internal/core/models/usage_calculation_time.go new file mode 100644 index 000000000..524859d00 --- /dev/null +++ b/backend/internal/core/models/usage_calculation_time.go @@ -0,0 +1,11 @@ +package models + +import "time" + +// UserUsageCalculationTime represents the last time a user's usage was calculated +type UserUsageCalculationTime struct { + ID int `gorm:"primaryKey;autoIncrement;column:id"` + UserID int `gorm:"user_id;index:idx_user_id,unique" binding:"required"` + LastCalcTime time.Time `json:"last_calc_time"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/backend/internal/core/models/user.go b/backend/internal/core/models/user.go index 0d886e6c1..f3f843ac5 100644 --- a/backend/internal/core/models/user.go +++ b/backend/internal/core/models/user.go @@ -6,6 +6,13 @@ import ( "gorm.io/gorm" ) +type ContractType string + +const ( + ContractTypeRented ContractType = "rented" + ContractTypeDeployed ContractType = "deployed" +) + // User represents a user in the system type User struct { ID int `json:"id" gorm:"primaryKey;autoIncrement;column:id"` diff --git a/backend/internal/core/models/user_nodes.go b/backend/internal/core/models/user_nodes.go index ba6d97249..6b0680008 100644 --- a/backend/internal/core/models/user_nodes.go +++ b/backend/internal/core/models/user_nodes.go @@ -6,12 +6,13 @@ import ( "gorm.io/gorm" ) -// UserNodes model holds info of reserved nodes of user -type UserNodes struct { +// UserContractData model holds info of contracts of user +type UserContractData struct { ID int `gorm:"primaryKey;autoIncrement;column:id"` UserID int `gorm:"user_id" binding:"required"` ContractID uint64 `gorm:"contract_id" binding:"required"` NodeID uint32 `gorm:"column:node_id;index:idx_user_node_id,unique,where:deleted_at IS NULL" binding:"required"` + Type ContractType `gorm:"type" binding:"required"` CreatedAt time.Time `json:"created_at"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` } diff --git a/backend/internal/core/persistence/gorm_repositories.go b/backend/internal/core/persistence/gorm_repositories.go index 668a37aa7..bb00b16ca 100644 --- a/backend/internal/core/persistence/gorm_repositories.go +++ b/backend/internal/core/persistence/gorm_repositories.go @@ -26,7 +26,11 @@ func NewGormUserRepository(db models.DB) models.UserRepository { // RegisterUser registers a new user to the system func (r *GormUserRepository) RegisterUser(user *models.User) error { - return r.db.Create(user).Error + if err := r.db.Create(user).Error; err != nil { + return err + } + + return r.UpdateUserLastCalcTime(user.ID, time.Now()) } // SetStateUserID sets gorm user ID in workflow state @@ -86,6 +90,36 @@ func (r *GormUserRepository) UpdateUserByID(user *models.User) error { return nil } +func (r *GormUserRepository) DeductUserBalance(user *models.User, amount uint64) error { + if user.CreditedBalance >= amount { + return r.db.Model(&models.User{}). + Where("id = ?", user.ID). + UpdateColumn("credited_balance", gorm.Expr("credited_balance - ?", amount)). + Error + } + + if user.CreditedBalance > 0 && user.CreditCardBalance >= amount-user.CreditedBalance { + return r.db.Model(&models.User{}). + Where("id = ?", user.ID). + UpdateColumn("credited_balance", gorm.Expr("credited_balance - ?", user.CreditedBalance)). + UpdateColumn("credit_card_balance", gorm.Expr("credit_card_balance - ?", amount-user.CreditedBalance)). + Error + } + + if user.CreditCardBalance >= amount { + return r.db.Model(&models.User{}). + Where("id = ?", user.ID). + UpdateColumn("credit_card_balance", gorm.Expr("credit_card_balance - ?", amount)). + Error + } + + // if credit card balance is not enough, add debt + return r.db.Model(&models.User{}). + Where("id = ?", user.ID). + UpdateColumn("debt", gorm.Expr("debt + ?", amount)). + Error +} + // ListAllUsers lists all users in system func (r *GormUserRepository) ListAllUsers() ([]models.User, error) { var users []models.User @@ -120,6 +154,44 @@ func (r *GormUserRepository) DeleteUserByID(userID int) error { return nil } +// GetUserLastCalcTime returns the last calculation time for a user +func (r *GormUserRepository) GetUserLastCalcTime(userID int) (time.Time, error) { + var calcTime models.UserUsageCalculationTime + err := r.db.Where("user_id = ?", userID).First(&calcTime).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + // If no record exists, return zero time + return time.Time{}, nil + } + return time.Time{}, err + } + return calcTime.LastCalcTime, nil +} + +// UpdateUserLastCalcTime updates the last calculation time for a user +func (r *GormUserRepository) UpdateUserLastCalcTime(userID int, lastCalcTime time.Time) error { + var calcTime models.UserUsageCalculationTime + err := r.db.Where("user_id = ?", userID).First(&calcTime).Error + + if err != nil { + if err == gorm.ErrRecordNotFound { + // Create a new record if one doesn't exist + calcTime = models.UserUsageCalculationTime{ + UserID: userID, + LastCalcTime: lastCalcTime, + UpdatedAt: time.Now(), + } + return r.db.Create(&calcTime).Error + } + return err + } + + // Update existing record + calcTime.LastCalcTime = lastCalcTime + calcTime.UpdatedAt = time.Now() + return r.db.Save(&calcTime).Error +} + // CreditUserBalance add credited balance to user by its ID func (r *GormUserRepository) CreditUserBalance(userID int, amount uint64) error { return r.db.Model(&models.User{}). @@ -219,7 +291,29 @@ func (r *GormClusterRepository) CreateCluster(userID int, cluster *models.Cluste cluster.CreatedAt = time.Now() cluster.UpdatedAt = time.Now() cluster.UserID = userID - return r.db.Create(cluster).Error + + clusterData, err := cluster.GetClusterResult() + if err != nil { + return err + } + + return r.db.Transaction(func(tx *gorm.DB) error { + for _, node := range clusterData.Nodes { + contractData := &models.UserContractData{ + UserID: userID, + ContractID: node.ContractID, + NodeID: node.NodeID, + Type: models.ContractTypeDeployed, + CreatedAt: time.Now(), + } + + if err := tx.Create(contractData).Error; err != nil { + return fmt.Errorf("failed to create contract data for node %d: %w", node.NodeID, err) + } + } + + return tx.Create(cluster).Error + }) } func (r *GormClusterRepository) ListUserClusters(userID int) ([]models.Cluster, error) { @@ -238,7 +332,58 @@ func (r *GormClusterRepository) GetClusterByName(userID int, projectName string) return cluster, query.Error } -func (r *GormClusterRepository) UpdateCluster(cluster *models.Cluster) error { +func (r *GormClusterRepository) UpdateCluster(contractsRepo models.ContractDataRepository, cluster *models.Cluster) error { + existingCluster, err := r.GetClusterByName(cluster.UserID, cluster.ProjectName) + if err != nil { + return fmt.Errorf("failed to get existing cluster: %w", err) + } + + existingClusterData, err := existingCluster.GetClusterResult() + if err != nil { + return fmt.Errorf("failed to parse existing cluster data: %w", err) + } + + newClusterData, err := cluster.GetClusterResult() + if err != nil { + return fmt.Errorf("failed to parse new cluster data: %w", err) + } + + existingNodes := make(map[uint64]struct{}) + for _, node := range existingClusterData.Nodes { + if node.ContractID != 0 { + existingNodes[node.ContractID] = struct{}{} + } + } + + for _, node := range newClusterData.Nodes { + if node.ContractID != 0 { + if _, exists := existingNodes[node.ContractID]; !exists { + // This is a new node, create a contract for it + if err := contractsRepo.CreateUserContractData( + &models.UserContractData{ + UserID: cluster.UserID, + ContractID: node.ContractID, + NodeID: node.NodeID, + Type: models.ContractTypeDeployed, + CreatedAt: time.Now(), + }, + ); err != nil { + return fmt.Errorf("failed to create contract for new node: %w", err) + } + } + // Remove from existing nodes map to track what is processed + delete(existingNodes, node.ContractID) + } + } + + // Handle removed nodes - delete contracts for nodes that exist in old but not in new + for contractID := range existingNodes { + if err := contractsRepo.DeleteUserContract(contractID); err != nil { + return fmt.Errorf("failed to delete contract for removed node: %w", err) + } + } + + // Update the cluster record cluster.UpdatedAt = time.Now() return r.db.Model(&models.Cluster{}). Where("user_id = ? AND project_name = ?", cluster.UserID, cluster.ProjectName). @@ -246,20 +391,51 @@ func (r *GormClusterRepository) UpdateCluster(cluster *models.Cluster) error { } func (r *GormClusterRepository) DeleteCluster(userID int, projectName string) error { - query := r.db.Where("user_id = ? AND project_name = ?", userID, projectName).Delete(&models.Cluster{}) - - if errors.Is(query.Error, gorm.ErrRecordNotFound) { - return models.ErrClusterNotFound + cluster, err := r.GetClusterByName(userID, projectName) + if err != nil { + return err } - if query.RowsAffected == 0 { - return models.ErrClusterNotFound + clusterData, err := cluster.GetClusterResult() + if err != nil { + return err } - return query.Error + err = r.db.Transaction(func(tx *gorm.DB) error { + for _, node := range clusterData.Nodes { + if err := tx.Model(&models.UserContractData{}). + Where("contract_id = ?", node.ContractID). + Update("deleted_at", time.Now()).Error; err != nil { + return fmt.Errorf("failed to delete contract for node %d: %w", node.NodeID, err) + } + } + + return tx.Where("user_id = ? AND project_name = ?", userID, projectName). + Delete(&models.Cluster{}).Error + }) + + return err } -func (r *GormClusterRepository) DeleteAllUserClusters(userID int) error { +func (r *GormClusterRepository) DeleteAllUserClusters(contractsRepo models.ContractDataRepository, userID int) error { + clusters, err := r.ListUserClusters(userID) + if err != nil { + return err + } + + for _, cluster := range clusters { + clusterData, err := cluster.GetClusterResult() + if err != nil { + return err + } + + for _, node := range clusterData.Nodes { + if err := contractsRepo.DeleteUserContract(node.ContractID); err != nil { + return err + } + } + } + return r.db.Where("user_id = ?", userID).Delete(&models.Cluster{}).Error } @@ -331,55 +507,71 @@ func (r *GormVoucherRepository) RedeemVoucher(userID int, username, code string) return nil } -// UserNodes Repository +// UserContractData Repository -var _ models.UserNodesRepository = (*GormUserNodesRepository)(nil) +var _ models.ContractDataRepository = (*GormUserContractDataRepository)(nil) -type GormUserNodesRepository struct { +type GormUserContractDataRepository struct { db *gorm.DB } -func NewGormUserNodesRepository(db models.DB) models.UserNodesRepository { - return &GormUserNodesRepository{db: db.GetDB()} +func NewGormUserContractDataRepository(db models.DB) models.ContractDataRepository { + return &GormUserContractDataRepository{db: db.GetDB()} } -func (r *GormUserNodesRepository) CreateUserNode(userNode *models.UserNodes) error { - return r.db.Create(&userNode).Error +// CreateUserContractData creates new contract record for user +func (r *GormUserContractDataRepository) CreateUserContractData(contractData *models.UserContractData) error { + return r.db.Create(&contractData).Error } -func (r *GormUserNodesRepository) DeleteUserNode(contractID uint64) error { - return r.db.Where("contract_id = ?", contractID).Delete(&models.UserNodes{}).Error +// DeleteUserContract updates deleted time of a contract record for user by its contract ID +func (r *GormUserContractDataRepository) DeleteUserContract(contractID uint64) error { + return r.db.Where("contract_id = ?", contractID).Update("deleted_at", time.Now()).Error } -func (r *GormUserNodesRepository) ListUserNodes(userID int) ([]models.UserNodes, error) { - var userNodes []models.UserNodes - return userNodes, r.db.Where("user_id = ?", userID).Find(&userNodes).Error +// ListUserRentedNodes returns all nodes records for user by its ID +func (r *GormUserContractDataRepository) ListUserRentedNodes(userID int) ([]models.UserContractData, error) { + var userNodes []models.UserContractData + return userNodes, r.db.Where("user_id = ? and type = ? and deleted_at = ?", userID, models.ContractTypeRented, time.Time{}).Find(&userNodes).Error } -func (r *GormUserNodesRepository) ListAllReservedNodes() ([]models.UserNodes, error) { - var userNodes []models.UserNodes - return userNodes, r.db.Find(&userNodes).Error -} +// ListAllContractsInPeriod returns all contracts that existed during the specified time period. +// This includes: +// 1. Contracts created before or during the period end date +// 2. AND either not deleted (deleted_at is zero time) OR deleted after the period start date +// If userID is provided (non-zero), it will only return contracts for that specific user. +// If userID is 0, it will return contracts for all users. +func (r *GormUserContractDataRepository) ListAllContractsInPeriod(userID int, start, end time.Time) ([]models.UserContractData, error) { + var userNodes []models.UserContractData -func (r *GormUserNodesRepository) GetUserNodeByNodeID(nodeID uint64) (models.UserNodes, error) { - var userNode models.UserNodes - result := r.db.Where("node_id = ?", nodeID).First(&userNode) + // Query for contracts that: + // - Were created on or before the end date of the period + // - AND are either not deleted (deleted_at is zero) OR were deleted after the start of the period + query := r.db.Where("created_at <= ?", end). + Where("(deleted_at = ? OR deleted_at >= ?)", time.Time{}, start) - if result.Error != nil && errors.Is(result.Error, gorm.ErrRecordNotFound) { - return models.UserNodes{}, models.ErrUserNodeNotFound + // If userID is provided (non-zero), filter by that user + if userID > 0 { + query = query.Where("user_id = ?", userID) } - return userNode, result.Error + return userNodes, query.Find(&userNodes).Error } -func (r *GormUserNodesRepository) GetUserNodeByContractID(contractID uint64) (models.UserNodes, error) { - var userNode models.UserNodes - result := r.db.Where("contract_id = ?", contractID).First(&userNode) - if result.Error != nil && errors.Is(result.Error, gorm.ErrRecordNotFound) { - return models.UserNodes{}, models.ErrUserNodeNotFound - } +// ListAllReservedNodes returns all reserved nodes from all users +func (r *GormUserContractDataRepository) ListAllReservedNodes() ([]models.UserContractData, error) { + var userNodes []models.UserContractData + return userNodes, r.db.Where("type = ? and deleted_at = ?", models.ContractTypeRented, time.Time{}).Find(&userNodes).Error +} + +func (r *GormUserContractDataRepository) GetUserNodeByNodeID(nodeID uint64) (models.UserContractData, error) { + var userNode models.UserContractData + return userNode, r.db.Where("node_id = ? and deleted_at = ?", nodeID, time.Time{}).First(&userNode).Error +} - return userNode, result.Error +func (r *GormUserContractDataRepository) GetUserNodeByContractID(contractID uint64) (models.UserContractData, error) { + var userNode models.UserContractData + return userNode, r.db.Where("contract_id = ? and deleted_at = ?", contractID, time.Time{}).First(&userNode).Error } // Transaction Repository @@ -458,44 +650,57 @@ func (r *GormSettingsRepository) GetMaintenanceMode() (bool, error) { return value == maintenanceModeEnabled, nil } -// PendingRecord Repository +// TransferRecord Repository -var _ models.PendingRecordRepository = (*GormPendingRecordRepository)(nil) +var _ models.TransferRecordRepository = (*GormTransferRecordRepository)(nil) -type GormPendingRecordRepository struct { +type GormTransferRecordRepository struct { db *gorm.DB } -func NewGormPendingRecordRepository(db models.DB) models.PendingRecordRepository { - return &GormPendingRecordRepository{db: db.GetDB()} +func NewGormTransferRecordRepository(db models.DB) models.TransferRecordRepository { + return &GormTransferRecordRepository{db: db.GetDB()} } -func (r *GormPendingRecordRepository) CreatePendingRecord(record *models.PendingRecord) error { +func (r *GormTransferRecordRepository) CreateTransferRecord(record *models.TransferRecord) error { record.CreatedAt = time.Now() return r.db.Create(record).Error } +func (r *GormTransferRecordRepository) ListTransferRecords() ([]models.TransferRecord, error) { + var TransferRecords []models.TransferRecord + return TransferRecords, r.db.Find(&TransferRecords).Error +} -func (r *GormPendingRecordRepository) ListAllPendingRecords() ([]models.PendingRecord, error) { - var pendingRecords []models.PendingRecord - return pendingRecords, r.db.Find(&pendingRecords).Error +func (r *GormTransferRecordRepository) CalculateTotalPendingTFTAmountPerUser(userID int) (uint64, error) { + var totalAmount uint64 + err := r.db.Model(&models.TransferRecord{}). + Select("COALESCE(SUM(tft_amount), 0)"). + Where("user_id = ? AND state = ?", userID, models.PendingState). + Scan(&totalAmount).Error + if err != nil { + return 0, err + } + return totalAmount, nil } -func (r *GormPendingRecordRepository) ListOnlyPendingRecords() ([]models.PendingRecord, error) { - var pendingRecords []models.PendingRecord - return pendingRecords, r.db.Where("tft_amount > transferred_tft_amount").Find(&pendingRecords).Error +func (r *GormTransferRecordRepository) ListUserTransferRecords(userID int) ([]models.TransferRecord, error) { + var TransferRecords []models.TransferRecord + return TransferRecords, r.db.Where("user_id = ?", userID).Find(&TransferRecords).Error } -func (r *GormPendingRecordRepository) ListUserPendingRecords(userID int) ([]models.PendingRecord, error) { - var pendingRecords []models.PendingRecord - return pendingRecords, r.db.Where("user_id = ?", userID).Find(&pendingRecords).Error +func (r *GormTransferRecordRepository) ListPendingTransferRecords() ([]models.TransferRecord, error) { + var TransferRecords []models.TransferRecord + return TransferRecords, r.db.Where("state = ?", models.PendingState).Find(&TransferRecords).Error } -func (r *GormPendingRecordRepository) UpdatePendingRecordTransferredAmount(id int, amount uint64) error { - return r.db.Model(&models.PendingRecord{}). - Where("id = ?", id). - UpdateColumn("transferred_tft_amount", gorm.Expr("transferred_tft_amount + ?", amount)). - UpdateColumn("updated_at", gorm.Expr("?", time.Now())). - Error +func (r *GormTransferRecordRepository) ListFailedTransferRecords() ([]models.TransferRecord, error) { + var TransferRecords []models.TransferRecord + return TransferRecords, r.db.Where("state = ?", models.FailedState).Find(&TransferRecords).Error +} + +func (r *GormTransferRecordRepository) UpdateTransferRecordState(recordID int, state models.State, failure string) error { + return r.db.Model(&models.TransferRecord{}).Where("id = ?", recordID).Updates( + map[string]interface{}{"state": state, "failure": failure, "updated_at": time.Now()}).Error } // Notification Repository diff --git a/backend/internal/core/services/admin_service.go b/backend/internal/core/services/admin_service.go index cf26df18c..6683c4acb 100644 --- a/backend/internal/core/services/admin_service.go +++ b/backend/internal/core/services/admin_service.go @@ -21,12 +21,12 @@ import ( ) type AdminService struct { - userRepo models.UserRepository - nodesRepo models.UserNodesRepository - prRepo models.PendingRecordRepository - voucherRepo models.VoucherRepository - transRepo models.TransactionRepository - ewfRepo *persistence.GormEWFRepository + userRepo models.UserRepository + contractsRepo models.ContractDataRepository + transferRecordsRepo models.TransferRecordRepository + voucherRepo models.VoucherRepository + transRepo models.TransactionRepository + ewfRepo *persistence.GormEWFRepository appCtx context.Context gridClient gridclient.GridClient @@ -37,8 +37,8 @@ type AdminService struct { func NewAdminService(appCtx context.Context, userRepo models.UserRepository, - userNodeRepo models.UserNodesRepository, - pendingRecordRepo models.PendingRecordRepository, + contractsRepo models.ContractDataRepository, + transferRecordsRepo models.TransferRecordRepository, voucherRepo models.VoucherRepository, transactionRepo models.TransactionRepository, gridClient gridclient.GridClient, @@ -48,12 +48,12 @@ func NewAdminService(appCtx context.Context, ewfRepo *persistence.GormEWFRepository, ) AdminService { return AdminService{ - userRepo: userRepo, - nodesRepo: userNodeRepo, - prRepo: pendingRecordRepo, - voucherRepo: voucherRepo, - transRepo: transactionRepo, - ewfRepo: ewfRepo, + userRepo: userRepo, + contractsRepo: contractsRepo, + transferRecordsRepo: transferRecordsRepo, + voucherRepo: voucherRepo, + transRepo: transactionRepo, + ewfRepo: ewfRepo, appCtx: appCtx, gridClient: gridClient, @@ -65,22 +65,25 @@ func NewAdminService(appCtx context.Context, const maxConcurrentBalanceFetches = 20 -type UserWithUSDBalance struct { +type UserWithTFTBalance struct { models.User - Balance float64 `json:"balance"` // USD balance + BalanceInTFT float64 `json:"balance_in_tft"` } -type PendingRecordsWithUSDAmounts struct { - models.PendingRecord - USDAmount float64 `json:"usd_amount"` - TransferredUSDAmount float64 `json:"transferred_usd_amount"` +type TransferRecordsWithTFTAmount struct { + models.TransferRecord + TFTAmountInWholeUnit float32 `json:"tft_amount_in_whole_unit"` } func (svc *AdminService) ListAllUsers() ([]models.User, error) { return svc.userRepo.ListAllUsers() } -func (svc *AdminService) ListAllUsersIncludingUSDBalance() ([]UserWithUSDBalance, error) { +func (svc *AdminService) GetUserByID(id int) (models.User, error) { + return svc.userRepo.GetUserByID(id) +} + +func (svc *AdminService) ListAllUsersIncludingUSDBalance() ([]UserWithTFTBalance, error) { users, err := svc.ListAllUsers() // Here is the only critical errors, not the balance related ones if err != nil { @@ -88,7 +91,7 @@ func (svc *AdminService) ListAllUsersIncludingUSDBalance() ([]UserWithUSDBalance } var ( - usersWithBalance []UserWithUSDBalance + usersWithBalance []UserWithTFTBalance wg sync.WaitGroup mu sync.Mutex balanceErrors *multierror.Error @@ -104,18 +107,19 @@ func (svc *AdminService) ListAllUsersIncludingUSDBalance() ([]UserWithUSDBalance defer wg.Done() defer func() { <-balanceConcurrencyLimiter }() - balance, err := svc.gridClient.GetUserBalanceUSD(user.Mnemonic) + balanceInTFTUnit, err := svc.gridClient.GetFreeBalanceTFT(user.Mnemonic) if err != nil { + logger.GetLogger().Error().Err(err).Int("user_id", user.ID).Msg("failed to get user balance") mu.Lock() balanceErrors = multierror.Append(balanceErrors, fmt.Errorf("failed to get balance for user %d: %w", user.ID, err)) mu.Unlock() - balance = 0.0 + return } mu.Lock() - usersWithBalance = append(usersWithBalance, UserWithUSDBalance{ - User: user, - Balance: balance, + usersWithBalance = append(usersWithBalance, UserWithTFTBalance{ + User: user, + BalanceInTFT: float64(balanceInTFTUnit) / TFTUnitFactor, }) mu.Unlock() }(user) @@ -136,40 +140,6 @@ func (svc *AdminService) DeleteUserByID(userID int) error { return svc.userRepo.DeleteUserByID(userID) } -func (svc *AdminService) AsyncCreditUserUSD(transaction *models.Transaction) error { - if err := svc.transRepo.CreateTransaction(transaction); err != nil { - return err - } - - user, err := svc.userRepo.GetUserByID(transaction.UserID) - if err != nil { - return err - } - - displayName := fmt.Sprintf("Admin credit balance for %s", user.Username) - wf, err := svc.ewfEngine.NewWorkflow(workflows.WorkflowAdminCreditBalance, ewf.WithDisplayName(displayName)) - if err != nil { - return err - } - - wf.State = map[string]interface{}{ - "amount": gridclient.FromUSDToUSDMillicent(transaction.Amount), - "username": user.Username, - "transfer_mode": models.AdminCreditMode, - "admin_id": transaction.AdminID, - "config": map[string]interface{}{ - "user_id": transaction.UserID, - "mnemonic": user.Mnemonic, - }, - } - - if err = persistence.SetStateUserID(&wf, transaction.AdminID); err != nil { - return err - } - - return svc.ewfEngine.Run(svc.appCtx, wf, ewf.WithAsync()) -} - func (svc *AdminService) GenerateVouchers(count, expireAfterDays int, voucherValue float64) ([]models.Voucher, error) { var vouchers []models.Voucher @@ -194,32 +164,21 @@ func (svc *AdminService) ListAllVouchers() ([]models.Voucher, error) { return svc.voucherRepo.ListAllVouchers() } -func (svc *AdminService) ListAllPendingRecordsWithUSDAmounts() ([]PendingRecordsWithUSDAmounts, error) { - pendingRecords, err := svc.prRepo.ListAllPendingRecords() +func (svc *AdminService) ListAllTransferRecordsWithTFTAmount() ([]TransferRecordsWithTFTAmount, error) { + transferRecords, err := svc.transferRecordsRepo.ListTransferRecords() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to list transfer records: %w", err) } - var pendingRecordsWithUSDAmounts []PendingRecordsWithUSDAmounts - for _, record := range pendingRecords { - usdAmount, err := svc.gridClient.FromTFTtoUSDMillicent(record.TFTAmount) - if err != nil { - return nil, err - } - - usdTransferredAmount, err := svc.gridClient.FromTFTtoUSDMillicent(record.TransferredTFTAmount) - if err != nil { - return nil, err - } - - pendingRecordsWithUSDAmounts = append(pendingRecordsWithUSDAmounts, PendingRecordsWithUSDAmounts{ - PendingRecord: record, - USDAmount: gridclient.FromUSDMilliCentToUSD(usdAmount), - TransferredUSDAmount: gridclient.FromUSDMilliCentToUSD(usdTransferredAmount), + var transferRecordsResponse []TransferRecordsWithTFTAmount + for _, transferRecord := range transferRecords { + transferRecordsResponse = append(transferRecordsResponse, TransferRecordsWithTFTAmount{ + TransferRecord: transferRecord, + TFTAmountInWholeUnit: float32(transferRecord.TFTAmount) / TFTUnitFactor, }) } - return pendingRecordsWithUSDAmounts, nil + return transferRecordsResponse, nil } func (svc *AdminService) generateVoucherWithTimestamp() string { @@ -228,6 +187,20 @@ func (svc *AdminService) generateVoucherWithTimestamp() string { return fmt.Sprintf("%s-%s", voucherCode, timestampPart) } +func (svc *AdminService) CreditUserBalance(ctx context.Context, transaction models.Transaction, user *models.User) error { + if err := svc.transRepo.CreateTransaction(&transaction); err != nil { + return fmt.Errorf("failed to create transaction: %w", err) + } + + millicentAmount := gridclient.FromUSDToUSDMillicent(transaction.Amount) + user.CreditedBalance += millicentAmount + if err := svc.userRepo.UpdateUserByID(&models.User{ID: transaction.UserID}); err != nil { + return fmt.Errorf("failed to update user: %w", err) + } + + return nil +} + // AsyncDrainUserUSD drains a specific user's balance to the system account func (svc *AdminService) AsyncDrainUserUSD(userID, adminID int) error { wf, err := svc.ewfEngine.NewWorkflow(workflows.WorkflowDrainUser, ewf.WithDisplayName("Drain user balance")) diff --git a/backend/internal/core/services/admin_service_test.go b/backend/internal/core/services/admin_service_test.go index ba0ebe539..cddfff70c 100644 --- a/backend/internal/core/services/admin_service_test.go +++ b/backend/internal/core/services/admin_service_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "kubecloud/internal/core/models" "kubecloud/internal/infrastructure/gridclient" @@ -14,52 +15,80 @@ import ( "github.com/stretchr/testify/require" ) -type mockUserNodesRepo struct { +// We're using the mock implementations from user_service_test.go +// Adding only the missing methods needed for DeductUserBalance + +func (m *mockUserRepo) DeductUserBalance(user *models.User, amount uint64) error { + args := m.Called(user, amount) + return args.Error(0) +} + +func (m *mockUserRepo) GetUserLastCalcTime(userID int) (time.Time, error) { + args := m.Called(userID) + return args.Get(0).(time.Time), args.Error(1) +} + +func (m *mockUserRepo) UpdateUserLastCalcTime(userID int, lastCalcTime time.Time) error { + args := m.Called(userID, lastCalcTime) + return args.Error(0) +} + +// Contract data repository mock +type mockContractDataRepo struct { mock.Mock } -func (m *mockUserNodesRepo) CreateUserNode(userNode *models.UserNodes) error { - args := m.Called(userNode) +func (m *mockContractDataRepo) CreateUserContractData(contractData *models.UserContractData) error { + args := m.Called(contractData) return args.Error(0) } -func (m *mockUserNodesRepo) DeleteUserNode(contractID uint64) error { +func (m *mockContractDataRepo) DeleteUserContract(contractID uint64) error { args := m.Called(contractID) return args.Error(0) } -func (m *mockUserNodesRepo) ListUserNodes(userID int) ([]models.UserNodes, error) { +func (m *mockContractDataRepo) ListUserRentedNodes(userID int) ([]models.UserContractData, error) { args := m.Called(userID) if args.Get(0) == nil { return nil, args.Error(1) } - return args.Get(0).([]models.UserNodes), args.Error(1) + return args.Get(0).([]models.UserContractData), args.Error(1) } -func (m *mockUserNodesRepo) GetUserNodeByNodeID(nodeID uint64) (models.UserNodes, error) { +func (m *mockContractDataRepo) GetUserNodeByNodeID(nodeID uint64) (models.UserContractData, error) { args := m.Called(nodeID) if args.Get(0) == nil { - return models.UserNodes{}, args.Error(1) + return models.UserContractData{}, args.Error(1) } - return args.Get(0).(models.UserNodes), args.Error(1) + return args.Get(0).(models.UserContractData), args.Error(1) } -func (m *mockUserNodesRepo) GetUserNodeByContractID(contractID uint64) (models.UserNodes, error) { +func (m *mockContractDataRepo) GetUserNodeByContractID(contractID uint64) (models.UserContractData, error) { args := m.Called(contractID) if args.Get(0) == nil { - return models.UserNodes{}, args.Error(1) + return models.UserContractData{}, args.Error(1) } - return args.Get(0).(models.UserNodes), args.Error(1) + return args.Get(0).(models.UserContractData), args.Error(1) } -func (m *mockUserNodesRepo) ListAllReservedNodes() ([]models.UserNodes, error) { +func (m *mockContractDataRepo) ListAllReservedNodes() ([]models.UserContractData, error) { args := m.Called() if args.Get(0) == nil { return nil, args.Error(1) } - return args.Get(0).([]models.UserNodes), args.Error(1) + return args.Get(0).([]models.UserContractData), args.Error(1) +} + +func (m *mockContractDataRepo) ListAllContractsInPeriod(userID int, start, end time.Time) ([]models.UserContractData, error) { + args := m.Called(userID, start, end) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]models.UserContractData), args.Error(1) } +// Transaction repository mock type mockTransactionRepo struct { mock.Mock } @@ -69,13 +98,65 @@ func (m *mockTransactionRepo) CreateTransaction(transaction *models.Transaction) return args.Error(0) } +// Transfer record repository mock +type mockTransferRecordRepo struct { + mock.Mock +} + +func (m *mockTransferRecordRepo) CreateTransferRecord(record *models.TransferRecord) error { + args := m.Called(record) + return args.Error(0) +} + +func (m *mockTransferRecordRepo) ListTransferRecords() ([]models.TransferRecord, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]models.TransferRecord), args.Error(1) +} + +func (m *mockTransferRecordRepo) ListUserTransferRecords(userID int) ([]models.TransferRecord, error) { + args := m.Called(userID) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]models.TransferRecord), args.Error(1) +} + +func (m *mockTransferRecordRepo) ListPendingTransferRecords() ([]models.TransferRecord, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]models.TransferRecord), args.Error(1) +} + +func (m *mockTransferRecordRepo) ListFailedTransferRecords() ([]models.TransferRecord, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]models.TransferRecord), args.Error(1) +} + +func (m *mockTransferRecordRepo) UpdateTransferRecordState(recordID int, state models.State, failure string) error { + args := m.Called(recordID, state, failure) + return args.Error(0) +} + +func (m *mockTransferRecordRepo) CalculateTotalPendingTFTAmountPerUser(userID int) (uint64, error) { + args := m.Called(userID) + return args.Get(0).(uint64), args.Error(1) +} + var dummyMailService = mailservice.MailService{} // Test 1: ListAllUsers - SUCCESS func TestAdminService_ListAllUsers_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -90,7 +171,7 @@ func TestAdminService_ListAllUsers_Success(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -105,8 +186,8 @@ func TestAdminService_ListAllUsers_Success(t *testing.T) { // Test 2: ListAllUsers - EMPTY func TestAdminService_ListAllUsers_Empty(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -116,7 +197,7 @@ func TestAdminService_ListAllUsers_Empty(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -129,8 +210,8 @@ func TestAdminService_ListAllUsers_Empty(t *testing.T) { // Test 3: ListAllUsers - ERROR func TestAdminService_ListAllUsers_Error(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -140,7 +221,7 @@ func TestAdminService_ListAllUsers_Error(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -153,8 +234,8 @@ func TestAdminService_ListAllUsers_Error(t *testing.T) { // Test 4: DeleteUserByID - SUCCESS func TestAdminService_DeleteUserByID_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -164,7 +245,7 @@ func TestAdminService_DeleteUserByID_Success(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -177,8 +258,8 @@ func TestAdminService_DeleteUserByID_Success(t *testing.T) { // Test 5: DeleteUserByID - USER NOT FOUND func TestAdminService_DeleteUserByID_NotFound(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -188,7 +269,7 @@ func TestAdminService_DeleteUserByID_NotFound(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -201,8 +282,8 @@ func TestAdminService_DeleteUserByID_NotFound(t *testing.T) { // Test 6: GenerateVouchers - SUCCESS func TestAdminService_GenerateVouchers_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -215,7 +296,7 @@ func TestAdminService_GenerateVouchers_Success(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -230,8 +311,8 @@ func TestAdminService_GenerateVouchers_Success(t *testing.T) { // Test 7: GenerateVouchers - ZERO COUNT func TestAdminService_GenerateVouchers_ZeroCount(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -239,7 +320,7 @@ func TestAdminService_GenerateVouchers_ZeroCount(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) @@ -252,8 +333,8 @@ func TestAdminService_GenerateVouchers_ZeroCount(t *testing.T) { // Test 8: GenerateVouchers - LARGE COUNT func TestAdminService_GenerateVouchers_LargeCount(t *testing.T) { mockUserRepo := new(mockUserRepo) - mockNodesRepo := new(mockUserNodesRepo) - mockPRRepo := new(mockPendingRecordRepo) + mockContractsRepo := new(mockContractDataRepo) + mockTransferRecordsRepo := new(mockTransferRecordRepo) mockVoucherRepo := new(mockVoucherRepo) mockTransRepo := new(mockTransactionRepo) @@ -266,7 +347,7 @@ func TestAdminService_GenerateVouchers_LargeCount(t *testing.T) { service := NewAdminService( context.Background(), - mockUserRepo, mockNodesRepo, mockPRRepo, mockVoucherRepo, mockTransRepo, + mockUserRepo, mockContractsRepo, mockTransferRecordsRepo, mockVoucherRepo, mockTransRepo, gridClient, nil, dummyMailService, nil, nil, ) diff --git a/backend/internal/core/services/billing_service.go b/backend/internal/core/services/billing_service.go new file mode 100644 index 000000000..91644fa9a --- /dev/null +++ b/backend/internal/core/services/billing_service.go @@ -0,0 +1,498 @@ +package services + +import ( + "context" + "fmt" + "kubecloud/internal/billing" + "kubecloud/internal/core/models" + "kubecloud/internal/deployment/kubedeployer" + "kubecloud/internal/infrastructure/gridclient" + "kubecloud/internal/infrastructure/logger" + "math" + "strconv" + "time" + + "github.com/threefoldtech/tfgrid-sdk-go/grid-client/graphql" + "github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/pkg/types" +) + +type Discount string + +type DiscountPackage struct { + DurationInMonth float64 + Discount int +} + +type BillingService struct { + userRepo models.UserRepository + contractsRepo models.ContractDataRepository + transferRecordsRepo models.TransferRecordRepository + clusterRepo models.ClusterRepository + + gridClient gridclient.GridClient + graphql graphql.GraphQl + + minimumTFTAmountInWallet uint64 + appliedDiscount Discount +} + +func NewBillingService(userRepo models.UserRepository, contractsRepo models.ContractDataRepository, + transferRecordsRepo models.TransferRecordRepository, clusterRepo models.ClusterRepository, + graphql graphql.GraphQl, gridClient gridclient.GridClient, + minimumTFTAmountInWallet uint64, appliedDiscount Discount, +) BillingService { + return BillingService{ + userRepo: userRepo, + contractsRepo: contractsRepo, + transferRecordsRepo: transferRecordsRepo, + clusterRepo: clusterRepo, + + gridClient: gridClient, + graphql: graphql, + + appliedDiscount: appliedDiscount, + minimumTFTAmountInWallet: minimumTFTAmountInWallet, + } +} + +func (svc *BillingService) SettleUserUsage(user *models.User) error { + usageInUSDMillicent, err := svc.getUserLatestUsageInUSD(user.ID) + if err != nil { + return err + } + + return svc.userRepo.DeductUserBalance(user, usageInUSDMillicent) +} + +func (svc *BillingService) AfterUserGetCredit(ctx context.Context, user *models.User) error { + if err := svc.CreateTransferRecordToChargeUserWithMinTFTAmount(user.ID, user.Username, user.Mnemonic); err != nil { + return err + } + + if err := svc.SettleUserUsage(user); err != nil { + return err + } + + return svc.FundUserToFulfillDiscount(ctx, user, nil, nil) +} + +func (svc *BillingService) CreateTransferRecordToChargeUserWithMinTFTAmount(userID int, username, userMnemonic string) error { + userTFTBalance, err := svc.gridClient.GetFreeBalanceTFT(userMnemonic) + if err != nil { + return err + } + + totalPendingTFTAmount, err := svc.transferRecordsRepo.CalculateTotalPendingTFTAmountPerUser(userID) + if err != nil { + return err + } + + if userTFTBalance+totalPendingTFTAmount >= zeroTFTBalanceValue { + return nil + } + + return svc.transferRecordsRepo.CreateTransferRecord(&models.TransferRecord{ + UserID: userID, + Username: username, + TFTAmount: svc.minimumTFTAmountInWallet * TFTUnitFactor, + Operation: models.DepositOperation, + }) +} + +func (svc *BillingService) FundUserToFulfillDiscount(ctx context.Context, user *models.User, addedRentedNodes []types.Node, addedSharedNodes []kubedeployer.Node) error { + if user.CreditCardBalance+user.CreditedBalance-user.Debt <= 0 { + // user has no USD balance, skip + return nil + } + + // calculate resources usage in USD applying discount + // I took the cluster nodes since only the new node is in cluster.Nodes + dailyUsageInUSDMillicent, err := svc.calculateResourcesUsageInUSDApplyingDiscount(ctx, user.ID, user.Mnemonic, addedRentedNodes, addedSharedNodes, svc.appliedDiscount) + if err != nil { + return err + } + + dailyUsageInTFT, err := svc.gridClient.FromUSDMillicentToTFT(dailyUsageInUSDMillicent) + if err != nil { + return err + } + + totalPendingTFTAmount, err := svc.transferRecordsRepo.CalculateTotalPendingTFTAmountPerUser(user.ID) + if err != nil { + return err + } + + userTFTBalance, err := svc.gridClient.GetFreeBalanceTFT(user.Mnemonic) + if err != nil { + return err + } + + // fund user to fulfill discount + // make sure no old payments will fund more than needed + if totalPendingTFTAmount+userTFTBalance < dailyUsageInTFT && + dailyUsageInTFT > 0 { + if err := svc.transferRecordsRepo.CreateTransferRecord(&models.TransferRecord{ + UserID: user.ID, + Username: user.Username, + TFTAmount: dailyUsageInTFT - userTFTBalance - totalPendingTFTAmount, + Operation: models.DepositOperation, + }); err != nil { + return err + } + } + + return nil +} + +func (svc *BillingService) calculateResourcesUsageInUSDApplyingDiscount( + ctx context.Context, + userID int, + userMnemonic string, + addedRentedNodes []types.Node, + addedSharedNodes []kubedeployer.Node, + configuredDiscount Discount, +) (uint64, error) { + calculator, err := svc.gridClient.NewCalculator(userMnemonic) + if err != nil { + return 0, fmt.Errorf("failed to create calculator: %w", err) + } + + var totalResourcesCostMillicent uint64 + + rentedNodes, _, err := svc.getRentedNodesForUser(ctx, userID, true) + if err != nil { + return 0, err + } + if addedRentedNodes == nil { + rentedNodes = append(rentedNodes, addedRentedNodes...) + } + + // Calculate rented nodes + for _, node := range rentedNodes { + resourcesCost, err := calculator.CalculateCost( + node.TotalResources.CRU, + uint64(node.TotalResources.MRU), + uint64(node.TotalResources.HRU), + uint64(node.TotalResources.SRU), + len(node.PublicConfig.Ipv4) > 0, + node.CertificationType == nodeCertified, + ) + if err != nil { + return 0, err + } + + // resources cost per month + pricingPolicy, err := svc.gridClient.GetPricingPolicy(defaultPricingPolicyID) + if err != nil { + return 0, err + } + dedicatedDiscountPercentage := float64(pricingPolicy.DedicatedNodesDiscount / 100) + totalResourcesCostMillicent += gridclient.FromUSDToUSDMillicent(resourcesCost * dedicatedDiscountPercentage) + } + + sharedNodes, err := svc.getUserNodes(userID) + if err != nil { + return 0, err + } + if addedSharedNodes != nil { + sharedNodes = append(sharedNodes, addedSharedNodes...) + } + + // Calculate shared nodes + for _, node := range sharedNodes { + proxyNode, err := svc.gridClient.Node(ctx, node.NodeID) + if err != nil { + return 0, err + } + + if proxyNode.Rented { + twinID, err := svc.gridClient.GetTwin(userMnemonic) + if err != nil { + return 0, err + } + + if proxyNode.RentedByTwinID == uint(twinID) { + // skip rented nodes as they are already calculated + continue + } + } + + // Calculate total disk size (sum all data disks + root size) + totalDiskSize := node.RootSize + for _, diskSize := range node.DataDisks { + totalDiskSize += diskSize + } + + resourcesCost, err := calculator.CalculateCost( + uint64(node.CPU), + node.Memory, + 0, + totalDiskSize, + false, + proxyNode.CertificationType == nodeCertified, + ) + if err != nil { + return 0, err + } + + // resources cost per month + totalResourcesCostMillicent += gridclient.FromUSDToUSDMillicent(resourcesCost) + } + + // Calculate name contracts + nameContracts, err := svc.listNameContractsForUser(userID) + if err != nil { + return 0, err + } + + nameContractMonthlyCostInUSD, err := svc.calculateUniqueNameMonthlyCost() + if err != nil { + return 0, err + } + + totalResourcesCostMillicent += gridclient.FromUSDToUSDMillicent(float64(len(nameContracts)) * nameContractMonthlyCostInUSD) + + discount := getDiscountPackage(configuredDiscount).DurationInMonth + if discount == 0 { + return totalResourcesCostMillicent, nil + } + + return uint64(float64(totalResourcesCostMillicent) * discount), nil +} + +func (svc *BillingService) getUserNodes(userID int) ([]kubedeployer.Node, error) { + userClusters, err := svc.clusterRepo.ListUserClusters(userID) + if err != nil { + return nil, err + } + + var sharedNodes []kubedeployer.Node + for _, cluster := range userClusters { + clusterResult, err := cluster.GetClusterResult() + if err != nil { + return nil, err + } + sharedNodes = append(sharedNodes, clusterResult.Nodes...) + } + + return sharedNodes, nil +} + +func (svc *BillingService) calculateUniqueNameMonthlyCost() (float64, error) { + pricingPolicy, err := svc.gridClient.GetPricingPolicy(defaultPricingPolicyID) + if err != nil { + return 0, err + } + + // cost in unit-USD + monthlyCost := float64(pricingPolicy.UniqueName.Value) * 24 * 30 + + costInUSD := monthlyCost / TFTUnitFactor + return costInUSD, nil +} + +func (svc *BillingService) getRentedNodesForUser(ctx context.Context, userID int, healthy bool) ([]types.Node, int, error) { + twinID, err := svc.getTwinIDFromUserID(userID) + if err != nil { + return nil, 0, err + } + + filter := types.NodeFilter{ + RentedBy: &twinID, + Features: Zos3NodeFeatures, + } + + if healthy { + filter.Healthy = &healthy + } + + limit := types.DefaultLimit() + + nodes, count, err := svc.gridClient.Nodes(ctx, filter, limit) + if err != nil { + return nil, 0, err + } + + return nodes, count, nil +} + +func (svc *BillingService) listNameContractsForUser(userID int) ([]graphql.Contract, error) { + twinID, err := svc.getTwinIDFromUserID(userID) + if err != nil { + return nil, err + } + + contractGetter := svc.gridClient.NewContractsGetter( + uint32(twinID), + svc.graphql, + ) + + contractsList, err := contractGetter.ListContractsByTwinID([]string{"Created, GracePeriod"}) + if err != nil { + return nil, err + } + + return contractsList.NameContracts, nil +} + +func (svc *BillingService) getTwinIDFromUserID(userID int) (uint64, error) { + user, err := svc.userRepo.GetUserByID(userID) + if err != nil { + return 0, err + } + + twinID, err := svc.gridClient.GetTwin(user.Mnemonic) + if err != nil { + return 0, err + } + + return uint64(twinID), nil +} + +func (svc *BillingService) getUserLatestUsageInUSD(userID int) (uint64, error) { + now := time.Now() + // Define the end of the day (next day at 00:00) + endOfDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, time.Local) + + // Get the last calculation time for this user from the database, or use a default if not available + lastCalcTime, err := svc.userRepo.GetUserLastCalcTime(userID) + if err != nil { + return 0, err + } + + // If this is the first time or no record exists, use the start of the day as default + if lastCalcTime.IsZero() { + lastCalcTime = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.Local) + } + + contracts, err := svc.contractsRepo.ListAllContractsInPeriod(userID, lastCalcTime, endOfDay) + if err != nil { + return 0, err + } + + if len(contracts) == 0 { + return 0, nil + } + + var totalDailyUsageInUSDMillicent uint64 + + for _, record := range contracts { + // Get bill reports from the last calculation time to the end of day + billReports, err := billing.ListContractBillReports(svc.graphql, record.ContractID, lastCalcTime, endOfDay) + if err != nil { + return 0, err + } + + totalAmountBilledInUSDMillicent, err := svc.calculateTotalUsageOfReportsInUSDMillicent(billReports.Reports) + if err != nil { + return 0, err + } + + totalDailyUsageInUSDMillicent += totalAmountBilledInUSDMillicent + } + + // Update the last calculation time for this user in the database + if err := svc.userRepo.UpdateUserLastCalcTime(userID, now); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to update last calculation time for user %d", userID) + } + + return totalDailyUsageInUSDMillicent, nil +} + +func (svc *BillingService) calculateTotalUsageOfReportsInUSDMillicent(reports []billing.Report) (uint64, error) { + var totalAmountBilledInUSDMillicent uint64 + for _, report := range reports { + amountInTFT, err := removeDiscountFromReport(&report) + if err != nil { + return 0, err + } + + amountInUSDMillicent, err := svc.fromTFTtoUSDMillicent(amountInTFT, report) + if err != nil { + return 0, err + } + + totalAmountBilledInUSDMillicent += amountInUSDMillicent + } + + return totalAmountBilledInUSDMillicent, nil +} + +func (svc *BillingService) fromTFTtoUSDMillicent(amount uint64, report billing.Report) (uint64, error) { + price, err := svc.getBillingRateAt(report) + if err != nil { + return 0, err + } + + usdMillicentBalance := uint64(math.Round((float64(amount) / TFTUnitFactor) * float64(price))) + return usdMillicentBalance, nil +} + +func (svc *BillingService) getBillingRateAt(report billing.Report) (float64, error) { + block_duration := 6 // in seconds + now := time.Now().Unix() + + reportTimestamp, err := strconv.ParseInt(report.Timestamp, 10, 64) + if err != nil { + return 0, err + } + + timeBetweenNowAndReport := now - reportTimestamp // seconds + + // Calculate number of blocks since report + numberOfBlocks := math.Round(float64(timeBetweenNowAndReport) / float64(block_duration)) + + nowBlock, err := svc.gridClient.GetCurrentHeight() + if err != nil { + return 0, err + } + reportBlock := nowBlock - uint32(numberOfBlocks) + + tftPrice, err := svc.gridClient.GetTFTBillingRateAt(uint64(reportBlock)) + if err != nil { + return 0, err + } + + return float64(tftPrice), nil +} + +func removeDiscountFromReport(report *billing.Report) (uint64, error) { + discountPackage := getDiscountPackage(Discount(report.DiscountReceived)) + + amountBilled, err := strconv.ParseInt(report.AmountBilled, 10, 64) + if err != nil { + return 0, err + } + + amountBilledWithNoDiscount := float64(amountBilled) / float64(1-discountPackage.Discount/100) + return uint64(amountBilledWithNoDiscount), nil +} + +func getDiscountPackage(discountInput Discount) DiscountPackage { + oneDayMargin := 1.0 / 30.0 + + discountPackages := map[Discount]DiscountPackage{ + "none": { + DurationInMonth: oneDayMargin * 3, + Discount: 0, + }, + "default": { + DurationInMonth: 1.5 + oneDayMargin, + Discount: 20, + }, + "bronze": { + DurationInMonth: 3 + oneDayMargin, + Discount: 30, + }, + "silver": { + DurationInMonth: 6 + oneDayMargin, + Discount: 40, + }, + "gold": { + DurationInMonth: 10 + oneDayMargin, + Discount: 60, + }, + } + + return discountPackages[discountInput] +} diff --git a/backend/internal/core/services/deployment_service.go b/backend/internal/core/services/deployment_service.go index f7d3e90b0..0fd1e2c35 100644 --- a/backend/internal/core/services/deployment_service.go +++ b/backend/internal/core/services/deployment_service.go @@ -20,8 +20,9 @@ import ( ) type DeploymentService struct { - clusterRepo models.ClusterRepository - userRepo models.UserRepository + clusterRepo models.ClusterRepository + userRepo models.UserRepository + contractsRepo models.ContractDataRepository appCtx context.Context ewfEngine *ewf.Engine @@ -36,12 +37,14 @@ type DeploymentService struct { func NewDeploymentService(appCtx context.Context, clusterRepo models.ClusterRepository, userRepo models.UserRepository, - userNodesRepo models.UserNodesRepository, ewfEngine *ewf.Engine, + contractsRepo models.ContractDataRepository, + ewfEngine *ewf.Engine, debug bool, sshPublicKey, sshPrivateKeyPath, systemNetwork string, ) DeploymentService { return DeploymentService{ - clusterRepo: clusterRepo, - userRepo: userRepo, + clusterRepo: clusterRepo, + userRepo: userRepo, + contractsRepo: contractsRepo, appCtx: appCtx, ewfEngine: ewfEngine, @@ -164,7 +167,7 @@ func (svc *DeploymentService) GetClusterKubeconfig(ctx context.Context, cluster } cluster.Kubeconfig = kubeconfig - if err = svc.clusterRepo.UpdateCluster(cluster); err != nil { + if err = svc.clusterRepo.UpdateCluster(svc.contractsRepo, cluster); err != nil { telemetry.RecordError(span, err) return "", err } @@ -172,8 +175,12 @@ func (svc *DeploymentService) GetClusterKubeconfig(ctx context.Context, cluster return kubeconfig, nil } +func (svc *DeploymentService) GetUserByID(userID int) (models.User, error) { + return svc.userRepo.GetUserByID(userID) +} + func (svc *DeploymentService) GetClientConfig(userID int) (statemanager.ClientConfig, error) { - user, err := svc.userRepo.GetUserByID(userID) + user, err := svc.GetUserByID(userID) if err != nil { return statemanager.ClientConfig{}, fmt.Errorf("failed to get user: %v", err) } diff --git a/backend/internal/core/services/deployment_service_test.go b/backend/internal/core/services/deployment_service_test.go index a084e5c3b..109a718dc 100644 --- a/backend/internal/core/services/deployment_service_test.go +++ b/backend/internal/core/services/deployment_service_test.go @@ -36,8 +36,8 @@ func (m *mockClusterRepo) GetClusterByName(userID int, projectName string) (mode return args.Get(0).(models.Cluster), args.Error(1) } -func (m *mockClusterRepo) UpdateCluster(cluster *models.Cluster) error { - args := m.Called(cluster) +func (m *mockClusterRepo) UpdateCluster(contractsRepo models.ContractDataRepository, cluster *models.Cluster) error { + args := m.Called(contractsRepo, cluster) return args.Error(0) } @@ -46,8 +46,8 @@ func (m *mockClusterRepo) DeleteCluster(userID int, projectName string) error { return args.Error(0) } -func (m *mockClusterRepo) DeleteAllUserClusters(userID int) error { - args := m.Called(userID) +func (m *mockClusterRepo) DeleteAllUserClusters(contractsRepo models.ContractDataRepository, userID int) error { + args := m.Called(contractsRepo, userID) return args.Error(0) } @@ -78,8 +78,9 @@ func TestDeploymentService_GetClusterByName_Success(t *testing.T) { mockClusterRepo.On("GetClusterByName", 1, "test-project").Return(cluster, nil) service := DeploymentService{ - clusterRepo: mockClusterRepo, - userRepo: mockUserRepo, + clusterRepo: mockClusterRepo, + userRepo: mockUserRepo, + contractsRepo: new(mockContractDataRepo), } result, err := service.GetClusterByName(1, "test-project") @@ -97,8 +98,9 @@ func TestDeploymentService_GetClusterByName_NotFound(t *testing.T) { mockClusterRepo.On("GetClusterByName", 1, "nonexistent").Return(models.Cluster{}, fmt.Errorf("cluster not found")) service := DeploymentService{ - clusterRepo: mockClusterRepo, - userRepo: mockUserRepo, + clusterRepo: mockClusterRepo, + userRepo: mockUserRepo, + contractsRepo: new(mockContractDataRepo), } _, err := service.GetClusterByName(1, "nonexistent") @@ -128,8 +130,9 @@ func TestDeploymentService_ListUserClusters_Success(t *testing.T) { mockClusterRepo.On("ListUserClusters", 1).Return(clusters, nil) service := DeploymentService{ - clusterRepo: mockClusterRepo, - userRepo: mockUserRepo, + clusterRepo: mockClusterRepo, + userRepo: mockUserRepo, + contractsRepo: new(mockContractDataRepo), } result, err := service.ListUserClusters(1) @@ -147,8 +150,9 @@ func TestDeploymentService_ListUserClusters_Empty(t *testing.T) { mockClusterRepo.On("ListUserClusters", 999).Return([]models.Cluster{}, nil) service := DeploymentService{ - clusterRepo: mockClusterRepo, - userRepo: mockUserRepo, + clusterRepo: mockClusterRepo, + userRepo: mockUserRepo, + contractsRepo: new(mockContractDataRepo), } result, err := service.ListUserClusters(999) @@ -165,8 +169,9 @@ func TestDeploymentService_ListUserClusters_Error(t *testing.T) { mockClusterRepo.On("ListUserClusters", 1).Return(nil, fmt.Errorf("database error")) service := DeploymentService{ - clusterRepo: mockClusterRepo, - userRepo: mockUserRepo, + clusterRepo: mockClusterRepo, + userRepo: mockUserRepo, + contractsRepo: new(mockContractDataRepo), } _, err := service.ListUserClusters(1) diff --git a/backend/internal/core/services/node_service.go b/backend/internal/core/services/node_service.go index 49ee228ef..180717fae 100644 --- a/backend/internal/core/services/node_service.go +++ b/backend/internal/core/services/node_service.go @@ -23,8 +23,8 @@ var Zos3NodeFeatures = []string{ } type NodeService struct { - nodesRepo models.UserNodesRepository - userRepo models.UserRepository + contractsRepo models.ContractDataRepository + userRepo models.UserRepository appCtx context.Context ewfEngine *ewf.Engine @@ -33,16 +33,16 @@ type NodeService struct { } func NewNodeService( - userNodesRepo models.UserNodesRepository, userRepo models.UserRepository, + contractsRepo models.ContractDataRepository, userRepo models.UserRepository, appCtx context.Context, ewfEngine *ewf.Engine, gridClient gridclient.GridClient, ) NodeService { return NodeService{ - nodesRepo: userNodesRepo, - userRepo: userRepo, - appCtx: appCtx, - ewfEngine: ewfEngine, - gridClient: gridClient, - tracer: telemetry.NewServiceTracer("node_service"), + contractsRepo: contractsRepo, + userRepo: userRepo, + appCtx: appCtx, + ewfEngine: ewfEngine, + gridClient: gridClient, + tracer: telemetry.NewServiceTracer("node_service"), } } @@ -82,8 +82,8 @@ func (svc *NodeService) GetUserByID(userID int) (models.User, error) { return svc.userRepo.GetUserByID(userID) } -func (svc *NodeService) GetUserNodeByNodeID(nodeID uint32) (models.UserNodes, error) { - return svc.nodesRepo.GetUserNodeByNodeID(uint64(nodeID)) +func (svc *NodeService) GetUserNodeByNodeID(nodeID uint32) (models.UserContractData, error) { + return svc.contractsRepo.GetUserNodeByNodeID(uint64(nodeID)) } func (svc *NodeService) CheckUserBalanceForOneHour(ctx context.Context, userMnemonic string, userDebt uint64, nodePriceUsd float64) error { @@ -117,8 +117,8 @@ func (svc *NodeService) CheckUserBalanceForOneHour(ctx context.Context, userMnem return nil } -func (svc *NodeService) GetUserNodeByContractID(contractID uint64) (models.UserNodes, error) { - return svc.nodesRepo.GetUserNodeByContractID(contractID) +func (svc *NodeService) GetUserNodeByContractID(contractID uint64) (models.UserContractData, error) { + return svc.contractsRepo.GetUserNodeByContractID(contractID) } func (svc *NodeService) GetTwinIDFromUserID(ctx context.Context, userID int) (uint64, error) { diff --git a/backend/internal/core/services/node_service_test.go b/backend/internal/core/services/node_service_test.go index 7eecc3443..e70a77d07 100644 --- a/backend/internal/core/services/node_service_test.go +++ b/backend/internal/core/services/node_service_test.go @@ -12,10 +12,10 @@ import ( // Test 1: NodeService - GetUserNodeByNodeID SUCCESS func TestNodeService_GetUserNodeByNodeID_Success(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) - node := models.UserNodes{ + node := models.UserContractData{ ID: 1, UserID: 1, NodeID: 100, @@ -25,8 +25,8 @@ func TestNodeService_GetUserNodeByNodeID_Success(t *testing.T) { mockNodesRepo.On("GetUserNodeByNodeID", uint64(100)).Return(node, nil) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } result, err := service.GetUserNodeByNodeID(100) @@ -37,14 +37,14 @@ func TestNodeService_GetUserNodeByNodeID_Success(t *testing.T) { // Test 2: NodeService - GetUserNodeByNodeID NOT FOUND func TestNodeService_GetUserNodeByNodeID_NotFound(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) - mockNodesRepo.On("GetUserNodeByNodeID", uint64(999)).Return(models.UserNodes{}, fmt.Errorf("node not found")) + mockNodesRepo.On("GetUserNodeByNodeID", uint64(999)).Return(models.UserContractData{}, fmt.Errorf("node not found")) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } _, err := service.GetUserNodeByNodeID(999) @@ -55,10 +55,10 @@ func TestNodeService_GetUserNodeByNodeID_NotFound(t *testing.T) { // Test 3: NodeService - GetUserNodeByContractID SUCCESS func TestNodeService_GetUserNodeByContractID_Success(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) - node := models.UserNodes{ + node := models.UserContractData{ ID: 1, UserID: 1, NodeID: 100, @@ -68,8 +68,8 @@ func TestNodeService_GetUserNodeByContractID_Success(t *testing.T) { mockNodesRepo.On("GetUserNodeByContractID", uint64(123)).Return(node, nil) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } result, err := service.GetUserNodeByContractID(123) @@ -80,14 +80,14 @@ func TestNodeService_GetUserNodeByContractID_Success(t *testing.T) { // Test 4: NodeService - GetUserNodeByContractID ERROR func TestNodeService_GetUserNodeByContractID_Error(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) - mockNodesRepo.On("GetUserNodeByContractID", uint64(999)).Return(models.UserNodes{}, fmt.Errorf("contract error")) + mockNodesRepo.On("GetUserNodeByContractID", uint64(999)).Return(models.UserContractData{}, fmt.Errorf("contract error")) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } _, err := service.GetUserNodeByContractID(999) @@ -98,7 +98,7 @@ func TestNodeService_GetUserNodeByContractID_Error(t *testing.T) { // Test 5: NodeService - GetUserByID SUCCESS func TestNodeService_GetUserByID_Success(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) user := models.User{ @@ -110,8 +110,8 @@ func TestNodeService_GetUserByID_Success(t *testing.T) { mockUserRepo.On("GetUserByID", 1).Return(user, nil) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } result, err := service.GetUserByID(1) @@ -122,14 +122,14 @@ func TestNodeService_GetUserByID_Success(t *testing.T) { // Test 6: NodeService - GetUserByID NOT FOUND func TestNodeService_GetUserByID_NotFound(t *testing.T) { - mockNodesRepo := new(mockUserNodesRepo) + mockNodesRepo := new(mockContractDataRepo) mockUserRepo := new(mockUserRepo) mockUserRepo.On("GetUserByID", 999).Return(models.User{}, fmt.Errorf("user not found")) service := NodeService{ - nodesRepo: mockNodesRepo, - userRepo: mockUserRepo, + contractsRepo: mockNodesRepo, + userRepo: mockUserRepo, } _, err := service.GetUserByID(999) diff --git a/backend/internal/core/services/user_service.go b/backend/internal/core/services/user_service.go index 74f801666..297cd0675 100644 --- a/backend/internal/core/services/user_service.go +++ b/backend/internal/core/services/user_service.go @@ -20,7 +20,6 @@ import ( type UserService struct { userRepo models.UserRepository voucherRepo models.VoucherRepository - prRepo models.PendingRecordRepository appCtx context.Context gridClient gridclient.GridClient @@ -36,7 +35,6 @@ type UserService struct { func NewUserService(appCtx context.Context, userRepo models.UserRepository, voucherRepo models.VoucherRepository, - pendingRecordRepo models.PendingRecordRepository, gridClient gridclient.GridClient, ewfEngine *ewf.Engine, kycClient *kyc.KYCClient, @@ -47,7 +45,6 @@ func NewUserService(appCtx context.Context, return UserService{ userRepo: userRepo, voucherRepo: voucherRepo, - prRepo: pendingRecordRepo, appCtx: appCtx, gridClient: gridClient, @@ -60,9 +57,12 @@ func NewUserService(appCtx context.Context, } } -type UserWithPendingBalance struct { +type UserWithBalancesInUSD struct { models.User - PendingBalanceUSD float64 `json:"pending_balance_usd"` + CreditCardBalanceInUSD float64 `json:"credit_card_balance_in_usd"` + CreditedBalanceInUSD float64 `json:"credited_balance_in_usd"` + DebtInUSD float64 `json:"debt_in_usd"` + BalanceInTFT float64 `json:"balance_in_tft,omitempty"` } func (svc *UserService) GetUserByEmail(email string) (models.User, error) { @@ -73,70 +73,20 @@ func (svc *UserService) GetUserByID(userID int) (models.User, error) { return svc.userRepo.GetUserByID(userID) } -func (svc *UserService) GetUserWithPendingBalance(userID int) (UserWithPendingBalance, error) { +func (svc *UserService) GetUserWithBalancesInUSD(userID int) (UserWithBalancesInUSD, error) { user, err := svc.GetUserByID(userID) if err != nil { - return UserWithPendingBalance{}, err + return UserWithBalancesInUSD{}, err } - usdMillicentPendingAmount, err := svc.GetUserPendingBalanceInUSDMillicent(userID) - if err != nil { - return UserWithPendingBalance{}, err - } - - return UserWithPendingBalance{ - User: user, - PendingBalanceUSD: gridclient.FromUSDMilliCentToUSD(usdMillicentPendingAmount), + return UserWithBalancesInUSD{ + User: user, + CreditedBalanceInUSD: gridclient.FromUSDMilliCentToUSD(user.CreditedBalance), + CreditCardBalanceInUSD: gridclient.FromUSDMilliCentToUSD(user.CreditCardBalance), + DebtInUSD: gridclient.FromUSDMilliCentToUSD(user.Debt), }, nil } -func (svc *UserService) GetUserPendingBalanceInUSDMillicent(userID int) (uint64, error) { - pendingRecords, err := svc.prRepo.ListUserPendingRecords(userID) - if err != nil { - return 0, err - } - - var tftPendingAmount uint64 - for _, record := range pendingRecords { - tftPendingAmount += record.TFTAmount - record.TransferredTFTAmount - } - - usdMillicentPendingAmount, err := svc.gridClient.FromTFTtoUSDMillicent(tftPendingAmount) - if err != nil { - return 0, err - } - - return usdMillicentPendingAmount, nil -} - -func (svc *UserService) ListUserPendingRecordsWithUSDAmounts(userID int) ([]PendingRecordsWithUSDAmounts, error) { - pendingRecords, err := svc.prRepo.ListUserPendingRecords(userID) - if err != nil { - return nil, err - } - - var pendingRecordsWithUSDAmounts []PendingRecordsWithUSDAmounts - for _, record := range pendingRecords { - usdAmount, err := svc.gridClient.FromTFTtoUSDMillicent(record.TFTAmount) - if err != nil { - return nil, err - } - - usdTransferredAmount, err := svc.gridClient.FromTFTtoUSDMillicent(record.TransferredTFTAmount) - if err != nil { - return nil, err - } - - pendingRecordsWithUSDAmounts = append(pendingRecordsWithUSDAmounts, PendingRecordsWithUSDAmounts{ - PendingRecord: record, - USDAmount: gridclient.FromUSDMilliCentToUSD(usdAmount), - TransferredUSDAmount: gridclient.FromUSDMilliCentToUSD(usdTransferredAmount), - }) - } - - return pendingRecordsWithUSDAmounts, nil -} - func (svc *UserService) ListRemainingWorkflowsByUserID(userID int) ([]*ewf.Workflow, error) { records, err := svc.userRepo.ListRemainingWorkflowsByUserID(userID) if err != nil { @@ -267,36 +217,6 @@ func (svc *UserService) AsyncStripeChargeBalance(userID int, userStripeCustomerI "payment_method_id": paymentMethodID, "amount": gridclient.FromUSDToUSDMillicent(requestAmount), "username": username, - "transfer_mode": models.ChargeBalanceMode, - "config": map[string]interface{}{ - "user_id": userID, - "mnemonic": userMnemonic, - }, - } - - if err = persistence.SetStateUserID(&wf, userID); err != nil { - return "", err - } - - err = svc.ewfEngine.Run(svc.appCtx, wf, ewf.WithAsync()) - return wf.UUID, err -} - -func (svc *UserService) AsyncRedeemVoucher(userID int, voucherValue float64, userMnemonic, userUsername, voucherCode string) (string, error) { - err := svc.voucherRepo.RedeemVoucher(userID, userUsername, voucherCode) - if err != nil { - return "", err - } - - wf, err := svc.ewfEngine.NewWorkflow(workflows.WorkflowRedeemVoucher, ewf.WithDisplayName(fmt.Sprintf("Redeem voucher %s", voucherCode))) - if err != nil { - return "", err - } - - wf.State = map[string]interface{}{ - "amount": gridclient.FromUSDToUSDMillicent(voucherValue), - "username": userUsername, - "transfer_mode": models.RedeemVoucherMode, "config": map[string]interface{}{ "user_id": userID, "mnemonic": userMnemonic, diff --git a/backend/internal/core/services/user_service_test.go b/backend/internal/core/services/user_service_test.go index 6f84fac3a..918e57d92 100644 --- a/backend/internal/core/services/user_service_test.go +++ b/backend/internal/core/services/user_service_test.go @@ -145,43 +145,7 @@ func (m *mockVoucherRepo) RedeemVoucher(userID int, username, code string) error return args.Error(0) } -type mockPendingRecordRepo struct { - mock.Mock -} - -func (m *mockPendingRecordRepo) CreatePendingRecord(record *models.PendingRecord) error { - args := m.Called(record) - return args.Error(0) -} - -func (m *mockPendingRecordRepo) ListAllPendingRecords() ([]models.PendingRecord, error) { - args := m.Called() - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]models.PendingRecord), args.Error(1) -} - -func (m *mockPendingRecordRepo) ListOnlyPendingRecords() ([]models.PendingRecord, error) { - args := m.Called() - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]models.PendingRecord), args.Error(1) -} - -func (m *mockPendingRecordRepo) ListUserPendingRecords(userID int) ([]models.PendingRecord, error) { - args := m.Called(userID) - if args.Get(0) == nil { - return nil, args.Error(1) - } - return args.Get(0).([]models.PendingRecord), args.Error(1) -} - -func (m *mockPendingRecordRepo) UpdatePendingRecordTransferredAmount(id int, amount uint64) error { - args := m.Called(id, amount) - return args.Error(0) -} +// Using mockTransferRecordRepo from admin_service_test.go // ============================================================================ // TESTS FOR METHODS THAT DON'T REQUIRE EXTERNAL DEPENDENCIES @@ -191,7 +155,6 @@ func (m *mockPendingRecordRepo) UpdatePendingRecordTransferredAmount(id int, amo func TestUserService_GetUserByEmail_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) expectedUser := models.User{ ID: 1, @@ -204,7 +167,7 @@ func TestUserService_GetUserByEmail_Success(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -219,7 +182,6 @@ func TestUserService_GetUserByEmail_Success(t *testing.T) { func TestUserService_GetUserByEmail_NotFound(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) mockUserRepo.On("GetUserByEmail", "invalid@example.com"). Return(nil, fmt.Errorf("user not found")) @@ -227,7 +189,7 @@ func TestUserService_GetUserByEmail_NotFound(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -241,14 +203,13 @@ func TestUserService_GetUserByEmail_NotFound(t *testing.T) { func TestUserService_CreateSSHKey_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) mockUserRepo.On("CreateSSHKey", mock.AnythingOfType("*models.SSHKey")).Return(nil) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -264,14 +225,13 @@ func TestUserService_CreateSSHKey_Success(t *testing.T) { func TestUserService_DeleteSSHKey_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) mockUserRepo.On("DeleteSSHKey", 5, 1).Return("my-key", nil) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -286,7 +246,6 @@ func TestUserService_DeleteSSHKey_Success(t *testing.T) { func TestUserService_IsVerificationCodeExpired_Expired(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) // Code from 20 minutes ago, timeout is 5 minutes oldTime := time.Now().Add(-20 * time.Minute) @@ -294,7 +253,7 @@ func TestUserService_IsVerificationCodeExpired_Expired(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, // 5 minute timeout ) @@ -307,7 +266,6 @@ func TestUserService_IsVerificationCodeExpired_Expired(t *testing.T) { func TestUserService_IsVerificationCodeExpired_NotExpired(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) // Code from 2 minutes ago, timeout is 5 minutes recentTime := time.Now().Add(-2 * time.Minute) @@ -315,7 +273,7 @@ func TestUserService_IsVerificationCodeExpired_NotExpired(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, // 5 minute timeout ) @@ -328,12 +286,11 @@ func TestUserService_IsVerificationCodeExpired_NotExpired(t *testing.T) { func TestUserService_IsSystemAdmin_True(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{"admin@example.com", "superuser@example.com"}, ) @@ -346,12 +303,11 @@ func TestUserService_IsSystemAdmin_True(t *testing.T) { func TestUserService_IsSystemAdmin_False(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{"admin@example.com"}, ) @@ -364,7 +320,6 @@ func TestUserService_IsSystemAdmin_False(t *testing.T) { func TestUserService_GetVoucherByCode_Success(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) expectedVoucher := models.Voucher{ ID: 1, @@ -376,7 +331,7 @@ func TestUserService_GetVoucherByCode_Success(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -390,7 +345,6 @@ func TestUserService_GetVoucherByCode_Success(t *testing.T) { func TestUserService_GetVoucherByCode_NotFound(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) mockVoucherRepo.On("GetVoucherByCode", "INVALID"). Return(nil, fmt.Errorf("voucher not found")) @@ -398,7 +352,7 @@ func TestUserService_GetVoucherByCode_NotFound(t *testing.T) { var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -412,12 +366,11 @@ func TestUserService_GetVoucherByCode_NotFound(t *testing.T) { func TestUserService_GenerateRandomCode_ValidRange(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 5, []string{}, ) @@ -432,12 +385,11 @@ func TestUserService_GenerateRandomCode_ValidRange(t *testing.T) { func TestUserService_CodeTimeoutInMinutes(t *testing.T) { mockUserRepo := new(mockUserRepo) mockVoucherRepo := new(mockVoucherRepo) - mockPRRepo := new(mockPendingRecordRepo) var gridClient gridclient.GridClient service := NewUserService( context.Background(), - mockUserRepo, mockVoucherRepo, mockPRRepo, gridClient, + mockUserRepo, mockVoucherRepo, gridClient, nil, nil, nil, 15, []string{}, ) diff --git a/backend/internal/core/services/workers_service.go b/backend/internal/core/services/workers_service.go index e3ebb2fcc..0dd07104a 100644 --- a/backend/internal/core/services/workers_service.go +++ b/backend/internal/core/services/workers_service.go @@ -17,19 +17,33 @@ import ( "sync" "time" + "github.com/cenkalti/backoff" "github.com/hashicorp/go-multierror" "github.com/threefoldtech/tfgrid-sdk-go/grid-client/graphql" "github.com/xmonader/ewf" ) +const ( + // UnitFactor represents the smallest unit conversion factor for both USD and TFT + TFTUnitFactor = 1e7 + transferFees = 0.01 * TFTUnitFactor // 0.01 TFT + nodeCertified = "Certified" + + zeroTFTBalanceValue = 0.05 * TFTUnitFactor // 0.05 TFT + defaultPricingPolicyID = uint32(1) + + TrackingDebtPeriod = time.Hour + reties = 3 +) + type WorkerService struct { ctx context.Context - userRepo models.UserRepository - nodesRepo models.UserNodesRepository - invoicesRepo models.InvoiceRepository - clusterRepo models.ClusterRepository - pendingRecordsRepo models.PendingRecordRepository + userRepo models.UserRepository + contractsRepo models.ContractDataRepository + invoicesRepo models.InvoiceRepository + clusterRepo models.ClusterRepository + transferRecordsRepo models.TransferRecordRepository mailService mailservice.MailService graphql graphql.GraphQl @@ -47,31 +61,34 @@ type WorkerService struct { reservedNodeHealthCheckIntervalInHours int reservedNodeHealthCheckTimeoutInMinutes int reservedNodeHealthCheckWorkersNum int - monitorBalanceIntervalInMinutes int + settleTransferRecordsIntervalInMinutes int notifyAdminsForPendingRecordsInHours int + minimumTFTAmountInWallet int + appliedDiscount Discount usersBalanceCheckIntervalInHours int } func NewWorkersService( - ctx context.Context, userRepo models.UserRepository, nodesRepo models.UserNodesRepository, - invoicesRepo models.InvoiceRepository, clusterRepo models.ClusterRepository, pendingRecordsRepo models.PendingRecordRepository, + ctx context.Context, userRepo models.UserRepository, contractsRepo models.ContractDataRepository, + invoicesRepo models.InvoiceRepository, clusterRepo models.ClusterRepository, transferRecordsRepo models.TransferRecordRepository, mailService mailservice.MailService, gridClient gridclient.GridClient, ewfEngine *ewf.Engine, notificationDispatcher *notification.NotificationDispatcher, graphql graphql.GraphQl, firesquidClient graphql.GraphQl, invoiceCompanyData config.InvoiceCompanyData, systemMnemonic, currency string, clusterHealthCheckIntervalInHours, reservedNodeHealthCheckIntervalInHours, reservedNodeHealthCheckTimeoutInMinutes, reservedNodeHealthCheckWorkersNum, - monitorBalanceIntervalInMinutes, notifyAdminsForPendingRecordsInHours int, - usersBalanceCheckIntervalInHours int, + settleTransferRecordsIntervalInMinutes, notifyAdminsForPendingRecordsInHours, + minimumTFTAmountInWallet int, appliedDiscount Discount, + usersBalanceCheckIntervalInHours, checkUserDebtIntervalInHours int, ) WorkerService { return WorkerService{ - ctx: ctx, - userRepo: userRepo, - nodesRepo: nodesRepo, - invoicesRepo: invoicesRepo, - clusterRepo: clusterRepo, - pendingRecordsRepo: pendingRecordsRepo, + ctx: ctx, + userRepo: userRepo, + contractsRepo: contractsRepo, + invoicesRepo: invoicesRepo, + clusterRepo: clusterRepo, + transferRecordsRepo: transferRecordsRepo, mailService: mailService, notificationDispatcher: notificationDispatcher, @@ -89,9 +106,12 @@ func NewWorkersService( reservedNodeHealthCheckIntervalInHours: reservedNodeHealthCheckIntervalInHours, reservedNodeHealthCheckTimeoutInMinutes: reservedNodeHealthCheckTimeoutInMinutes, reservedNodeHealthCheckWorkersNum: reservedNodeHealthCheckWorkersNum, - monitorBalanceIntervalInMinutes: monitorBalanceIntervalInMinutes, + settleTransferRecordsIntervalInMinutes: settleTransferRecordsIntervalInMinutes, notifyAdminsForPendingRecordsInHours: notifyAdminsForPendingRecordsInHours, - usersBalanceCheckIntervalInHours: usersBalanceCheckIntervalInHours, + + minimumTFTAmountInWallet: minimumTFTAmountInWallet, + appliedDiscount: appliedDiscount, + usersBalanceCheckIntervalInHours: usersBalanceCheckIntervalInHours, } } @@ -108,12 +128,20 @@ func (svc WorkerService) ListAllClusters() ([]models.Cluster, error) { return svc.clusterRepo.ListAllClusters() } -func (svc WorkerService) ListAllReservedNodes() ([]models.UserNodes, error) { - return svc.nodesRepo.ListAllReservedNodes() +func (svc WorkerService) ListUserClusters(userID int) ([]models.Cluster, error) { + return svc.clusterRepo.ListUserClusters(userID) +} + +func (svc WorkerService) ListAllReservedNodes() ([]models.UserContractData, error) { + return svc.contractsRepo.ListAllReservedNodes() } -func (svc WorkerService) ListOnlyPendingRecords() ([]models.PendingRecord, error) { - return svc.pendingRecordsRepo.ListOnlyPendingRecords() +func (svc WorkerService) ListFailedTransferRecords() ([]models.TransferRecord, error) { + return svc.transferRecordsRepo.ListFailedTransferRecords() +} + +func (svc WorkerService) ListPendingTransferRecords() ([]models.TransferRecord, error) { + return svc.transferRecordsRepo.ListPendingTransferRecords() } func (svc WorkerService) GetClusterHealthCheckInterval() time.Duration { @@ -124,12 +152,12 @@ func (svc WorkerService) GetReservedNodeHealthCheckInterval() time.Duration { return time.Duration(svc.reservedNodeHealthCheckIntervalInHours) * time.Hour } -func (svc WorkerService) GetCheckUserDebtInterval() time.Duration { - return time.Duration(svc.checkUserDebtIntervalInHours) * time.Hour +func (svc WorkerService) GetSettleTransferRecordsInterval() time.Duration { + return time.Duration(svc.settleTransferRecordsIntervalInMinutes) * time.Minute } -func (svc WorkerService) GetMonitorBalanceInterval() time.Duration { - return time.Duration(svc.monitorBalanceIntervalInMinutes) * time.Minute +func (svc WorkerService) GetCheckUserDebtInterval() time.Duration { + return time.Duration(svc.checkUserDebtIntervalInHours) * time.Hour } func (svc WorkerService) GetNotifyAdminsForPendingRecordsInterval() time.Duration { @@ -140,43 +168,40 @@ func (svc WorkerService) GetUsersBalanceCheckInterval() time.Duration { return time.Duration(svc.usersBalanceCheckIntervalInHours) * time.Hour } -func (svc WorkerService) CreateUserInvoice(user models.User) error { - records, err := svc.nodesRepo.ListUserNodes(user.ID) +func (svc WorkerService) CreateUserInvoice(BillingService BillingService, user models.User) error { + now := time.Now() + timeMonthAgo := now.AddDate(0, -1, 0) + + contracts, err := svc.contractsRepo.ListAllContractsInPeriod(user.ID, timeMonthAgo, now) if err != nil { return err } - if len(records) == 0 { + if len(contracts) == 0 { return nil } - now := time.Now() - var invoiceItems []models.NodeItem var totalInvoiceCostUSD float64 - for _, record := range records { - billReports, err := billing.ListContractBillReportsPerMonth(svc.graphql, record.ContractID, now) + for _, contract := range contracts { + billReports, err := billing.ListContractBillReports(svc.graphql, contract.ContractID, timeMonthAgo, now) if err != nil { return err } - totalAmountTFT, err := billing.AmountBilledPerMonth(billReports) - if err != nil { - return err - } - totalAmountUSDMillicent, err := svc.gridClient.FromTFTtoUSDMillicent(totalAmountTFT) + totalAmountBilledInUSDMillicent, err := BillingService.calculateTotalUsageOfReportsInUSDMillicent(billReports.Reports) if err != nil { return err } + rentRecordStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()) - if record.CreatedAt.After(rentRecordStart) { - rentRecordStart = record.CreatedAt + if contract.CreatedAt.After(rentRecordStart) { + rentRecordStart = contract.CreatedAt } var totalHours int - cancellationDate, err := billing.GetRentContractCancellationDate(svc.firesquidClient, record.ContractID) - + cancellationDate, err := billing.GetRentContractCancellationDate(svc.firesquidClient, contract.ContractID) if errors.Is(err, billing.ErrorEventsNotFound) { totalHours = getHoursOfGivenPeriod(rentRecordStart, time.Now()) } else if err != nil { @@ -185,11 +210,11 @@ func (svc WorkerService) CreateUserInvoice(user models.User) error { totalHours = getHoursOfGivenPeriod(rentRecordStart, cancellationDate) } - totalAmountUSD := gridclient.FromUSDMilliCentToUSD(totalAmountUSDMillicent) + totalAmountUSD := gridclient.FromUSDMilliCentToUSD(totalAmountBilledInUSDMillicent) invoiceItems = append(invoiceItems, models.NodeItem{ - NodeID: record.NodeID, - ContractID: record.ContractID, + NodeID: contract.NodeID, + ContractID: contract.ContractID, RentCreatedAt: rentRecordStart, PeriodInHours: float64(totalHours), Cost: totalAmountUSD, @@ -233,16 +258,18 @@ func (svc WorkerService) UpdateUserDebt() error { } for _, user := range users { - userNodes, err := svc.nodesRepo.ListUserNodes(user.ID) + userContracts, err := svc.contractsRepo.ListAllContractsInPeriod(user.ID, time.Now().Add(-TrackingDebtPeriod), time.Now()) if err != nil { - logger.ForOperation("debt_tracker", "list_user_nodes").Error().Err(err).Msg("Failed to list user nodes") + logger.ForOperation("debt_tracker", "list_user_contracts").Error().Err(err).Msg("Failed to list user contracts") continue } - contractIDs := make([]uint64, len(userNodes)) - for i, node := range userNodes { - contractIDs[i] = node.ContractID + + var contractIDs []uint64 + for _, contract := range userContracts { + contractIDs = append(contractIDs, contract.ContractID) } - userDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs, time.Hour) + + userDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs) if err != nil { logger.ForOperation("debt_tracker", "calculate_debt").Error().Err(err).Msg("Failed to calculate user debt") continue @@ -258,20 +285,30 @@ func (svc WorkerService) UpdateUserDebt() error { return nil } -func (svc WorkerService) calculateDebt(userMnemonic string, contractIDs []uint64, debtPeriod time.Duration) (uint64, error) { +func (svc WorkerService) calculateDebt(userMnemonic string, contractIDs []uint64) (uint64, error) { + if len(contractIDs) == 0 { + return 0, nil + } calculatorClient, err := svc.gridClient.NewCalculator(userMnemonic) if err != nil { - return 0, fmt.Errorf("failed to create new calculator: %w", err) + return 0, fmt.Errorf("failed to create calculator: %w", err) } var totalDebt int64 for _, contractID := range contractIDs { - debt, err := calculatorClient.CalculateContractOverdue(contractID, debtPeriod) - if err != nil { + var debt int64 + if err = backoff.Retry(func() error { + debt, err = calculatorClient.CalculateContractOverdue(contractID, time.Hour) + return err + }, backoff.WithMaxRetries( + backoff.NewExponentialBackOff(), + reties, + )); err != nil { logger.ForOperation("debt_tracker", "calc_overdue").Error().Err(err).Msg("Failed to calculate contract overdue") continue } + totalDebt += debt } @@ -284,7 +321,7 @@ func (svc WorkerService) calculateDebt(userMnemonic string, contractIDs []uint64 } // checkNodesWithWorkerPool uses a worker pool to check node health concurrently -func (svc WorkerService) CheckNodesWithWorkerPool(reservedNodes []models.UserNodes) { +func (svc WorkerService) CheckNodesWithWorkerPool(reservedNodes []models.UserContractData) { timeout := time.Duration(svc.reservedNodeHealthCheckTimeoutInMinutes) * time.Minute ctx, cancel := context.WithTimeout(context.Background(), timeout) @@ -295,7 +332,7 @@ func (svc WorkerService) CheckNodesWithWorkerPool(reservedNodes []models.UserNod workerCount = len(reservedNodes) } - jobs := make(chan models.UserNodes, len(reservedNodes)) + jobs := make(chan models.UserContractData, len(reservedNodes)) results := make(chan NodeHealthResult, len(reservedNodes)) var wg sync.WaitGroup @@ -354,7 +391,7 @@ func (svc WorkerService) CheckNodesWithWorkerPool(reservedNodes []models.UserNod } } -func (svc WorkerService) healthCheckWorker(ctx context.Context, wg *sync.WaitGroup, jobs <-chan models.UserNodes, results chan<- NodeHealthResult) { +func (svc WorkerService) healthCheckWorker(ctx context.Context, wg *sync.WaitGroup, jobs <-chan models.UserContractData, results chan<- NodeHealthResult) { defer wg.Done() log := logger.ForOperation("health_tracker", "health_check_worker") @@ -381,15 +418,22 @@ func (svc WorkerService) healthCheckWorker(ctx context.Context, wg *sync.WaitGro } } -func (svc WorkerService) SettlePendingPayments(records []models.PendingRecord) { +func (svc WorkerService) SettlePendingPayments(records []models.TransferRecord) { log := logger.ForOperation("balance_monitor", "settle_pending_payments") for _, record := range records { + if record.Operation == models.WithdrawOperation { + continue + } + // Already settled - if record.TransferredTFTAmount >= record.TFTAmount { + if record.State == models.SuccessState { continue } + transferState := models.SuccessState + var transferFailure string + // getting balance every time to ensure we have the latest balance systemTFTBalance, err := svc.gridClient.GetFreeBalanceTFT(svc.systemMnemonic) if err != nil { @@ -397,43 +441,82 @@ func (svc WorkerService) SettlePendingPayments(records []models.PendingRecord) { continue } - amountToTransfer := record.TFTAmount - record.TransferredTFTAmount - if systemTFTBalance < amountToTransfer { - log.Warn(). - Int("record_id", record.ID). - Uint64("system_balance", systemTFTBalance). - Uint64("amount_needed", amountToTransfer). - Msg("Insufficient system balance to settle pending record") + if systemTFTBalance < record.TFTAmount { + logger.GetLogger().Warn().Msgf("Insufficient system balance to settle pending record ID %d", record.ID) continue } - if err = svc.transferTFTsToUser(record.UserID, record.ID, amountToTransfer); err != nil { - log.Error().Err(err).Int("user_id", record.UserID).Int("record_id", record.ID).Msg("Failed to transfer TFTs to user") - continue + if err = svc.transferTFTsToUser(record.UserID, record.TFTAmount); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to settle pending record ID %d", record.ID) + + transferState = models.FailedState + transferFailure = err.Error() + } + + if err := svc.transferRecordsRepo.UpdateTransferRecordState(record.ID, transferState, transferFailure); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to update pending record ID %d state", record.ID) } } } -func (svc WorkerService) transferTFTsToUser(userID, recordID int, amountToTransfer uint64) error { +func (svc WorkerService) transferTFTsToUser(userID int, amountToTransfer uint64) error { user, err := svc.userRepo.GetUserByID(userID) if err != nil { - return fmt.Errorf("failed to get user for pending record ID %d: %w", recordID, err) + return fmt.Errorf("failed to get user %d: %w", userID, err) } err = svc.gridClient.TransferTFTsFromSystem(amountToTransfer, user.Mnemonic) if err != nil { - return fmt.Errorf("failed to transfer TFTs for pending record ID %d: %w", recordID, err) + return fmt.Errorf("failed to transfer TFTs for user %d: %w", userID, err) } - err = svc.pendingRecordsRepo.UpdatePendingRecordTransferredAmount(recordID, amountToTransfer) - if err != nil { - return fmt.Errorf("failed to update transferred amount for pending record ID %d: %w", recordID, err) + return nil +} + +func (svc *WorkerService) ResetUsersTFTsWithNoUSDBalance(users []models.User) error { + for _, user := range users { + if user.CreditedBalance+user.CreditCardBalance-user.Debt <= 0 { + logger.GetLogger().Info().Msgf("User %d has no USD balance, withdrawing all TFTs except for %d", user.ID, svc.minimumTFTAmountInWallet) + + userTFTBalance, err := svc.gridClient.GetFreeBalanceTFT(user.Mnemonic) + if err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to get user TFT balance for user %d", user.ID) + continue + } + + if userTFTBalance <= uint64(svc.minimumTFTAmountInWallet)*TFTUnitFactor { + continue + } + + if userTFTBalance <= uint64(svc.minimumTFTAmountInWallet)*TFTUnitFactor+transferFees { + continue + } + + transferRecord := models.TransferRecord{ + UserID: user.ID, + Username: user.Username, + TFTAmount: userTFTBalance - transferFees - uint64(svc.minimumTFTAmountInWallet)*TFTUnitFactor, + Operation: models.WithdrawOperation, + State: models.SuccessState, + } + + if err = svc.gridClient.TransferTFTsToSystem(userTFTBalance, user.Mnemonic); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to withdraw all TFTs for user %d", user.ID) + + transferRecord.State = models.FailedState + transferRecord.Failure = err.Error() + } + + if err := svc.transferRecordsRepo.CreateTransferRecord(&transferRecord); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to create transfer record for user %d", user.ID) + } + } } return nil } -func (svc WorkerService) NotifyAdminWithPendingRecords(records []models.PendingRecord) error { +func (svc WorkerService) NotifyAdminWithPendingRecords(records []models.TransferRecord) error { admins, err := svc.userRepo.ListAdmins() if err != nil { @@ -473,7 +556,7 @@ func (svc WorkerService) AsyncTrackClusterHealth(cluster models.Cluster) error { } func (svc WorkerService) checkUserDebt(user models.User, contractIDs []uint64) error { - totalDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs, svc.GetCheckUserDebtInterval()) + totalDebt, err := svc.calculateDebt(user.Mnemonic, contractIDs) if err != nil { return fmt.Errorf("failed to calculate debt: %w", err) } @@ -559,7 +642,7 @@ func (svc WorkerService) getUserContractIDs() (map[int][]uint64, error) { return nil, fmt.Errorf("failed to list clusters: %w", err) } - reservedNodes, err := svc.nodesRepo.ListAllReservedNodes() + reservedNodes, err := svc.contractsRepo.ListAllReservedNodes() if err != nil { return nil, fmt.Errorf("failed to list reserved nodes: %w", err) } diff --git a/backend/internal/core/workers/balance_monitor.go b/backend/internal/core/workers/balance_monitor.go index 0281aa72e..9dab546af 100644 --- a/backend/internal/core/workers/balance_monitor.go +++ b/backend/internal/core/workers/balance_monitor.go @@ -7,32 +7,97 @@ import ( ) func (w Workers) MonitorSystemBalanceAndHandleSettlement() { - balanceTicker := time.NewTicker(w.svc.GetMonitorBalanceInterval()) + settleTransfersTicker := time.NewTicker(w.svc.GetSettleTransferRecordsInterval()) adminNotifyTicker := time.NewTicker(w.svc.GetNotifyAdminsForPendingRecordsInterval()) - defer balanceTicker.Stop() + zeroUSDBalanceTicker := time.NewTicker(time.Minute) + zeroTFTBalanceTicker := time.NewTicker(time.Minute) + fundUserTFTBalanceTicker := time.NewTicker(24 * time.Hour) + defer settleTransfersTicker.Stop() defer adminNotifyTicker.Stop() + defer zeroUSDBalanceTicker.Stop() + defer zeroTFTBalanceTicker.Stop() + defer fundUserTFTBalanceTicker.Stop() for { select { case <-w.ctx.Done(): return - case <-balanceTicker.C: - records, err := w.svc.ListOnlyPendingRecords() + case <-settleTransfersTicker.C: + records, err := w.svc.ListPendingTransferRecords() if err != nil { continue } - w.svc.SettlePendingPayments(records) + failedRecords, err := w.svc.ListFailedTransferRecords() + if err != nil { + continue + } + + w.svc.SettlePendingPayments(append(records, failedRecords...)) case <-adminNotifyTicker.C: - records, err := w.svc.ListOnlyPendingRecords() + records, err := w.svc.ListPendingTransferRecords() + if err != nil { + continue + } + + if len(records) == 0 { + continue + } + + if err := w.svc.NotifyAdminWithPendingRecords(records); err != nil { + logger.ForOperation("balance_monitor", "notify_admins_pending_records").Error().Err(err).Msg("Failed to notify admins with pending records") + } + + case <-zeroUSDBalanceTicker.C: + users, err := w.svc.ListAllUsers() + if err != nil { + logger.GetLogger().Error().Err(err).Msg("Failed to list users") + continue + } + + if err := w.svc.ResetUsersTFTsWithNoUSDBalance(users); err != nil { + logger.GetLogger().Error().Err(err).Send() + } + + case <-zeroTFTBalanceTicker.C: + users, err := w.svc.ListAllUsers() + if err != nil { + logger.GetLogger().Error().Err(err).Msg("Failed to list users") + continue + } + + for _, user := range users { + clusters, err := w.svc.ListUserClusters(user.ID) + if err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to list user clusters") + continue + } + + if len(clusters) > 0 { + // user has deployed workloads, skip + continue + } + + if user.CreditedBalance+user.CreditCardBalance-user.Debt <= 0 { + continue + } + + if err := w.billingService.CreateTransferRecordToChargeUserWithMinTFTAmount(user.ID, user.Username, user.Mnemonic); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to create transfer record for user %d", user.ID) + } + } + + case <-fundUserTFTBalanceTicker.C: + users, err := w.svc.ListAllUsers() if err != nil { + logger.GetLogger().Error().Err(err).Msg("Failed to list users") continue } - if len(records) > 0 { - if err := w.svc.NotifyAdminWithPendingRecords(records); err != nil { - logger.ForOperation("balance_monitor", "notify_admins_pending_records").Error().Err(err).Msg("Failed to notify admins with pending records") + for _, user := range users { + if err := w.billingService.FundUserToFulfillDiscount(w.ctx, &user, nil, nil); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to fund user %d to claim discount", user.ID) } } } diff --git a/backend/internal/core/workers/create_invoices.go b/backend/internal/core/workers/create_invoices.go index 327e93536..fb2988d02 100644 --- a/backend/internal/core/workers/create_invoices.go +++ b/backend/internal/core/workers/create_invoices.go @@ -34,7 +34,7 @@ func (w Workers) MonthlyInvoicesHandler() { } for _, user := range users { - if err = w.svc.CreateUserInvoice(user); err != nil { + if err = w.svc.CreateUserInvoice(w.billingService, user); err != nil { baseLog.Error().Err(err).Int("user_id", user.ID).Msg("failed to create invoice for user") } } diff --git a/backend/internal/core/workers/debt_tracker.go b/backend/internal/core/workers/debt_tracker.go index 8065bcdcf..99a767bbd 100644 --- a/backend/internal/core/workers/debt_tracker.go +++ b/backend/internal/core/workers/debt_tracker.go @@ -3,11 +3,12 @@ package workers import ( "time" + "kubecloud/internal/core/services" "kubecloud/internal/infrastructure/logger" ) func (w Workers) TrackUserDebt() { - ticker := time.NewTicker(1 * time.Hour) + ticker := time.NewTicker(services.TrackingDebtPeriod) defer ticker.Stop() for { diff --git a/backend/internal/core/workers/settle_usage.go b/backend/internal/core/workers/settle_usage.go new file mode 100644 index 000000000..5517e3b46 --- /dev/null +++ b/backend/internal/core/workers/settle_usage.go @@ -0,0 +1,32 @@ +package workers + +import ( + "kubecloud/internal/infrastructure/logger" + "time" +) + +// DeductUSDBalanceBasedOnUsage deducts the user balance based on the usage +// This function is called every 24 hours +func (w Workers) DeductUSDBalanceBasedOnUsage() { + usageDeductionTicker := time.NewTicker(24 * time.Hour) + defer usageDeductionTicker.Stop() + + for { + select { + case <-w.ctx.Done(): + return + case <-usageDeductionTicker.C: + users, err := w.svc.ListAllUsers() + if err != nil { + logger.GetLogger().Error().Err(err).Msg("Failed to list users") + continue + } + + for _, user := range users { + if err := w.billingService.SettleUserUsage(&user); err != nil { + logger.GetLogger().Error().Err(err).Msgf("Failed to settle daily usage for user %d", user.ID) + } + } + } + } +} diff --git a/backend/internal/core/workers/workers.go b/backend/internal/core/workers/workers.go index c3291b6c4..453ec781f 100644 --- a/backend/internal/core/workers/workers.go +++ b/backend/internal/core/workers/workers.go @@ -8,13 +8,14 @@ import ( ) type Workers struct { - ctx context.Context - svc services.WorkerService - metrics *metrics.Metrics - db models.DB + ctx context.Context + svc services.WorkerService + billingService services.BillingService + metrics *metrics.Metrics + db models.DB } -func NewWorkers(ctx context.Context, svc services.WorkerService, metrics *metrics.Metrics, db models.DB) Workers { +func NewWorkers(ctx context.Context, svc services.WorkerService, billingService services.BillingService, metrics *metrics.Metrics, db models.DB) Workers { return Workers{ ctx: ctx, svc: svc, diff --git a/backend/internal/core/workflows/deployer_activities.go b/backend/internal/core/workflows/deployer_activities.go index 0e8c57ad1..849903310 100644 --- a/backend/internal/core/workflows/deployer_activities.go +++ b/backend/internal/core/workflows/deployer_activities.go @@ -304,7 +304,7 @@ func BatchDeployAllNodesStep(metrics *metricsLib.Metrics) ewf.StepFn { } } -func StoreDeploymentStep(clusterRepo models.ClusterRepository) ewf.StepFn { +func StoreDeploymentStep(clusterRepo models.ClusterRepository, contractsRepo models.ContractDataRepository) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { log := logger.ForOperation("deployer_activities", "store_deployment") cluster, err := statemanager.GetCluster(state) @@ -341,7 +341,7 @@ func StoreDeploymentStep(clusterRepo models.ClusterRepository) ewf.StepFn { } else { // cluster exists, update it existingCluster.Result = dbCluster.Result existingCluster.Kubeconfig = dbCluster.Kubeconfig - if err := clusterRepo.UpdateCluster(&existingCluster); err != nil { + if err := clusterRepo.UpdateCluster(contractsRepo, &existingCluster); err != nil { return fmt.Errorf("failed to update cluster %s in database (user_id=%d): %w", cluster.Name, config.UserID, err) } } @@ -498,7 +498,7 @@ func BatchCancelContractsStep() ewf.StepFn { } } -func DeleteAllUserClustersStep(clusterRepo models.ClusterRepository, metrics *metricsLib.Metrics) ewf.StepFn { +func DeleteAllUserClustersStep(clusterRepo models.ClusterRepository, contractsRepo models.ContractDataRepository, metrics *metricsLib.Metrics) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { config, err := getConfig(state) if err != nil { @@ -511,7 +511,7 @@ func DeleteAllUserClustersStep(clusterRepo models.ClusterRepository, metrics *me } clusterCount := len(clusters) - if err := clusterRepo.DeleteAllUserClusters(config.UserID); err != nil { + if err := clusterRepo.DeleteAllUserClusters(contractsRepo, config.UserID); err != nil { return fmt.Errorf("failed to delete all user clusters from database (user_id=%d): %w", config.UserID, err) } @@ -641,7 +641,7 @@ func createAddNodeWorkflowTemplate(notificationDispatcher *notification.Notifica return template } -func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metrics, clusterRepo models.ClusterRepository, notificationDispatcher *notification.NotificationDispatcher, config cfg.Configuration) { +func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metrics, clusterRepo models.ClusterRepository, contractsRepo models.ContractDataRepository, notificationDispatcher *notification.NotificationDispatcher, config cfg.Configuration) { engine.Register(StepDeployNetwork, DeployNetworkStep()) engine.Register(StepDeployLeaderNode, DeployLeaderNodeStep()) engine.Register(StepBatchDeployAllNodes, BatchDeployAllNodesStep(metrics)) @@ -649,14 +649,14 @@ func registerDeploymentActivities(engine *ewf.Engine, metrics *metricsLib.Metric engine.Register(StepAddNode, AddNodeStep()) engine.Register(StepUpdateNetwork, UpdateNetworkStep()) engine.Register(StepRemoveNode, RemoveDeploymentNodeStep()) - engine.Register(StepStoreDeployment, StoreDeploymentStep(clusterRepo)) + engine.Register(StepStoreDeployment, StoreDeploymentStep(clusterRepo, contractsRepo)) engine.Register(StepFetchKubeconfig, FetchKubeconfigStep(clusterRepo, config.SSH.PrivateKeyPath)) engine.Register(StepVerifyClusterReady, VerifyClusterReadyStep()) engine.Register(StepVerifyNewNodes, VerifyAddedNodeStep(clusterRepo, config.SSH.PrivateKeyPath)) engine.Register(StepRemoveClusterFromDB, RemoveClusterFromDBStep(clusterRepo, metrics)) engine.Register(StepGatherAllContractIDs, GatherAllContractIDsStep(clusterRepo)) engine.Register(StepBatchCancelContracts, BatchCancelContractsStep()) - engine.Register(StepDeleteAllUserClusters, DeleteAllUserClustersStep(clusterRepo, metrics)) + engine.Register(StepDeleteAllUserClusters, DeleteAllUserClustersStep(clusterRepo, contractsRepo, metrics)) deployWFTemplate := createDeployerWorkflowTemplate(notificationDispatcher, engine, metrics) deployWFTemplate.Steps = []ewf.Step{ @@ -968,7 +968,7 @@ func VerifyClusterInDBStep(clusterRepo models.ClusterRepository) ewf.StepFn { } } -func CheckClusterNodesHealthStep(clusterRepo models.ClusterRepository) ewf.StepFn { +func CheckClusterNodesHealthStep(clusterRepo models.ClusterRepository, contractsRepo models.ContractDataRepository) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { config, err := getConfig(state) if err != nil { @@ -1020,7 +1020,7 @@ func CheckClusterNodesHealthStep(clusterRepo models.ClusterRepository) ewf.StepF if err := dbCluster.SetClusterResult(cluster); err != nil { return fmt.Errorf("failed to set cluster result for cluster %s: %w", cluster.Name, err) } - if err := clusterRepo.UpdateCluster(&dbCluster); err != nil { + if err := clusterRepo.UpdateCluster(contractsRepo, &dbCluster); err != nil { return fmt.Errorf("failed to update cluster %s in database: %w", cluster.Name, err) } return nil diff --git a/backend/internal/core/workflows/names.go b/backend/internal/core/workflows/names.go index 648257356..0a0b8649d 100644 --- a/backend/internal/core/workflows/names.go +++ b/backend/internal/core/workflows/names.go @@ -3,10 +3,8 @@ package workflows // Workflow names const ( WorkflowChargeBalance = "charge-balance" - WorkflowAdminCreditBalance = "admin-credit-balance" WorkflowUserRegistration = "user-registration" WorkflowUserVerification = "user-verification" - WorkflowRedeemVoucher = "redeem-voucher" WorkflowReserveNode = "reserve-node" WorkflowUnreserveNode = "unreserve-node" WorkflowDeployCluster = "deploy-cluster" @@ -25,7 +23,6 @@ const ( // Step names const ( StepCreatePaymentIntent = "create_payment_intent" - StepCreatePendingRecord = "create_pending_record" StepUpdateCreditCardBalance = "update_user_balance" StepSendVerificationEmail = "send_verification_email" StepCreateUser = "create_user" @@ -37,7 +34,6 @@ const ( StepCreateIdentity = "create_identity" StepReserveNode = "reserve_node" StepUnreserveNode = "unreserve-node" - StepUpdateCreditedBalance = "update-credited-balance" StepRemoveNode = "remove-node" StepStoreDeployment = "store-deployment" StepAddNode = "add-node" diff --git a/backend/internal/core/workflows/node_activities.go b/backend/internal/core/workflows/node_activities.go index b39395480..662ff0623 100644 --- a/backend/internal/core/workflows/node_activities.go +++ b/backend/internal/core/workflows/node_activities.go @@ -17,7 +17,7 @@ const ( NodeHasActiveContracts = "NodeHasActiveContracts" ) -func ReserveNodeStep(userNodesRepo models.UserNodesRepository, gridClient gridclient.GridClient) ewf.StepFn { +func ReserveNodeStep(contractsRepo models.ContractDataRepository, gridClient gridclient.GridClient) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { config, err := getConfig(state) if err != nil { @@ -39,10 +39,11 @@ func ReserveNodeStep(userNodesRepo models.UserNodesRepository, gridClient gridcl return fmt.Errorf("failed to create rent contract for node_id=%d (user_id=%d): %w", nodeID, config.UserID, err) } - err = userNodesRepo.CreateUserNode(&models.UserNodes{ + err = contractsRepo.CreateUserContractData(&models.UserContractData{ UserID: config.UserID, ContractID: contractID, NodeID: nodeID, + Type: models.ContractTypeRented, CreatedAt: time.Now(), }) if err != nil { @@ -54,7 +55,7 @@ func ReserveNodeStep(userNodesRepo models.UserNodesRepository, gridClient gridcl } } -func UnreserveNodeStep(userNodesRepo models.UserNodesRepository, gridClient gridclient.GridClient) ewf.StepFn { +func UnreserveNodeStep(contractsRepo models.ContractDataRepository, gridClient gridclient.GridClient) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { contractIDVal, ok := state["contract_id"] if !ok { @@ -76,7 +77,7 @@ func UnreserveNodeStep(userNodesRepo models.UserNodesRepository, gridClient grid return fmt.Errorf("failed to cancel contract: %w", err) } - err = userNodesRepo.DeleteUserNode(contractID) + err = contractsRepo.DeleteUserContract(contractID) if err != nil { return fmt.Errorf("failed to delete user node: %w", err) } diff --git a/backend/internal/core/workflows/notify.go b/backend/internal/core/workflows/notify.go index 945a61a92..288430239 100644 --- a/backend/internal/core/workflows/notify.go +++ b/backend/internal/core/workflows/notify.go @@ -182,10 +182,6 @@ func sendBillingWorkflowNotifications(ctx context.Context, notificationDispatche displayName := getWorkflowDisplayName(wf) - if wf.Name == WorkflowAdminCreditBalance { - return sendAdminCreditBalanceWorkflowNotification(ctx, notificationDispatcher, wf, err) - } - amount, amountErr := getUint64FromState(wf.State, "amount") if amountErr != nil { log.Error().Err(amountErr).Msg("failed to get amount from state") @@ -208,12 +204,6 @@ func sendBillingWorkflowNotifications(ctx context.Context, notificationDispatche subject = "Adding Funds Succeeded" message = fmt.Sprintf("Funds were added successfully to your account. Amount added: $%.2f. New balance will be: $%.2f.", amountUSD, newBalanceUSD) - if wf.Name == WorkflowRedeemVoucher { - status = "voucher_redeemed" - subject = "Voucher Redeemed" - message = fmt.Sprintf("Voucher redeemed successfully. Amount added: $%.2f.", amountUSD) - } - notif := notification.BillingNotification(config.UserID). Success(message). WithSubject(subject). @@ -239,78 +229,6 @@ func sendBillingWorkflowNotifications(ctx context.Context, notificationDispatche return notificationDispatcher.Send(ctx, notif) } -func sendAdminCreditBalanceWorkflowNotification(ctx context.Context, notificationDispatcher *notification.NotificationDispatcher, wf *ewf.Workflow, err error) error { - log := logger.ForOperation("workflow", "create_admin_credit_balance_notification").With().Str("workflow_name", wf.Name).Logger() - adminID, adminIDErr := getIntFromState(wf.State, "admin_id") - if adminIDErr != nil { - log.Error().Err(adminIDErr).Msg("failed to get admin ID from state") - return adminIDErr - } - - username, usernameErr := getFromState[string](wf.State, "username") - if usernameErr != nil { - log.Warn().Err(usernameErr).Msg("failed to get username from state") - notif := buildGenericWorkflowNotification(wf, adminID, err) - return notificationDispatcher.Send(ctx, notif) - } - - displayName := getWorkflowDisplayName(wf) - amount, amountErr := getUint64FromState(wf.State, "amount") - if amountErr != nil { - log.Error().Err(amountErr).Msg("failed to get amount from state") - notif := buildGenericWorkflowNotification(wf, adminID, err) - return notificationDispatcher.Send(ctx, notif) - } - - amountUSD := gridclient.FromUSDMilliCentToUSD(amount) - - // Admin notification - adminNotif := notification.BillingNotification(adminID). - Success(fmt.Sprintf("User %s was credited successfully, money transferred successfully to their account (Amount: $%.2f)", username, amountUSD)). - WithSubject("Money transfer to user's account succeeded"). - WithStatus("succeeded"). - WithExtra("amount", fmt.Sprintf("%.2f", amountUSD)). - WithExtra("workflow_name", displayName). - WithChannels(notification.ChannelUI). - Build() - - if err != nil { - adminNotif = notification.BillingNotification(adminID). - Failure(fmt.Sprintf("Money transfer to user %s's account failed", username), err). - WithSubject("Money transfer to user's account failed"). - WithChannels(notification.ChannelUI). - Build() - } - - if sendErr := notificationDispatcher.Send(ctx, adminNotif); sendErr != nil { - return sendErr - } - - // User notification - config, confErr := getConfig(wf.State) - if confErr != nil { - log.Error().Msg("Missing or invalid 'config' in workflow state") - return confErr - } - userBuilder := notification.BillingNotification(config.UserID) - if err != nil { - userBuilder = userBuilder.Failure("Funds transfer to your account failed", err). - WithSubject("Your Account Credit Failed") - } else { - userBuilder = userBuilder.Success("Funds were credited to your account."). - WithSubject("Your Account Has Been Credited"). - WithStatus("succeeded"). - WithExtra("amount", fmt.Sprintf("%.2f", amountUSD)) - } - - userNotif := userBuilder. - WithExtra("workflow_name", displayName). - WithChannels(notification.ChannelEmail). - Build() - - return notificationDispatcher.Send(ctx, userNotif) -} - func sendNodeWorkflowNotification(ctx context.Context, notificationDispatcher *notification.NotificationDispatcher, wf *ewf.Workflow, err error) error { log := logger.ForOperation("workflow", "create_node_notification").With().Str("workflow_name", wf.Name).Logger() config, confErr := getConfig(wf.State) @@ -427,7 +345,7 @@ func sendUserWorkflowNotification(ctx context.Context, notificationDispatcher *n } func workflowToNotificationType(workflowName string) models.NotificationType { - billingWf := []string{WorkflowChargeBalance, WorkflowAdminCreditBalance, WorkflowRedeemVoucher, WorkflowDrainUser, WorkflowDrainAllUsers} + billingWf := []string{WorkflowChargeBalance, WorkflowDrainUser, WorkflowDrainAllUsers} deployWf := []string{WorkflowDeleteAllClusters, WorkflowDeleteCluster, WorkflowRemoveNode, WorkflowAddNode, WorkflowRollbackFailedDeployment} nodesWf := []string{WorkflowReserveNode, WorkflowUnreserveNode} userWf := []string{WorkflowUserVerification, WorkflowUserRegistration} diff --git a/backend/internal/core/workflows/state_converters.go b/backend/internal/core/workflows/state_converters.go index 489181a45..6e126e708 100644 --- a/backend/internal/core/workflows/state_converters.go +++ b/backend/internal/core/workflows/state_converters.go @@ -6,21 +6,6 @@ import ( "github.com/xmonader/ewf" ) -// toInt safely converts various numeric types to int -// Handles int, float64, and int64 types commonly found in workflow state after JSON unmarshaling -func toInt(val interface{}) (int, error) { - switch v := val.(type) { - case int: - return v, nil - case float64: - return int(v), nil - case int64: - return int(v), nil - default: - return 0, fmt.Errorf("cannot convert %T to int", val) - } -} - // toUint32 safely converts various numeric types to uint32 // Handles uint32, float64, int64, and int types commonly found in workflow state func toUint32(val interface{}) (uint32, error) { @@ -64,10 +49,6 @@ func getFromStateWithConverter[T any](state ewf.State, key string, converter fun return converter(val) } -func getIntFromState(state ewf.State, key string) (int, error) { - return getFromStateWithConverter(state, key, toInt) -} - func getUint64FromState(state ewf.State, key string) (uint64, error) { return getFromStateWithConverter(state, key, toUint64) } diff --git a/backend/internal/core/workflows/user_activities.go b/backend/internal/core/workflows/user_activities.go index 77f790b65..82a728f84 100644 --- a/backend/internal/core/workflows/user_activities.go +++ b/backend/internal/core/workflows/user_activities.go @@ -366,62 +366,6 @@ func CreatePaymentIntentStep(currency string, metrics *metrics.Metrics, stripeCl } } -func CreatePendingRecord(gridClient gridclient.GridClient, pendingRecordRepo models.PendingRecordRepository) ewf.StepFn { - return func(ctx context.Context, state ewf.State) error { - log := logger.ForOperation("user_activities", "create_pending_record") - amountVal, ok := state["amount"] - if !ok { - return fmt.Errorf("missing 'amount' in state") - } - - amount, ok := amountVal.(uint64) - if !ok { - return fmt.Errorf("'amount' in state is not a uint64") - } - - config, err := getConfig(state) - if err != nil { - return fmt.Errorf("failed to get config from state: %w", err) - } - - usernameVal, ok := state["username"] - if !ok { - return fmt.Errorf("missing 'username' in state") - } - username, ok := usernameVal.(string) - if !ok { - return fmt.Errorf("'username' in state is not a string") - } - - transferModeVal, ok := state["transfer_mode"] - if !ok { - return fmt.Errorf("missing 'transfer_mode' in state") - } - transferMode, ok := transferModeVal.(string) - if !ok { - return fmt.Errorf("'transfer_mode' in state is not a string") - } - - requestedTFTs, err := gridClient.FromUSDMillicentToTFT(amount) - if err != nil { - log.Error().Err(err).Msg("error converting USD to TFT") - return err - } - - if err = pendingRecordRepo.CreatePendingRecord(&models.PendingRecord{ - UserID: config.UserID, - Username: username, - TFTAmount: requestedTFTs, - TransferMode: transferMode, - }); err != nil { - log.Error().Err(err).Msg("failed to create pending record") - return err - } - - return nil - } -} - func UpdateCreditCardBalanceStep(userRepo models.UserRepository) ewf.StepFn { return func(ctx context.Context, state ewf.State) error { config, err := getConfig(state) diff --git a/backend/internal/core/workflows/workflow.go b/backend/internal/core/workflows/workflow.go index 09af2e426..82a358cd0 100644 --- a/backend/internal/core/workflows/workflow.go +++ b/backend/internal/core/workflows/workflow.go @@ -32,8 +32,7 @@ func RegisterEWFWorkflows( ) { userRepo := persistence.NewGormUserRepository(db) clusterRepo := persistence.NewGormClusterRepository(db) - userNodesRepo := persistence.NewGormUserNodesRepository(db) - pendingRecordRepo := persistence.NewGormPendingRecordRepository(db) + contractsRepo := persistence.NewGormUserContractDataRepository(db) engine.Register(StepSendVerificationEmail, SendVerificationEmailStep(mailService, config)) engine.Register(StepCreateUser, CreateUserStep(config, userRepo)) @@ -43,17 +42,15 @@ func RegisterEWFWorkflows( engine.Register(StepCreateKYCSponsorship, CreateKYCSponsorship(kycClient, sponsorAddress, sponsorKeyPair, userRepo)) engine.Register(StepSendWelcomeEmail, SendWelcomeEmailStep(mailService, metrics)) engine.Register(StepCreatePaymentIntent, CreatePaymentIntentStep(config.Currency, metrics, stripeClient)) - engine.Register(StepCreatePendingRecord, CreatePendingRecord(gridClient, pendingRecordRepo)) engine.Register(StepUpdateCreditCardBalance, UpdateCreditCardBalanceStep(userRepo)) - engine.Register(StepReserveNode, ReserveNodeStep(userNodesRepo, gridClient)) - engine.Register(StepUnreserveNode, UnreserveNodeStep(userNodesRepo, gridClient)) - engine.Register(StepUpdateCreditedBalance, UpdateCreditedBalanceStep(userRepo)) + engine.Register(StepReserveNode, ReserveNodeStep(contractsRepo, gridClient)) + engine.Register(StepUnreserveNode, UnreserveNodeStep(contractsRepo, gridClient)) engine.Register(StepSendEmailNotification, SendEmailNotificationStep(userRepo, mailService)) engine.Register(StepVerifyNodeState, VerifyNodeStateStep(gridClient)) engine.Register(StepVerifyClusterInDB, VerifyClusterInDBStep(clusterRepo)) engine.Register(StepDrainUserBalance, DrainUserBalanceStep(userRepo, gridClient)) engine.Register(StepDrainAllUsersBalance, DrainAllUsersBalanceStep(userRepo, engine, config.MailSender.MaxConcurrentSends)) - engine.Register(StepCheckClusterNodesHealth, CheckClusterNodesHealthStep(clusterRepo)) + engine.Register(StepCheckClusterNodesHealth, CheckClusterNodesHealthStep(clusterRepo, contractsRepo)) engine.Register(StepCheckClusterHealth, CheckClusterHealthStep(config.SSH.PrivateKeyPath)) registerWorkflowTemplate := newKubecloudWorkflowTemplate(notificationDispatcher) @@ -102,24 +99,9 @@ func RegisterEWFWorkflows( chargeBalanceTemplate.Steps = []ewf.Step{ {Name: StepCreatePaymentIntent, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, {Name: StepUpdateCreditCardBalance, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, - {Name: StepCreatePendingRecord, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, } engine.RegisterTemplate(WorkflowChargeBalance, &chargeBalanceTemplate) - adminCreditBalanceTemplate := newKubecloudWorkflowTemplate(notificationDispatcher) - adminCreditBalanceTemplate.Steps = []ewf.Step{ - {Name: StepUpdateCreditedBalance, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, - {Name: StepCreatePendingRecord, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, - } - engine.RegisterTemplate(WorkflowAdminCreditBalance, &adminCreditBalanceTemplate) - - redeemVoucherTemplate := newKubecloudWorkflowTemplate(notificationDispatcher) - redeemVoucherTemplate.Steps = []ewf.Step{ - {Name: StepUpdateCreditedBalance, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, - {Name: StepCreatePendingRecord, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, - } - engine.RegisterTemplate(WorkflowRedeemVoucher, &redeemVoucherTemplate) - reserveNodeTemplate := newKubecloudWorkflowTemplate(notificationDispatcher) reserveNodeTemplate.Steps = []ewf.Step{ {Name: StepReserveNode, RetryPolicy: &ewf.RetryPolicy{MaxAttempts: 2, BackOff: ewf.ConstantBackoff(2 * time.Second)}}, @@ -144,7 +126,7 @@ func RegisterEWFWorkflows( // trackClusterHealthWFTemplate.BeforeWorkflowHooks = []ewf.BeforeWorkflowHook{hookNotificationWorkflowStarted} engine.RegisterTemplate(WorkflowTrackClusterHealth, &trackClusterHealthWFTemplate) - registerDeploymentActivities(engine, metrics, clusterRepo, notificationDispatcher, config) + registerDeploymentActivities(engine, metrics, clusterRepo, contractsRepo, notificationDispatcher, config) // Email-only workflow for guaranteed email delivery with retries emailNotificationTemplate := ewf.WorkflowTemplate{ diff --git a/backend/internal/infrastructure/gridclient/grid_client.go b/backend/internal/infrastructure/gridclient/grid_client.go index 446831653..dd737dc25 100644 --- a/backend/internal/infrastructure/gridclient/grid_client.go +++ b/backend/internal/infrastructure/gridclient/grid_client.go @@ -17,6 +17,7 @@ import ( substrate "github.com/threefoldtech/tfchain/clients/tfchain-client-go" "github.com/threefoldtech/tfgrid-sdk-go/grid-client/calculator" "github.com/threefoldtech/tfgrid-sdk-go/grid-client/deployer" + "github.com/threefoldtech/tfgrid-sdk-go/grid-client/graphql" client "github.com/threefoldtech/tfgrid-sdk-go/grid-client/node" "github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/pkg/types" sdktrace "go.opentelemetry.io/otel/sdk/trace" @@ -37,6 +38,7 @@ type GridClient interface { TransferTFTsToSystem(tftBalance uint64, userMnemonic string) error SystemIdentity() (substrate.Identity, error) GetTwinIDFromUserMnemonic(mnemonic string) (uint64, error) + GetTwin(mnemonic string) (uint32, error) GetFreeBalanceTFT(mnemonic string) (uint64, error) GetUserAddress(mnemonic string) (string, error) AcceptTermsAndConditions(mnemonic, docLink, docHash string) error @@ -44,6 +46,9 @@ type GridClient interface { CreateRentContract(mnemonic string, nodeID uint32) (uint64, error) CancelContract(mnemonic string, contractID uint64) error SetupUserOnTFChain(termsAndConditions config.TermsANDConditions) (mnemonic string, twinID uint32, err error) + GetPricingPolicy(policyID uint32) (pricingPolicy substrate.PricingPolicy, err error) + GetTFTBillingRateAt(block uint64) (float64, error) + GetCurrentHeight() (uint32, error) // node methods GetNodeClient(nodeID uint32) (*client.NodeClient, error) @@ -51,6 +56,9 @@ type GridClient interface { // calculator methods NewCalculator(mnemonic string) (calculator.Calculator, error) + // contracts getter methods + NewContractsGetter(twinID uint32, graphqlClient graphql.GraphQl) graphql.ContractsGetter + // grid-proxy client methods Node(ctx context.Context, nodeID uint32) (res types.NodeWithNestedCapacity, err error) Nodes(ctx context.Context, filter types.NodeFilter, pagination types.Limit) (res []types.Node, totalCount int, err error) @@ -264,6 +272,15 @@ func (s *gridClient) GetTwinIDFromUserMnemonic(mnemonic string) (uint64, error) return uint64(twinID), nil } +func (s *gridClient) GetTwin(mnemonic string) (uint32, error) { + identity, err := s.getIdentity(mnemonic) + if err != nil { + return 0, err + } + + return s.gridClient.SubstrateConn.GetTwinByPubKey(identity.PublicKey()) +} + // GetNodeClient gents the node client given nodeID func (s *gridClient) GetNodeClient(nodeID uint32) (*client.NodeClient, error) { return s.gridClient.NcPool.GetNodeClient(s.gridClient.SubstrateConn, nodeID) @@ -279,6 +296,10 @@ func (s *gridClient) NewCalculator(mnemonic string) (calculator.Calculator, erro return calculator.NewCalculator(s.gridClient.SubstrateConn, identity), nil } +func (s *gridClient) NewContractsGetter(twinID uint32, graphqlClient graphql.GraphQl) graphql.ContractsGetter { + return graphql.NewContractsGetter(twinID, graphqlClient, s.gridClient.SubstrateConn, s.gridClient.NcPool) +} + // GetFreeBalance returns free balance from user mnemonic func (s *gridClient) GetFreeBalanceTFT(mnemonic string) (uint64, error) { identity, err := s.getIdentity(mnemonic) @@ -418,6 +439,18 @@ func (s *gridClient) SetupUserOnTFChain(termsAndConditions config.TermsANDCondit return mnemonic, twinID, nil } +func (s *gridClient) GetPricingPolicy(policyID uint32) (pricingPolicy substrate.PricingPolicy, err error) { + return s.gridClient.SubstrateConn.GetPricingPolicy(policyID) +} + +func (s *gridClient) GetTFTBillingRateAt(block uint64) (float64, error) { + return s.gridClient.SubstrateConn.GetTFTBillingRateAt(block) +} + +func (s *gridClient) GetCurrentHeight() (uint32, error) { + return s.gridClient.SubstrateConn.GetCurrentHeight() +} + func (s *gridClient) Close() { s.gridClient.Close() } diff --git a/backend/internal/infrastructure/persistence/gorm.go b/backend/internal/infrastructure/persistence/gorm.go index 1ed25a920..f7b70cf2b 100644 --- a/backend/internal/infrastructure/persistence/gorm.go +++ b/backend/internal/infrastructure/persistence/gorm.go @@ -34,14 +34,15 @@ func newGormDB(db *gorm.DB) (*GormDB, error) { err := db.AutoMigrate( &models.User{}, &models.Voucher{}, - models.Transaction{}, - models.Invoice{}, - models.NodeItem{}, - models.UserNodes{}, + &models.Transaction{}, + &models.Invoice{}, + &models.NodeItem{}, + &models.UserContractData{}, &models.Notification{}, &models.SSHKey{}, &models.Cluster{}, - &models.PendingRecord{}, + &models.TransferRecord{}, + &models.UserUsageCalculationTime{}, &models.Settings{}, ) if err != nil { @@ -82,7 +83,7 @@ func ensureSoftDeleteIndexes(db *gorm.DB) error { `DROP INDEX IF EXISTS idx_user_project`, `CREATE UNIQUE INDEX IF NOT EXISTS idx_user_project ON clusters (user_id, project_name) WHERE deleted_at IS NULL`, `DROP INDEX IF EXISTS idx_user_node_id`, - `CREATE UNIQUE INDEX IF NOT EXISTS idx_user_node_id ON user_nodes (node_id) WHERE deleted_at IS NULL`, + `CREATE UNIQUE INDEX IF NOT EXISTS idx_user_node_id ON user_contract_data (node_id) WHERE deleted_at IS NULL`, `DROP INDEX IF EXISTS idx_users_email`, `CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users (email) WHERE deleted_at IS NULL`, `DROP INDEX IF EXISTS idx_user_name`, diff --git a/backend/internal/infrastructure/persistence/migrate.go b/backend/internal/infrastructure/persistence/migrate.go index 4d515cecc..676bfc4c4 100644 --- a/backend/internal/infrastructure/persistence/migrate.go +++ b/backend/internal/infrastructure/persistence/migrate.go @@ -30,13 +30,13 @@ func MigrateAll(ctx context.Context, src models.DB, dst models.DB) error { if err := migrateNodeItems(ctx, srcGormDB, dstGormDB); err != nil { return fmt.Errorf("node_items: %w", err) } - if err := migrateUserNodes(ctx, srcGormDB, dstGormDB); err != nil { - return fmt.Errorf("user_nodes: %w", err) + if err := migrateUserContractData(ctx, srcGormDB, dstGormDB); err != nil { + return fmt.Errorf("user_contract_data: %w", err) } if err := migrateClusters(ctx, srcGormDB, dstGormDB); err != nil { return fmt.Errorf("clusters: %w", err) } - if err := migratePendingRecords(ctx, srcGormDB, dstGormDB); err != nil { + if err := migrateTransferRecords(ctx, srcGormDB, dstGormDB); err != nil { return fmt.Errorf("pending_records: %w", err) } if err := migrateNotificationsToDst(ctx, srcGormDB, dstGormDB); err != nil { @@ -107,8 +107,8 @@ func migrateNodeItems(ctx context.Context, src *GormDB, dst *GormDB) error { return insertOnConflictReturnError(ctx, dst, rows) } -func migrateUserNodes(ctx context.Context, src *GormDB, dst *GormDB) error { - var rows []models.UserNodes +func migrateUserContractData(ctx context.Context, src *GormDB, dst *GormDB) error { + var rows []models.UserContractData if err := src.db.WithContext(ctx).Find(&rows).Error; err != nil { return err } @@ -123,8 +123,8 @@ func migrateClusters(ctx context.Context, src *GormDB, dst *GormDB) error { return insertOnConflictReturnError(ctx, dst, rows) } -func migratePendingRecords(ctx context.Context, src *GormDB, dst *GormDB) error { - var rows []models.PendingRecord +func migrateTransferRecords(ctx context.Context, src *GormDB, dst *GormDB) error { + var rows []models.TransferRecord if err := src.db.WithContext(ctx).Find(&rows).Error; err != nil { return err }