Skip to content

Commit 2876381

Browse files
authored
Merge pull request #12 from CoLearn-Dev/init_func
- support init function
2 parents c397bce + 586f21d commit 2876381

File tree

4 files changed

+101
-2
lines changed

4 files changed

+101
-2
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "colink"
3-
version = "0.1.16"
3+
version = "0.1.17"
44
edition = "2021"
55
description = "CoLink Rust SDK"
66
license = "MIT"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ CoLink SDK helps both application adnd protocol developers access the functional
99
Add this to your Cargo.toml:
1010
```toml
1111
[dependencies]
12-
colink = "0.1.15"
12+
colink = "0.1.17"
1313
```
1414

1515
## Getting Started
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#![allow(unused_variables)]
2+
use colink::{CoLink, Participant, ProtocolEntry};
3+
4+
struct Init;
5+
#[colink::async_trait]
6+
impl ProtocolEntry for Init {
7+
async fn start(
8+
&self,
9+
cl: CoLink,
10+
_param: Vec<u8>, // For init function, param is empty
11+
_participants: Vec<Participant>, // For init function, participants is empty
12+
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
13+
// Init
14+
Ok(())
15+
}
16+
}
17+
18+
struct Initiator;
19+
#[colink::async_trait]
20+
impl ProtocolEntry for Initiator {
21+
async fn start(
22+
&self,
23+
cl: CoLink,
24+
param: Vec<u8>,
25+
participants: Vec<Participant>,
26+
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
27+
println!("initiator");
28+
Ok(())
29+
}
30+
}
31+
32+
struct Receiver;
33+
#[colink::async_trait]
34+
impl ProtocolEntry for Receiver {
35+
async fn start(
36+
&self,
37+
cl: CoLink,
38+
param: Vec<u8>,
39+
participants: Vec<Participant>,
40+
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
41+
println!("{}", String::from_utf8_lossy(&param));
42+
cl.create_entry(&format!("tasks:{}:output", cl.get_task_id()?), &param)
43+
.await?;
44+
Ok(())
45+
}
46+
}
47+
48+
colink::protocol_start!(
49+
("greetings:@init", Init), // bind init function
50+
("greetings:initiator", Initiator), // bind initiator's entry function
51+
("greetings:receiver", Receiver) // bind receiver's entry function
52+
);

src/protocol.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,54 @@ pub fn _protocol_start(
160160
cl: CoLink,
161161
user_funcs: HashMap<String, Box<dyn ProtocolEntry + Send + Sync>>,
162162
) -> Result<(), Error> {
163+
let mut operator_funcs: HashMap<String, Box<dyn ProtocolEntry + Send + Sync>> = HashMap::new();
164+
let mut protocols = vec![];
163165
for (protocol_and_role, user_func) in user_funcs {
166+
let cl = cl.clone();
167+
if protocol_and_role.ends_with(":@init") {
168+
let protocol_name = protocol_and_role[..protocol_and_role.len() - 6].to_string();
169+
let _ = tokio::runtime::Builder::new_multi_thread()
170+
.enable_all()
171+
.build()
172+
.unwrap()
173+
.block_on(async move {
174+
let is_initialized_key =
175+
format!("_internal:protocols:{}:_is_initialized", protocol_name);
176+
let lock = cl.lock(&is_initialized_key).await?;
177+
let res = cl.read_entry(&is_initialized_key).await;
178+
if res.is_err() || res.unwrap()[0] == 0 {
179+
let cl_clone = cl.clone();
180+
match user_func
181+
.start(cl_clone, Default::default(), Default::default())
182+
.await
183+
{
184+
Ok(_) => {}
185+
Err(e) => error!("{}: {}.", protocol_and_role, e),
186+
}
187+
cl.update_entry(&is_initialized_key, &[1]).await?;
188+
}
189+
cl.unlock(lock).await?;
190+
Ok::<(), Box<dyn std::error::Error + Send + Sync + 'static>>(())
191+
});
192+
} else {
193+
protocols.push(protocol_and_role[..protocol_and_role.rfind(':').unwrap()].to_string());
194+
operator_funcs.insert(protocol_and_role, user_func);
195+
}
196+
}
197+
let cl_clone = cl.clone();
198+
let _ = tokio::runtime::Builder::new_multi_thread()
199+
.enable_all()
200+
.build()
201+
.unwrap()
202+
.block_on(async move {
203+
for protocol_name in protocols {
204+
let is_initialized_key =
205+
format!("_internal:protocols:{}:_is_initialized", protocol_name);
206+
cl_clone.update_entry(&is_initialized_key, &[1]).await?;
207+
}
208+
Ok::<(), Box<dyn std::error::Error + Send + Sync + 'static>>(())
209+
});
210+
for (protocol_and_role, user_func) in operator_funcs {
164211
let cl = cl.clone();
165212
thread::spawn(|| {
166213
tokio::runtime::Builder::new_multi_thread()

0 commit comments

Comments
 (0)