Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract MiniBatchSchema From CombineSchema To Support Column Types #8

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions cmake/mindalpha_shared.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ add_library(mindalpha_shared SHARED
cpp/mindalpha/ps_default_agent.cpp
cpp/mindalpha/ps_helper.cpp
cpp/mindalpha/combine_schema.cpp
cpp/mindalpha/minibatch_schema.cpp
cpp/mindalpha/index_batch.cpp
cpp/mindalpha/hash_uniquifier.cpp
cpp/mindalpha/model_metric_buffer.cpp
Expand Down
1 change: 1 addition & 0 deletions cmake/python_wheel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(python_files
python/mindalpha/initializer.py
python/mindalpha/updater.py
python/mindalpha/model.py
python/mindalpha/minibatch.py
python/mindalpha/distributed_trainer.py
python/mindalpha/distributed_tensor.py
python/mindalpha/agent.py
Expand Down
5 changes: 4 additions & 1 deletion cpp/mindalpha/combine_schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ CombineSchema::CombineToIndicesAndOffsets(const IndexBatch& batch, bool feature_
const StringViewHashVector*
CombineSchema::GetCell(const IndexBatch& batch, size_t i, const std::string& column_name) const
{
const size_t column_index = column_name_map_.at(column_name);
size_t column_index = -1;
if (column_name_map_.size()) {
column_index = column_name_map_.at(column_name);
}
const StringViewHashVector& vec = batch.GetCell(i, column_index, column_name);
return vec.empty() ? nullptr : &vec;
}
Expand Down
28 changes: 26 additions & 2 deletions cpp/mindalpha/feature_extraction_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <mindalpha/pybind_utils.h>
#include <mindalpha/combine_schema.h>
#include <mindalpha/index_batch.h>
#include <mindalpha/minibatch_schema.h>
#include <mindalpha/hash_uniquifier.h>
#include <mindalpha/feature_extraction_python_bindings.h>

Expand Down Expand Up @@ -48,7 +49,8 @@ py::class_<mindalpha::CombineSchema, std::shared_ptr<mindalpha::CombineSchema>>(
return map;
})
.def("combine_to_indices_and_offsets",
[](const mindalpha::CombineSchema& schema, const mindalpha::IndexBatch& batch, bool feature_offset)
[](const mindalpha::CombineSchema& schema,
const mindalpha::IndexBatch& batch, bool feature_offset)
{
auto [indices, offsets] = schema.CombineToIndicesAndOffsets(batch, feature_offset);
py::array indices_arr = mindalpha::to_numpy_array(std::move(indices));
Expand Down Expand Up @@ -76,7 +78,7 @@ py::class_<mindalpha::CombineSchema, std::shared_ptr<mindalpha::CombineSchema>>(
auto& str2 = schema.GetCombineSchemaSource();
return py::make_tuple(str1, str2);
},
[](py::tuple t)
[](const py::tuple& t)
{
if (t.size() != 2)
throw std::runtime_error("invalid pickle state");
Expand All @@ -89,10 +91,32 @@ py::class_<mindalpha::CombineSchema, std::shared_ptr<mindalpha::CombineSchema>>(
}))
;

py::class_<mindalpha::MinibatchSchema, std::shared_ptr<mindalpha::MinibatchSchema>>(m, "MinibatchSchema")
.def(py::init<>())
.def("clear", &mindalpha::MinibatchSchema::Clear)
.def("load_column_name_from_source", &mindalpha::MinibatchSchema::LoadColumnNameFromSource)
.def("load_column_name_from_file", &mindalpha::MinibatchSchema::LoadColumnNameFromFile)
.def("get_schema_str", &mindalpha::MinibatchSchema::GetSchemaString)
.def(py::pickle(
[](const mindalpha::MinibatchSchema& schema)
{
auto& str = schema.GetColumnNameSource();
return str;
},
[](const std::string& str)
{
auto schema = std::make_shared<mindalpha::MinibatchSchema>();
schema->LoadColumnNameFromSource(str);
return schema;

}))
;
py::class_<mindalpha::IndexBatch, std::shared_ptr<mindalpha::IndexBatch>>(m, "IndexBatch")
.def_property_readonly("rows", &mindalpha::IndexBatch::GetRows)
.def_property_readonly("columns", &mindalpha::IndexBatch::GetColumns)
.def(py::init<const std::string&>())
.def(py::init<py::list, const std::string&>())
.def(py::init<py::list, py::list, const std::string&>())
.def("to_list", &mindalpha::IndexBatch::ToList)
.def("__str__", &mindalpha::IndexBatch::ToString)
;
Expand Down
76 changes: 59 additions & 17 deletions cpp/mindalpha/index_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,62 @@

#include <stdexcept>
#include <mindalpha/index_batch.h>
#include <mindalpha/debug.h>

namespace mindalpha
{
IndexBatch::IndexBatch(const std::string& schema_file) {

IndexBatch::IndexBatch(pybind11::list columns, const std::string& delimiters)
}
IndexBatch::IndexBatch(pybind11::list column_names, pybind11::list columns, const std::string& delimiters) {
if (columns.size() <= 0) {
throw std::runtime_error("empty columns list");
}
column_names_.reserve(column_names.size());
for (auto & col : column_names) {
//https://github.com/pybind/pybind11/issues/1201
auto o = pybind11::cast<pybind11::object>(col);
auto [str, obj] = get_string_object_tuple(o);
// todo compare with string_view
std::string col_name(str.data(), str.size());
column_name_map_.emplace(col_name, column_names_.size());
column_names_.emplace_back(col_name);
}
ConvertColumn(columns, delimiters);
}
IndexBatch::IndexBatch(pybind11::list columns, const std::string& delimiters) {
ConvertColumn(std::move(columns), delimiters);
}
void IndexBatch::ConvertColumn(pybind11::list columns, const std::string& delimiters)
{
if (columns.empty())
if (columns.empty()) {
throw std::runtime_error("empty columns list");
}
split_columns_.reserve(columns.size());
size_t rows = 0;
for (size_t j = 0; j < columns.size(); j++)
{
pybind11::object item = columns[j];
if (!pybind11::isinstance<pybind11::array>(item))
throw std::runtime_error("column " + std::to_string(j) + " is not numpy ndarray");
if (!pybind11::isinstance<pybind11::array>(item)) {
throw std::runtime_error(fmt::format("column {} is not numpy ndarray, but {}", j));
}
pybind11::array arr = item.cast<pybind11::array>();
if (arr.dtype().kind() != 'O')
throw std::runtime_error("column " + std::to_string(j) + " is not numpy ndarray of object");
if (arr.dtype().kind() != 'O') {
throw std::runtime_error(fmt::format("column {} is not numpy ndarray of object", j));
}
StringViewColumn column = SplitColumn(arr, delimiters);
if (j == 0)
if (j == 0) {
rows = column.size();
else if (column.size() != rows)
throw std::runtime_error("column " + std::to_string(j) + " and column 0 are not of the same length; " +
std::to_string(column.size()) + " != " + std::to_string(rows));
}
if (column.size() != rows) {
throw std::runtime_error(fmt::format("column {} and column 0 are not of the same length; {} != {}",
j, column.size(), rows));
}
split_columns_.push_back(std::move(column));
}
if (rows == 0)
if (rows == 0) {
throw std::runtime_error("number of rows is zero");
}
rows_ = rows;
}

Expand All @@ -53,11 +81,15 @@ IndexBatch::SplitColumn(const pybind11::array& column, std::string_view delims)
const size_t rows = column.size();
StringViewColumn output;
output.reserve(rows);
#if 1
for (auto& item: column) {
#else
for (size_t i = 0; i < rows; i++)
{
const void* item_ptr = column.data(i);
// Consider avoiding complex casting here.
PyObject* item = (PyObject*)(*(void**)item_ptr);
#endif
pybind11::object cell = pybind11::reinterpret_borrow<pybind11::object>(item);
auto [str, obj] = get_string_object_tuple(cell);
auto items = SplitFilterStringViewHash(str, delims);
Expand All @@ -68,11 +100,21 @@ IndexBatch::SplitColumn(const pybind11::array& column, std::string_view delims)

const StringViewHashVector& IndexBatch::GetCell(size_t i, size_t j, const std::string& column_name) const
{
if (i >= rows_)
throw std::runtime_error("row index i is out of range; " + std::to_string(i) + " >= " + std::to_string(rows_));
if (j >= split_columns_.size())
throw std::runtime_error("column index j (" + column_name + ") is out of range; " + std::to_string(j) +
" >= " + std::to_string(split_columns_.size()));
if (i >= rows_) {
throw std::runtime_error(fmt::format("row index i is out of range; {}>={}", i, rows_));
}
if (column_names_.size()) {
auto iter = column_name_map_.find(column_name);
if (iter == column_name_map_.end()) {
throw std::runtime_error(fmt::format("can't find {} in column_name_map element_size {}",
column_name, column_name_map_.size()));
}
j = iter->second;
}
if (j >= split_columns_.size()) {
throw std::runtime_error(fmt::format("column index j ({}) is out of range; {} >={}",
column_name, j, split_columns_.size()));
}
auto& column = split_columns_.at(j);
return column.at(i).items_;
}
Expand Down Expand Up @@ -101,7 +143,7 @@ pybind11::list IndexBatch::ToList() const

std::string IndexBatch::ToString() const
{
return "[IndexBatch: " + std::to_string(GetRows()) + " x " + std::to_string(GetColumns()) + "]";
return fmt::format("[IndexBatch: {} x {}]", GetRows(), GetColumns());
}

}
10 changes: 9 additions & 1 deletion cpp/mindalpha/index_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ namespace mindalpha
class __attribute__((visibility("hidden"))) IndexBatch
{
public:
IndexBatch(const std::string& schema_file);
IndexBatch(pybind11::list columns, const std::string& delimiters);
IndexBatch(pybind11::list column_names, pybind11::list columns, const std::string& delimiters);

void ConvertColumn(pybind11::list columns, const std::string& delimiters);
const StringViewHashVector& GetCell(size_t i, size_t j, const std::string& column_name) const;

pybind11::list ToList() const;
Expand All @@ -36,7 +39,9 @@ class __attribute__((visibility("hidden"))) IndexBatch
size_t GetColumns() const { return split_columns_.size(); }

std::string ToString() const;

size_t GetColumnNameSize() {
return column_names_.size();
}
private:
struct __attribute__((visibility("hidden"))) string_view_cell
{
Expand All @@ -50,6 +55,9 @@ class __attribute__((visibility("hidden"))) IndexBatch

std::vector<StringViewColumn> split_columns_;
size_t rows_;

std::unordered_map<std::string, int> column_name_map_;
std::vector<std::string> column_names_;
};

}
78 changes: 78 additions & 0 deletions cpp/mindalpha/minibatch_schema.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <mindalpha/io.h>
#include <mindalpha/minibatch_schema.h>
#include <mindalpha/string_utils.h>
#include <fmt/format.h>
namespace mindalpha {
MinibatchSchema::MinibatchSchema() {
}
void MinibatchSchema::LoadColumnNameFromStream(std::istream &stream) {
using namespace std::string_view_literals;
std::string line;
std::string source;
int i = 0;
while (std::getline(stream, line)) {
source.append(line);
source.push_back('\n');
if (line[0] == '#')
continue;
const auto svpair = SplitStringView(line, " "sv);
int index = -1;
std::string_view name;
if (svpair.size() != 2) {
index = i;
name = svpair[0];
}
else {
index = std::stoi(std::string(svpair[0]));
name = svpair[1];
}
const auto name_alias_pair = SplitStringView(name, "@"sv);
std::string_view col_name;
if (name_alias_pair.size() == 2)
col_name = name_alias_pair[1];
else
col_name = name_alias_pair[0];
column_name_map_[std::string(col_name)] = index;
column_names_.emplace_back(col_name);
i++;
}
column_name_source_ = std::move(source);
}

void MinibatchSchema::LoadColumnNameFromSource(const std::string &source) {
std::istringstream stream(source);
LoadColumnNameFromStream(stream);
}

void MinibatchSchema::LoadColumnNameFromFile(const std::string &uri) {
std::string source = StreamReadAll(uri);
LoadColumnNameFromSource(source);
}
Column MinibatchSchema::GetColumn(const std::string &column_name) const {
const size_t column_index = column_name_map_.at(column_name);
return Column{column_index, 0};
}
void MinibatchSchema::Clear() {
column_name_source_.clear();
column_names_.clear();
column_name_map_.clear();
}
std::string MinibatchSchema::GetSchemaString() const {
fmt::memory_buffer buff;
buff.reserve(1024);
// schema detail
// project string ,article string ,requests integer ,bytes_served long
for (auto& name : column_names_) {
fmt::format_to(buff, "{} string,", name);
}
if (buff.size()) {
buff.resize(buff.size() - 1);
}
return std::string(buff.data(), buff.size());
}
std::string MinibatchSchema::ToString() const {
// TODO
std::string json_str;
return json_str;
}
} // namespace mindalpha
30 changes: 30 additions & 0 deletions cpp/mindalpha/minibatch_schema.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <sstream>
#include <stdint.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace mindalpha {
typedef int ColumnType;
struct Column {
size_t idx;
ColumnType type;
};
class MinibatchSchema {
public:
MinibatchSchema();
void LoadColumnNameFromStream(std::istream &stream);
void LoadColumnNameFromSource(const std::string &source);
void LoadColumnNameFromFile(const std::string &uri);
void Clear();
const std::string& GetColumnNameSource() const { return column_name_source_; }
Column GetColumn(const std::string &column_name) const;
std::string GetSchemaString() const;
std::string ToString() const;
private:
std::unordered_map<std::string, int> column_name_map_;
std::vector<std::string> column_names_;
std::string column_name_source_;
};
} // namespace mindalpha
1 change: 1 addition & 0 deletions python/mindalpha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ._mindalpha import NodeRole
from ._mindalpha import ActorConfig
from ._mindalpha import PSRunner
from ._mindalpha import MinibatchSchema

from .embedding import EmbeddingSumConcat
from .embedding import EmbeddingRangeSum
Expand Down
7 changes: 4 additions & 3 deletions python/mindalpha/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _clean(self):
self._output = torch.tensor(0.0)

@torch.jit.unused
def _do_cast(self, ndarrays):
def _do_cast(self, minibatch):
ndarrays = minibatch.column_values
columns = []
for name in self._selected_columns:
index = self._column_name_map[name]
Expand All @@ -89,10 +90,10 @@ def _do_cast(self, ndarrays):
return output

@torch.jit.unused
def _cast(self, ndarrays):
def _cast(self, minibatch):
self._clean()
self._ensure_column_name_map_loaded()
self._output = self._do_cast(ndarrays)
self._output = self._do_cast(minibatch)

def forward(self, x):
return self._output
Loading