Commit 6db7f300 authored by javier's avatar javier
Browse files

'Initial Experiment on model validation'

parents
Pipeline #3296 passed with stage
.classpath
.project
.settings/
.vscode/
target/
\ No newline at end of file
<!DOCTYPE html>
<html>
<body>
<h1>CODE FOR SAC'21 PAPER: "Testing the tests: Testing the Tests: Simulation of Rankings to Compare Statistical Significance Tests in Information Retrieval Evaluation."</h1>
<h2>REQUIREMENTS</h2>
<ul>
<li>JDK 11 or above</li>
<li>R 3.5 or above</li>
<li>R package nonpar installed <code>install.packages("nonpar")</code></li>
<li>Maven 3.6.0</li>
<li>TREC system runs and qrels</li>
</ul>
<h2>COMPILATION</h2>
<code>mvn install</code>
<h2>EXECUTION</h2>
<p>Experiments are executed by running the corresponding classes, if you want to use directly the jar file update the mainClass property in the pom file and then run:</p>
<code>java -jar tests-1.0.0-jar-with-dependencies.jar </code>
</body>
</html>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.irlab</groupId>
<artifactId>tests</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>tests</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<pluginRepositories>
<pluginRepository>
<id>onejar-maven-plugin.googlecode.com</id>
<url>http://onejar-maven-plugin.googlecode.com/svn/mavenrepo</url>
</pluginRepository>
</pluginRepositories>
<repositories>
<repository>
<id>clojars</id>
<url>http://clojars.org/repo/</url>
</repository>
<repository>
<id>attlasian</id>
<name>Attlasian</name>
<url>https://packages.atlassian.com/maven-3rdparty/</url>
</repository>
</repositories>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.3</version>
<configuration>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<version>2.5.4</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
<archive>
<manifest>
<mainClass>org.irlab.tests.experiment.ExperimentTestLogisticRegressionModel</mainClass>
</manifest>
</archive>
</configuration>
</plugin>
</plugins>
</build>
<dependencies>
<dependency>
<groupId>com.github.jbytecode</groupId>
<artifactId>RCaller</artifactId>
<version>2.8</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.3</version>
</dependency>
<dependency>
<groupId>jsc</groupId>
<artifactId>jsc</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>org.javatuples</groupId>
<artifactId>javatuples</artifactId>
<version>1.2</version>
</dependency>
<dependency>
<groupId>com.koloboke</groupId>
<artifactId>koloboke-api-jdk8</artifactId>
<version>1.0.0</version>
</dependency>
<dependency>
<groupId>com.koloboke</groupId>
<artifactId>koloboke-impl-jdk8</artifactId>
<version>1.0.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.28</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-log4j12</artifactId>
<version>1.7.28</version>
</dependency>
<dependency>
<groupId>com.numericalmethod</groupId>
<artifactId>suanshu-20120606</artifactId>
<version>1.0.1-atlassian-hosted</version>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
<artifactId>fastutil</artifactId>
<version>8.3.0</version>
</dependency>
<dependency>
<groupId>com.github.haifengl</groupId>
<artifactId>smile-core</artifactId>
<version>2.0.0</version>
</dependency>
</dependencies>
</project>
package org.irlab.tests.experiment;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.correlation.KendallsCorrelation;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.correlation.SpearmansCorrelation;
import org.apache.log4j.BasicConfigurator;
import org.irlab.tests.models.LogisticRegressionProbabilityModel;
import org.irlab.tests.models.ProbabilityModel;
import org.irlab.tests.util.correlation.TauAP;
import org.irlab.tests.util.metric.MetricEnum;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ExperimentTestLogisticRegressionModel {
private static final Logger logger = LoggerFactory.getLogger(ExperimentTestLogisticRegressionModel.class);
public static void main(String[] args) throws IOException, InterruptedException {
BasicConfigurator.configure();
final int THREADS = Runtime.getRuntime().availableProcessors();
int[] trecs = new int[] { 3, 5, 6, 7, 8 };
String[] runpaths = new String[] {
"/mnt/datasets/trec_col/runs/adhoc-routing/trec3.results.input/adhoc",
"/mnt/datasets/trec_col/runs/adhoc-routing/trec5.results.input/adhoc/CategoryA",
"/mnt/datasets/trec_col/runs/adhoc-routing/trec6.results.input/adhoc",
"/mnt/datasets/trec_col/runs/adhoc-routing/trec7.results.input/adhoc",
"/mnt/datasets/trec_col/runs/adhoc-routing/trec8.results.input/adhoc" };
String[] qrelpaths = new String[] { "/mnt/datasets/trec_col/qrels/qrels.151-200.trec3",
"/mnt/datasets/trec_col/qrels/qrels.251-300.trec5",
"/mnt/datasets/trec_col/qrels/qrels.301-350.trec6",
"/mnt/datasets/trec_col/qrels/qrels.351-400.trec7",
"/mnt/datasets/trec_col/qrels/qrels.401-450.trec8" };
logger.info("Running with {} threads", THREADS);
int samples = 1000;
int cutoff = 1000;
MetricEnum[] metrics = new MetricEnum[] { MetricEnum.MAP };
for (int trec = 0; trec < trecs.length; trec++) {
for (MetricEnum metric : metrics) {
File folder = new File(runpaths[trec]);
logger.info("Correlation for TREC {} and {} systems under {} with {} samples",
trecs[trec], folder.list().length, metric, samples);
double[] systemMeanTrueMetricValues = new double[folder.list().length];
double[][] systemMeanMetricValuesFromModel = new double[samples][folder.list().length];
int system = 0;
for (String file : folder.list()) {
ProbabilityModel model = new LogisticRegressionProbabilityModel(samples,
THREADS, cutoff, qrelpaths[trec], metric);
String run = runpaths[trec] + File.separator + file;
double[] trueMetricValues = model
.learnProbabilitiesAndReturnMetricPerQuery(run);
systemMeanTrueMetricValues[system] = StatUtils.mean(trueMetricValues);
double[] meanMetricValuesFromModel = model.sampleMeanMetric(0, 50, false);
for (int sample = 0; sample < samples; sample++) {
systemMeanMetricValuesFromModel[sample][system] = meanMetricValuesFromModel[sample];
}
double[] simulatedMetricValues = new double[50];
int samplesK = 1000;
for (int s = 0; s < samplesK; s++) {
double[] sample = model.sampleMetric(0, 50, false);
for (int p = 0; p < sample.length; p++) {
simulatedMetricValues[p] += sample[p] / samplesK;
}
}
system++;
}
TauAP tau = new TauAP();
KendallsCorrelation kendalls = new KendallsCorrelation();
PearsonsCorrelation pearson = new PearsonsCorrelation();
SpearmansCorrelation spearman = new SpearmansCorrelation();
double[] avgSystemMeanTrueMetricValues = new double[systemMeanTrueMetricValues.length];
for (int sys = 0; sys < systemMeanTrueMetricValues.length; sys++) {
for (int rep = 0; rep < samples; rep++) {
avgSystemMeanTrueMetricValues[sys] += systemMeanMetricValuesFromModel[rep][sys];
}
avgSystemMeanTrueMetricValues[sys] /= samples;
}
double k = kendalls.correlation(
Arrays.copyOf(systemMeanTrueMetricValues,
systemMeanTrueMetricValues.length),
Arrays.copyOf(avgSystemMeanTrueMetricValues,
avgSystemMeanTrueMetricValues.length));
double t = tau.correlation(
Arrays.copyOf(systemMeanTrueMetricValues,
systemMeanTrueMetricValues.length),
Arrays.copyOf(avgSystemMeanTrueMetricValues,
avgSystemMeanTrueMetricValues.length));
double p = pearson.correlation(
Arrays.copyOf(systemMeanTrueMetricValues,
systemMeanTrueMetricValues.length),
Arrays.copyOf(avgSystemMeanTrueMetricValues,
avgSystemMeanTrueMetricValues.length));
double s = spearman.correlation(
Arrays.copyOf(systemMeanTrueMetricValues,
systemMeanTrueMetricValues.length),
Arrays.copyOf(avgSystemMeanTrueMetricValues,
avgSystemMeanTrueMetricValues.length));
logger.info("TAUAP systems= {}", t);
logger.info("Pearson systems= {}", p);
logger.info("Spearman systems= {}", s);
logger.info("Kendalls systems= {}", k);
}
}
}
}
/*
* Copyright (c) 2019 Information Retrieval Lab - University of A Coruña
*/
package org.irlab.tests.models;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.irlab.tests.util.metric.MetricEnum;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import smile.classification.CustomLogitLogisticRegressor;
/**
* The Logistic Regression Probability Modelling of TREC systems
*
* @author Javier Parapar (javier.parapar@udc.es)
*/
public class LogisticRegressionProbabilityModel extends ProbabilityModel {
private static final Logger logger = LoggerFactory.getLogger(LogisticRegressionProbabilityModel.class);
public LogisticRegressionProbabilityModel(int samples, int threads, int cutoff, String pathToQrel,
MetricEnum metric) {
super(samples, threads, cutoff, pathToQrel, metric);
}
private Int2DoubleMap a = new Int2DoubleOpenHashMap();
private Int2DoubleMap b = new Int2DoubleOpenHashMap();
private Int2ObjectMap<double[][]> positions = new Int2ObjectOpenHashMap<>();
private Int2ObjectMap<int[]> relevantsInt = new Int2ObjectOpenHashMap<>();
private Int2ObjectMap<double[]> relevants = new Int2ObjectOpenHashMap<>();
private Int2ObjectMap<Int2ObjectMap<double[]>> probCache = new Int2ObjectOpenHashMap<>();
public double[] learn(int queryId, List<Pair<String, Double>> ranking, Map<Integer, Map<String, Double>> qrels) {
double[] probabilities = new double[ranking.size()];
double[] qRelevants = new double[ranking.size()];
int[] qRelevantsInt = new int[ranking.size()];
double[][] qPositions = new double[ranking.size()][1];
probCache.put(queryId, new Int2ObjectOpenHashMap<>());
int count = 0;
for (Pair<String, Double> position : ranking) {
qRelevants[count] = qrels.get(queryId).getOrDefault(position.getKey(), 0d);
count++;
}
for (int i = 0; i < qPositions.length; i++) {
qPositions[i][0] = i + 1d;
qRelevantsInt[i] = (int) qRelevants[i];
}
CustomLogitLogisticRegressor regressor = null;
relevants.put(queryId, qRelevants);
relevantsInt.put(queryId, qRelevantsInt);
positions.put(queryId, qPositions);
try {
regressor = CustomLogitLogisticRegressor.fit(qPositions, qRelevantsInt);
} catch (IllegalArgumentException e) {
logger.warn("Failed to create regressor for query {} returning default relevants", queryId, e);
return qRelevants;
}
double[] w = regressor.getW();
a.put(queryId, w[0]);
b.put(queryId, w[1]);
double[] posterior = new double[2];
for (int i = 0; i < qPositions.length; i++) {
regressor.predict(new double[] { i + 1 }, posterior);
probabilities[i] = posterior[1];
}
return probabilities;
}
public int[] sampleRelevance(int queryID, double change) {
double[] tmp = new double[0];
int cacheKey = (int) (change * 1000);
if (probCache.get(queryID).get(cacheKey) == null) {
synchronized (probCache) {
if (probCache.get(queryID).get(cacheKey) == null) {
boolean skip = false;
double[] probabilities = new double[positions.get(queryID).length];
double[] wchange = new double[2];
double p = 1d + change;
if (a.get(queryID) > 0) {
wchange[0] = p * a.get(queryID);
} else {
wchange[0] = (1 / p) * a.get(queryID);
}
if (b.get(queryID) > 0) {
wchange[1] = p * b.get(queryID);
} else {
wchange[1] = (1 / p) * b.get(queryID);
}
CustomLogitLogisticRegressor changedRegressor = null;
try {
changedRegressor = CustomLogitLogisticRegressor.dummyAlteredRegressor(positions.get(queryID),
relevantsInt.get(queryID), wchange);
} catch (IllegalArgumentException e) {
tmp = relevants.get(queryID);
probCache.get(queryID).put(cacheKey, tmp);
logger.warn("Problems with regressor for query {}", queryID, e);
skip = true;
}
if (!skip) {
double[] posterior = new double[2];
for (int i = 0; i < positions.get(queryID).length; i++) {
changedRegressor.predict(new double[] { i + 1 }, posterior);
probabilities[i] = posterior[1];
}
tmp = probabilities;
probCache.get(queryID).put(cacheKey, probabilities);
}
} else {
tmp = probCache.get(queryID).get(cacheKey);
}
}
} else {
try {
tmp = probCache.get(queryID).get(cacheKey);
} catch (java.lang.ArrayIndexOutOfBoundsException e) {
logger.error("Problems", e);
}
}
int[] sample = new int[tmp.length];
for (int i = 0; i < sample.length; i++) {
if (randomGenerator.nextDouble() <= tmp[i]) {
sample[i] = 1;
} else {
sample[i] = 0;
}
}
return sample;
}
}
/*
* Copyright (c) 2019 Information Retrieval Lab - University of A Coruña
*/
package org.irlab.tests.models;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.irlab.tests.models.util.ProbabilityModelThread;
import org.irlab.tests.util.ParallelDoubleArrays;
import org.irlab.tests.util.metric.Metric;
import org.irlab.tests.util.metric.MetricEnum;
import org.irlab.tests.util.metric.MetricFactory;
import org.irlab.tests.util.trec.TRECParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
/**
* The generic Probability Modelling of TREC systems
*
* @author Javier Parapar (javier.parapar@udc.es)
*/
public abstract class ProbabilityModel {
/** The Constant logger. */
private static final Logger logger = LoggerFactory.getLogger(ProbabilityModel.class);
public static final Random randomGenerator = new Random();
static int TOP = 1000;
private Int2ObjectMap<double[]> probs;
private Metric metric;
private int samples;
public BlockingQueue<AbstractMap.SimpleImmutableEntry<Integer, DoubleList>> queue;
private int threads;
private Map<Integer, Map<String, Double>> qrels;
private int[] queryNumbers;
public ProbabilityModel(int samples, int threads, int cutoff, String pathToQrel, MetricEnum metric) {
super();
this.samples = samples;
this.threads = threads;
queue = new ArrayBlockingQueue<AbstractMap.SimpleImmutableEntry<Integer, DoubleList>>(samples);
this.qrels = TRECParser.parseQrels(pathToQrel);
this.queryNumbers = ArrayUtils.toPrimitive(qrels.keySet().toArray(new Integer[0]));
Arrays.sort(queryNumbers);
this.metric = MetricFactory.getMetric(metric, qrels, cutoff);
}
public Int2ObjectMap<double[]> getProbs() {
return probs;
}
public double[] learnProbabilitiesAndReturnMetricPerQuery(String pathTorun)
throws IOException, InterruptedException {
Map<Integer, List<Pair<String, Double>>> run = TRECParser.parseRun(pathTorun).getValue();
Map<Integer, Double> perQueryMetricMap = metric.calculate(run);
double[] perQueryMetricArrays = perQueryMetricMap.entrySet().stream().sorted(Map.Entry.comparingByKey())
.map(Map.Entry::getValue).collect(Collectors.toList()).stream().mapToDouble(d -> d).toArray();
probs = new Int2ObjectOpenHashMap<double[]>(run.size());
for (Integer query : run.keySet()) {
probs.put(query.intValue(), learn(query, run.get(query), qrels));
}
return perQueryMetricArrays;
}
public int[] getRandomSample(int queries) {
ArrayList<Integer> queryIdList = new ArrayList<Integer>(probs.keySet());
Collections.shuffle(queryIdList);
return ArrayUtils.toPrimitive(queryIdList.subList(0, queries).toArray(new Integer[0]));
}
public double[] sampleMeanMetric(double change, int queries, boolean random) {
int[] sampledQueries;
if (random) {
sampledQueries = getRandomSample(queries);
} else {
sampledQueries = ArrayUtils.subarray(queryNumbers, 0, queries);
}
return sampleMeanMetric(change, sampledQueries);
}
private double[] sampleMeanMetric(double change, int[] queryIds) {
double[] meanMetricsProb = new double[samples];
ExecutorService pool = Executors.newFixedThreadPool(threads);
for (int i = 0; i < samples; i++) {
pool.execute(new ProbabilityModelThread(this, i, change, queryIds, metric));
}
for (int l = 0; l < samples; l++) {
SimpleImmutableEntry<Integer, DoubleList> fit = null;
try {
fit = queue.poll(5, TimeUnit.MINUTES);
meanMetricsProb[fit.getKey()] = fit.getValue().getDouble(0);
} catch (InterruptedException e) {
logger.error("samplingMean polling", e);
}
}
pool.shutdownNow();
return meanMetricsProb;
}
public double[] sampleMetric(double change, int queries, int queriesImproved, boolean random)
throws IOException, InterruptedException {
int[] sampledQueries;
if (random) {
sampledQueries = getRandomSample(queries);
} else {
sampledQueries = ArrayUtils.subarray(queryNumbers, 0, queries);
}
return sampleMetric(change, queries, queriesImproved, sampledQueries);
}
public double[] sampleMetric(double change, int queries, boolean random) throws IOException, InterruptedException {