Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Merge pull request #95 from myui/feature/densemodel
Browse files Browse the repository at this point in the history
Feature/densemodel
  • Loading branch information
myui committed Aug 9, 2014
2 parents cb34e0a + ac6f907 commit ecf4d1f
Show file tree
Hide file tree
Showing 50 changed files with 1,791 additions and 721 deletions.
16 changes: 16 additions & 0 deletions scripts/ddl/define-all.hive
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,20 @@ create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifie
drop temporary function conv2dense;
create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';

-- for backward compatibility
drop temporary function addBias;
create temporary function addBias as 'hivemall.ftvec.AddBiasUDF';

drop temporary function add_bias;
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';

-- for backward compatibility
drop temporary function sortByFeature;
create temporary function sortByFeature as 'hivemall.ftvec.SortByFeatureUDF';

drop temporary function sort_by_feature;
create temporary function sort_by_feature as 'hivemall.ftvec.SortByFeatureUDF';

drop temporary function extract_feature;
create temporary function extract_feature as 'hivemall.ftvec.ExtractFeatureUDF';

Expand Down Expand Up @@ -247,9 +255,13 @@ create temporary function arowe2_regress as 'hivemall.regression.AROWRegressionU
-- array functions --
---------------------

-- alias for backward compatibility
drop temporary function AllocFloatArray;
create temporary function AllocFloatArray as 'hivemall.tools.array.AllocFloatArrayUDF';

drop temporary function float_array;
create temporary function float_array as 'hivemall.tools.array.AllocFloatArrayUDF';

drop temporary function array_remove;
create temporary function array_remove as 'hivemall.tools.array.ArrayRemoveUDF';

Expand Down Expand Up @@ -318,9 +330,13 @@ create temporary function generate_series as 'hivemall.tools.GenerateSeriesUDTF'
-- string functions --
----------------------

-- alias for backward compatibility
drop temporary function isStopword;
create temporary function isStopword as 'hivemall.tools.string.StopwordUDF';

drop temporary function is_stopword;
create temporary function is_stopword as 'hivemall.tools.string.StopwordUDF';

drop temporary function split_words;
create temporary function split_words as 'hivemall.tools.string.SplitWordsUDF';

Expand Down
8 changes: 8 additions & 0 deletions scripts/ddl/define-ftvec-udf.hive
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,20 @@ create temporary function rand_amplify as 'hivemall.ftvec.amplify.RandomAmplifie
drop temporary function conv2dense;
create temporary function conv2dense as 'hivemall.ftvec.ConvertToDenseModelUDAF';

-- for backward compatibility
drop temporary function addBias;
create temporary function addBias as 'hivemall.ftvec.AddBiasUDF';

drop temporary function add_bias;
create temporary function add_bias as 'hivemall.ftvec.AddBiasUDF';

-- for backward compatibility
drop temporary function sortByFeature;
create temporary function sortByFeature as 'hivemall.ftvec.SortByFeatureUDF';

drop temporary function sort_by_feature;
create temporary function sort_by_feature as 'hivemall.ftvec.SortByFeatureUDF';

drop temporary function extract_feature;
create temporary function extract_feature as 'hivemall.ftvec.ExtractFeatureUDF';

Expand Down
8 changes: 8 additions & 0 deletions scripts/ddl/define-tools-udf.hive
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
-- array functions --
---------------------

-- alias for backward compatibility
drop temporary function AllocFloatArray;
create temporary function AllocFloatArray as 'hivemall.tools.array.AllocFloatArrayUDF';

drop temporary function float_array;
create temporary function float_array as 'hivemall.tools.array.AllocFloatArrayUDF';

drop temporary function array_remove;
create temporary function array_remove as 'hivemall.tools.array.ArrayRemoveUDF';

Expand Down Expand Up @@ -74,9 +78,13 @@ create temporary function distcache_gets as 'hivemall.tools.mapred.DistributedCa
-- string functions --
----------------------

-- alias for backward compatibility
drop temporary function isStopword;
create temporary function isStopword as 'hivemall.tools.string.StopwordUDF';

drop temporary function is_stopword;
create temporary function is_stopword as 'hivemall.tools.string.StopwordUDF';

drop temporary function split_words;
create temporary function split_words as 'hivemall.tools.string.SplitWordsUDF';

Expand Down
4 changes: 2 additions & 2 deletions src/main/hivemall/HivemallConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

public final class HivemallConstants {

public static final String BIAS_CLAUSE = "+bias";
public static final int BIAS_CLAUSE_INT = -1;
public static final String BIAS_CLAUSE = "0";
public static final String CONFKEY_RAND_AMPLIFY_SEED = "hivemall.amplify.seed";

// org.apache.hadoop.hive.serde.Constants (hive 0.9)
// org.apache.hadoop.hive.serde.serdeConstants (hive 0.10 or later)
Expand Down
104 changes: 59 additions & 45 deletions src/main/hivemall/LearnerBaseUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
package hivemall;

import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.common.DenseModel;
import hivemall.common.PredictionModel;
import hivemall.common.SpaceEfficientDenseModel;
import hivemall.common.SparseModel;
import hivemall.common.WeightValue;
import hivemall.common.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.OpenHashMap;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -49,93 +53,103 @@
import org.apache.hadoop.io.Text;

public abstract class LearnerBaseUDTF extends UDTFWithOptions {

private static final Log logger = LogFactory.getLog(LearnerBaseUDTF.class);

protected boolean feature_hashing;
protected float bias;
protected String preloadedModelFile;
protected boolean skipUntouched;
protected boolean dense_model;
protected int model_dims;
protected boolean disable_halffloat;

public LearnerBaseUDTF() {}

protected boolean returnCovariance() {
protected boolean useCovariance() {
return false;
}

public boolean isFeatureHashingEnabled() {
return feature_hashing;
}

public float getBias() {
return bias;
}

@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("fh", "fhash", false, "Enable feature hashing (only used when feature is TEXT type) [default: off]");
opts.addOption("b", "bias", true, "Bias clause [default 0.0 (disable)]");
opts.addOption("loadmodel", true, "Model file name in the distributed cache");
opts.addOption("output_untouched", false, "Output feature weights not touched in the training");
opts.addOption("dense", "densemodel", false, "Use dense model or not");
opts.addOption("dims", "feature_dimensions", true, "The dimension of model [default: 16777216 (2^24)]");
opts.addOption("disable_halffloat", false, "Toggle this option to disable the use of SpaceEfficientDenseModel");
return opts;
}

@Override
protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException {
boolean fhashFlag = false;
float biasValue = 0.f;
String modelfile = null;
boolean output_untouched = true; // emit every weight by the default
boolean denseModel = false;
int modelDims = -1;
boolean disableHalfFloat = false;

CommandLine cl = null;
if(argOIs.length >= 3) {
String rawArgs = HiveUtils.getConstString(argOIs[2]);
cl = parseOptions(rawArgs);

if(cl.hasOption("fh")) {
fhashFlag = true;
}
modelfile = cl.getOptionValue("loadmodel");

String biasStr = cl.getOptionValue("b");
if(biasStr != null) {
biasValue = Float.parseFloat(biasStr);
denseModel = cl.hasOption("dense");
if(denseModel) {
modelDims = Primitives.parseInt(cl.getOptionValue("dims"), 16777216);
}

modelfile = cl.getOptionValue("loadmodel");
if(modelfile != null) {
output_untouched = cl.hasOption("output_untouched");
}
disableHalfFloat = cl.hasOption("disable_halffloat");
}

this.feature_hashing = fhashFlag;
this.bias = biasValue;
this.preloadedModelFile = modelfile;
this.skipUntouched = output_untouched ? false : true;
this.dense_model = denseModel;
this.model_dims = modelDims;
this.disable_halffloat = disableHalfFloat;
return cl;
}

protected void loadPredictionModel(OpenHashMap<Object, WeightValue> map, String filename, PrimitiveObjectInspector keyOI) {
protected PredictionModel createModel() {
if(dense_model) {
boolean useCovar = useCovariance();
if(model_dims > 16777216) {
logger.info("Build a space efficient dense model with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
return new SpaceEfficientDenseModel(model_dims, useCovar);
} else {
logger.info("Build a dense model with initial with " + model_dims
+ " initial dimensions" + (useCovar ? " w/ covariances" : ""));
return new DenseModel(model_dims, useCovar);
}
} else {
int initModelSize = getInitialModelSize();
logger.info("Build a sparse model with initial with " + initModelSize
+ " initial dimensions");
return new SparseModel(initModelSize);
}
}

protected int getInitialModelSize() {
return 16384;
}

protected void loadPredictionModel(PredictionModel model, String filename, PrimitiveObjectInspector keyOI) {
final StopWatch elapsed = new StopWatch();
final long lines;
try {
if(returnCovariance()) {
lines = loadPredictionModel(map, new File(filename), keyOI, writableFloatObjectInspector, writableFloatObjectInspector);
if(useCovariance()) {
lines = loadPredictionModel(model, new File(filename), keyOI, writableFloatObjectInspector, writableFloatObjectInspector);
} else {
lines = loadPredictionModel(map, new File(filename), keyOI, writableFloatObjectInspector);
lines = loadPredictionModel(model, new File(filename), keyOI, writableFloatObjectInspector);
}
} catch (IOException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
} catch (SerDeException e) {
throw new RuntimeException("Failed to load a model: " + filename, e);
}
if(!map.isEmpty()) {
logger.info("Loaded " + map.size() + " features from distributed cache '" + filename
if(model.size() > 0) {
logger.info("Loaded " + model.size() + " features from distributed cache '" + filename
+ "' (" + lines + " lines) in " + elapsed);
}
}

private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, File file, PrimitiveObjectInspector keyOI, WritableFloatObjectInspector valueOI)
private static long loadPredictionModel(PredictionModel model, File file, PrimitiveObjectInspector keyOI, WritableFloatObjectInspector valueOI)
throws IOException, SerDeException {
long count = 0L;
if(!file.exists()) {
Expand All @@ -144,7 +158,7 @@ private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
if(!file.getName().endsWith(".crc")) {
if(file.isDirectory()) {
for(File f : file.listFiles()) {
count += loadPredictionModel(map, f, keyOI, valueOI);
count += loadPredictionModel(model, f, keyOI, valueOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI);
Expand All @@ -169,7 +183,7 @@ private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
}
Object k = keyRefOI.getPrimitiveWritableObject(keyRefOI.copyObject(f0));
float v = varRefOI.get(f1);
map.put(k, new WeightValue(v, false));
model.set(k, new WeightValue(v, false));
}
} finally {
reader.close();
Expand All @@ -179,7 +193,7 @@ private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
return count;
}

private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, File file, PrimitiveObjectInspector featureOI, WritableFloatObjectInspector weightOI, WritableFloatObjectInspector covarOI)
private static long loadPredictionModel(PredictionModel model, File file, PrimitiveObjectInspector featureOI, WritableFloatObjectInspector weightOI, WritableFloatObjectInspector covarOI)
throws IOException, SerDeException {
long count = 0L;
if(!file.exists()) {
Expand All @@ -188,7 +202,7 @@ private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
if(!file.getName().endsWith(".crc")) {
if(file.isDirectory()) {
for(File f : file.listFiles()) {
count += loadPredictionModel(map, f, featureOI, weightOI, covarOI);
count += loadPredictionModel(model, f, featureOI, weightOI, covarOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getLineSerde(featureOI, weightOI, covarOI);
Expand Down Expand Up @@ -218,7 +232,7 @@ private static long loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
float v = c2oi.get(f1);
float cov = (f2 == null) ? WeightValueWithCovar.DEFAULT_COVAR
: c3oi.get(f2);
map.put(k, new WeightValueWithCovar(v, cov, false));
model.set(k, new WeightValueWithCovar(v, cov, false));
}
} finally {
reader.close();
Expand Down
18 changes: 6 additions & 12 deletions src/main/hivemall/classifier/AROWClassifierUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
}

@Override
protected boolean returnCovariance() {
protected boolean useCovariance() {
return true;
}

Expand Down Expand Up @@ -118,23 +118,17 @@ protected void update(final List<?> features, final float y, final float alpha,
}
final Object k;
final float v;
if(parseX) {
FeatureValue fv = FeatureValue.parse(f, feature_hashing);
if(parseFeature) {
FeatureValue fv = FeatureValue.parse(f);
k = fv.getFeature();
v = fv.getValue();
} else {
k = ObjectInspectorUtils.copyToStandardObject(f, featureInspector);
v = 1.f;
}
WeightValue old_w = weights.get(k);
WeightValue old_w = model.get(k);
WeightValue new_w = getNewWeight(old_w, v, y, alpha, beta);
weights.put(k, new_w);
}

if(biasKey != null) {
WeightValue old_bias = weights.get(biasKey);
WeightValue new_bias = getNewWeight(old_bias, bias, y, alpha, beta);
weights.put(biasKey, new_bias);
model.set(k, new_w);
}
}

Expand All @@ -145,7 +139,7 @@ private static WeightValue getNewWeight(final WeightValue old, final float x, fi
old_w = 0.f;
old_cov = 1.f;
} else {
old_w = old.getValue();
old_w = old.get();
old_cov = old.getCovariance();
}

Expand Down
Loading

0 comments on commit ecf4d1f

Please sign in to comment.