Mianzhi Wang

Ph.D. in Electrical Engineering

Notes on Migrating doa-tools from MATLAB to Python


Recently, I migrated my MATLAB library for my research, namely doa-tools, to Python because I will no longer have access to MATLAB after graduation. MATLAB is a great tool for array processing related research. However, it is not freely accessible to everyone. Python, on the other hand, is more broadly accessible.

The overall migration process was not easy, but was quite smooth. I had to rewrite thousands lines of code to replace MATLAB functions with their NumPy, SciPy, and Matplotlib counterparts. The original MATLAB code was written without classes in mind. With Python classes, I can better organize my code. I also added tests and set up Sphinx to generate documentations. You can check the fully documented Python version here.

This article summarizes a list of note I took during the migration process. For a complete manual on moving from MATLAB to Python, always start with NumPy's official migration guide here. The code samples in this articles are based on MATLAB 2017b and Python 3.6. In all Python code samples, we will assume that numpy is imported as np and that matplotlib.pyplot is imported as plt.

Multidimensional arrays

In MATLAB, every variable is a multidimensional array[1]. This is definitely not the case for Python. In Python, we use NumPy to manipulate multidimensional arrays. Although the MATLAB syntax and NumPy syntax share a lot of similarities, there are still many notable differences.

Row-major order is used by default

NumPy stores elements in row-major order by default, while MATLAB stores elements in column-major order. Therefore, in NumPy, reshape will generally produce results different from those in MATLAB.

MATLAB (column-major order):

Matrix layout: [A11A12A13A21A22A23A31A32A33]Memory layout: [A11A21A31A12A22A32A13A23A33]\text{Matrix layout: } \begin{bmatrix} A_{11} & A_{12} & A_{13}\\ A_{21} & A_{22} & A_{23}\\ A_{31} & A_{32} & A_{33} \end{bmatrix} \implies \text{Memory layout: } \begin{bmatrix} A_{11} & A_{21} & A_{31} & A_{12} & A_{22} & A_{32} & A_{13} & A_{23} & A_{33} \end{bmatrix}

NumPy (row-major order):

Matrix layout: [A11A12A13A21A22A23A31A32A33]Memory layout: [A11A12A13A21A22A23A31A32A33]\text{Matrix layout: } \begin{bmatrix} A_{11} & A_{12} & A_{13}\\ A_{21} & A_{22} & A_{23}\\ A_{31} & A_{32} & A_{33} \end{bmatrix} \implies \text{Memory layout: } \begin{bmatrix} A_{11} & A_{12} & A_{13} & A_{21} & A_{22} & A_{23} & A_{31} & A_{32} & A_{33} \end{bmatrix}

I first encountered this issue when implementing the vec()\mathrm{vec}(\cdot) operation, which converts a matrix to a column vector by stack all its columns together. In MATLAB, this operation can be simply implemented using reshape(A, [], 1) (or A(:)). In NumPy, one must specify order='F' when calling reshape to enforce MATLAB-like index ordering.

MATLAB:

A = [1 2;3 4];
v = [1 2 3 4 5 6];

% Reshaping
reshape(v, 2, 3);
% will produce
% [ 1 3 5;
%   2 4 5 ]

% Stack all columns of A together.
a = A(:);

Python:

A = np.array([[1., 2.], [3., 4.]])
v = np.array([[1., 2., 3., 4., 5., 6.]])

# Reshaping
v.reshape((2, 3))
# will produce
# [[1, 2, 3],
#  [4, 5, 6]]

# Stack all columns of A together
a = A.reshape((-1,), order='F')

There is no linear indexing

In MATLAB, one can access a single element in a multidimensional array with a single index. This indexing scheme is called linear indexing[2]. For instance, if A\mathbf{A} is a 3×33\times3 matrix, A(8) will refer to A(2,3), or equivalently the 8-th element in A(:) (remember that MATLAB uses column-major order). NumPy, however, does not support linear indexing. In general, linear indexing is not needed in NumPy. If you really want to do it, you can mimic linear indexing by utilizing reshape.

MATLAB:

A = [11 12 13;
     21 22 23;
     31 32 33];

% Linear indexing
disp(A(3)); % 31
disp(A(8)); % 23

% Linear indexing of multiple elements
disp(A([2 4])); % [21 12]

Python:

A = np.array([[11, 12, 13],
              [21, 22, 23],
              [31, 32, 33]])

# No linear indexing!
print(A[1]) # [21, 22, 23]
print(A[7]) # Error

# Mimic linear indexing
print(A.reshape((-1,), order='F')[[1, 3]])
# [21, 12]

There exist 1D arrays and 0D arrays

MATLAB is designed to manipulate matrices, and every matrix variable is at least 2D (size(0.5) will return [1 1]). However, in NumPy, there are 1D arrays and even 0D arrays. Using 1D arrays is perfectly fine in NumPy. However, you should be careful when mixing 1D arrays and 2D matrices. Another thing to keep in mind is that the transpose operation in NumPy is defined for general multidimensional arrays. Therefore, transposing a 1-D array in NumPy still produces a 1-D array.

MATLAB:

% A scalar is a 1x1 matrix in MATLAB.
x = 1;
disp(ndims(x)); % 2
disp(size(x)); % [1 1]


% A row vector is a 1xn matrix in MATLAB.
y = [1 1 2 3 5];
disp(ndims(y)); % 2
disp(size(y)); % [1 5]

% A 2x2 matrix.
A = [1, 4; 2, 3];
disp(ndims(A)); % 2
disp(size(A)); % [2 2]

% Transposing a row vector results in
% a column vector.
z = transpose(y);
disp(ndims(z)); % 2
disp(size(z)); % [5 1]

Python:

# You can wrap a scalar.
x = np.array(1.)
# You get an 0D array with an empty shape.
print(x.ndim) # 0
print(x.shape) # ()

# Create a 1D array.
y = np.array([1., 1., 2., 3., 5.])
print(y.ndim) # 1
print(y.shape) # (5,)

# A 2x2 matrix.
A = np.array([[1., 4.], [2., 3.]]) 
print(A.ndim) # 2
print(A.shape) # (2, 2)

# Transposing a 1D array results in
# a 1D array
z = y.T
print(z.ndim) # 1
print(z.shape) # (5,)

Singleton dimensions can be automatically removed

In MATLAB, when extracting a row, a column, or a slice from a multidimensional array, singleton dimensions are not automatically removed and sometimes you need to use squeeze()[3] to remove these singleton dimensions. Common reduction operations such as sum() and mean() follows a similar behavior.

In NumPy, singleton dimensions can be automatically removed, saving you from extra calls to np.squeeze(). For common reduction operations, you can even control this behavior with the keepdims option. However, you should be careful when extracting columns for a matrix due as A[:, 0] will return the first column as a 1D array.

MATLAB:

A = [1 4;2 3];
% Retrieves the first column.
v = A(:,1);
disp(size(v)); % [2 1]





% Extract a matrix from a tensor.
T = ones(3, 4, 5);
t = squeeze(T(1,:,:));
disp(size(t)); % [4 5]

% Sum each row
A = repmat([1 2 3], [3 1]);
disp(sum(A, 2)); % [6; 6; 6]


Python:

A = np.array([[1., 4.], [2., 3.]])
# Retrieves the first column as a 1D vector.
v = A[:, 0]
print(v.shape) # (2,)
# Retrieves the first column as a 2D
# column vector.
v = A[:, 0, np.newaxis]
print(v.shape) # (2, 1)

# Extract a matrix from a tensor.
T = np.ones((3, 4, 5))
t = T[1, :, :]
print(t.shape) # (4, 5)

# Sum each row
A = np.tile([1, 2, 3], (3, 1))
print(np.sum(A, axis=1)) # [6, 6, 6]
print(np.sum(A, axis=1, keepdims=True))
# [[6], [6], [6]]

You need to explicitly copy an array if you need to

In MATLAB, variables are passed by values. You can simply copy a multidimensional array with the assignment operator, =. In addition, various operations including slicing and transpose always creates a copy. For instance, A(:,1) creates a copy of the first column of A.

In Python, everything is an object and objects are passed by reference. Therefore, multiple variables can refer to the same object. If this object is mutable, changes will be visible to all the variables referencing it. You need to explicitly make a copy if you need to. In NumPy, basic slicing and transpose creates a view instead of copy, which shares the underlying data storage with the original array[4]. If you need a copy, you need to explicitly invoke the copy() method.

MATLAB:

x = [1 2 3];
% Create a copy of x.
y = x;

A = [1 4;2 3];
% Get the first row and modify it.
v = A(1,:);
v(1) = 10;
% Changing v does not affect A.
disp(A); % [1 4;2 3]

A = [1 4;2 3];
% Get the transpose and modify it.
B = transpose(A);
% Changing B does not affect A.
B(1,1) = 10;
disp(B); % [10 2;4 3]
disp(A); % [1 4;2 3]

Python:

x = np.array([1., 2., 3.])
# Create a copy of x.
y = x.copy()

A = np.array([[1., 4.], [2., 3.]])
# Get the first row and modify it.
v = A[0, :]
v[0] = 10.
# Changing v affects A.
print(A) # [[10., 4.], [2., 3.]]

A = np.array([[1., 4.], [2., 3.]])
# Get the transpose and modify it.
B = A.T
# Changing B affects A.
B[0, 0] = 10.
print(B) # [[10., 2.]. [4., 3.]]
print(A) # [[10., 4.]. [2., 3.]]

You will need np.ix_ when extract submatrices

Suppose we have two vectors, row_ids and col_ids, where row_ids is a list of row indices and col_ids is a list of column indices. We want to extract the submatrix specified by row_ids and col_ids from a matrix, A. In MATLAB, this can be simply done with A(row_ids, col_ids). However, if you try it in NumPy, you will get unexpected results and even errors. In NumPy, we need to use np.ix_.

MATLAB:

A = rand(10, 10);
row_ids = [1 2 9];
col_ids = [4 7];
B = A(row_ids, col_ids);
disp(size(B)); % [3 2]
% B consists of
% A(1,4) A(1,7)
% A(2,4) A(2,7)
% A(9,4) A(9,7)

Python:

A = np.random.rand(10, 10)
row_ids = [0, 1, 8]
col_ids = [3, 6]
B = A[np.ix_(row_ids, col_ids)]
print(B.shape) # (3, 2)
# B consists of
# A[0,3] A[0,6]
# A[1,3] A[1,6]
# A[8,3] A[8,6]

You can get MATLAB-like concatenation with np.block

In MATLAB, you can easily concatenate many smaller matrices into a bigger one with []. In NumPy, you can use np.block() to perform similar operations.

MATLAB:

A = [1 3;2 4];
v1 = [-1 -1];
v2 = [1; 1];
S = [9 v1;v2 A];
disp(S);
% [9 -1 -1
%  1  1  3
%  1  2  4]

Python:

A = np.array([[1., 3.], [2., 4.]])
v1 = np.array([[-1., -1.]])
v2 = np.array([[1.], [1.]])
S = np.block([[9., v1], [v2, A]])
print(S)
# [[ 9. -1. -1.]
#  [ 1.  1.  3.]
#  [ 1.  2.  4.]]

Linear algebra

In MATLAB, the linear algebra routines are globally available. In Python, these routines can by found under np.linalg. Some other matrix functions such as sqrtm() can be found under scipy.linalg[5].

You cannot use apostrophe to obtain the conjugate transpose

In MATLAB, the conjugate transpose (or Hermitian) of a matrix, A, can be simply expressed with A' (a single apostrophe means conjugate transpose, while .' means normal transpose). In NumPy, there is no such shortcut.

MATLAB:

% Compute A^H A
B = A' * A;

Python:

# Compute A^H A
B = A.conj().T @ A

You can use eigh instead of eig for Hermitian matrices

In MATLAB, the eig function is design for general eigendecomposition problems. Therefore, if the input matrix is expected to be Hermitian, but actually not due to numerical errors, eig will use the algorithms for non-Hermitian matrices, which may produce undesired results. Hence, if the input matrix, A, is expected to be Hermitian, we usually pass in 0.5*(A + A') instead of A. For more details, you can check my previous article here).

In NumPy, there is no such problem because np.linalg provides a function named eigh, which is designed for Hermitian matrices.

MATLAB:

% Even if A is supposed to be Hermitian, we
% need to force it to be numerically Hermitian.
A = 0.5 * (A + A');
[E, V] = eig(A, 'vector');

Python:

# We can use `eigh` instead of `eig`, which will
# only use the upper (or lower) triangular part.
v, E = np.linalg.eigh(A)

Use np.linalg.lstsq to implement mldivide and mrdivide

MATLAB overloads \ and / with matrix right division, mrdivide, and matrix left division, mldivide, respectively[6]. In NumPy, \ is not an operator and / means element-wise division. You can simply use np.linalg.lstsq to implement mldivide and mrdivide (or np.linalg.solve for the invertible case).

MATLAB:

% Evaluate mrdivide.
A / B 
% Evaluate mldivide.
A \ B
% Evaluate A B^{-1}.
A / B 
% Evaluate A^{-1} B.
A \ B

Python:

# Evaluate mrdivide.
np.linalg.lstsq(B.T, A.T)[0].T
# Evaluate mldivide.
np.linalg.lstsq(A, B)
# Evaluate A B^{-1}.
np.linalg.solve(B.T, A.T).T
# Evaluate A^{-1} B.
np.linalg.solve(A, B)

Signal processing

MATLAB provides a rich set of functions for signal processing related applications (e.g., spectral analysis, filtering, filter design etc.). In Python, relevant functions can be found under scipy.signal[7]

SciPy has find_peaks, but does not have imregionalmax

When implementing spectrum-based direction-of-arrival estimators, one needs to find the top peaks in the resulting spectrum. In MATLAB, findpeaks() can find peaks in a 1D spectrum and imregionalmax() can identify local maximas in a 2D spectrum. In Python, SciPy recently added find_peaks() under scipy.signal, which is similar to MATLAB's findpeaks(). However, there is no function similar to MATLAB's imregionalmax().

MATLAB:

% Find peaks of a sine wave.
x = sin(linspace(-8, 8, 100));
[peaks, indices] = findpeaks(x);



Python:

from scipy.signal import find_peaks

# Find peaks of a sine wave.
x = np.sin(np.linspace(-8, 8, 100))
indices, props = find_peaks(x)
peaks = x[indices] # Get peak values

Optimization

Optimization problems exist in various engineering applications. doa-tools implemented both the maximum-likelihood based estimators and sparsity-based estimators. The former require a solver that is capable of solving constrained non-convex problems, while the latter require a solver that can efficiently solve l1l_1-regularized least squares problems. MATLAB has built-in functions to solve various optimization problems, while in Python, various solvers are available from third-party libraries.

Use scipy.optimize.minimize for general constrained non-convex problems

In MATLAB, sub-optimal solutions to constrained non-convex problems can be obtained with fmincon(). In Python, a similar function minimize() is provided by scipy.optimize.

MATLAB:

f = @(x) x^4 - 3*x^2 + 1;
% Minimize f subject to x >= -1.
% Starting from x = 0.5.
x_opt = fmincon(...
  f, 0.5, [], [], [], [], -1, inf...
);
print(x_opt); % 1.2247

Python:

from scipy.optimize import minimize

# Minimize f subject to x >= -1.
# Starting from x = 0.5.
f = lambda x: x**4 - 3*x**2 + 1
res = minimize(f, 0.5, bounds=[(-1, np.inf)])
print('{0:.4f}'.format(res.x[0])) # 1.2247 

The Python version of CVX is CVXPY

CVX is a modeling framework for convex programming problems designed for MATLAB. The Python version is available and is named CVXPY. CVXPY's syntax is very similar to that of the CVX. One advantage of CVXPY is that it support defining parameters, allowing the creation of reusable optimization problems.

Note: When installing CVXPY under Windows, you may encounter errors when installing ECOS. To fix this issue, manually install the latest version of ECOS. Check the discussion here.

MATLAB:



m = 4;
n = 10;
rng(42);
A = rand(m, n);
b = rand(m, 1);


% l1-regularized least squares.
cvx_begin
  variable x(n, 1)
  minimize sum_square(A*x - b) + norm(x, 1)
  subject to
    x >= 0
cvx_end


disp(x);

Python:

import cvxpy as cvx

m = 4
n = 10
np.random.seed(42)
# MATLAB uses column-major order.
# Ensure A is the same under the same seed.
A = np.random.rand(n, m).T
b = np.random.rand(m, 1)
# l1-regularized least squares.
x = cvx.Variable((n, 1))
objective = cvx.Minimize(
    cvx.sum_squares(cvx.matmul(A, x) - b)
    + cvx.norm1(x)
)
constraints = [x >= 0]
problem = cvx.Problem(objective, constraints)
result = problem.solve()
print(x.value)

Plotting

MATLAB provides a rich set of plotting options. In Python, similar functionalities are provided by Matplotlib. In the code examples, we assume that matplotlib.pyplot has been imported as plt.

hold on is no longer needed

In MATLAB, hold on is required to prevent new plots from deleting existing plots. With Matplotlib, this command is not needed[8].

MATLAB:

% Three line plots in the same figure.
figure;
plot(rand(50, 1)); hold on;
plot(rand(50, 1) + 1); hold on;
plot(rand(50, 1) + 2); hold off;

Python:

# Three line plots in the same figure.
plt.figure()
plt.plot(np.random.rand(50, 1))
plt.plot(np.random.rand(50, 1) + 1)
plt.plot(np.random.rand(50, 1) + 2)
plt.show()

Remember to call plt.tight_layout() when using subplots

Subplots are supports by Matplotlib. Before calling plt.show(), remember to call plt.tight_layout() to avoid overlapping titles/labels/etc[9].

Stem plots come with baselines by default

In MATLAB, stem plots do not come with baselines. However, Matplotlib's stem plots will have baselines drawn be default. You can remove the baselines by setting basefmt to ' '.

MATLAB:

% A simple stem plot.
figure;
stem([1 2 3 2 1]);

Python:

# A simple stem plot.
plt.figure()
plt.stem([1, 2, 3, 2, 1], basefmt=' ')
plt.show()

You need to manually create 3D axes in matplotlib

By default, Matplotlib's axes are for 2D plots. To enable 3D plots, you need to import Axes3D from mpl_toolkits.mplot3d and manually create 3D axes.

MATLAB:

# Visualize 3D Gaussian samples.
x = randn(100, 3);
figure;
scatter3(x(:,1), x(:,2), x(:,3));




Python:

# Visualize 3D Gaussian samples.
from mpl_toolkits.mplot3d import Axes3D

x = np.random.randn(100, 3)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:, 0], x[:, 1], x[:, 2])
fig.show()

3D plotting can be very slow in Matplotlib

Rendering 3D plots in MATLAB on modern PCs is very fast because MATLAB plots' rendering are hardware accelerated. Matplotlib currently does not support hardware acceleration, and large 3D plots will take a long time to render. To display complex 3D plots, consider using Mayavi instead of Matplotlib[10].

set_aspect('equal') is broken in 3D plots

In some cases, you want the three axis in a 3D plot share same scale (e.g., visualizing points on a unit sphere). If you use ax.set_aspect('equal') you will obtain weired results. This is a known issue and you can find the discussion here on GitHub. Currently my workaround is to manually set the limits of the three axes. You can find the relevant code here.


  1. See https://www.mathworks.com/help/matlab/learn_matlab/matrices-and-arrays.html.

  2. Refer to "Indexing with a Single Index" here: https://www.mathworks.com/help/matlab/math/array-indexing.html.

  3. See https://www.mathworks.com/help/matlab/ref/squeeze.html.

  4. For more details, check https://docs.scipy.org/doc/numpy-1.15.0/reference/arrays.indexing.html.

  5. See https://docs.scipy.org/doc/scipy/reference/linalg.html.

  6. See https://www.mathworks.com/help/matlab/ref/mldivide.html and https://www.mathworks.com/help/matlab/ref/mrdivide.html.

  7. See https://docs.scipy.org/doc/scipy/reference/signal.html.

  8. Matplotlib used to have the hold() command. Now it is obsolete: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.hold.html.

  9. See the examples here: https://matplotlib.org/users/tight_layout_guide.html.

  10. See https://matplotlib.org/mpl_toolkits/mplot3d/faq.html for FAQs regarding 3D plots in Matplotlib.