42template <
typename numtype>
43std::string
show(
const numtype *a,
size_t rows,
size_t cols,
44 const std::string &name =
"") {
45 std::string output =
"\n";
47 output +=
"\n" + name +
" (" + std::to_string(rows) +
", " +
48 std::to_string(cols) +
")\n\n";
51 "\n(" + std::to_string(rows) +
", " + std::to_string(cols) +
")\n\n";
55 numtype max = *std::max_element(a, a + rows * cols);
56 if constexpr (std::is_same<numtype, int>::value) {
57 spacing = std::max(0, (
int)log10(max + .01)) + 2;
58 }
else if constexpr (std::is_same<numtype, float>::value) {
62 throw std::runtime_error(
"Unsupported number type for show()");
65 for (
size_t i = 0; i < rows; i++) {
70 for (
size_t j = 0; j < cols; j++) {
76 if constexpr (std::is_same<numtype, int>::value) {
77 snprintf(buffer, spacing,
"%*d", spacing, a[i * cols + j]);
78 }
else if constexpr (std::is_same<numtype, float>::value) {
79 if (std::abs(a[i * cols + j]) < 1000 &&
80 std::abs(a[i * cols + j]) > 0.01 ||
81 a[i * cols + j] == 0.0) {
82 snprintf(buffer, 16,
"%9.2f", a[i * cols + j]);
84 snprintf(buffer, 16,
"%10.2e", a[i * cols + j]);
86 throw std::runtime_error(
"Unsupported number type for show()");
107template <
typename numtype,
size_t rows,
size_t cols>
108std::string
show(
const std::array<numtype, rows * cols> &a,
109 const std::string &name =
"") {
125template <
size_t rows,
size_t cols>
126std::string
show(
const std::array<float, rows * cols> &a,
127 const std::string &name =
"") {
139void range(
float *input,
size_t N,
float start = 0.0,
float step = 1.0) {
142 for (
size_t i = 0; i < N; i++) {
155void range(std::array<float, N> &input,
float start = 0.0,
float step = 1.0) {
157 for (
size_t i = start; i < N; i++) {
171void randint(
float *a,
size_t N, std::mt19937 &gen,
int min = -1,
int max = 1) {
172 std::uniform_int_distribution<> dist(min, max);
173 for (
int i = 0; i < N; i++) {
174 a[i] =
static_cast<float>(dist(gen));
185template <
typename numtype,
size_t size>
186void randint(std::array<numtype, size> &a, std::mt19937 &gen,
int min = -1,
188 std::uniform_int_distribution<> dist(min, max);
189 for (
int i = 0; i <
size; i++) {
190 a[i] =
static_cast<numtype
>(dist(gen));
202void randn(
float *a,
size_t N, std::mt19937 &gen,
float mean = 0.0,
204 std::normal_distribution<float> dist(mean, std);
205 for (
int i = 0; i < N; i++) {
206 a[i] =
static_cast<float>(dist(gen));
217template <
size_t size>
218void randn(std::array<float, size> &a, std::mt19937 &gen,
float mean = 0.0,
220 std::normal_distribution<float> dist(mean, std);
221 for (
int i = 0; i <
size; i++) {
222 a[i] =
static_cast<float>(dist(gen));
231inline void eye(
float *a,
size_t N) {
232 for (
size_t i = 0; i < N; i++) {
233 for (
size_t j = 0; j < N; j++) {
234 a[i * N + j] = (i == j) ? 1.0 : 0.0;
249inline void transpose(
float *input,
float *output,
size_t M,
size_t N) {
250 for (
size_t i = 0; i < M; i++) {
251 for (
size_t j = 0; j < N; j++) {
252 output[j * M + i] = input[i * N + j];
264inline void flip(
float *a,
size_t R,
size_t C,
bool horizontal =
true) {
266 for (
size_t i = 0; i < R; i++) {
267 for (
size_t j = 0; j < C / 2; j++) {
268 std::swap(a[i * C + j], a[i * C + C - j - 1]);
272 for (
size_t i = 0; i < R / 2; i++) {
273 for (
size_t j = 0; j < C; j++) {
274 std::swap(a[i * C + j], a[(R - i - 1) * C + j]);
288bool isclose(
float *a,
float *b,
size_t n,
float tol = 1e-3) {
289 for (
size_t i = 0; i < n; i++) {
290 if (std::abs(a[i] - b[i]) > tol || std::isnan(a[i]) || std::isnan(b[i])) {
static Logger kDefLog
Default logger for logging messages to stdout at the info level. Output stream and logging level for ...
void range(float *input, size_t N, float start=0.0, float step=1.0)
Populate the array with a range of values. This is mostly for testing purposes.
void LOG(Logger &logger, int level, const char *message,...)
Log a message to the logger. If NDEBUG is defined in a source or as a compiler flag,...
void eye(float *a, size_t N)
Populate a square matrix with the identity matrix.
void randn(float *a, size_t N, std::mt19937 &gen, float mean=0.0, float std=1.0)
Populate the array with random floats, generated from a Gaussian distribution.
bool isclose(float *a, float *b, size_t n, float tol=1e-3)
Determine if the values of two arrays are close to each other.
std::string show(const numtype *a, size_t rows, size_t cols, const std::string &name="")
Show a 2D array as a string, base implementation.
void transpose(float *input, float *output, size_t M, size_t N)
Transpose a matrix.
void randint(float *a, size_t N, std::mt19937 &gen, int min=-1, int max=1)
Populate the array with random integers.
static constexpr int kShowMaxCols
size_t size(const Shape &shape)
Returns the number of elements in a tensor with the given shape, which is equal to the product of the...
static constexpr int kShowMaxRows
void flip(float *a, size_t R, size_t C, bool horizontal=true)
Flip a matrix horizontally or vertically.