diff --git a/gmodule.lua b/gmodule.lua index 551baae..eff60bc 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -479,6 +479,78 @@ function gModule:accGradParameters(input,gradOutput,lr) end end +function gModule:backward(input,gradOutput,scale) + local function neteval(node) + if node.data.selectindex then + assert(not node.data.module, "the selectindex-handling nodes should have no module") + assert(#node.children == 1, "only the splitted node should be the input") + local child = node.children[1] + local go = getTotalGradOutput(node) + child.data.gradOutput = child.data.gradOutput or {} + assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") + -- The data.gradOutput holds the to-be-summed gradients. + child.data.gradOutput[1] = child.data.gradOutput[1] or {} + assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") + child.data.gradOutput[1][node.data.selectindex] = go + else + local gradOutput = getTotalGradOutput(node) + -- backward through this node + -- If no module is present, the node behaves like nn.Identity. + local gradInput + if not node.data.module then + gradInput = gradOutput + else + local input = node.data.input + -- a parameter node is captured + if input == nil and node.data.module ~= nil then + input = {} + end + if #input == 1 then + input = input[1] + end + local module = node.data.module + gradInput = module:backward(input,gradOutput,scale) + end + -- propagate the output to children + for i,child in ipairs(node.children) do + child.data.gradOutput = child.data.gradOutput or {} + local mapindex = node.data.mapindex[child.data] + local gi + if #node.children == 1 then + gi = gradInput + else + gi = gradInput[mapindex] + end + table.insert(child.data.gradOutput,gi) + end + end + if self.verbose then + print(' V : ' .. node:label()) + end + end + local outnode = self.outnode + if #outnode.children > 1 and #gradOutput ~= #outnode.children then + error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) + end + for _,node in ipairs(self.backwardnodes) do + local gradOutput = node.data.gradOutput + while gradOutput and #gradOutput >0 do + table.remove(gradOutput) + end + end + -- Set the starting gradOutput. + outnode.data.gradOutput = outnode.data.gradOutput or {} + outnode.data.gradOutput[1] = gradOutput + + for i,node in ipairs(self.backwardnodes) do + neteval(node) + end + + assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") + self.gradInput = self.innode.data.gradOutput[1] + return self.gradInput +end + function gModule:read(file) local data = file:readObject() for k, v in pairs(data) do