import time
import pyopencl
from pyopencl import mem_flags
import numpy
from PIL import Image

src = '''
__kernel void mandelbrot(
    __global uchar* out,
    float pixel_size,
    float x0,
    float y0,
    int max_iter, 
    float threshold
)
{
    const int i = get_global_id(1);
    const int j = get_global_id(0);
    const float cr = x0 + pixel_size * i;
    const float ci = y0 + pixel_size * j;
    float zr = 0.0f;
    float zi = 0.0f;
    float zrzi, zr2, zi2;
    int k;
    for(k = 0; k < max_iter; k++) {
        zrzi = zr * zi;
        zr2 = zr * zr;
        zi2 = zi * zi;
        zr = zr2 - zi2 + cr;
        zi = zrzi + zrzi + ci;
        if(zi2 + zr2 >= threshold) {
            break;
        }
    }
    if(k > 255){
        k = 255;
    }
    const int base = (j * get_global_size(1) + i) * 3;
    out[base] = 255 - k;
    out[base + 1] = 255 - k;
    out[base + 2] = 255 - k;
}
'''

def main():
    context = pyopencl.create_some_context()
    queue = pyopencl.CommandQueue(context)
    program = pyopencl.Program(context, src).build()
    width = 300
    height = 200
    pixel_size = numpy.float32(0.01)
    x0 = numpy.float32(-2.)
    y0 = numpy.float32(-1.)
    max_iter = numpy.int32(256)
    threshold = numpy.float32(2.)
    out = numpy.empty((height, width, 3), numpy.uint8)
    out_buf = pyopencl.Buffer(context, mem_flags.WRITE_ONLY, out.nbytes)
    global_size = (height, width)
    start = time.time()
    program.mandelbrot(queue, global_size, None, out_buf, pixel_size, x0, y0, max_iter, threshold)
    event = pyopencl.enqueue_copy(queue, out, out_buf)
    event.wait()
    stop = time.time()
    image = Image.fromarray(out)
    image.save('mandelbrot.png')
    print stop - start

if __name__ == '__main__':
    main()

