Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
47ba4cd
Fixed bug in numBinScalarOp and supported ewadd, ewmul, and smul
lexu27 Aug 15, 2025
c979269
Added associativity and distributive operator properties to smul
lexu27 Aug 15, 2025
34aa319
Finished TASO Relu
lexu27 Aug 15, 2025
a8a2103
finished concat rules
lexu27 Aug 16, 2025
60c282e
added geometry rule for concat
lexu27 Aug 16, 2025
1ecd2d1
Finished transpose implementation and added colored printing to verif…
lexu27 Aug 20, 2025
f10a792
Added pretty printing for verification in Verify.hs
lexu27 Aug 20, 2025
a6961c5
finished enlarge
lexu27 Aug 22, 2025
9030966
Remove extra files
jaiarora0011 Sep 4, 2025
8ec47fc
Update hie.yaml
jaiarora0011 Sep 4, 2025
51347fd
Pretty print warnings; fix runall.sh
jaiarora0011 Sep 5, 2025
581fb6f
PrintTitle for XLA rules
jaiarora0011 Sep 5, 2025
44bbbbd
Merging to new branch changes
lexu27 Sep 15, 2025
d4e11aa
Merge remote-tracking branch 'refs/remotes/origin/taso_rules' into ta…
lexu27 Sep 15, 2025
a611d29
Added rank precondition
lexu27 Sep 16, 2025
220cdac
Fixed a bug in rankConditions
lexu27 Sep 16, 2025
1c8d22d
Fixed intractible verification for enlarge
lexu27 Oct 2, 2025
fe1ebca
added ewadd identity matrix rule
lexu27 Oct 2, 2025
d0baf85
finished split
lexu27 Oct 8, 2025
9510300
Finished convolution operator
lexu27 Oct 19, 2025
3e858e0
Fixed haskell formatting and hie.yaml
lexu27 Oct 19, 2025
3282dee
Use stack config for gen-hie
jaiarora0011 Nov 5, 2025
29aeda0
Remove addition lemmas for TensorElemSum
jaiarora0011 Nov 6, 2025
64094e6
Remove incorrect clamp operation with TensorElemSum
jaiarora0011 Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ result.txt
.stack-work/
.envrc
.direnv
.vscode/
dist-newstyle/
66 changes: 66 additions & 0 deletions hie.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,72 @@ cradle:
- path: "./rules/debug/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-debug"

- path: "./rules/taso/concat/Main.hs"
component: "tensor-right:exe:rules-taso-concat"

- path: "./rules/taso/concat/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-concat"

- path: "./rules/taso/conv/Main.hs"
component: "tensor-right:exe:rules-taso-conv"

- path: "./rules/taso/conv/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-conv"

- path: "./rules/taso/enlarge/Main.hs"
component: "tensor-right:exe:rules-taso-enlarge"

- path: "./rules/taso/enlarge/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-enlarge"

- path: "./rules/taso/ewadd/Main.hs"
component: "tensor-right:exe:rules-taso-ewadd"

- path: "./rules/taso/ewadd/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-ewadd"

- path: "./rules/taso/ewmul/Main.hs"
component: "tensor-right:exe:rules-taso-ewmul"

- path: "./rules/taso/ewmul/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-ewmul"

- path: "./rules/taso/matmul2D/Main.hs"
component: "tensor-right:exe:rules-taso-matmul2D"

- path: "./rules/taso/matmul2D/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-matmul2D"

- path: "./rules/taso/matmul3D/Main.hs"
component: "tensor-right:exe:rules-taso-matmul3D"

- path: "./rules/taso/matmul3D/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-matmul3D"

- path: "./rules/taso/relu/Main.hs"
component: "tensor-right:exe:rules-taso-relu"

- path: "./rules/taso/relu/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-relu"

- path: "./rules/taso/smul/Main.hs"
component: "tensor-right:exe:rules-taso-smul"

- path: "./rules/taso/smul/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-smul"

- path: "./rules/taso/split/Main.hs"
component: "tensor-right:exe:rules-taso-split"

- path: "./rules/taso/split/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-split"

- path: "./rules/taso/transpose/Main.hs"
component: "tensor-right:exe:rules-taso-transpose"

- path: "./rules/taso/transpose/Paths_tensor_right.hs"
component: "tensor-right:exe:rules-taso-transpose"

- path: "./rules/xla/add/Main.hs"
component: "tensor-right:exe:rules-xla-add"

Expand Down
67 changes: 67 additions & 0 deletions package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,73 @@ executables:
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
# TASO Executables
rules-taso-ewadd:
source-dirs: rules/taso/ewadd
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-ewmul:
source-dirs: rules/taso/ewmul
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-smul:
source-dirs: rules/taso/smul
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-relu:
source-dirs: rules/taso/relu
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-concat:
source-dirs: rules/taso/concat
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-transpose:
source-dirs: rules/taso/transpose
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-enlarge:
source-dirs: rules/taso/enlarge
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-matmul2D:
source-dirs: rules/taso/matmul2D
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-matmul3D:
source-dirs: rules/taso/matmul3D
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-split:
source-dirs: rules/taso/split
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
rules-taso-conv:
source-dirs: rules/taso/conv
main: Main.hs
dependencies: tensor-right
ghc-options: *exe-ghc-options
default-extensions: *exe-extensions
# Other Executables
rules-debug:
source-dirs: rules/debug
Expand Down
2 changes: 1 addition & 1 deletion plot/timing_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def first_num_of_tasks(self) -> int:
def overall_time(self) -> float:
return sum(x.time for x in self.results)


# TODO: Handle ANSI color escape codes
def parse_file(lines: Sequence[str]) -> list[Rule]:
"""
The file looks like this:
Expand Down
95 changes: 95 additions & 0 deletions rules/taso/concat/Main.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{-# OPTIONS_GHC -Wno-missing-import-lists #-}

import Grisette hiding ((-->))
import TensorRight
import TensorRight.Internal.DSL.TASO (concat, ewadd, ewmul, relu, smul)
import Prelude hiding (concat)

desugar :: forall a. NumRule a
desugar _ = do
r <- newRClass "r"
[sa, sb] <- newMaps ["sa", "sb"] r
a <- newTensor @a "A" [r --> sa]
b <- newTensor @a "B" [r --> sb]
let d = ByRClass r
lhs <- concat d a b
rhs <- concatTensor a b d
rewrite "concat(d, A, B) ⇒ Concatenate((A, B), d)" lhs rhs

smulAssociativity :: forall a. NumRule a
smulAssociativity _ = do
let w = ("w" :: a)
r <- newRClass "r"
s <- newMap "s" r
x <- newTensor @a "x" [r --> s]
y <- newTensor @a "y" [r --> s]
let d = ByRClass r
lhs <- concat d (smul x w) (smul y w)
rhs <- smul (concat d x y) w
rewrite "concat(d, smul(x, w), smul(y, w)) ⇒ smul(concat(d, x, y), w)" lhs rhs

ewaddAssociativity :: forall a. NumRule a
ewaddAssociativity _ = do
r <- newRClass "r"
s <- newMap "s" r
x <- newTensor @a "x" [r --> s]
y <- newTensor @a "y" [r --> s]
z <- newTensor @a "z" [r --> s]
w <- newTensor @a "w" [r --> s]
let d = ByRClass r
lhs <- concat d (ewadd x y) (ewadd z w)
rhs <- ewadd (concat d x z) (concat d y w)
rewrite "concat(d, ewadd(x, y), ewadd(z, w)) ⇒ ewadd(concat(d, x, z), concat(d, y, w))" lhs rhs

ewmulAssociativity :: forall a. NumRule a
ewmulAssociativity _ = do
r <- newRClass "r"
s <- newMap "s" r
x <- newTensor @a "x" [r --> s]
y <- newTensor @a "y" [r --> s]
z <- newTensor @a "z" [r --> s]
w <- newTensor @a "w" [r --> s]
let d = ByRClass r
lhs <- concat d (ewmul x y) (ewmul z w)
rhs <- ewmul (concat d x z) (concat d y w)
rewrite "concat(d, ewmul(x, y), ewmul(z, w)) ⇒ ewmul(concat(d, x, z), concat(d, y, w))" lhs rhs

reluAssociativity :: forall a. NumRule a
reluAssociativity _ = do
r <- newRClass "r"
s <- newMap "s" r
x <- newTensor @a "x" [r --> s]
y <- newTensor @a "y" [r --> s]
let d = ByRClass r
lhs <- concat d (relu @a x) (relu @a y)
rhs <- relu @a $ concat d x y
rewrite "" lhs rhs

geometry :: forall a. NumRule a
geometry _ = do
[d0, d1, d2] <- newRClasses ["d0", "d1", "d2"]
d0S <- newMap "d0S" d0
d1S <- newMap "d1S" d1
d2S <- newMap "d2S" d2
x <- newTensor @a "x" [d0 --> d0S, d1 --> d1S, d2 --> d2S]
y <- newTensor @a "y" [d0 --> d0S, d1 --> d1S, d2 --> d2S]
z <- newTensor @a "z" [d0 --> d0S, d1 --> d1S, d2 --> d2S]
w <- newTensor @a "w" [d0 --> d0S, d1 --> d1S, d2 --> d2S]
lhs <- concat (ByRClass d0) (concat (ByRClass d1) x y) (concat (ByRClass d1) z w)
rhs <- concat (ByRClass d1) (concat (ByRClass d0) x z) (concat (ByRClass d0) y w)
rewrite "concat(d0, concat(d1, x, y), concat(d1, z, w)) ⇒ concat(d1, concat(d0, x, z), concat(0, y, w))" lhs rhs

main :: IO ()
main = do
printTitle "######################## desugarOneRole ########################"
verifyNumDSL desugar
printTitle "######################## smulAssociativity #####################"
verifyNumDSL smulAssociativity
printTitle "######################## ewaddAssociativity ####################"
verifyNumDSL ewaddAssociativity
printTitle "######################## ewmulAssociativity ####################"
verifyNumDSL ewmulAssociativity
printTitle "######################## reluAssociativity ####################"
verifyNumDSL reluAssociativity
printTitle "######################## geometry #############################"
verifyNumDSL geometry
Loading