Writing a custom operator in TFLite — GPU

Registering the op


Writing the custom parser

struct SinAttributes {
float frequency;
float phase;
class SinOperationParser: public TFLiteOperationParser {
// 2 functions to be overriden
IsSupported(some args);
Parse(some args);
// data = tflite_node->data
// where tflite_node is a parameter of the function
auto cast_data = reinterpreted_cast<const uint8_t*>(data)
const flexbuffers::Map map = flexbuffers::GetRoot(cast_data, data_size).AsMap();
SinAttributes attr;// freq is the parameter name in python api of the op
attr.frequency = map["freq"].AsFloat();
attr.phase = map["phase"].AsFloat();
// Inside NewCustomOperationParser()
if (op_name == "sin") {
return std::make_unique<SinOperationParser>();

Writing the shader

  • parameters — key-value pairs for simple data, like kernel_size , that will be hard coded in the glsl shader
  • objects — key-value pairs for read only data other than input tensors, such as kernel weights, binded as uniform buffer objects* in the glsl shader
  • shared_variables — key-value pairs for data to be read/written to but not the output tensor, maybe to store intermediate compute results, binded as SSBO* in the glsl shader
  • workload — 3D int vector regarding how many threads to launch
  • workgroup — 3D int vector such that workload.x / workgroup.x gives number of threads in each work group. Usually workload and workgroup will be the output shape of the operator with 1 thread per workgroup
  • source_code — the actual GLSL source code
  • input, output — enum that specifies how you would want to access input and output tensors. If you were writing an element wise op, where you would only require to read/write the element corresponding to each thread, then use value_n = op(value_n) . If you want to access any element of the input tensor then use output_data_n = op(input_data_n[x,y], input_data_n[x+dx,y+dy]) , where x and y are obtained using InvocationID of the thread. Here n is the index of the input/output tensor.
const parameters = {
"frequency": ctx.attr.frequency,
"phase": ctx.attr.phase,
*generated_code = {
{}, // objects
{}, // shared variables
uint3(), // workload (0,0,0) means one thread for one output element
uint3(), // workgroup
R"value_0 = sin(value_0/$frequency$ + $phase$)", // GLSL code
IOStructure::AUTO, // input access type
IOStructure::AUTO // output access type



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store


Data Science at ShareChat. Ola. IIT Madras.