|
17 | 17 |
|
18 | 18 | namespace CSA {
|
19 | 19 |
|
20 |
| -template<bool PATH_RETRIEVAL = true, int ENABLE_PRUNING = 0, typename PROFILER = NoProfiler> |
| 20 | +template<bool PATH_RETRIEVAL = true, typename PROFILER = NoProfiler> |
21 | 21 | class CSA {
|
22 | 22 |
|
23 | 23 | public:
|
24 | 24 | constexpr static bool PathRetrieval = PATH_RETRIEVAL;
|
25 | 25 | using Profiler = PROFILER;
|
26 |
| - constexpr static int EnablePruning = ENABLE_PRUNING; |
27 |
| - using Type = CSA<PathRetrieval, EnablePruning, Profiler>; |
| 26 | + using Type = CSA<PathRetrieval, Profiler>; |
28 | 27 | using TripFlag = std::conditional_t<PathRetrieval, ConnectionId, bool>;
|
29 | 28 |
|
30 | 29 | private:
|
@@ -136,11 +135,6 @@ class CSA {
|
136 | 135 | const Connection& connection = data.connections[i];
|
137 | 136 | if (targetStop != noStop && connection.departureTime > arrivalTime[targetStop]) break;
|
138 | 137 | if (connectionIsReachable(connection, i)) {
|
139 |
| - if constexpr (EnablePruning == 1) { |
140 |
| - if (connection.arrivalTime > arrivalTime[targetStop]) { |
141 |
| - continue; |
142 |
| - } |
143 |
| - } |
144 | 138 | profiler.countMetric(METRIC_CONNECTIONS);
|
145 | 139 | arrivalByTrip(connection.arrivalStopId, connection.arrivalTime, connection.tripId);
|
146 | 140 | }
|
@@ -186,10 +180,206 @@ class CSA {
|
186 | 180 | profiler.countMetric(METRIC_EDGES);
|
187 | 181 | const StopId toStop = StopId(data.transferGraph.get(ToVertex, edge));
|
188 | 182 | const int newArrivalTime = time + data.transferGraph.get(TravelTime, edge);
|
189 |
| - if constexpr (EnablePruning == 1) { |
190 |
| - if (newArrivalTime > arrivalTime[targetStop]) { |
191 |
| - break; |
| 183 | + arrivalByTransfer(toStop, newArrivalTime, stop, edge); |
| 184 | + } |
| 185 | + } |
| 186 | + |
| 187 | + inline void arrivalByTransfer(const StopId stop, const int time, const StopId parent, const Edge edge) noexcept { |
| 188 | + if (arrivalTime[stop] <= time) return; |
| 189 | + profiler.countMetric(METRIC_STOPS_BY_TRANSFER); |
| 190 | + arrivalTime[stop] = time; |
| 191 | + if constexpr (PathRetrieval) { |
| 192 | + parentLabel[stop].parent = parent; |
| 193 | + parentLabel[stop].reachedByTransfer = true; |
| 194 | + parentLabel[stop].transferId = edge; |
| 195 | + } |
| 196 | + } |
| 197 | + |
| 198 | +private: |
| 199 | + const Data& data; |
| 200 | + |
| 201 | + StopId sourceStop; |
| 202 | + StopId targetStop; |
| 203 | + |
| 204 | + std::vector<TripFlag> tripReached; |
| 205 | + std::vector<int> arrivalTime; |
| 206 | + std::vector<ParentLabel> parentLabel; |
| 207 | + |
| 208 | + Profiler profiler; |
| 209 | +}; |
| 210 | + template<bool PATH_RETRIEVAL = true, typename PROFILER = NoProfiler> |
| 211 | +class CSA_prune { |
| 212 | + |
| 213 | +public: |
| 214 | + constexpr static bool PathRetrieval = PATH_RETRIEVAL; |
| 215 | + using Profiler = PROFILER; |
| 216 | + using Type = CSA_prune<PathRetrieval, Profiler>; |
| 217 | + using TripFlag = std::conditional_t<PathRetrieval, ConnectionId, bool>; |
| 218 | + |
| 219 | +private: |
| 220 | + struct ParentLabel { |
| 221 | + ParentLabel(const StopId parent = noStop, const bool reachedByTransfer = false, const TripId tripId = noTripId) : |
| 222 | + parent(parent), |
| 223 | + reachedByTransfer(reachedByTransfer), |
| 224 | + tripId(tripId) { |
| 225 | + } |
| 226 | + |
| 227 | + StopId parent; |
| 228 | + bool reachedByTransfer; |
| 229 | + union { |
| 230 | + TripId tripId; |
| 231 | + Edge transferId; |
| 232 | + }; |
| 233 | + }; |
| 234 | + |
| 235 | +public: |
| 236 | + CSA_prune(const Data& data, const Profiler& profilerTemplate = Profiler()) : |
| 237 | + data(data), |
| 238 | + sourceStop(noStop), |
| 239 | + targetStop(noStop), |
| 240 | + tripReached(data.numberOfTrips(), TripFlag()), |
| 241 | + arrivalTime(data.numberOfStops(), never), |
| 242 | + parentLabel(PathRetrieval ? data.numberOfStops() : 0), |
| 243 | + profiler(profilerTemplate) { |
| 244 | + Assert(Vector::isSorted(data.connections), "Connections must be sorted in ascending order!"); |
| 245 | + profiler.registerPhases({PHASE_CLEAR, PHASE_INITIALIZATION, PHASE_CONNECTION_SCAN}); |
| 246 | + profiler.registerMetrics({METRIC_CONNECTIONS, METRIC_EDGES, METRIC_STOPS_BY_TRIP, METRIC_STOPS_BY_TRANSFER}); |
| 247 | + profiler.initialize(); |
| 248 | + } |
| 249 | + |
| 250 | + inline void run(const StopId source, const int departureTime, const StopId target = noStop) noexcept { |
| 251 | + profiler.start(); |
| 252 | + |
| 253 | + profiler.startPhase(); |
| 254 | + Assert(data.isStop(source), "Source stop " << source << " is not a valid stop!"); |
| 255 | + clear(); |
| 256 | + profiler.donePhase(PHASE_CLEAR); |
| 257 | + |
| 258 | + profiler.startPhase(); |
| 259 | + sourceStop = source; |
| 260 | + targetStop = target; |
| 261 | + arrivalTime[sourceStop] = departureTime; |
| 262 | + const int targetArrivalTime = (targetStop != noStop) |
| 263 | + ? arrivalTime[targetStop] |
| 264 | + : never; |
| 265 | + relaxEdges(sourceStop, departureTime, targetArrivalTime); |
| 266 | + const ConnectionId firstConnection = firstReachableConnection(departureTime); |
| 267 | + profiler.donePhase(PHASE_INITIALIZATION); |
| 268 | + |
| 269 | + profiler.startPhase(); |
| 270 | + scanConnections(firstConnection, ConnectionId(data.connections.size())); |
| 271 | + profiler.donePhase(PHASE_CONNECTION_SCAN); |
| 272 | + |
| 273 | + profiler.done(); |
| 274 | + } |
| 275 | + |
| 276 | + inline bool reachable(const StopId stop) const noexcept { |
| 277 | + return arrivalTime[stop] < never; |
| 278 | + } |
| 279 | + |
| 280 | + inline int getEarliestArrivalTime(const StopId stop) const noexcept { |
| 281 | + return arrivalTime[stop]; |
| 282 | + } |
| 283 | + |
| 284 | + inline Journey getJourney() const noexcept requires PathRetrieval { |
| 285 | + return getJourney(targetStop); |
| 286 | + } |
| 287 | + |
| 288 | + inline Journey getJourney(StopId stop) const noexcept requires PathRetrieval { |
| 289 | + Journey journey; |
| 290 | + if (!reachable(stop)) return journey; |
| 291 | + while (stop != sourceStop) { |
| 292 | + const ParentLabel& label = parentLabel[stop]; |
| 293 | + if (label.reachedByTransfer) { |
| 294 | + const int travelTime = data.transferGraph.get(TravelTime, label.transferId); |
| 295 | + journey.emplace_back(label.parent, stop, arrivalTime[stop] - travelTime, arrivalTime[stop], label.transferId); |
| 296 | + } else { |
| 297 | + journey.emplace_back(label.parent, stop, data.connections[tripReached[label.tripId]].departureTime, arrivalTime[stop], label.tripId); |
| 298 | + } |
| 299 | + stop = label.parent; |
| 300 | + } |
| 301 | + Vector::reverse(journey); |
| 302 | + return journey; |
| 303 | + } |
| 304 | + |
| 305 | + inline const Profiler& getProfiler() const noexcept { |
| 306 | + return profiler; |
| 307 | + } |
| 308 | + |
| 309 | +private: |
| 310 | + inline void clear() { |
| 311 | + sourceStop = noStop; |
| 312 | + targetStop = noStop; |
| 313 | + Vector::fill(arrivalTime, never); |
| 314 | + Vector::fill(tripReached, TripFlag()); |
| 315 | + if constexpr (PathRetrieval) { |
| 316 | + Vector::fill(parentLabel, ParentLabel()); |
| 317 | + } |
| 318 | + } |
| 319 | + |
| 320 | + inline ConnectionId firstReachableConnection(const int departureTime) const noexcept { |
| 321 | + return ConnectionId(Vector::lowerBound(data.connections, departureTime, [](const Connection& connection, const int time) { |
| 322 | + return connection.departureTime < time; |
| 323 | + })); |
| 324 | + } |
| 325 | + |
| 326 | + inline void scanConnections(const ConnectionId begin, const ConnectionId end) noexcept { |
| 327 | + for (ConnectionId i = begin; i < end; i++) { |
| 328 | + |
| 329 | + const Connection& connection = data.connections[i]; |
| 330 | + if (targetStop != noStop && connection.departureTime > arrivalTime[targetStop]) break; |
| 331 | + if (connectionIsReachable(connection, i)) { |
| 332 | + if (connection.arrivalTime > arrivalTime[targetStop]) { |
| 333 | + continue; |
192 | 334 | }
|
| 335 | + |
| 336 | + profiler.countMetric(METRIC_CONNECTIONS); |
| 337 | + arrivalByTrip(connection.arrivalStopId, connection.arrivalTime, connection.tripId, arrivalTime[targetStop]); |
| 338 | + } |
| 339 | + } |
| 340 | + } |
| 341 | + |
| 342 | + inline bool connectionIsReachableFromStop(const Connection& connection) const noexcept { |
| 343 | + return arrivalTime[connection.departureStopId] <= connection.departureTime - data.minTransferTime(connection.departureStopId); |
| 344 | + } |
| 345 | + |
| 346 | + inline bool connectionIsReachableFromTrip(const Connection& connection) const noexcept { |
| 347 | + return tripReached[connection.tripId] != TripFlag(); |
| 348 | + } |
| 349 | + |
| 350 | + inline bool connectionIsReachable(const Connection& connection, const ConnectionId id) noexcept { |
| 351 | + if (connectionIsReachableFromTrip(connection)) return true; |
| 352 | + if (connectionIsReachableFromStop(connection)) { |
| 353 | + if constexpr (PathRetrieval) { |
| 354 | + tripReached[connection.tripId] = id; |
| 355 | + } else { |
| 356 | + suppressUnusedParameterWarning(id); |
| 357 | + tripReached[connection.tripId] = true; |
| 358 | + } |
| 359 | + return true; |
| 360 | + } |
| 361 | + return false; |
| 362 | + } |
| 363 | + |
| 364 | + inline void arrivalByTrip(const StopId stop, const int time, const TripId trip, const int targetArrivalTime) noexcept { |
| 365 | + if (arrivalTime[stop] <= time) return; |
| 366 | + profiler.countMetric(METRIC_STOPS_BY_TRIP); |
| 367 | + arrivalTime[stop] = time; |
| 368 | + if constexpr (PathRetrieval) { |
| 369 | + parentLabel[stop].parent = data.connections[tripReached[trip]].departureStopId; |
| 370 | + parentLabel[stop].reachedByTransfer = false; |
| 371 | + parentLabel[stop].tripId = trip; |
| 372 | + } |
| 373 | + relaxEdges(stop, time, targetArrivalTime); |
| 374 | + } |
| 375 | + |
| 376 | + inline void relaxEdges(const StopId stop, const int time, const int targetArrivalTime) noexcept { |
| 377 | + for (const Edge edge : data.transferGraph.edgesFrom(stop)) { |
| 378 | + profiler.countMetric(METRIC_EDGES); |
| 379 | + const StopId toStop = StopId(data.transferGraph.get(ToVertex, edge)); |
| 380 | + const int newArrivalTime = time + data.transferGraph.get(TravelTime, edge); |
| 381 | + if (newArrivalTime > targetArrivalTime) { |
| 382 | + break; |
193 | 383 | }
|
194 | 384 | arrivalByTransfer(toStop, newArrivalTime, stop, edge);
|
195 | 385 | }
|
|
0 commit comments