public class LMNNGradientCalculator extends Object
Constructor and Description |
---|
LMNNGradientCalculator(Sequence[] data,
int[] trainingLabels,
AlignmentAlgorithm<? extends AlignmentDerivativeAlgorithm> algo) |
Modifier and Type | Method and Description |
---|---|
<X extends Value,Y> |
calculateParameterGradient(DerivableComparator<X,Y> comp,
String keyword,
double[][] D)
Calculates the gradient of the LMNN cost function with respect to the
parameters of the given comparator.
|
double[] |
calculateWeightGradient(double[][] D)
Calculates the gradient of the LMNN cost function with respect to the
keyword weights.
|
AlignmentAlgorithm<? extends AlignmentDerivativeAlgorithm> |
getAlgo()
Returns the algorithm that is used to compute the gradient on the
alignment distance.
|
Sequence[] |
getData()
Returns the training data points as Sequences.
|
double |
getMargin()
Returns the margin of safety that is required by the LMNN cost function.
|
int |
getNumberOfThreads()
Returns the number of threads used in the parallel computation of the
gradient on the LMNN cost function.
|
ProgressReporter |
getReporter()
Returns the ProgressReporter that is used to report progress.
|
int[] |
getTrainingLabels()
Returns the true labels for all training data points.
|
void |
setK(int K)
Sets the number of considered nearest neighbors in the LMNN cost
function.
|
void |
setMargin(double margin)
Sets the margin of safety that is required by the LMNN cost function.
|
void |
setNumberOfThreads(int numberOfThreads)
Sets the number of threads used in the parallel computation of the
gradient on the LMNN cost function.
|
void |
setReporter(ProgressReporter reporter)
Sets the ProgressReporter that is used to report progress.
|
public LMNNGradientCalculator(Sequence[] data, int[] trainingLabels, AlignmentAlgorithm<? extends AlignmentDerivativeAlgorithm> algo)
public Sequence[] getData()
public int[] getTrainingLabels()
public AlignmentAlgorithm<? extends AlignmentDerivativeAlgorithm> getAlgo()
public void setK(int K)
K
- the number of considered nearest neighbors in the LMNN cost
function.public double getMargin()
public void setMargin(double margin)
margin
- the margin of safety that is required by the LMNN cost
function.public int getNumberOfThreads()
public void setNumberOfThreads(int numberOfThreads)
numberOfThreads
- the number of threads used in the parallel
computation of the gradient on the LMNN cost function.public ProgressReporter getReporter()
public void setReporter(ProgressReporter reporter)
reporter
- the ProgressReporter that is used to report progress.
This is a CommandLineProgressReporter per default. If it is set to null,
the progress is not reported.public <X extends Value,Y> Y calculateParameterGradient(DerivableComparator<X,Y> comp, String keyword, double[][] D)
X
- the value class the given comparator operates on.Y
- the format of the parameters for the given comparator.comp
- the comparator itself.keyword
- the keyword for which the given comparator is used in the
AlignmentSpecification.D
- given N training data points this should be a N x N matrix
of alignment distances computed with the same distance
scheme as is implemented by the given algorithm for this
LMNNGradientCalculator. This distance matrix serves as basis for the
determination of the LMNN cost function.public double[] calculateWeightGradient(double[][] D)
D
- given N training data points this should be a N x N matrix
of alignment distances computed with the same distance
scheme as is implemented by the given algorithm for this
LMNNGradientCalculator. This distance matrix serves as basis for the
determination of the LMNN cost function.Copyright (C) 2013-2015 Benjamin Paaßen, Georg Zentgraf, AG Theoretical Computer Science, Centre of Excellence Cognitive Interaction Technology (CITEC), University of Bielefeld, licensed under the AGPL v. 3: http://openresearch.cit-ec.de/projects/tcs . This documentation is licensed under the conditions of CC-BY-SA 4.0: https://creativecommons.org/licenses/by-sa/4.0/