CodeGen中在分配内存后,即执行工作空间更新
backend::FunctionInfo func_info;
// defined()判断memory_plan_的数据是否为空,这里表示内存分配是否成功
if (memory_plan_.defined()) {
// TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
// 使用新的内存分配更新mod工作空间大小
func_info =
relay::tec::UpdateMainWorkspaceSize(mod, config_, memory_plan_->expr_to_storage_info);
// 给mod加一个main_func_info属性,值为刚才更新后的函数信息
mod = WithAttr(mod, "main_func_info", func_info);
}
UpdateMainWorkspaceSize的实现
backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const CompilationConfig& config,
Map storage_info_map) {
Function func = Downcast(mod->Lookup("main"));
VLOG_CONTEXT << "UpdateMainWorkspaceSize";
VLOG(1) << "calculating FunctionInfo for main:" << std::endl << PrettyPrint(func);
// This is a Map>
// TODO(mbs): Collapsing VirtualDevices to just device type.
// 索引为设备类型,值是分配的内存块id和大小
std::unordered_map, backend::EnumClassHash>
sid_workspace;
// This is a Map
// 索引为设备类型,值为io个数
std::unordered_map device_io;
// This is a Map
// 索引为设备类型,值为常量个数
std::unordered_map device_consts;
// Initialize the mapping from all storage identifiers to workspace sizes,
// the amount of device io, and the device constants.
// storage_info_map是分配的内存表,对应各个token分配的内存.
// 这里sid_workspace,device_io,device_consts是以token的设备类型为索引
for (const auto& kv : storage_info_map) {
const backend::StorageInfo& storage_info = kv.second;
const std::vector& storage_ids = storage_info->storage_ids;
const std::vector& virtual_devices = storage_info->virtual_devices;
CHECK_EQ(storage_ids.size(), virtual_devices.size());
for (uint32_t i = 0; i < virtual_devices.size(); i++) {
DLDeviceType device_type = virtual_devices[i]->device_type();
sid_workspace[device_type][storage_ids[i]] = 0;
device_io[device_type] = 0;
device_consts[device_type] = 0;
}
}
// Iterate the storage map to compute all the tensor sizes in the program.
// There are 3 cases in this code:
//
// First we need to compute the sizes of all
// inline constants.
//
// Second we compute the size of any bound variable as these are input and output
// sizes of the program.
//
// Finally for all other expressions we check which storage identifier they have
// been assigned and we compute the maximal size of the storage, as tensors can
// share storage with other tensors which are the same size or larger.
//
// In this final case there is only one allocation for all tensors which share storage
// which will be the maximal size of all tensors which were assigned to it.
/* 迭代内存卡映射来计算程序中所有张量的大小
在这个代码中有3种情况:
首先,我们需要计算所有内联常数的大小;
其次,我们计算所有绑定变量的大小,因为这些是程序的输入和输出大小;
最后,我们检查所有其他表达式的存储标识符,并计算存储空间的最大大小,因为张量可以与其他大小相同或更大的张量复用存储空间.
在最后一种情况下,所有张量只有一个共享存储的分配,即分配给它的所有张量的最大大小。
*/
for (const auto& kv : storage_info_map) {
const Expr& expr = kv.first;
const backend::StorageInfo& storage_info = kv.second;
// 计算token tensor需要的空间大小
int64_t size_bytes = backend::CalculateRelayExprSizeBytes(expr->checked_type());
VLOG(1) << "expression:" << std::endl
<< PrettyPrint(expr) << std::endl
<< "of type:" << std::endl
<< PrettyPrint(expr->checked_type()) << std::endl
<< "has size " << size_bytes << " and storage info:" << std::endl
<< storage_info;
//获取为该token分配的内存块id和设备类型
const std::vector& storage_ids = storage_info->storage_ids;
const std::vector& virtual_devices = storage_info->virtual_devices;
//如果对应的token是常量,则按设备类型统计常量所占空间大小
if (expr->IsInstance()) {
for (const auto& virtual_device : virtual_devices) {
DLDeviceType device_type = virtual_device->device_type();
ICHECK_EQ(device_consts.count(device_type), 1);
device_consts[device_type] += size_bytes;
}
} else if (expr->IsInstance() || expr.same_as(func->body)) {
//如果是变量或者函数体,则按照设备类型统计io所占内存大小
CHECK(size_bytes == 0 || virtual_devices.size() >= 1) << "must be at least one device";
for (const auto& virtual_device : virtual_devices) {
DLDeviceType device_type = virtual_device->device_type();
device_io[device_type] += size_bytes;
}
} else {
// TODO(@electriclilies): This code is never being called which means sid_workspace is not
// updated.. This means that storage info is probably not being created correctly. Or is not
// equivalent to what was here previously
for (uint32_t i = 0; i < storage_ids.size(); i++) {
// Here we record the largest size of the tensor
// that share the same storage id, because storage_id will
// be shared between multiple tensors that are not live simultaneously.
/* 如果一种设备上若干个tensor不同时存在, 那么它们复用同一块内存,
只要保证这个内存是最大的tensor大小即可, 所以这里记录最大的tensor大小*/
DLDeviceType device_type = virtual_devices[i]->device_type();
if (size_bytes > sid_workspace[device_type][storage_ids[i]]) {
sid_workspace[device_type][storage_ids[i]] = size_bytes;
}
}
}
}
// This is a Map
// 表的key是设备类型, value是工作空间大小
std::unordered_map device_workspace;
// Once we know the sizes of sids, we need to accumulate per device
for (const auto& dev_sid_size : sid_workspace) {
auto dev = dev_sid_size.first;
device_workspace[dev] = 0;
// 对每种设备,统计该设备的分配的内存块总共大小
for (const auto& sid_size : dev_sid_size.second) {
device_workspace[dev] += sid_size.second;
}
}
Map workspace_sizes;
Map io_sizes;
Map constant_sizes;
Map tir_primfuncs;
Map relay_primfuncs;
// Initialize all target workspaces to zero
for (const auto& target : config->primitive_targets) {
workspace_sizes.Set(target, 0);
}
//获取分配的内存块相关设备target,设置内存块大小统计,关联relay fun和target
for (const auto& dev_and_size : device_workspace) {
Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
workspace_sizes.Set(target, dev_and_size.second);
relay_primfuncs.Set(target, func);
}
//按target记录io占用内存大小
for (const auto& dev_and_size : device_io) {
Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
io_sizes.Set(target, dev_and_size.second);
}
//按target记录常量占用内存大小
for (const auto& dev_and_size : device_consts) {
Target target = config->FindPrimitiveTargetForDeviceOrFail(dev_and_size.first);
ICHECK_EQ(constant_sizes.count(target), 0);
constant_sizes.Set(target, dev_and_size.second);
}
//返回函数占用空间信息
backend::FunctionInfo func_info(std::move(workspace_sizes), std::move(io_sizes),
std::move(constant_sizes), std::move(tir_primfuncs),
std::move(relay_primfuncs));
VLOG(1) << "func_info: " << func_info;
return std::move(func_info);
}
简单的说,就是统计一个函数的输入输出占用了多少空间,函数内部变量占用了多少空间,以及函数使用的常量占用了多少空间。