Predicting Turtle Hatchling Strandings in the Western Cape¶

Every year around Autumn turtle hatchlings are found stranded on beaches across the Western Cape. These hatchlings originate from northern KZN and travel southward within the Agulhas current. Stranding events occur when unfavorable conditions push these hatchlings onshore out of the warm agulhas current and onto beaches across the Western Cape.¶

Here the conditions that result in strandings will be identified, before attempting to build a model to predict when a stranding event may occur. With the ultimate goal to develop an early warning system to identify future stranding events¶

In [1]:
import xarray as xr
import zarr
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs

from datetime import datetime, timedelta
import seaborn as sns
import math

import pickle

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

import warnings
warnings.filterwarnings('ignore')
Warning: ecCodes 2.21.0 or higher is recommended. You are running version 2.16.0

Records have been kept of when and where strandings have previously occured¶

In [2]:
stranding_data = pd.read_csv('2015_2021_hatchling_stranding_summary.csv', names = ['Id', 'Date','Species','Location'])

I will only consider the Caretta Caretta species in a bid to reduce noise as this is by far the dominant species¶

In [3]:
# import historical stranding data
stranding_data = stranding_data.loc[stranding_data['Species'] == 'Caretta caretta']
#Correct a few inconsistent spellings
stranding_data.loc[stranding_data['Location'] == 'Mosselbay', 'Location'] = 'Mossel Bay'
stranding_data.loc[stranding_data['Location'] == 'Witsand', 'Location'] = 'Witsands'
stranding_data.loc[stranding_data['Location'] == 'Tenikwa PLETT', 'Location'] = 'Plettenberg Bay'
stranding_data.loc[stranding_data['Location'] == 'Tenikwa', 'Location'] = 'Plettenberg Bay'

First lets consider the locations of the strandings by plotting the location of beaches that have more than 10 strandings recorded¶

In [4]:
stranding_data = stranding_data.groupby('Location').filter(lambda x: len(x) >= 10) #consider only loactions with greater then 10 strandings
In [5]:
# Manually set locations for most common strandings
latlong_dict = {}
latlong_dict['Struisbaai'] = [-34.80, 20.10]
latlong_dict['Mossel Bay'] = [-34.15, 22.20]
latlong_dict['Muizenberg'] = [-34.11, 18.52]
latlong_dict['Witsands'] = [-34.41, 20.92]
latlong_dict['Arniston'] = [-34.67, 20.27]
latlong_dict['Plettenberg Bay'] = [-34.07, 23.44]
latlong_dict['Hermanus'] = [-34.44, 19.25]
latlong_dict['Sedgefield'] = [-34.10, 22.78]
latlong_dict['Gaansbaai'] = [-34.64, 19.35]
In [6]:
stranding_count = stranding_data.Location.value_counts() 

location = []
lats = []
lons = []
count = []
for index, n in stranding_count.items():
    location.append(index)
    lats.append(latlong_dict[index][0])
    lons.append(latlong_dict[index][1])
    count.append(n)


fig = plt.figure(figsize=(16,9), dpi = 350)

ax = plt.axes(projection=ccrs.PlateCarree())
ax.coastlines('50m', linewidth=0.8)

extent = (np.min(lons)-0.2,np.max(lons)+0.6, np.min(lats)-0.2,np.max(lats)+0.2)
ax.set_extent(extent)

scatterplot = ax.scatter(x = lons, y = lats, c=count, s = [x*10 for x in count], cmap = 'viridis', vmin = 0, vmax = 400, alpha=0.9,transform=ccrs.PlateCarree())

for loc,lat,lon in zip(location, lats, lons):
    ax.text(lon+0.08, lat-0.03, str(loc))

plt.colorbar(scatterplot, shrink = 0.5, orientation = 'horizontal', label = 'Number Of Strandings', pad=0.05)

plt.title('Location of Turtle Hatchling Strandings in the Western Cape')

plt.show()

Struisbaai is by far the most common stranding location.¶

Next lets consider the timing of the strandings¶

In [7]:
stranding_freq = stranding_data.groupby(stranding_data.Date).count()['Species'].rename('Strandings')

stranding_freq.index = pd.to_datetime(stranding_freq.index, infer_datetime_format = True)

weekly={}
yearly = {}

for Location in list(set(stranding_data.Location)):
    location_data = stranding_data.loc[stranding_data['Location'] == str(Location)]
    location_freq = location_data.groupby(location_data.Date).count()['Species'].rename('Standings')
    location_freq.index = pd.to_datetime(location_freq.index, infer_datetime_format=True)
    weekly[str(Location)] = location_freq.groupby(location_freq.index.week).sum()
    yearly[str(Location)] = location_freq.groupby(location_freq.index.year).sum()
    


weekly = pd.DataFrame(data = weekly).reindex(np.linspace(1,52,52, dtype = int))
yearly = pd.DataFrame(data = yearly)

fig = px.bar(yearly, labels={
                     "value": "Number of Strandings",
                     "Date": "Year"
                 },
                title="Turtle Strandings in the Western Cape")

fig.update_layout(
    legend=dict(title = 'Stranding Location'))


fig.show()

Some years are seen to have vastly more strandings than other years.¶

In [8]:
fig = px.bar(weekly,labels={
                     "value": "Number of Strandings",
                     "Date": "Week of the Year"
                 },
                title="Turtle Strandings in the Western Cape")

fig.update_layout(
    legend=dict(title = 'Stranding Location'))

fig.show()

Strandings primarily occur in the months March, April and May.¶

Next lets consider the weather condions directly preceeding the largest stranding events in Struisbaai.¶

In [9]:
location = 'Struisbaai'
In [10]:
ds = xr.open_dataset('ERA5_reanalysis_East_Coast.nc').load() 

# ensure lat lon corresponds to a point where sst is available
ocean_lons = ds.sst[0].dropna(dim = 'lat', how = 'any').dropna(dim = 'lon', how = 'any').lon.values
ocean_lats = ds.sst[0].dropna(dim = 'lat', how = 'any').dropna(dim = 'lon', how = 'any').lat.values

def return_nearest_ocean_latlon(latlon):
    ocean_lat = min(ocean_lats, key=lambda x:abs(x-latlon[0]))
    ocean_lon = min(ocean_lons, key=lambda x:abs(x-latlon[1]))
    return [ocean_lat,ocean_lon]


ocean_latlong_dict = {}
for index in latlong_dict:
    ocean_latlong_dict[index] = return_nearest_ocean_latlon(latlong_dict[index])


ds_loc = ds
ds_loc = ds_loc.sel(lat = ocean_latlong_dict[str(location)][0], method = 'nearest').sel(lon = ocean_latlong_dict[str(location)][1], method = 'nearest')

location_data = stranding_data.loc[stranding_data['Location'] == str(location)]
location_freq = location_data.groupby(location_data.Date).count()['Species'].rename('Standings')
location_freq.index = pd.to_datetime(location_freq.index, infer_datetime_format=True)
location_freq = location_freq.sort_values(ascending=False)
location_freq = location_freq[location_freq.values>1]
dates_sorted = list(location_freq.sort_values(ascending=False).index)

ws = {}
sst = {}
direc = {}
for date in dates_sorted:
    date_7 = date + timedelta(days = -7)
    d = ds_loc.sel(time = slice(str(date_7), str(date)))
    if len(d.ws.values) >= 169:
        ws[date] = d.ws.values
        sst[date] = d.sst.values
        direc[date] = d.dir.values
    else:
        pass

ws = pd.DataFrame(ws)
sst = pd.DataFrame(sst)
direc = pd.DataFrame(direc)

n_events = len(direc.columns)-1

show_hide = {}
for x in range(n_events):
    trues = []
    for i in range(n_events):
        if i ==x:
            trues.append(True)
            trues.append(True)
        else:
            trues.append(False)
            trues.append(False)
    show_hide[x] = trues
    

def get_vector(direction):
    if direction  < 90 and direction > 0:
        x = math.sin(math.radians(direction))
        y = math.cos(math.radians(direction))
    elif direction > 90 and direction <180:
        x = math.sin(math.radians(180 - direction))
        y = -math.cos(math.radians(180 - direction))
    elif direction >180 and direction <270:
        x = -math.cos(math.radians(270 - direction))
        y = -math.sin(math.radians(270 - direction))
    else:
        x = -math.sin(math.radians(360-direction))
        y = math.cos(math.radians(360-direction))
    return(x,-y)

arrows = {}
for n in range(n_events):
    arrows[n] = []
    direc_ = direc.iloc[:,n].values[0::10]
    for x_,i in zip(range(160,0,-10),range(16)):
        ax_,ay_ = get_vector(direc_[i])
        arrows[n].append(dict(
            x=x_,
            y=-5,
            xref = 'x',
            yref = 'y',
            axref = 'pixel',
            ayref = 'pixel',
            ax=(ax_*25),
            ay=(ay_*30),
            arrowhead = 2,
            arrowsize = 1,
            arrowwidth = 1.4,
            xanchor = 'centre',
            yanchor = 'centre',
            row = 1,
            col = 1
        )
        )

        
button_dicts = []
for i in range(n_events):
    button_dicts.append(
        dict(label=str(dates_sorted[i].date()) + ', '+str(location_freq[dates_sorted[i]]) + ' Strandings',
                     method="update",
                     args=[
                         {"visible": show_hide[i]},{'annotations':arrows[i]}
                     ]) 
    )   
    
    
fig = make_subplots(rows=2, cols=1, row_heights=[0.7, 0.3])
#fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.update_xaxes(autorange="reversed")

for i in range(n_events):
    if i == 0:
        fig.add_trace(
            go.Scatter(x = np.linspace(168,0,169), y= ws.iloc[:,i].values, name='Wind Speed (knots)', visible=True),
            row = 1, col = 1
        )
        fig.add_trace(
            go.Scatter(x = np.linspace(168,0,169), y= sst.iloc[:,i].values, name='Sea Surface Temperature (celsius)', visible=True),
            row = 2, col = 1
        )
    else:
        fig.add_trace(
            go.Scatter(x = np.linspace(168,0,169), y= ws.iloc[:,i].values, name='Wind Speed (knots)', visible=False),
            row = 1, col = 1
        )
        fig.add_trace(
            go.Scatter(x = np.linspace(168,0,169), y= sst.iloc[:,i].values, name='Sea Surface Temperature (celsius)', visible=False),
            row = 2, col = 1
        )      
    
fig.update_yaxes(range=[-10, 30], row = 1)
fig.update_xaxes(range=[170, 0], row = 1)
fig.update_xaxes(range=[170, 0], row = 2)
fig.update_xaxes(title="Hours Before Strandings", row = 2)
fig.update_xaxes(tick0=168, dtick=24)

fig.update_layout(
    updatemenus=[
        dict(
            buttons=button_dicts,
            direction="down",
            showactive=True
        )
    ]
)

# Update remaining layout properties
fig.update_layout(
    title_text=str(location)+" Stranding Events",
    showlegend=True,
)


fig.show()
fig.write_html("historical_"+str(location)+".html")

This plot is interactive, use the dropdown menu on the left to change stranding event and activate wind direction arrows.

A full screen version of this plot and other locations is available at:

  • Struisbaai - https://petemarsh.com/historical_html/historical_Struisbaai
  • Mossel Bay - https://petemarsh.com/historical_html/historical_Mossel_Bay
  • Muizenberg - https://petemarsh.com/historical_html/historical_Muizenberg
  • Witsands - https://petemarsh.com/historical_html/historical_Witsands
  • Arniston - https://petemarsh.com/historical_html/historical_Arniston
  • Plettenberg Bay - https://petemarsh.com/historical_html/historical_Plettenberg_Bay
  • Hermanus - https://petemarsh.com/historical_html/historical_Hermanus
  • Sedgefield - https://petemarsh.com/historical_html/historical_Sedgefield
  • Gaansbaai - https://petemarsh.com/historical_html/historical_Gaansbaai

A common theme across these large stranding events is consistent easterly winds over the preceeding 4-5 days and a reduction in sea surface temperature over the previous 7 days.¶

Next lets consider the distribution of common weather variables over over the days preceeding stranding events in Struisbaai relative to the Autumn average.¶

In [11]:
# forecast data is only available in 3 hour steps so matching reanalysis to this simplifies things
ds_loc = xr.merge([ds_loc.where(ds_loc.time.dt.hour == 0).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 3).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 6).dropna(dim = 'time'), 
                    ds_loc.where(ds_loc.time.dt.hour == 9).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 12).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 15).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 18).dropna(dim = 'time'),
                    ds_loc.where(ds_loc.time.dt.hour == 21).dropna(dim = 'time'),])

ds_loc_daily = ds_loc.resample(time = 'D', closed = "right").mean().load()
ds_seasonal = ds_loc_daily.where(ds_loc.time.dt.season == 'MAM').dropna(dim='time', how = 'all')

strandings = list(set(stranding_data.loc[stranding_data['Location'] == str(location)].Date))
strandings = pd.to_datetime(strandings)

march = [x for x in strandings if x.month == 3]
april = [x for x in strandings if x.month == 4]
may = [x for x in strandings if x.month == 5]
dates = march + april + may
dates.sort()

lagged = {}
for lag in [0, -1,-3,-7]:
    dates_1 = [date + timedelta(days = lag) for date in dates]
    lagged[lag] = ds_loc_daily.sel(time = dates_1)

def make_plot():
    fig, axs = plt.subplots(nrows=3,ncols=2,figsize=(16,12))
    axs=axs.flatten()
    for i,var in enumerate(['sst','t2m','ws','dir','u10','v10']):
        sns.kdeplot(ds_seasonal[str(var)].values, ax = axs[i], label = 'Autumn', lw= 5)
        for key in lagged:
            sns.kdeplot(lagged[key][str(var)].values, ax = axs[i],  label = 'Stranding Days '+str(key), lw = 2.5)
        #axs[i].set_title(str(var))
    axs[0].legend()
    axs[0].set_xlabel('Sea Surface Temperature ($^\circ$C)')
    axs[1].set_xlabel('Air Temperature ($^\circ$C)')
    axs[2].set_xlabel('Wind Speed (knots)')
    axs[3].set_xlabel('Wind Direction')
    axs[4].set_xlabel('Eastward Wind Component (knots)')
    axs[5].set_xlabel('Northward Wind Component (knots)')
    fig.suptitle('', fontsize=14, fontweight='bold')


        
make_plot()

Here there aren't many clear distribution shifts other than that wind speed on days of stranding appears to be lower than normal and easterly winds are more common 1,3 and 7 days before stranding events.¶

In [12]:
df = ds_loc_daily.to_dataframe()

x = ds_loc_daily[['u10','v10']].rolling(time = 5).mean(dim = 'time')
df['u10_5m'] = x.u10
df['v10_5m'] = x.v10

sst_delta = (ds_loc_daily.sst.values[7:] - ds_loc_daily.sst.values[:-7])
df['sst_delta'] = np.concatenate([np.zeros(7), sst_delta])

df['u10_1'] = np.concatenate([df['u10'].iloc[1:].values, np.zeros(1)])
df = df.fillna(0)

location_data = stranding_data.loc[stranding_data['Location'] == str(location)]
location_freq = location_data.groupby(location_data.Date).count()['Species'].rename('Standings')
location_freq.index = pd.to_datetime(location_freq.index)
df['strandings'] = location_freq.reindex(df.index).fillna(0)
df['month'] = df.index.month

fig, axs = plt.subplots(nrows=1,ncols=3,figsize=(16,4))
axs=axs.flatten()
for i,var in enumerate(['v10_5m', 'u10_5m', 'sst_delta']):
    sns.kdeplot(df[var], label = 'all', ax = axs[i],lw = 5)
    sns.kdeplot(df[var].where(df.strandings >= 1), label = 'stranding', ax = axs[i],lw = 3)
axs[0].legend()
axs[0].set_xlabel('5-Day Average Northward Wind Component (knots)')
axs[1].set_xlabel('5-Day Average  Eastward Wind Component (knots)')
axs[2].set_xlabel('Change in Sea Surface Temperature over 7 days')
plt.show()

Considering the 5 day average of easterly wind on stranding days shows a clear distribution shift, while considering the change in sea surface temperature over 7 days also shows a clear distribution shift.¶

Next lets try build a model to predict stranding events in Stuisbaai.¶

Here I have chosen to use a random forest model, which is a type of supervised machine learning model. The random forest is a classification algorithm consisting of many uncorrelated decisions trees which seeks to classify input data and produce a prediction more accurate than that of the individual decision trees it is comprised of. Here the input data I am considering is: the month of the year, wind speed on the day of strandings, eastward wind component the day before stranding, the average eastward wind component over the preceeding 5 days before stranding and the change in sea surface temperature over the 7 days preceeding stranding.¶

In [13]:
from sklearn.ensemble import RandomForestRegressor

train = df.iloc[:2045]
test = df.iloc[2045:]

sel_var = ['month', 'ws','u10_1', 'u10_5m','sst_delta']

train_features = np.array(train[sel_var])
train_labels = np.array(train['strandings'])
test_features = np.array(test[sel_var])
test_labels = np.array(test['strandings'])

# Instantiate model
rf = RandomForestRegressor(n_estimators = 100, random_state = 42)
# Train the model on training data
rf.fit(train_features, train_labels)
#save model
pickle.dump(rf, open('model_train.pkl', 'wb'))
In [14]:
# Use the forest's predict method on the test data
predictions = rf.predict(test_features)
# Calculate the absolute errors
errors = abs(predictions - test_labels)
#Print out the mean absolute error (mae)
print('Mean Absolute Error:', round(np.mean(errors), 5), 'degrees.')
Mean Absolute Error: 0.17265 degrees.
In [15]:
out = pd.DataFrame({'day': test.index, 'predictions' : list(predictions), "test_labels": list(test_labels)}).set_index('day')
pred_ = out.predictions.where(out.predictions>=1).dropna()
maybe_ = out.predictions.where(out.predictions>0.6).where(out.predictions<1).dropna()
test_ = out.test_labels.where(out.test_labels>=1).dropna()

fig = go.Figure()

fig.add_trace(go.Scatter(x=pred_.index, y=pred_.values,
                    marker = dict(size=12, color = 'red'),
                    mode='markers',
                    name='Predicted'))

fig.add_trace(go.Scatter(x=maybe_.index, y=maybe_.values, 
                    marker = dict(size=12, color = 'yellow'),
                    mode='markers',
                    name='Maybe'))

fig.add_trace(go.Scatter(x=test_.index, y=test_.values,
                    marker = dict(size=12, color = 'black'),
                    mode='markers',
                    name='Observed'))

fig.update_yaxes(title="Number of Strandings")

fig.show()

To test the viability of the model it has been trained on the first 80% of the available data and used to predict the remaining 20%, which has been shown above. Here while the model is far from perfect it does appear to cluster around observed strandings. Particularly so when the number of strandings is high¶

A key barrier to implement this model to provide future projections is that sea surface temperature is not available from the ECMWF free variables. It is possible to motivate for a research license for this data, however here sea surface temperature has simply been removed from the prediction model. This does have a slight negative impact on model performance. Below is an example of the model trained without sea surface temperature¶

In [16]:
rf = pickle.load(open('model_no_sst_train.pkl', 'rb'))
In [17]:
# Use the forest's predict method on the test data
predictions = rf.predict(test_features[:,:-1])
# Calculate the absolute errors
errors = abs(predictions - test_labels)
#Print out the mean absolute error (mae)
print('Mean Absolute Error:', round(np.mean(errors), 5), 'degrees.')
Mean Absolute Error: 0.18096 degrees.
In [18]:
out = pd.DataFrame({'day': test.index, 'predictions' : list(predictions), "test_labels": list(test_labels)}).set_index('day')
pred_ = out.predictions.where(out.predictions>=1).dropna()
maybe_ = out.predictions.where(out.predictions>0.6).where(out.predictions<1).dropna()
test_ = out.test_labels.where(out.test_labels>=1).dropna()

fig = go.Figure()

fig.add_trace(go.Scatter(x=pred_.index, y=pred_.values,
                    marker = dict(size=12, color = 'red'),
                    mode='markers',
                    name='Predicted'))

fig.add_trace(go.Scatter(x=maybe_.index, y=maybe_.values, 
                    marker = dict(size=12, color = 'yellow'),
                    mode='markers',
                    name='Maybe'))

fig.add_trace(go.Scatter(x=test_.index, y=test_.values,
                    marker = dict(size=12, color = 'black'),
                    mode='markers',
                    name='Observed'))

fig.update_yaxes(title="Number of Strandings")
fig.show()
  • a 5-day forecast implementation of the above can be viewed at - https://petemarsh.com/forecast

  • while the notebook implementing this can be viewed at - https://petemarsh.com/Turtle_Forecast_Notebook

  • To reproduce this notebook please refer to https://github.com/peterm790/Turtle_Stranding

In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 

ends