Skip to content

Commit bf5930d

Browse files
Add function support to skrypt
Co-Authored-By: Serg Kryvonos <[email protected]>
1 parent 03d60e3 commit bf5930d

File tree

4 files changed

+208
-1
lines changed

4 files changed

+208
-1
lines changed

libskrypt/skrypt.cpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ bool Skrypt::Add(std::string_view line) {
159159
return Add(v);
160160
}
161161

162+
bool Skrypt::Add(const ::omnn::math::Variable& var, ::omnn::math::Valuable&& val) {
163+
return base::Add(var, std::move(val));
164+
}
165+
162166
namespace {
163167
std::string Questionless(std::string s) {
164168
s.erase(std::remove(s.begin(), s.end(), '?'), s.end());
@@ -170,6 +174,132 @@ namespace {
170174
auto FindBracePos(std::string_view line) {
171175
return line.find_first_of("{}[]"sv);
172176
}
177+
178+
bool IsFunctionDefinition(std::string_view line) {
179+
auto openParenPos = line.find('(');
180+
if (openParenPos == std::string_view::npos) return false;
181+
182+
auto closeParenPos = line.find(')', openParenPos);
183+
if (closeParenPos == std::string_view::npos) return false;
184+
185+
auto equalPos = line.find('=', closeParenPos);
186+
if (equalPos == std::string_view::npos) return false;
187+
188+
if (openParenPos == 0) return false;
189+
190+
return true;
191+
}
192+
193+
bool IsFunctionApplication(std::string_view line) {
194+
auto equalPos = line.find('=');
195+
auto openParenPos = line.find('(');
196+
197+
if (openParenPos == std::string_view::npos) return false;
198+
199+
if (equalPos != std::string_view::npos && openParenPos < equalPos) return false;
200+
201+
auto closeParenPos = line.find(')', openParenPos);
202+
if (closeParenPos == std::string_view::npos) return false;
203+
204+
if (openParenPos == 0) return false;
205+
206+
return true;
207+
}
208+
209+
std::vector<std::string> ParseFunctionArgs(std::string_view argsStr) {
210+
std::vector<std::string> args;
211+
std::string currentArg;
212+
213+
for (size_t i = 0; i < argsStr.size(); ++i) {
214+
char c = argsStr[i];
215+
if (c == ',') {
216+
boost::algorithm::trim(currentArg);
217+
if (!currentArg.empty()) {
218+
args.push_back(currentArg);
219+
currentArg.clear();
220+
}
221+
} else {
222+
currentArg += c;
223+
}
224+
}
225+
226+
boost::algorithm::trim(currentArg);
227+
if (!currentArg.empty()) {
228+
args.push_back(currentArg);
229+
}
230+
231+
return args;
232+
}
233+
}
234+
235+
bool Skrypt::ProcessFunctionDefinition(std::string_view& line)
236+
{
237+
if (!IsFunctionDefinition(line)) {
238+
return false;
239+
}
240+
241+
auto openParenPos = line.find('(');
242+
auto closeParenPos = line.find(')', openParenPos);
243+
auto equalPos = line.find('=', closeParenPos);
244+
245+
std::string_view funcName = line.substr(0, openParenPos);
246+
boost::algorithm::trim(funcName);
247+
248+
std::string_view argsStr = line.substr(openParenPos + 1, closeParenPos - openParenPos - 1);
249+
auto args = ParseFunctionArgs(argsStr);
250+
251+
std::string_view exprStr = line.substr(equalPos + 1);
252+
boost::algorithm::trim(exprStr);
253+
254+
std::string lambdaExpr = std::string(exprStr);
255+
256+
auto& funcVar = varHost->Host(std::string(funcName));
257+
258+
std::string functionDef = "Function(";
259+
functionDef += std::string(funcName);
260+
functionDef += ", [";
261+
for (size_t i = 0; i < args.size(); ++i) {
262+
if (i > 0) functionDef += ", ";
263+
functionDef += args[i];
264+
}
265+
functionDef += "], ";
266+
functionDef += std::string(exprStr);
267+
functionDef += ")";
268+
269+
Add(funcVar, ::omnn::math::Valuable(functionDef, varHost));
270+
271+
return true;
272+
}
273+
274+
bool Skrypt::ProcessFunctionApplication(std::string_view& line)
275+
{
276+
if (!IsFunctionApplication(line)) {
277+
return false;
278+
}
279+
280+
auto equalPos = line.find('=');
281+
std::string_view leftSide;
282+
std::string_view rightSide;
283+
284+
if (equalPos != std::string_view::npos) {
285+
leftSide = line.substr(0, equalPos);
286+
boost::algorithm::trim(leftSide);
287+
rightSide = line.substr(equalPos + 1);
288+
} else {
289+
rightSide = line;
290+
}
291+
292+
boost::algorithm::trim(rightSide);
293+
294+
if (equalPos != std::string_view::npos) {
295+
auto& resultVar = varHost->Host(std::string(leftSide));
296+
297+
Add(resultVar, ::omnn::math::Valuable(rightSide, varHost));
298+
} else {
299+
Add(::omnn::math::Valuable(rightSide, varHost));
300+
}
301+
302+
return true;
173303
}
174304

175305
void Skrypt::ProcessQuestionLine(std::string_view& line)
@@ -332,6 +462,8 @@ bool Skrypt::ParseNextLine(std::istream& in, std::string_view& line)
332462

333463
if (boost::algorithm::contains(line, "?")) {
334464
ProcessQuestionLine(line);
465+
} else if (ProcessFunctionDefinition(line)) {
466+
} else if (ProcessFunctionApplication(line)) {
335467
} else {
336468
Valuable expression;
337469
try {
@@ -676,4 +808,4 @@ Skrypt::Skrypt(const skrypt::Skrypt& that)
676808
, sourceFilePath(that.sourceFilePath)
677809
, moduleFileSearchAdditionalPaths(that.moduleFileSearchAdditionalPaths)
678810
{
679-
}
811+
}

libskrypt/skrypt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,13 @@ class Skrypt
6565

6666
bool Add(::omnn::math::Valuable&&);
6767
bool Add(std::string_view);
68+
bool Add(const ::omnn::math::Variable& var, ::omnn::math::Valuable&& val);
6869
bool ParseNextLine(std::istream&, std::string_view&);
6970
void PrintVarKnowns(const omnn::math::Variable&);
7071
void PrintAllKnowns();
7172
void ProcessQuestionLine(std::string_view&);
73+
bool ProcessFunctionDefinition(std::string_view&);
74+
bool ProcessFunctionApplication(std::string_view&);
7275

7376
void BindTargetStream(std::ostream&);
7477
void BindTargetStream(std::shared_ptr<std::ostream>);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
f(x) = x^2
2+
y = f(2)
3+
g(x, y) = x*y + x^2
4+
z = g(2, 3)
5+
h = f(g(3, 1))
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#define BOOST_TEST_MODULE skryptFunction test
2+
#include <boost/test/unit_test.hpp>
3+
#include "skrypt.h"
4+
5+
using namespace ::skrypt;
6+
using namespace std::string_literals;
7+
using namespace std::string_view_literals;
8+
9+
BOOST_AUTO_TEST_CASE(FunctionDefinitionTest) {
10+
Skrypt skrypt;
11+
12+
skrypt.Add("f(x) = x^2"sv);
13+
14+
auto varhost = skrypt.GetVarHost();
15+
auto& f = varhost->Host("f"s);
16+
17+
auto& solutions = skrypt.Known(f);
18+
BOOST_TEST(!solutions.empty());
19+
20+
skrypt.Add("y = f(2)"sv);
21+
auto& y = varhost->Host("y"s);
22+
auto& ySolutions = skrypt.Known(y);
23+
BOOST_TEST(ySolutions.size() == 1);
24+
for (auto& solution : ySolutions) {
25+
BOOST_TEST(solution == 4);
26+
}
27+
28+
skrypt.Add("z = 3"sv);
29+
skrypt.Add("w = f(z)"sv);
30+
auto& w = varhost->Host("w"s);
31+
auto& wSolutions = skrypt.Known(w);
32+
BOOST_TEST(wSolutions.size() == 1);
33+
for (auto& solution : wSolutions) {
34+
BOOST_TEST(solution == 9);
35+
}
36+
}
37+
38+
BOOST_AUTO_TEST_CASE(MultipleArgumentFunctionTest) {
39+
Skrypt skrypt;
40+
41+
skrypt.Add("g(x, y) = x*y + x^2"sv);
42+
43+
skrypt.Add("result = g(2, 3)"sv);
44+
auto varhost = skrypt.GetVarHost();
45+
auto& result = varhost->Host("result"s);
46+
auto& resultSolutions = skrypt.Known(result);
47+
BOOST_TEST(resultSolutions.size() == 1);
48+
for (auto& solution : resultSolutions) {
49+
BOOST_TEST(solution == 10); // 2*3 + 2^2 = 6 + 4 = 10
50+
}
51+
}
52+
53+
BOOST_AUTO_TEST_CASE(ComposedFunctionTest) {
54+
Skrypt skrypt;
55+
56+
skrypt.Add("f(x) = x^2"sv);
57+
skrypt.Add("g(x) = 2*x + 1"sv);
58+
59+
skrypt.Add("h = f(g(3))"sv);
60+
auto varhost = skrypt.GetVarHost();
61+
auto& h = varhost->Host("h"s);
62+
auto& hSolutions = skrypt.Known(h);
63+
BOOST_TEST(hSolutions.size() == 1);
64+
for (auto& solution : hSolutions) {
65+
BOOST_TEST(solution == 49); // f(g(3)) = f(2*3+1) = f(7) = 7^2 = 49
66+
}
67+
}

0 commit comments

Comments
 (0)