Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions src/Agda/Compiler/Rust/AgdaToRustExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ module Agda.Compiler.Rust.AgdaToRustExpr ( compile, compileModule ) where
import Control.Monad.IO.Class ( MonadIO(liftIO) )
import qualified Data.List.NonEmpty as Nel

import Agda.Compiler.Backend ( IsMain )
import Agda.Compiler.Backend ( Defn(..), funCompiled, funClauses, IsMain, RecordData(..))
import Agda.Syntax.Abstract.Name ( QName )
import Agda.Syntax.Common.Pretty ( prettyShow )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), moduleNameParts )
import Agda.Syntax.Common ( moduleNameParts )
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), NamedName, WithOrigin(..), Ranged(..) )
import Agda.Syntax.Internal (
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..),
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), Dom'(..), unDom, PatternInfo(..), Pattern'(..),
qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) )
import Agda.Syntax.TopLevelModuleName ( TopLevelModuleName )
import Agda.TypeChecking.Monad.Base ( Definition(..) )
import Agda.TypeChecking.Monad
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )
import Agda.TypeChecking.Telescope ( teleNamedArgs, teleArgs, teleArgNames )

import Agda.Compiler.Rust.CommonTypes ( Options, CompiledDef, ModuleEnv )
import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustName, RustType, RustElem(..), FunBody )
Expand All @@ -30,24 +32,42 @@ compile _ _ _ Defn{..}

compileDefn :: QName -> Defn -> CompiledDef
compileDefn defName theDef =
-- https://hackage.haskell.org/package/Agda/docs/Agda-Compiler-Backend.html#t:Defn
case theDef of
Datatype{dataCons = fields} ->
compileDataType defName fields
Function{funCompiled = funDef, funClauses = fc} ->
compileFunction defName funDef fc
_ ->
Unhandled "compileDefn" (show defName ++ " = " ++ show theDef)
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
compileRecord defName recFields recTel
other ->
Unhandled "compileDefn" (show defName ++ "\n = \n" ++ show theDef)

compileDataType :: QName -> [QName] -> CompiledDef
compileDataType defName fields = TeEnum (showName defName) (map showName fields)
compileDataType defName fields = ReEnum (showName defName) (map showName fields)

compileRecord :: QName -> [Dom QName] -> Telescope -> CompiledDef
compileRecord defName recFields recTel = ReRec (showName defName) (foldl varsFromTelescope [] recTel)

varsFromTelescope :: [RustElem] -> Dom Type -> [RustElem]
varsFromTelescope xs dt = RustElem (nameFromDom dt) (fromDom dt) : xs

nameFromDom :: Dom Type -> RustName
nameFromDom dt = case (domName dt) of
Nothing -> error ("\nnameFromDom [" ++ show dt ++ "]\n")
Just a -> namedNameToStr a

-- https://hackage.haskell.org/package/Agda-2.6.4.3/docs/Agda-Syntax-Common.html#t:NamedName
namedNameToStr :: NamedName -> RustName
namedNameToStr n = rangedThing (woThing n)

compileFunction :: QName
-> Maybe CompiledClauses
-> [Clause]
-> CompiledDef
compileFunction defName funDef fc = TeFun
compileFunction defName funDef fc = ReFun
(showName defName)
(RustElem (compileFunctionArgument fc) (compileFunctionArgType fc))
[(RustElem (compileFunctionArgument fc) (compileFunctionArgType fc))]
(compileFunctionResultType fc)
(compileFunctionBody funDef)

Expand Down Expand Up @@ -120,7 +140,7 @@ showName = prettyShow . qnameName

compileModule :: TopLevelModuleName -> [CompiledDef] -> CompiledDef
compileModule mName cdefs =
TeMod (moduleName mName) cdefs
ReMod (moduleName mName) cdefs

moduleName :: TopLevelModuleName -> String
moduleName n = prettyShow (Nel.last (moduleNameParts n))
17 changes: 13 additions & 4 deletions src/Agda/Compiler/Rust/PrettyPrintingUtils.hs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
module Agda.Compiler.Rust.PrettyPrintingUtils ( prettyPrintRustExpr, moduleHeader ) where

import Data.List ( intersperse )
import Data.List ( intersperse, intercalate )
import Agda.Compiler.Rust.CommonTypes ( CompiledDef )
import Agda.Compiler.Rust.RustExpr ( RustExpr(..), RustElem(..), FunBody )

prettyPrintRustExpr :: CompiledDef -> String
prettyPrintRustExpr def = case def of
(TeEnum name fields) ->
(ReEnum name fields) ->
"enum" <> exprSeparator
<> name
<> exprSeparator
<> bracket (
indent -- TODO this is too simplistic indentation
<> concat (intersperse ", " fields))
<> defsSeparator
(TeFun fName (RustElem aName aType) resType fBody) ->
(ReFun fName [RustElem aName aType] resType fBody) ->
"pub fn" <> exprSeparator
<> fName
<> argList (
Expand All @@ -25,12 +25,15 @@ prettyPrintRustExpr def = case def of
<> exprSeparator <> bracket (
indent <> (prettyPrintFunctionBody fBody))
<> defsSeparator
(TeMod mName defs) ->
(ReMod mName defs) ->
moduleHeader mName
<> bracket (
defsSeparator -- empty line before first definition in module
<> combineLines (map prettyPrintRustExpr defs))
<> defsSeparator
(ReRec name args) -> "pub struct" <> exprSeparator <> name
<> exprSeparator <> (bracket (combineThem ",\n" (map (indent ++) (map printVar args))))
<> defsSeparator
(Unhandled name payload) -> ""
-- XXX at the end there should be no Unhandled expression
-- other -> "unsupported prettyPrintRustExpr " ++ (show other)
Expand All @@ -41,6 +44,9 @@ bracket str = "{\n" <> str <> "\n}"
argList :: String -> String
argList str = "(" <> str <> ")"

printVar :: RustElem -> String
printVar (RustElem sName sType) = sName <> ":" <> exprSeparator <> sType

indent :: String
indent = " "

Expand All @@ -59,6 +65,9 @@ funReturnTypeSeparator = "->"
combineLines :: [String] -> String
combineLines xs = unlines (filter (not . null) xs)

combineThem :: String -> [String] -> String
combineThem s xs = intercalate s xs

prettyPrintFunctionBody :: FunBody -> String
prettyPrintFunctionBody fBody = "return" <> exprSeparator <> fBody <> ";"

Expand Down
7 changes: 4 additions & 3 deletions src/Agda/Compiler/Rust/RustExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ data RustElem = RustElem RustName RustType
deriving ( Show )

data RustExpr
= TeMod RustName [RustExpr]
| TeEnum RustName [RustName]
| TeFun RustName RustElem RustType FunBody
= ReMod RustName [RustExpr]
| ReEnum RustName [RustName]
| ReFun RustName [RustElem] RustType FunBody
| ReRec RustName [RustElem]
| Unhandled RustName String
deriving ( Show )

Expand Down
20 changes: 10 additions & 10 deletions test/hello.agda
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ id_rgb x = x

-- product types

-- record ThePair : Set where
-- field
-- pairFst : Rgb
-- pairSnd : WeekDay
-- {-# COMPILE AGDA2RUST ThePair #-}
record ThePair : Set where
field
pairFst : Rgb
pairSnd : WeekDay
{-# COMPILE AGDA2RUST ThePair #-}

-- record Foo (A : Set) : Set where
-- field
-- foo : Pair A A

-- TODO Data.Product as Rust tuple
record Foo : Set where
field
foo : ThePair
{-# COMPILE AGDA2RUST Foo #-}

-- TODO function returning constant result
-- as-friday : TheRgb → TheWeekDay
Expand All @@ -53,6 +52,7 @@ id_rgb x = x

-- TODO polymorphic types

-- TODO Data.Product as Rust tuple
-- TODO Data.Bool
-- TODO if expressions, and, or

Expand Down
9 changes: 9 additions & 0 deletions test/hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,14 @@ pub fn id_rgb(x: Rgb) -> Rgb {
return x;
}

pub struct ThePair {
pairSnd: WeekDay,
pairFst: Rgb
}

pub struct Foo {
foo: ThePair
}


}