host_common_util.hpp Source File

host_common_util.hpp Source File#

Composable Kernel: host_common_util.hpp Source File
host_common_util.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7#include <array>
8#include <iostream>
9#include <fstream>
10#include <string>
11#include <algorithm>
12
13#include "ck/ck.hpp"
14
15namespace ck {
16
17namespace host_common {
18
19template <typename T>
20static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems)
21{
22 std::ofstream outFile(fileName, std::ios::binary);
23 if(outFile)
24 {
25 outFile.write(reinterpret_cast<const char*>(data), dataNumItems * sizeof(T));
26 outFile.close();
27 std::cout << "Write output to file " << fileName << std::endl;
28 }
29 else
30 {
31 std::cout << "Could not open file " << fileName << " for writing" << std::endl;
32 }
33};
34
35template <typename T>
36static inline T getSingleValueFromString(const std::string& valueStr)
37{
38 std::istringstream iss(valueStr);
39
40 T val;
41
42 iss >> val;
43
44 return (val);
45};
46
47template <typename T>
48static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
49{
50 std::string valuesStr(cstr_values);
51
52 std::vector<T> values;
53 std::size_t pos = 0;
54 std::size_t new_pos;
55
56 new_pos = valuesStr.find(',', pos);
57 while(new_pos != std::string::npos)
58 {
59 const std::string sliceStr = valuesStr.substr(pos, new_pos - pos);
60
61 T val = getSingleValueFromString<T>(sliceStr);
62
63 values.push_back(val);
64
65 pos = new_pos + 1;
66 new_pos = valuesStr.find(',', pos);
67 };
68
69 std::string sliceStr = valuesStr.substr(pos);
70 T val = getSingleValueFromString<T>(sliceStr);
71
72 values.push_back(val);
73
74 return (values);
75}
76
77template <int NDim>
78static inline std::vector<std::array<index_t, NDim>>
79get_index_set(const std::array<index_t, NDim>& dim_lengths)
80{
81 static_assert(NDim >= 1, "NDim >= 1 is required to use this function!");
82
83 if constexpr(NDim == 1)
84 {
85 std::vector<std::array<index_t, NDim>> index_set;
86
87 for(int i = 0; i < dim_lengths[0]; i++)
88 {
89 std::array<index_t, 1> index{i};
90
91 index_set.push_back(index);
92 };
93
94 return index_set;
95 }
96 else
97 {
98 std::vector<std::array<index_t, NDim>> index_set;
99 std::array<index_t, NDim - 1> partial_dim_lengths;
100
101 std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin());
102
103 std::vector<std::array<index_t, NDim - 1>> partial_index_set;
104
105 partial_index_set = get_index_set<NDim - 1>(partial_dim_lengths);
106
107 for(index_t i = 0; i < dim_lengths[0]; i++)
108 for(const auto& partial_index : partial_index_set)
109 {
110 std::array<index_t, NDim> index;
111
112 index[0] = i;
113
114 std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1);
115
116 index_set.push_back(index);
117 };
118
119 return index_set;
120 };
121};
122
123template <int NDim>
124static inline size_t get_offset_from_index(const std::array<index_t, NDim>& strides,
125 const std::array<index_t, NDim>& index)
126{
127 size_t offset = 0;
128
129 for(int i = 0; i < NDim; i++)
130 offset += index[i] * strides[i];
131
132 return (offset);
133};
134
135} // namespace host_common
136} // namespace ck
Definition host_common_util.hpp:17
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299