581 lines
19 KiB
C++
581 lines
19 KiB
C++
/*
|
|
* Copyright (c) 2018, Alliance for Open Media. All rights reserved
|
|
*
|
|
* This source code is subject to the terms of the BSD 2 Clause License and
|
|
* the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
|
|
* was not distributed with this source code in the LICENSE file, you can
|
|
* obtain it at www.aomedia.org/license/software. If the Alliance for Open
|
|
* Media Patent License 1.0 was not distributed with this source code in the
|
|
* PATENTS file, you can obtain it at www.aomedia.org/license/patent.
|
|
*/
|
|
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <math.h>
|
|
#include <float.h>
|
|
#include <string.h>
|
|
|
|
#include "tools/txfm_analyzer/txfm_graph.h"
|
|
|
|
typedef enum CODE_TYPE {
|
|
CODE_TYPE_C,
|
|
CODE_TYPE_SSE2,
|
|
CODE_TYPE_SSE4_1
|
|
} CODE_TYPE;
|
|
|
|
int get_cos_idx(double value, int mod) {
|
|
return round(acos(fabs(value)) / PI * mod);
|
|
}
|
|
|
|
char *cos_text_arr(double value, int mod, char *text, int size) {
|
|
int num = get_cos_idx(value, mod);
|
|
if (value < 0) {
|
|
snprintf(text, size, "-cospi[%2d]", num);
|
|
} else {
|
|
snprintf(text, size, " cospi[%2d]", num);
|
|
}
|
|
|
|
if (num == 0)
|
|
printf("v: %f -> %d/%d v==-1 is %d\n", value, num, mod, value == -1);
|
|
|
|
return text;
|
|
}
|
|
|
|
char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) {
|
|
int idx0 = get_cos_idx(w0, mod);
|
|
int idx1 = get_cos_idx(w1, mod);
|
|
char p[] = "p";
|
|
char n[] = "m";
|
|
char *sgn0 = w0 < 0 ? n : p;
|
|
char *sgn1 = w1 < 0 ? n : p;
|
|
snprintf(text, size, "cospi_%s%02d_%s%02d", sgn0, idx0, sgn1, idx1);
|
|
return text;
|
|
}
|
|
|
|
char *cos_text_sse4_1(double w, int mod, char *text, int size) {
|
|
int idx = get_cos_idx(w, mod);
|
|
char p[] = "p";
|
|
char n[] = "m";
|
|
char *sgn = w < 0 ? n : p;
|
|
snprintf(text, size, "cospi_%s%02d", sgn, idx);
|
|
return text;
|
|
}
|
|
|
|
void node_to_code_c(Node *node, const char *buf0, const char *buf1) {
|
|
int cnt = 0;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
|
|
}
|
|
if (cnt == 2) {
|
|
int cnt2 = 0;
|
|
printf(" %s[%d] =", buf1, node->nodeIdx);
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node->inWeight[i]) == 1) {
|
|
cnt2++;
|
|
}
|
|
}
|
|
if (cnt2 == 2) {
|
|
printf(" apply_value(");
|
|
}
|
|
int cnt1 = 0;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (node->inWeight[i] == 1) {
|
|
if (cnt1 > 0)
|
|
printf(" + %s[%d]", buf0, node->inNodeIdx[i]);
|
|
else
|
|
printf(" %s[%d]", buf0, node->inNodeIdx[i]);
|
|
cnt1++;
|
|
} else if (node->inWeight[i] == -1) {
|
|
if (cnt1 > 0)
|
|
printf(" - %s[%d]", buf0, node->inNodeIdx[i]);
|
|
else
|
|
printf("-%s[%d]", buf0, node->inNodeIdx[i]);
|
|
cnt1++;
|
|
}
|
|
}
|
|
if (cnt2 == 2) {
|
|
printf(", stage_range[stage])");
|
|
}
|
|
printf(";\n");
|
|
} else {
|
|
char w0[100];
|
|
char w1[100];
|
|
printf(
|
|
" %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], "
|
|
"cos_bit);\n",
|
|
buf1, node->nodeIdx, cos_text_arr(node->inWeight[0], COS_MOD, w0, 100),
|
|
buf0, node->inNodeIdx[0],
|
|
cos_text_arr(node->inWeight[1], COS_MOD, w1, 100), buf0,
|
|
node->inNodeIdx[1]);
|
|
}
|
|
}
|
|
|
|
void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
|
|
char *fun_name = new char[100];
|
|
get_fun_name(fun_name, 100, type, node_num);
|
|
|
|
printf("\n");
|
|
printf(
|
|
"void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, "
|
|
"const int8_t* stage_range) "
|
|
"{\n",
|
|
fun_name);
|
|
printf(" assert(output != input);\n");
|
|
printf(" const int32_t size = %d;\n", node_num);
|
|
printf(" const int32_t *cospi = cospi_arr(cos_bit);\n");
|
|
printf("\n");
|
|
|
|
printf(" int32_t stage = 0;\n");
|
|
printf(" int32_t *bf0, *bf1;\n");
|
|
printf(" int32_t step[%d];\n", node_num);
|
|
|
|
const char *buf0 = "bf0";
|
|
const char *buf1 = "bf1";
|
|
const char *input = "input";
|
|
|
|
int si = 0;
|
|
printf("\n");
|
|
printf(" // stage %d;\n", si);
|
|
printf(" apply_range(stage, input, %s, size, stage_range[stage]);\n", input);
|
|
|
|
si = 1;
|
|
printf("\n");
|
|
printf(" // stage %d;\n", si);
|
|
printf(" stage++;\n");
|
|
if (si % 2 == (stage_num - 1) % 2) {
|
|
printf(" %s = output;\n", buf1);
|
|
} else {
|
|
printf(" %s = step;\n", buf1);
|
|
}
|
|
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
node_to_code_c(node + idx, input, buf1);
|
|
}
|
|
|
|
printf(" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
|
|
|
|
for (int si = 2; si < stage_num; si++) {
|
|
printf("\n");
|
|
printf(" // stage %d\n", si);
|
|
printf(" stage++;\n");
|
|
if (si % 2 == (stage_num - 1) % 2) {
|
|
printf(" %s = step;\n", buf0);
|
|
printf(" %s = output;\n", buf1);
|
|
} else {
|
|
printf(" %s = output;\n", buf0);
|
|
printf(" %s = step;\n", buf1);
|
|
}
|
|
|
|
// computation code
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
node_to_code_c(node + idx, buf0, buf1);
|
|
}
|
|
|
|
if (si != stage_num - 1) {
|
|
printf(
|
|
" range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
|
|
}
|
|
}
|
|
printf(" apply_range(stage, input, output, size, stage_range[stage]);\n");
|
|
printf("}\n");
|
|
}
|
|
|
|
void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
|
|
printf(" %s[%2d] =", buf1, node->nodeIdx);
|
|
if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
|
|
printf(" _mm_adds_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
|
|
node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
|
|
printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
|
|
node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
|
|
printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
|
|
node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
|
|
printf(" %s[%d]", buf0, node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
|
|
printf(" %s[%d]", buf0, node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
|
|
printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
|
|
printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
|
|
}
|
|
printf(";\n");
|
|
}
|
|
|
|
void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0,
|
|
const char *buf1) {
|
|
char temp0[100];
|
|
char temp1[100];
|
|
// btf_16_sse2_type0(w0, w1, in0, in1, out0, out1)
|
|
if (node->inNodeIdx[0] != partnerNode->inNodeIdx[0])
|
|
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
|
|
cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
|
|
100),
|
|
cos_text_sse2(partnerNode->inWeight[1], partnerNode->inWeight[0],
|
|
COS_MOD, temp1, 100),
|
|
buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
|
|
node->nodeIdx, buf1, partnerNode->nodeIdx);
|
|
else
|
|
printf(" btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
|
|
cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
|
|
100),
|
|
cos_text_sse2(partnerNode->inWeight[0], partnerNode->inWeight[1],
|
|
COS_MOD, temp1, 100),
|
|
buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
|
|
node->nodeIdx, buf1, partnerNode->nodeIdx);
|
|
}
|
|
|
|
Node *get_partner_node(Node *node) {
|
|
int diff = node->inNode[1]->nodeIdx - node->nodeIdx;
|
|
return node + diff;
|
|
}
|
|
|
|
void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
|
|
int cnt = 0;
|
|
int cnt1 = 0;
|
|
if (node->visited == 0) {
|
|
node->visited = 1;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
|
|
if (fabs(node->inWeight[i]) == 1) cnt1++;
|
|
}
|
|
if (cnt == 2) {
|
|
if (cnt1 == 2) {
|
|
// has a partner
|
|
Node *partnerNode = get_partner_node(node);
|
|
partnerNode->visited = 1;
|
|
single_node_to_code_sse2(node, buf0, buf1);
|
|
single_node_to_code_sse2(partnerNode, buf0, buf1);
|
|
} else {
|
|
single_node_to_code_sse2(node, buf0, buf1);
|
|
}
|
|
} else {
|
|
Node *partnerNode = get_partner_node(node);
|
|
partnerNode->visited = 1;
|
|
pair_node_to_code_sse2(node, partnerNode, buf0, buf1);
|
|
}
|
|
}
|
|
}
|
|
|
|
void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) {
|
|
int visited[65][65][2][2];
|
|
memset(visited, 0, sizeof(visited));
|
|
char text[100];
|
|
char text1[100];
|
|
char text2[100];
|
|
int size = 100;
|
|
printf("\n");
|
|
for (int si = 1; si < stage_num; si++) {
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
int cnt = 0;
|
|
Node *node0 = node + idx;
|
|
if (node0->visited == 0) {
|
|
node0->visited = 1;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
|
|
cnt++;
|
|
}
|
|
if (cnt != 2) {
|
|
{
|
|
double w0 = node0->inWeight[0];
|
|
double w1 = node0->inWeight[1];
|
|
int idx0 = get_cos_idx(w0, COS_MOD);
|
|
int idx1 = get_cos_idx(w1, COS_MOD);
|
|
int sgn0 = w0 < 0 ? 1 : 0;
|
|
int sgn1 = w1 < 0 ? 1 : 0;
|
|
|
|
if (!visited[idx0][idx1][sgn0][sgn1]) {
|
|
visited[idx0][idx1][sgn0][sgn1] = 1;
|
|
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
|
|
cos_text_sse2(w0, w1, COS_MOD, text, size),
|
|
cos_text_arr(w0, COS_MOD, text1, size),
|
|
cos_text_arr(w1, COS_MOD, text2, size));
|
|
}
|
|
}
|
|
Node *node1 = get_partner_node(node0);
|
|
node1->visited = 1;
|
|
if (node1->inNode[0]->nodeIdx != node0->inNode[0]->nodeIdx) {
|
|
double w0 = node1->inWeight[0];
|
|
double w1 = node1->inWeight[1];
|
|
int idx0 = get_cos_idx(w0, COS_MOD);
|
|
int idx1 = get_cos_idx(w1, COS_MOD);
|
|
int sgn0 = w0 < 0 ? 1 : 0;
|
|
int sgn1 = w1 < 0 ? 1 : 0;
|
|
|
|
if (!visited[idx1][idx0][sgn1][sgn0]) {
|
|
visited[idx1][idx0][sgn1][sgn0] = 1;
|
|
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
|
|
cos_text_sse2(w1, w0, COS_MOD, text, size),
|
|
cos_text_arr(w1, COS_MOD, text1, size),
|
|
cos_text_arr(w0, COS_MOD, text2, size));
|
|
}
|
|
} else {
|
|
double w0 = node1->inWeight[0];
|
|
double w1 = node1->inWeight[1];
|
|
int idx0 = get_cos_idx(w0, COS_MOD);
|
|
int idx1 = get_cos_idx(w1, COS_MOD);
|
|
int sgn0 = w0 < 0 ? 1 : 0;
|
|
int sgn1 = w1 < 0 ? 1 : 0;
|
|
|
|
if (!visited[idx0][idx1][sgn0][sgn1]) {
|
|
visited[idx0][idx1][sgn0][sgn1] = 1;
|
|
printf(" __m128i %s = pair_set_epi16(%s, %s);\n",
|
|
cos_text_sse2(w0, w1, COS_MOD, text, size),
|
|
cos_text_arr(w0, COS_MOD, text1, size),
|
|
cos_text_arr(w1, COS_MOD, text2, size));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
|
|
char *fun_name = new char[100];
|
|
get_fun_name(fun_name, 100, type, node_num);
|
|
|
|
printf("\n");
|
|
printf(
|
|
"void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) "
|
|
"{\n",
|
|
fun_name);
|
|
|
|
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
|
|
printf(" const __m128i __zero = _mm_setzero_si128();\n");
|
|
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
|
|
|
|
graph_reset_visited(node, stage_num, node_num);
|
|
gen_cospi_list_sse2(node, stage_num, node_num);
|
|
graph_reset_visited(node, stage_num, node_num);
|
|
for (int si = 1; si < stage_num; si++) {
|
|
char in[100];
|
|
char out[100];
|
|
printf("\n");
|
|
printf(" // stage %d\n", si);
|
|
if (si == 1)
|
|
snprintf(in, 100, "%s", "input");
|
|
else
|
|
snprintf(in, 100, "x%d", si - 1);
|
|
if (si == stage_num - 1) {
|
|
snprintf(out, 100, "%s", "output");
|
|
} else {
|
|
snprintf(out, 100, "x%d", si);
|
|
printf(" __m128i %s[%d];\n", out, node_num);
|
|
}
|
|
// computation code
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
node_to_code_sse2(node + idx, in, out);
|
|
}
|
|
}
|
|
|
|
printf("}\n");
|
|
}
|
|
void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) {
|
|
int visited[65][2];
|
|
memset(visited, 0, sizeof(visited));
|
|
char text[100];
|
|
char text1[100];
|
|
int size = 100;
|
|
printf("\n");
|
|
for (int si = 1; si < stage_num; si++) {
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
Node *node0 = node + idx;
|
|
if (node0->visited == 0) {
|
|
int cnt = 0;
|
|
node0->visited = 1;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
|
|
cnt++;
|
|
}
|
|
if (cnt != 2) {
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node0->inWeight[i]) != 1 &&
|
|
fabs(node0->inWeight[i]) != 0) {
|
|
double w = node0->inWeight[i];
|
|
int idx = get_cos_idx(w, COS_MOD);
|
|
int sgn = w < 0 ? 1 : 0;
|
|
|
|
if (!visited[idx][sgn]) {
|
|
visited[idx][sgn] = 1;
|
|
printf(" __m128i %s = _mm_set1_epi32(%s);\n",
|
|
cos_text_sse4_1(w, COS_MOD, text, size),
|
|
cos_text_arr(w, COS_MOD, text1, size));
|
|
}
|
|
}
|
|
}
|
|
Node *node1 = get_partner_node(node0);
|
|
node1->visited = 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void single_node_to_code_sse4_1(Node *node, const char *buf0,
|
|
const char *buf1) {
|
|
printf(" %s[%2d] =", buf1, node->nodeIdx);
|
|
if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
|
|
printf(" _mm_add_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
|
|
node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
|
|
printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
|
|
node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
|
|
printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
|
|
node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
|
|
printf(" %s[%d]", buf0, node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
|
|
printf(" %s[%d]", buf0, node->inNodeIdx[1]);
|
|
} else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
|
|
printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
|
|
} else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
|
|
printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
|
|
}
|
|
printf(";\n");
|
|
}
|
|
|
|
void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0,
|
|
const char *buf1) {
|
|
char temp0[100];
|
|
char temp1[100];
|
|
if (node->inWeight[0] * partnerNode->inWeight[0] < 0) {
|
|
/* type0
|
|
* cos sin
|
|
* sin -cos
|
|
*/
|
|
// btf_32_sse2_type0(w0, w1, in0, in1, out0, out1)
|
|
// out0 = w0*in0 + w1*in1
|
|
// out1 = -w0*in1 + w1*in0
|
|
printf(
|
|
" btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
|
|
"__rounding, cos_bit);\n",
|
|
cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
|
|
cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
|
|
node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
|
|
partnerNode->nodeIdx);
|
|
} else {
|
|
/* type1
|
|
* cos sin
|
|
* -sin cos
|
|
*/
|
|
// btf_32_sse2_type1(w0, w1, in0, in1, out0, out1)
|
|
// out0 = w0*in0 + w1*in1
|
|
// out1 = w0*in1 - w1*in0
|
|
printf(
|
|
" btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
|
|
"__rounding, cos_bit);\n",
|
|
cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
|
|
cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
|
|
node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
|
|
partnerNode->nodeIdx);
|
|
}
|
|
}
|
|
|
|
void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) {
|
|
int cnt = 0;
|
|
int cnt1 = 0;
|
|
if (node->visited == 0) {
|
|
node->visited = 1;
|
|
for (int i = 0; i < 2; i++) {
|
|
if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
|
|
if (fabs(node->inWeight[i]) == 1) cnt1++;
|
|
}
|
|
if (cnt == 2) {
|
|
if (cnt1 == 2) {
|
|
// has a partner
|
|
Node *partnerNode = get_partner_node(node);
|
|
partnerNode->visited = 1;
|
|
single_node_to_code_sse4_1(node, buf0, buf1);
|
|
single_node_to_code_sse4_1(partnerNode, buf0, buf1);
|
|
} else {
|
|
single_node_to_code_sse2(node, buf0, buf1);
|
|
}
|
|
} else {
|
|
Node *partnerNode = get_partner_node(node);
|
|
partnerNode->visited = 1;
|
|
pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1);
|
|
}
|
|
}
|
|
}
|
|
|
|
void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
|
|
char *fun_name = new char[100];
|
|
get_fun_name(fun_name, 100, type, node_num);
|
|
|
|
printf("\n");
|
|
printf(
|
|
"void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) "
|
|
"{\n",
|
|
fun_name);
|
|
|
|
printf(" const int32_t* cospi = cospi_arr(cos_bit);\n");
|
|
printf(" const __m128i __zero = _mm_setzero_si128();\n");
|
|
printf(" const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
|
|
|
|
graph_reset_visited(node, stage_num, node_num);
|
|
gen_cospi_list_sse4_1(node, stage_num, node_num);
|
|
graph_reset_visited(node, stage_num, node_num);
|
|
for (int si = 1; si < stage_num; si++) {
|
|
char in[100];
|
|
char out[100];
|
|
printf("\n");
|
|
printf(" // stage %d\n", si);
|
|
if (si == 1)
|
|
snprintf(in, 100, "%s", "input");
|
|
else
|
|
snprintf(in, 100, "x%d", si - 1);
|
|
if (si == stage_num - 1) {
|
|
snprintf(out, 100, "%s", "output");
|
|
} else {
|
|
snprintf(out, 100, "x%d", si);
|
|
printf(" __m128i %s[%d];\n", out, node_num);
|
|
}
|
|
// computation code
|
|
for (int ni = 0; ni < node_num; ni++) {
|
|
int idx = get_idx(si, ni, node_num);
|
|
node_to_code_sse4_1(node + idx, in, out);
|
|
}
|
|
}
|
|
|
|
printf("}\n");
|
|
}
|
|
|
|
void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) {
|
|
int stage_num = get_hybrid_stage_num(txfm_type, node_num);
|
|
|
|
Node *node = new Node[node_num * stage_num];
|
|
init_graph(node, stage_num, node_num);
|
|
|
|
gen_hybrid_graph_1d(node, stage_num, node_num, 0, 0, node_num, txfm_type);
|
|
|
|
switch (code_type) {
|
|
case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break;
|
|
case CODE_TYPE_SSE2:
|
|
gen_code_sse2(node, stage_num, node_num, txfm_type);
|
|
break;
|
|
case CODE_TYPE_SSE4_1:
|
|
gen_code_sse4_1(node, stage_num, node_num, txfm_type);
|
|
break;
|
|
}
|
|
|
|
delete[] node;
|
|
}
|
|
|
|
int main(int argc, char **argv) {
|
|
CODE_TYPE code_type = CODE_TYPE_SSE4_1;
|
|
for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) {
|
|
for (int node_num = 4; node_num <= 64; node_num *= 2) {
|
|
gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num);
|
|
}
|
|
}
|
|
return 0;
|
|
}
|