diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..899af19 --- /dev/null +++ b/example/main.go @@ -0,0 +1,171 @@ +package main + +import ( + "context" + "fmt" + "log" + + rbac "github.com/codescalers/rbac/pkg" + "github.com/codescalers/rbac/pkg/store" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +// Blog represents a blog post resource +type Blog struct { + ID string + Title string + Content string + OwnerID string +} + +// Name implements the rbac.Resource interface +func (b Blog) Name() string { + return "blog" +} + +// BlogOwnershipRule ensures users can only access their own blogs +type BlogOwnershipRule struct{} + +func (r BlogOwnershipRule) Name() string { + return "blog_ownership" +} + +func (r BlogOwnershipRule) Evaluate(ctx context.Context, subjectID string, resource rbac.Resource) (bool, error) { + blog, ok := resource.(Blog) + if !ok { + return false, fmt.Errorf("expected Blog resource, got %T", resource) + } + + // Allow access if the user is the owner + return blog.OwnerID == subjectID, nil +} + +func main() { + ctx := context.Background() + + // Initialize SQLite database + db, err := gorm.Open(sqlite.Open("rbac.db"), &gorm.Config{}) + if err != nil { + log.Fatal("Failed to connect to database:", err) + } + + // Create GORM store + gormStore, err := store.NewGormStore(db) + if err != nil { + log.Fatal("Failed to create store:", err) + } + + // Initialize RBAC + r, err := rbac.NewRBAC(ctx, gormStore) + if err != nil { + log.Fatal(err) + } + + // Register business rule for blog ownership + bizRole := rbac.BizRule(BlogOwnershipRule{}) + if err := r.RegisterBizRule(bizRole); err != nil { + log.Fatal(err) + } + + // Create permissions + readPerm, err := r.CreatePermission(ctx, "blog", "read") + if err != nil { + log.Fatal(err) + } + updateOwnPerm, err := r.CreatePermission(ctx, "blog", "update", bizRole.Name()) + if err != nil { + log.Fatal(err) + } + deleteOwnPerm, err := r.CreatePermission(ctx, "blog", "delete", bizRole.Name()) + if err != nil { + log.Fatal(err) + } + createPerm, err := r.CreatePermission(ctx, "blog", "create") + if err != nil { + log.Fatal(err) + } + + updateAll, err := r.CreatePermission(ctx, "blog", "update") + if err != nil { + log.Fatal(err) + } + deleteAll, err := r.CreatePermission(ctx, "blog", "delete") + if err != nil { + log.Fatal(err) + } + + // Create roles + userRole, err := r.CreateRole(ctx, "user", "Regular user with blog access") + if err != nil { + log.Fatal(err) + } + adminRole, err := r.CreateRole(ctx, "admin", "Administrator with full access", userRole.ID) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Created roles: user=%s, admin=%s\n", userRole.ID, adminRole.ID) + + //Add user permissions + if err := r.AddPermissionToRole(ctx, "user", readPerm.ID); err != nil { + log.Fatal(err) + } + if err := r.AddPermissionToRole(ctx, "user", createPerm.ID); err != nil { + log.Fatal(err) + } + if err := r.AddPermissionToRole(ctx, "user", updateOwnPerm.ID); err != nil { + log.Fatal(err) + } + if err := r.AddPermissionToRole(ctx, "user", deleteOwnPerm.ID); err != nil { + log.Fatal(err) + } + + //Add admin permissions + if err := r.AddPermissionToRole(ctx, "admin", updateAll.ID); err != nil { + log.Fatal(err) + } + if err := r.AddPermissionToRole(ctx, "admin", deleteAll.ID); err != nil { + log.Fatal(err) + } + + // Create test users + adminUserID := "admin-user-123" + regularUserID := "regular-user-456" + + // Create subjects with roles using role names + if err := r.CreateSubjectWithRole(ctx, adminUserID, "admin"); err != nil { + log.Fatal(err) + } + if err := r.CreateSubjectWithRole(ctx, regularUserID, "user"); err != nil { + log.Fatal(err) + } + + fmt.Println("Created subjects with roles") + + // Test blogs + blog1 := Blog{ID: "blog-1", Title: "Admin's Blog", OwnerID: adminUserID} + blog2 := Blog{ID: "blog-2", Title: "User's Blog", OwnerID: regularUserID} + + hasPerm, err := r.Can(ctx, regularUserID, "update", blog2) + if err != nil { + log.Fatal(err) + } + fmt.Printf("User has permission to update blog2: %v\n", hasPerm) + hasPerm, err = r.Can(ctx, regularUserID, "update", blog1) + if err != nil { + log.Fatal(err) + } + fmt.Printf("User has permission to update blog1: %v\n", hasPerm) + + hasPerm, err = r.Can(ctx, adminUserID, "update", blog1) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Admin has permission to update blog1: %v\n", hasPerm) + hasPerm, err = r.Can(ctx, adminUserID, "update", blog2) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Admin has permission to update blog2: %v\n", hasPerm) +} diff --git a/go.mod b/go.mod index 6b1b798..05d2714 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,23 @@ module github.com/codescalers/rbac go 1.24.6 -require github.com/google/uuid v1.6.0 +require ( + github.com/google/uuid v1.6.0 + go.uber.org/mock v0.6.0 + gorm.io/gorm v1.31.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/text v0.20.0 // indirect + gorm.io/driver/sqlite v1.6.0 +) diff --git a/go.sum b/go.sum index 7790d7c..84568cb 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,25 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/internal/mocks/store_mock.go b/internal/mocks/store_mock.go new file mode 100644 index 0000000..a1c792c --- /dev/null +++ b/internal/mocks/store_mock.go @@ -0,0 +1,245 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: pkg/store.go +// +// Generated by this command: +// +// mockgen -source=pkg/store.go -destination=internal/mocks/store_mock.go -package=mocks +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + rbac "github.com/codescalers/rbac/pkg" + gomock "go.uber.org/mock/gomock" +) + +// MockStore is a mock of Store interface. +type MockStore struct { + ctrl *gomock.Controller + recorder *MockStoreMockRecorder + isgomock struct{} +} + +// MockStoreMockRecorder is the mock recorder for MockStore. +type MockStoreMockRecorder struct { + mock *MockStore +} + +// NewMockStore creates a new mock instance. +func NewMockStore(ctrl *gomock.Controller) *MockStore { + mock := &MockStore{ctrl: ctrl} + mock.recorder = &MockStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStore) EXPECT() *MockStoreMockRecorder { + return m.recorder +} + +// CreatePermission mocks base method. +func (m *MockStore) CreatePermission(ctx context.Context, p rbac.Permission) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePermission", ctx, p) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreatePermission indicates an expected call of CreatePermission. +func (mr *MockStoreMockRecorder) CreatePermission(ctx, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePermission", reflect.TypeOf((*MockStore)(nil).CreatePermission), ctx, p) +} + +// CreateRole mocks base method. +func (m *MockStore) CreateRole(ctx context.Context, role rbac.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateRole", ctx, role) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateRole indicates an expected call of CreateRole. +func (mr *MockStoreMockRecorder) CreateRole(ctx, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRole", reflect.TypeOf((*MockStore)(nil).CreateRole), ctx, role) +} + +// CreateSubject mocks base method. +func (m *MockStore) CreateSubject(ctx context.Context, subject rbac.Subject) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSubject", ctx, subject) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateSubject indicates an expected call of CreateSubject. +func (mr *MockStoreMockRecorder) CreateSubject(ctx, subject any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSubject", reflect.TypeOf((*MockStore)(nil).CreateSubject), ctx, subject) +} + +// GetPermission mocks base method. +func (m *MockStore) GetPermission(ctx context.Context, id string) (rbac.Permission, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPermission", ctx, id) + ret0, _ := ret[0].(rbac.Permission) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPermission indicates an expected call of GetPermission. +func (mr *MockStoreMockRecorder) GetPermission(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermission", reflect.TypeOf((*MockStore)(nil).GetPermission), ctx, id) +} + +// GetRole mocks base method. +func (m *MockStore) GetRole(ctx context.Context, roleID string) (rbac.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRole", ctx, roleID) + ret0, _ := ret[0].(rbac.Role) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRole indicates an expected call of GetRole. +func (mr *MockStoreMockRecorder) GetRole(ctx, roleID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRole", reflect.TypeOf((*MockStore)(nil).GetRole), ctx, roleID) +} + +// GetRoleByName mocks base method. +func (m *MockStore) GetRoleByName(ctx context.Context, name string) (rbac.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRoleByName", ctx, name) + ret0, _ := ret[0].(rbac.Role) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRoleByName indicates an expected call of GetRoleByName. +func (mr *MockStoreMockRecorder) GetRoleByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoleByName", reflect.TypeOf((*MockStore)(nil).GetRoleByName), ctx, name) +} + +// GetSubject mocks base method. +func (m *MockStore) GetSubject(ctx context.Context, subjectID string) (rbac.Subject, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSubject", ctx, subjectID) + ret0, _ := ret[0].(rbac.Subject) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSubject indicates an expected call of GetSubject. +func (mr *MockStoreMockRecorder) GetSubject(ctx, subjectID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubject", reflect.TypeOf((*MockStore)(nil).GetSubject), ctx, subjectID) +} + +// ListPermissions mocks base method. +func (m *MockStore) ListPermissions(ctx context.Context) ([]rbac.Permission, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListPermissions", ctx) + ret0, _ := ret[0].([]rbac.Permission) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListPermissions indicates an expected call of ListPermissions. +func (mr *MockStoreMockRecorder) ListPermissions(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListPermissions", reflect.TypeOf((*MockStore)(nil).ListPermissions), ctx) +} + +// ListRoles mocks base method. +func (m *MockStore) ListRoles(ctx context.Context) ([]rbac.Role, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRoles", ctx) + ret0, _ := ret[0].([]rbac.Role) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRoles indicates an expected call of ListRoles. +func (mr *MockStoreMockRecorder) ListRoles(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRoles", reflect.TypeOf((*MockStore)(nil).ListRoles), ctx) +} + +// ListSubjects mocks base method. +func (m *MockStore) ListSubjects(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSubjects", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSubjects indicates an expected call of ListSubjects. +func (mr *MockStoreMockRecorder) ListSubjects(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSubjects", reflect.TypeOf((*MockStore)(nil).ListSubjects), ctx) +} + +// RemovePermission mocks base method. +func (m *MockStore) RemovePermission(ctx context.Context, id string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemovePermission", ctx, id) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemovePermission indicates an expected call of RemovePermission. +func (mr *MockStoreMockRecorder) RemovePermission(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePermission", reflect.TypeOf((*MockStore)(nil).RemovePermission), ctx, id) +} + +// RemoveRole mocks base method. +func (m *MockStore) RemoveRole(ctx context.Context, roleID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoveRole", ctx, roleID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RemoveRole indicates an expected call of RemoveRole. +func (mr *MockStoreMockRecorder) RemoveRole(ctx, roleID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveRole", reflect.TypeOf((*MockStore)(nil).RemoveRole), ctx, roleID) +} + +// UpdateRole mocks base method. +func (m *MockStore) UpdateRole(ctx context.Context, role rbac.Role) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateRole", ctx, role) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateRole indicates an expected call of UpdateRole. +func (mr *MockStoreMockRecorder) UpdateRole(ctx, role any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRole", reflect.TypeOf((*MockStore)(nil).UpdateRole), ctx, role) +} + +// UpdateSubject mocks base method. +func (m *MockStore) UpdateSubject(ctx context.Context, subject rbac.Subject) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSubject", ctx, subject) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSubject indicates an expected call of UpdateSubject. +func (mr *MockStoreMockRecorder) UpdateSubject(ctx, subject any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubject", reflect.TypeOf((*MockStore)(nil).UpdateSubject), ctx, subject) +} diff --git a/pkg/biz_rules.go b/pkg/biz_rules.go new file mode 100644 index 0000000..fe0719d --- /dev/null +++ b/pkg/biz_rules.go @@ -0,0 +1,50 @@ +package rbac + +import ( + "context" + "fmt" +) + +// Resource represents any entity that can be accessed or modified +type Resource interface { + Name() string +} + +// BizRule defines a custom business rule for fine-grained authorization +type BizRule interface { + Name() string + Evaluate(ctx context.Context, subjectID string, resource Resource) (bool, error) +} + +// RegisterBizRule registers a custom business rule +func (r *RBAC) RegisterBizRule(rule BizRule) error { + if rule == nil { + return fmt.Errorf("business rule cannot be nil") + } + + name := rule.Name() + if name == "" { + return fmt.Errorf("business rule name cannot be empty") + } + + if r.bizRules == nil { + r.bizRules = make(map[string]BizRule) + } + + if _, exists := r.bizRules[name]; exists { + return fmt.Errorf("business rule %q already registered", name) + } + + r.bizRules[name] = rule + return nil +} + +// GetBizRule retrieves a registered business rule by name +func (r *RBAC) GetBizRule(name string) (BizRule, bool) { + if r.bizRules == nil { + return nil, false + } + + rule, exists := r.bizRules[name] + return rule, exists +} diff --git a/pkg/errors.go b/pkg/errors.go index 74d7bc9..fc7e129 100644 --- a/pkg/errors.go +++ b/pkg/errors.go @@ -1,5 +1,6 @@ package rbac +// Common RBAC errors var ( ErrAlreadyExists = errorString("already exists") ErrInvalidResourceOrAction = errorString("invalid resource or action") @@ -9,6 +10,8 @@ var ( ErrPermissionInUse = errorString("permission is in use and cannot be removed") ErrDuplicateRole = errorString("role with this name already exists") ErrDuplicatePermission = errorString("permission with this resource and action already exists") + ErrRoleCycle = errorString("role hierarchy cycle detected") + ErrRoleHasChildren = errorString("role has child roles and cannot be removed") ) type errorString string diff --git a/pkg/helpers.go b/pkg/helpers.go new file mode 100644 index 0000000..b81b2be --- /dev/null +++ b/pkg/helpers.go @@ -0,0 +1,39 @@ +package rbac + +import ( + "strings" + + "github.com/google/uuid" +) + +func normalizeString(s string) string { + return strings.ToLower(strings.TrimSpace(s)) +} + +func validateUUIDs(ids ...string) error { + for _, id := range ids { + if _, err := uuid.Parse(id); err != nil { + return err + } + } + return nil +} + +func roleHasPermission(role *Role, permID string) bool { + for _, p := range role.Permissions { + if p.ID == permID { + return true + } + } + return false +} + +func filterOutPermission(perms []Permission, permID string) []Permission { + filtered := perms[:0] + for _, p := range perms { + if p.ID != permID { + filtered = append(filtered, p) + } + } + return filtered +} diff --git a/pkg/hierarchy.go b/pkg/hierarchy.go new file mode 100644 index 0000000..b9f715a --- /dev/null +++ b/pkg/hierarchy.go @@ -0,0 +1,52 @@ +package rbac + +import "context" + +type roleVisitor func(role Role) error + +func (r *RBAC) traverseRoleHierarchy(ctx context.Context, role Role, visitor roleVisitor) error { + currentRole := role + for { + if err := visitor(currentRole); err != nil { + return err + } + + if currentRole.ParentID == "" { + break + } + + parentRole, err := r.store.GetRole(ctx, currentRole.ParentID) + if err != nil { + return err + } + + currentRole = parentRole + } + + return nil +} + +func (r *RBAC) checkRoleHierarchyCycle(ctx context.Context, parentID, childID string) error { + visited := make(map[string]bool) + currentID := parentID + + for currentID != "" { + if currentID == childID { + return ErrRoleCycle + } + + if visited[currentID] { + return ErrRoleCycle + } + visited[currentID] = true + + role, err := r.store.GetRole(ctx, currentID) + if err != nil { + return err + } + + currentID = role.ParentID + } + + return nil +} diff --git a/pkg/rbac.go b/pkg/rbac.go index 0319d2c..e9a7ae9 100644 --- a/pkg/rbac.go +++ b/pkg/rbac.go @@ -2,17 +2,22 @@ package rbac import ( "context" - "strings" "github.com/google/uuid" ) +// RBAC is the main role-based access control manager type RBAC struct { - store Store + store Store + bizRules map[string]BizRule } -func New(ctx context.Context, s Store, opts ...Option) (*RBAC, error) { - r := &RBAC{store: s} +// NewRBAC creates a new RBAC instance with the given store and optional configuration +func NewRBAC(ctx context.Context, s Store, opts ...Option) (*RBAC, error) { + r := &RBAC{ + store: s, + bizRules: make(map[string]BizRule), + } for _, opt := range opts { if err := opt(ctx, r); err != nil { return nil, err @@ -21,8 +26,10 @@ func New(ctx context.Context, s Store, opts ...Option) (*RBAC, error) { return r, nil } +// Option is a function that configures the RBAC instance during initialization type Option func(ctx context.Context, r *RBAC) error +// WithSeed seeds the RBAC system with predefined roles func WithSeed(roles []Role) Option { return func(ctx context.Context, r *RBAC) error { return r.initFromSeed(ctx, roles) @@ -38,34 +45,75 @@ func (r *RBAC) initFromSeed(ctx context.Context, roles []Role) error { return nil } -func (r *RBAC) CreateRole(ctx context.Context, name, description string) error { - n := strings.ToLower(strings.TrimSpace(name)) +// CreateRole creates a new role with optional parent for hierarchy +func (r *RBAC) CreateRole(ctx context.Context, name, description string, parentID ...string) (Role, error) { + n := normalizeString(name) if n == "" { - return ErrInvalidName + return Role{}, ErrInvalidName } exists, err := r.roleNameExists(ctx, n) if err != nil { - return err + return Role{}, err } if exists { - return ErrDuplicateRole + return Role{}, ErrDuplicateRole } role := Role{ID: uuid.New().String(), Name: n, Description: description} - return r.store.CreateRole(ctx, role) + + if len(parentID) > 0 && parentID[0] != "" { + if err := validateUUIDs(parentID[0]); err != nil { + return Role{}, err + } + + // Verify parent exists + if _, err := r.store.GetRole(ctx, parentID[0]); err != nil { + return Role{}, err + } + + role.ParentID = parentID[0] + } + + if err := r.store.CreateRole(ctx, role); err != nil { + return Role{}, err + } + return role, nil } -func (r *RBAC) RemoveRole(ctx context.Context, roleID string) error { - if err := validateUUIDs(roleID); err != nil { - return err +// UpdateRole updates an existing role's parent, which can be used to reorganize the hierarchy +func (r *RBAC) UpdateRole(ctx context.Context, roleName, newParentName string) error { + role, err := r.store.GetRoleByName(ctx, roleName) + if err != nil { + return ErrNotFound } - if _, err := r.store.GetRole(ctx, roleID); err != nil { + if newParentName != "" { + parent, err := r.store.GetRoleByName(ctx, newParentName) + if err != nil { + return ErrNotFound + } + + if err := r.checkRoleHierarchyCycle(ctx, parent.ID, role.ID); err != nil { + return err + } + + role.ParentID = parent.ID + } else { + role.ParentID = "" + } + + return r.store.UpdateRole(ctx, role) +} + +// RemoveRole deletes a role if it's not in use by any subject +func (r *RBAC) RemoveRole(ctx context.Context, roleName string) error { + role, err := r.store.GetRoleByName(ctx, roleName) + if err != nil { return ErrNotFound } - inUse, err := r.isRoleInUse(ctx, roleID) + inUse, err := r.isRoleInUse(ctx, role.ID) if err != nil { return err } @@ -73,28 +121,39 @@ func (r *RBAC) RemoveRole(ctx context.Context, roleID string) error { return ErrRoleInUse } - return r.store.RemoveRole(ctx, roleID) + return r.store.RemoveRole(ctx, role.ID) } -func (r *RBAC) CreatePermission(ctx context.Context, resource, action string) error { - res := strings.ToLower(strings.TrimSpace(resource)) - a := strings.ToLower(strings.TrimSpace(action)) +// CreatePermission creates a new permission with optional business rule +func (r *RBAC) CreatePermission(ctx context.Context, resource, action string, bizRuleName ...string) (Permission, error) { + res := normalizeString(resource) + a := normalizeString(action) if res == "" || a == "" { - return ErrInvalidResourceOrAction + return Permission{}, ErrInvalidResourceOrAction + } + + bizRule := "" + if len(bizRuleName) > 0 && bizRuleName[0] != "" { + bizRule = bizRuleName[0] } - exists, err := r.permissionExists(ctx, res, a) + exists, err := r.permissionExists(ctx, res, a, bizRule) if err != nil { - return err + return Permission{}, err } if exists { - return ErrDuplicatePermission + return Permission{}, ErrDuplicatePermission } - p := Permission{ID: uuid.New().String(), Resource: res, Action: a} - return r.store.CreatePermission(ctx, p) + p := Permission{ID: uuid.New().String(), Resource: res, Action: a, BizRule: bizRule} + + if err := r.store.CreatePermission(ctx, p); err != nil { + return Permission{}, err + } + return p, nil } +// RemovePermission deletes a permission if it's not assigned to any role func (r *RBAC) RemovePermission(ctx context.Context, permID string) error { if err := validateUUIDs(permID); err != nil { return err @@ -115,69 +174,46 @@ func (r *RBAC) RemovePermission(ctx context.Context, permID string) error { return r.store.RemovePermission(ctx, permID) } -func (r *RBAC) AssignRole(ctx context.Context, subjectID, roleID string) error { - if err := validateUUIDs(roleID); err != nil { - return err - } - - if _, err := r.store.GetRole(ctx, roleID); err != nil { - return ErrNotFound - } - - user, err := r.store.GetSubject(ctx, subjectID) +// CreateSubjectWithRole creates a new subject and assigns a role by role name +func (r *RBAC) CreateSubjectWithRole(ctx context.Context, subjectID, roleName string) error { + role, err := r.store.GetRoleByName(ctx, roleName) if err != nil { - return err - } - - for _, role := range user.Roles { - if role.ID == roleID { - return ErrAlreadyExists - } + return ErrNotFound } - role, err := r.store.GetRole(ctx, roleID) - if err != nil { - return err + subject := Subject{ + ID: subjectID, + RoleID: role.ID, } - - user.Roles = append(user.Roles, role) - return r.store.UpdateSubject(ctx, user) + return r.store.CreateSubject(ctx, subject) } -func (r *RBAC) RevokeRole(ctx context.Context, subjectID, roleID string) error { - if err := validateUUIDs(roleID); err != nil { - return err +// AssignRole assigns a role to a subject +func (r *RBAC) AssignRole(ctx context.Context, subjectID, roleName string) error { + role, err := r.store.GetRoleByName(ctx, roleName) + if err != nil { + return ErrNotFound } - user, err := r.store.GetSubject(ctx, subjectID) + subject, err := r.store.GetSubject(ctx, subjectID) if err != nil { return err } - found := false - for _, role := range user.Roles { - if role.ID == roleID { - found = true - break - } - } - if !found { - return ErrNotFound - } - - user.Roles = r.filterOutRole(user.Roles, roleID) - return r.store.UpdateSubject(ctx, user) + subject.RoleID = role.ID + return r.store.UpdateSubject(ctx, subject) } -func (r *RBAC) AddPermissionToRole(ctx context.Context, roleID, permID string) error { - if err := validateUUIDs(roleID, permID); err != nil { +// AddPermissionToRole adds a permission to a role +func (r *RBAC) AddPermissionToRole(ctx context.Context, roleName, permID string) error { + if err := validateUUIDs(permID); err != nil { return err } - role, err := r.store.GetRole(ctx, roleID) + role, err := r.store.GetRoleByName(ctx, roleName) if err != nil { return ErrNotFound } - if r.roleHasPermission(&role, permID) { + if roleHasPermission(&role, permID) { return ErrAlreadyExists } p, err := r.store.GetPermission(ctx, permID) @@ -188,222 +224,103 @@ func (r *RBAC) AddPermissionToRole(ctx context.Context, roleID, permID string) e return r.store.UpdateRole(ctx, role) } -func (r *RBAC) RemovePermissionFromRole(ctx context.Context, roleID, permID string) error { - if err := validateUUIDs(roleID, permID); err != nil { +// RemovePermissionFromRole removes a permission from a role +func (r *RBAC) RemovePermissionFromRole(ctx context.Context, roleName, permID string) error { + if err := validateUUIDs(permID); err != nil { return err } - role, err := r.store.GetRole(ctx, roleID) + role, err := r.store.GetRoleByName(ctx, roleName) if err != nil { return ErrNotFound } - if !r.roleHasPermission(&role, permID) { + if !roleHasPermission(&role, permID) { return ErrNotFound } - role.Permissions = r.filterOutPermission(role.Permissions, permID) + role.Permissions = filterOutPermission(role.Permissions, permID) return r.store.UpdateRole(ctx, role) } -// Direct grants -func (r *RBAC) GrantSubject(ctx context.Context, subjectID string, resource, action, resourceID string) error { - res := strings.ToLower(strings.TrimSpace(resource)) - act := strings.ToLower(strings.TrimSpace(action)) - if res == "" || act == "" { - return ErrInvalidResourceOrAction - } - - user, err := r.store.GetSubject(ctx, subjectID) +// Can checks if a subject has permission to perform an action on a resource +func (r *RBAC) Can(ctx context.Context, subjectID, action string, resource Resource) (bool, error) { + subject, err := r.store.GetSubject(ctx, subjectID) if err != nil { - return err + return false, err } - rid := strings.ToLower(strings.TrimSpace(resourceID)) - grant := Grant{ID: uuid.New().String(), Resource: res, Action: act, ResourceID: rid} - user.Grants = append(user.Grants, grant) - - return r.store.UpdateSubject(ctx, user) -} - -func (r *RBAC) RevokeSubjectGrant(ctx context.Context, subjectID, grantID string) error { - if err := validateUUIDs(grantID); err != nil { - return err + if resource == nil { + return false, ErrInvalidResourceOrAction } - user, err := r.store.GetSubject(ctx, subjectID) - if err != nil { - return err + res := normalizeString(resource.Name()) + act := normalizeString(action) + if res == "" || act == "" { + return false, ErrInvalidResourceOrAction } - found := false - for _, grant := range user.Grants { - if grant.ID == grantID { - found = true - break - } - } - if !found { - return ErrNotFound + if subject.RoleID == "" { + return false, nil } - user.Grants = r.filterOutGrant(user.Grants, grantID) - return r.store.UpdateSubject(ctx, user) -} + // Validate roleID + if err := validateUUIDs(subject.RoleID); err != nil { + return false, err + } -func (r *RBAC) Can(ctx context.Context, subjectID, action, resource string, resourceID ...string) (bool, error) { - user, err := r.store.GetSubject(ctx, subjectID) + role, err := r.store.GetRole(ctx, subject.RoleID) if err != nil { return false, err } - id := "" - if len(resourceID) > 0 && resourceID[0] != "" { - id = resourceID[0] - } - res := strings.ToLower(strings.TrimSpace(resource)) - act := strings.ToLower(strings.TrimSpace(action)) - if res == "" || act == "" { - return false, ErrInvalidResourceOrAction - } + return r.checkRolePermission(ctx, role, subjectID, act, resource) +} - //direct grants - for _, g := range user.Grants { - if g.Resource != res { - continue - } - if g.Action != act { - continue - } - if g.ResourceID != id && g.ResourceID != "*" { - continue - } - return true, nil - } +func (r *RBAC) checkRolePermission(ctx context.Context, role Role, subjectID, action string, resourceObj Resource) (bool, error) { + var hasPermission bool + var permErr error - //role permissions - for _, role := range user.Roles { - for _, p := range role.Permissions { - if p.Resource != res { + err := r.traverseRoleHierarchy(ctx, role, func(currentRole Role) error { + for _, p := range currentRole.Permissions { + if p.Resource != resourceObj.Name() { continue } - if p.Action != act { + if p.Action != action { continue } - return true, nil - } - } - return false, nil -} + if p.BizRule == "" { + hasPermission = true + return nil + } -func (r *RBAC) roleHasPermission(role *Role, permID string) bool { - for _, p := range role.Permissions { - if p.ID == permID { - return true - } - } - return false -} + rule, exists := r.GetBizRule(p.BizRule) + if !exists { + continue + } -func (r *RBAC) filterOutPermission(perms []Permission, permID string) []Permission { - filtered := perms[:0] - for _, p := range perms { - if p.ID != permID { - filtered = append(filtered, p) - } - } - return filtered -} + allowed, err := rule.Evaluate(ctx, subjectID, resourceObj) + if err != nil { + permErr = err + return err + } -func validateUUIDs(ids ...string) error { - for _, id := range ids { - if _, err := uuid.Parse(id); err != nil { - return err - } - } - return nil -} + if !allowed { + continue + } -func (r *RBAC) roleNameExists(ctx context.Context, name string) (bool, error) { - roles, err := r.store.ListRoles(ctx) - if err != nil { - return false, err - } - for _, role := range roles { - if role.Name == name { - return true, nil + hasPermission = true + return nil } - } - return false, nil -} + return nil + }) -func (r *RBAC) permissionExists(ctx context.Context, resource, action string) (bool, error) { - perms, err := r.store.ListPermissions(ctx) if err != nil { - return false, err - } - for _, p := range perms { - if p.Resource != resource { - continue + if permErr != nil { + return false, permErr } - if p.Action != action { - continue - } - return true, nil - } - return false, nil -} - -func (r *RBAC) isRoleInUse(ctx context.Context, roleID string) (bool, error) { - subjects, err := r.store.ListSubjects(ctx) - if err != nil { return false, err } - for _, subjectID := range subjects { - user, err := r.store.GetSubject(ctx, subjectID) - if err != nil { - continue - } - for _, role := range user.Roles { - if role.ID == roleID { - return true, nil - } - } - } - return false, nil -} -func (r *RBAC) filterOutRole(roles []Role, roleID string) []Role { - filtered := roles[:0] - for _, role := range roles { - if role.ID != roleID { - filtered = append(filtered, role) - } - } - return filtered -} - -func (r *RBAC) filterOutGrant(grants []Grant, grantID string) []Grant { - filtered := grants[:0] - for _, grant := range grants { - if grant.ID != grantID { - filtered = append(filtered, grant) - } - } - return filtered -} - -func (r *RBAC) isPermissionInUse(ctx context.Context, permID string) (bool, error) { - roles, err := r.store.ListRoles(ctx) - if err != nil { - return false, err - } - for _, role := range roles { - for _, p := range role.Permissions { - if p.ID == permID { - return true, nil - } - } - } - return false, nil + return hasPermission, nil } diff --git a/pkg/rbac_test.go b/pkg/rbac_test.go new file mode 100644 index 0000000..87934ba --- /dev/null +++ b/pkg/rbac_test.go @@ -0,0 +1,875 @@ +package rbac_test + +import ( + "context" + "testing" + + "github.com/codescalers/rbac/internal/mocks" + rbac "github.com/codescalers/rbac/pkg" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestNewRBAC(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + t.Run("Create new RBAC instance", func(t *testing.T) { + r, err := rbac.NewRBAC(ctx, store) + + assert.NoError(t, err) + assert.NotNil(t, r) + }) + t.Run("Create new RBAC instance with seed", func(t *testing.T) { + roles := []rbac.Role{ + {ID: "1", Name: "admin", Permissions: []rbac.Permission{{ID: "p1", Resource: "blog", Action: "read"}}}, + {ID: "2", Name: "user", Permissions: []rbac.Permission{{ID: "p2", Resource: "blog", Action: "update"}}}, + } + store.EXPECT().CreateRole(ctx, roles[0]).Return(nil) + store.EXPECT().CreateRole(ctx, roles[1]).Return(nil) + + r, err := rbac.NewRBAC(ctx, store, rbac.WithSeed(roles)) + + assert.NoError(t, err) + assert.NotNil(t, r) + }) +} + +func TestCreateRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + adminID := "550e8400-e29b-41d4-a716-446655440000" + existingAdmin := rbac.Role{ID: adminID, Name: "admin"} + + store.EXPECT().ListRoles(ctx).Return([]rbac.Role{existingAdmin}, nil).Times(4) + store.EXPECT().GetRole(ctx, adminID).Return(rbac.Role{ID: adminID, Name: "admin"}, nil).Times(1) + store.EXPECT().CreateRole(ctx, gomock.Any()).Return(nil).Times(2) + + tests := []struct { + name string + roleName string + description string + parentID string + setupMock func() + expectedError error + validateRole func(role rbac.Role) + }{ + { + name: "Create role with empty name", + roleName: "", + description: "description", + expectedError: rbac.ErrInvalidName, + }, + { + name: "Create role with duplicate name", + roleName: "admin", + description: "description", + expectedError: rbac.ErrDuplicateRole, + }, + { + name: "Create role with non-existing parent", + roleName: "user", + description: "description", + parentID: "660e8400-e29b-41d4-a716-446655440099", + setupMock: func() { + store.EXPECT().GetRole(ctx, "660e8400-e29b-41d4-a716-446655440099").Return(rbac.Role{}, rbac.ErrNotFound).Times(1) + }, + expectedError: rbac.ErrNotFound, + }, + { + name: "Create role successfully without parent", + roleName: "user", + description: "User role", + validateRole: func(role rbac.Role) { + assert.Equal(t, "user", role.Name) + assert.Equal(t, "User role", role.Description) + }, + }, + { + name: "Create role successfully with parent", + roleName: "editor", + description: "Editor role", + parentID: adminID, + validateRole: func(role rbac.Role) { + assert.Equal(t, "editor", role.Name) + assert.Equal(t, "Editor role", role.Description) + assert.Equal(t, adminID, role.ParentID) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupMock != nil { + tt.setupMock() + } + + var role rbac.Role + var err error + if tt.parentID != "" { + role, err = r.CreateRole(ctx, tt.roleName, tt.description, tt.parentID) + } else { + role, err = r.CreateRole(ctx, tt.roleName, tt.description) + } + + if tt.expectedError == nil { + assert.NoError(t, err) + if tt.validateRole != nil { + tt.validateRole(role) + } + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestUpdateRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + adminID := "550e8400-e29b-41d4-a716-446655440000" + editorID := "550e8400-e29b-41d4-a716-446655440001" + viewerID := "550e8400-e29b-41d4-a716-446655440002" + + existingAdmin := rbac.Role{ID: adminID, Name: "admin"} + existingEditor := rbac.Role{ID: editorID, Name: "editor", ParentID: adminID} + existingViewer := rbac.Role{ID: viewerID, Name: "viewer", ParentID: editorID} + + store.EXPECT().GetRoleByName(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, name string) (rbac.Role, error) { + switch name { + case "admin": + return existingAdmin, nil + case "editor": + return existingEditor, nil + case "viewer": + return existingViewer, nil + default: + return rbac.Role{}, rbac.ErrNotFound + } + }).AnyTimes() + store.EXPECT().GetRole(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, id string) (rbac.Role, error) { + switch id { + case adminID: + return existingAdmin, nil + case editorID: + return existingEditor, nil + case viewerID: + return existingViewer, nil + default: + return rbac.Role{}, rbac.ErrNotFound + } + }).AnyTimes() + store.EXPECT().UpdateRole(ctx, gomock.Any()).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + roleName string + newParentName string + expectedError error + expectAnyError bool + }{ + { + name: "Update non-existing role", + roleName: "nonexistent", + newParentName: "admin", + expectedError: rbac.ErrNotFound, + }, + { + name: "Update role with non-existing parent", + roleName: "viewer", + newParentName: "nonexistent", + expectedError: rbac.ErrNotFound, + }, + { + name: "Update role creates cycle", + roleName: "admin", + newParentName: "editor", + expectedError: rbac.ErrRoleCycle, + }, + { + name: "Update role successfully", + roleName: "viewer", + newParentName: "admin", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.UpdateRole(ctx, tt.roleName, tt.newParentName) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestRemoveRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + roleName := "user" + roleID := "550e8400-e29b-41d4-a716-446655440000" + + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ID: roleID, Name: roleName}, nil).Times(3) + store.EXPECT().ListSubjects(ctx).Return([]string{"user1", "user2"}, nil).Times(3) + store.EXPECT().GetSubject(ctx, "user1").Return(rbac.Subject{ID: "user1", RoleID: "some-other-role"}, nil).Times(3) + store.EXPECT().GetSubject(ctx, "user2").Return(rbac.Subject{ID: "user2", RoleID: roleID}, nil).Times(1) + store.EXPECT().GetSubject(ctx, "user2").Return(rbac.Subject{ID: "user2", RoleID: "another-role"}, nil).Times(2) + store.EXPECT().RemoveRole(ctx, roleID).Return(rbac.ErrRoleHasChildren).Times(1) + store.EXPECT().RemoveRole(ctx, roleID).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + roleName string + expectedError error + expectAnyError bool + }{ + { + name: "Remove non-existing role", + roleName: roleName, + expectedError: rbac.ErrNotFound, + }, + { + name: "Remove role that is in use", + roleName: roleName, + expectedError: rbac.ErrRoleInUse, + }, + { + name: "Remove role that has children", + roleName: roleName, + expectedError: rbac.ErrRoleHasChildren, + }, + { + name: "Remove role successfully", + roleName: roleName, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.RemoveRole(ctx, tt.roleName) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestCreatePermission(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + store.EXPECT().ListPermissions(ctx).Return([]rbac.Permission{ + {ID: "p1", Resource: "blog", Action: "read", BizRule: ""}, + }, nil).AnyTimes() + store.EXPECT().CreatePermission(ctx, gomock.Any()).Return(nil).Times(2) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + resource string + action string + bizRuleName []string + expectedError error + validatePerm func(*rbac.Permission) + }{ + { + name: "Create permission with empty resource", + resource: "", + action: "read", + expectedError: rbac.ErrInvalidResourceOrAction, + }, + { + name: "Create permission with empty action", + resource: "blog", + action: "", + expectedError: rbac.ErrInvalidResourceOrAction, + }, + { + name: "Create duplicate permission", + resource: "blog", + action: "read", + expectedError: rbac.ErrDuplicatePermission, + }, + { + name: "Create permission successfully without bizrule", + resource: "article", + action: "write", + validatePerm: func(p *rbac.Permission) { + assert.NotEmpty(t, p.ID) + assert.Equal(t, "article", p.Resource) + assert.Equal(t, "write", p.Action) + assert.Equal(t, "", p.BizRule) + }, + }, + { + name: "Create permission successfully with bizrule", + resource: "post", + action: "delete", + bizRuleName: []string{"is_owner"}, + validatePerm: func(p *rbac.Permission) { + assert.NotEmpty(t, p.ID) + assert.Equal(t, "post", p.Resource) + assert.Equal(t, "delete", p.Action) + assert.Equal(t, "is_owner", p.BizRule) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + perm, err := r.CreatePermission(ctx, tt.resource, tt.action, tt.bizRuleName...) + + if tt.expectedError == nil { + assert.NoError(t, err) + if tt.validatePerm != nil { + tt.validatePerm(&perm) + } + return + } + + assert.Error(t, err) + assert.ErrorIs(t, err, tt.expectedError) + }) + } +} + +func TestRemovePermission(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + permID := "550e8400-e29b-41d4-a716-446655440000" + roleID := "550e8400-e29b-41d4-a716-446655440001" + + store.EXPECT().GetPermission(ctx, permID).Return(rbac.Permission{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetPermission(ctx, permID).Return(rbac.Permission{ID: permID}, nil).Times(2) + store.EXPECT().ListRoles(ctx).Return([]rbac.Role{ + { + ID: roleID, + Name: "admin", + Permissions: []rbac.Permission{ + {ID: permID, Resource: "blog", Action: "read"}, + }, + }, + }, nil).Times(1) + store.EXPECT().ListRoles(ctx).Return([]rbac.Role{ + { + ID: roleID, + Name: "admin", + Permissions: []rbac.Permission{}, + }, + }, nil).Times(1) + store.EXPECT().RemovePermission(ctx, permID).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + permID string + expectedError error + expectAnyError bool + }{ + { + name: "Remove permission with invalid UUID", + permID: "invalid-uuid", + expectAnyError: true, + }, + { + name: "Remove non-existing permission", + permID: permID, + expectedError: rbac.ErrNotFound, + }, + { + name: "Remove permission that is in use", + permID: permID, + expectedError: rbac.ErrPermissionInUse, + }, + { + name: "Remove permission successfully", + permID: permID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.RemovePermission(ctx, tt.permID) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestAddPermissionToRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + roleName := "editor" + permID := "550e8400-e29b-41d4-a716-446655440001" + existingPermID := "550e8400-e29b-41d4-a716-446655440002" + + existingPerm := rbac.Permission{ID: existingPermID, Resource: "blog", Action: "read"} + newPerm := rbac.Permission{ID: permID, Resource: "article", Action: "write"} + + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Name: roleName, + Permissions: []rbac.Permission{existingPerm}, + }, nil).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ + ID: "550e8400-e29b-41d4-a716-446655440000", + Name: roleName, + Permissions: []rbac.Permission{}, + }, nil).Times(2) + store.EXPECT().GetPermission(ctx, permID).Return(rbac.Permission{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetPermission(ctx, permID).Return(newPerm, nil).Times(1) + store.EXPECT().UpdateRole(ctx, gomock.Any()).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + roleName string + permID string + expectedError error + expectAnyError bool + }{ + { + name: "Add permission with invalid permission UUID", + roleName: roleName, + permID: "invalid-uuid", + expectAnyError: true, + }, + { + name: "Add permission to non-existing role", + roleName: roleName, + permID: permID, + expectedError: rbac.ErrNotFound, + }, + { + name: "Add duplicate permission to role", + roleName: roleName, + permID: existingPermID, + expectedError: rbac.ErrAlreadyExists, + }, + { + name: "Add non-existing permission to role", + roleName: roleName, + permID: permID, + expectedError: rbac.ErrNotFound, + }, + { + name: "Add permission to role successfully", + roleName: roleName, + permID: permID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.AddPermissionToRole(ctx, tt.roleName, tt.permID) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestRemovePermissionFromRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + roleName := "editor" + roleID := "550e8400-e29b-41d4-a716-446655440000" + permID := "550e8400-e29b-41d4-a716-446655440001" + otherPermID := "550e8400-e29b-41d4-a716-446655440002" + + existingPerm := rbac.Permission{ID: permID, Resource: "blog", Action: "read"} + + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ + ID: roleID, + Name: roleName, + Permissions: []rbac.Permission{}, + }, nil).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ + ID: roleID, + Name: roleName, + Permissions: []rbac.Permission{existingPerm}, + }, nil).Times(1) + store.EXPECT().UpdateRole(ctx, gomock.Any()).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + roleName string + permID string + expectedError error + expectAnyError bool + }{ + { + name: "Remove permission with invalid permission UUID", + roleName: roleName, + permID: "invalid-uuid", + expectAnyError: true, + }, + { + name: "Remove permission from non-existing role", + roleName: roleName, + permID: permID, + expectedError: rbac.ErrNotFound, + }, + { + name: "Remove non-existing permission from role", + roleName: roleName, + permID: otherPermID, + expectedError: rbac.ErrNotFound, + }, + { + name: "Remove permission from role successfully", + roleName: roleName, + permID: permID, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.RemovePermissionFromRole(ctx, tt.roleName, tt.permID) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +func TestAssignRole(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + roleName := "admin" + roleID := "550e8400-e29b-41d4-a716-446655440000" + subjectID := "subject-123" + + existingSubject := rbac.Subject{ID: subjectID, RoleID: ""} + + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetRoleByName(ctx, roleName).Return(rbac.Role{ID: roleID, Name: roleName}, nil).Times(2) + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{}, rbac.ErrNotFound).Times(1) + store.EXPECT().GetSubject(ctx, subjectID).Return(existingSubject, nil).Times(1) + store.EXPECT().UpdateSubject(ctx, rbac.Subject{ID: subjectID, RoleID: roleID}).Return(nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + tests := []struct { + name string + subjectID string + roleName string + expectedError error + expectAnyError bool + }{ + { + name: "Assign non-existing role", + subjectID: subjectID, + roleName: roleName, + expectedError: rbac.ErrNotFound, + }, + { + name: "Assign role to non-existing subject", + subjectID: subjectID, + roleName: roleName, + expectAnyError: true, + }, + { + name: "Assign role successfully", + subjectID: subjectID, + roleName: roleName, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := r.AssignRole(ctx, tt.subjectID, tt.roleName) + + if tt.expectedError == nil && !tt.expectAnyError { + assert.NoError(t, err) + return + } + + assert.Error(t, err) + if tt.expectedError != nil { + assert.ErrorIs(t, err, tt.expectedError) + } + }) + } +} + +type MockResource struct { + name string + ownerID string +} + +func (mr MockResource) Name() string { + return mr.name +} + +type MockOwnershipBizRule struct{} + +func (mbr MockOwnershipBizRule) Name() string { + return "test_ownership" +} + +func (mbr MockOwnershipBizRule) Evaluate(ctx context.Context, subjectID string, resource rbac.Resource) (bool, error) { + mockRes, ok := resource.(MockResource) + if !ok { + return false, nil + } + return mockRes.ownerID == subjectID, nil +} + +func TestCan(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + store := mocks.NewMockStore(ctrl) + + subjectID := "subject-123" + otherSubjectID := "subject-456" + adminRoleID := "550e8400-e29b-41d4-a716-446655440000" + userRoleID := "550e8400-e29b-41d4-a716-446655440001" + readPermID := "perm-read-123" + writePermID := "perm-write-123" + deleteWithRulePermID := "perm-delete-123" + + adminRole := rbac.Role{ + ID: adminRoleID, + Name: "admin", + Permissions: []rbac.Permission{ + {ID: readPermID, Resource: "blog", Action: "read"}, + {ID: writePermID, Resource: "blog", Action: "write"}, + }, + } + + userRole := rbac.Role{ + ID: userRoleID, + Name: "user", + ParentID: adminRoleID, + Permissions: []rbac.Permission{ + {ID: deleteWithRulePermID, Resource: "blog", Action: "delete", BizRule: "test_ownership"}, + }, + } + + resource := MockResource{name: "blog", ownerID: subjectID} + + store.EXPECT().GetRole(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, id string) (rbac.Role, error) { + switch id { + case adminRoleID: + return adminRole, nil + case userRoleID: + return userRole, nil + default: + return rbac.Role{}, rbac.ErrNotFound + } + }).AnyTimes() + + store.EXPECT().GetSubject(ctx, "non-existing").Return(rbac.Subject{}, rbac.ErrNotFound).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: ""}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: "invalid-uuid"}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: adminRoleID}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: adminRoleID}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: userRoleID}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, subjectID).Return(rbac.Subject{ID: subjectID, RoleID: userRoleID}, nil).Times(1) + + store.EXPECT().GetSubject(ctx, otherSubjectID).Return(rbac.Subject{ID: otherSubjectID, RoleID: userRoleID}, nil).Times(1) + + r, err := rbac.NewRBAC(ctx, store) + assert.NoError(t, err) + assert.NotNil(t, r) + + err = r.RegisterBizRule(MockOwnershipBizRule{}) + assert.NoError(t, err) + + tests := []struct { + name string + subjectID string + action string + resource rbac.Resource + expectedResult bool + expectError bool + }{ + { + name: "Subject not found", + subjectID: "non-existing", + action: "read", + resource: resource, + expectError: true, + }, + { + name: "Subject without role", + subjectID: subjectID, + action: "read", + resource: resource, + expectedResult: false, + expectError: false, + }, + { + name: "Subject with invalid role UUID", + subjectID: subjectID, + action: "read", + resource: resource, + expectError: true, + }, + { + name: "Subject does not have permission", + subjectID: subjectID, + action: "update", + resource: resource, + expectedResult: false, + expectError: false, + }, + { + name: "Subject has direct permission", + subjectID: subjectID, + action: "read", + resource: resource, + expectedResult: true, + expectError: false, + }, + { + name: "Subject has permission via hierarchy", + subjectID: subjectID, + action: "write", + resource: resource, + expectedResult: true, + expectError: false, + }, + { + name: "Subject has permission with business rule (passes)", + subjectID: subjectID, + action: "delete", + resource: resource, + expectedResult: true, + expectError: false, + }, + { + name: "Subject has permission with business rule (fails)", + subjectID: otherSubjectID, + action: "delete", + resource: resource, + expectedResult: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := r.Can(ctx, tt.subjectID, tt.action, tt.resource) + + if tt.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/pkg/store.go b/pkg/store.go index da2c113..7d82a38 100644 --- a/pkg/store.go +++ b/pkg/store.go @@ -2,10 +2,12 @@ package rbac import "context" +// Store defines the interface for persisting RBAC entities type Store interface { // Roles CreateRole(ctx context.Context, role Role) error GetRole(ctx context.Context, roleID string) (Role, error) + GetRoleByName(ctx context.Context, name string) (Role, error) UpdateRole(ctx context.Context, role Role) error RemoveRole(ctx context.Context, roleID string) error ListRoles(ctx context.Context) ([]Role, error) @@ -17,9 +19,8 @@ type Store interface { RemovePermission(ctx context.Context, id string) error // Subjects - GetSubject(ctx context.Context, subjectID string) (User, error) - UpdateSubject(ctx context.Context, user User) error + CreateSubject(ctx context.Context, subject Subject) error + GetSubject(ctx context.Context, subjectID string) (Subject, error) + UpdateSubject(ctx context.Context, subject Subject) error ListSubjects(ctx context.Context) ([]string, error) - ListSubjectRoles(ctx context.Context, subjectID string) ([]Role, error) - ListSubjectGrants(ctx context.Context, subjectID string) ([]Grant, error) } diff --git a/pkg/store/gorm.go b/pkg/store/gorm.go new file mode 100644 index 0000000..ae7702d --- /dev/null +++ b/pkg/store/gorm.go @@ -0,0 +1,283 @@ +package store + +import ( + "context" + "fmt" + + rbac "github.com/codescalers/rbac/pkg" + "gorm.io/gorm" +) + +type GormStore struct { + db *gorm.DB +} + +func NewGormStore(db *gorm.DB) (*GormStore, error) { + store := &GormStore{db: db} + if err := store.migrate(); err != nil { + return nil, err + } + return store, nil +} + +func (s *GormStore) migrate() error { + return s.db.AutoMigrate( + &Role{}, + &Permission{}, + &Subject{}, + ) +} + +type Role struct { + ID string `gorm:"primaryKey"` + Name string `gorm:"uniqueIndex;not null"` + Description string `gorm:"type:text"` + ParentID string `gorm:"index;constraint:OnDelete:RESTRICT"` + Parent *Role `gorm:"foreignKey:ParentID;references:ID"` + Permissions []Permission `gorm:"many2many:role_permissions;"` +} + +type Permission struct { + ID string `gorm:"primaryKey"` + Resource string `gorm:"not null;index:idx_resource_action"` + Action string `gorm:"not null;index:idx_resource_action"` + BizRule string `gorm:"type:text"` + Roles []Role `gorm:"many2many:role_permissions;"` +} + +type Subject struct { + ID string `gorm:"primaryKey"` + RoleID string `gorm:"index"` +} + +func (s *GormStore) Close() error { + db, err := s.db.DB() + if err != nil { + return err + } + return db.Close() +} + +func (s *GormStore) CreateRole(ctx context.Context, role rbac.Role) error { + r := Role{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + ParentID: role.ParentID, + } + + r.Permissions = convertPermissionsFromRBAC(role.Permissions) + + return s.db.WithContext(ctx).Create(&r).Error +} + +func (s *GormStore) GetRole(ctx context.Context, roleID string) (rbac.Role, error) { + var r Role + err := s.db.WithContext(ctx).Preload("Permissions").First(&r, "id = ?", roleID).Error + if err != nil { + return rbac.Role{}, err + } + + role := rbac.Role{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + ParentID: r.ParentID, + Permissions: convertToRBACPermissions(r.Permissions), + } + + return role, nil +} + +func (s *GormStore) GetRoleByName(ctx context.Context, name string) (rbac.Role, error) { + var r Role + err := s.db.WithContext(ctx).Preload("Permissions").First(&r, "name = ?", name).Error + if err != nil { + return rbac.Role{}, err + } + + role := rbac.Role{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + ParentID: r.ParentID, + Permissions: convertToRBACPermissions(r.Permissions), + } + + return role, nil +} + +func (s *GormStore) UpdateRole(ctx context.Context, role rbac.Role) error { + var permIDs []string + for _, p := range role.Permissions { + permIDs = append(permIDs, p.ID) + } + + var perms []Permission + if len(permIDs) > 0 { + if err := s.db.WithContext(ctx).Find(&perms, "id IN ?", permIDs).Error; err != nil { + return err + } + if len(perms) != len(permIDs) { + return fmt.Errorf("some permissions not found for IDs: %v", permIDs) + } + } + + r := Role{ + ID: role.ID, + Name: role.Name, + Description: role.Description, + ParentID: role.ParentID, + } + + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Model(&r).Association("Permissions").Replace(perms); err != nil { + return err + } + return tx.Save(&r).Error + }) +} + +func (s *GormStore) RemoveRole(ctx context.Context, roleID string) error { + return s.db.WithContext(ctx).Delete(&Role{}, "id = ?", roleID).Error +} + +func (s *GormStore) ListRoles(ctx context.Context) ([]rbac.Role, error) { + var roles []Role + err := s.db.WithContext(ctx).Preload("Permissions").Find(&roles).Error + if err != nil { + return nil, err + } + + result := make([]rbac.Role, 0, len(roles)) + for _, r := range roles { + role := rbac.Role{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + ParentID: r.ParentID, + Permissions: convertToRBACPermissions(r.Permissions), + } + result = append(result, role) + } + + return result, nil +} + +func (s *GormStore) CreatePermission(ctx context.Context, p rbac.Permission) error { + perm := Permission{ + ID: p.ID, + Resource: p.Resource, + Action: p.Action, + BizRule: p.BizRule, + } + return s.db.WithContext(ctx).Create(&perm).Error +} + +func (s *GormStore) GetPermission(ctx context.Context, id string) (rbac.Permission, error) { + var permission Permission + err := s.db.WithContext(ctx).First(&permission, "id = ?", id).Error + if err != nil { + return rbac.Permission{}, err + } + + return rbac.Permission{ + ID: permission.ID, + Resource: permission.Resource, + Action: permission.Action, + BizRule: permission.BizRule, + }, nil +} + +func (s *GormStore) ListPermissions(ctx context.Context) ([]rbac.Permission, error) { + var perms []Permission + err := s.db.WithContext(ctx).Find(&perms).Error + if err != nil { + return nil, err + } + + result := make([]rbac.Permission, 0, len(perms)) + for _, perm := range perms { + result = append(result, rbac.Permission{ + ID: perm.ID, + Resource: perm.Resource, + Action: perm.Action, + BizRule: perm.BizRule, + }) + } + return result, nil +} + +func (s *GormStore) RemovePermission(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&Permission{}, "id = ?", id).Error +} + +func (s *GormStore) CreateSubject(ctx context.Context, subject rbac.Subject) error { + sub := Subject{ + ID: subject.ID, + RoleID: subject.RoleID, + } + return s.db.WithContext(ctx).Create(&sub).Error +} + +func (s *GormStore) GetSubject(ctx context.Context, subjectID string) (rbac.Subject, error) { + var sub Subject + err := s.db.WithContext(ctx).First(&sub, "id = ?", subjectID).Error + if err != nil { + return rbac.Subject{}, err + } + + return rbac.Subject{ + ID: sub.ID, + RoleID: sub.RoleID, + }, nil +} + +func (s *GormStore) UpdateSubject(ctx context.Context, subject rbac.Subject) error { + sub := Subject{ + ID: subject.ID, + RoleID: subject.RoleID, + } + return s.db.WithContext(ctx).Save(&sub).Error +} + +func (s *GormStore) ListSubjects(ctx context.Context) ([]string, error) { + var subjects []Subject + err := s.db.WithContext(ctx).Find(&subjects).Error + if err != nil { + return nil, err + } + + result := make([]string, 0, len(subjects)) + for _, sub := range subjects { + result = append(result, sub.ID) + } + + return result, nil +} + +func convertPermissionsFromRBAC(rbacPerms []rbac.Permission) []Permission { + perms := make([]Permission, 0, len(rbacPerms)) + for _, p := range rbacPerms { + perms = append(perms, Permission{ + ID: p.ID, + Resource: p.Resource, + Action: p.Action, + BizRule: p.BizRule, + }) + } + return perms +} + +func convertToRBACPermissions(perms []Permission) []rbac.Permission { + rbacPerms := make([]rbac.Permission, 0, len(perms)) + for _, p := range perms { + rbacPerms = append(rbacPerms, rbac.Permission{ + ID: p.ID, + Resource: p.Resource, + Action: p.Action, + BizRule: p.BizRule, + }) + } + return rbacPerms +} diff --git a/pkg/types.go b/pkg/types.go index 6e9e14e..8622b63 100644 --- a/pkg/types.go +++ b/pkg/types.go @@ -1,27 +1,24 @@ package rbac +// Permission represents an action that can be performed on a resource type Permission struct { ID string `json:"id"` Resource string `json:"resource"` Action string `json:"action"` + BizRule string `json:"biz_rule,omitempty"` } +// Role represents a collection of permissions with optional parent hierarchy type Role struct { ID string `json:"id"` Name string `json:"name"` Description string `json:"description"` + ParentID string `json:"parent_id,omitempty"` Permissions []Permission `json:"permissions"` } -type Grant struct { - ID string `json:"id"` - Resource string `json:"resource"` - ResourceID string `json:"resource_id"` - Action string `json:"action"` -} - -type User struct { - ID string `json:"id"` - Roles []Role `json:"roles"` - Grants []Grant `json:"grants"` +// Subject represents a subject with a single assigned role +type Subject struct { + ID string `json:"id"` + RoleID string `json:"role_id"` } diff --git a/pkg/validation.go b/pkg/validation.go new file mode 100644 index 0000000..5c17cca --- /dev/null +++ b/pkg/validation.go @@ -0,0 +1,68 @@ +package rbac + +import "context" + +func (r *RBAC) roleNameExists(ctx context.Context, name string) (bool, error) { + roles, err := r.store.ListRoles(ctx) + if err != nil { + return false, err + } + for _, role := range roles { + if role.Name == name { + return true, nil + } + } + return false, nil +} + +func (r *RBAC) permissionExists(ctx context.Context, resource, action, bizRule string) (bool, error) { + perms, err := r.store.ListPermissions(ctx) + if err != nil { + return false, err + } + for _, p := range perms { + if p.Resource != resource { + continue + } + if p.Action != action { + continue + } + if p.BizRule != bizRule { + continue + } + return true, nil + } + return false, nil +} + +func (r *RBAC) isRoleInUse(ctx context.Context, roleID string) (bool, error) { + subjects, err := r.store.ListSubjects(ctx) + if err != nil { + return false, err + } + for _, subjectID := range subjects { + subject, err := r.store.GetSubject(ctx, subjectID) + if err != nil { + continue + } + if subject.RoleID == roleID { + return true, nil + } + } + return false, nil +} + +func (r *RBAC) isPermissionInUse(ctx context.Context, permID string) (bool, error) { + roles, err := r.store.ListRoles(ctx) + if err != nil { + return false, err + } + for _, role := range roles { + for _, p := range role.Permissions { + if p.ID == permID { + return true, nil + } + } + } + return false, nil +}