Skip to content

Commit 28827a2

Browse files
fix: additional records should be used for unrequired records
Link: https://datatracker.ietf.org/doc/html/rfc6763#section-12
1 parent d605749 commit 28827a2

File tree

4 files changed

+181
-103
lines changed

4 files changed

+181
-103
lines changed

src/dns_parser/builder.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ impl<T: MoveTo<Answers>> Builder<T> {
195195

196196
builder
197197
}
198+
199+
pub fn add_answers<'a, 'b>(
200+
self,
201+
name: &Name,
202+
cls: QueryClass,
203+
ttl: u32,
204+
data: impl Iterator<Item = RRData<'b>> + 'a,
205+
) -> Builder<Answers> {
206+
let mut builder = self.move_to::<Answers>();
207+
for item in data {
208+
builder.write_rr(name, cls, ttl, &item);
209+
Header::inc_answers(&mut builder.buf).expect("Too many answers");
210+
}
211+
builder
212+
}
198213
}
199214

200215
impl<T: MoveTo<Nameservers>> Builder<T> {
@@ -213,10 +228,25 @@ impl<T: MoveTo<Nameservers>> Builder<T> {
213228

214229
builder
215230
}
231+
232+
#[allow(dead_code)]
233+
pub fn add_nameservers<'a, 'b>(
234+
self,
235+
name: &Name,
236+
cls: QueryClass,
237+
ttl: u32,
238+
data: impl Iterator<Item = RRData<'b>> + 'a,
239+
) -> Builder<Nameservers> {
240+
let mut builder = self.move_to::<Nameservers>();
241+
for item in data {
242+
builder.write_rr(name, cls, ttl, &item);
243+
Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers");
244+
}
245+
builder
246+
}
216247
}
217248

218249
impl<T: MoveTo<Additional>> Builder<T> {
219-
#[allow(dead_code)]
220250
pub fn add_additional(
221251
self,
222252
name: &Name,
@@ -231,6 +261,21 @@ impl<T: MoveTo<Additional>> Builder<T> {
231261

232262
builder
233263
}
264+
265+
pub fn add_additionals<'a, 'b>(
266+
self,
267+
name: &Name,
268+
cls: QueryClass,
269+
ttl: u32,
270+
data: impl Iterator<Item = RRData<'b>> + 'a,
271+
) -> Builder<Additional> {
272+
let mut builder = self.move_to::<Additional>();
273+
for item in data {
274+
builder.write_rr(name, cls, ttl, &item);
275+
Header::inc_additional(&mut builder.buf).expect("Too many additional answers");
276+
}
277+
builder
278+
}
234279
}
235280

236281
#[cfg(test)]

src/dns_parser/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pub use self::header::Header;
1212
mod rrdata;
1313
pub use self::rrdata::RRData;
1414
mod builder;
15-
pub use self::builder::{Answers, Builder, Questions};
15+
pub use self::builder::*;

src/fsm.rs

Lines changed: 118 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc};
1818

1919
use super::{DEFAULT_TTL, MDNS_PORT};
2020
use crate::address_family::AddressFamily;
21-
use crate::services::{ServiceData, Services};
21+
use crate::services::{ServiceData, Services, ServicesInner};
2222

2323
pub type AnswerBuilder = dns_parser::Builder<dns_parser::Answers>;
24+
pub type AdditionalBuilder = dns_parser::Builder<dns_parser::Additional>;
2425

2526
const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> =
2627
Cow::Borrowed("_services._dns-sd._udp.local");
@@ -104,57 +105,46 @@ impl<AF: AddressFamily> FSM<AF> {
104105
return;
105106
}
106107

107-
let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true)
108-
.move_to::<dns_parser::Answers>();
109-
let mut multicast_builder =
110-
dns_parser::Builder::new_response(packet.header.id, false, true)
111-
.move_to::<dns_parser::Answers>();
112-
unicast_builder.set_max_size(None);
113-
multicast_builder.set_max_size(None);
114-
115108
for question in packet.questions {
116109
debug!(
117110
"received question: {:?} {}",
118111
question.qclass, question.qname
119112
);
120113

121114
if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any {
115+
let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true)
116+
.move_to::<dns_parser::Answers>();
117+
builder.set_max_size(None);
118+
let builder = self.handle_question(&question, builder);
119+
if builder.is_empty() {
120+
continue;
121+
}
122+
let response = builder.build().unwrap_or_else(|x| x);
122123
if question.qu {
123-
unicast_builder = self.handle_question(&question, unicast_builder);
124+
self.outgoing.push_back((response, addr));
124125
} else {
125-
multicast_builder = self.handle_question(&question, multicast_builder);
126+
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
127+
self.outgoing.push_back((response, addr));
126128
}
127129
}
128130
}
129-
130-
if !multicast_builder.is_empty() {
131-
let response = multicast_builder.build().unwrap_or_else(|x| x);
132-
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
133-
self.outgoing.push_back((response, addr));
134-
}
135-
136-
if !unicast_builder.is_empty() {
137-
let response = unicast_builder.build().unwrap_or_else(|x| x);
138-
self.outgoing.push_back((response, addr));
139-
}
140131
}
141132

142133
/// https://www.rfc-editor.org/rfc/rfc6763#section-9
143134
fn handle_service_type_enumeration<'a>(
144135
question: &dns_parser::Question,
145-
services: impl Iterator<Item = &'a ServiceData>,
136+
services: &ServicesInner,
146137
mut builder: AnswerBuilder,
147138
) -> AnswerBuilder {
148139
let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME);
149140
if question.qname == service_type_enumeration_name {
150-
for svc in services {
151-
let svc_type = ServiceData {
152-
name: svc.typ.clone(),
153-
typ: service_type_enumeration_name.clone(),
154-
port: svc.port,
155-
txt: vec![],
156-
};
157-
builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL);
141+
for typ in services.all_types() {
142+
builder = builder.add_answer(
143+
&service_type_enumeration_name,
144+
QueryClass::IN,
145+
DEFAULT_TTL,
146+
&RRData::PTR(typ.clone()),
147+
);
158148
}
159149
}
160150

@@ -165,93 +155,138 @@ impl<AF: AddressFamily> FSM<AF> {
165155
&self,
166156
question: &dns_parser::Question,
167157
mut builder: AnswerBuilder,
168-
) -> AnswerBuilder {
158+
) -> AdditionalBuilder {
169159
let services = self.services.read().unwrap();
170160
let hostname = services.get_hostname();
171161

172162
match question.qtype {
173-
QueryType::A | QueryType::AAAA if question.qname == *hostname => {
174-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
175-
}
163+
QueryType::A | QueryType::AAAA if question.qname == *hostname => builder
164+
.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
165+
.move_to(),
176166
QueryType::All => {
167+
let mut include_ip_additionals = false;
177168
// A / AAAA
178169
if question.qname == *hostname {
179-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
170+
builder =
171+
builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
180172
}
181173
// PTR
182-
builder =
183-
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
174+
builder = Self::handle_service_type_enumeration(question, &services, builder);
184175
for svc in services.find_by_type(&question.qname) {
185-
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
186-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
187-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
188-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
176+
builder =
177+
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr());
178+
include_ip_additionals = true;
189179
}
190180
// SRV
191181
if let Some(svc) = services.find_by_name(&question.qname) {
192-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
193-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
182+
builder = builder
183+
.add_answer(
184+
&svc.name,
185+
QueryClass::IN,
186+
DEFAULT_TTL,
187+
&svc.srv_rr(hostname),
188+
)
189+
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
190+
include_ip_additionals = true;
191+
}
192+
let mut builder = builder.move_to::<dns_parser::Additional>();
193+
// PTR (additional)
194+
for svc in services.find_by_type(&question.qname) {
195+
builder = builder
196+
.add_additional(
197+
&svc.name,
198+
QueryClass::IN,
199+
DEFAULT_TTL,
200+
&svc.srv_rr(hostname),
201+
)
202+
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
203+
include_ip_additionals = true;
204+
}
205+
206+
if include_ip_additionals {
207+
builder = builder.add_additionals(
208+
hostname,
209+
QueryClass::IN,
210+
DEFAULT_TTL,
211+
self.ip_rr(),
212+
);
194213
}
214+
builder
195215
}
196216
QueryType::PTR => {
197-
builder =
198-
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
217+
let mut builder =
218+
Self::handle_service_type_enumeration(question, &services, builder);
219+
for svc in services.find_by_type(&question.qname) {
220+
builder =
221+
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr())
222+
}
223+
let mut builder = builder.move_to::<dns_parser::Additional>();
199224
for svc in services.find_by_type(&question.qname) {
200-
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
201-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
202-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
203-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
225+
builder = builder
226+
.add_additional(
227+
&svc.name,
228+
QueryClass::IN,
229+
DEFAULT_TTL,
230+
&svc.srv_rr(hostname),
231+
)
232+
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
233+
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
204234
}
235+
builder
205236
}
206237
QueryType::SRV => {
207238
if let Some(svc) = services.find_by_name(&question.qname) {
208-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
209-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
239+
builder
240+
.add_answer(
241+
&svc.name,
242+
QueryClass::IN,
243+
DEFAULT_TTL,
244+
&svc.srv_rr(hostname),
245+
)
246+
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
247+
.move_to()
248+
} else {
249+
builder.move_to()
210250
}
211251
}
212252
QueryType::TXT => {
213253
if let Some(svc) = services.find_by_name(&question.qname) {
214-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
254+
builder
255+
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
256+
.move_to()
257+
} else {
258+
builder.move_to()
215259
}
216260
}
217-
_ => (),
261+
_ => builder.move_to(),
218262
}
219-
220-
builder
221263
}
222264

223-
fn add_ip_rr(&self, hostname: &Name, mut builder: AnswerBuilder, ttl: u32) -> AnswerBuilder {
265+
fn ip_rr(&self) -> impl Iterator<Item = RRData<'static>> + '_ {
224266
let interfaces = match get_if_addrs() {
225267
Ok(interfaces) => interfaces,
226268
Err(err) => {
227269
error!("could not get list of interfaces: {}", err);
228-
return builder;
270+
vec![]
229271
}
230272
};
231-
232-
for iface in interfaces {
273+
interfaces.into_iter().filter_map(move |iface| {
233274
if iface.is_loopback() {
234-
continue;
275+
return None;
235276
}
236277

237278
trace!("found interface {:?}", iface);
238279
if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) {
239280
trace!(" -> interface dropped");
240-
continue;
281+
return None;
241282
}
242283

243284
match (iface.ip(), AF::DOMAIN) {
244-
(IpAddr::V4(ip), Domain::IPV4) => {
245-
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip))
246-
}
247-
(IpAddr::V6(ip), Domain::IPV6) => {
248-
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip))
249-
}
250-
_ => (),
285+
(IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)),
286+
(IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)),
287+
_ => None,
251288
}
252-
}
253-
254-
builder
289+
})
255290
}
256291

257292
fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) {
@@ -261,11 +296,17 @@ impl<AF: AddressFamily> FSM<AF> {
261296

262297
let services = self.services.read().unwrap();
263298

264-
builder = svc.add_ptr_rr(builder, ttl);
265-
builder = svc.add_srv_rr(services.get_hostname(), builder, ttl);
266-
builder = svc.add_txt_rr(builder, ttl);
299+
builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr());
300+
builder = builder.add_answer(
301+
&svc.name,
302+
QueryClass::IN,
303+
ttl,
304+
&svc.srv_rr(services.get_hostname()),
305+
);
306+
builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr());
267307
if include_ip {
268-
builder = self.add_ip_rr(services.get_hostname(), builder, ttl);
308+
builder =
309+
builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr());
269310
}
270311

271312
if !builder.is_empty() {
@@ -349,7 +390,7 @@ mod tests {
349390

350391
answer_builder = FSM::<Inet>::handle_service_type_enumeration(
351392
&question,
352-
services.read().unwrap().into_iter(),
393+
&services.read().unwrap(),
353394
answer_builder,
354395
);
355396

0 commit comments

Comments
 (0)