# -*- coding: utf-8 -*-

import pyopencl
from pyopencl import mem_flags
import numpy
import time

size = 1024
a = numpy.random.randint(0, 256, (size,size)).astype(numpy.int32)
b = numpy.random.randint(0, 256, (size,size)).astype(numpy.int32)
dest = numpy.empty_like(a)

context = pyopencl.create_some_context(interactive=False)
queue = pyopencl.CommandQueue(context)
a_buf = pyopencl.Buffer(context, mem_flags.READ_ONLY | mem_flags.COPY_HOST_PTR, hostbuf=a)
b_buf = pyopencl.Buffer(context, mem_flags.READ_ONLY | mem_flags.COPY_HOST_PTR, hostbuf=b)
dest_buf = pyopencl.Buffer(context, mem_flags.WRITE_ONLY, dest.nbytes)

program = pyopencl.Program(context, '''
__kernel void matrix_mul(
    __global const int* a,
    __global const int* b,
    __global int* dest,
    __local int* local_a,
    __local int* local_b,
    const int n,
    const int m
)
{
    const int i = get_global_id(0);
    const int j = get_global_id(1);

    const int local_i = get_local_id(0);
    const int local_j = get_local_id(1);

    int tmp = 0;

    int local_a_base = get_group_id(1) * m * n;
    int local_b_base = get_group_id(0) * m;
    int local_index = local_j * m + local_i;
    int global_index = local_j * n + local_i;
    for(int l = 0; l < n / m; l++){
        local_a[local_index] = a[global_index + local_a_base];
        local_b[local_index] = b[global_index + local_b_base];
        barrier(CLK_LOCAL_MEM_FENCE);
        for(int k = 0; k < m; k++){
            tmp += local_a[local_j * m + k] * local_b[k * m + local_i];
        }
        barrier(CLK_LOCAL_MEM_FENCE);
        local_a_base += m;
        local_b_base += m * n;
    }
    dest[j * n + i] = tmp;
}
''').build()

n = numpy.int32(size)
local_size = 32
m = numpy.int32(local_size)
local_a = pyopencl.LocalMemory(4 * local_size * local_size)
local_b = pyopencl.LocalMemory(4 * local_size * local_size)

start = time.time()
e = program.matrix_mul(queue, a.shape, (local_size,local_size),
        a_buf, b_buf, dest_buf, local_a, local_b, n, m)
e.wait()
stop = time.time()

pyopencl.enqueue_copy(queue, dest, dest_buf)

print numpy.all(numpy.dot(a, b) == dest)
print stop - start

