Mandelbrot 集合与分形

背景简介

Mandelbrot 集合,是满足以下条件的所有复数 $c$ 的集合:

$$z_{n+1} \leftarrow z_n^2 + c,$$

并且 $z_0 = 0$. 选取一个 $c$, 如果 $n\rightarrow \infty$, $|z|$ 始终有限,则 $c$ 属于 Mandelbrot 集合;否则,不属于该集合。在 Youtube 上有一个非常好在视频,介始该集合,点击此处.

数值计算

在数值计算的时候,通常有以下几点考虑

  • 在通常的数值计算里,我们不可能让 $n\rightarrow \infty$,因此通常都会选定一个最大迭代次数或迭代深度 depth。
  • 数学上证明过,一旦某次迭代给出 $|z| > 2$ ,则该点一定不属于 M-集合。
  • 为了上色美观,对于一个给定的 $c$, 我们会记录下 $|z_n|>2$ 的迭代次数。通过此次数的某种函数选取对应颜色。

单个点的判断函数

基于以上 Mandelbrot 复数集合的介绍,很容易写出如下的函数来判断给定单个点,它是否属于这个集合。我们先导入相关的计算库。

In [1]:
import numpy as np
import matplotlib.pyplot as plt
In [2]:
def mandelbrot(c, depth):
    z = c
    
    for ii in range(depth):
        z = z**2 + c
        if abs(z) > 2:
            break
        
    if ii < depth-1:
        mu = ii-np.log2(np.log2(abs(z)))  # 没有直接选用 ii, 是为了让最终的颜色图象更加平滑
    else:
        mu = ii

    return mu

M-集合的生成函数

接下来定义第二个函数,用于处理批量的点。利用 mandelbrot 函数来判断这些点里,哪里不在集合里。 这些点的中心在 center, 宽度为 center 上下 $\pm$ width. 每个方向上选取的点数则为 grid. 这个函数的原型来源于 Matlab 里对应的函数。

In [3]:
def mandelbrot_set(center, width, grid, depth):
    x0 = center.real
    y0 = center.imag
    x = np.linspace(x0-width/2.0, x0+width/2.0, grid)
    y = np.linspace(y0-width/2.0, y0+width/2.0, grid)
    m = np.empty((grid, grid))

    for i in range(grid):
        for j in range(grid):
            m[i,j] = mandelbrot(x[i]+y[j]*1j, depth)
                
    return x, y, m

我们接下来就可以用这两个函数来作不同的初始点,不同的宽度,不同的格点数和深度的 Mandelbrot 图。

In [4]:
center = -0.5+0j
width = 3
grid = 512
depth = 256
    
x, y, m = mandelbrot_set(center, width, grid, depth)

plt.close()    
fig,ax=plt.subplots(figsize=(8,6))
ax.pcolormesh(x, y, m.transpose(), cmap='jet')
plt.show()

也可以选取不同的初始点,尝试不同的初始点。

In [5]:
# x-wing
center = -1.6735-0.0003318j
width = 1.5e-4
grid = 1024
depth = 160
    
x, y, m = mandelbrot_set(center, width, grid, depth)
fig,ax=plt.subplots(figsize=(8,6))
ax.pcolormesh(x, y, m.transpose(), cmap='jet')
plt.show()

M-集用计算的加速

以上的程序会运行非常非常地缓慢,主要是因为存在大量的 for 循环,计算量是在 grid$^2\times$depth 量级,因此 Python 这种解释型语言运算速度慢的劣势非常明显。

利用 numpy 加速

考虑到不同初始点的计算之前并无依赖性,可以非常方便地数组化以上程序,然后利用 numpy 提高计算效率。这是利用 python 实现快速科学计算的一个基本方法,也是在 MATLAB 等其它类似语言里提高效率的基本方法。以下为一种网上其他人的实现方式. 为了与以上的直接方式相区别,我们将相关函数名加上 npy 后缀。

In [6]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

def mandelbrot_npy(c, depth):
    z = np.zeros(c.shape, dtype=np.complex)
    output = np.zeros(c.shape)

    for i in range(depth):
        notdone = np.less(z.real*z.real+z.imag*z.imag, 4.0)
        output[notdone] = i
        z[notdone] = z[notdone]**2 + c[notdone]

    index_out = output < depth-1
    output[index_out] = output[index_out]-np.log2(np.log2(np.abs(z[index_out])))

    return output

def mandelbrot_set_npy(center,width,grid,depth):
    cc = complex(center)

    x = np.linspace(cc.real-width/2, cc.real+width/2, grid)
    y = np.linspace(cc.imag-width/2, cc.imag+width/2, grid)

    xmesh, ymesh = np.meshgrid(x,y,indexing='ij')

    c = xmesh+1j*ymesh
    m = mandelbrot_npy(c, depth)
    
    return x, y, m

if __name__ == '__main__':

    center = -0.5+0j
    width = 3
    grid = 512
    depth = 256

    x,y,m = mandelbrot_set_npy(center,width,grid,depth)
    
    fig,ax=plt.subplots(figsize=(8,6))
    ax.pcolormesh(x, y, m.transpose(), cmap='jet')
    plt.show()

现在我们可用用 %timeit 命令来比较 numpy 实现了多少的提速。先统一定义一下 center 等变量。

In [7]:
center = -0.5+0j
width = 3
grid = 512
depth = 256
In [8]:
%timeit mandelbrot_set(center,width,grid,depth)
11.7 s ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [9]:
%timeit mandelbrot_set_npy(center,width,grid,depth)
446 ms ± 8.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

可以看到利用 numpy 的向量化,程序运行效率提高了 26 倍。这是因为很多的 numpy 函数底层都是用编译型语言 (c 语言)实现的。

利用 numba 库加速

有时候会碰到的问题,比较难或者不可能回避掉 for 循环。这时候可以考虑用 numba 库加速。其它类似的库还用 Cython。这些库的原理都是把一部分 python 程序编译,从而加速。Numba 的优点是对 python 程序的改动非常少,仅仅通过添加几个修饰器就可以了。

这里我把最原始的 Mandelbrot 程序添加几个 @nb.njit 这里 jit 是 just-in-time 的意思, njit 则是 jit(nopython==True) 的简写,意思是将该函数编译运行,完全不用 python 解释器。而如果不设 nopython = True 的话,则编译失败的情况下,会跳到 python object 模式,对效率提升有限。一般提升效率都用 njit,如果编译失败,则根据错误信息提示修改。

In [11]:
import numpy as np
import matplotlib.pyplot as plt
import numba as nb

@nb.njit
def mandelbrot_nb(c, depth):
    z = c
    
    for ii in range(depth):
        z = z**2 + c
        if abs(z) > 2:
            break
        
    if ii < depth-1:
        mu = ii-np.log2(np.log2(abs(z)))  # 没有直接选用 ii, 是为了让最终的颜色图象更加平滑
    else:
        mu = ii

    return mu

@nb.njit
def mandelbrot_set_nb(center, width, grid, depth):
    x0 = center.real
    y0 = center.imag
    x = np.linspace(x0-width/2.0, x0+width/2.0, grid)
    y = np.linspace(y0-width/2.0, y0+width/2.0, grid)
    m = np.empty((grid, grid))

    for i in range(grid):
        for j in range(grid):
            m[i,j] = mandelbrot_nb(x[i]+y[j]*1j, depth)
                
    return x, y, m

if __name__ == '__main__':

    center = -0.5+0j
    width = 3
    grid = 512
    depth = 256

    x,y,m = mandelbrot_set_nb(center,width,grid,depth)
    
    fig,ax=plt.subplots(figsize=(8,6))
    ax.pcolormesh(x, y, m.transpose(), cmap='jet')
    plt.show()

现在再来看一下它的效率。

In [12]:
%timeit mandelbrot_set_nb(center,width,grid,depth)
72.9 ms ± 2.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

很显然,在这个例子里,numba 比 numpy 实现了更好的加速,进一步提速了6倍。原因应该是在 numpy 版本里,还是存在着 python 模式下的 for-loop,而 numba 下,则是全部编译了。而与原始版本相比,在仅仅添加了三行代码 (import numba as nb; 两行 @nb.njit) 的情况下,效率提升了 156 倍。