Skip to content
Closed

Fly #75

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
docs/build/

.vscode/
*.log
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ Query
JSON
NearestNeighbors
ProgressMeter
DelimitedFiles
15,190 changes: 7,595 additions & 7,595 deletions asset/77625.swc

Large diffs are not rendered by default.

17,738 changes: 8,869 additions & 8,869 deletions asset/77641.swc

Large diffs are not rendered by default.

117 changes: 68 additions & 49 deletions src/NBLASTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using LinearAlgebra
using ProgressMeter
using CSV
using Distributed
import Statistics: mean

using ..RealNeuralNetworks.Utils.VectorClouds
using ..RealNeuralNetworks.NodeNets
Expand All @@ -23,17 +24,21 @@ export nblast, nblast_allbyall
"""
@inline function VectorCloud(neuron::SWC; k::Integer=20,
class::Union{Nothing, UInt8}=nothing,
downscaleFactor::Int=1)
VectorCloud(NodeNet(neuron); k=k, class=class, downscaleFactor=downscaleFactor)
downscaleFactor::Int=1,
recenter::Bool=false)
VectorCloud(NodeNet(neuron); k=k, class=class,
downscaleFactor=downscaleFactor, recenter=recenter)
end

"""
VectorCloud(neuron::Neuron; k::Integer=20, class::Union{Nothing, UInt8}=nothing)
"""
@inline function VectorCloud(neuron::Neuron{T}; k::Integer=20,
class::Union{Nothing, UInt8}=nothing,
downscaleFactor::Int=1) where T
VectorCloud(NodeNet(neuron); k=k, class=class, downscaleFactor=downscaleFactor)
downscaleFactor::Int=1,
recenter::Bool=false) where T
VectorCloud(NodeNet(neuron); k=k, class=class,
downscaleFactor=downscaleFactor, recenter=recenter)
end

"""
Expand All @@ -49,7 +54,8 @@ Return:
"""
function VectorCloud(neuron::NodeNet{T}; k::Integer=20,
class::Union{Nothing, UInt8}=nothing,
downscaleFactor::Int=1) where T
downscaleFactor::Int=1,
recenter::Bool=false) where T
nodeClassList = NodeNets.get_node_class_list(neuron)

# transform neuron to xyzmatrix
Expand All @@ -59,12 +65,13 @@ function VectorCloud(neuron::NodeNet{T}; k::Integer=20,
return zeros(T, (0,0))
end

xyzmatrix = Matrix{Float32}(undef, 3, N)
# the first 3 rows will be X,Y,Z, and the last 3 rows will be X,Y,Z of direction vector
vectorCloud = Matrix{Float32}(undef, 6, N)

if class == nothing
# all the nodes should be included
@inbounds for (i, node) in NodeNets.get_node_list(neuron) |> enumerate
xyzmatrix[:,i] = [node[1:3]...,]
vectorCloud[1:3, i] = [node[1:3]...,]
end
else
# only nodes match the class should be included
Expand All @@ -73,30 +80,30 @@ function VectorCloud(neuron::NodeNet{T}; k::Integer=20,
@inbounds for (i, node) in NodeNets.get_node_list(neuron) |> enumerate
if nodeClassList[i] == class
j += 1
xyzmatrix[:, j] = [node[1:3]...,]
vectorCloud[1:3, j] = [node[1:3]...,]
end
end
end

xyzmatrix = vectorCloud[1:3, :]
tree = KDTree(xyzmatrix; leafsize=k)

# the first 3 rows will be X,Y,Z, and the last 3 rows will be X,Y,Z of direction vector
ret = Matrix{Float32}(undef, 6, N)

idxs, dists = knn(tree, xyzmatrix, k, false)

data = Matrix{Float32}(undef, 3, k)
@inbounds for (nodeIndex, indexList) in idxs |> enumerate
data = xyzmatrix[:, indexList]
PCs, eigenValueList = Mathes.pca(data)
PCs, eigenValueList = Mathes.pca(xyzmatrix[:, indexList])
# use the first principle component as the direction vector
ret[:, nodeIndex] = [xyzmatrix[:, nodeIndex]..., PCs[1]...,]
vectorCloud[:, nodeIndex] = [xyzmatrix[:, nodeIndex]..., PCs[1]...,]
end

if downscaleFactor != 1
ret[1:3,:] ./= T(1000)
vectorCloud[1:3,:] ./= T(1000)
end

if recenter
vectorCloud[1:3,:] .-= mean(vectorCloud[1:3,:], dims=2)
end
ret
vectorCloud
end

function nblast(targetNeuron::Neuron{T}, queryNeuron::Neuron{T};
Expand Down Expand Up @@ -133,16 +140,17 @@ function nblast(target::Matrix{T}, query::Matrix{T};
targetTree::Union{Nothing, KDTree}=VectorClouds.to_kd_tree(target)) where T

if isempty(target)
#return ria[end, 1] * size(query, 2)
return -Inf32
return ria[end, 1] * size(query, 2)
#return -Inf32
elseif isempty(query)
# if one of them is empty, return the largest difference
#return ria[end, 1] * size(target, 2)
return -Inf32
return ria[end, 1] * size(target, 2)
#return -Inf32
end

totalScore = zero(T)


idxs, dists = knn(targetTree, query[1:3, :], 1, false)

@inbounds for (i, nodeIndexList) in idxs |> enumerate
Expand Down Expand Up @@ -212,23 +220,28 @@ function nblast_allbyall(vectorCloudList::Vector{Matrix{T}};
num = length(vectorCloudList)
similarityMatrix = Matrix{T}(undef, num, num)

@inbounds @showprogress 1 "computing similarity matrix..." for targetIndex in 1:num
#Threads.@threads for targetIndex in 1:num
Threads.@threads for queryIndex in 1:num
#for queryIndex in 1:num
similarityMatrix[targetIndex, queryIndex] = nblast(
tasks = Task[]
@inbounds for targetIndex in 1:num
for queryIndex in 1:num
task = Threads.@spawn similarityMatrix[targetIndex, queryIndex] = nblast(
vectorCloudList, targetIndex, queryIndex;
ria=ria, targetTree=treeList[targetIndex] )
ria=ria, targetTree=treeList[targetIndex])
push!(tasks, task)
end
end
for task in tasks
wait(task)
end
similarityMatrix
end

"""
nblast_allbyall(neuronList::Vector{Neuron{T}};
semantic::Bool=false,
k::Int=20,
ria::RangeIndexingArray{TR}=RangeIndexingArray{Float32}()) where {T}
ria::RangeIndexingArray{TR}=RangeIndexingArray{Float32}(),
downsacleFactor::Number=1000,
recenter::Bool=false) where {T}
Note that the neuron coordinate unit should be nm, it will be translated to micron internally.
The semantic NBLAST will find the nearest vector pair with same semantic labeling.
An axonal vector in neuron A will find closest axonal vector in neuron B.
Expand All @@ -244,23 +257,27 @@ function nblast_allbyall(neuronList::Vector{Neuron{T}};
semantic::Bool=false,
k::Int=20,
ria::Union{Nothing, RangeIndexingArray{T,2}}=nothing,
downscaleFactor::Number=1000) where T
downscaleFactor::Number=1000,
recenter::Bool=false) where T
if ria == nothing
ria = RangeIndexingArray{T}()
end
if semantic
# transforming to vector cloud list
axonVectorCloudList = map(x->VectorCloud(x;class=Segments.AXON_CLASS,
k=k, downscaleFactor=downscaleFactor), neuronList)
k=k, downscaleFactor=downscaleFactor,
recenter=recenter), neuronList)
dendVectorCloudList = map(x->VectorCloud(x;class=Segments.DENDRITE_CLASS,
k=k, downscaleFactor=downscaleFactor), neuronList)
k=k, downscaleFactor=downscaleFactor,
recenter=recenter), neuronList)
axonSimilarityMatrix = nblast_allbyall(axonVectorCloudList; ria=ria)
dendSimilarityMatrix = nblast_allbyall(dendVectorCloudList; ria=ria)
rawSimilarityMatrix = axonSimilarityMatrix .+ dendSimilarityMatrix
else
# transforming to vector cloud list
vectorCloudList = map(x->VectorCloud(x;class=nothing,
k=k, downscaleFactor=downscaleFactor), neuronList)
k=k, downscaleFactor=downscaleFactor,
recenter=recenter), neuronList)
rawSimilarityMatrix = nblast_allbyall(vectorCloudList; ria=ria)
end

Expand Down Expand Up @@ -300,26 +317,24 @@ end
function nblast_allbyall_small2big(vectorCloudList::Vector{X};
ria::Union{Nothing, RangeIndexingArray{T,2}}=nothing,
treeList::Vector=pmap(VectorClouds.to_kd_tree, vectorCloudList),
normalized::Bool=true
) where {X<:Matrix{Float32}, T}
normalized::Bool = true) where {X<:Matrix{Float32}, T}
N = length(vectorCloudList)
rawSimilarityMatrix = ones(Float32, (N,N))
selfScoreMatrix = ones(Float32, (N,N))

# Threads.@threads for i in 1:N
@showprogress for i in 1:N
Threads.@threads for j in (i+1):N
#for j in (i+1):N
small_to_big_nblast!(rawSimilarityMatrix, selfScoreMatrix,
tasks = Task[]
@inbounds for i in 1:N
for j in (i+1):N
task = Threads.@spawn small_to_big_nblast!(rawSimilarityMatrix, selfScoreMatrix,
vectorCloudList, i, j, ria, treeList)
push!(tasks, task)
end
end

if normalized
return rawSimilarityMatrix ./ selfScoreMatrix
else
return rawSimilarityMatrix, selfScoreMatrix
for task in tasks
wait(task)
end

return rawSimilarityMatrix, selfScoreMatrix
end

"""
Expand All @@ -338,18 +353,22 @@ function nblast_allbyall_small2big(neuronList::Vector{Neuron{T}};
dendVectorCloudList = pmap(x->VectorCloud(x; class=Segments.DENDRITE_CLASS,
k=k, downscaleFactor=downscaleFactor), neuronList)
axonRawSimilarityMatrix, axonSelfScoreMatrix = nblast_allbyall_small2big(
axonVectorCloudList; ria=ria, normalized=false)
axonVectorCloudList; ria=ria)
dendRawSimilarityMatrix, dendSelfScoreMatrix = nblast_allbyall_small2big(
dendVectorCloudList; ria=ria, normalized=false)
dendVectorCloudList; ria=ria)
# normalization
# add a small number to avoid divid by zero error
similarityMatrix = (axonRawSimilarityMatrix .+ dendRawSimilarityMatrix) ./
(axonSelfScoreMatrix .+ dendSelfScoreMatrix .+ T(1e-6))
@assert !any(isnan.(similarityMatrix))
return similarityMatrix
else
vectorCloudList = pmap(x->VectorCloud(x; k=k, downscaleFactor=downscaleFactor), neuronList);
return nblast_allbyall_small2big(vectorCloudList; ria=ria, normalized=true)

rawSimilarityMatrix, selfScoreMatrix = nblast_allbyall_small2big(vectorCloudList; ria=ria)
# normalization
similarityMatrix = rawSimilarityMatrix ./ selfScoreMatrix
end
return similarityMatrix
end

end # end of module
end # end of module
20 changes: 15 additions & 5 deletions src/Neurons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,15 @@ function Neuron(nodeNet::NodeNet{T}, rootNodeId::Integer;
neuron = Neuron!(rootNodeId, nodeNet, collectedFlagVec)
# println(sum(collectedFlagVec), " visited voxels in ", length(collectedFlagVec))

# precompute the terminal node Id List
terminalNodeIdList2 = NodeNets.get_terminal_node_id_list(nodeNet)
while !all(collectedFlagVec)
# println(sum(collectedFlagVec), " visited voxels in ", length(collectedFlagVec))
# there exist some uncollected nodes
# find the uncollected node that is closest to the terminal nodes as seed
# terminal node should have high priority than normal nodes.
seedNodeId2 = find_seed_node_id(neuron, nodeNet, collectedFlagVec)
seedNodeId2 = find_seed_node_id(neuron, nodeNet, collectedFlagVec;
terminalNodeIdList2=terminalNodeIdList2)

mergingSegmentId1, mergingNodeIdInSegment1, weightedTerminalDistance =
find_merging_terminal_node_id(neuron, nodeNet[seedNodeId2])
Expand Down Expand Up @@ -2194,17 +2197,19 @@ end

find the closest terminal node in uncollected node set as new growing seed.
"""
function find_seed_node_id(neuron::Neuron, nodeNet::NodeNet, collectedFlagVec::BitArray{1})
function find_seed_node_id(neuron::Neuron, nodeNet::NodeNet, collectedFlagVec::BitArray{1};
terminalNodeIdList2::Vector{Int} = NodeNets.get_terminal_node_id_list(nodeNet))
# number 1 means the alread grown neuron
# number 2 means the id in the raw node net
segmentList1 = get_segment_list(neuron)
# the new seed will be chosen this terminal node set
terminalNodeIdList2 = NodeNets.get_terminal_node_id_list(nodeNet)
@assert all(terminalNodeIdList2 .> 0)

# initialization
seedNodeId2 = 0
distance = typemax(Float32)

@assert !isempty(terminalNodeIdList2)
@assert !isempty(segmentList1)
@assert !all(collectedFlagVec)

Expand All @@ -2220,14 +2225,19 @@ function find_seed_node_id(neuron::Neuron, nodeNet::NodeNet, collectedFlagVec::B
bbox_distance = Segments.get_bounding_box_distance(segment1, node2)
if bbox_distance < distance
d, _ = Segments.distance_from(segment1, node2)
if d < distance
if d < distance
# @show bbox_distance, distance, d, seedNodeId2, candidateSeedId2
distance = d
seedNodeId2 = candidateSeedId2
seedNodeId2 = candidateSeedId2
if seedNodeId2 == 0
@show d, bbox_distance
end
end
end
end
end
end
@assert seedNodeId2 > 0
return seedNodeId2
end

Expand Down
6 changes: 3 additions & 3 deletions src/NodeNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ end
function NodeNet(swc::SWC)
nodeList = Vector{NTuple{4, Float32}}()
nodeClassList = Vector{UInt8}()
I = Vector{UInt32}()
J = Vector{UInt32}()
I = UInt32[]
J = UInt32[]
for (index, point) in enumerate(swc)
push!(nodeList, (point.x, point.y, point.z, point.radius))
push!(nodeClassList, point.class)
Expand All @@ -189,7 +189,7 @@ function NodeNet(swc::SWC)
push!(J, index)
end
end
connectivityMatrix = sparse(I,J,true, length(swc), length(swc))
connectivityMatrix = sparse(I,J, true, length(swc), length(swc))
NodeNet(nodeList, nodeClassList, connectivityMatrix)
end

Expand Down
Loading