2016-12-02 128 views
0

我有一个MultilayerPerceptronClassificationModel建立并培训(以同样的方式在this教程),现在我想坚持它,以便在神经网络下一次我需要一些数据进行分类再利用。该模型有loadsave方法要在文件中保留和恢复。但是有没有办法在数据库中保存(以及后来 - 加载)模型? (在我的情况下它是CassandraDB)。如何在数据库中保存Spark MLlib模型?

回答

1

好的,我自己找到了答案。不知道这是最好的解决方案,但它对我来说工作正常。

MultilayerPerceptronClassificationModel(以及据我所知,每个型号的MLlib包)都实现了Serializable接口。所以它可以被序列化/反序列化为ByteArray

让我们做一个表,用于保存在卡桑德拉DB模式:

CREATE TABLE models (
    uid TEXT, 
    name TEXT, 
    model BLOB, 

    PRIMARY KEY (uid) 
); 

现在我们可以编写模型到DB:

def saveModel(model: MultilayerPerceptronClassificationModel) = { 
    val baos = new ByteArrayOutputStream() 
    val oos = new ObjectOutputStream(baos) 

    oos.writeObject(model) 
    oos.flush() 
    oos.close() 

    sc.parallelize(Seq((model.uid, "my-neural-network-model", baos.toByteArray))) 
    .saveToCassandra("mykeyspace", "models", SomeColumns("uid", "name", "model")) 
} 

和读取模型回:

def loadModel(): MultilayerPerceptronClassificationModel = { 
    sc.cassandraTable("mykeyspace", "models").map { r => 
    val bis = new ByteArrayInputStream(r.getBytes("model").array()) 
    val ois = new ObjectInputStream(bis) 

    ois.readObject.asInstanceOf[MultilayerPerceptronClassificationModel] 
    }.first() 
}