Recently, I migrated my MATLAB library for my research, namely doa-tools, to Python because I will no longer have access to MATLAB after graduation. MATLAB is a great tool for array processing related research. However, it is not freely accessible to everyone. Python, on the other hand, is more broadly accessible.

The overall migration process was not easy, but was quite smooth. I had to rewrite thousands lines of code to replace MATLAB functions with their NumPy, SciPy, and Matplotlib counterparts. The original MATLAB code was written without classes in mind. With Python classes, I can better organize my code. I also added tests and set up Sphinx to generate documentations. You can check the fully documented Python version here.

This article summarizes a list of note I took during the migration process. For a complete manual on moving from MATLAB to Python, always start with NumPy's official migration guide here. The code samples in this articles are based on MATLAB 2017b and Python 3.6. In all Python code samples, we will assume that `numpy`

is imported as `np`

and that `matplotlib.pyplot`

is imported as `plt`

.

## Multidimensional arrays

In MATLAB, every variable is a multidimensional array^{[1]}. This is definitely not the case for Python. In Python, we use NumPy to manipulate multidimensional arrays. Although the MATLAB syntax and NumPy syntax share a lot of similarities, there are still many notable differences.

### Row-major order is used by default

NumPy stores elements in row-major order by default, while MATLAB stores elements in column-major order. Therefore, in NumPy, `reshape`

will generally produce results different from those in MATLAB.

MATLAB (column-major order):

NumPy (row-major order):

I first encountered this issue when implementing the $\mathrm{vec}(\cdot)$ operation, which converts a matrix to a column vector by stack all its columns together. In MATLAB, this operation can be simply implemented using `reshape(A, [], 1)`

(or `A(:)`

). In NumPy, one must specify `order='F'`

when calling `reshape`

to enforce MATLAB-like index ordering.

**MATLAB:**

```
A = [1 2;3 4];
v = [1 2 3 4 5 6];
% Reshaping
reshape(v, 2, 3);
% will produce
% [ 1 3 5;
% 2 4 5 ]
% Stack all columns of A together.
a = A(:);
```

**Python:**

```
A = np.array([[1., 2.], [3., 4.]])
v = np.array([[1., 2., 3., 4., 5., 6.]])
# Reshaping
v.reshape((2, 3))
# will produce
# [[1, 2, 3],
# [4, 5, 6]]
# Stack all columns of A together
a = A.reshape((-1,), order='F')
```

### There is no linear indexing

In MATLAB, one can access a single element in a multidimensional array with a single index. This indexing scheme is called linear indexing^{[2]}. For instance, if $\mathbf{A}$ is a $3\times3$ matrix, `A(8)`

will refer to `A(2,3)`

, or equivalently the 8-th element in `A(:)`

(remember that MATLAB uses column-major order). NumPy, however, does not support linear indexing. In general, linear indexing is not needed in NumPy. If you really want to do it, you can mimic linear indexing by utilizing `reshape`

.

**MATLAB:**

```
A = [11 12 13;
21 22 23;
31 32 33];
% Linear indexing
disp(A(3)); % 31
disp(A(8)); % 23
% Linear indexing of multiple elements
disp(A([2 4])); % [21 12]
```

**Python:**

```
A = np.array([[11, 12, 13],
[21, 22, 23],
[31, 32, 33]])
# No linear indexing!
print(A[1]) # [21, 22, 23]
print(A[7]) # Error
# Mimic linear indexing
print(A.reshape((-1,), order='F')[[1, 3]])
# [21, 12]
```

### There exist 1D arrays and 0D arrays

MATLAB is designed to manipulate matrices, and every matrix variable is at least 2D (`size(0.5)`

will return `[1 1]`

). However, in NumPy, there are 1D arrays and even 0D arrays. Using 1D arrays is perfectly fine in NumPy. However, you should be careful when mixing 1D arrays and 2D matrices. Another thing to keep in mind is that the transpose operation in NumPy is defined for general multidimensional arrays. Therefore, transposing a 1-D array in NumPy still produces a 1-D array.

**MATLAB:**

```
% A scalar is a 1x1 matrix in MATLAB.
x = 1;
disp(ndims(x)); % 2
disp(size(x)); % [1 1]
% A row vector is a 1xn matrix in MATLAB.
y = [1 1 2 3 5];
disp(ndims(y)); % 2
disp(size(y)); % [1 5]
% A 2x2 matrix.
A = [1, 4; 2, 3];
disp(ndims(A)); % 2
disp(size(A)); % [2 2]
% Transposing a row vector results in
% a column vector.
z = transpose(y);
disp(ndims(z)); % 2
disp(size(z)); % [5 1]
```

**Python:**

```
# You can wrap a scalar.
x = np.array(1.)
# You get an 0D array with an empty shape.
print(x.ndim) # 0
print(x.shape) # ()
# Create a 1D array.
y = np.array([1., 1., 2., 3., 5.])
print(y.ndim) # 1
print(y.shape) # (5,)
# A 2x2 matrix.
A = np.array([[1., 4.], [2., 3.]])
print(A.ndim) # 2
print(A.shape) # (2, 2)
# Transposing a 1D array results in
# a 1D array
z = y.T
print(z.ndim) # 1
print(z.shape) # (5,)
```

### Singleton dimensions can be automatically removed

In MATLAB, when extracting a row, a column, or a slice from a multidimensional array, singleton dimensions are not automatically removed and sometimes you need to use `squeeze()`

^{[3]} to remove these singleton dimensions. Common reduction operations such as `sum()`

and `mean()`

follows a similar behavior.

In NumPy, singleton dimensions can be automatically removed, saving you from extra calls to `np.squeeze()`

. For common reduction operations, you can even control this behavior with the `keepdims`

option. However, you should be careful when extracting columns for a matrix due as `A[:, 0]`

will return the first column as a 1D array.

**MATLAB:**

```
A = [1 4;2 3];
% Retrieves the first column.
v = A(:,1);
disp(size(v)); % [2 1]
% Extract a matrix from a tensor.
T = ones(3, 4, 5);
t = squeeze(T(1,:,:));
disp(size(t)); % [4 5]
% Sum each row
A = repmat([1 2 3], [3 1]);
disp(sum(A, 2)); % [6; 6; 6]
```

**Python:**

```
A = np.array([[1., 4.], [2., 3.]])
# Retrieves the first column as a 1D vector.
v = A[:, 0]
print(v.shape) # (2,)
# Retrieves the first column as a 2D
# column vector.
v = A[:, 0, np.newaxis]
print(v.shape) # (2, 1)
# Extract a matrix from a tensor.
T = np.ones((3, 4, 5))
t = T[1, :, :]
print(t.shape) # (4, 5)
# Sum each row
A = np.tile([1, 2, 3], (3, 1))
print(np.sum(A, axis=1)) # [6, 6, 6]
print(np.sum(A, axis=1, keepdims=True))
# [[6], [6], [6]]
```

### You need to explicitly copy an array if you need to

In MATLAB, variables are passed by values. You can simply copy a multidimensional array with the assignment operator, `=`

. In addition, various operations including slicing and transpose always creates a copy. For instance, `A(:,1)`

creates a copy of the first column of `A`

.

In Python, everything is an object and objects are passed by reference. Therefore, multiple variables can refer to the same object. If this object is mutable, changes will be visible to all the variables referencing it. You need to explicitly make a copy if you need to. In NumPy, basic slicing and transpose creates a view instead of copy, which shares the underlying data storage with the original array^{[4]}. If you need a copy, you need to explicitly invoke the `copy()`

method.

**MATLAB:**

```
x = [1 2 3];
% Create a copy of x.
y = x;
A = [1 4;2 3];
% Get the first row and modify it.
v = A(1,:);
v(1) = 10;
% Changing v does not affect A.
disp(A); % [1 4;2 3]
A = [1 4;2 3];
% Get the transpose and modify it.
B = transpose(A);
% Changing B does not affect A.
B(1,1) = 10;
disp(B); % [10 2;4 3]
disp(A); % [1 4;2 3]
```

**Python:**

```
x = np.array([1., 2., 3.])
# Create a copy of x.
y = x.copy()
A = np.array([[1., 4.], [2., 3.]])
# Get the first row and modify it.
v = A[0, :]
v[0] = 10.
# Changing v affects A.
print(A) # [[10., 4.], [2., 3.]]
A = np.array([[1., 4.], [2., 3.]])
# Get the transpose and modify it.
B = A.T
# Changing B affects A.
B[0, 0] = 10.
print(B) # [[10., 2.]. [4., 3.]]
print(A) # [[10., 4.]. [2., 3.]]
```

### You will need `np.ix_`

when extract submatrices

Suppose we have two vectors, `row_ids`

and `col_ids`

, where `row_ids`

is a list of row indices and `col_ids`

is a list of column indices. We want to extract the submatrix specified by `row_ids`

and `col_ids`

from a matrix, `A`

. In MATLAB, this can be simply done with `A(row_ids, col_ids)`

. However, if you try it in NumPy, you will get unexpected results and even errors. In NumPy, we need to use `np.ix_`

.

**MATLAB:**

```
A = rand(10, 10);
row_ids = [1 2 9];
col_ids = [4 7];
B = A(row_ids, col_ids);
disp(size(B)); % [3 2]
% B consists of
% A(1,4) A(1,7)
% A(2,4) A(2,7)
% A(9,4) A(9,7)
```

**Python:**

```
A = np.random.rand(10, 10)
row_ids = [0, 1, 8]
col_ids = [3, 6]
B = A[np.ix_(row_ids, col_ids)]
print(B.shape) # (3, 2)
# B consists of
# A[0,3] A[0,6]
# A[1,3] A[1,6]
# A[8,3] A[8,6]
```

### You can get MATLAB-like concatenation with `np.block`

In MATLAB, you can easily concatenate many smaller matrices into a bigger one with `[]`

. In NumPy, you can use `np.block()`

to perform similar operations.

**MATLAB:**

```
A = [1 3;2 4];
v1 = [-1 -1];
v2 = [1; 1];
S = [9 v1;v2 A];
disp(S);
% [9 -1 -1
% 1 1 3
% 1 2 4]
```

**Python:**

```
A = np.array([[1., 3.], [2., 4.]])
v1 = np.array([[-1., -1.]])
v2 = np.array([[1.], [1.]])
S = np.block([[9., v1], [v2, A]])
print(S)
# [[ 9. -1. -1.]
# [ 1. 1. 3.]
# [ 1. 2. 4.]]
```

## Linear algebra

In MATLAB, the linear algebra routines are globally available. In Python, these routines can by found under `np.linalg`

. Some other matrix functions such as `sqrtm()`

can be found under `scipy.linalg`

^{[5]}.

### You cannot use apostrophe to obtain the conjugate transpose

In MATLAB, the conjugate transpose (or Hermitian) of a matrix, `A`

, can be simply expressed with `A'`

(a single apostrophe means conjugate transpose, while `.'`

means normal transpose). In NumPy, there is no such shortcut.

**MATLAB:**

```
% Compute A^H A
B = A' * A;
```

**Python:**

```
# Compute A^H A
B = A.conj().T @ A
```

### You can use `eigh`

instead of `eig`

for Hermitian matrices

In MATLAB, the `eig`

function is design for general eigendecomposition problems. Therefore, if the input matrix is expected to be Hermitian, but actually not due to numerical errors, `eig`

will use the algorithms for non-Hermitian matrices, which may produce undesired results. Hence, if the input matrix, `A`

, is expected to be Hermitian, we usually pass in `0.5*(A + A')`

instead of `A`

. For more details, you can check my previous article here).

In NumPy, there is no such problem because `np.linalg`

provides a function named `eigh`

, which is designed for Hermitian matrices.

**MATLAB:**

```
% Even if A is supposed to be Hermitian, we
% need to force it to be numerically Hermitian.
A = 0.5 * (A + A');
[E, V] = eig(A, 'vector');
```

**Python:**

```
# We can use `eigh` instead of `eig`, which will
# only use the upper (or lower) triangular part.
v, E = np.linalg.eigh(A)
```

### Use `np.linalg.lstsq`

to implement `mldivide`

and `mrdivide`

MATLAB overloads `\`

and `/`

with matrix right division, `mrdivide`

, and matrix left division, `mldivide`

, respectively^{[6]}. In NumPy, `\`

is not an operator and `/`

means element-wise division. You can simply use `np.linalg.lstsq`

to implement `mldivide`

and `mrdivide`

(or `np.linalg.solve`

for the invertible case).

**MATLAB:**

```
% Evaluate mrdivide.
A / B
% Evaluate mldivide.
A \ B
% Evaluate A B^{-1}.
A / B
% Evaluate A^{-1} B.
A \ B
```

**Python:**

```
# Evaluate mrdivide.
np.linalg.lstsq(B.T, A.T)[0].T
# Evaluate mldivide.
np.linalg.lstsq(A, B)
# Evaluate A B^{-1}.
np.linalg.solve(B.T, A.T).T
# Evaluate A^{-1} B.
np.linalg.solve(A, B)
```

## Signal processing

MATLAB provides a rich set of functions for signal processing related applications (e.g., spectral analysis, filtering, filter design etc.). In Python, relevant functions can be found under `scipy.signal`

^{[7]}

### SciPy has `find_peaks`

, but does not have `imregionalmax`

When implementing spectrum-based direction-of-arrival estimators, one needs to find the top peaks in the resulting spectrum. In MATLAB, `findpeaks()`

can find peaks in a 1D spectrum and `imregionalmax()`

can identify local maximas in a 2D spectrum. In Python, SciPy recently added `find_peaks()`

under `scipy.signal`

, which is similar to MATLAB's `findpeaks()`

. However, there is no function similar to MATLAB's `imregionalmax()`

.

**MATLAB:**

```
% Find peaks of a sine wave.
x = sin(linspace(-8, 8, 100));
[peaks, indices] = findpeaks(x);
```

**Python:**

```
from scipy.signal import find_peaks
# Find peaks of a sine wave.
x = np.sin(np.linspace(-8, 8, 100))
indices, props = find_peaks(x)
peaks = x[indices] # Get peak values
```

## Optimization

Optimization problems exist in various engineering applications. *doa-tools* implemented both the maximum-likelihood based estimators and sparsity-based estimators. The former require a solver that is capable of solving constrained non-convex problems, while the latter require a solver that can efficiently solve $l_1$-regularized least squares problems. MATLAB has built-in functions to solve various optimization problems, while in Python, various solvers are available from third-party libraries.

### Use `scipy.optimize.minimize`

for general constrained non-convex problems

In MATLAB, sub-optimal solutions to constrained non-convex problems can be obtained with `fmincon()`

. In Python, a similar function `minimize()`

is provided by `scipy.optimize`

.

**MATLAB:**

```
f = @(x) x^4 - 3*x^2 + 1;
% Minimize f subject to x >= -1.
% Starting from x = 0.5.
x_opt = fmincon(...
f, 0.5, [], [], [], [], -1, inf...
);
print(x_opt); % 1.2247
```

**Python:**

```
from scipy.optimize import minimize
# Minimize f subject to x >= -1.
# Starting from x = 0.5.
f = lambda x: x**4 - 3*x**2 + 1
res = minimize(f, 0.5, bounds=[(-1, np.inf)])
print('{0:.4f}'.format(res.x[0])) # 1.2247
```

### The Python version of CVX is CVXPY

CVX is a modeling framework for convex programming problems designed for MATLAB. The Python version is available and is named CVXPY. CVXPY's syntax is very similar to that of the CVX. One advantage of CVXPY is that it support defining parameters, allowing the creation of reusable optimization problems.

Note:When installing CVXPY under Windows, you may encounter errors when installing ECOS. To fix this issue, manually install the latest version of ECOS. Check the discussion here.

**MATLAB:**

```
m = 4;
n = 10;
rng(42);
A = rand(m, n);
b = rand(m, 1);
% l1-regularized least squares.
cvx_begin
variable x(n, 1)
minimize sum_square(A*x - b) + norm(x, 1)
subject to
x >= 0
cvx_end
disp(x);
```

**Python:**

```
import cvxpy as cvx
m = 4
n = 10
np.random.seed(42)
# MATLAB uses column-major order.
# Ensure A is the same under the same seed.
A = np.random.rand(n, m).T
b = np.random.rand(m, 1)
# l1-regularized least squares.
x = cvx.Variable((n, 1))
objective = cvx.Minimize(
cvx.sum_squares(cvx.matmul(A, x) - b)
+ cvx.norm1(x)
)
constraints = [x >= 0]
problem = cvx.Problem(objective, constraints)
result = problem.solve()
print(x.value)
```

## Plotting

MATLAB provides a rich set of plotting options. In Python, similar functionalities are provided by Matplotlib. In the code examples, we assume that `matplotlib.pyplot`

has been imported as `plt`

.

`hold on`

is no longer needed

In MATLAB, `hold on`

is required to prevent new plots from deleting existing plots. With Matplotlib, this command is not needed^{[8]}.

**MATLAB:**

```
% Three line plots in the same figure.
figure;
plot(rand(50, 1)); hold on;
plot(rand(50, 1) + 1); hold on;
plot(rand(50, 1) + 2); hold off;
```

**Python:**

```
# Three line plots in the same figure.
plt.figure()
plt.plot(np.random.rand(50, 1))
plt.plot(np.random.rand(50, 1) + 1)
plt.plot(np.random.rand(50, 1) + 2)
plt.show()
```

### Remember to call `plt.tight_layout()`

when using subplots

Subplots are supports by Matplotlib. Before calling `plt.show()`

, remember to call `plt.tight_layout()`

to avoid overlapping titles/labels/etc^{[9]}.

### Stem plots come with baselines by default

In MATLAB, stem plots do not come with baselines. However, Matplotlib's stem plots will have baselines drawn be default. You can remove the baselines by setting `basefmt`

to `' '`

.

**MATLAB:**

```
% A simple stem plot.
figure;
stem([1 2 3 2 1]);
```

**Python:**

```
# A simple stem plot.
plt.figure()
plt.stem([1, 2, 3, 2, 1], basefmt=' ')
plt.show()
```

### You need to manually create 3D axes in `matplotlib`

By default, Matplotlib's axes are for 2D plots. To enable 3D plots, you need to import `Axes3D`

from `mpl_toolkits.mplot3d`

and manually create 3D axes.

**MATLAB:**

```
# Visualize 3D Gaussian samples.
x = randn(100, 3);
figure;
scatter3(x(:,1), x(:,2), x(:,3));
```

**Python:**

```
# Visualize 3D Gaussian samples.
from mpl_toolkits.mplot3d import Axes3D
x = np.random.randn(100, 3)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x[:, 0], x[:, 1], x[:, 2])
fig.show()
```

### 3D plotting can be very slow in Matplotlib

Rendering 3D plots in MATLAB on modern PCs is very fast because MATLAB plots' rendering are hardware accelerated. Matplotlib currently does not support hardware acceleration, and large 3D plots will take a long time to render. To display complex 3D plots, consider using Mayavi instead of Matplotlib^{[10]}.

`set_aspect('equal')`

is broken in 3D plots

In some cases, you want the three axis in a 3D plot share same scale (e.g., visualizing points on a unit sphere). If you use `ax.set_aspect('equal')`

you will obtain weired results. This is a known issue and you can find the discussion here on GitHub. Currently my workaround is to manually set the limits of the three axes. You can find the relevant code here.

See https://www.mathworks.com/help/matlab/learn_matlab/matrices-and-arrays.html. ↩

Refer to "Indexing with a Single Index" here: https://www.mathworks.com/help/matlab/math/array-indexing.html. ↩

See https://www.mathworks.com/help/matlab/ref/squeeze.html. ↩

For more details, check https://docs.scipy.org/doc/numpy-1.15.0/reference/arrays.indexing.html. ↩

See https://docs.scipy.org/doc/scipy/reference/linalg.html. ↩

See https://www.mathworks.com/help/matlab/ref/mldivide.html and https://www.mathworks.com/help/matlab/ref/mrdivide.html. ↩

See https://docs.scipy.org/doc/scipy/reference/signal.html. ↩

Matplotlib used to have the

`hold()`

command. Now it is obsolete: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.hold.html. ↩See the examples here: https://matplotlib.org/users/tight_layout_guide.html. ↩

See https://matplotlib.org/mpl_toolkits/mplot3d/faq.html for FAQs regarding 3D plots in Matplotlib. ↩