- 
                Notifications
    
You must be signed in to change notification settings  - Fork 79
 
add example for symmetric matrices #2545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| 
          
 Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/docs/src/notebooks/structured.jl b/docs/src/notebooks/structured.jl
index 359b215..f21e9c7 100644
--- a/docs/src/notebooks/structured.jl
+++ b/docs/src/notebooks/structured.jl
@@ -28,19 +28,19 @@ import EnzymeCore: EnzymeRules
 
 # ╔═╡ 4aa3740e-211c-4fff-9707-b8731a3fa57f
 begin
-	struct MySymmetric{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
-	    data::S
-	    uplo::Char
-	
-	    function MySymmetric{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
-	        LinearAlgebra.require_one_based_indexing(data)
-	        (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
-	        new{T,S}(data, uplo)
-	    end
-	end
-	function MySymmetric(A, uplo='U')
-		 MySymmetric{eltype(A), typeof(A)}(A, 'U')
-	end
+    struct MySymmetric{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+        data::S
+        uplo::Char
+
+        function MySymmetric{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+            LinearAlgebra.require_one_based_indexing(data)
+            (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+            return new{T, S}(data, uplo)
+        end
+    end
+    function MySymmetric(A, uplo = 'U')
+        return MySymmetric{eltype(A), typeof(A)}(A, 'U')
+    end
 end
 
 # ╔═╡ d6572143-fc11-4fdc-9c23-34867b33ad85
@@ -57,13 +57,15 @@ end
 
 # ╔═╡ 8f6588f2-754b-435c-9998-38da8f6b14ad
 begin
-	Base.size(A::MySymmetric) = size(A.data)
-	Base.length(A::MySymmetric) = length(A.data)
+    Base.size(A::MySymmetric) = size(A.data)
+    Base.length(A::MySymmetric) = length(A.data)
 end
 
 # ╔═╡ de04ba90-a397-4b75-b6b4-977d5881e848
-x = [1.0 0.0
-	 0.0 1.0]
+x = [
+    1.0 0.0
+    0.0 1.0
+]
 
 # ╔═╡ 9beada1d-6281-4dee-9049-8a66af1199a4
 norm(x)
@@ -72,11 +74,13 @@ norm(x)
 Enzyme.gradient(Reverse, norm, x) |> only
 
 # ╔═╡ 4be9a16a-c14b-4378-aa02-fc4bfa783d10
-Enzyme.gradient(Reverse, norm, MySymmetric(x))|> only
+Enzyme.gradient(Reverse, norm, MySymmetric(x)) |> only
 
 # ╔═╡ de7e14eb-4e55-4661-8614-e8026d09e6d3
-x2 = [0.0 1.0
-	  1.0 0.0]
+x2 = [
+    0.0 1.0
+    1.0 0.0
+]
 
 # ╔═╡ 999fb61b-20c6-4dcf-ad34-eca257bfda9f
 d_x2 = Enzyme.gradient(Reverse, norm, x2) |> only
@@ -91,26 +95,26 @@ d_x2 == d_x2_sym
 sum(d_x2) == sum(d_x2_sym.data)
 
 # ╔═╡ 827485bf-0973-4650-a969-6225f72e5d6a
- Symmetric(x2) |> dump
+Symmetric(x2) |> dump
 
 # ╔═╡ dbe34880-93bf-4a5d-b28b-5e6b76267742
- d_x2_sym |> dump
+d_x2_sym |> dump
 
 # ╔═╡ 0122a4df-75d9-444e-8d83-d7a93b6dfeb5
 begin
-	struct MySymmetric2{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
-	    data::S
-	    uplo::Char
-	
-	    function MySymmetric2{T,S}(data, uplo::Char) where {T,S<:AbstractMatrix{<:T}}
-	        LinearAlgebra.require_one_based_indexing(data)
-	        (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
-	        new{T,S}(data, uplo)
-	    end
-	end
-	function MySymmetric2(A, uplo='U')
-		 MySymmetric2{eltype(A), typeof(A)}(A, 'U')
-	end
+    struct MySymmetric2{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
+        data::S
+        uplo::Char
+
+        function MySymmetric2{T, S}(data, uplo::Char) where {T, S <: AbstractMatrix{<:T}}
+            LinearAlgebra.require_one_based_indexing(data)
+            (uplo != 'U' && uplo != 'L') && LinearAlgebra.throw_uplo()
+            return new{T, S}(data, uplo)
+        end
+    end
+    function MySymmetric2(A, uplo = 'U')
+        return MySymmetric2{eltype(A), typeof(A)}(A, 'U')
+    end
 end
 
 # ╔═╡ 8497709d-d123-48bc-a86e-5f58aa1b0ebc
@@ -149,8 +153,8 @@ md"""
 
 # ╔═╡ d0a031e4-99a4-417b-8f57-58a67219fa23
 begin
-	Base.size(A::MySymmetric2) = size(A.data)
-	Base.length(A::MySymmetric2) = length(A.data)
+    Base.size(A::MySymmetric2) = size(A.data)
+    Base.length(A::MySymmetric2) = length(A.data)
 end
 
 # ╔═╡ d047c162-446f-4de4-b2fd-5f2550f0ad78
@@ -160,34 +164,36 @@ Now we can implement a rule where we adjust the gradient contribution to be half
 
 # ╔═╡ e81ac66b-75ba-4e88-8e9c-491a60a671dc
 begin
-	function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
-	    # Compute primal
-	    if needs_primal(config)
-	        primal = func.val(S.val, i.val, j.val)
-	    else
-	        primal = nothing
-	    end
-	
-	    # Return an AugmentedReturn object with shadow = nothing
-	    return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
-	end
-
-	function EnzymeRules.reverse(config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
-	                 S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
-		i = i.val
-		j = j.val
-		A = S.val
-		dA = S.dval
-		@inbounds if i == j
-        	dA.data[i, j] += dret.val
-    	elseif (A.uplo == 'U') == (i < j)
-        	dA.data[i, j] += dret.val / 2
-    	else
-	        dA.data[j, i] += dret.val / 2
-    	end
-		
-	    return (nothing, nothing, nothing)
-	end
+    function EnzymeRules.augmented_primal(config, func::Const{typeof(Base.getindex)}, ::Type{<:Active}, S::Duplicated{<:MySymmetric2}, i::Const, j::Const)
+        # Compute primal
+        if needs_primal(config)
+            primal = func.val(S.val, i.val, j.val)
+        else
+            primal = nothing
+        end
+
+        # Return an AugmentedReturn object with shadow = nothing
+        return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
+    end
+
+    function EnzymeRules.reverse(
+            config, ::Const{typeof(Base.getindex)}, dret::Active, tape,
+            S::Duplicated{<:MySymmetric2}, i::Const, j::Const
+        )
+        i = i.val
+        j = j.val
+        A = S.val
+        dA = S.dval
+        @inbounds if i == j
+            dA.data[i, j] += dret.val
+        elseif (A.uplo == 'U') == (i < j)
+            dA.data[i, j] += dret.val / 2
+        else
+            dA.data[j, i] += dret.val / 2
+        end
+
+        return (nothing, nothing, nothing)
+    end
 end
 
 # ╔═╡ fe7138ce-950a-4f79-b176-cdb227d4c898 | 
    
          Benchmark Results
 Benchmark PlotsA plot of the benchmark results has been uploaded as an artifact at https://github.com/EnzymeAD/Enzyme.jl/actions/runs/17343787618/artifacts/3889340236.  | 
    
          Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@           Coverage Diff           @@
##             main    #2545   +/-   ##
=======================================
  Coverage   74.92%   74.92%           
=======================================
  Files          56       56           
  Lines       17428    17428           
=======================================
  Hits        13058    13058           
  Misses       4370     4370           ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
  | 
    
| dA = S.dval | ||
| @inbounds if i == j | ||
| dA.data[i, j] += dret.val | ||
| elseif (A.uplo == 'U') == (i < j) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the correctness of this generally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, which is why I haven't pushed this as a general rule for Symmetric, this notebook is more meant to explain to users what is happening.
No description provided.