From 80208ba0fa1a38a41ad35d09794c61643714d292 Mon Sep 17 00:00:00 2001 From: ds5678 <49847914+ds5678@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:20:02 -0800 Subject: [PATCH] Add ReadOnlySpan overloads to many methods --- src/TorchSharp/Autograd.cs | 12 +- src/TorchSharp/AutogradFunction.cs | 10 +- src/TorchSharp/LinearAlgebra.cs | 2 +- src/TorchSharp/NN/Utils/RNNUtils.cs | 4 +- src/TorchSharp/Optimizers/LBFGS.cs | 2 +- ...torch.IndexingSlicingJoiningMutatingOps.cs | 142 ++++++++++++++---- .../Tensor/torch.OtherOperations.cs | 35 +++-- src/TorchSharp/Tensor/torch.cs | 67 +++++++-- src/TorchSharp/Torch.cs | 8 +- src/TorchSharp/Utils/OverloadHelper.cs | 40 +++++ 10 files changed, 245 insertions(+), 77 deletions(-) create mode 100644 src/TorchSharp/Utils/OverloadHelper.cs diff --git a/src/TorchSharp/Autograd.cs b/src/TorchSharp/Autograd.cs index 4c73fce46..c043225da 100644 --- a/src/TorchSharp/Autograd.cs +++ b/src/TorchSharp/Autograd.cs @@ -135,9 +135,9 @@ public static IList grad(IList outputs, IList inputs, IL using var grads = new PinnedArray(); using var results = new PinnedArray(); - IntPtr outsRef = outs.CreateArray(outputs.Select(p => p.Handle).ToArray()); - IntPtr insRef = ins.CreateArray(inputs.Select(p => p.Handle).ToArray()); - IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.Select(p => p.Handle).ToArray()); + IntPtr outsRef = outs.CreateArray(outputs.ToHandleArray()); + IntPtr insRef = ins.CreateArray(inputs.ToHandleArray()); + IntPtr gradsRef = grad_outputs == null ? IntPtr.Zero : grads.CreateArray(grad_outputs.ToHandleArray()); long gradsLength = grad_outputs == null ? 0 : grads.Array.Length; THSAutograd_grad(outsRef, outs.Array.Length, insRef, ins.Array.Length, gradsRef, gradsLength, retain_graph, create_graph, allow_unused, results.CreateArray); @@ -178,9 +178,9 @@ public static void backward(IList tensors, IList grad_tensors = using var ts = new PinnedArray(); using var gts = new PinnedArray(); using var ins = new PinnedArray(); - IntPtr tensRef = ts.CreateArray(tensors.Select(p => p.Handle).ToArray()); - IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.Select(p => p.Handle).ToArray()); - IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.Select(p => p.Handle).ToArray()); + IntPtr tensRef = ts.CreateArray(tensors.ToHandleArray()); + IntPtr gradsRef = grad_tensors == null ? IntPtr.Zero : gts.CreateArray(grad_tensors.ToHandleArray()); + IntPtr insRef = inputs == null ? IntPtr.Zero : ins.CreateArray(inputs.ToHandleArray()); long insLength = inputs == null ? 0 : ins.Array.Length; long gradsLength = grad_tensors == null ? 0 : gts.Array.Length; diff --git a/src/TorchSharp/AutogradFunction.cs b/src/TorchSharp/AutogradFunction.cs index 390ce94c6..49adb2ee1 100644 --- a/src/TorchSharp/AutogradFunction.cs +++ b/src/TorchSharp/AutogradFunction.cs @@ -148,7 +148,7 @@ internal List ComputeVariableInput(object[] args) internal void SetNextEdges(List inputVars, bool isExecutable) { using var l = new PinnedArray(); - THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray()), isExecutable); + THSAutograd_CSharpNode_setNextEdges(handle, l.CreateArrayWithSize(inputVars.ToHandleArray()), isExecutable); CheckForErrors(); } @@ -166,10 +166,10 @@ internal List WrapOutputs(List inputVars, List outputs, using var outputArr = new PinnedArray(); using var resultsArr = new PinnedArray(); - var varsPtr = varsArr.CreateArrayWithSize(inputVars.Select(v => v.Handle).ToArray()); - var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.Select(v => v.Handle).ToArray()); - var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.Select(v => v.Handle).ToArray()); - var outputPtr = outputArr.CreateArrayWithSize(outputs.Select(v => v.Handle).ToArray()); + var varsPtr = varsArr.CreateArrayWithSize(inputVars.ToHandleArray()); + var diffsPtr = diffArr.CreateArrayWithSize(_context.NonDifferentiableTensors.ToHandleArray()); + var dirtyPtr = diffArr.CreateArrayWithSize(_context.DirtyTensors.ToHandleArray()); + var outputPtr = outputArr.CreateArrayWithSize(outputs.ToHandleArray()); THSAutograd_Function_wrapOutputs(varsPtr, diffsPtr, dirtyPtr, outputPtr, isExecutable ? handle : new(), resultsArr.CreateArray); CheckForErrors(); diff --git a/src/TorchSharp/LinearAlgebra.cs b/src/TorchSharp/LinearAlgebra.cs index 436650ac7..91a22e3b2 100644 --- a/src/TorchSharp/LinearAlgebra.cs +++ b/src/TorchSharp/LinearAlgebra.cs @@ -444,7 +444,7 @@ public static Tensor multi_dot(IList tensors) } using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSLinalg_multi_dot(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) torch.CheckForErrors(); diff --git a/src/TorchSharp/NN/Utils/RNNUtils.cs b/src/TorchSharp/NN/Utils/RNNUtils.cs index ab0b62cc5..eb486a912 100644 --- a/src/TorchSharp/NN/Utils/RNNUtils.cs +++ b/src/TorchSharp/NN/Utils/RNNUtils.cs @@ -55,7 +55,7 @@ public static (torch.Tensor, torch.Tensor) pad_packed_sequence(PackedSequence se /// The padded tensor public static torch.Tensor pad_sequence(IEnumerable sequences, bool batch_first = false, double padding_value = 0.0) { - var sequences_arg = sequences.Select(p => p.Handle).ToArray(); + var sequences_arg = sequences.ToHandleArray(); var res = THSNN_pad_sequence(sequences_arg, sequences_arg.Length, batch_first, padding_value); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new torch.Tensor(res); @@ -69,7 +69,7 @@ public static torch.Tensor pad_sequence(IEnumerable sequences, boo /// The packed batch of variable length sequences public static PackedSequence pack_sequence(IEnumerable sequences, bool enforce_sorted = true) { - var sequences_arg = sequences.Select(p => p.Handle).ToArray(); + var sequences_arg = sequences.ToHandleArray(); var res = THSNN_pack_sequence(sequences_arg, sequences_arg.Length, enforce_sorted); if (res.IsInvalid) { torch.CheckForErrors(); } return new PackedSequence(res); diff --git a/src/TorchSharp/Optimizers/LBFGS.cs b/src/TorchSharp/Optimizers/LBFGS.cs index 1249b5ba5..a06424dce 100644 --- a/src/TorchSharp/Optimizers/LBFGS.cs +++ b/src/TorchSharp/Optimizers/LBFGS.cs @@ -47,7 +47,7 @@ public static LBFGS LBFGS(IEnumerable parameters, double lr = 0.01, l if (!max_eval.HasValue) max_eval = 5 * max_iter / 4; using var parray = new PinnedArray(); - IntPtr paramsRef = parray.CreateArray(parameters.Select(p => p.Handle).ToArray()); + IntPtr paramsRef = parray.CreateArray(parameters.ToHandleArray()); var res = THSNN_LBFGS_ctor(paramsRef, parray.Array.Length, lr, max_iter, max_eval.Value, tolerange_grad, tolerance_change, history_size); if (res == IntPtr.Zero) { torch.CheckForErrors(); } diff --git a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs index c55ec9f4c..15824430c 100644 --- a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs +++ b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs @@ -44,7 +44,39 @@ public static Tensor cat(IList tensors, long dim = 0) } using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); + + var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim); + if (res == IntPtr.Zero) CheckForErrors(); + return new Tensor(res); + } + + // https://pytorch.org/docs/stable/generated/torch.cat + /// + /// Concatenates the given sequence of tensors in the given dimension. + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor cat(Tensor[] tensors, long dim = 0) => torch.cat((ReadOnlySpan)tensors, dim); + + // https://pytorch.org/docs/stable/generated/torch.cat + /// + /// Concatenates the given sequence of tensors in the given dimension. + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor cat(ReadOnlySpan tensors, long dim = 0) + { + switch (tensors.Length) + { + case <=0: throw new ArgumentException(nameof(tensors)); + case 1: return tensors[0].alias(); + } + + using var parray = new PinnedArray(); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSTensor_cat(tensorsRef, parray.Array.Length, dim); if (res == IntPtr.Zero) CheckForErrors(); @@ -60,6 +92,24 @@ public static Tensor cat(IList tensors, long dim = 0) /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. public static Tensor concat(IList tensors, long dim = 0) => torch.cat(tensors, dim); + // https://pytorch.org/docs/stable/generated/torch.concat + /// + /// Alias of torch.cat() + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor concat(Tensor[] tensors, long dim = 0) => torch.cat(tensors, dim); + + // https://pytorch.org/docs/stable/generated/torch.concat + /// + /// Alias of torch.cat() + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor concat(ReadOnlySpan tensors, long dim = 0) => torch.cat(tensors, dim); + // https://pytorch.org/docs/stable/generated/torch.conj /// /// Returns a view of input with a flipped conjugate bit. If input has a non-complex dtype, this function just returns input. @@ -103,7 +153,7 @@ public static Tensor[] dsplit(Tensor input, (long, long, long, long) indices_or_ /// /// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d(). public static Tensor dstack(params Tensor[] tensors) - => dstack((IEnumerable)tensors); + => dstack(tensors.ToHandleArray()); // https://pytorch.org/docs/stable/generated/torch.dstack /// @@ -113,16 +163,19 @@ public static Tensor dstack(params Tensor[] tensors) /// /// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d(). public static Tensor dstack(IList tensors) - { - using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + => dstack(tensors.ToHandleArray()); - var res = THSTensor_dstack(tensorsRef, parray.Array.Length); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } - } + // https://pytorch.org/docs/stable/generated/torch.dstack + /// + /// Stack tensors in sequence depthwise (along third axis). + /// + /// + /// + /// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d(). + public static Tensor dstack(ReadOnlySpan tensors) + => dstack(tensors.ToHandleArray()); + // https://pytorch.org/docs/stable/generated/torch.dstack /// /// Stack tensors in sequence depthwise (along third axis). /// @@ -130,14 +183,19 @@ public static Tensor dstack(IList tensors) /// /// This is equivalent to concatenation along the third axis after 1-D and 2-D tensors have been reshaped by torch.atleast_3d(). public static Tensor dstack(IEnumerable tensors) + => dstack(tensors.ToHandleArray()); + + static Tensor dstack(IntPtr[] tensors) { - using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); - var res = THSTensor_dstack(tensorsRef, parray.Array.Length); - if (res == IntPtr.Zero) { CheckForErrors(); } - return new Tensor(res); + using (var parray = new PinnedArray()) { + IntPtr tensorsRef = parray.CreateArray(tensors); + + var res = THSTensor_dstack(tensorsRef, parray.Array.Length); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } } - + // https://pytorch.org/docs/stable/generated/torch.gather /// /// Gathers values along an axis specified by dim. @@ -192,14 +250,7 @@ public static Tensor[] hsplit(Tensor input, (long, long, long, long) indices_or_ /// /// public static Tensor hstack(IList tensors) - { - using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); - - var res = THSTensor_hstack(tensorsRef, parray.Array.Length); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } + => hstack(tensors.ToHandleArray()); // https://pytorch.org/docs/stable/generated/torch.hstack /// @@ -208,9 +259,7 @@ public static Tensor hstack(IList tensors) /// /// public static Tensor hstack(params Tensor[] tensors) - { - return hstack((IEnumerable)tensors); - } + => hstack(tensors.ToHandleArray()); // https://pytorch.org/docs/stable/generated/torch.hstack /// @@ -219,9 +268,21 @@ public static Tensor hstack(params Tensor[] tensors) /// /// public static Tensor hstack(IEnumerable tensors) + => hstack(tensors.ToHandleArray()); + + // https://pytorch.org/docs/stable/generated/torch.hstack + /// + /// Stack tensors in sequence horizontally (column wise). + /// + /// + /// + public static Tensor hstack(ReadOnlySpan tensors) + => hstack(tensors.ToHandleArray()); + + static Tensor hstack(IntPtr[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors); var res = THSTensor_hstack(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } @@ -474,7 +535,7 @@ public static Tensor[] split(Tensor tensor, long[] split_size_or_sections, long public static Tensor stack(IEnumerable tensors, long dim = 0) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSTensor_stack(tensorsRef, parray.Array.Length, dim); if (res == IntPtr.Zero) { CheckForErrors(); } @@ -560,9 +621,30 @@ public static Tensor[] vsplit(Tensor input, long[] indices_or_sections) /// /// public static Tensor vstack(IList tensors) + => vstack(tensors.ToHandleArray()); + + // https://pytorch.org/docs/stable/generated/torch.vstack + /// + /// Stack tensors in sequence vertically (row wise). + /// + /// + /// + public static Tensor vstack(Tensor[] tensors) + => vstack(tensors.ToHandleArray()); + + // https://pytorch.org/docs/stable/generated/torch.vstack + /// + /// Stack tensors in sequence vertically (row wise). + /// + /// + /// + public static Tensor vstack(ReadOnlySpan tensors) + => vstack(tensors.ToHandleArray()); + + static Tensor vstack(IntPtr[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors); var res = THSTensor_vstack(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } diff --git a/src/TorchSharp/Tensor/torch.OtherOperations.cs b/src/TorchSharp/Tensor/torch.OtherOperations.cs index 4edfcf715..a41cb340d 100644 --- a/src/TorchSharp/Tensor/torch.OtherOperations.cs +++ b/src/TorchSharp/Tensor/torch.OtherOperations.cs @@ -45,7 +45,7 @@ public static partial class torch public static Tensor block_diag(params Tensor[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSTensor_block_diag(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } @@ -71,7 +71,7 @@ public static IList broadcast_tensors(params Tensor[] tensors) using (var pa = new PinnedArray()) using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); THSTensor_broadcast_tensors(tensorsRef, tensors.Length, pa.CreateArray); CheckForErrors(); @@ -125,22 +125,31 @@ public static Tensor bucketize(Tensor input, Tensor boundaries, bool outInt32 = /// Do cartesian product of the given sequence of tensors. /// /// - public static Tensor cartesian_prod(IList tensors) - { - using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + public static Tensor cartesian_prod(IList tensors) => cartesian_prod(tensors.ToHandleArray()); - var res = THSTensor_cartesian_prod(tensorsRef, parray.Array.Length); - if (res == IntPtr.Zero) { torch.CheckForErrors(); } - return new Tensor(res); - } + // https://pytorch.org/docs/stable/generated/torch.cartesian_prod + /// + /// Do cartesian product of the given sequence of tensors. + /// + /// + public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod(tensors.ToHandleArray()); // https://pytorch.org/docs/stable/generated/torch.cartesian_prod /// /// Do cartesian product of the given sequence of tensors. /// /// - public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod((IList)tensors); + public static Tensor cartesian_prod(ReadOnlySpan tensors) => cartesian_prod(tensors.ToHandleArray()); + + static Tensor cartesian_prod(IntPtr[] tensors) + { + using var parray = new PinnedArray(); + IntPtr tensorsRef = parray.CreateArray(tensors); + + var res = THSTensor_cartesian_prod(tensorsRef, parray.Array.Length); + if (res == IntPtr.Zero) { torch.CheckForErrors(); } + return new Tensor(res); + } // https://pytorch.org/docs/stable/generated/torch.cdist /// @@ -350,7 +359,7 @@ public static Tensor diag_embed(Tensor input, long offset = 0L, long dim1 = -2L, public static Tensor einsum(string equation, params Tensor[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSTensor_einsum(equation, tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } @@ -512,7 +521,7 @@ public static Tensor[] meshgrid(IEnumerable tensors, string indexing = " IntPtr[] ptrArray; using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); _ = THSTensor_meshgrid(tensorsRef, parray.Array.Length, indexing, parray.CreateArray); CheckForErrors(); ptrArray = parray.Array; diff --git a/src/TorchSharp/Tensor/torch.cs b/src/TorchSharp/Tensor/torch.cs index 6892d2b69..3361c4964 100644 --- a/src/TorchSharp/Tensor/torch.cs +++ b/src/TorchSharp/Tensor/torch.cs @@ -33,6 +33,24 @@ public static partial class torch /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. public static Tensor concatenate(IList tensors, long axis = 0) => torch.cat(tensors, axis); + /// + /// Concatenates the given sequence of tensors along the given axis (dimension). + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor concatenate(Tensor[] tensors, long axis = 0) => torch.cat(tensors, axis); + + /// + /// Concatenates the given sequence of tensors along the given axis (dimension). + /// + /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension. + /// The dimension over which the tensors are concatenated + /// + /// All tensors must either have the same shape (except in the concatenating dimension) or be empty. + public static Tensor concatenate(ReadOnlySpan tensors, long axis = 0) => torch.cat(tensors, axis); + /// /// Returns a tensor with all the dimensions of input of size 1 removed. When dim is given, a squeeze operation is done only in the given dimension. /// @@ -55,10 +73,28 @@ public static partial class torch /// A list of input tensors. /// /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally. - public static Tensor column_stack(IList tensors) + public static Tensor column_stack(IList tensors) => column_stack(tensors.ToHandleArray()); + + /// + /// Creates a new tensor by horizontally stacking the input tensors. + /// + /// A list of input tensors. + /// + /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally. + public static Tensor column_stack(params Tensor[] tensors) => column_stack(tensors.ToHandleArray()); + + /// + /// Creates a new tensor by horizontally stacking the input tensors. + /// + /// A list of input tensors. + /// + /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally. + public static Tensor column_stack(ReadOnlySpan tensors) => column_stack(tensors.ToHandleArray()); + + static Tensor column_stack(IntPtr[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors); var res = THSTensor_column_stack(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } @@ -66,35 +102,36 @@ public static Tensor column_stack(IList tensors) } /// - /// Creates a new tensor by horizontally stacking the input tensors. + /// Stack tensors in sequence vertically (row wise). /// - /// A list of input tensors. + /// /// - /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally. - public static Tensor column_stack(params Tensor[] tensors) => column_stack((IList)tensors); + public static Tensor row_stack(IList tensors) => row_stack(tensors.ToHandleArray()); + + /// + /// Stack tensors in sequence vertically (row wise). + /// + /// + /// + public static Tensor row_stack(params Tensor[] tensors) => row_stack(tensors.ToHandleArray()); /// /// Stack tensors in sequence vertically (row wise). /// /// /// - public static Tensor row_stack(IList tensors) + public static Tensor row_stack(ReadOnlySpan tensors) => row_stack(tensors.ToHandleArray()); + + static Tensor row_stack(IntPtr[] tensors) { using var parray = new PinnedArray(); - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors); var res = THSTensor_row_stack(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) { CheckForErrors(); } return new Tensor(res); } - /// - /// Stack tensors in sequence vertically (row wise). - /// - /// - /// - public static Tensor row_stack(params Tensor[] tensors) => row_stack((IList)tensors); - /// /// Removes a tensor dimension. /// diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs index c59196bfe..f35c5df71 100644 --- a/src/TorchSharp/Torch.cs +++ b/src/TorchSharp/Torch.cs @@ -393,7 +393,7 @@ public static partial class utils public static double clip_grad_norm_(IEnumerable tensors, double max_norm, double norm_type = 2.0) { using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var value = THSTensor_clip_grad_norm_(tensorsRef, parray.Array.Length, max_norm, norm_type); CheckForErrors(); return value; @@ -409,7 +409,7 @@ public static double clip_grad_norm_(IEnumerable tensors, dou public static void clip_grad_value_(IEnumerable tensors, double clip_value) { using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); THSTensor_clip_grad_value_(tensorsRef, parray.Array.Length, clip_value); CheckForErrors(); } @@ -423,7 +423,7 @@ public static void clip_grad_value_(IEnumerable tensors, doub public static Tensor parameters_to_vector(IEnumerable tensors) { using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); var res = THSTensor_parameters_to_vector(tensorsRef, parray.Array.Length); if (res == IntPtr.Zero) @@ -441,7 +441,7 @@ public static Tensor parameters_to_vector(IEnumerable tensors public static void vector_to_parameters(Tensor vec, IEnumerable tensors) { using (var parray = new PinnedArray()) { - IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray()); + IntPtr tensorsRef = parray.CreateArray(tensors.ToHandleArray()); THSTensor_vector_to_parameters(vec.Handle, tensorsRef, parray.Array.Length); CheckForErrors(); diff --git a/src/TorchSharp/Utils/OverloadHelper.cs b/src/TorchSharp/Utils/OverloadHelper.cs new file mode 100644 index 000000000..14316a0a1 --- /dev/null +++ b/src/TorchSharp/Utils/OverloadHelper.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace TorchSharp +{ + static class OverloadHelper + { + public static IntPtr[] ToHandleArray(this ReadOnlySpan span) + { + if (span.Length == 0) + return Array.Empty(); + + var result = new IntPtr[span.Length]; + for (int i = 0; i < span.Length; i++) + result[i] = span[i].Handle; + + return result; + } + + public static IntPtr[] ToHandleArray(this IList list) + { + if (list.Count == 0) + return Array.Empty(); + + var result = new IntPtr[list.Count]; + for (int i = 0; i < list.Count; i++) + result[i] = list[i].Handle; + + return result; + } + + public static IntPtr[] ToHandleArray(this torch.Tensor[] array) => ToHandleArray((ReadOnlySpan)array); + + public static IntPtr[] ToHandleArray(this IEnumerable enumerable) + { + return enumerable.Select(t => t.Handle).ToArray(); + } + } +}